Skip to content

Conversation

@timmoon10
Copy link
Collaborator

@timmoon10 timmoon10 commented Jan 24, 2026

Description

This PR adds a grouped linear op, which can be used in the grouped MLP block in Mixture-of-Experts models. It also adds an experimental fused operation for a grouped MLP block, using a CuTe DSL kernel that computes an MXFP8 grouped GEMM and SwiGLU.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Add a grouped linear operation
  • Add a post-scaled SwiGLU op and add support for interleaving SwiGLU gate and linear units
  • Add a fused operation for grouped MLP

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

timmoon10 and others added 30 commits January 7, 2026 00:15
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Refactor fusion functions to remove index bookkeeping. Refactor fused ops to use consistent operation order.

Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Test is too permissive since the test should still be failing. The weights are not properly interleaved yet.

Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
@timmoon10 timmoon10 added the performance Performance issues label Jan 24, 2026
Signed-off-by: Tim Moon <tmoon@nvidia.com>
timmoon10 added a commit to timmoon10/TransformerEngine that referenced this pull request Jan 24, 2026
Signed-off-by: Tim Moon <tmoon@nvidia.com>
timmoon10 added a commit that referenced this pull request Jan 25, 2026
* Expose option for custom op fusions

Refactor fusion functions to remove index bookkeeping. Refactor fused ops to use consistent operation order.

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Add tests for custom ops

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix linter warnings and numerical test failures

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Tweak pattern matching logic with fixed window sizes

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Use TF32 tols in fused op tests

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Review suggestion from @greptile-apps

Signed-off-by: Tim Moon <tmoon@nvidia.com>

* Backpropagate fixes from #2622

Signed-off-by: Tim Moon <tmoon@nvidia.com>

---------

Signed-off-by: Tim Moon <tmoon@nvidia.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
@timmoon10 timmoon10 mentioned this pull request Jan 25, 2026
13 tasks
@timmoon10 timmoon10 changed the title [PyTorch] Prototype of fused operation for grouped MLP [PyTorch] Add grouped linear op and experimental fusion for grouped MLP Jan 25, 2026
Signed-off-by: Tim Moon <tmoon@nvidia.com>
@timmoon10 timmoon10 marked this pull request as ready for review January 25, 2026 01:00
@timmoon10
Copy link
Collaborator Author

/te-ci pytorch L1

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Jan 25, 2026

Greptile Overview

Greptile Summary

This PR adds grouped linear operations and an experimental fused grouped MLP kernel for Mixture-of-Experts models. The implementation includes a new GroupedLinear op that applies multiple linear transformations to input splits, a ScaledSwiGLU activation with post-scaling support for gate/linear interleaving, and a CuTe DSL-based fused kernel for SM100+ GPUs that combines grouped GEMM, SwiGLU, and post-multiplication in MXFP8.

Critical Issues Found:

  • Undefined variable group_idx in grouped_linear.py:419 that will cause runtime failure in the FP8 quantization path
  • Duplicate dimension checks in forward_grouped_mlp.py (lines 68, 73) checking the same dimension twice instead of validating both in_features and out_features

Changes:

  • New GroupedLinear operation supporting FP8/MXFP8 quantization with multiple grouped linear transformations
  • ScaledSwiGLU activation supporting 32-wide gate/linear interleaving for the fused kernel
  • ForwardGroupedMLP_CuTeGEMMSwiGLU_MXFP8 fused operation requiring SM100+ compute capability
  • Test coverage for GLU interleaving and grouped operations

Confidence Score: 2/5

  • This PR has critical bugs that will cause runtime failures in production
  • The undefined variable bug in grouped_linear.py:419 will cause a NameError when the FP8 quantization path is triggered with non-quantized weights. The duplicate dimension validation checks are also bugs that could allow invalid tensor dimensions to pass validation. These are definite runtime errors, not potential issues.
  • Pay close attention to transformer_engine/pytorch/ops/basic/grouped_linear.py (critical bug on line 419) and transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py (validation bugs on lines 68, 73)

Important Files Changed

Filename Overview
transformer_engine/pytorch/ops/basic/grouped_linear.py New GroupedLinear op with undefined variable bug in quantization path on line 419
transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py Fused grouped MLP kernel with duplicate dimension checks (lines 68, 73)
transformer_engine/pytorch/ops/basic/swiglu.py Added ScaledSwiGLU op with post-scaling, implementation looks correct

Sequence Diagram

