Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 11 additions & 17 deletions monai/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,8 +210,12 @@ class PersistentDataset(Dataset):

Cached data is expected to be tensors, primitives, or dictionaries keying to these values. Numpy arrays will
be converted to tensors, however any other object type returned by transforms will not be loadable since
`torch.load` will be used with `weights_only=True` to prevent loading of potentially malicious objects.
Legacy cache files may not be loadable and may need to be recomputed.
`torch.load` will be used with `weights_only=True` by default to prevent loading of potentially malicious
objects. Legacy cache files may not be loadable and may need to be recomputed. MetaTensor objects can be saved
and loaded with their metadata preserved if `track_meta` is True, however the objects stored in the metadata
must be acceptable as serialisable by `torch.load` by default or if they have been white-listed with
`torch.serialization.add_safe_globals`. Any other object type may be stored but will fail to load and force
a cache recompute.

Lazy Resampling:
If you make use of the lazy resampling feature of `monai.transforms.Compose`, please refer to
Expand Down Expand Up @@ -266,17 +270,12 @@ def __init__(
When this is enabled, the traced transform instance IDs will be removed from the cached MetaTensors.
This is useful for skipping the transform instance checks when inverting applied operations
using the cached content and with re-created transform instances.
track_meta: whether to track the meta information, if `True`, will convert to `MetaTensor`.
default to `False`. Cannot be used with `weights_only=True`.
track_meta: whether to track the meta information, defaults to False. If `True`, converts to `MetaTensor`.
weights_only: keyword argument passed to `torch.load` when reading cached files.
default to `True`. When set to `True`, `torch.load` restricts loading to tensors and
other safe objects. Setting this to `False` is required for loading `MetaTensor`
objects saved with `track_meta=True`, however this creates the possibility of remote
code execution through `torch.load` so be aware of the security implications of doing so.

Raises:
ValueError: When both `track_meta=True` and `weights_only=True`, since this combination
prevents cached MetaTensors from being reloaded and causes perpetual cache regeneration.
default to `True`. When `True`, `torch.load` restricts loading to tensors and other safe objects.
Setting to `False` should only be done if it's absolutely necessary to load unsafe pickled data,
eg. MetaTensor objects with unsafe objects in their metadata. Users must verify the safety of the data
they intend to load before doing so.
"""
super().__init__(data=data, transform=transform)
self.cache_dir = Path(cache_dir) if cache_dir is not None else None
Expand All @@ -292,11 +291,6 @@ def __init__(
if hash_transform is not None:
self.set_transform_hash(hash_transform)
self.reset_ops_id = reset_ops_id
if track_meta and weights_only:
raise ValueError(
"Invalid argument combination: `track_meta=True` cannot be used with `weights_only=True`. "
"To cache and reload MetaTensors, set `track_meta=True` and `weights_only=False`."
)
self.track_meta = track_meta
self.weights_only = weights_only

Expand Down
19 changes: 4 additions & 15 deletions monai/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import math
import os
import pickle
import sys
from collections import abc, defaultdict
from collections.abc import Generator, Iterable, Mapping, Sequence, Sized
from copy import deepcopy
Expand Down Expand Up @@ -1367,13 +1366,8 @@ def json_hashing(item) -> bytes:

"""
# TODO: Find way to hash transforms content as part of the cache
cache_key = ""
if sys.version_info.minor < 9:
cache_key = hashlib.md5(json.dumps(item, sort_keys=True).encode("utf-8")).hexdigest()
else:
cache_key = hashlib.md5(
json.dumps(item, sort_keys=True).encode("utf-8"), usedforsecurity=False # type: ignore
).hexdigest()
dump = json.dumps(item, sort_keys=True).encode("utf-8")
cache_key = hashlib.sha256(dump, usedforsecurity=False).hexdigest() # type: ignore
return f"{cache_key}".encode()


Expand All @@ -1388,13 +1382,8 @@ def pickle_hashing(item, protocol=pickle.HIGHEST_PROTOCOL) -> bytes:
Returns: the corresponding hash key

"""
cache_key = ""
if sys.version_info.minor < 9:
cache_key = hashlib.md5(pickle.dumps(sorted_dict(item), protocol=protocol)).hexdigest()
else:
cache_key = hashlib.md5(
pickle.dumps(sorted_dict(item), protocol=protocol), usedforsecurity=False # type: ignore
).hexdigest()
dump = pickle.dumps(sorted_dict(item), protocol=protocol)
cache_key = hashlib.sha256(dump, usedforsecurity=False).hexdigest() # type: ignore
return f"{cache_key}".encode()


Expand Down
133 changes: 132 additions & 1 deletion tests/data/test_persistentdataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,11 @@

import contextlib
import os
import pickle
import tempfile
import unittest
from pathlib import Path
from unittest.mock import patch

import nibabel as nib
import numpy as np
Expand Down Expand Up @@ -46,7 +49,7 @@

TEST_CASE_4 = [True, False, False, MetaTensor]

TEST_CASE_5 = [True, True, True, None]
TEST_CASE_5 = [True, True, False, MetaTensor]

TEST_CASE_6 = [False, False, False, torch.Tensor]

Expand Down Expand Up @@ -200,6 +203,134 @@ def test_track_meta_and_weights_only(self, track_meta, weights_only, expected_er
im = test_dataset[0]["image"]
self.assertIsInstance(im, expected_type)

def test_metatensor_loading(self):
"""
Thorough test of metadata loading correctly with MetaTensor. This will store a MetaTensor with safe object types
in its metadata dictionary, test the cache file exists and can be safely loaded with weights only, and that the
loaded object is another MetaTensor with the correct information
"""
meta = {"test_meta": 123, "foo": "bar", "test_tuple": (1, 2, 3)}
imt = MetaTensor(torch.rand(1, 128, 128, 128), meta=dict(meta), affine=torch.rand(4, 4))

with tempfile.TemporaryDirectory() as tempdir:
# cache_dir = os.path.join(os.path.join(tempdir, "cache"), "data")
cache_dir = Path(tempdir) / "cache" / "data"

test_data = [{"image": imt}]

test_dataset = PersistentDataset(
data=test_data,
transform=Compose([Identity()]),
cache_dir=str(cache_dir),
track_meta=True,
weights_only=True,
)

im = test_dataset[0]["image"]
self.assertIsInstance(im, MetaTensor, "MetaTensor not stored in dataset.")

for k, v in meta.items():
self.assertIn(k, im.meta, f"Metadata key {k} missing from loaded object.")
self.assertEqual(im.meta[k], v, f"Metadata key {k} not equal ({im.meta[k]}!={v}).")

torch.testing.assert_close(imt.affine, im.affine)

cache_files = list(cache_dir.glob("*"))
self.assertEqual(len(cache_files), 1, "Cached file not present.")

cache_im = torch.load(cache_files[0], weights_only=True)["image"]

self.assertIsInstance(cache_im, MetaTensor, "MetaTensor not stored in dataset.")

for k, v in meta.items():
self.assertIn(k, cache_im.meta, f"Metadata key {k} missing from loaded object.")
self.assertEqual(cache_im.meta[k], v, f"Metadata key {k} not equal ({cache_im.meta[k]}!={v}).")

# create a new dataset to be sure
test_dataset2 = PersistentDataset(
data=test_data,
transform=Compose([Identity()]),
cache_dir=str(cache_dir),
track_meta=True,
weights_only=True,
)

# Replace torch.load with a function returning the same thing wrapped in a tuple, this is used to indicate
# the dataset loaded the cached data rather than recomputed.
old_load = torch.load

def _mock_load(f, weights_only):
self.assertTrue(weights_only, f"torch.load called with {weights_only=}.")
return (old_load(f, weights_only=weights_only),)

# check the returned object is a tuple containing the expected dict, if not then _mock_load wasn't called
with patch("torch.load", _mock_load):
im2_t = test_dataset2[0]
self.assertIsInstance(im2_t, tuple, "Special tuple not returned, so mock not used.")
self.assertIsInstance(im2_t[0]["image"], MetaTensor, "MetaTensor not stored in dataset.")

def test_metatensor_badcache(self):
"""
Test attempting to save then load a MetaTensor with an unsafe metadata item raises an exception. This creates
a MetaTensor with an object in its metadata using unsafe code in __reduce__ which gets stored in the pickle.
When attempting to load this through torch.load, pickle.UnpicklingError should be raised to force a recompute
of the cached data rather than attempting to load something unsafe.
"""
with tempfile.TemporaryDirectory() as tempdir:
cache_dir = Path(tempdir) / "cache" / "data"

class _BadType:
def __reduce__(self):
# something more insecure than this could be done with os.system
return (os.system, (f'echo "Code injected!" > {Path(tempdir)/"out.txt"!s}',))

meta = {"test_meta": 123, "foo": "bar", "bad_item": _BadType()}
imt = MetaTensor(torch.rand(1, 128, 128, 128), meta=dict(meta), affine=torch.rand(4, 4))
test_data = [{"image": imt}]

test_dataset = PersistentDataset(
data=test_data,
transform=Compose([Identity()]),
cache_dir=str(cache_dir),
track_meta=True,
weights_only=True,
)

# This will trigger the _BadType class code injection because deepcopy will use __reduce__, but will still
# write the cache file as needed for the test. The alternative was to write the cache file directly with a
# computed hash value, but computing that hash without using pickle_hashing isn't trivial.
im = test_dataset[0]["image"]

self.assertIsInstance(im, MetaTensor, "MetaTensor not stored in dataset.")

cache_files = list(cache_dir.glob("*"))
self.assertEqual(len(cache_files), 1, "Cached file not present.")

# loading the cache file directly will raise the pickle exception as expected
with self.assertRaises(pickle.UnpicklingError):
torch.load(cache_files[0], weights_only=True)

# create a new dataset object just to be sure. When loading, a cache hit will occur but this will raise
# the pickle exception again and force a recompute of the cached data as well as a warning, this indicates
# the unsafe data was correctly rejected.
test_dataset2 = PersistentDataset(
data=test_data,
transform=Compose([Identity()]),
cache_dir=str(cache_dir),
track_meta=True,
weights_only=True,
)

# warning raised about recomputing the corrupted cache file which raised UnpicklingError
with self.assertWarns(UserWarning):
im = test_dataset2[0]["image"]

self.assertIsInstance(im, MetaTensor, "MetaTensor not stored in dataset.")

cache_files2 = list(cache_dir.glob("*"))

self.assertEqual(cache_files[0], cache_files2[0], "Hashes for cached data differ.")


if __name__ == "__main__":
unittest.main()
Loading