From 25c9b1cbbbdefb37425f5f397944d2d7fb5f154a Mon Sep 17 00:00:00 2001 From: Colin Son Date: Sun, 21 Jun 2026 10:38:40 -0500 Subject: [PATCH 1/2] Add ConformalRiskCalibrator and ConformalRiskPredictor for conformal risk control prediction sets Signed-off-by: Colin Son --- docs/source/metrics.rst | 20 ++ monai/metrics/__init__.py | 8 + monai/metrics/conformal_risk.py | 477 +++++++++++++++++++++++++++ tests/metrics/test_conformal_risk.py | 292 ++++++++++++++++ 4 files changed, 797 insertions(+) create mode 100644 monai/metrics/conformal_risk.py create mode 100644 tests/metrics/test_conformal_risk.py diff --git a/docs/source/metrics.rst b/docs/source/metrics.rst index 654958bbbf..a5c3fa91ab 100644 --- a/docs/source/metrics.rst +++ b/docs/source/metrics.rst @@ -203,6 +203,26 @@ Metrics :members: +`Conformal Risk Control` +------------------------ +.. autoclass:: ConformalRiskCalibrator + :members: + +.. autoclass:: ConformalRiskPredictor + :members: + :special-members: __call__ + +.. autoclass:: Coverage + :members: + +.. autoclass:: SetSize + :members: + +.. autofunction:: compute_coverage + +.. autofunction:: compute_set_size + + Utilities --------- .. automodule:: monai.metrics.utils diff --git a/monai/metrics/__init__.py b/monai/metrics/__init__.py index 2265dd3a3f..d7a8d94f8e 100644 --- a/monai/metrics/__init__.py +++ b/monai/metrics/__init__.py @@ -14,6 +14,14 @@ from .active_learning_metrics import LabelQualityScore, VarianceMetric, compute_variance, label_quality_score from .average_precision import AveragePrecisionMetric, compute_average_precision from .calibration import CalibrationErrorMetric, CalibrationReduction, calibration_binning +from .conformal_risk import ( + ConformalRiskCalibrator, + ConformalRiskPredictor, + Coverage, + SetSize, + compute_coverage, + compute_set_size, +) from .confusion_matrix import ConfusionMatrixMetric, compute_confusion_matrix_metric, get_confusion_matrix from .cumulative_average import CumulativeAverage from .embedding_collapse import EmbeddingCollapseMetric, compute_embedding_collapse diff --git a/monai/metrics/conformal_risk.py b/monai/metrics/conformal_risk.py new file mode 100644 index 0000000000..52e2e02b1d --- /dev/null +++ b/monai/metrics/conformal_risk.py @@ -0,0 +1,477 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Conformal risk control for segmentation / classification (issue #8935, part 2). + +Implements the recipe of Angelopoulos, Bates, Lei, Wasserman & Jordan, +"Conformal Risk Control" (arXiv:2208.02814, 2022): pick a single threshold +``lambda_hat`` on a held-out calibration split via the finite-sample-corrected +selection ``lambda_hat = inf { lambda : (n * R_hat(lambda) + B) / (n + 1) <= alpha }`` +(``B = 1`` is the loss upper bound), which guarantees ``E[L] <= alpha`` on a +fresh sample. At inference, the same threshold yields a prediction set per +voxel / per sample, and the per-voxel *uncertainty mask* flags locations where +the set contains more than one class (i.e. the model is ambiguous). + +This module mirrors the :class:`ConformalPredictor` / :class:`ConformalCalibrator` +split from ``monai/inferers/conformal_predictor.py`` but lives in ``metrics`` +because the calibration target is a *loss bounded on a calibration set* +rather than a marginal-coverage quantile, and because the natural outputs +(``Coverage``, ``SetSize``) are evaluation metrics. +""" + +from __future__ import annotations + +from collections.abc import Callable +from typing import Any + +import torch + +from monai.metrics.metric import CumulativeIterationMetric +from monai.utils import MetricReduction +from monai.utils.module import optional_import + +__all__ = [ + "ConformalRiskCalibrator", + "ConformalRiskPredictor", + "Coverage", + "SetSize", + "compute_set_size", + "compute_coverage", +] + +tqdm, has_tqdm = optional_import("tqdm", name="tqdm") + + +# ---------------------------------------------------------------------------------------------- +# Losses for conformal risk control. Each returns a per-sample (image-level) loss in [0, 1] +# given the prediction-set mask (B, C, spatial...) and the integer label (B, 1, spatial...). +# ---------------------------------------------------------------------------------------------- + + +def _set_from_threshold(scores: torch.Tensor, lam: float) -> torch.Tensor: + """Boolean prediction set ``{ y : score(y) <= lam }``, shape (..., C). + + ``scores`` is the non-conformity score tensor with class as the last dim; + ``lam`` is a scalar threshold. Returns a bool tensor of the same shape. + """ + return scores <= lam + + +def _flatten_spatial(sets: torch.Tensor, labels: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """Bring ``(B, C, spatial...)`` set mask and ``(B, 1, spatial...)``/``(B, spatial...)`` labels + to ``(N, C)`` and ``(N,)`` so the same loss code serves classification and segmentation.""" + if sets.ndim < 2: + raise ValueError(f"sets must be (B, C, spatial...), got shape {tuple(sets.shape)}.") + c = sets.shape[1] + sets_flat = sets.movedim(1, -1).reshape(-1, c) + labels_flat = labels.reshape(-1).long() + labels_flat = labels_flat.clamp(min=0, max=c - 1) + return sets_flat, labels_flat + + +def miscoverage_loss(sets: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: + """Per-image mean miscoverage ``mean_voxel 1{ y_v not in S_v }`` in [0, 1]. + + This is the canonical loss for conformal risk control on a classification/ + segmentation output (Angelopoulos et al. 2022, Eq. 11). ``sets`` is a bool + tensor ``(B, C, spatial...)`` and ``labels`` is ``(B, 1, spatial...)`` or + ``(B, spatial...)`` integer class indices. Returns ``(B,)`` per-image loss. + """ + sets_flat, labels_flat = _flatten_spatial(sets, labels) + b = sets.shape[0] + n_per_image = sets_flat.shape[0] // b + covered = sets_flat.gather(1, labels_flat.unsqueeze(1)).squeeze(1).bool() + miss = (~covered).float() + return miss.reshape(b, n_per_image).mean(dim=1) + + +def false_negative_loss(sets: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: + """Per-image false-negative rate among foreground voxels, in [0, 1]. + + Useful when background dominates and a pure miscoverage target is too lax for + the classes that matter. Voxels where ``labels == 0`` are excluded from the + denominator. Returns ``(B,)`` per-image loss; images with no foreground + voxels get ``0`` (so they do not push ``lambda`` down). + """ + sets_flat, labels_flat = _flatten_spatial(sets, labels) + b = sets.shape[0] + n_per_image = sets_flat.shape[0] // b + covered = sets_flat.gather(1, labels_flat.unsqueeze(1)).squeeze(1).bool() + fg = labels_flat != 0 + miss = (~covered).float() + miss = miss * fg.float() + miss = miss.reshape(b, n_per_image) + denom = fg.float().reshape(b, n_per_image).sum(dim=1).clamp(min=1.0) + return miss.sum(dim=1) / denom + + +_LOSSES: dict[str, Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = { + "miscoverage": miscoverage_loss, + "false_negative": false_negative_loss, +} + + +# ---------------------------------------------------------------------------------------------- +# Calibration / prediction +# ---------------------------------------------------------------------------------------------- + + +class ConformalRiskCalibrator: + """Calibrate a single threshold ``lambda_hat`` that bounds an image-level loss on a + held-out split, following Conformal Risk Control (Angelopoulos et al. 2022, arXiv:2208.02814). + + Unlike split-conformal (which targets marginal coverage via a quantile), risk control + picks ``lambda_hat = inf { lambda : (n * R_hat(lambda) + B) / (n + 1) <= alpha }`` where + ``R_hat(lambda) = (1/n) sum_i L(y_i, S_lambda(x_i))`` and ``B = 1`` bounds the loss; this + guarantees ``E[L(Y, S_lambda_hat(X))] <= alpha`` on a fresh sample. The threshold is + global — one scalar applied to every voxel / sample at inference. When ``alpha`` is too + small for the calibration size (``alpha < 1 / (n + 1)``) no threshold satisfies the bound + and :meth:`calibrate` falls back to the largest grid value (full sets). + + The non-conformity score is ``1 - softmax[y]`` (LAC, same as + :class:`monai.inferers.ConformalCalibrator`); the set at threshold ``lambda`` is + ``S_lambda(x) = { y : 1 - softmax[y] <= lambda }``. + + Args: + alpha: target risk, e.g. ``0.1`` bounds the expected loss at ~``0.1``. + loss: image-level loss bounded in [0, 1]. Either a callable + ``(sets, labels) -> (B,)`` tensor, or one of ``"miscoverage"`` / + ``"false_negative"``. + include_background: when ``False`` drop background-labeled (class 0) voxels from + the score pool before computing the loss. Defaults to ``True``. + lam_grid: grid of candidate thresholds (1-D tensor in ``[0, 1]``) used for the + ``inf`` search. Defaults to ``torch.linspace(0, 1, 101)``. Finer grids give a + tighter bound at a small compute cost. + + Example: + + .. code-block:: python + + import torch + from monai.metrics import ConformalRiskCalibrator + + cal = ConformalRiskCalibrator(alpha=0.1, loss="miscoverage") + for batch in cal_loader: + probs = model(batch["image"]).softmax(dim=1) + cal.accumulate(probs, batch["label"]) + lam = cal.calibrate() + # lam is a scalar threshold; pass it to ConformalRiskPredictor + + References: + - Angelopoulos, A.; Bates, S.; Lei, J.; Wasserman, L.; Jordan, M. "Conformal + Risk Control." arXiv:2208.02814, 2022. https://arxiv.org/abs/2208.02814 + """ + + def __init__( + self, + alpha: float = 0.1, + loss: str | Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = "miscoverage", + include_background: bool = True, + lam_grid: torch.Tensor | None = None, + ) -> None: + if not 0.0 < alpha < 1.0: + raise ValueError(f"alpha must be in (0, 1), got {alpha}.") + if isinstance(loss, str): + if loss not in _LOSSES: + raise ValueError(f"Unknown loss {loss!r}; available: {sorted(_LOSSES)}.") + loss_fn = _LOSSES[loss] + elif callable(loss): + loss_fn = loss + else: + raise TypeError(f"loss must be a str or callable, got {type(loss)}.") + self.alpha = float(alpha) + self.loss_fn = loss_fn + self.include_background = include_background + if lam_grid is None: + lam_grid = torch.linspace(0.0, 1.0, 101) + if lam_grid.ndim != 1 or (lam_grid < 0).any() or (lam_grid > 1).any(): + raise ValueError("lam_grid must be a 1-D tensor with values in [0, 1].") + self.lam_grid = lam_grid.float() + # Per-image score/label tensors, stored one entry per calibration image so spatial + # size may vary across images and across accumulate() calls (variable-size volumes). + self._scores: list[torch.Tensor] = [] # each (P_i, C) + self._labels: list[torch.Tensor] = [] # each (P_i,) + self._num_classes: int | None = None + + def accumulate(self, probs: torch.Tensor, labels: torch.Tensor) -> None: + """Accumulate calibration data from one batch. + + Spatial size may differ from batch to batch; each image is stored separately so + calibration works on variable-size volumes. The channel count ``C`` must stay fixed. + + Args: + probs: softmax probabilities ``(B, C, spatial...)`` in [0, 1] summing to 1 over C. + labels: integer class indices ``(B, 1, spatial...)`` or ``(B, spatial...)`` in [0, C). + """ + if probs.ndim < 2: + raise ValueError(f"probs must be (B, C, spatial...), got shape {tuple(probs.shape)}.") + b, c = probs.shape[:2] + if self._num_classes is None: + self._num_classes = c + elif c != self._num_classes: + raise ValueError(f"channel count C changed across accumulate() calls: {self._num_classes} -> {c}.") + spatial = probs.shape[2:] + per_image = int(torch.tensor(spatial).prod().item()) if spatial else 1 + # (B, per_image, C): move class to last then flatten spatial + scores = (1.0 - probs).movedim(1, -1).reshape(b, per_image, c).detach() + # labels (B, 1, spatial...) or (B, spatial...) -> (B, per_image) + labels_flat = labels.reshape(b, per_image).long().clamp(min=0, max=c - 1).detach() + for i in range(b): + self._scores.append(scores[i]) # (per_image, C) + self._labels.append(labels_flat[i]) # (per_image,) + + def calibrate(self) -> torch.Tensor: + """Search ``lam_grid`` for the smallest threshold whose risk-controlled bound holds. + + Selects ``lambda_hat = inf { lambda : (n * R_hat(lambda) + B) / (n + 1) <= alpha }`` + with ``B = 1`` (Angelopoulos et al. 2022, Thm 1), which bounds the expected loss on a + fresh sample by ``alpha``. + + Returns: + Scalar tensor ``lambda_hat``. If no grid point satisfies the finite-sample bound + (only possible when ``alpha < 1 / (n + 1)`` — the calibration set is too small for + the requested risk), the largest grid value is returned (full sets); callers + should check the achieved risk with :class:`Coverage` / :class:`SetSize`. + """ + if not self._scores: + raise RuntimeError("No calibration data accumulated; call accumulate(probs, labels) first.") + n = len(self._scores) + device, dtype = self._scores[0].device, self._scores[0].dtype + lam_grid = self.lam_grid.to(device) + n_lam = lam_grid.numel() + # Sum each image's per-lambda loss; images vary in size so we loop per image but + # vectorize over the whole lambda grid (n_lam acts as the batch dim into loss_fn). + risk_sum = torch.zeros(n_lam, device=device, dtype=torch.float32) + for scores_i, labels_i in zip(self._scores, self._labels): + if not self.include_background: + keep = labels_i != 0 + if not bool(keep.any()): + continue # all-background image: 0 loss, but still counted in n + scores_i, labels_i = scores_i[keep], labels_i[keep] + sets = scores_i.unsqueeze(0) <= lam_grid.view(-1, 1, 1) # (n_lam, P_i, C) + sets_shaped = sets.movedim(-1, 1) # (n_lam, C, P_i) + labels_rep = labels_i.view(1, 1, -1).expand(n_lam, 1, -1) # (n_lam, 1, P_i) + risk_sum += self.loss_fn(sets_shaped, labels_rep).float() + emp_risk = risk_sum / n + # Finite-sample-corrected selection. B = 1 is the loss upper bound (losses are in + # [0, 1]); losses are non-increasing in lambda, so the leftmost lambda clearing the + # bound is the infimum. + b_bound = 1.0 + alpha_eff = ((n + 1) * self.alpha - b_bound) / n + within = (emp_risk <= alpha_eff).nonzero(as_tuple=True)[0] + if within.numel() == 0: + lam_hat = lam_grid[-1] + else: + lam_hat = lam_grid[within[0]] + self.reset() # ponytail: one-shot; caller keeps lam_hat + return lam_hat.to(dtype).to(device) + + def reset(self) -> None: + self._scores, self._labels = [], [] + self._num_classes = None + + +class ConformalRiskPredictor: + """Apply a pre-calibrated threshold ``lambda_hat`` at inference and return both the + prediction set and the per-voxel *uncertainty mask*. + + The uncertainty mask flags voxels where the prediction set contains more than one + class — i.e. the model cannot commit to a single label at the calibrated risk level. + Voxels in the mask are candidates for review, defer-to-human, or downstream refinement. + + This is intentionally *not* an :class:`monai.inferers.Inferer` subclass: it does not + own a network. Pair it with any inferer (e.g. ``SimpleInferer`` or + ``SlidingWindowInferer``) that produces logits, then call this on the softmax. + + Args: + lam: calibrated threshold (scalar tensor). Required. + include_background: if ``False``, background voxels (label 0 at inference = argmax + 0) are excluded from the uncertainty mask. Defaults to ``True``. + + Example: + + .. code-block:: python + + import torch + from monai.inferers import SlidingWindowInferer + from monai.metrics import ConformalRiskPredictor + + lam = torch.tensor(0.4) + crp = ConformalRiskPredictor(lam=lam) + with torch.no_grad(): + logits = sliding_inferer(imgs, model) + sets, mask, probs = crp(logits.softmax(dim=1)) + # sets: (B, C, ...) bool, mask: (B, 1, ...) bool, probs: (B, C, ...) float + """ + + def __init__(self, lam: torch.Tensor, include_background: bool = True) -> None: + self.set_threshold(lam) + self.include_background = include_background + + def set_threshold(self, lam: torch.Tensor) -> None: + """Set (or update) the calibrated threshold.""" + if not isinstance(lam, torch.Tensor): + raise TypeError(f"lam must be a torch.Tensor, got {type(lam)}.") + self.lam = lam.detach().clone() + + def __call__(self, probs: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Run inference-time risk-controlled prediction. + + Args: + probs: softmax probabilities ``(B, C, spatial...)`` in [0, 1] summing to 1 over C. + + Returns: + A 3-tuple ``(sets, uncertainty_mask, probs)``. ``sets`` is a bool tensor + ``(B, C, spatial...)`` with ``True`` where class ``c`` is in the set. + ``uncertainty_mask`` is a bool tensor ``(B, 1, spatial...)`` with ``True`` where the + set holds more than one class (ambiguous voxels), zeroed at background-argmax voxels + when ``include_background=False``. ``probs`` is the input, returned for convenience + (e.g. for :class:`Coverage` / :class:`SetSize`). + """ + if probs.ndim < 2: + raise ValueError(f"probs must be (B, C, spatial...), got shape {tuple(probs.shape)}.") + lam = self.lam.to(probs.device, probs.dtype) + sets = (1.0 - probs) <= lam + # per-voxel set size > 1 -> ambiguous + set_size = sets.sum(dim=1, keepdim=True) + uncertainty_mask = set_size > 1 + if not self.include_background: + # background voxels are where argmax == 0; zero them out of the mask + argmax = probs.argmax(dim=1, keepdim=True) + uncertainty_mask = uncertainty_mask & (argmax != 0) + return sets, uncertainty_mask, probs + + +# ---------------------------------------------------------------------------------------------- +# Evaluation metrics for the prediction sets. These are CumulativeIterationMetrics so they +# compose with MONAI's evaluator / handler infrastructure like DiceMetric. +# ---------------------------------------------------------------------------------------------- + + +def compute_coverage(sets: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: + """Per-image fraction of voxels whose true label is in the prediction set, in [0, 1]. + + Args: + sets: bool tensor ``(B, C, spatial...)``. + labels: integer class indices ``(B, 1, spatial...)`` or ``(B, spatial...)`` in [0, C). + + Returns: + ``(B,)`` per-image coverage. Higher is better (1 = full coverage). + """ + sets_flat, labels_flat = _flatten_spatial(sets, labels) + b = sets.shape[0] + n_per_image = sets_flat.shape[0] // b + covered = sets_flat.gather(1, labels_flat.unsqueeze(1)).squeeze(1).bool() + return covered.float().reshape(b, n_per_image).mean(dim=1) + + +def compute_set_size(sets: torch.Tensor) -> torch.Tensor: + """Per-image mean prediction-set size (number of classes in the set per voxel), in [0, C]. + + Args: + sets: bool tensor ``(B, C, spatial...)``. + + Returns: + ``(B,)`` per-image mean set size. Smaller is better (tight sets). + """ + b = sets.shape[0] + sizes = sets.sum(dim=1).float() # (B, spatial...) + return sizes.reshape(b, -1).mean(dim=1) + + +class Coverage(CumulativeIterationMetric): + """Cumulative per-image coverage of conformal prediction sets. + + Coverage = fraction of voxels / samples whose true label is inside the prediction set. + For a well-calibrated split-conformal or risk-controlled predictor this should be + ``>= 1 - alpha`` on a held-out test set (split-conformal) or satisfy the risk bound + (risk control). Useful as a sanity check after calibration. + + Args: + metric_reduction: reduction across batch/channel dims on ``aggregate()``. + Defaults to ``"mean"``. + get_not_nans: if ``True``, ``aggregate()`` returns ``(metric, not_nans)``. + + Example: + + .. code-block:: python + + import torch + from monai.metrics import Coverage, ConformalRiskPredictor + + cov = Coverage() + predictor = ConformalRiskPredictor(lam=torch.tensor(0.4)) + for batch in test_loader: + probs = model(batch["image"]).softmax(dim=1) + sets, _, _ = predictor(probs) + cov(sets, batch["label"]) + print(cov.aggregate()) + """ + + def __init__( + self, metric_reduction: MetricReduction | str = MetricReduction.MEAN, get_not_nans: bool = False + ) -> None: + super().__init__() + self.metric_reduction = metric_reduction + self.get_not_nans = get_not_nans + + def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor, **kwargs: Any) -> torch.Tensor: # type: ignore[override] + if not isinstance(y_pred, torch.Tensor) or not isinstance(y, torch.Tensor): + raise TypeError("Coverage expects torch.Tensor inputs (sets, labels).") + return compute_coverage(y_pred, y).unsqueeze(1) # (B, 1) for do_metric_reduction + + def aggregate( + self, reduction: MetricReduction | str | None = None + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + from monai.metrics.utils import do_metric_reduction + + data = self.get_buffer() + if not isinstance(data, torch.Tensor): + raise ValueError("the data to aggregate must be a PyTorch Tensor.") + f, not_nans = do_metric_reduction(data, reduction or self.metric_reduction) + return (f, not_nans) if self.get_not_nans else f + + +class SetSize(CumulativeIterationMetric): + """Cumulative per-image mean prediction-set size. + + Set size = average number of classes in the prediction set per voxel / sample. Smaller + is better (tighter sets). Use alongside :class:`Coverage` to check the coverage / + efficiency trade-off of a conformal predictor. + + Args: + metric_reduction: reduction across batch/channel dims on ``aggregate()``. + Defaults to ``"mean"``. + get_not_nans: if ``True``, ``aggregate()`` returns ``(metric, not_nans)``. + """ + + def __init__( + self, metric_reduction: MetricReduction | str = MetricReduction.MEAN, get_not_nans: bool = False + ) -> None: + super().__init__() + self.metric_reduction = metric_reduction + self.get_not_nans = get_not_nans + + def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor | None = None, **kwargs: Any) -> torch.Tensor: # type: ignore[override] + if not isinstance(y_pred, torch.Tensor): + raise TypeError("SetSize expects a torch.Tensor input (sets).") + return compute_set_size(y_pred).unsqueeze(1) + + def aggregate( + self, reduction: MetricReduction | str | None = None + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + from monai.metrics.utils import do_metric_reduction + + data = self.get_buffer() + if not isinstance(data, torch.Tensor): + raise ValueError("the data to aggregate must be a PyTorch Tensor.") + f, not_nans = do_metric_reduction(data, reduction or self.metric_reduction) + return (f, not_nans) if self.get_not_nans else f diff --git a/tests/metrics/test_conformal_risk.py b/tests/metrics/test_conformal_risk.py new file mode 100644 index 0000000000..72458a2689 --- /dev/null +++ b/tests/metrics/test_conformal_risk.py @@ -0,0 +1,292 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch + +from monai.metrics import ( + ConformalRiskCalibrator, + ConformalRiskPredictor, + Coverage, + SetSize, + compute_coverage, + compute_set_size, +) +from monai.metrics.conformal_risk import false_negative_loss, miscoverage_loss +from tests.test_utils import assert_allclose + + +class TestMiscoverageLoss(unittest.TestCase): + def test_full_coverage_zero_loss(self): + # 2 images, 3 classes, 2 voxels. True class 0 always included. + sets = torch.ones(2, 3, 2, dtype=torch.bool) + labels = torch.zeros(2, 1, 2, dtype=torch.long) + loss = miscoverage_loss(sets, labels) + assert_allclose(loss, torch.zeros(2), atol=1e-6) + + def test_no_coverage_unit_loss(self): + sets = torch.zeros(2, 3, 2, dtype=torch.bool) # empty sets + labels = torch.zeros(2, 1, 2, dtype=torch.long) + loss = miscoverage_loss(sets, labels) + assert_allclose(loss, torch.ones(2), atol=1e-6) + + def test_half_coverage(self): + # image 0: voxel 0 covered (class 0 in set), voxel 1 not -> loss 0.5 + sets = torch.tensor([[[True, False], [False, False], [False, False]]]) # (1, 3, 2) + labels = torch.zeros(1, 1, 2, dtype=torch.long) + loss = miscoverage_loss(sets, labels) + assert_allclose(loss, torch.tensor([0.5]), atol=1e-6) + + +class TestFalseNegativeLoss(unittest.TestCase): + def test_ignores_background(self): + # image: voxel 0 bg (label 0), voxel 1 fg class 1 not covered -> FNR 1.0 + sets = torch.zeros(1, 3, 2, dtype=torch.bool) + labels = torch.tensor([[[0, 1]]]) + loss = false_negative_loss(sets, labels) + assert_allclose(loss, torch.tensor([1.0]), atol=1e-6) + + def test_no_foreground_zero_loss(self): + # all background -> denom 0 -> loss 0 by convention + sets = torch.zeros(1, 3, 2, dtype=torch.bool) + labels = torch.zeros(1, 1, 2, dtype=torch.long) + loss = false_negative_loss(sets, labels) + assert_allclose(loss, torch.tensor([0.0]), atol=1e-6) + + +class TestConformalRiskCalibrator(unittest.TestCase): + def test_calibrate_miscoverage_monotone(self): + # 4 images, 2 classes, 2 voxels each. True class 0; softmax[0] gives true-prob + # 0.55, 0.65, 0.75, 0.85 -> scores 0.45, 0.35, 0.25, 0.15. With grid step 0.1: + # at lam=0.2 only the 0.15 image is covered -> risk 0.75; at lam=0.3 the + # 0.25 and 0.15 images are covered -> risk 0.5; at lam=0.4 three covered -> risk 0.25. + # Smallest grid point with risk <= alpha_eff (0.25, see below) is 0.4. + p = [0.55, 0.65, 0.75, 0.85] + probs = torch.stack([torch.tensor([[pi, pi], [1 - pi, 1 - pi]]) for pi in p]) # (4, C=2, spatial=2) + labels = torch.zeros(4, 1, 2, dtype=torch.long) + # n=4, B=1: alpha_eff = (5*0.4 - 1)/4 = 0.25; emp_risk hits 0.25 first at lam=0.4. + cal = ConformalRiskCalibrator(alpha=0.4, loss="miscoverage", lam_grid=torch.linspace(0.0, 1.0, 11)) # step 0.1 + cal.accumulate(probs, labels) + lam = cal.calibrate() + assert_allclose(lam, torch.tensor(0.4), atol=1e-6) + + def test_calibrate_returns_scalar_in_grid(self): + probs = torch.rand(5, 3, 2, 2) + probs = probs / probs.sum(dim=1, keepdim=True) + labels = torch.randint(0, 3, (5, 1, 2, 2)) + cal = ConformalRiskCalibrator(alpha=0.1) + cal.accumulate(probs, labels) + lam = cal.calibrate() + self.assertEqual(lam.ndim, 0) + self.assertTrue(0.0 <= lam.item() <= 1.0) + + def test_calibrate_multi_batch(self): + # two batches each with 2 images; same probs as the monotone test split in half + probs_a = torch.stack([torch.tensor([[0.55, 0.55], [0.45, 0.45]]), torch.tensor([[0.65, 0.65], [0.35, 0.35]])]) + probs_b = torch.stack([torch.tensor([[0.75, 0.75], [0.25, 0.25]]), torch.tensor([[0.85, 0.85], [0.15, 0.15]])]) + labels = torch.zeros(2, 1, 2, dtype=torch.long) + cal = ConformalRiskCalibrator(alpha=0.4, lam_grid=torch.linspace(0.0, 1.0, 11)) + cal.accumulate(probs_a, labels) + cal.accumulate(probs_b, labels) + lam = cal.calibrate() + assert_allclose(lam, torch.tensor(0.4), atol=1e-6) + + def test_include_background_drops_bg_voxels(self): + # 1 image, 3 classes, 2 voxels. voxel 0 bg (label 0) has a HARD true class (score 0.5); + # voxel 1 fg class 1 is easy (score 0.2). n=1, alpha=0.5 -> alpha_eff = (2*0.5-1)/1 = 0, + # so we need risk exactly 0. Dropping the bg voxel lets lam = 0.2; keeping it forces + # lam = 0.5 to also cover the hard bg voxel -> the flag genuinely changes the result. + probs = torch.tensor([[[0.5, 0.1], [0.3, 0.8], [0.2, 0.1]]]) # (1, 3, 2) + labels = torch.tensor([[[0, 1]]]) # (1, 1, 2) + grid = torch.linspace(0.0, 1.0, 11) + lam_drop = ConformalRiskCalibrator(alpha=0.5, loss="miscoverage", include_background=False, lam_grid=grid) + lam_drop.accumulate(probs, labels) + assert_allclose(lam_drop.calibrate(), torch.tensor(0.2), atol=1e-6) + lam_keep = ConformalRiskCalibrator(alpha=0.5, loss="miscoverage", include_background=True, lam_grid=grid) + lam_keep.accumulate(probs, labels) + assert_allclose(lam_keep.calibrate(), torch.tensor(0.5), atol=1e-6) + + def test_variable_size_volumes(self): + # Calibration images with DIFFERENT spatial sizes (4 vs 9 voxels) must work — each image + # is stored separately. label 0 everywhere; class-0 scores: imgs 0.3, 0.4 (batch1), 0.1 + # (batch2). n=3, alpha=0.5 -> alpha_eff = (4*0.5-1)/3 = 1/3; emp_risk first <= 1/3 at lam=0.3. + def img(p0, n): # one image (C=2, n voxels) with class-0 prob p0 + return torch.stack([torch.full((n,), p0), torch.full((n,), 1 - p0)]) + + probs_1 = torch.stack([img(0.7, 4), img(0.6, 4)]) # (2, 2, 4) -> class-0 scores 0.3, 0.4 + probs_2 = img(0.9, 9).unsqueeze(0) # (1, 2, 9) -> class-0 score 0.1 + cal = ConformalRiskCalibrator(alpha=0.5, loss="miscoverage", lam_grid=torch.linspace(0.0, 1.0, 11)) + cal.accumulate(probs_1, torch.zeros(2, 1, 4, dtype=torch.long)) + cal.accumulate(probs_2, torch.zeros(1, 1, 9, dtype=torch.long)) + lam = cal.calibrate() + assert_allclose(lam, torch.tensor(0.3), atol=1e-6) + + def test_false_negative_loss_calibrate(self): + # 2 images, 3 classes, 2 voxels. img0 bg-only, img1 fg class 1 prob 0.65 -> score 0.35. + # n=2, alpha=0.5 -> alpha_eff = (3*0.5-1)/2 = 0.25. emp_risk = 0.5*1{0.35 > lam}, so it + # clears 0.25 first at lam=0.4 (img1's fg voxel covered). + probs = torch.tensor( + [ + [[0.9, 0.9], [0.05, 0.05], [0.05, 0.05]], # img0: both voxels bg-dominant + [[0.25, 0.25], [0.65, 0.65], [0.1, 0.1]], # img1: both voxels class 1 prob 0.65 + ] + ) # (2, 3, 2) + labels = torch.tensor([[[0, 0]], [[1, 1]]]) # (2, 1, 2) + cal = ConformalRiskCalibrator(alpha=0.5, loss="false_negative", lam_grid=torch.linspace(0.0, 1.0, 11)) + cal.accumulate(probs, labels) + lam = cal.calibrate() + assert_allclose(lam, torch.tensor(0.4), atol=1e-6) + + def test_rejects_bad_alpha(self): + with self.assertRaises(ValueError): + ConformalRiskCalibrator(alpha=0.0) + with self.assertRaises(ValueError): + ConformalRiskCalibrator(alpha=1.0) + + def test_rejects_unknown_loss(self): + with self.assertRaises(ValueError): + ConformalRiskCalibrator(loss="brier") + + def test_rejects_bad_lam_grid(self): + with self.assertRaises(ValueError): + ConformalRiskCalibrator(lam_grid=torch.tensor([[-0.1], [0.5]])) + + def test_calibrate_empty_raises(self): + with self.assertRaises(RuntimeError): + ConformalRiskCalibrator().calibrate() + + def test_reset_clears_buffers(self): + cal = ConformalRiskCalibrator() + # probs (1, 2, 1): 1 image, 2 classes, 1 voxel; labels (1, 1, 1) + cal.accumulate(torch.tensor([[[0.5], [0.5]]]), torch.zeros(1, 1, 1, dtype=torch.long)) + cal.reset() + with self.assertRaises(RuntimeError): + cal.calibrate() + + +class TestConformalRiskPredictor(unittest.TestCase): + def test_sets_and_mask(self): + # 1 image, 3 classes, 3 voxels. + # voxel 0: probs [0.9, 0.05, 0.05] -> scores [0.1, 0.95, 0.95] -> only class 0 in set + # voxel 1: probs [0.5, 0.3, 0.2] -> scores [0.5, 0.7, 0.8] -> only class 0 in set (lam=0.6) + # voxel 2: probs [0.4, 0.4, 0.2] -> scores [0.6, 0.6, 0.8] -> classes 0,1 in set (ambiguous) + probs = torch.tensor([[[0.9, 0.5, 0.4], [0.05, 0.3, 0.4], [0.05, 0.2, 0.2]]]) # (1, 3, 3): B=1, C=3, spatial=3 + lam = torch.tensor(0.6) + predictor = ConformalRiskPredictor(lam=lam) + sets, mask, probs_out = predictor(probs) + self.assertEqual(sets.shape, (1, 3, 3)) + self.assertEqual(mask.shape, (1, 1, 3)) + self.assertEqual(mask.dtype, torch.bool) + # voxel 0: only class 0 -> not ambiguous + self.assertFalse(mask[0, 0, 0].item()) + # voxel 1: only class 0 -> not ambiguous + self.assertFalse(mask[0, 0, 1].item()) + # voxel 2: classes 0 and 1 -> ambiguous + self.assertTrue(mask[0, 0, 2].item()) + # sets correct + self.assertTrue(sets[0, 0].all()) # class 0 always in set + self.assertFalse(sets[0, 1, 0].item()) + self.assertTrue(sets[0, 1, 2].item()) # class 1 in set for voxel 2 + self.assertFalse(sets[0, 2].any().item()) # class 2 never in set + + def test_include_background_zeroes_bg_argmax_mask(self): + # 1 image, 3 classes, 2 voxels. voxel 0 argmax 0 (bg), set ambiguous; with + # include_background=False mask should be False there. + # voxel 0: probs [0.4, 0.3, 0.3] -> scores [0.6, 0.7, 0.7] -> set {0,1,2} at lam=0.75 + # voxel 1: probs [0.1, 0.8, 0.1] -> scores [0.9, 0.2, 0.9] -> set {1} only -> not ambiguous + probs = torch.tensor([[[0.4, 0.1], [0.3, 0.8], [0.3, 0.1]]]) # (1, 3, 2) + lam = torch.tensor(0.75) + predictor = ConformalRiskPredictor(lam=lam, include_background=False) + _, mask, _ = predictor(probs) + # voxel 0: argmax==0 (bg) -> masked out despite ambiguous set + self.assertFalse(mask[0, 0, 0].item()) + # voxel 1: only class 1 in set -> not ambiguous + self.assertFalse(mask[0, 0, 1].item()) + + def test_set_threshold(self): + predictor = ConformalRiskPredictor(lam=torch.tensor(0.5)) + predictor.set_threshold(torch.tensor(0.9)) + assert_allclose(predictor.lam, torch.tensor(0.9), atol=1e-6) + + def test_rejects_non_tensor_lam(self): + with self.assertRaises(TypeError): + ConformalRiskPredictor(lam=0.5) + + def test_rejects_bad_probs_shape(self): + predictor = ConformalRiskPredictor(lam=torch.tensor(0.5)) + with self.assertRaises(ValueError): + predictor(torch.zeros(5)) # 1-D + + +class TestCoverageMetric(unittest.TestCase): + def test_compute_coverage_full(self): + sets = torch.ones(2, 3, 4, dtype=torch.bool) + labels = torch.zeros(2, 1, 4, dtype=torch.long) + cov = compute_coverage(sets, labels) + assert_allclose(cov, torch.ones(2), atol=1e-6) + + def test_compute_coverage_partial(self): + # image 0: voxel 0 covered (class 0 in set), voxel 1 not -> cov 0.5 + sets = torch.tensor([[[True, False], [False, False], [False, False]]]) + labels = torch.zeros(1, 1, 2, dtype=torch.long) + cov = compute_coverage(sets, labels) + assert_allclose(cov, torch.tensor([0.5]), atol=1e-6) + + def test_coverage_aggregate_mean(self): + metric = Coverage() + sets = torch.ones(2, 3, 4, dtype=torch.bool) + labels = torch.zeros(2, 1, 4, dtype=torch.long) + metric(sets, labels) + result = metric.aggregate() + assert_allclose(result, torch.tensor(1.0), atol=1e-6) + + def test_coverage_aggregate_get_not_nans(self): + metric = Coverage(get_not_nans=True) + sets = torch.ones(1, 2, 2, dtype=torch.bool) + labels = torch.zeros(1, 1, 2, dtype=torch.long) + metric(sets, labels) + result, not_nans = metric.aggregate() + self.assertEqual(not_nans.item(), 1.0) + assert_allclose(result, torch.tensor(1.0), atol=1e-6) + + +class TestSetSizeMetric(unittest.TestCase): + def test_compute_set_size(self): + # 2 images, 3 classes, 2 voxels. + # img0: voxel0 {0,1} size 2, voxel1 {0} size 1 -> mean 1.5 + # img1: voxel0 {0,1,2} size 3, voxel1 {0,1} size 2 -> mean 2.5 + sets = torch.tensor( + [ + [[True, True], [True, False], [False, False]], # img0: voxel0 {0,1}, voxel1 {0} + [[True, True], [True, True], [True, False]], # img1: voxel0 {0,1,2}, voxel1 {0,1} + ] + ) # (2, 3, 2) + sizes = compute_set_size(sets) + assert_allclose(sizes, torch.tensor([1.5, 2.5]), atol=1e-6) + + def test_setsize_aggregate_mean(self): + metric = SetSize() + sets = torch.ones(2, 3, 4, dtype=torch.bool) + metric(sets, None) + result = metric.aggregate() + assert_allclose(result, torch.tensor(3.0), atol=1e-6) + + def test_setsize_rejects_non_tensor(self): + metric = SetSize() + with self.assertRaises(TypeError): + metric._compute_tensor([[True, True]], None) # type: ignore[arg-type] + + +if __name__ == "__main__": + unittest.main() From 6b0fc270ab2e1fdbdcfb30103d20f1e6539e7c7f Mon Sep 17 00:00:00 2001 From: Colin Son Date: Sun, 21 Jun 2026 16:04:58 -0500 Subject: [PATCH 2/2] Address Code Rabbit review: validate labels, lam_grid, loss_fn; chunk lambda loop; sort __all__; scalar/range check set_threshold; zip(strict=True); docstrings - conformal_risk.py: reject out-of-range labels instead of silent clamp (lines ~76, ~226) - conformal_risk.py: chunk the lambda grid in calibrate() to avoid materializing (n_lam, P_i, C) at once - conformal_risk.py: validate lam_grid is non-empty and sorted ascending (prevents IndexError and wrong infimum) - conformal_risk.py: validate loss_fn output shape and NaN after each call - conformal_risk.py: enforce set_threshold lam is scalar in [0, 1] - conformal_risk.py: zip(strict=True) over _scores/_labels - conformal_risk.py: alphabetical __all__ (RUF022), reset() docstring - test_conformal_risk.py: assert predictor returns input probs unchanged (RUF059) Signed-off-by: Colin Son --- monai/metrics/conformal_risk.py | 62 +++++++++++++++++++++++----- tests/metrics/test_conformal_risk.py | 4 ++ 2 files changed, 55 insertions(+), 11 deletions(-) diff --git a/monai/metrics/conformal_risk.py b/monai/metrics/conformal_risk.py index 52e2e02b1d..d29a012511 100644 --- a/monai/metrics/conformal_risk.py +++ b/monai/metrics/conformal_risk.py @@ -43,8 +43,8 @@ "ConformalRiskPredictor", "Coverage", "SetSize", - "compute_set_size", "compute_coverage", + "compute_set_size", ] tqdm, has_tqdm = optional_import("tqdm", name="tqdm") @@ -73,7 +73,10 @@ def _flatten_spatial(sets: torch.Tensor, labels: torch.Tensor) -> tuple[torch.Te c = sets.shape[1] sets_flat = sets.movedim(1, -1).reshape(-1, c) labels_flat = labels.reshape(-1).long() - labels_flat = labels_flat.clamp(min=0, max=c - 1) + if (labels_flat < 0).any() or (labels_flat >= c).any(): + raise ValueError( + f"labels must lie in [0, {c - 1}], got min={int(labels_flat.min())}, max={int(labels_flat.max())}." + ) return sets_flat, labels_flat @@ -192,8 +195,10 @@ def __init__( self.include_background = include_background if lam_grid is None: lam_grid = torch.linspace(0.0, 1.0, 101) - if lam_grid.ndim != 1 or (lam_grid < 0).any() or (lam_grid > 1).any(): - raise ValueError("lam_grid must be a 1-D tensor with values in [0, 1].") + if lam_grid.ndim != 1 or lam_grid.numel() == 0 or (lam_grid < 0).any() or (lam_grid > 1).any(): + raise ValueError("lam_grid must be a non-empty 1-D tensor with values in [0, 1].") + if not bool((lam_grid[1:] >= lam_grid[:-1]).all()): + raise ValueError("lam_grid must be sorted in ascending order for the infimum search.") self.lam_grid = lam_grid.float() # Per-image score/label tensors, stored one entry per calibration image so spatial # size may vary across images and across accumulate() calls (variable-size volumes). @@ -223,7 +228,11 @@ def accumulate(self, probs: torch.Tensor, labels: torch.Tensor) -> None: # (B, per_image, C): move class to last then flatten spatial scores = (1.0 - probs).movedim(1, -1).reshape(b, per_image, c).detach() # labels (B, 1, spatial...) or (B, spatial...) -> (B, per_image) - labels_flat = labels.reshape(b, per_image).long().clamp(min=0, max=c - 1).detach() + labels_flat = labels.reshape(b, per_image).long().detach() + if (labels_flat < 0).any() or (labels_flat >= c).any(): + raise ValueError( + f"labels must lie in [0, {c - 1}], got min={int(labels_flat.min())}, max={int(labels_flat.max())}." + ) for i in range(b): self._scores.append(scores[i]) # (per_image, C) self._labels.append(labels_flat[i]) # (per_image,) @@ -250,16 +259,32 @@ def calibrate(self) -> torch.Tensor: # Sum each image's per-lambda loss; images vary in size so we loop per image but # vectorize over the whole lambda grid (n_lam acts as the batch dim into loss_fn). risk_sum = torch.zeros(n_lam, device=device, dtype=torch.float32) - for scores_i, labels_i in zip(self._scores, self._labels): + # ponytail: chunk over the lambda grid to bound peak memory; the full + # (n_lam, P_i, C) tensor would OOM on large 3D volumes. 1 << 12 lambdas + # at a time keeps the working set modest while preserving the cumulative + # sum; lower if calibration volumes are very large. + lam_chunk = 1 << 12 + for scores_i, labels_i in zip(self._scores, self._labels, strict=True): if not self.include_background: keep = labels_i != 0 if not bool(keep.any()): continue # all-background image: 0 loss, but still counted in n scores_i, labels_i = scores_i[keep], labels_i[keep] - sets = scores_i.unsqueeze(0) <= lam_grid.view(-1, 1, 1) # (n_lam, P_i, C) - sets_shaped = sets.movedim(-1, 1) # (n_lam, C, P_i) - labels_rep = labels_i.view(1, 1, -1).expand(n_lam, 1, -1) # (n_lam, 1, P_i) - risk_sum += self.loss_fn(sets_shaped, labels_rep).float() + p_i = scores_i.shape[0] + for start in range(0, n_lam, lam_chunk): + end = min(start + lam_chunk, n_lam) + lam_chunk_grid = lam_grid[start:end] # (n_chunk,) + sets = scores_i.unsqueeze(0) <= lam_chunk_grid.view(-1, 1, 1) # (n_chunk, P_i, C) + sets_shaped = sets.movedim(-1, 1) # (n_chunk, C, P_i) + labels_rep = labels_i.view(1, 1, -1).expand(sets_shaped.shape[0], 1, p_i) # (n_chunk, 1, P_i) + loss = self.loss_fn(sets_shaped, labels_rep).float() + if loss.shape != (sets_shaped.shape[0],): + raise ValueError( + f"loss_fn must return per-image loss of shape (n_chunk,), got {tuple(loss.shape)}." + ) + if bool(torch.isnan(loss).any()): + raise ValueError("loss_fn returned NaN; check inputs or loss implementation.") + risk_sum[start:end] += loss emp_risk = risk_sum / n # Finite-sample-corrected selection. B = 1 is the loss upper bound (losses are in # [0, 1]); losses are non-increasing in lambda, so the leftmost lambda clearing the @@ -275,6 +300,11 @@ def calibrate(self) -> torch.Tensor: return lam_hat.to(dtype).to(device) def reset(self) -> None: + """Reset internal calibration state. + + Clears the per-image score/label buffers and the cached class count so + the calibrator can be reused on a fresh calibration split. + """ self._scores, self._labels = [], [] self._num_classes = None @@ -317,9 +347,19 @@ def __init__(self, lam: torch.Tensor, include_background: bool = True) -> None: self.include_background = include_background def set_threshold(self, lam: torch.Tensor) -> None: - """Set (or update) the calibrated threshold.""" + """Set (or update) the calibrated threshold. + + Args: + lam: scalar tensor in ``[0, 1]``. A non-scalar would broadcast over + spatial dims at inference and silently produce wrong sets. + """ if not isinstance(lam, torch.Tensor): raise TypeError(f"lam must be a torch.Tensor, got {type(lam)}.") + if lam.ndim != 0: + raise ValueError(f"lam must be a scalar tensor, got shape {tuple(lam.shape)}.") + lam_val = float(lam.detach().item()) + if not 0.0 <= lam_val <= 1.0: + raise ValueError(f"lam must lie in [0, 1], got {lam_val}.") self.lam = lam.detach().clone() def __call__(self, probs: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: diff --git a/tests/metrics/test_conformal_risk.py b/tests/metrics/test_conformal_risk.py index 72458a2689..3e6a7d0098 100644 --- a/tests/metrics/test_conformal_risk.py +++ b/tests/metrics/test_conformal_risk.py @@ -188,6 +188,10 @@ def test_sets_and_mask(self): self.assertEqual(sets.shape, (1, 3, 3)) self.assertEqual(mask.shape, (1, 1, 3)) self.assertEqual(mask.dtype, torch.bool) + # contract: predictor returns the input probs unchanged + self.assertIsInstance(probs_out, torch.Tensor) + self.assertEqual(probs_out.shape, (1, 3, 3)) + assert_allclose(probs_out, probs, atol=0.0) # voxel 0: only class 0 -> not ambiguous self.assertFalse(mask[0, 0, 0].item()) # voxel 1: only class 0 -> not ambiguous