sequenceDiagram
    participant User
    participant GroupedLinear1 as GroupedLinear (FC1)
    participant ScaledSwiGLU
    participant GroupedLinear2 as GroupedLinear (FC2)
    participant FusedKernel as ForwardGroupedMLP_CuTeGEMMSwiGLU_MXFP8

    Note over User,FusedKernel: Forward Pass - Grouped MLP Block

    User->>GroupedLinear1: input tensor + split_sizes
    Note over GroupedLinear1: Split input by split_sizes
    Note over GroupedLinear1: Quantize inputs (FP8/MXFP8)
    Note over GroupedLinear1: Quantize weights (FP8/MXFP8)
    GroupedLinear1->>GroupedLinear1: general_grouped_gemm()
    Note over GroupedLinear1: Compute y = xW^T for each group
    GroupedLinear1->>ScaledSwiGLU: FC1 output (interleaved gate/linear)

    ScaledSwiGLU->>ScaledSwiGLU: Remove gate interleaving
    ScaledSwiGLU->>ScaledSwiGLU: tex.swiglu(x)
    Note over ScaledSwiGLU: SiLU(gate) * linear
    ScaledSwiGLU->>ScaledSwiGLU: Multiply by scales
    ScaledSwiGLU->>GroupedLinear2: Scaled SwiGLU output

    GroupedLinear2->>GroupedLinear2: Split input by split_sizes
    Note over GroupedLinear2: Quantize inputs (FP8/MXFP8)
    GroupedLinear2->>GroupedLinear2: general_grouped_gemm()
    Note over GroupedLinear2: Compute y = xW^T for each group
    GroupedLinear2->>User: Final output

    Note over FusedKernel: Alternative: Fused Path (SM100+, MXFP8)
    
    User->>FusedKernel: input + split_sizes + scales
    Note over FusedKernel: Pack MXFP8 tensors for kernel
    FusedKernel->>FusedKernel: grouped_gemm_swiglu_kernel()
    Note over FusedKernel: Fused FC1 + SwiGLU + scale
    FusedKernel->>FusedKernel: Unpack MXFP8 outputs
    FusedKernel->>FusedKernel: general_grouped_gemm() for FC2
    FusedKernel->>User: Final output
Loading

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

3 files reviewed, 3 comments

Edit Code Review Agent Settings | Greptile

if not with_quantized_compute:
w = maybe_dequantize(w, dtype)
elif with_quantized_compute and not is_quantized_tensor(w):
quantizer = weight_quantizers[group_idx]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

group_idx is undefined in this scope - loop variable is w, quantizer from the zip on line 415. Should use quantizer directly (already assigned) or add enumeration to the loop.

Suggested change
quantizer = weight_quantizers[group_idx]
# quantizer is already assigned from the zip, use it directly

if not self.is_supported():
self.grouped_gemm_swiglu_kernel() # Try triggering import error
raise RuntimeError(f"{self.__class__.__name__} is not supported on this system.")
if fc1.in_features % 256 != 0 or fc1.in_features % 256 != 0:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

checking fc1.in_features % 256 != 0 twice instead of checking both in_features and out_features

Suggested change
if fc1.in_features % 256 != 0 or fc1.in_features % 256 != 0:
if fc1.in_features % 256 != 0 or fc1.out_features % 256 != 0:

f"Unsupported dims for FC1 (group_size={fc1.group_size}, "
f"in_features={fc1.in_features}, out_features={fc1.out_features})."
)
if fc2.in_features % 256 != 0 or fc2.in_features % 256 != 0:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

checking fc2.in_features % 256 != 0 twice instead of checking both in_features and out_features

Suggested change
if fc2.in_features % 256 != 0 or fc2.in_features % 256 != 0:
if fc2.in_features % 256 != 0 or fc2.out_features % 256 != 0:

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Jan 25, 2026

Greptile Overview

Greptile Summary

This PR adds grouped linear operations and SwiGLU activation variants to support Mixture-of-Experts (MoE) models with grouped MLP blocks.

Key Changes:

  • Added GroupedLinear operation that applies multiple linear transformations by splitting input along the first dimension
  • Refactored SwiGLU operations into separate module (swiglu.py) with three variants: SwiGLU, ClampedSwiGLU, and ScaledSwiGLU
  • Added experimental fused kernel ForwardGroupedMLP_CuTeGEMMSwiGLU_MXFP8 using CuTe DSL for SM100+ GPUs that fuses GroupedLinear + ScaledSwiGLU + GroupedLinear
  • Added support for gate/activation interleaving in SwiGLU operations (controlled by glu_interleave_size parameter)
  • Comprehensive test coverage for all new operations

Issues Found:

  • Critical bug in grouped_linear.py:419 where undefined variable group_idx is used instead of the loop variable quantizer
  • Duplicate validation conditions in fused kernel that check fc1.in_features and fc2.in_features twice instead of checking output features
  • Minor typo in docstring ("Paramters" instead of "Parameters")

Confidence Score: 2/5

  • This PR has critical bugs that will cause runtime failures
  • The undefined variable group_idx in grouped_linear.py:419 will cause a NameError when the code path is executed (quantized compute with unquantized weights). The duplicate validation conditions may allow invalid configurations to pass validation. These bugs need to be fixed before merging.
  • Pay close attention to transformer_engine/pytorch/ops/basic/grouped_linear.py and transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py - both contain logic errors

