diff --git a/monai/data/dataset.py b/monai/data/dataset.py index 2511ce2219..b182aaf235 100644 --- a/monai/data/dataset.py +++ b/monai/data/dataset.py @@ -210,8 +210,12 @@ class PersistentDataset(Dataset): Cached data is expected to be tensors, primitives, or dictionaries keying to these values. Numpy arrays will be converted to tensors, however any other object type returned by transforms will not be loadable since - `torch.load` will be used with `weights_only=True` to prevent loading of potentially malicious objects. - Legacy cache files may not be loadable and may need to be recomputed. + `torch.load` will be used with `weights_only=True` by default to prevent loading of potentially malicious + objects. Legacy cache files may not be loadable and may need to be recomputed. MetaTensor objects can be saved + and loaded with their metadata preserved if `track_meta` is True, however the objects stored in the metadata + must be acceptable as serialisable by `torch.load` by default or if they have been white-listed with + `torch.serialization.add_safe_globals`. Any other object type may be stored but will fail to load and force + a cache recompute. Lazy Resampling: If you make use of the lazy resampling feature of `monai.transforms.Compose`, please refer to @@ -266,17 +270,12 @@ def __init__( When this is enabled, the traced transform instance IDs will be removed from the cached MetaTensors. This is useful for skipping the transform instance checks when inverting applied operations using the cached content and with re-created transform instances. - track_meta: whether to track the meta information, if `True`, will convert to `MetaTensor`. - default to `False`. Cannot be used with `weights_only=True`. + track_meta: whether to track the meta information, defaults to False. If `True`, converts to `MetaTensor`. weights_only: keyword argument passed to `torch.load` when reading cached files. - default to `True`. When set to `True`, `torch.load` restricts loading to tensors and - other safe objects. Setting this to `False` is required for loading `MetaTensor` - objects saved with `track_meta=True`, however this creates the possibility of remote - code execution through `torch.load` so be aware of the security implications of doing so. - - Raises: - ValueError: When both `track_meta=True` and `weights_only=True`, since this combination - prevents cached MetaTensors from being reloaded and causes perpetual cache regeneration. + default to `True`. When `True`, `torch.load` restricts loading to tensors and other safe objects. + Setting to `False` should only be done if it's absolutely necessary to load unsafe pickled data, + eg. MetaTensor objects with unsafe objects in their metadata. Users must verify the safety of the data + they intend to load before doing so. """ super().__init__(data=data, transform=transform) self.cache_dir = Path(cache_dir) if cache_dir is not None else None @@ -292,11 +291,6 @@ def __init__( if hash_transform is not None: self.set_transform_hash(hash_transform) self.reset_ops_id = reset_ops_id - if track_meta and weights_only: - raise ValueError( - "Invalid argument combination: `track_meta=True` cannot be used with `weights_only=True`. " - "To cache and reload MetaTensors, set `track_meta=True` and `weights_only=False`." - ) self.track_meta = track_meta self.weights_only = weights_only diff --git a/monai/data/utils.py b/monai/data/utils.py index d548ed7248..cc3d95cc13 100644 --- a/monai/data/utils.py +++ b/monai/data/utils.py @@ -17,7 +17,6 @@ import math import os import pickle -import sys from collections import abc, defaultdict from collections.abc import Generator, Iterable, Mapping, Sequence, Sized from copy import deepcopy @@ -1367,13 +1366,8 @@ def json_hashing(item) -> bytes: """ # TODO: Find way to hash transforms content as part of the cache - cache_key = "" - if sys.version_info.minor < 9: - cache_key = hashlib.md5(json.dumps(item, sort_keys=True).encode("utf-8")).hexdigest() - else: - cache_key = hashlib.md5( - json.dumps(item, sort_keys=True).encode("utf-8"), usedforsecurity=False # type: ignore - ).hexdigest() + dump = json.dumps(item, sort_keys=True).encode("utf-8") + cache_key = hashlib.sha256(dump, usedforsecurity=False).hexdigest() # type: ignore return f"{cache_key}".encode() @@ -1388,13 +1382,8 @@ def pickle_hashing(item, protocol=pickle.HIGHEST_PROTOCOL) -> bytes: Returns: the corresponding hash key """ - cache_key = "" - if sys.version_info.minor < 9: - cache_key = hashlib.md5(pickle.dumps(sorted_dict(item), protocol=protocol)).hexdigest() - else: - cache_key = hashlib.md5( - pickle.dumps(sorted_dict(item), protocol=protocol), usedforsecurity=False # type: ignore - ).hexdigest() + dump = pickle.dumps(sorted_dict(item), protocol=protocol) + cache_key = hashlib.sha256(dump, usedforsecurity=False).hexdigest() # type: ignore return f"{cache_key}".encode() diff --git a/tests/data/test_persistentdataset.py b/tests/data/test_persistentdataset.py index ca62cdb184..1ddd6f4115 100644 --- a/tests/data/test_persistentdataset.py +++ b/tests/data/test_persistentdataset.py @@ -13,8 +13,11 @@ import contextlib import os +import pickle import tempfile import unittest +from pathlib import Path +from unittest.mock import patch import nibabel as nib import numpy as np @@ -46,7 +49,7 @@ TEST_CASE_4 = [True, False, False, MetaTensor] -TEST_CASE_5 = [True, True, True, None] +TEST_CASE_5 = [True, True, False, MetaTensor] TEST_CASE_6 = [False, False, False, torch.Tensor] @@ -200,6 +203,134 @@ def test_track_meta_and_weights_only(self, track_meta, weights_only, expected_er im = test_dataset[0]["image"] self.assertIsInstance(im, expected_type) + def test_metatensor_loading(self): + """ + Thorough test of metadata loading correctly with MetaTensor. This will store a MetaTensor with safe object types + in its metadata dictionary, test the cache file exists and can be safely loaded with weights only, and that the + loaded object is another MetaTensor with the correct information + """ + meta = {"test_meta": 123, "foo": "bar", "test_tuple": (1, 2, 3)} + imt = MetaTensor(torch.rand(1, 128, 128, 128), meta=dict(meta), affine=torch.rand(4, 4)) + + with tempfile.TemporaryDirectory() as tempdir: + # cache_dir = os.path.join(os.path.join(tempdir, "cache"), "data") + cache_dir = Path(tempdir) / "cache" / "data" + + test_data = [{"image": imt}] + + test_dataset = PersistentDataset( + data=test_data, + transform=Compose([Identity()]), + cache_dir=str(cache_dir), + track_meta=True, + weights_only=True, + ) + + im = test_dataset[0]["image"] + self.assertIsInstance(im, MetaTensor, "MetaTensor not stored in dataset.") + + for k, v in meta.items(): + self.assertIn(k, im.meta, f"Metadata key {k} missing from loaded object.") + self.assertEqual(im.meta[k], v, f"Metadata key {k} not equal ({im.meta[k]}!={v}).") + + torch.testing.assert_close(imt.affine, im.affine) + + cache_files = list(cache_dir.glob("*")) + self.assertEqual(len(cache_files), 1, "Cached file not present.") + + cache_im = torch.load(cache_files[0], weights_only=True)["image"] + + self.assertIsInstance(cache_im, MetaTensor, "MetaTensor not stored in dataset.") + + for k, v in meta.items(): + self.assertIn(k, cache_im.meta, f"Metadata key {k} missing from loaded object.") + self.assertEqual(cache_im.meta[k], v, f"Metadata key {k} not equal ({cache_im.meta[k]}!={v}).") + + # create a new dataset to be sure + test_dataset2 = PersistentDataset( + data=test_data, + transform=Compose([Identity()]), + cache_dir=str(cache_dir), + track_meta=True, + weights_only=True, + ) + + # Replace torch.load with a function returning the same thing wrapped in a tuple, this is used to indicate + # the dataset loaded the cached data rather than recomputed. + old_load = torch.load + + def _mock_load(f, weights_only): + self.assertTrue(weights_only, f"torch.load called with {weights_only=}.") + return (old_load(f, weights_only=weights_only),) + + # check the returned object is a tuple containing the expected dict, if not then _mock_load wasn't called + with patch("torch.load", _mock_load): + im2_t = test_dataset2[0] + self.assertIsInstance(im2_t, tuple, "Special tuple not returned, so mock not used.") + self.assertIsInstance(im2_t[0]["image"], MetaTensor, "MetaTensor not stored in dataset.") + + def test_metatensor_badcache(self): + """ + Test attempting to save then load a MetaTensor with an unsafe metadata item raises an exception. This creates + a MetaTensor with an object in its metadata using unsafe code in __reduce__ which gets stored in the pickle. + When attempting to load this through torch.load, pickle.UnpicklingError should be raised to force a recompute + of the cached data rather than attempting to load something unsafe. + """ + with tempfile.TemporaryDirectory() as tempdir: + cache_dir = Path(tempdir) / "cache" / "data" + + class _BadType: + def __reduce__(self): + # something more insecure than this could be done with os.system + return (os.system, (f'echo "Code injected!" > {Path(tempdir)/"out.txt"!s}',)) + + meta = {"test_meta": 123, "foo": "bar", "bad_item": _BadType()} + imt = MetaTensor(torch.rand(1, 128, 128, 128), meta=dict(meta), affine=torch.rand(4, 4)) + test_data = [{"image": imt}] + + test_dataset = PersistentDataset( + data=test_data, + transform=Compose([Identity()]), + cache_dir=str(cache_dir), + track_meta=True, + weights_only=True, + ) + + # This will trigger the _BadType class code injection because deepcopy will use __reduce__, but will still + # write the cache file as needed for the test. The alternative was to write the cache file directly with a + # computed hash value, but computing that hash without using pickle_hashing isn't trivial. + im = test_dataset[0]["image"] + + self.assertIsInstance(im, MetaTensor, "MetaTensor not stored in dataset.") + + cache_files = list(cache_dir.glob("*")) + self.assertEqual(len(cache_files), 1, "Cached file not present.") + + # loading the cache file directly will raise the pickle exception as expected + with self.assertRaises(pickle.UnpicklingError): + torch.load(cache_files[0], weights_only=True) + + # create a new dataset object just to be sure. When loading, a cache hit will occur but this will raise + # the pickle exception again and force a recompute of the cached data as well as a warning, this indicates + # the unsafe data was correctly rejected. + test_dataset2 = PersistentDataset( + data=test_data, + transform=Compose([Identity()]), + cache_dir=str(cache_dir), + track_meta=True, + weights_only=True, + ) + + # warning raised about recomputing the corrupted cache file which raised UnpicklingError + with self.assertWarns(UserWarning): + im = test_dataset2[0]["image"] + + self.assertIsInstance(im, MetaTensor, "MetaTensor not stored in dataset.") + + cache_files2 = list(cache_dir.glob("*")) + + self.assertEqual(cache_files[0], cache_files2[0], "Hashes for cached data differ.") + if __name__ == "__main__": unittest.main()