From c9566c698c740f6e8a77a69a6b963d835db894b4 Mon Sep 17 00:00:00 2001 From: ytl0623 Date: Sun, 18 Jan 2026 16:14:26 +0800 Subject: [PATCH 1/4] Add sigmoid/softmax interface for AsymmetricUnifiedFocalLoss Signed-off-by: ytl0623 --- monai/losses/unified_focal_loss.py | 263 +++++++++--------------- tests/losses/test_unified_focal_loss.py | 136 ++++++++++-- 2 files changed, 208 insertions(+), 191 deletions(-) diff --git a/monai/losses/unified_focal_loss.py b/monai/losses/unified_focal_loss.py index 06704c0104..7b3687800f 100644 --- a/monai/losses/unified_focal_loss.py +++ b/monai/losses/unified_focal_loss.py @@ -20,221 +20,146 @@ from monai.utils import LossReduction -class AsymmetricFocalTverskyLoss(_Loss): +class AsymmetricUnifiedFocalLoss(_Loss): """ - AsymmetricFocalTverskyLoss is a variant of FocalTverskyLoss, which attentions to the foreground class. + AsymmetricUnifiedFocalLoss is a variant of Focal Loss that combines Asymmetric Focal Loss + and Asymmetric Focal Tversky Loss to handle imbalanced medical image segmentation. - Actually, it's only supported for binary image segmentation now. + It supports multi-class segmentation by treating channel 0 as background and + channels 1..N as foreground, applying asymmetric weighting controlled by `delta`. - Reimplementation of the Asymmetric Focal Tversky Loss described in: + Reimplementation of the Asymmetric Unified Focal Loss described in: - "Unified Focal Loss: Generalising Dice and Cross Entropy-based Losses to Handle Class Imbalanced Medical Image Segmentation", Michael Yeung, Computerized Medical Imaging and Graphics + + Example: + >>> import torch + >>> from monai.losses import AsymmetricUnifiedFocalLoss + >>> # B, C, H, W = 1, 3, 32, 32 + >>> pred_logits = torch.randn(1, 3, 32, 32) + >>> # Ground truth indices (B, 1, H, W) + >>> grnd = torch.randint(0, 3, (1, 1, 32, 32)) + >>> # Use softmax=True if input is logits + >>> loss_func = AsymmetricUnifiedFocalLoss(to_onehot_y=True, use_softmax=True) + >>> loss = loss_func(pred_logits, grnd) """ def __init__( self, + weight: float = 0.5, + delta: float = 0.6, + gamma: float = 0.5, + include_background: bool = True, to_onehot_y: bool = False, - delta: float = 0.7, - gamma: float = 0.75, - epsilon: float = 1e-7, reduction: LossReduction | str = LossReduction.MEAN, + use_softmax: bool = False, + epsilon: float = 1e-7, ) -> None: """ Args: - to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False. - delta : weight of the background. Defaults to 0.7. - gamma : value of the exponent gamma in the definition of the Focal loss . Defaults to 0.75. - epsilon : it defines a very small number each time. simmily smooth value. Defaults to 1e-7. + weight: The weighting factor between Asymmetric Focal Loss and Asymmetric Focal Tversky Loss. + Final Loss = weight * AFL + (1 - weight) * AFTL. Defaults to 0.5. + delta: The balancing factor controls the weight of background vs foreground classes. + Values > 0.5 give more weight to foreground (False Negatives). Defaults to 0.6. + gamma: The focal exponent. Higher values focus more on hard examples. Defaults to 0.5. + include_background: If False, channel index 0 (background category) is excluded from the loss calculation. + Defaults to True. + to_onehot_y: Whether to convert the label `target` into the one-hot format. Defaults to False. + reduction: {``"none"``, ``"mean"``, ``"sum"``} + Specifies the reduction to apply to the output. Defaults to ``"mean"``. + use_softmax: Whether to use softmax to transform the original logits into probabilities. + If True, softmax is used. If False, assumes input is already probabilities. Defaults to False. + epsilon: Small value to prevent division by zero or log(0). Defaults to 1e-7. """ super().__init__(reduction=LossReduction(reduction).value) - self.to_onehot_y = to_onehot_y + self.weight = weight self.delta = delta self.gamma = gamma + self.include_background = include_background + self.to_onehot_y = to_onehot_y + self.use_softmax = use_softmax self.epsilon = epsilon - def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: - n_pred_ch = y_pred.shape[1] - - if self.to_onehot_y: - if n_pred_ch == 1: - warnings.warn("single channel prediction, `to_onehot_y=True` ignored.") - else: - y_true = one_hot(y_true, num_classes=n_pred_ch) - - if y_true.shape != y_pred.shape: - raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred.shape})") - - # clip the prediction to avoid NaN - y_pred = torch.clamp(y_pred, self.epsilon, 1.0 - self.epsilon) - axis = list(range(2, len(y_pred.shape))) - - # Calculate true positives (tp), false negatives (fn) and false positives (fp) - tp = torch.sum(y_true * y_pred, dim=axis) - fn = torch.sum(y_true * (1 - y_pred), dim=axis) - fp = torch.sum((1 - y_true) * y_pred, dim=axis) - dice_class = (tp + self.epsilon) / (tp + self.delta * fn + (1 - self.delta) * fp + self.epsilon) - - # Calculate losses separately for each class, enhancing both classes - back_dice = 1 - dice_class[:, 0] - fore_dice = (1 - dice_class[:, 1]) * torch.pow(1 - dice_class[:, 1], -self.gamma) - - # Average class scores - loss = torch.mean(torch.stack([back_dice, fore_dice], dim=-1)) - return loss - - -class AsymmetricFocalLoss(_Loss): - """ - AsymmetricFocalLoss is a variant of FocalTverskyLoss, which attentions to the foreground class. - - Actually, it's only supported for binary image segmentation now. - - Reimplementation of the Asymmetric Focal Loss described in: - - - "Unified Focal Loss: Generalising Dice and Cross Entropy-based Losses to Handle Class Imbalanced Medical Image Segmentation", - Michael Yeung, Computerized Medical Imaging and Graphics - """ - - def __init__( - self, - to_onehot_y: bool = False, - delta: float = 0.7, - gamma: float = 2, - epsilon: float = 1e-7, - reduction: LossReduction | str = LossReduction.MEAN, - ): + def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ Args: - to_onehot_y : whether to convert `y` into the one-hot format. Defaults to False. - delta : weight of the background. Defaults to 0.7. - gamma : value of the exponent gamma in the definition of the Focal loss . Defaults to 0.75. - epsilon : it defines a very small number each time. simmily smooth value. Defaults to 1e-7. + input: the shape should be BNH[WD], where N is the number of classes. + target: the shape should be BNH[WD] or B1H[WD]. + + Raises: + ValueError: When input and target have incompatible shapes. """ - super().__init__(reduction=LossReduction(reduction).value) - self.to_onehot_y = to_onehot_y - self.delta = delta - self.gamma = gamma - self.epsilon = epsilon + if self.use_softmax: + input = torch.nn.functional.softmax(input, dim=1) - def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: - n_pred_ch = y_pred.shape[1] + n_pred_ch = input.shape[1] if self.to_onehot_y: if n_pred_ch == 1: warnings.warn("single channel prediction, `to_onehot_y=True` ignored.") else: - y_true = one_hot(y_true, num_classes=n_pred_ch) - - if y_true.shape != y_pred.shape: - raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred.shape})") - - y_pred = torch.clamp(y_pred, self.epsilon, 1.0 - self.epsilon) - cross_entropy = -y_true * torch.log(y_pred) + if target.shape[1] == 1: + target = one_hot(target, num_classes=n_pred_ch) - back_ce = torch.pow(1 - y_pred[:, 0], self.gamma) * cross_entropy[:, 0] - back_ce = (1 - self.delta) * back_ce + if target.shape != input.shape: + raise ValueError(f"ground truth has different shape ({target.shape}) from input ({input.shape})") - fore_ce = cross_entropy[:, 1] - fore_ce = self.delta * fore_ce - - loss = torch.mean(torch.sum(torch.stack([back_ce, fore_ce], dim=1), dim=1)) - return loss - - -class AsymmetricUnifiedFocalLoss(_Loss): - """ - AsymmetricUnifiedFocalLoss is a variant of Focal Loss. + # Clip values for numerical stability + input = torch.clamp(input, self.epsilon, 1.0 - self.epsilon) - Actually, it's only supported for binary image segmentation now + # Part A: Asymmetric Focal Loss + # Cross Entropy: -target * log(input) + cross_entropy = -target * torch.log(input) - Reimplementation of the Asymmetric Unified Focal Tversky Loss described in: + # Background (Channel 0): (1 - delta) * (1 - p)^gamma * CE + back_ce = (1 - self.delta) * torch.pow(1 - input[:, 0:1], self.gamma) * cross_entropy[:, 0:1] - - "Unified Focal Loss: Generalising Dice and Cross Entropy-based Losses to Handle Class Imbalanced Medical Image Segmentation", - Michael Yeung, Computerized Medical Imaging and Graphics - """ + # Foreground (Channel 1..N): delta * CE + fore_ce = self.delta * cross_entropy[:, 1:] - def __init__( - self, - to_onehot_y: bool = False, - num_classes: int = 2, - weight: float = 0.5, - gamma: float = 0.5, - delta: float = 0.7, - reduction: LossReduction | str = LossReduction.MEAN, - ): - """ - Args: - to_onehot_y : whether to convert `y` into the one-hot format. Defaults to False. - num_classes : number of classes, it only supports 2 now. Defaults to 2. - delta : weight of the background. Defaults to 0.7. - gamma : value of the exponent gamma in the definition of the Focal loss. Defaults to 0.75. - epsilon : it defines a very small number each time. simmily smooth value. Defaults to 1e-7. - weight : weight for each loss function, if it's none it's 0.5. Defaults to None. - - Example: - >>> import torch - >>> from monai.losses import AsymmetricUnifiedFocalLoss - >>> pred = torch.ones((1,1,32,32), dtype=torch.float32) - >>> grnd = torch.ones((1,1,32,32), dtype=torch.int64) - >>> fl = AsymmetricUnifiedFocalLoss(to_onehot_y=True) - >>> fl(pred, grnd) - """ - super().__init__(reduction=LossReduction(reduction).value) - self.to_onehot_y = to_onehot_y - self.num_classes = num_classes - self.gamma = gamma - self.delta = delta - self.weight: float = weight - self.asy_focal_loss = AsymmetricFocalLoss(gamma=self.gamma, delta=self.delta) - self.asy_focal_tversky_loss = AsymmetricFocalTverskyLoss(gamma=self.gamma, delta=self.delta) + # Combine + if self.include_background: + asy_focal_loss = torch.cat([back_ce, fore_ce], dim=1) + else: + asy_focal_loss = fore_ce - # TODO: Implement this function to support multiple classes segmentation - def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: - """ - Args: - y_pred : the shape should be BNH[WD], where N is the number of classes. - It only supports binary segmentation. - The input should be the original logits since it will be transformed by - a sigmoid in the forward function. - y_true : the shape should be BNH[WD], where N is the number of classes. - It only supports binary segmentation. + # Part B: Asymmetric Focal Tversky Loss + # Sum over spatial dimensions (Batch and Channel dims are preserved) + reduce_axis = list(range(2, input.dim())) - Raises: - ValueError: When input and target are different shape - ValueError: When len(y_pred.shape) != 4 and len(y_pred.shape) != 5 - ValueError: When num_classes - ValueError: When the number of classes entered does not match the expected number - """ - if y_pred.shape != y_true.shape: - raise ValueError(f"ground truth has different shape ({y_true.shape}) from input ({y_pred.shape})") + tp = torch.sum(target * input, dim=reduce_axis) + fn = torch.sum(target * (1 - input), dim=reduce_axis) + fp = torch.sum((1 - target) * input, dim=reduce_axis) - if len(y_pred.shape) != 4 and len(y_pred.shape) != 5: - raise ValueError(f"input shape must be 4 or 5, but got {y_pred.shape}") + # Tversky Index + dice_class = (tp + self.epsilon) / (tp + self.delta * fn + (1 - self.delta) * fp + self.epsilon) - if y_pred.shape[1] == 1: - y_pred = one_hot(y_pred, num_classes=self.num_classes) - y_true = one_hot(y_true, num_classes=self.num_classes) + # Background: 1 - Dice + back_dice_loss = 1 - dice_class[:, 0:1] - if torch.max(y_true) != self.num_classes - 1: - raise ValueError(f"Please make sure the number of classes is {self.num_classes - 1}") + # Foreground: (1 - Dice)^(1 - gamma) + fore_dice_loss = (1 - dice_class[:, 1:]) * torch.pow(1 - dice_class[:, 1:], -self.gamma) - n_pred_ch = y_pred.shape[1] - if self.to_onehot_y: - if n_pred_ch == 1: - warnings.warn("single channel prediction, `to_onehot_y=True` ignored.") - else: - y_true = one_hot(y_true, num_classes=n_pred_ch) + # Combine + if self.include_background: + asy_focal_tversky_loss = torch.cat([back_dice_loss, fore_dice_loss], dim=1) + else: + asy_focal_tversky_loss = fore_dice_loss - asy_focal_loss = self.asy_focal_loss(y_pred, y_true) - asy_focal_tversky_loss = self.asy_focal_tversky_loss(y_pred, y_true) + # Part C: Unified Combination & Reduction + # Aggregate Focal Loss spatial dimensions to match Tversky Loss shape (B, C) + if asy_focal_loss.dim() > 2: + asy_focal_loss = torch.mean(asy_focal_loss, dim=reduce_axis) - loss: torch.Tensor = self.weight * asy_focal_loss + (1 - self.weight) * asy_focal_tversky_loss + # Weighted sum + total_loss = self.weight * asy_focal_loss + (1 - self.weight) * asy_focal_tversky_loss if self.reduction == LossReduction.SUM.value: - return torch.sum(loss) # sum over the batch and channel dims + return torch.sum(total_loss) if self.reduction == LossReduction.NONE.value: - return loss # returns [N, num_classes] losses + return total_loss if self.reduction == LossReduction.MEAN.value: - return torch.mean(loss) + return torch.mean(total_loss) + raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].') diff --git a/tests/losses/test_unified_focal_loss.py b/tests/losses/test_unified_focal_loss.py index 3b868a560e..2fc3418617 100644 --- a/tests/losses/test_unified_focal_loss.py +++ b/tests/losses/test_unified_focal_loss.py @@ -20,46 +20,138 @@ from monai.losses import AsymmetricUnifiedFocalLoss TEST_CASES = [ - [ # shape: (2, 1, 2, 2), (2, 1, 2, 2) + # Case 0: Binary Classification, Perfect Prediction (Probs) + # Input is already probabilities (use_softmax=False), perfect prediction -> Loss should be close to 0 + [ { - "y_pred": torch.tensor([[[[1.0, 0], [0, 1.0]]], [[[1.0, 0], [0, 1.0]]]]), - "y_true": torch.tensor([[[[1.0, 0], [0, 1.0]]], [[[1.0, 0], [0, 1.0]]]]), + "init_kwargs": {"use_softmax": False, "to_onehot_y": False}, + "forward_kwargs": { + "input": torch.tensor([[[[1.0, 0.0], [0.0, 1.0]], [[0.0, 1.0], [1.0, 0.0]]]]), + "target": torch.tensor([[[[1.0, 0.0], [0.0, 1.0]], [[0.0, 1.0], [1.0, 0.0]]]]), + }, }, 0.0, ], - [ # shape: (2, 1, 2, 2), (2, 1, 2, 2) + # Case 1: Multi-class (3 Classes), Perfect Prediction (Logits) + # Input is Logits (use_softmax=True), large value difference implies high confidence -> Loss should be close to 0 + [ { - "y_pred": torch.tensor([[[[1.0, 0], [0, 1.0]]], [[[1.0, 0], [0, 1.0]]]]), - "y_true": torch.tensor([[[[1.0, 0], [0, 1.0]]], [[[1.0, 0], [0, 1.0]]]]), + "init_kwargs": {"use_softmax": True, "to_onehot_y": False}, + "forward_kwargs": { + # Logits: Large positive values indicate high probability + "input": torch.tensor( + [[[[10.0, -10.0], [-10.0, -10.0]], [[-10.0, 10.0], [-10.0, -10.0]], [[-10.0, -10.0], [10.0, 10.0]]]] + ), + "target": torch.tensor( + [[[[1.0, 0.0], [0.0, 0.0]], [[0.0, 1.0], [0.0, 0.0]], [[0.0, 0.0], [1.0, 1.0]]]] + ), + }, + }, + 0.0, + ], + # Case 2: Label Indices Input (to_onehot_y=True) + # Test automatic conversion from Index to One-Hot + [ + { + "init_kwargs": {"use_softmax": False, "to_onehot_y": True}, + "forward_kwargs": { + "input": torch.tensor([[[[1.0, 0.0], [0.0, 0.0]], [[0.0, 1.0], [0.0, 0.0]], [[0.0, 0.0], [1.0, 1.0]]]]), + "target": torch.tensor([[[[0, 1], [2, 2]]]]), # Shape (1, 1, 2, 2) + }, }, 0.0, ], ] +TEST_CASES_REDUCTION = [ + # Case: Reduction = 'none' + # Output shape should be (B, C) + [ + { + "init_kwargs": {"reduction": "none", "use_softmax": False}, + "forward_kwargs": { + "input": torch.randn(2, 3, 4, 4).sigmoid(), # B=2, C=3 + "target": torch.randint(0, 2, (2, 3, 4, 4)).float(), + }, + }, + (2, 3), + ], + # Case: Reduction = 'none' AND include_background=False + # Output shape should be (B, C-1) -> (2, 2) + [ + { + "init_kwargs": {"reduction": "none", "include_background": False, "use_softmax": False}, + "forward_kwargs": { + "input": torch.randn(2, 3, 4, 4).sigmoid(), + "target": torch.randint(0, 2, (2, 3, 4, 4)).float(), + }, + }, + (2, 2), + ], +] + class TestAsymmetricUnifiedFocalLoss(unittest.TestCase): @parameterized.expand(TEST_CASES) - def test_result(self, input_data, expected_val): - loss = AsymmetricUnifiedFocalLoss() - result = loss(**input_data) - np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, atol=1e-4, rtol=1e-4) + def test_result(self, input_params, expected_val): + """Test numerical accuracy of the loss.""" + init_kwargs = input_params.get("init_kwargs", {}) + forward_kwargs = input_params.get("forward_kwargs", {}) + + loss_func = AsymmetricUnifiedFocalLoss(**init_kwargs) + result = loss_func(**forward_kwargs) + + np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, atol=1e-3, rtol=1e-4) + + @parameterized.expand(TEST_CASES_REDUCTION) + def test_reduction_shape(self, input_params, expected_shape): + """Test output shapes under different Reduction modes.""" + init_kwargs = input_params.get("init_kwargs", {}) + forward_kwargs = input_params.get("forward_kwargs", {}) + + loss_func = AsymmetricUnifiedFocalLoss(**init_kwargs) + result = loss_func(**forward_kwargs) + + self.assertEqual(result.shape, expected_shape, msg=f"Expected shape {expected_shape} but got {result.shape}") def test_ill_shape(self): - loss = AsymmetricUnifiedFocalLoss() - with self.assertRaisesRegex(ValueError, ""): - loss(torch.ones((2, 2, 2)), torch.ones((2, 2, 2, 2))) + """Test handling of incorrect shapes.""" + loss_func = AsymmetricUnifiedFocalLoss() + with self.assertRaisesRegex(ValueError, "ground truth has different shape"): + loss_func(torch.ones((2, 2, 2)), torch.ones((2, 2, 2, 2))) + + def test_mismatch_shape(self): + """Test completely mismatched input and target shapes.""" + loss_func = AsymmetricUnifiedFocalLoss() + with self.assertRaisesRegex(ValueError, "ground truth has different shape"): + loss_func(torch.ones((1, 2, 4, 4)), torch.ones((1, 2, 3, 3))) + + def test_script(self): + """Test TorchScript compatibility.""" + loss_func = AsymmetricUnifiedFocalLoss() + input_data = torch.rand(1, 2, 4, 4) + target_data = torch.rand(1, 2, 4, 4) + try: + scripted_loss = torch.jit.script(loss_func) + scripted_loss(input_data, target_data) + except Exception as e: + self.fail(f"TorchScript failed: {e}") def test_with_cuda(self): - loss = AsymmetricUnifiedFocalLoss() - i = torch.tensor([[[[1.0, 0], [0, 1.0]]], [[[1.0, 0], [0, 1.0]]]]) - j = torch.tensor([[[[1.0, 0], [0, 1.0]]], [[[1.0, 0], [0, 1.0]]]]) - if torch.cuda.is_available(): - i = i.cuda() - j = j.cuda() - output = loss(i, j) - print(output) - np.testing.assert_allclose(output.detach().cpu().numpy(), 0.0, atol=1e-4, rtol=1e-4) + """Test CUDA support.""" + if not torch.cuda.is_available(): + return + + loss_func = AsymmetricUnifiedFocalLoss() + input_data = torch.rand(1, 2, 4, 4).cuda() + target_data = torch.rand(1, 2, 4, 4).cuda() + + try: + output = loss_func(input_data, target_data) + self.assertTrue(output.is_cuda, "Output should be on CUDA") + except Exception as e: + self.fail(f"CUDA forward pass failed: {e}") if __name__ == "__main__": From 358cf8b50f22271da0942edcb2e06893a217a7ad Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 18 Jan 2026 08:17:56 +0000 Subject: [PATCH 2/4] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/losses/unified_focal_loss.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/losses/unified_focal_loss.py b/monai/losses/unified_focal_loss.py index 7b3687800f..fbaa335eec 100644 --- a/monai/losses/unified_focal_loss.py +++ b/monai/losses/unified_focal_loss.py @@ -32,7 +32,7 @@ class AsymmetricUnifiedFocalLoss(_Loss): - "Unified Focal Loss: Generalising Dice and Cross Entropy-based Losses to Handle Class Imbalanced Medical Image Segmentation", Michael Yeung, Computerized Medical Imaging and Graphics - + Example: >>> import torch >>> from monai.losses import AsymmetricUnifiedFocalLoss From 05142ed20dbfcc7098dc66b392294977d9d580ea Mon Sep 17 00:00:00 2001 From: ytl0623 Date: Sun, 18 Jan 2026 17:36:06 +0800 Subject: [PATCH 3/4] add imperfect prediction case and fix bugs Signed-off-by: ytl0623 --- monai/losses/unified_focal_loss.py | 10 ++++++++-- tests/losses/test_unified_focal_loss.py | 25 ++++++++++++++++++++----- 2 files changed, 28 insertions(+), 7 deletions(-) diff --git a/monai/losses/unified_focal_loss.py b/monai/losses/unified_focal_loss.py index fbaa335eec..baa1ff77be 100644 --- a/monai/losses/unified_focal_loss.py +++ b/monai/losses/unified_focal_loss.py @@ -73,6 +73,12 @@ def __init__( epsilon: Small value to prevent division by zero or log(0). Defaults to 1e-7. """ super().__init__(reduction=LossReduction(reduction).value) + if not 0 <= weight <= 1: + raise ValueError(f"weight must be in [0, 1], got {weight}") + if not 0 <= delta <= 1: + raise ValueError(f"delta must be in [0, 1], got {delta}") + if gamma <= 0: + raise ValueError(f"gamma must be > 0, got {gamma}") self.weight = weight self.delta = delta self.gamma = gamma @@ -97,7 +103,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: if self.to_onehot_y: if n_pred_ch == 1: - warnings.warn("single channel prediction, `to_onehot_y=True` ignored.") + warnings.warn("single channel prediction, `to_onehot_y=True` ignored.", stacklevel=2) else: if target.shape[1] == 1: target = one_hot(target, num_classes=n_pred_ch) @@ -139,7 +145,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: back_dice_loss = 1 - dice_class[:, 0:1] # Foreground: (1 - Dice)^(1 - gamma) - fore_dice_loss = (1 - dice_class[:, 1:]) * torch.pow(1 - dice_class[:, 1:], -self.gamma) + fore_dice_loss = torch.pow(torch.clamp(1 - dice_class[:, 1:], min=self.epsilon), 1 - self.gamma) # Combine if self.include_background: diff --git a/tests/losses/test_unified_focal_loss.py b/tests/losses/test_unified_focal_loss.py index 2fc3418617..a8f223f7fa 100644 --- a/tests/losses/test_unified_focal_loss.py +++ b/tests/losses/test_unified_focal_loss.py @@ -20,7 +20,7 @@ from monai.losses import AsymmetricUnifiedFocalLoss TEST_CASES = [ - # Case 0: Binary Classification, Perfect Prediction (Probs) + # Case: Binary Classification, Perfect Prediction (Probs) # Input is already probabilities (use_softmax=False), perfect prediction -> Loss should be close to 0 [ { @@ -32,7 +32,7 @@ }, 0.0, ], - # Case 1: Multi-class (3 Classes), Perfect Prediction (Logits) + # Case: Multi-class (3 Classes), Perfect Prediction (Logits) # Input is Logits (use_softmax=True), large value difference implies high confidence -> Loss should be close to 0 [ { @@ -49,7 +49,7 @@ }, 0.0, ], - # Case 2: Label Indices Input (to_onehot_y=True) + # Case: Label Indices Input (to_onehot_y=True) # Test automatic conversion from Index to One-Hot [ { @@ -61,6 +61,21 @@ }, 0.0, ], + # Case: Imperfect Prediction + [ + { + "init_kwargs": {"use_softmax": True, "to_onehot_y": False}, + "forward_kwargs": { + "input": torch.tensor( + [[[[ -10.0, -10.0]], [[ 10.0, 10.0]], [[ -10.0, -10.0]]]] # B=1, C=3, H=1, W=2 + ), + "target": torch.tensor( + [[[[ 1.0, 1.0]], [[ 0.0, 0.0]], [[ 0.0, 0.0]]]] + ), + }, + }, + 1.518984, + ], ] TEST_CASES_REDUCTION = [ @@ -70,7 +85,7 @@ { "init_kwargs": {"reduction": "none", "use_softmax": False}, "forward_kwargs": { - "input": torch.randn(2, 3, 4, 4).sigmoid(), # B=2, C=3 + "input": torch.randn(2, 3, 4, 4).sigmoid(), "target": torch.randint(0, 2, (2, 3, 4, 4)).float(), }, }, @@ -141,7 +156,7 @@ def test_script(self): def test_with_cuda(self): """Test CUDA support.""" if not torch.cuda.is_available(): - return + self.skipTest("CUDA not available") loss_func = AsymmetricUnifiedFocalLoss() input_data = torch.rand(1, 2, 4, 4).cuda() From 307cb54f7a69a573eb82139f7dcab6804b6497d2 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 18 Jan 2026 09:36:36 +0000 Subject: [PATCH 4/4] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: ytl0623 --- tests/losses/test_unified_focal_loss.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/tests/losses/test_unified_focal_loss.py b/tests/losses/test_unified_focal_loss.py index a8f223f7fa..7d12824ad4 100644 --- a/tests/losses/test_unified_focal_loss.py +++ b/tests/losses/test_unified_focal_loss.py @@ -56,7 +56,7 @@ "init_kwargs": {"use_softmax": False, "to_onehot_y": True}, "forward_kwargs": { "input": torch.tensor([[[[1.0, 0.0], [0.0, 0.0]], [[0.0, 1.0], [0.0, 0.0]], [[0.0, 0.0], [1.0, 1.0]]]]), - "target": torch.tensor([[[[0, 1], [2, 2]]]]), # Shape (1, 1, 2, 2) + "target": torch.tensor([[[[0, 1], [2, 2]]]]), }, }, 0.0, @@ -66,15 +66,11 @@ { "init_kwargs": {"use_softmax": True, "to_onehot_y": False}, "forward_kwargs": { - "input": torch.tensor( - [[[[ -10.0, -10.0]], [[ 10.0, 10.0]], [[ -10.0, -10.0]]]] # B=1, C=3, H=1, W=2 - ), - "target": torch.tensor( - [[[[ 1.0, 1.0]], [[ 0.0, 0.0]], [[ 0.0, 0.0]]]] - ), + "input": torch.tensor([[[[-10.0, -10.0]], [[10.0, 10.0]], [[-10.0, -10.0]]]]), + "target": torch.tensor([[[[1.0, 1.0]], [[0.0, 0.0]], [[0.0, 0.0]]]]), }, }, - 1.518984, + 1.518984, ], ]