diff --git a/docs/source/inferers.rst b/docs/source/inferers.rst index 326f56e96c..f478e93bc0 100644 --- a/docs/source/inferers.rst +++ b/docs/source/inferers.rst @@ -73,6 +73,17 @@ Inferers :members: :special-members: __call__ +`ConformalPredictor` +~~~~~~~~~~~~~~~~~~~~~ +.. autoclass:: ConformalPredictor + :members: + :special-members: __call__ + +`ConformalCalibrator` +~~~~~~~~~~~~~~~~~~~~~~ +.. autoclass:: ConformalCalibrator + :members: + Splitters --------- .. currentmodule:: monai.inferers diff --git a/monai/inferers/__init__.py b/monai/inferers/__init__.py index fc78b9f7c4..bfc5c4d671 100644 --- a/monai/inferers/__init__.py +++ b/monai/inferers/__init__.py @@ -11,6 +11,7 @@ from __future__ import annotations +from .conformal_predictor import ConformalCalibrator, ConformalPredictor from .inferer import ( ControlNetDiffusionInferer, ControlNetLatentDiffusionInferer, diff --git a/monai/inferers/conformal_predictor.py b/monai/inferers/conformal_predictor.py new file mode 100644 index 0000000000..cbd6c9274b --- /dev/null +++ b/monai/inferers/conformal_predictor.py @@ -0,0 +1,290 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import math +from collections.abc import Callable +from typing import Any + +import torch + +from monai.inferers.inferer import Inferer +from monai.utils.module import optional_import + +__all__ = ["ConformalPredictor", "ConformalCalibrator"] + +tqdm, has_tqdm = optional_import("tqdm", name="tqdm") + + +def _quantile_threshold(scores: torch.Tensor, alpha: float) -> torch.Tensor: + """Split-conformal threshold ``qhat`` at the ceil((n+1)(1-alpha))/n quantile of ``scores``. + + Implements the finite-sample marginal coverage guarantee of split conformal prediction + (Vovk et al. 2005; Angelopoulos & Bates 2021). ``scores`` is a 1-D tensor of + non-conformity scores from the held-out calibration split; the returned scalar + lives on the same device/dtype as ``scores``. + + Args: + scores: 1-D tensor of non-conformity scores from the calibration split. + alpha: mis-coverage level in ``(0, 1)``. + + Returns: + Scalar tensor ``qhat`` on the same device/dtype as ``scores``. + + Raises: + ValueError: if ``scores`` is empty or ``alpha`` is not in ``(0, 1)``. + """ + n = scores.numel() + if n <= 0: + raise ValueError("Cannot calibrate from an empty calibration set.") + if not 0.0 < alpha < 1.0: + raise ValueError(f"alpha must be in (0, 1), got {alpha}.") + rank = math.ceil((n + 1) * (1.0 - alpha)) + rank = max(1, min(rank, n)) # clamp into [1, n]; kthvalue is 1-indexed + return torch.kthvalue(scores.float(), rank).values.to(scores.dtype) + + +class ConformalCalibrator: + """Collect softmax on a held-out calibration split and turn it into a split-conformal + threshold using the LAC (Least Ambiguous set-Valued Classifier) non-conformity score + ``s_i = 1 - p_i(y_i)`` (Sadinle et al. 2019, arXiv:1905.12581). + + The calibrator is decoupled from any network/transform pipeline: it consumes softmax + probabilities and integer labels, so it works for image-level classification directly + and for per-voxel segmentation by reshaping ``(B, C, spatial...)`` to ``(N, C)``. + + Args: + alpha: mis-coverage level, e.g. ``0.1`` gives 90% marginal coverage. + score: non-conformity score, currently only ``"lac"`` (``1 - softmax[y]``). + include_background: when ``False`` exclude background-labeled (class 0) voxels from the + calibration set (useful for segmentation where background voxels dominate). + + Example: + + .. code-block:: python + + import torch + from monai.inferers.conformal_predictor import ConformalCalibrator + + cal = ConformalCalibrator(alpha=0.1) + for batch in cal_loader: + logits = model(batch["image"]) # (B, C, ...) + probs = logits.softmax(dim=1) + cal.accumulate(probs, batch["label"]) # label: (B, 1, ...) int + qhat = cal.calibrate() + # qhat is the threshold used by ConformalPredictor + + """ + + def __init__(self, alpha: float = 0.1, score: str = "lac", include_background: bool = True) -> None: + if score != "lac": + raise ValueError(f"Unsupported score {score!r}; only 'lac' (1 - softmax[y]) is implemented.") + self.alpha = float(alpha) + self.score = score + self.include_background = include_background + self._scores: list[torch.Tensor] = [] + + def accumulate(self, probs: torch.Tensor, labels: torch.Tensor) -> None: + """Accumulate non-conformity scores from one calibration batch. + + Args: + probs: softmax probabilities ``(B, C, spatial...)`` in ``[0, 1]`` summing to 1 over C. + labels: integer class indices ``(B, 1, spatial...)`` or ``(B, spatial...)`` with values + in ``[0, C)``. Shape must broadcast against ``probs`` spatial dims. + + Raises: + ValueError: if ``probs`` has fewer than 2 dimensions. + """ + if probs.ndim < 2: + raise ValueError(f"probs must be (B, C, spatial...), got shape {tuple(probs.shape)}.") + c = probs.shape[1] + # flatten to (N, C) + probs_flat = probs.reshape(probs.shape[0], c, -1).movedim(1, -1).reshape(-1, c) + labels_flat = labels.reshape(-1).long() + if not self.include_background: + # drop background-labeled voxels (class 0) from calibration so the threshold isn't + # dominated by easy background; the full softmax is kept so ``1 - softmax[y]`` stays a + # valid LAC score. (Relabeling/renormalizing instead would corrupt the scores.) + keep = labels_flat != 0 + probs_flat = probs_flat[keep] + labels_flat = labels_flat[keep] + # reject invalid labels (negative or >= C) outright rather than silently clamping them, + # which would corrupt the non-conformity scores. + valid = (labels_flat >= 0) & (labels_flat < c) + probs_flat = probs_flat[valid] + labels_flat = labels_flat[valid] + if labels_flat.numel() == 0: + return # nothing to accumulate this batch (all labels invalid or all bg-excluded) + true_p = probs_flat.gather(1, labels_flat.unsqueeze(1)).squeeze(1) + # move to CPU to avoid GPU OOM on large per-voxel calibration sets + self._scores.append((1.0 - true_p).detach().cpu()) + + def calibrate(self) -> torch.Tensor: + """Return the split-conformal threshold ``qhat`` from all accumulated scores. + + Returns: + Scalar tensor ``qhat`` on CPU. + + Raises: + RuntimeError: if no calibration scores have been accumulated. + """ + if not self._scores: + raise RuntimeError("No calibration scores accumulated; call accumulate(probs, labels) first.") + all_scores = torch.cat(self._scores) + qhat = _quantile_threshold(all_scores, self.alpha) + self._scores = [] # ponytail: one-shot; caller keeps qhat + return qhat + + def reset(self) -> None: + self._scores = [] + + +class ConformalPredictor(Inferer): + """Inferer that wraps a network and a pre-calibrated split-conformal threshold ``qhat`` + to return prediction sets with a marginal coverage guarantee ``1 - alpha``. + + This implements the LAC (Least Ambiguous set-Valued Classifier) recipe from issue #8935: + + 1. (out of band) calibrate ``qhat`` on a held-out split with + :class:`ConformalCalibrator` using non-conformity score ``1 - softmax[y]``. + 2. at inference compute ``softmax(logits)`` and return, for each sample, the set + ``{ y : 1 - softmax[y] <= qhat }``. + + Args: + qhat: pre-calibrated threshold (scalar tensor). Pass ``None`` to use ``alpha`` together + with calibration data passed via :meth:`calibrate`; otherwise ``qhat`` is used directly. + alpha: mis-coverage level used only by :meth:`calibrate` when ``qhat`` is ``None``. + score: non-conformity score, currently only ``"lac"``. + include_background: forwarded to :class:`ConformalCalibrator` for :meth:`calibrate`. + + The inferer calls ``network(inputs, *args, **kwargs)`` exactly as :class:`SimpleInferer` + does, then applies ``softmax`` over channel dim 1 and builds the set. + + References: + - Sadinle, M.; Lei, J.; Wasserman, L. "Least Ambiguous Set-Valued Classifiers with + Bounded Error Rates." arXiv:1905.12581, 2019. https://arxiv.org/abs/1905.12581 + - Angelopoulos, A.; Bates, S. "A Gentle Introduction to Conformal Prediction and + Distribution-Free Uncertainty Quantification." arXiv:2107.07511, 2021. + + Example: + + .. code-block:: python + + import torch + from monai.inferers import ConformalPredictor + + qhat = torch.tensor(0.65) + inferer = ConformalPredictor(qhat=qhat) + with torch.no_grad(): + sets = inferer(imgs, model) # sets: (B, C, ...) bool + # sets[b, c, ...] True means class c is in the prediction set at that location. + + """ + + def __init__( + self, qhat: torch.Tensor | None = None, alpha: float = 0.1, score: str = "lac", include_background: bool = True + ) -> None: + Inferer.__init__(self) + self.score = score + self.alpha = float(alpha) + self.include_background = include_background + self.qhat: torch.Tensor | None = None + if qhat is not None: + self.set_threshold(qhat) + + def set_threshold(self, qhat: torch.Tensor) -> None: + """Set (or update) the calibrated threshold. Lets you keep one inferer and re-calibrate. + + Args: + qhat: scalar ``torch.Tensor`` threshold. + + Raises: + TypeError: if ``qhat`` is not a ``torch.Tensor``. + ValueError: if ``qhat`` is not a scalar (has more than one element). + """ + if not isinstance(qhat, torch.Tensor): + raise TypeError(f"qhat must be a torch.Tensor, got {type(qhat)}.") + if qhat.numel() != 1: + raise ValueError(f"qhat must be a scalar tensor, got {qhat.numel()} elements.") + self.qhat = qhat.detach().clone() + + def calibrate( + self, network: torch.nn.Module, cal_loader: Any, device: torch.device | str | None = None + ) -> torch.Tensor: + """Run the network on ``cal_loader`` to calibrate ``qhat`` in-band. + + Args: + network: ``nn.Module`` returning logits ``(B, C, spatial...)`` (needs ``.parameters()`` + for device inference and ``.eval()``). + cal_loader: iterable yielding ``dict``-like items with keys ``"image"`` and ``"label"``. + ``"label"`` is integer class indices ``(B, 1, ...)`` or ``(B, ...)``. + device: device to run the network on; defaults to ``next(network.parameters()).device``. + + Returns: + The calibrated ``qhat`` (also stored and used by subsequent ``__call__`` invocations). + + Raises: + TypeError: if the network returns a non-Tensor. + """ + if device is None: + param = next(network.parameters(), None) + device = param.device if param is not None else torch.device("cpu") + was_training = getattr(network, "training", False) + network.eval() + cal = ConformalCalibrator(alpha=self.alpha, score=self.score, include_background=self.include_background) + iterator = cal_loader + if has_tqdm: + iterator = tqdm(cal_loader, desc="conformal calibration") + with torch.no_grad(): + for batch in iterator: + imgs = batch["image"] + if isinstance(imgs, torch.Tensor): + imgs = imgs.to(device) + logits = network(imgs) + if isinstance(logits, torch.Tensor): + probs = logits.softmax(dim=1).detach() + else: # ponytail: dict/tuple outputs left as a follow-up if needed + raise TypeError(f"network must return a Tensor of logits, got {type(logits)}.") + cal.accumulate(probs.to(device), batch["label"].to(device)) + qhat = cal.calibrate() + self.set_threshold(qhat) + if was_training: + network.train() + return qhat + + def __call__( + self, inputs: torch.Tensor, network: Callable[..., torch.Tensor], *args: Any, **kwargs: Any + ) -> torch.Tensor: + """Run inference and return the prediction-set mask. + + Args: + inputs: input batch ``(B, C_in, spatial...)``. + network: callable returning logits ``(B, C, spatial...)``. + args/kwargs: forwarded to ``network``. + + Returns: + sets: bool tensor ``(B, C, spatial...)``, ``True`` where class ``c`` is in the set. + For the underlying softmax, call ``network(inputs).softmax(1)``. + + Raises: + RuntimeError: if no threshold has been set. + TypeError: if the network returns a non-Tensor. + """ + if self.qhat is None: + raise RuntimeError("No threshold set; call set_threshold(qhat) or calibrate(...) first.") + logits = network(inputs, *args, **kwargs) + if not isinstance(logits, torch.Tensor): + raise TypeError(f"network must return a Tensor of logits, got {type(logits)}.") + probs = logits.softmax(dim=1) + qhat = self.qhat.to(probs.device, probs.dtype) + sets: torch.Tensor = (1.0 - probs) <= qhat + return sets diff --git a/tests/inferers/test_conformal_predictor.py b/tests/inferers/test_conformal_predictor.py new file mode 100644 index 0000000000..5e44d12959 --- /dev/null +++ b/tests/inferers/test_conformal_predictor.py @@ -0,0 +1,234 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch + +from monai.inferers import ConformalCalibrator, ConformalPredictor +from monai.inferers.conformal_predictor import _quantile_threshold +from tests.test_utils import assert_allclose + + +class TestQuantileThreshold(unittest.TestCase): + def test_exact_coverage_rank(self): + # alpha=0.1, n=9 -> rank = ceil(10*0.9) = 9 -> kth=9 -> scores[8] + scores = torch.arange(1, 10, dtype=torch.float32) # 1..9 + qhat = _quantile_threshold(scores, alpha=0.1) + assert_allclose(qhat, torch.tensor(9.0)) + + def test_clamps_to_valid_rank(self): + # alpha very small -> rank clamps to n + scores = torch.linspace(0.0, 1.0, 5) + qhat = _quantile_threshold(scores, alpha=1e-6) + assert_allclose(qhat, scores.max()) + + def test_rejects_empty(self): + with self.assertRaises(ValueError): + _quantile_threshold(torch.empty(0), alpha=0.1) + + def test_rejects_bad_alpha(self): + with self.assertRaises(ValueError): + _quantile_threshold(torch.ones(4), alpha=0.0) + with self.assertRaises(ValueError): + _quantile_threshold(torch.ones(4), alpha=1.0) + + +class TestConformalCalibrator(unittest.TestCase): + def _cal_batch(self, probs, labels, include_background=True): + cal = ConformalCalibrator(alpha=0.1, include_background=include_background) + cal.accumulate(probs, labels) + return cal.calibrate() + + def test_classification_single_batch(self): + # 3 classes, 10 samples; true class prob 0.8 each -> score 0.2 each + probs = torch.full((10, 3), 0.1) + probs[:, 0] = 0.8 + probs = probs / probs.sum(dim=1, keepdim=True) + labels = torch.zeros(10, dtype=torch.long) + qhat = self._cal_batch(probs, labels) + # all scores == 0.2 -> qhat == 0.2 + assert_allclose(qhat, torch.tensor(0.2), atol=1e-6) + + def test_classification_multi_batch(self): + cal = ConformalCalibrator(alpha=0.1) + # batch 1: 5 samples, true prob 0.9 -> score 0.1 + probs1 = torch.full((5, 2), 0.1) + probs1[:, 0] = 0.9 + probs1 = probs1 / probs1.sum(dim=1, keepdim=True) + cal.accumulate(probs1, torch.zeros(5, dtype=torch.long)) + # batch 2: 5 samples, true prob 0.5 -> score 0.5 + probs2 = torch.full((5, 2), 0.5) + probs2 = probs2 / probs2.sum(dim=1, keepdim=True) + cal.accumulate(probs2, torch.zeros(5, dtype=torch.long)) + # combined 10 scores: five 0.1, five 0.5; rank=ceil(11*0.9)=10 -> kth=10 -> max + qhat = cal.calibrate() + assert_allclose(qhat, torch.tensor(0.5), atol=1e-6) + + def test_segmentation_voxel_reshape(self): + # (B=2, C=3, H=2, W=2) with constant true-prob 0.8 -> score 0.2 everywhere + probs = torch.zeros(2, 3, 2, 2) + probs[:, 0] = 0.8 + probs[:, 1] = 0.1 + probs[:, 2] = 0.1 # already normalized: 0.8 + 0.1 + 0.1 = 1 + labels = torch.zeros(2, 1, 2, 2, dtype=torch.long) + qhat = self._cal_batch(probs, labels) + assert_allclose(qhat, torch.tensor(0.2), atol=1e-6) + + def test_exclude_background_scores_foreground_only(self): + # foreground voxel (class 1): full softmax kept, score = 1 - softmax[1] + probs = torch.tensor([[[0.8], [0.1], [0.1]]]) # (B=1, C=3, spatial=1) + labels = torch.tensor([[[1]]]) # (1,1,1) class 1 + qhat = self._cal_batch(probs, labels, include_background=False) + # true class-1 prob 0.1 -> score 0.9 + assert_allclose(qhat, torch.tensor(0.9), atol=1e-6) + + def test_exclude_background_drops_bg_voxels(self): + # one bg voxel (class 0) + one fg voxel (class 2); bg must be excluded so only the + # fg score (1 - 0.7 = 0.3) calibrates the threshold, not bg's tiny score. + probs = torch.tensor([[[0.9, 0.2], [0.05, 0.1], [0.05, 0.7]]]) # (B=1, C=3, spatial=2) + labels = torch.tensor([[[0, 2]]]) # (1,1,2): voxel0 bg, voxel1 class 2 + qhat = self._cal_batch(probs, labels, include_background=False) + assert_allclose(qhat, torch.tensor(0.3), atol=1e-6) + + def test_unsupported_score_raises(self): + with self.assertRaises(ValueError): + ConformalCalibrator(score="aps") + + def test_calibrate_empty_raises(self): + with self.assertRaises(RuntimeError): + ConformalCalibrator().calibrate() + + def test_invalid_labels_are_dropped_not_clamped(self): + # labels -1 and 99 (out of range for C=3) must be dropped, not clamped to 0/2. + # If clamped, score for "label 99" would be 1 - softmax[2] = 0.9, corrupting qhat. + probs = torch.tensor([[[0.8, 0.8], [0.1, 0.1], [0.1, 0.1]]]) # (1, 3, 2) + labels = torch.tensor([[[-1, 99]]]) # both invalid + cal = ConformalCalibrator(alpha=0.1) + cal.accumulate(probs, labels) + with self.assertRaises(RuntimeError): + # all labels dropped -> no scores -> calibrate raises + cal.calibrate() + + def test_mixed_valid_invalid_labels_keeps_valid(self): + # one valid label (class 0, score 0.2) + one invalid (class 99, dropped). + # qhat should reflect only the valid score. + probs = torch.tensor([[[0.8, 0.8], [0.1, 0.1], [0.1, 0.1]]]) # (1, 3, 2) + labels = torch.tensor([[[0, 99]]]) # voxel0 valid, voxel1 invalid + qhat = self._cal_batch(probs, labels) + assert_allclose(qhat, torch.tensor(0.2), atol=1e-6) + + +class TestConformalPredictor(unittest.TestCase): + def test_set_and_predict(self): + qhat = torch.tensor(0.3) + inferer = ConformalPredictor(qhat=qhat) + + def net(x): + # logits: (1, 3, 2, 2) where channel 0 dominates + logits = torch.zeros(1, 3, 2, 2) + logits[:, 0] = 5.0 + logits[:, 1] = 0.0 + logits[:, 2] = 0.0 + return logits + + sets = inferer(torch.zeros(1, 1, 2, 2), net) + # softmax[0] ~ 0.99 -> score 0.01 <= 0.3 -> True; others score ~0.99 -> False + self.assertEqual(sets.shape, (1, 3, 2, 2)) + self.assertEqual(sets.dtype, torch.bool) + self.assertTrue(sets[:, 0].all()) + self.assertFalse(sets[:, 1].any()) + self.assertFalse(sets[:, 2].any()) + + def test_no_threshold_raises(self): + inferer = ConformalPredictor() # qhat None + + def net(x): + return torch.zeros(1, 2) + + with self.assertRaises(RuntimeError): + inferer(torch.zeros(1, 1), net) + + def test_calibrate_then_predict(self): + # tiny deterministic "model": always predict [0.9, 0.1] + class Net(torch.nn.Module): + def forward(self, x): + logits = torch.zeros(x.shape[0], 2) + logits[:, 0] = 2.1972 # log(0.9 / 0.1) + return logits + + net = Net() + # calibration set: 10 samples, true label always 0 + cal_loader = [{"image": torch.zeros(2, 1), "label": torch.zeros(2, 1, dtype=torch.long)} for _ in range(5)] + + inferer = ConformalPredictor(alpha=0.1) + qhat = inferer.calibrate(net, cal_loader, device=torch.device("cpu")) + # true prob 0.9 -> score 0.1 for all 10 samples; rank=ceil(11*0.9)=10 -> max score 0.1 + assert_allclose(qhat, torch.tensor(0.1), atol=1e-4) + sets = inferer(torch.zeros(1, 1), net) + self.assertTrue(sets[0, 0].item()) + self.assertFalse(sets[0, 1].item()) + + def test_bad_network_output_raises(self): + inferer = ConformalPredictor(qhat=torch.tensor(0.5)) + + def net(x): + return [0.0, 0.0] # not a tensor + + with self.assertRaises(TypeError): + inferer(torch.zeros(1, 1), net) + + def test_set_threshold_rejects_non_scalar(self): + inferer = ConformalPredictor() + with self.assertRaises(ValueError): + inferer.set_threshold(torch.tensor([0.1, 0.2])) + + def test_set_threshold_rejects_non_tensor(self): + inferer = ConformalPredictor() + with self.assertRaises(TypeError): + inferer.set_threshold(0.5) # type: ignore[arg-type] + + def test_calibrate_restores_training_state(self): + class Net(torch.nn.Module): + def forward(self, x): + logits = torch.zeros(x.shape[0], 2) + logits[:, 0] = 2.1972 + return logits + + net = Net() + net.train() + self.assertTrue(net.training) + cal_loader = [{"image": torch.zeros(2, 1), "label": torch.zeros(2, 1, dtype=torch.long)} for _ in range(5)] + inferer = ConformalPredictor(alpha=0.1) + inferer.calibrate(net, cal_loader, device=torch.device("cpu")) + # training state must be restored to True + self.assertTrue(net.training) + + def test_calibrate_preserves_eval_state(self): + class Net(torch.nn.Module): + def forward(self, x): + logits = torch.zeros(x.shape[0], 2) + logits[:, 0] = 2.1972 + return logits + + net = Net() + net.eval() + self.assertFalse(net.training) + cal_loader = [{"image": torch.zeros(2, 1), "label": torch.zeros(2, 1, dtype=torch.long)} for _ in range(5)] + inferer = ConformalPredictor(alpha=0.1) + inferer.calibrate(net, cal_loader, device=torch.device("cpu")) + self.assertFalse(net.training) + + +if __name__ == "__main__": + unittest.main()