-
Notifications
You must be signed in to change notification settings - Fork 1.4k
Add MAGNUS: Multi-Attention Guided Network for Unified Segmentation #8717
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Conversation
- Add MAGNUS hybrid CNN-Transformer architecture for medical image segmentation - Implement CNNPath for hierarchical feature extraction - Implement TransformerPath for global context modeling - Add CrossModalAttentionFusion for bidirectional cross-attention - Add ScaleAdaptiveConv for multi-scale feature extraction - Add SEBlock for channel recalibration - Support both 2D and 3D medical images - Add deep supervision option - Add comprehensive unit tests Reference: Aras et al., IEEE Access 2026, DOI: 10.1109/ACCESS.2026.3656667 Signed-off-by: Sefa Aras <sefa666@hotmail.com>
for more information, see https://pre-commit.ci
📝 WalkthroughWalkthroughAdds MAGNUS, a CNN–ViT fusion segmentation model and its components to monai.networks.nets. New classes: MAGNUS, CNNPath, TransformerPath, CrossModalAttentionFusion, ScaleAdaptiveConv, SEBlock, and DecoderBlock. Implements feature extraction (CNNPath), patch embedding and Transformer encoding (TransformerPath) with positional interpolation, bidirectional cross-attention fusion, multi-scale convolution aggregation, decoder blocks with optional SE attention, deep supervision, weight initialization, and input validations. Also exposes these symbols in the package init and adds comprehensive unit tests for the new modules and configurations. Estimated code review effort🎯 4 (Complex) | ⏱️ ~50 minutes 🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
🤖 Fix all issues with AI agents
In `@monai/networks/nets/magnus.py`:
- Around line 147-166: The transformer path lacks positional embeddings: add a
learnable positional embedding parameter (e.g., self.pos_embed =
nn.Parameter(torch.zeros(1, num_patches, hidden_dim))) initialized (truncated
normal or normal) and sized to match the sequence length produced by
self.embedding (compute num_patches from input spatial dimensions divided by
patch_size or infer from the flattened embedding shape at runtime), then in the
forward pass add this positional embedding to the flattened patch tokens before
passing them into self.transformer; ensure the parameter is registered on the
correct device and that self.norm still applies after the transformer.
🧹 Nitpick comments (4)
monai/networks/nets/magnus.py (2)
37-37: Sort__all__alphabetically.Per Ruff RUF022: apply isort-style sorting to
__all__.Proposed fix
-__all__ = ["MAGNUS", "CNNPath", "TransformerPath", "CrossModalAttentionFusion", "ScaleAdaptiveConv"] +__all__ = ["CNNPath", "CrossModalAttentionFusion", "MAGNUS", "ScaleAdaptiveConv", "TransformerPath"]
703-704: Addstrict=Truetozip().Ensures decoder_blocks and cnn_skips have matching lengths, catching bugs if construction changes.
Proposed fix
- for i, (decoder_block, skip) in enumerate(zip(self.decoder_blocks, cnn_skips)): + for i, (decoder_block, skip) in enumerate(zip(self.decoder_blocks, cnn_skips, strict=True)):tests/networks/nets/test_magnus.py (2)
29-35: Consider importing frommonai.networks.netsto verify public exports.Current imports bypass the public API. Testing via
from monai.networks.nets import MAGNUSwould also validate the__init__.pyexports.
190-194: Addstrict=Truetozip().Although length is asserted above, adding
strict=Truemakes intent explicit.Proposed fix
- for i, (feat, out) in enumerate(zip(features, outputs)): + for i, (feat, out) in enumerate(zip(features, outputs, strict=True)):
| aux_weights: Sequence[float] = (0.4, 0.3, 0.3), | ||
| ) -> None: | ||
| super().__init__() | ||
|
|
||
| if spatial_dims not in (2, 3): | ||
| raise ValueError(f"spatial_dims must be 2 or 3, got {spatial_dims}.") | ||
|
|
||
| self.spatial_dims = spatial_dims | ||
| self.in_channels = in_channels | ||
| self.out_channels = out_channels | ||
| self.features = list(features) | ||
| self.deep_supervision = deep_supervision | ||
| self.aux_weights = list(aux_weights) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
aux_weights is stored but never used.
The aux_weights parameter is documented and stored as an attribute but not applied anywhere in the model. Either apply them in the forward pass or remove from constructor and document that users should handle weighting externally.
🧰 Tools
🪛 Ruff (0.14.13)
558-558: Avoid specifying long messages outside the exception class
(TRY003)
- Add learnable positional embeddings to TransformerPath for proper spatial reasoning - Implement dynamic positional embedding interpolation for varying input sizes - Add positional dropout for regularization - Update aux_weights docstring to clarify it's for external use only Addresses CodeRabbit review comments on PR Project-MONAI#8717 Signed-off-by: Sefa Aras <sefa666@hotmail.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🤖 Fix all issues with AI agents
In `@monai/networks/nets/magnus.py`:
- Line 37: The __all__ export list in magnus.py omits the public classes SEBlock
and DecoderBlock; update the __all__ variable to include "SEBlock" and
"DecoderBlock" (and reorder the entries alphabetically) so these symbols are
properly exported for external importers; ensure the string names exactly match
the class names SEBlock and DecoderBlock and preserve existing entries like
MAGNUS, CNNPath, TransformerPath, CrossModalAttentionFusion, ScaleAdaptiveConv.
🧹 Nitpick comments (5)
monai/networks/nets/magnus.py (5)
177-200: Unused parameterxin_interpolate_pos_encoding.Docstring says "for device reference" but it's not used. Either remove it or use
x.deviceif device placement is needed.Proposed fix (remove unused parameter)
- def _interpolate_pos_encoding(self, x: torch.Tensor, num_patches: int) -> torch.Tensor: + def _interpolate_pos_encoding(self, num_patches: int) -> torch.Tensor: """ Interpolate positional embeddings to match the number of patches. Args: - x: input tensor for device reference. num_patches: target number of patches. Returns: Interpolated positional embeddings of shape (1, num_patches, hidden_dim). """And update the call site at line 222:
- pos_embed = self._interpolate_pos_encoding(x_flat, num_patches) + pos_embed = self._interpolate_pos_encoding(num_patches)
746-746: Addstrict=Truetozip()for safety.Lengths should match by construction, but explicit strictness prevents silent bugs if refactored later.
Proposed fix
- for i, (decoder_block, skip) in enumerate(zip(self.decoder_blocks, cnn_skips)): + for i, (decoder_block, skip) in enumerate(zip(self.decoder_blocks, cnn_skips, strict=True)):
690-705: Minor: Kaiming initialization withnonlinearity="relu"applied to GELU layers.TransformerEncoderLayer uses GELU activation. While Kaiming init defaults assume ReLU, PyTorch lacks a GELU option. This is acceptable but worth noting.
580-601: MissingRaisessection in docstring.Per coding guidelines, raised exceptions should be documented.
ValueErroris raised at line 601 but not documented.Proposed docstring addition
Add to the docstring before
Example::Raises: ValueError: If spatial_dims is not 2 or 3.
250-261: MissingRaisessection in docstring.
ValueErroris raised when channels is not divisible by num_heads.Proposed docstring addition
Add to the class docstring:
Raises: ValueError: If channels is not divisible by num_heads.
| from monai.networks.blocks import Convolution, UpSample | ||
| from monai.networks.layers.utils import get_act_layer, get_norm_layer | ||
|
|
||
| __all__ = ["MAGNUS", "CNNPath", "TransformerPath", "CrossModalAttentionFusion", "ScaleAdaptiveConv"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
__all__ is missing SEBlock and DecoderBlock.
These are public classes that should be exported. Also, consider sorting alphabetically per static analysis.
Proposed fix
-__all__ = ["MAGNUS", "CNNPath", "TransformerPath", "CrossModalAttentionFusion", "ScaleAdaptiveConv"]
+__all__ = [
+ "CNNPath",
+ "CrossModalAttentionFusion",
+ "DecoderBlock",
+ "MAGNUS",
+ "ScaleAdaptiveConv",
+ "SEBlock",
+ "TransformerPath",
+]🧰 Tools
🪛 Ruff (0.14.13)
37-37: __all__ is not sorted
Apply an isort-style sorting to __all__
(RUF022)
🤖 Prompt for AI Agents
In `@monai/networks/nets/magnus.py` at line 37, The __all__ export list in
magnus.py omits the public classes SEBlock and DecoderBlock; update the __all__
variable to include "SEBlock" and "DecoderBlock" (and reorder the entries
alphabetically) so these symbols are properly exported for external importers;
ensure the string names exactly match the class names SEBlock and DecoderBlock
and preserve existing entries like MAGNUS, CNNPath, TransformerPath,
CrossModalAttentionFusion, ScaleAdaptiveConv.
Description
This PR adds MAGNUS (Multi-Attention Guided Network for Unified Segmentation), a hybrid CNN-Transformer architecture for medical image segmentation.
Key Features
New Files
monai/networks/nets/magnus.py- Main implementationtests/networks/nets/test_magnus.py- Unit tests (17 tests)Modified Files
monai/networks/nets/__init__.py- Export MAGNUS and componentsUsage Example
from monai.networks.nets import MAGNUS
model = MAGNUS(
spatial_dims=3,
in_channels=1,
out_channels=2,
features=(64, 128, 256, 512),
)
Test Results
All 17 unit tests pass ✅
Reference
Aras, E., Kayikcioglu, T., Aras, S., & Merd, N. (2026). MAGNUS: Multi-Attention Guided Network for Unified Segmentation via CNN-ViT Fusion. IEEE Access. DOI: 10.1109/ACCESS.2026.3656667