Important Files Changed

Filename Overview
transformer_engine/pytorch/ops/basic/grouped_linear.py New GroupedLinear op with typo in docstring and undefined variable bug in forward pass
transformer_engine/pytorch/ops/basic/swiglu.py New SwiGLU ops (SwiGLU, ClampedSwiGLU, ScaledSwiGLU) with gate interleaving support
transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py New fused grouped MLP kernel with duplicate validation conditions for FC1/FC2 dimensions

Sequence Diagram

sequenceDiagram
    participant User
    participant GroupedLinear_FC1
    participant ScaledSwiGLU
    participant GroupedLinear_FC2
    participant CuTeGEMM as CuTe Fused Kernel
    
    User->>GroupedLinear_FC1: Input (batched tokens)
    Note over GroupedLinear_FC1: Split input by group sizes
    GroupedLinear_FC1->>GroupedLinear_FC1: Quantize inputs (FP8)
    GroupedLinear_FC1->>GroupedLinear_FC1: Quantize weights (FP8)
    
    alt MXFP8 + SM100 + no bias
        GroupedLinear_FC1->>CuTeGEMM: Fused path
        Note over CuTeGEMM: GroupedGEMM + SwiGLU + Scale
        CuTeGEMM->>GroupedLinear_FC2: FC2 input (quantized)
    else Standard path
        GroupedLinear_FC1->>GroupedLinear_FC1: general_grouped_gemm
        GroupedLinear_FC1->>ScaledSwiGLU: FC1 output
        ScaledSwiGLU->>ScaledSwiGLU: Remove gate interleaving
        ScaledSwiGLU->>ScaledSwiGLU: SwiGLU(x) = silu(x1) * x2
        ScaledSwiGLU->>ScaledSwiGLU: Multiply by scales
        ScaledSwiGLU->>GroupedLinear_FC2: Scaled activation
    end
    
    GroupedLinear_FC2->>GroupedLinear_FC2: Quantize inputs (FP8)
    GroupedLinear_FC2->>GroupedLinear_FC2: general_grouped_gemm
    GroupedLinear_FC2->>User: Output (batched tokens)
Loading

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

4 files reviewed, 4 comments

Edit Code Review Agent Settings | Greptile

dimension, applying a separate ``torch.nn.Linear`` to each split,
and concatenating along the first dimension.

Paramters
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo: "Paramters" should be "Parameters"

Suggested change
Paramters
Parameters

if not with_quantized_compute:
w = maybe_dequantize(w, dtype)
elif with_quantized_compute and not is_quantized_tensor(w):
quantizer = weight_quantizers[group_idx]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

group_idx is undefined in this scope - loop uses w, quantizer from zip. Should use quantizer directly

Suggested change
quantizer = weight_quantizers[group_idx]
quantizer.set_usage(rowwise=True, columnwise=input_requires_grad)

Comment on lines +68 to +72
if fc1.in_features % 256 != 0 or fc1.in_features % 256 != 0:
raise ValueError(
f"Unsupported dims for FC1 (group_size={fc1.group_size}, "
f"in_features={fc1.in_features}, out_features={fc1.out_features})."
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

duplicate condition checks fc1.in_features % 256 twice instead of checking fc1.out_features

Suggested change
if fc1.in_features % 256 != 0 or fc1.in_features % 256 != 0:
raise ValueError(
f"Unsupported dims for FC1 (group_size={fc1.group_size}, "
f"in_features={fc1.in_features}, out_features={fc1.out_features})."
)
if fc1.in_features % 256 != 0 or fc1.out_features % 256 != 0:
raise ValueError(
f"Unsupported dims for FC1 (group_size={fc1.group_size}, "
f"in_features={fc1.in_features}, out_features={fc1.out_features})."
)

Comment on lines +73 to +77
if fc2.in_features % 256 != 0 or fc2.in_features % 256 != 0:
raise ValueError(
f"Unsupported dims for FC2 (group_size={fc2.group_size}, "
f"in_features={fc2.in_features}, out_features={fc2.out_features})."
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

duplicate condition checks fc2.in_features % 256 twice instead of checking fc2.out_features

Suggested change
if fc2.in_features % 256 != 0 or fc2.in_features % 256 != 0:
raise ValueError(
f"Unsupported dims for FC2 (group_size={fc2.group_size}, "
f"in_features={fc2.in_features}, out_features={fc2.out_features})."
)
if fc2.in_features % 256 != 0 or fc2.out_features % 256 != 0:
raise ValueError(
f"Unsupported dims for FC2 (group_size={fc2.group_size}, "
f"in_features={fc2.in_features}, out_features={fc2.out_features})."
)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

performance Performance issues

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant