Skip to content

Conversation

@Wohox
Copy link
Contributor

@Wohox Wohox commented Jan 22, 2026

Description

This PR adds get_backward_dw_params for TE modules, which helps manage the hooks of parameters.

For Megatron-LM, get_backward_dw_params will be called once the wgrad cuda graph is executed. Currently the backward_post_hook of wgrad computation is discarded and will cause parameters to skip grad reduce.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Jan 22, 2026

Greptile Overview

Greptile Summary

This PR fixes a bug where weight gradient (wgrad) accumulation and reduce hooks were being skipped when CUDA graphs are used with delayed weight gradient computation in Megatron-LM.

Key changes:

  • Refactored hook triggering logic in TransformerEngineBaseModule.backward_dw() into a reusable _trigger_wgrad_accumulation_and_reduce_hooks() method
  • Modified make_graphed_attribute_functions() in graph.py to accept te_modules parameter and trigger hooks after wgrad CUDA graph replay
  • Ensures parameter gradient reduction hooks execute correctly when backward_dw() is called from within CUDA graph context

How it works:
After the wgrad CUDA graph is replayed, the code now iterates through all TE modules and calls _trigger_wgrad_accumulation_and_reduce_hooks() on modules that have delayed wgrad computation enabled. This ensures registered hooks (e.g., for gradient reduction in distributed training) are properly executed.

Confidence Score: 4/5

  • Safe to merge with minor considerations around edge cases
  • The implementation follows established patterns in the codebase (similar to layernorm_mlp.py and grouped_linear.py). The refactoring is clean and the hook triggering logic is straightforward. Score is 4 instead of 5 due to: (1) potential edge case where modules in te_modules might not all be TransformerEngineBaseModule instances despite the check, and (2) the hook is triggered for all modules regardless of whether they actually executed wgrad computation in the graph
  • No files require special attention - the changes are localized and follow existing patterns

Important Files Changed

Filename Overview
transformer_engine/pytorch/graph.py Modified make_graphed_attribute_functions to pass te_modules and trigger wgrad hooks after graph replay
transformer_engine/pytorch/module/base.py Refactored hook triggering logic into _trigger_wgrad_accumulation_and_reduce_hooks method for reusability

Sequence Diagram

sequenceDiagram
    participant MCore as Megatron-LM
    participant Graph as CUDA Graph
    participant Module as TE Module
    participant Hooks as Wgrad Hooks
    
    MCore->>Graph: Execute wgrad CUDA graph
    Graph->>Graph: backward_dw()
    Graph->>Graph: Replay bwd_dw_graph
    Graph->>Module: Check need_backward_dw()
    Module-->>Graph: true
    Graph->>Module: _trigger_wgrad_accumulation_and_reduce_hooks()
    Module->>Hooks: Execute registered hooks
    Hooks-->>Module: Hooks complete (grad reduce, etc.)
    Module-->>Graph: Hook execution complete
    Graph-->>MCore: Graph execution complete
Loading

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Jan 22, 2026

Greptile's behavior is changing!

From now on, if a review finishes with no comments, we will not post an additional "statistics" comment to confirm that our review found nothing to comment on. However, you can confirm that we reviewed your changes in the status check section.

This feature can be toggled off in your Code Review Settings by deselecting "Create a status check for each PR".

@Wohox
Copy link
Contributor Author

Wohox commented Jan 22, 2026

@buptzyb @lhb8125 Please help review this PR, thanks!

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 file reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Get the parameters for the backward weight gradient computation.
"""
params = []
params.append(noop_cat(self._get_weight_tensors()))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: in backward_dw() (line 1520-1522), weight tensors are only accessed when not self.fuse_wgrad_accumulation, but this method unconditionally returns weight parameters. depending on Megatron-LM's usage, this could cause hooks to be registered on parameters that shouldn't have them when fuse_wgrad_accumulation=True

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

commit content reverted.

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant