From 04b79e9a73b3d8db9864abd777beb17ea1786bdd Mon Sep 17 00:00:00 2001 From: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> Date: Sun, 21 Jun 2026 16:03:23 -0400 Subject: [PATCH 1/3] Enabling MetaTensor Persistent Caching Signed-off-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> --- monai/data/dataset.py | 24 +++-- monai/data/utils.py | 18 +--- tests/data/test_persistentdataset.py | 133 ++++++++++++++++++++++++++- 3 files changed, 147 insertions(+), 28 deletions(-) diff --git a/monai/data/dataset.py b/monai/data/dataset.py index 2511ce2219..1ccc0d532d 100644 --- a/monai/data/dataset.py +++ b/monai/data/dataset.py @@ -210,8 +210,12 @@ class PersistentDataset(Dataset): Cached data is expected to be tensors, primitives, or dictionaries keying to these values. Numpy arrays will be converted to tensors, however any other object type returned by transforms will not be loadable since - `torch.load` will be used with `weights_only=True` to prevent loading of potentially malicious objects. - Legacy cache files may not be loadable and may need to be recomputed. + `torch.load` will be used with `weights_only=True` by default to prevent loading of potentially malicious + objects. Legacy cache files may not be loadable and may need to be recomputed. MetaTensor objects can be saved + and loaded with their metadata preserved if `track_meta` is True, however the objects stored in the metadata + must be acceptable as serialisable by `torch.load` by default or if they have been white-listed with + `torch.serialization.add_safe_globals`. Any other object type may be stored but will fail to load and force + a cache recompute. Lazy Resampling: If you make use of the lazy resampling feature of `monai.transforms.Compose`, please refer to @@ -266,13 +270,12 @@ def __init__( When this is enabled, the traced transform instance IDs will be removed from the cached MetaTensors. This is useful for skipping the transform instance checks when inverting applied operations using the cached content and with re-created transform instances. - track_meta: whether to track the meta information, if `True`, will convert to `MetaTensor`. - default to `False`. Cannot be used with `weights_only=True`. + track_meta: whether to track the meta information, defaults to False. If `True`, converts to `MetaTensor`. weights_only: keyword argument passed to `torch.load` when reading cached files. - default to `True`. When set to `True`, `torch.load` restricts loading to tensors and - other safe objects. Setting this to `False` is required for loading `MetaTensor` - objects saved with `track_meta=True`, however this creates the possibility of remote - code execution through `torch.load` so be aware of the security implications of doing so. + default to `True`. When `True`, `torch.load` restricts loading to tensors and other safe objects. + Setting to `False` should only be done if it's absolutely necessary to load unsafe pickled data, + eg. MetaTensor objects with unsafe objects in their metadata. Users must verify the safety of the data + they intend to load before doing so. Raises: ValueError: When both `track_meta=True` and `weights_only=True`, since this combination @@ -292,11 +295,6 @@ def __init__( if hash_transform is not None: self.set_transform_hash(hash_transform) self.reset_ops_id = reset_ops_id - if track_meta and weights_only: - raise ValueError( - "Invalid argument combination: `track_meta=True` cannot be used with `weights_only=True`. " - "To cache and reload MetaTensors, set `track_meta=True` and `weights_only=False`." - ) self.track_meta = track_meta self.weights_only = weights_only diff --git a/monai/data/utils.py b/monai/data/utils.py index d548ed7248..8d675ed122 100644 --- a/monai/data/utils.py +++ b/monai/data/utils.py @@ -1367,13 +1367,8 @@ def json_hashing(item) -> bytes: """ # TODO: Find way to hash transforms content as part of the cache - cache_key = "" - if sys.version_info.minor < 9: - cache_key = hashlib.md5(json.dumps(item, sort_keys=True).encode("utf-8")).hexdigest() - else: - cache_key = hashlib.md5( - json.dumps(item, sort_keys=True).encode("utf-8"), usedforsecurity=False # type: ignore - ).hexdigest() + dump = json.dumps(item, sort_keys=True).encode("utf-8") + cache_key = hashlib.sha256(dump, usedforsecurity=False).hexdigest() # type: ignore return f"{cache_key}".encode() @@ -1388,13 +1383,8 @@ def pickle_hashing(item, protocol=pickle.HIGHEST_PROTOCOL) -> bytes: Returns: the corresponding hash key """ - cache_key = "" - if sys.version_info.minor < 9: - cache_key = hashlib.md5(pickle.dumps(sorted_dict(item), protocol=protocol)).hexdigest() - else: - cache_key = hashlib.md5( - pickle.dumps(sorted_dict(item), protocol=protocol), usedforsecurity=False # type: ignore - ).hexdigest() + dump = pickle.dumps(sorted_dict(item), protocol=protocol) + cache_key = hashlib.sha256(dump, usedforsecurity=False).hexdigest() # type: ignore return f"{cache_key}".encode() diff --git a/tests/data/test_persistentdataset.py b/tests/data/test_persistentdataset.py index ca62cdb184..73eac832f1 100644 --- a/tests/data/test_persistentdataset.py +++ b/tests/data/test_persistentdataset.py @@ -13,8 +13,11 @@ import contextlib import os +import pickle import tempfile import unittest +from pathlib import Path +from unittest.mock import patch import nibabel as nib import numpy as np @@ -46,7 +49,7 @@ TEST_CASE_4 = [True, False, False, MetaTensor] -TEST_CASE_5 = [True, True, True, None] +TEST_CASE_5 = [True, True, False, MetaTensor] TEST_CASE_6 = [False, False, False, torch.Tensor] @@ -200,6 +203,134 @@ def test_track_meta_and_weights_only(self, track_meta, weights_only, expected_er im = test_dataset[0]["image"] self.assertIsInstance(im, expected_type) + def test_metatensor_loading(self): + """ + Thorough test of metadata loading correctly with MetaTensor. This will store a MetaTensor with safe object types + in its metadata dictionary, test the cache file exists and can be safely loaded with weights only, and that the + loaded object is another MetaTensor with the correct information + """ + meta = {"test_meta": 123, "foo": "bar", "test_tuple": (1, 2, 3)} + imt = MetaTensor(torch.rand(1, 128, 128, 128), meta=dict(meta), affine=torch.rand(4, 4)) + + with tempfile.TemporaryDirectory() as tempdir: + # cache_dir = os.path.join(os.path.join(tempdir, "cache"), "data") + cache_dir = Path(tempdir) / "cache" / "data" + + test_data = [{"image": imt}] + + test_dataset = PersistentDataset( + data=test_data, + transform=Compose([Identity()]), + cache_dir=str(cache_dir), + track_meta=True, + weights_only=True, + ) + + im = test_dataset[0]["image"] + self.assertIsInstance(im, MetaTensor, "MetaTensor not stored in dataset.") + + for k, v in meta.items(): + self.assertIn(k, im.meta, f"Metadata key {k} missing from loaded object.") + self.assertEqual(im.meta[k], v, f"Metadata key {k} not equal ({im.meta[k]}!={v}).") + + torch.testing.assert_close(imt.affine, im.affine) + + cache_files = list(cache_dir.glob("*")) + self.assertEqual(len(cache_files), 1, "Cached file not present.") + + cache_im = torch.load(cache_files[0], weights_only=True)["image"] + + self.assertIsInstance(cache_im, MetaTensor, "MetaTensor not stored in dataset.") + + for k, v in meta.items(): + self.assertIn(k, cache_im.meta, f"Metadata key {k} missing from loaded object.") + self.assertEqual(cache_im.meta[k], v, f"Metadata key {k} not equal ({cache_im.meta[k]}!={v}).") + + # create a new dataset to be sure + test_dataset2 = PersistentDataset( + data=test_data, + transform=Compose([Identity()]), + cache_dir=str(cache_dir), + track_meta=True, + weights_only=True, + ) + + # Replace torch.load with a function returning the same thing wrapped in a tuple, this is used to indicate + # the dataset loaded the cached data rather than recomputed. + old_load = torch.load + + def _mock_load(f, weights_only): + self.assertTrue(weights_only, f"torch.load called with {weights_only=}.") + return (old_load(f, weights_only=weights_only),) + + # check the returned object is a tuple containing the expected dict, if not then _mock_load wasn't called + with patch("torch.load", _mock_load): + im2_t = test_dataset2[0] + self.assertIsInstance(im2_t, tuple, "Special tuple not returned, so mock not used.") + self.assertIsInstance(im2_t[0]["image"], MetaTensor, "MetaTensor not stored in dataset.") + + def test_metatensor_badcache(self): + """ + Test attempting to save then load a MetaTensor with an unsafe metadata item raises an exception. This creates + a MetaTensor with an object in its metadata using unsafe code in __reduce__ which gets stored in the pickle. + When attempting to load this through torch.load, pickle.UnpicklingError should be raised to force a recompute + of the cached data rather than attempting to load something unsafe. + """ + with tempfile.TemporaryDirectory() as tempdir: + cache_dir = Path(tempdir) / "cache" / "data" + + class _BadType: + def __reduce__(self): + # something more insecure than this could be done with os.system + return (os.system, (f'echo "Code injected!" > {str(Path(tempdir)/"out.txt")}',)) + + meta = {"test_meta": 123, "foo": "bar", "bad_item": _BadType()} + imt = MetaTensor(torch.rand(1, 128, 128, 128), meta=dict(meta), affine=torch.rand(4, 4)) + test_data = [{"image": imt}] + + test_dataset = PersistentDataset( + data=test_data, + transform=Compose([Identity()]), + cache_dir=str(cache_dir), + track_meta=True, + weights_only=True, + ) + + # This will trigger the _BadType class code injection because deepcopy will use __reduce__, but will still + # write the cache file as needed for the test. The alternative was to write the cache file directly with a + # computed hash value, but computing that hash without using pickle_hashing isn't trivial. + im = test_dataset[0]["image"] + + self.assertIsInstance(im, MetaTensor, "MetaTensor not stored in dataset.") + + cache_files = list(cache_dir.glob("*")) + self.assertEqual(len(cache_files), 1, "Cached file not present.") + + # loading the cache file directly will raise the pickle exception as expected + with self.assertRaises(pickle.UnpicklingError): + torch.load(cache_files[0], weights_only=True) + + # create a new dataset object just to be sure. When loading, a cache hit will occur but this will raise + # the pickle exception again and force a recompute of the cached data as well as a warning, this indicates + # the unsafe data was correctly rejected. + test_dataset2 = PersistentDataset( + data=test_data, + transform=Compose([Identity()]), + cache_dir=str(cache_dir), + track_meta=True, + weights_only=True, + ) + + # warning raised about recomputing the corrupted cache file which raised UnpicklingError + with self.assertWarns(UserWarning): + im = test_dataset2[0]["image"] + + self.assertIsInstance(im, MetaTensor, "MetaTensor not stored in dataset.") + + cache_files2 = list(cache_dir.glob("*")) + + self.assertEqual(cache_files[0], cache_files2[0], "Hashes for cached data differ.") + if __name__ == "__main__": unittest.main() From 0ee969cb8cd93b91c490784afe466e1579f28778 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 21 Jun 2026 20:43:13 +0000 Subject: [PATCH 2/3] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/data/utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/monai/data/utils.py b/monai/data/utils.py index 8d675ed122..cc3d95cc13 100644 --- a/monai/data/utils.py +++ b/monai/data/utils.py @@ -17,7 +17,6 @@ import math import os import pickle -import sys from collections import abc, defaultdict from collections.abc import Generator, Iterable, Mapping, Sequence, Sized from copy import deepcopy From 19453a1f89b4d2254b15e1d9f0754ffc5c788950 Mon Sep 17 00:00:00 2001 From: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> Date: Sun, 21 Jun 2026 17:24:48 -0400 Subject: [PATCH 3/3] Tweaks Signed-off-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> --- monai/data/dataset.py | 4 ---- tests/data/test_persistentdataset.py | 2 +- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/monai/data/dataset.py b/monai/data/dataset.py index 1ccc0d532d..b182aaf235 100644 --- a/monai/data/dataset.py +++ b/monai/data/dataset.py @@ -276,10 +276,6 @@ def __init__( Setting to `False` should only be done if it's absolutely necessary to load unsafe pickled data, eg. MetaTensor objects with unsafe objects in their metadata. Users must verify the safety of the data they intend to load before doing so. - - Raises: - ValueError: When both `track_meta=True` and `weights_only=True`, since this combination - prevents cached MetaTensors from being reloaded and causes perpetual cache regeneration. """ super().__init__(data=data, transform=transform) self.cache_dir = Path(cache_dir) if cache_dir is not None else None diff --git a/tests/data/test_persistentdataset.py b/tests/data/test_persistentdataset.py index 73eac832f1..1ddd6f4115 100644 --- a/tests/data/test_persistentdataset.py +++ b/tests/data/test_persistentdataset.py @@ -282,7 +282,7 @@ def test_metatensor_badcache(self): class _BadType: def __reduce__(self): # something more insecure than this could be done with os.system - return (os.system, (f'echo "Code injected!" > {str(Path(tempdir)/"out.txt")}',)) + return (os.system, (f'echo "Code injected!" > {Path(tempdir)/"out.txt"!s}',)) meta = {"test_meta": 123, "foo": "bar", "bad_item": _BadType()} imt = MetaTensor(torch.rand(1, 128, 128, 128), meta=dict(meta), affine=torch.rand(4, 4))