Skip to content

Conversation

@jberchtold-nvidia
Copy link
Collaborator

Description

Please include a brief summary of the changes, relevant motivation and context.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

phu0ngng and others added 30 commits December 3, 2025 13:07
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
- Add FP8 scale_inv pointer handling in nvte_grouped_gemm for proper FP8 GEMM
- Fix random padding in tests to ensure 16-byte alignment for all dtypes
- Reorder GroupedGemmSetupWorkspace members for natural alignment
- Remove debug prints

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
- Remove unused alignment parameter from GroupedGemmSetupWorkspace::from_buffers
- Simplify select_grouped_operand by removing dead code branches
- Add GroupedOperandSelection.tensor field to avoid passing tensor separately
- Extract set_fp8_scale_pointers and init_matrix_layouts helpers
- Add safety check for FP8 on Hopper column-wise fallback
- Support NULL C tensor when beta=0 (uses D as placeholder)
- Remove unused get_scale_inv() from test
- Add use_null_c test parameter and test case
- Fix documentation: alpha/beta are single element tensors only

Signed-off-by: Piotr Gadzinski <pgadzinski@nvidia.com>
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
- Change alpha/beta from single values to per-matrix arrays
- Validate alpha/beta have exactly num_tensors elements
- Update kernel to index alpha_ptr[idx] and beta_ptr[idx]
- Move alpha/beta validation to validate_grouped_gemm_inputs
- Update tests to use per-matrix alpha/beta arrays
- Update documentation

Signed-off-by: Piotr Gadzinski <pgadzinski@nvidia.com>
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Signed-off-by: Piotr Gadzinski <pgadzinski@nvidia.com>
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
@jberchtold-nvidia jberchtold-nvidia marked this pull request as draft January 23, 2026 00:27
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Jan 23, 2026

Greptile Summary

This PR implements grouped GEMM support for JAX, enabling efficient batched matrix multiplications with varying shapes per group - a key optimization for Mixture-of-Experts (MoE) models with per-expert quantization.

Key Changes

  • C++ grouped GEMM kernel (cublaslt_grouped_gemm.cu): New implementation using cuBLASLt's grouped matmul API for Blackwell (SM100+) GPUs with cuBLAS 13.1+. Handles per-tensor shape metadata, workspace allocation, and proper row-wise/column-wise storage selection for FP8 Hopper TN constraints.

  • JAX einsum wrapper (einsum.py): New high-level API that decomposes einsum operations into vmapped dense layers with per-expert quantization support. Designed for MoE workloads (e.g., EBCM,EMH->EBCH).

  • FFI bindings (gemm.cpp, gemm.py): Integrates grouped GEMM into JAX's XLA FFI infrastructure with proper shape inference, sharding rules, and abstract evaluation.

  • Configuration management (config.cpp/.h): Added GroupedMatmulConfig API with optional heuristic hints (avg_m, avg_n, avg_k) for cuBLASLt algorithm selection.

  • Test coverage: Comprehensive C++ tests for various dtype/transpose/shape combinations, and JAX tests validating MoE patterns with gradients and multiple FP8 recipes.

Notes

  • This is marked as a DRAFT PR with commit message "wgrad", suggesting work-in-progress status
  • The PR description is incomplete (empty checklist, no description of changes)
  • Debug print statement found in production code (base.py:223)

Confidence Score: 3/5

  • This PR introduces significant new functionality with proper architecture but is marked as DRAFT with incomplete documentation
  • Score reflects that while the implementation appears technically sound with good test coverage, this is explicitly a DRAFT PR (per the title and "wgrad" commit message suggesting work-in-progress). The PR description is incomplete with empty checklists and no change documentation. A debug print statement was left in production code. The core grouped GEMM implementation looks solid with proper validation and error handling, but the draft status and incomplete documentation warrant caution before merging.
  • Pay attention to transformer_engine/jax/cpp_extensions/base.py (contains debug print), and verify the PR description is completed before merging

Important Files Changed

Filename Overview
transformer_engine/common/gemm/cublaslt_grouped_gemm.cu New grouped GEMM implementation for Blackwell GPUs (SM100+). Properly handles per-tensor shape metadata, workspace allocation, and FP8/BF16 dtypes. Implements row-wise/column-wise storage selection logic for Hopper TN constraints.
transformer_engine/common/include/transformer_engine/gemm.h API declarations for grouped GEMM with clear documentation of requirements (cuBLAS 13.1+, Blackwell SM100+). Well-documented config attributes and C++/C wrapper classes.
transformer_engine/jax/einsum.py New einsum implementation for MoE with per-expert quantization using vmap+dense. Includes comprehensive docstrings explaining limitations (NN layout only, single batch dim). Validates quantizer_dim requirements.
transformer_engine/jax/csrc/extensions/gemm.cpp C++ FFI bindings for grouped GEMM primitive. Integrates with existing GEMM infrastructure, handles workspace calculations, and validates tensor shapes.
transformer_engine/jax/cpp_extensions/base.py Added batcher_impl helper for standard JAX batching pattern. Includes debug print statements that should be removed before production.

Sequence Diagram

sequenceDiagram
    participant User as JAX User Code
    participant Einsum as einsum.py
    participant Dense as dense.py
    participant GemmPrim as gemm.py (Primitive)
    participant FFI as gemm.cpp (FFI)
    participant Kernel as cublaslt_grouped_gemm.cu
    participant cuBLAS as cuBLASLt

    User->>Einsum: einsum("EBCM,EMH->EBCH", x, w, quantizer_sets, quantizer_dim='E')
    Einsum->>Einsum: Parse equation & validate quantizer_dim
    Einsum->>Einsum: Stack quantizer_sets into pytree
    Einsum->>Dense: vmap(dense_with_quantizer) over expert dimension
    
    loop For each expert
        Dense->>Dense: Quantize operands (if needed)
        Dense->>GemmPrim: grouped_gemm.bind(lhs, rhs, ...)
        GemmPrim->>GemmPrim: Shape inference & sharding rules
        GemmPrim->>FFI: GroupedGemmFFI (XLA custom call)
        FFI->>FFI: Validate inputs & allocate workspaces
        FFI->>Kernel: nvte_grouped_gemm(A, B, C, D, alpha, beta, ...)
        Kernel->>Kernel: setup_grouped_gemm_kernel (compute pointers/dims)
        Kernel->>Kernel: Select operand storage (row/column-wise)
        Kernel->>cuBLAS: cublasLtMatmul (grouped API)
        cuBLAS-->>Kernel: Compute D = alpha*A@B + beta*C
        Kernel-->>FFI: Return output
        FFI-->>GemmPrim: Return result
        GemmPrim-->>Dense: Return expert output
    end
    
    Dense-->>Einsum: Stack expert outputs
    Einsum-->>User: Return final result
Loading

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

29 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

f"Got batched_args={[arg.shape for arg, bdim in zip(batched_args, batch_dims) if bdim is not None]}."
)
assert batch_dim is not None and batch_size is not None, "Invalid batching config!"

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: debug print statement left in production code

Suggested change
# print(f"[{cls.__name__}] Batching with size {batch_size}")

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants