From f3401df703909f211dc8617e8832ba2f2d3592a0 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Sat, 6 Jun 2026 14:11:29 +0200 Subject: [PATCH 01/14] [PyTorch] Make tensorless quantizers opaque value objects for torch.compile Give tensorless quantizers (MXFP8, FP8 blockwise, FP8 current-scaling, NVFP4) value-object semantics so torch.compile can treat them as baked-in constants: - Add opt-in value identity to the base Quantizer (_value_fields / _value_key / __eq__ / __hash__). Quantizers holding live tensors (delayed-scaling Float8Quantizer) and custom quantizers keep identity semantics. - New transformer_engine/pytorch/dynamo.py houses the torch.compile glue: __fx_repr__, value-key reconstruction and register_value_opaque_quantizer (gracefully a no-op on PyTorch builds without the opaque-object API). - Register the four tensorless quantizers as value opaque types. Also fix CustomRecipe state caching in TransformerEngineBaseModule: set_meta_tensor now rebuilds quantizers when the CustomRecipe instance changes (e.g. nested te.autocast regions) instead of reusing the first recipe's state, since every CustomRecipe shares the CustomRecipeState type but carries its own qfactory. Move the quantizer value-object tests into tests/pytorch/test_torch_compile.py and add that file to the L0 pytorch unittest QA suite. Signed-off-by: Pawel Gadzinski --- tests/pytorch/test_torch_compile.py | 154 ++++++++++++++++++ transformer_engine/pytorch/dynamo.py | 120 ++++++++++++++ .../pytorch/quantized_tensor.py | 67 ++++++++ .../pytorch/tensor/float8_blockwise_tensor.py | 7 + .../pytorch/tensor/float8_tensor.py | 9 + .../pytorch/tensor/mxfp8_tensor.py | 7 + .../pytorch/tensor/nvfp4_tensor.py | 25 +++ 7 files changed, 389 insertions(+) create mode 100644 transformer_engine/pytorch/dynamo.py diff --git a/tests/pytorch/test_torch_compile.py b/tests/pytorch/test_torch_compile.py index 1286492a6e..8adcc88e61 100644 --- a/tests/pytorch/test_torch_compile.py +++ b/tests/pytorch/test_torch_compile.py @@ -24,6 +24,7 @@ from transformer_engine.common import recipe from transformer_engine.pytorch.constants import FP8FwdTensorIdx, FP8BwdTensorIdx from transformer_engine.pytorch.module.base import TransformerEngineBaseModule +from transformer_engine.pytorch.quantization import QuantizerRole from transformer_engine.pytorch.ops.basic.basic_linear import BasicLinear from transformer_engine.pytorch.tensor.float8_tensor import Float8CurrentScalingQuantizer from transformer_engine.pytorch.quantization import QuantizerRole @@ -32,6 +33,14 @@ is_mxfp8_available, is_fp8_block_scaling_available, is_nvfp4_available, + Float8Quantizer, + Float8BlockQuantizer, + MXFP8Quantizer, + NVFP4Quantizer, +) +from transformer_engine.pytorch.dynamo import ( + register_value_opaque_quantizer, + _quantizer_from_value_key, ) from utils import recipe_id @@ -384,3 +393,148 @@ def fn(inp): out = compiled(inp) out.sum().backward() + + +# --------------------------------------------------------------------------- +# Value-opaque quantizers: eager value semantics + FX reconstruction +# +# The tensorless quantizers (current-scaling FP8, FP8 blockwise, MXFP8, NVFP4) +# are torch.compile *value* opaque types: they provide value-based +# ``__eq__`` / ``__hash__`` and an evaluable ``__fx_repr__`` (see +# ``torch._library.opaque_object`` Note [Opaque Objects]). These tests exercise +# the eager value semantics and the FX reconstruction round-trip. They are +# CPU-friendly except for NVFP4 (whose constructor touches the current CUDA +# device). +# --------------------------------------------------------------------------- + + +def _mxfp8(dtype=tex.DType.kFloat8E4M3, rowwise=True, columnwise=True): + return MXFP8Quantizer(fp8_dtype=dtype, rowwise=rowwise, columnwise=columnwise) + + +def _blockwise(dtype=tex.DType.kFloat8E4M3, force_pow_2_scales=True, block_scaling_dim=2): + return Float8BlockQuantizer( + fp8_dtype=dtype, + rowwise=True, + columnwise=True, + force_pow_2_scales=force_pow_2_scales, + block_scaling_dim=block_scaling_dim, + ) + + +def _current_scaling(dtype=tex.DType.kFloat8E4M3, force_pow_2_scales=False, amax_epsilon=0.0): + return Float8CurrentScalingQuantizer( + fp8_dtype=dtype, + device=torch.device("cpu"), + force_pow_2_scales=force_pow_2_scales, + amax_epsilon=amax_epsilon, + ) + + +# (factory, kwargs_for_a_different_but_valid_config) +_CPU_VALUE_QUANTIZERS = [ + pytest.param(_mxfp8, {"dtype": tex.DType.kFloat8E5M2}, id="mxfp8"), + pytest.param(_blockwise, {"force_pow_2_scales": False}, id="float8_blockwise"), + pytest.param(_current_scaling, {"amax_epsilon": 1e-4}, id="float8_current_scaling"), +] + + +@pytest.mark.parametrize("factory, other_kwargs", _CPU_VALUE_QUANTIZERS) +def test_quantizer_value_equality(factory, other_kwargs): + """Same config -> equal & same hash; different config -> not equal.""" + a = factory() + b = factory() + assert a is not b + assert a == b + assert hash(a) == hash(b) + + c = factory(**other_kwargs) + assert a != c + # Usage flags participate in the value. + d = factory() + d.set_usage(rowwise=False, columnwise=False) + assert a != d + + +@pytest.mark.parametrize("factory, other_kwargs", _CPU_VALUE_QUANTIZERS) +def test_quantizer_usable_in_set_and_dict(factory, other_kwargs): + a = factory() + b = factory() + c = factory(**other_kwargs) + assert len({a, b, c}) == 2 + mapping = {a: "x"} + assert mapping[b] == "x" + + +@pytest.mark.parametrize("factory, other_kwargs", _CPU_VALUE_QUANTIZERS) +def test_quantizer_cross_type_inequality(factory, other_kwargs): + a = factory() + other = _current_scaling() if not isinstance(a, Float8CurrentScalingQuantizer) else _mxfp8() + assert a != other + + +@pytest.mark.parametrize("factory, other_kwargs", _CPU_VALUE_QUANTIZERS) +def test_quantizer_fx_repr_roundtrip(factory, other_kwargs): + """``__fx_repr__`` returns an evaluable expression rebuilding an equal object.""" + a = factory() + repr_str, globals_ = a.__fx_repr__() + assert isinstance(repr_str, str) + assert isinstance(globals_, dict) + rebuilt = eval(repr_str, dict(globals_)) # pylint: disable=eval-used + assert rebuilt == a + assert rebuilt is not a + assert hash(rebuilt) == hash(a) + + +@pytest.mark.parametrize("factory, other_kwargs", _CPU_VALUE_QUANTIZERS) +def test_quantizer_value_key_reconstruction(factory, other_kwargs): + a = factory() + rebuilt = _quantizer_from_value_key(a._value_key()) + assert type(rebuilt) is type(a) + assert rebuilt == a + # The deprecated amax-reduction process group is never carried in the value. + assert getattr(rebuilt, "amax_reduction_group", None) is None + + +def test_quantizer_delayed_scaling_keeps_identity_semantics(): + """Float8Quantizer holds live tensors -> identity (not value) semantics.""" + scale = torch.ones(1) + amax = torch.zeros(1) + a = Float8Quantizer(scale=scale, amax=amax, fp8_dtype=tex.DType.kFloat8E4M3) + b = Float8Quantizer(scale=scale, amax=amax, fp8_dtype=tex.DType.kFloat8E4M3) + assert a._value_fields() is None + assert a == a + assert a != b # distinct instances are not equal despite identical config + assert hash(a) == object.__hash__(a) + + +@pytest.mark.parametrize("factory, other_kwargs", _CPU_VALUE_QUANTIZERS) +def test_quantizer_registration_is_idempotent_and_tolerant(factory, other_kwargs): + """Re-registering must not raise, regardless of PyTorch opaque-object support.""" + cls = type(factory()) + register_value_opaque_quantizer(cls) + register_value_opaque_quantizer(cls) + + if not _opaque_available: + pytest.skip("PyTorch build without opaque-object API") + from torch._library.opaque_object import is_opaque_value_type + + assert is_opaque_value_type(cls) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="NVFP4Quantizer requires CUDA") +def test_quantizer_nvfp4_value_semantics(): + a = NVFP4Quantizer(fp4_dtype=tex.DType.kFloat4E2M1) + b = NVFP4Quantizer(fp4_dtype=tex.DType.kFloat4E2M1) + assert a == b + assert hash(a) == hash(b) + + c = NVFP4Quantizer(fp4_dtype=tex.DType.kFloat4E2M1, with_rht=not a.with_rht) + assert a != c + + rebuilt = _quantizer_from_value_key(a._value_key()) + assert rebuilt == a + assert rebuilt.amax_reduction_group is None + + repr_str, globals_ = a.__fx_repr__() + assert eval(repr_str, dict(globals_)) == a # pylint: disable=eval-used diff --git a/transformer_engine/pytorch/dynamo.py b/transformer_engine/pytorch/dynamo.py new file mode 100644 index 0000000000..26aff6f59c --- /dev/null +++ b/transformer_engine/pytorch/dynamo.py @@ -0,0 +1,120 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""torch.compile glue for Transformer Engine quantizers. + +This module isolates the torch.compile-specific plumbing that turns a +*tensorless* quantizer into a torch.compile **value** opaque type: + + * :func:`register_value_opaque_quantizer` -- attaches the ``__fx_repr__`` used + by FX codegen and registers the quantizer class with + ``torch._library.opaque_object``. It is a no-op (other than populating the + local registry) on PyTorch builds without the opaque-object API, so + importing Transformer Engine never fails on older PyTorch -- only + torch.compile specialization on the quantizer is unavailable there. + * :func:`_quantizer_from_value_key` -- rebuilds a quantizer constant from its + value key inside the generated FX graph. + +The eager value semantics (``__eq__`` / ``__hash__`` / ``_value_key`` / +``_value_fields``) live on the quantizer itself; see +:class:`transformer_engine.pytorch.quantized_tensor.Quantizer`. + +See ``torch._library.opaque_object`` Note [Opaque Objects] for the contract a +value-typed opaque object must satisfy (``__eq__`` / ``__hash__`` / +``__fx_repr__``). +""" + +from __future__ import annotations +from typing import Any, Dict, Tuple + +from .constants import DType + + +# Maps a quantizer class qualname to the class object. A value key stores only +# the qualname, so reconstruction looks the class up here. Populated by +# ``register_value_opaque_quantizer`` at import time of each tensor module; this +# avoids importing the tensor modules into this module (which would create an +# import cycle). +_QUANTIZER_VALUE_REGISTRY: Dict[str, type] = {} + + +def _quantizer_from_value_key(key: Tuple[Any, ...]) -> Any: + """Rebuild a tensorless quantizer from its value key. + + Referenced by the ``__fx_repr__`` emitted for value-opaque quantizers; the + generated FX code calls this to materialize the quantizer constant. The + deprecated amax-reduction process group is never part of the value, so a + reconstructed quantizer always starts with no stored group. + """ + qualname, items = key[0], key[1] + cls = _QUANTIZER_VALUE_REGISTRY[qualname] + # Bypass ``__init__`` and restore the value attributes directly: the value + # key already captures every value-defining field (including derived ones), + # and the constructors have heterogeneous signatures / side effects. + obj = cls.__new__(cls) + field_names = set() + for name, value in items: + if name == "dtype": + value = DType.cast(value) + object.__setattr__(obj, name, value) + field_names.add(name) + # The deprecated amax-reduction process group is excluded from the value; + # restore it as ``None`` for quantizers that still carry the fallback so + # attribute access keeps working. + if "with_amax_reduction" in field_names and not hasattr(obj, "amax_reduction_group"): + object.__setattr__(obj, "amax_reduction_group", None) + return obj + + +def _quantizer_fx_repr(self: Any) -> Tuple[str, Dict[str, Any]]: + """``__fx_repr__`` for value-opaque quantizers (attached at registration). + + Returns an evaluable expression that rebuilds the quantizer via + :func:`_quantizer_from_value_key`, together with the globals needed to + evaluate it. + """ + return ( + f"_quantizer_from_value_key({self._value_key()!r})", + {"_quantizer_from_value_key": _quantizer_from_value_key}, + ) + + +def register_value_opaque_quantizer(cls: type) -> None: + """Register a tensorless quantizer class as a torch.compile value opaque type. + + Attaches ``__fx_repr__`` and registers the class with + ``torch._library.opaque_object``. Safe to call on any PyTorch build: on + versions without the opaque-object API it only records the class in the + local registry and attaches ``__fx_repr__`` (both harmless), so Transformer + Engine keeps importing and running in eager mode. + + The quantizer class must already provide value ``__eq__`` / ``__hash__`` and + a non-``None`` ``_value_fields`` (see + :class:`transformer_engine.pytorch.quantized_tensor.Quantizer`). + """ + _QUANTIZER_VALUE_REGISTRY[cls.__qualname__] = cls + + # ``register_opaque_type`` requires ``__fx_repr__`` to already exist on the + # class, so attach it before registering. + if "__fx_repr__" not in cls.__dict__: + cls.__fx_repr__ = _quantizer_fx_repr + + try: + from torch._library.opaque_object import ( # pylint: disable=import-outside-toplevel + register_opaque_type, + is_opaque_value_type, + ) + except (ImportError, AttributeError): + # Older PyTorch without the opaque-object API: eager value semantics + # still work; torch.compile specialization on the quantizer does not. + return + + if is_opaque_value_type(cls): + return + + try: + register_opaque_type(cls, typ="value") + except (ImportError, AttributeError, RuntimeError, TypeError): + # Tolerate partial / experimental opaque-object support. + pass diff --git a/transformer_engine/pytorch/quantized_tensor.py b/transformer_engine/pytorch/quantized_tensor.py index cfe488aae5..f37b3cc63d 100644 --- a/transformer_engine/pytorch/quantized_tensor.py +++ b/transformer_engine/pytorch/quantized_tensor.py @@ -408,6 +408,73 @@ def get_usages(self) -> Dict[str, bool]: "columnwise": self.columnwise_usage, } + # ----- Value-object identity (torch.compile opaque value support) ----- + # A *tensorless* quantizer (one whose entire state is a handful of plain, + # reproducible scalars -- no live tensors, no process groups) behaves like a + # value: two instances with the same configuration are interchangeable. Such + # quantizers opt into value-based ``__eq__`` / ``__hash__`` by overriding + # ``_value_fields``. Quantizers that keep the default (e.g. delayed-scaling + # ``Float8Quantizer``, which holds live scale/amax tensors, and any custom + # quantizer) retain the default identity semantics. + # + # This is the eager-side half of registering the quantizer as a torch.compile + # *value* opaque type; the torch.compile glue (``__fx_repr__``, FX + # reconstruction and ``register_opaque_type``) lives in + # ``transformer_engine.pytorch.dynamo``. + + #: Attributes shared by every quantizer that take part in value identity. + _BASE_VALUE_FIELDS: Tuple[str, ...] = ( + "rowwise_usage", + "columnwise_usage", + "internal", + "optimize_for_gemm", + ) + + def _value_fields(self) -> Optional[Tuple[str, ...]]: + """Subclass-specific value-defining attribute names, or ``None``. + + Returning ``None`` (the default) means the quantizer is *not* a value + object and keeps identity-based equality/hashing. Tensorless quantizers + override this to return the tuple of attribute names that, together with + :attr:`_BASE_VALUE_FIELDS`, fully determine their value (excluding + non-value state such as a deprecated amax-reduction process group). + """ + return None + + def _value_key(self) -> Tuple[Any, ...]: + """Hashable, reproducible key identifying this quantizer's value. + + Only valid for value quantizers (``_value_fields()`` is not ``None``). + """ + fields = self._value_fields() # pylint: disable=assignment-from-none + assert fields is not None, f"{type(self).__name__} is not a value quantizer" + items = [] + for name in self._BASE_VALUE_FIELDS + tuple(fields): + value = getattr(self, name) + if name == "dtype": + # ``DType`` is an ``IntEnum``; store the int so the key stays + # plain: hashable and ``repr``-reproducible for FX codegen. + value = int(value) + items.append((name, value)) + return (type(self).__qualname__, tuple(items)) + + def __eq__(self, other: object) -> Any: + # Value quantizers compare by configuration; everything else keeps the + # default identity semantics (returning ``NotImplemented`` makes Python + # fall back to identity). + if self is other: + return True + if self._value_fields() is None or type(self) is not type(other): + return NotImplemented + if other._value_fields() is None: + return NotImplemented + return self._value_key() == other._value_key() + + def __hash__(self) -> int: + if self._value_fields() is None: + return object.__hash__(self) + return hash(self._value_key()) + class QuantizedTensor(torch.Tensor): """Abstract base class for tensor with quantized data diff --git a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py index ba46508d74..92c22acd0b 100644 --- a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py @@ -14,6 +14,7 @@ from transformer_engine.common.recipe import Float8BlockScaling, Recipe from .storage.float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage from ..quantized_tensor import QuantizedTensor, Quantizer +from ..dynamo import register_value_opaque_quantizer from ._quantization_helpers import _IdentityFunc from ..constants import DType from ..utils import devices_match, round_up_to_nearest_multiple @@ -69,6 +70,9 @@ def copy(self) -> Float8BlockQuantizer: return quantizer + def _value_fields(self) -> Tuple[str, ...]: + return ("dtype", "block_len", "amax_epsilon", "force_pow_2_scales", "block_scaling_dim") + def update_quantized( self, src: torch.Tensor, @@ -211,6 +215,9 @@ def _get_compatible_recipe(self) -> Union[type[Recipe], None]: return Float8BlockScaling +register_value_opaque_quantizer(Float8BlockQuantizer) + + class Float8BlockwiseQTensor(Float8BlockwiseQTensorStorage, QuantizedTensor): """Tensor class with FP8 data quantized via NxN blocks or 1xN blocks. diff --git a/transformer_engine/pytorch/tensor/float8_tensor.py b/transformer_engine/pytorch/tensor/float8_tensor.py index e26abf7df0..56b1ecfb09 100644 --- a/transformer_engine/pytorch/tensor/float8_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_tensor.py @@ -18,6 +18,7 @@ from ..utils import canonicalize_process_group, devices_match from .storage.float8_tensor_storage import Float8TensorStorage, _FromFloat8Func from ..quantized_tensor import QuantizedTensor, Quantizer +from ..dynamo import register_value_opaque_quantizer from ._quantization_helpers import _IdentityFunc from ..constants import dist_group_type, DType @@ -386,6 +387,14 @@ def supports_only_rowwise_all_gather(self) -> bool: """ return True + def _value_fields(self) -> Tuple[str, ...]: + # ``amax_reduction_group`` is intentionally excluded: it is a deprecated + # process group (not a value) and is restored as ``None`` on rebuild. + return ("dtype", "force_pow_2_scales", "amax_epsilon", "with_amax_reduction") + + +register_value_opaque_quantizer(Float8CurrentScalingQuantizer) + class Float8Tensor(Float8TensorStorage, QuantizedTensor): """Experimental tensor class with FP8 data diff --git a/transformer_engine/pytorch/tensor/mxfp8_tensor.py b/transformer_engine/pytorch/tensor/mxfp8_tensor.py index d759aaf5c4..a3746b3088 100644 --- a/transformer_engine/pytorch/tensor/mxfp8_tensor.py +++ b/transformer_engine/pytorch/tensor/mxfp8_tensor.py @@ -18,6 +18,7 @@ from ..utils import devices_match, round_up_to_nearest_multiple from .storage.mxfp8_tensor_storage import MXFP8TensorStorage, _FromMXFP8Func from ..quantized_tensor import QuantizedTensor, Quantizer +from ..dynamo import register_value_opaque_quantizer from ._quantization_helpers import _IdentityFunc aten = torch.ops.aten @@ -57,6 +58,9 @@ def copy(self) -> MXFP8Quantizer: return quantizer + def _value_fields(self) -> Tuple[str, ...]: + return ("dtype",) + def update_quantized( self, src: torch.Tensor, @@ -1058,3 +1062,6 @@ def backward( ) return dgrad, None return grad.view(ctx.shape), None + + +register_value_opaque_quantizer(MXFP8Quantizer) diff --git a/transformer_engine/pytorch/tensor/nvfp4_tensor.py b/transformer_engine/pytorch/tensor/nvfp4_tensor.py index aa92be004f..573c78907c 100644 --- a/transformer_engine/pytorch/tensor/nvfp4_tensor.py +++ b/transformer_engine/pytorch/tensor/nvfp4_tensor.py @@ -23,6 +23,7 @@ from .storage.nvfp4_tensor_storage import NVFP4TensorStorage, _FromNVFP4Func from ..quantized_tensor import QuantizedTensor, Quantizer +from ..dynamo import register_value_opaque_quantizer from ._quantization_helpers import _IdentityFunc aten = torch.ops.aten @@ -333,6 +334,30 @@ def _canonicalized_amax_reduction_group(self) -> dist_group_type: def _get_compatible_recipe(self) -> Union[type[Recipe], None]: return NVFP4BlockScaling + def _value_fields(self) -> Tuple[str, ...]: + # ``amax_reduction_group`` is intentionally excluded: it is a deprecated + # process group (not a value) and is restored as ``None`` on rebuild. + # ``rht_matrix_random_sign_mask_t`` is derived (from + # ``_with_random_sign_mask`` and the device) but is stored verbatim so + # reconstruction does not need to touch the device. + return ( + "dtype", + "with_rht", + "with_post_rht_amax", + "with_2d_quantization", + "stochastic_rounding", + "row_scaled_nvfp4", + "nvfp4_use_4over6", + "nvfp4_e4m3_max", + "nvfp4_4over6_err_mode", + "_with_random_sign_mask", + "rht_matrix_random_sign_mask_t", + "with_amax_reduction", + ) + + +register_value_opaque_quantizer(NVFP4Quantizer) + class NVFP4Tensor(NVFP4TensorStorage, QuantizedTensor): """Quantized tensor class with FP4 data From c4ad54c1b1f28d9d3402cff92ce244fb9f1a2bc4 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Sat, 6 Jun 2026 14:51:06 +0200 Subject: [PATCH 02/14] [PyTorch] Drop quantizer value registry; reconstruct via __fx_repr__ globals Follow-up to the value-opaque quantizer support: - Remove the module-level _QUANTIZER_VALUE_REGISTRY (qualname -> class) and _quantizer_from_value_key. __fx_repr__ now captures the quantizer class directly in the FX globals and reconstructs via _rebuild_quantizer(cls, items), matching how PyTorch's own value opaque types (e.g. DTensor placements) reconstruct themselves. This removes global mutable state and the qualname collision risk. - Consolidate the quantizer value-object tests in test_torch_compile.py down to two functions and exercise reconstruction through the public __fx_repr__ path instead of internal helpers. Signed-off-by: Pawel Gadzinski --- tests/pytorch/test_torch_compile.py | 138 +++--------------- transformer_engine/pytorch/dynamo.py | 53 +++---- .../pytorch/quantized_tensor.py | 22 +-- 3 files changed, 51 insertions(+), 162 deletions(-) diff --git a/tests/pytorch/test_torch_compile.py b/tests/pytorch/test_torch_compile.py index 8adcc88e61..4491e23881 100644 --- a/tests/pytorch/test_torch_compile.py +++ b/tests/pytorch/test_torch_compile.py @@ -36,11 +36,6 @@ Float8Quantizer, Float8BlockQuantizer, MXFP8Quantizer, - NVFP4Quantizer, -) -from transformer_engine.pytorch.dynamo import ( - register_value_opaque_quantizer, - _quantizer_from_value_key, ) from utils import recipe_id @@ -396,145 +391,60 @@ def fn(inp): # --------------------------------------------------------------------------- -# Value-opaque quantizers: eager value semantics + FX reconstruction +# Value-opaque quantizers # -# The tensorless quantizers (current-scaling FP8, FP8 blockwise, MXFP8, NVFP4) -# are torch.compile *value* opaque types: they provide value-based -# ``__eq__`` / ``__hash__`` and an evaluable ``__fx_repr__`` (see -# ``torch._library.opaque_object`` Note [Opaque Objects]). These tests exercise -# the eager value semantics and the FX reconstruction round-trip. They are -# CPU-friendly except for NVFP4 (whose constructor touches the current CUDA -# device). +# Tensorless quantizers (MXFP8, FP8 blockwise, FP8 current-scaling) are +# torch.compile *value* opaque types: value-based ``__eq__`` / ``__hash__`` plus +# an evaluable ``__fx_repr__`` that rebuilds an equal object (see +# ``torch._library.opaque_object`` Note [Opaque Objects]). # --------------------------------------------------------------------------- -def _mxfp8(dtype=tex.DType.kFloat8E4M3, rowwise=True, columnwise=True): - return MXFP8Quantizer(fp8_dtype=dtype, rowwise=rowwise, columnwise=columnwise) +def _mxfp8(dtype=tex.DType.kFloat8E4M3): + return MXFP8Quantizer(fp8_dtype=dtype) -def _blockwise(dtype=tex.DType.kFloat8E4M3, force_pow_2_scales=True, block_scaling_dim=2): +def _blockwise(force_pow_2_scales=True): return Float8BlockQuantizer( - fp8_dtype=dtype, + fp8_dtype=tex.DType.kFloat8E4M3, rowwise=True, columnwise=True, force_pow_2_scales=force_pow_2_scales, - block_scaling_dim=block_scaling_dim, ) -def _current_scaling(dtype=tex.DType.kFloat8E4M3, force_pow_2_scales=False, amax_epsilon=0.0): +def _current_scaling(amax_epsilon=0.0): return Float8CurrentScalingQuantizer( - fp8_dtype=dtype, + fp8_dtype=tex.DType.kFloat8E4M3, device=torch.device("cpu"), - force_pow_2_scales=force_pow_2_scales, amax_epsilon=amax_epsilon, ) -# (factory, kwargs_for_a_different_but_valid_config) -_CPU_VALUE_QUANTIZERS = [ +# (factory, kwargs producing a different-but-valid config) +_VALUE_QUANTIZERS = [ pytest.param(_mxfp8, {"dtype": tex.DType.kFloat8E5M2}, id="mxfp8"), pytest.param(_blockwise, {"force_pow_2_scales": False}, id="float8_blockwise"), pytest.param(_current_scaling, {"amax_epsilon": 1e-4}, id="float8_current_scaling"), ] -@pytest.mark.parametrize("factory, other_kwargs", _CPU_VALUE_QUANTIZERS) -def test_quantizer_value_equality(factory, other_kwargs): - """Same config -> equal & same hash; different config -> not equal.""" - a = factory() - b = factory() +@pytest.mark.parametrize("factory, other_kwargs", _VALUE_QUANTIZERS) +def test_quantizer_value_object(factory, other_kwargs): + """Value semantics + ``__fx_repr__`` round-trip via the production FX path.""" + a, b = factory(), factory() + # Same config -> equal, same hash, interchangeable as a dict/set key. assert a is not b assert a == b assert hash(a) == hash(b) + assert {a: "x"}[b] == "x" + # Different config -> not equal. + assert a != factory(**other_kwargs) - c = factory(**other_kwargs) - assert a != c - # Usage flags participate in the value. - d = factory() - d.set_usage(rowwise=False, columnwise=False) - assert a != d - - -@pytest.mark.parametrize("factory, other_kwargs", _CPU_VALUE_QUANTIZERS) -def test_quantizer_usable_in_set_and_dict(factory, other_kwargs): - a = factory() - b = factory() - c = factory(**other_kwargs) - assert len({a, b, c}) == 2 - mapping = {a: "x"} - assert mapping[b] == "x" - - -@pytest.mark.parametrize("factory, other_kwargs", _CPU_VALUE_QUANTIZERS) -def test_quantizer_cross_type_inequality(factory, other_kwargs): - a = factory() - other = _current_scaling() if not isinstance(a, Float8CurrentScalingQuantizer) else _mxfp8() - assert a != other - - -@pytest.mark.parametrize("factory, other_kwargs", _CPU_VALUE_QUANTIZERS) -def test_quantizer_fx_repr_roundtrip(factory, other_kwargs): - """``__fx_repr__`` returns an evaluable expression rebuilding an equal object.""" - a = factory() + # ``__fx_repr__`` (used by torch.compile codegen) rebuilds an equal object. repr_str, globals_ = a.__fx_repr__() - assert isinstance(repr_str, str) - assert isinstance(globals_, dict) rebuilt = eval(repr_str, dict(globals_)) # pylint: disable=eval-used - assert rebuilt == a - assert rebuilt is not a + assert rebuilt == a and rebuilt is not a assert hash(rebuilt) == hash(a) - - -@pytest.mark.parametrize("factory, other_kwargs", _CPU_VALUE_QUANTIZERS) -def test_quantizer_value_key_reconstruction(factory, other_kwargs): - a = factory() - rebuilt = _quantizer_from_value_key(a._value_key()) - assert type(rebuilt) is type(a) - assert rebuilt == a - # The deprecated amax-reduction process group is never carried in the value. + # The deprecated amax-reduction group is never part of the value. assert getattr(rebuilt, "amax_reduction_group", None) is None - - -def test_quantizer_delayed_scaling_keeps_identity_semantics(): - """Float8Quantizer holds live tensors -> identity (not value) semantics.""" - scale = torch.ones(1) - amax = torch.zeros(1) - a = Float8Quantizer(scale=scale, amax=amax, fp8_dtype=tex.DType.kFloat8E4M3) - b = Float8Quantizer(scale=scale, amax=amax, fp8_dtype=tex.DType.kFloat8E4M3) - assert a._value_fields() is None - assert a == a - assert a != b # distinct instances are not equal despite identical config - assert hash(a) == object.__hash__(a) - - -@pytest.mark.parametrize("factory, other_kwargs", _CPU_VALUE_QUANTIZERS) -def test_quantizer_registration_is_idempotent_and_tolerant(factory, other_kwargs): - """Re-registering must not raise, regardless of PyTorch opaque-object support.""" - cls = type(factory()) - register_value_opaque_quantizer(cls) - register_value_opaque_quantizer(cls) - - if not _opaque_available: - pytest.skip("PyTorch build without opaque-object API") - from torch._library.opaque_object import is_opaque_value_type - - assert is_opaque_value_type(cls) - - -@pytest.mark.skipif(not torch.cuda.is_available(), reason="NVFP4Quantizer requires CUDA") -def test_quantizer_nvfp4_value_semantics(): - a = NVFP4Quantizer(fp4_dtype=tex.DType.kFloat4E2M1) - b = NVFP4Quantizer(fp4_dtype=tex.DType.kFloat4E2M1) - assert a == b - assert hash(a) == hash(b) - - c = NVFP4Quantizer(fp4_dtype=tex.DType.kFloat4E2M1, with_rht=not a.with_rht) - assert a != c - - rebuilt = _quantizer_from_value_key(a._value_key()) - assert rebuilt == a - assert rebuilt.amax_reduction_group is None - - repr_str, globals_ = a.__fx_repr__() - assert eval(repr_str, dict(globals_)) == a # pylint: disable=eval-used diff --git a/transformer_engine/pytorch/dynamo.py b/transformer_engine/pytorch/dynamo.py index 26aff6f59c..16f73920ae 100644 --- a/transformer_engine/pytorch/dynamo.py +++ b/transformer_engine/pytorch/dynamo.py @@ -9,12 +9,14 @@ * :func:`register_value_opaque_quantizer` -- attaches the ``__fx_repr__`` used by FX codegen and registers the quantizer class with - ``torch._library.opaque_object``. It is a no-op (other than populating the - local registry) on PyTorch builds without the opaque-object API, so - importing Transformer Engine never fails on older PyTorch -- only - torch.compile specialization on the quantizer is unavailable there. - * :func:`_quantizer_from_value_key` -- rebuilds a quantizer constant from its - value key inside the generated FX graph. + ``torch._library.opaque_object``. It is a no-op on PyTorch builds without + the opaque-object API, so importing Transformer Engine never fails on older + PyTorch -- only torch.compile specialization on the quantizer is + unavailable there. + * :func:`_rebuild_quantizer` -- rebuilds a quantizer constant from its value + items inside the generated FX graph. The quantizer class is captured + directly in the FX globals (see :func:`_quantizer_fx_repr`), so no global + class registry is needed. The eager value semantics (``__eq__`` / ``__hash__`` / ``_value_key`` / ``_value_fields``) live on the quantizer itself; see @@ -22,7 +24,10 @@ See ``torch._library.opaque_object`` Note [Opaque Objects] for the contract a value-typed opaque object must satisfy (``__eq__`` / ``__hash__`` / -``__fx_repr__``). +``__fx_repr__``). The ``__fx_repr__`` contract -- ``(repr_string, {name: type})`` +where ``repr_string`` references the names in the dict -- is exactly how +PyTorch's own value opaque types (e.g. DTensor placements) reconstruct +themselves, including across the on-disk compile cache. """ from __future__ import annotations @@ -31,26 +36,16 @@ from .constants import DType -# Maps a quantizer class qualname to the class object. A value key stores only -# the qualname, so reconstruction looks the class up here. Populated by -# ``register_value_opaque_quantizer`` at import time of each tensor module; this -# avoids importing the tensor modules into this module (which would create an -# import cycle). -_QUANTIZER_VALUE_REGISTRY: Dict[str, type] = {} - - -def _quantizer_from_value_key(key: Tuple[Any, ...]) -> Any: - """Rebuild a tensorless quantizer from its value key. +def _rebuild_quantizer(cls: type, items: Tuple[Tuple[str, Any], ...]) -> Any: + """Rebuild a tensorless quantizer of type *cls* from its value items. Referenced by the ``__fx_repr__`` emitted for value-opaque quantizers; the generated FX code calls this to materialize the quantizer constant. The deprecated amax-reduction process group is never part of the value, so a reconstructed quantizer always starts with no stored group. """ - qualname, items = key[0], key[1] - cls = _QUANTIZER_VALUE_REGISTRY[qualname] # Bypass ``__init__`` and restore the value attributes directly: the value - # key already captures every value-defining field (including derived ones), + # items already capture every value-defining field (including derived ones), # and the constructors have heterogeneous signatures / side effects. obj = cls.__new__(cls) field_names = set() @@ -71,12 +66,15 @@ def _quantizer_fx_repr(self: Any) -> Tuple[str, Dict[str, Any]]: """``__fx_repr__`` for value-opaque quantizers (attached at registration). Returns an evaluable expression that rebuilds the quantizer via - :func:`_quantizer_from_value_key`, together with the globals needed to - evaluate it. + :func:`_rebuild_quantizer`, capturing both the helper and the quantizer + class itself in the FX globals so codegen can resolve them with no global + registry and no qualname collisions. """ + cls = type(self) + items = self._value_key()[1] return ( - f"_quantizer_from_value_key({self._value_key()!r})", - {"_quantizer_from_value_key": _quantizer_from_value_key}, + f"_rebuild_quantizer({cls.__name__}, {items!r})", + {"_rebuild_quantizer": _rebuild_quantizer, cls.__name__: cls}, ) @@ -85,16 +83,13 @@ def register_value_opaque_quantizer(cls: type) -> None: Attaches ``__fx_repr__`` and registers the class with ``torch._library.opaque_object``. Safe to call on any PyTorch build: on - versions without the opaque-object API it only records the class in the - local registry and attaches ``__fx_repr__`` (both harmless), so Transformer - Engine keeps importing and running in eager mode. + versions without the opaque-object API it only attaches ``__fx_repr__`` + (harmless), so Transformer Engine keeps importing and running in eager mode. The quantizer class must already provide value ``__eq__`` / ``__hash__`` and a non-``None`` ``_value_fields`` (see :class:`transformer_engine.pytorch.quantized_tensor.Quantizer`). """ - _QUANTIZER_VALUE_REGISTRY[cls.__qualname__] = cls - # ``register_opaque_type`` requires ``__fx_repr__`` to already exist on the # class, so attach it before registering. if "__fx_repr__" not in cls.__dict__: diff --git a/transformer_engine/pytorch/quantized_tensor.py b/transformer_engine/pytorch/quantized_tensor.py index f37b3cc63d..6809893a40 100644 --- a/transformer_engine/pytorch/quantized_tensor.py +++ b/transformer_engine/pytorch/quantized_tensor.py @@ -408,20 +408,6 @@ def get_usages(self) -> Dict[str, bool]: "columnwise": self.columnwise_usage, } - # ----- Value-object identity (torch.compile opaque value support) ----- - # A *tensorless* quantizer (one whose entire state is a handful of plain, - # reproducible scalars -- no live tensors, no process groups) behaves like a - # value: two instances with the same configuration are interchangeable. Such - # quantizers opt into value-based ``__eq__`` / ``__hash__`` by overriding - # ``_value_fields``. Quantizers that keep the default (e.g. delayed-scaling - # ``Float8Quantizer``, which holds live scale/amax tensors, and any custom - # quantizer) retain the default identity semantics. - # - # This is the eager-side half of registering the quantizer as a torch.compile - # *value* opaque type; the torch.compile glue (``__fx_repr__``, FX - # reconstruction and ``register_opaque_type``) lives in - # ``transformer_engine.pytorch.dynamo``. - #: Attributes shared by every quantizer that take part in value identity. _BASE_VALUE_FIELDS: Tuple[str, ...] = ( "rowwise_usage", @@ -433,11 +419,9 @@ def get_usages(self) -> Dict[str, bool]: def _value_fields(self) -> Optional[Tuple[str, ...]]: """Subclass-specific value-defining attribute names, or ``None``. - Returning ``None`` (the default) means the quantizer is *not* a value - object and keeps identity-based equality/hashing. Tensorless quantizers - override this to return the tuple of attribute names that, together with - :attr:`_BASE_VALUE_FIELDS`, fully determine their value (excluding - non-value state such as a deprecated amax-reduction process group). + Returning ``None`` (the default) means the quantizer cannot be represented as + a value opaque object and keeps identity-based equality/hashing. + This also means, that torch.compile will not be able to optimize the quantizer. """ return None From a06324bd4f10354d76b5fbec133f07709bd6a1da Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Sun, 7 Jun 2026 15:42:31 +0200 Subject: [PATCH 03/14] [PyTorch] Split dynamo.py into a dynamo/ package Replace the single dynamo.py module with a dynamo/ package so the torch.compile glue can grow with a clear responsibility split across the stacked branches. This branch owns the value-opaque quantizer layer. * dynamo/quantizer_opaque.py -- register_value_opaque_quantizer and helpers * dynamo/__init__.py -- re-exports the public API so callers keep importing from transformer_engine.pytorch.dynamo unchanged Signed-off-by: Pawel Gadzinski --- transformer_engine/pytorch/dynamo/__init__.py | 18 ++++++++++++++++++ .../{dynamo.py => dynamo/quantizer_opaque.py} | 7 +++---- 2 files changed, 21 insertions(+), 4 deletions(-) create mode 100644 transformer_engine/pytorch/dynamo/__init__.py rename transformer_engine/pytorch/{dynamo.py => dynamo/quantizer_opaque.py} (95%) diff --git a/transformer_engine/pytorch/dynamo/__init__.py b/transformer_engine/pytorch/dynamo/__init__.py new file mode 100644 index 0000000000..44ca61d470 --- /dev/null +++ b/transformer_engine/pytorch/dynamo/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""torch.compile glue for Transformer Engine. + +Public API is re-exported here so callers keep importing from +``transformer_engine.pytorch.dynamo`` regardless of the internal module layout: + + * :mod:`.quantizer_opaque` -- make a tensorless quantizer a torch.compile + *value* opaque type (:func:`register_value_opaque_quantizer`). +""" + +from .quantizer_opaque import register_value_opaque_quantizer + +__all__ = [ + "register_value_opaque_quantizer", +] diff --git a/transformer_engine/pytorch/dynamo.py b/transformer_engine/pytorch/dynamo/quantizer_opaque.py similarity index 95% rename from transformer_engine/pytorch/dynamo.py rename to transformer_engine/pytorch/dynamo/quantizer_opaque.py index 16f73920ae..c7455846ee 100644 --- a/transformer_engine/pytorch/dynamo.py +++ b/transformer_engine/pytorch/dynamo/quantizer_opaque.py @@ -2,10 +2,9 @@ # # See LICENSE for license information. -"""torch.compile glue for Transformer Engine quantizers. +"""Value-opaque quantizers for torch.compile. -This module isolates the torch.compile-specific plumbing that turns a -*tensorless* quantizer into a torch.compile **value** opaque type: +Turns a *tensorless* quantizer into a torch.compile **value** opaque type: * :func:`register_value_opaque_quantizer` -- attaches the ``__fx_repr__`` used by FX codegen and registers the quantizer class with @@ -33,7 +32,7 @@ class registry is needed. from __future__ import annotations from typing import Any, Dict, Tuple -from .constants import DType +from ..constants import DType def _rebuild_quantizer(cls: type, items: Tuple[Tuple[str, Any], ...]) -> Any: From ea5b396b9e1e3ee67905c58b50e30989f5ab095d Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Mon, 8 Jun 2026 12:39:40 +0200 Subject: [PATCH 04/14] [PyTorch] Raise in quantizer __fx_repr__ when a process group is stored A value-opaque quantizer must not carry live distributed state. Scan the quantizer attributes in __fx_repr__ and raise TypeError if any holds a torch.distributed.ProcessGroup (e.g. a non-None deprecated amax_reduction_group), so it cannot be silently baked into a torch.compile FX graph. Clarify the related comments accordingly. Signed-off-by: Pawel Gadzinski --- tests/pytorch/test_torch_compile.py | 7 -- transformer_engine/pytorch/dynamo/__init__.py | 9 +-- .../pytorch/dynamo/quantizer_opaque.py | 70 ++++++++++--------- .../pytorch/quantized_tensor.py | 4 +- .../pytorch/tensor/float8_tensor.py | 3 +- .../pytorch/tensor/nvfp4_tensor.py | 3 +- 6 files changed, 45 insertions(+), 51 deletions(-) diff --git a/tests/pytorch/test_torch_compile.py b/tests/pytorch/test_torch_compile.py index 4491e23881..9a7a4a356a 100644 --- a/tests/pytorch/test_torch_compile.py +++ b/tests/pytorch/test_torch_compile.py @@ -392,11 +392,6 @@ def fn(inp): # --------------------------------------------------------------------------- # Value-opaque quantizers -# -# Tensorless quantizers (MXFP8, FP8 blockwise, FP8 current-scaling) are -# torch.compile *value* opaque types: value-based ``__eq__`` / ``__hash__`` plus -# an evaluable ``__fx_repr__`` that rebuilds an equal object (see -# ``torch._library.opaque_object`` Note [Opaque Objects]). # --------------------------------------------------------------------------- @@ -446,5 +441,3 @@ def test_quantizer_value_object(factory, other_kwargs): rebuilt = eval(repr_str, dict(globals_)) # pylint: disable=eval-used assert rebuilt == a and rebuilt is not a assert hash(rebuilt) == hash(a) - # The deprecated amax-reduction group is never part of the value. - assert getattr(rebuilt, "amax_reduction_group", None) is None diff --git a/transformer_engine/pytorch/dynamo/__init__.py b/transformer_engine/pytorch/dynamo/__init__.py index 44ca61d470..aae8b9cff6 100644 --- a/transformer_engine/pytorch/dynamo/__init__.py +++ b/transformer_engine/pytorch/dynamo/__init__.py @@ -2,14 +2,7 @@ # # See LICENSE for license information. -"""torch.compile glue for Transformer Engine. - -Public API is re-exported here so callers keep importing from -``transformer_engine.pytorch.dynamo`` regardless of the internal module layout: - - * :mod:`.quantizer_opaque` -- make a tensorless quantizer a torch.compile - *value* opaque type (:func:`register_value_opaque_quantizer`). -""" +"""torch.compile glue for Transformer Engine.""" from .quantizer_opaque import register_value_opaque_quantizer diff --git a/transformer_engine/pytorch/dynamo/quantizer_opaque.py b/transformer_engine/pytorch/dynamo/quantizer_opaque.py index c7455846ee..55bc8326e6 100644 --- a/transformer_engine/pytorch/dynamo/quantizer_opaque.py +++ b/transformer_engine/pytorch/dynamo/quantizer_opaque.py @@ -2,46 +2,35 @@ # # See LICENSE for license information. -"""Value-opaque quantizers for torch.compile. - -Turns a *tensorless* quantizer into a torch.compile **value** opaque type: - - * :func:`register_value_opaque_quantizer` -- attaches the ``__fx_repr__`` used - by FX codegen and registers the quantizer class with - ``torch._library.opaque_object``. It is a no-op on PyTorch builds without - the opaque-object API, so importing Transformer Engine never fails on older - PyTorch -- only torch.compile specialization on the quantizer is - unavailable there. - * :func:`_rebuild_quantizer` -- rebuilds a quantizer constant from its value - items inside the generated FX graph. The quantizer class is captured - directly in the FX globals (see :func:`_quantizer_fx_repr`), so no global - class registry is needed. - -The eager value semantics (``__eq__`` / ``__hash__`` / ``_value_key`` / -``_value_fields``) live on the quantizer itself; see -:class:`transformer_engine.pytorch.quantized_tensor.Quantizer`. - -See ``torch._library.opaque_object`` Note [Opaque Objects] for the contract a -value-typed opaque object must satisfy (``__eq__`` / ``__hash__`` / -``__fx_repr__``). The ``__fx_repr__`` contract -- ``(repr_string, {name: type})`` -where ``repr_string`` references the names in the dict -- is exactly how -PyTorch's own value opaque types (e.g. DTensor placements) reconstruct -themselves, including across the on-disk compile cache. -""" +"""Value-opaque quantizers for torch.compile.""" from __future__ import annotations from typing import Any, Dict, Tuple -from ..constants import DType +from ..constants import DType, dist_group_type + + +def _contains_process_group(value: Any) -> bool: + """Whether *value* is (or nests) a ``torch.distributed.ProcessGroup``. + + Checks the value directly and one level of ``tuple``/``list`` nesting, which + covers the shapes a quantizer value field could plausibly take. + """ + if isinstance(value, dist_group_type): + return True + if isinstance(value, (tuple, list)): + return any(_contains_process_group(item) for item in value) + return False def _rebuild_quantizer(cls: type, items: Tuple[Tuple[str, Any], ...]) -> Any: """Rebuild a tensorless quantizer of type *cls* from its value items. Referenced by the ``__fx_repr__`` emitted for value-opaque quantizers; the - generated FX code calls this to materialize the quantizer constant. The - deprecated amax-reduction process group is never part of the value, so a - reconstructed quantizer always starts with no stored group. + generated FX code calls this to materialize the quantizer constant. A + quantizer that actually stores a process group never reaches this path: + ``__fx_repr__`` raises for it. The deprecated amax-reduction group is not a + value field, so the rebuilt quantizer simply has no group attribute. """ # Bypass ``__init__`` and restore the value attributes directly: the value # items already capture every value-defining field (including derived ones), @@ -53,9 +42,9 @@ def _rebuild_quantizer(cls: type, items: Tuple[Tuple[str, Any], ...]) -> Any: value = DType.cast(value) object.__setattr__(obj, name, value) field_names.add(name) - # The deprecated amax-reduction process group is excluded from the value; - # restore it as ``None`` for quantizers that still carry the fallback so - # attribute access keeps working. + # The deprecated amax-reduction group is not a value field. Quantizers that + # actually hold a group error out in ``__fx_repr__`` before reaching here, so + # this only initializes the (groupless) attribute to keep access working. if "with_amax_reduction" in field_names and not hasattr(obj, "amax_reduction_group"): object.__setattr__(obj, "amax_reduction_group", None) return obj @@ -68,8 +57,23 @@ def _quantizer_fx_repr(self: Any) -> Tuple[str, Dict[str, Any]]: :func:`_rebuild_quantizer`, capturing both the helper and the quantizer class itself in the FX globals so codegen can resolve them with no global registry and no qualname collisions. + + Raises ``TypeError`` if the quantizer stores a process group (e.g. a + non-``None`` deprecated ``amax_reduction_group``): live distributed state + must never be baked into the graph as a constant, so such a quantizer cannot + be used with ``torch.compile``. Pass the reduction group per quantize call + instead of storing it on the quantizer. """ cls = type(self) + for name, value in vars(self).items(): + if _contains_process_group(value): + raise TypeError( + f"{cls.__name__} cannot be used with torch.compile: attribute " + f"{name!r} holds a torch.distributed.ProcessGroup, which is live " + "distributed state and must not be baked into an FX graph as a " + "constant. Pass the amax reduction group per quantize call instead " + "of storing it on the quantizer." + ) items = self._value_key()[1] return ( f"_rebuild_quantizer({cls.__name__}, {items!r})", diff --git a/transformer_engine/pytorch/quantized_tensor.py b/transformer_engine/pytorch/quantized_tensor.py index 6809893a40..3612c1080d 100644 --- a/transformer_engine/pytorch/quantized_tensor.py +++ b/transformer_engine/pytorch/quantized_tensor.py @@ -421,7 +421,9 @@ def _value_fields(self) -> Optional[Tuple[str, ...]]: Returning ``None`` (the default) means the quantizer cannot be represented as a value opaque object and keeps identity-based equality/hashing. - This also means, that torch.compile will not be able to optimize the quantizer. + This also means that passing such a quantizer as an argument to a custom op + causes a graph break under torch.compile, since it cannot be baked into the + FX graph as a constant. """ return None diff --git a/transformer_engine/pytorch/tensor/float8_tensor.py b/transformer_engine/pytorch/tensor/float8_tensor.py index 56b1ecfb09..0310c3855c 100644 --- a/transformer_engine/pytorch/tensor/float8_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_tensor.py @@ -389,7 +389,8 @@ def supports_only_rowwise_all_gather(self) -> bool: def _value_fields(self) -> Tuple[str, ...]: # ``amax_reduction_group`` is intentionally excluded: it is a deprecated - # process group (not a value) and is restored as ``None`` on rebuild. + # process group (not a value). If one is actually stored, ``__fx_repr__`` + # raises so it can never be baked into a torch.compile graph. return ("dtype", "force_pow_2_scales", "amax_epsilon", "with_amax_reduction") diff --git a/transformer_engine/pytorch/tensor/nvfp4_tensor.py b/transformer_engine/pytorch/tensor/nvfp4_tensor.py index 573c78907c..4bca783922 100644 --- a/transformer_engine/pytorch/tensor/nvfp4_tensor.py +++ b/transformer_engine/pytorch/tensor/nvfp4_tensor.py @@ -336,7 +336,8 @@ def _get_compatible_recipe(self) -> Union[type[Recipe], None]: def _value_fields(self) -> Tuple[str, ...]: # ``amax_reduction_group`` is intentionally excluded: it is a deprecated - # process group (not a value) and is restored as ``None`` on rebuild. + # process group (not a value). If one is actually stored, ``__fx_repr__`` + # raises so it can never be baked into a torch.compile graph. # ``rht_matrix_random_sign_mask_t`` is derived (from # ``_with_random_sign_mask`` and the device) but is stored verbatim so # reconstruction does not need to touch the device. From aa65e34e2ed8ba784543caa703ef7962062d3546 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Mon, 8 Jun 2026 16:17:36 +0200 Subject: [PATCH 05/14] [PyTorch] Cover NVFP4 in quantizer value-object test NVFP4Quantizer is registered as a value-opaque quantizer but was missing from the value-semantics / __fx_repr__ round-trip test. Add it to _VALUE_QUANTIZERS (skipped without CUDA, which it needs to construct). Signed-off-by: Pawel Gadzinski --- tests/pytorch/test_torch_compile.py | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/tests/pytorch/test_torch_compile.py b/tests/pytorch/test_torch_compile.py index 9a7a4a356a..3405001e04 100644 --- a/tests/pytorch/test_torch_compile.py +++ b/tests/pytorch/test_torch_compile.py @@ -27,7 +27,7 @@ from transformer_engine.pytorch.quantization import QuantizerRole from transformer_engine.pytorch.ops.basic.basic_linear import BasicLinear from transformer_engine.pytorch.tensor.float8_tensor import Float8CurrentScalingQuantizer -from transformer_engine.pytorch.quantization import QuantizerRole +from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer from transformer_engine.pytorch import ( is_fp8_available, is_mxfp8_available, @@ -416,11 +416,29 @@ def _current_scaling(amax_epsilon=0.0): ) +def _nvfp4(with_rht=False): + return NVFP4Quantizer( + fp4_dtype=tex.DType.kFloat4E2M1, + rowwise=True, + columnwise=True, + with_rht=with_rht, + ) + + # (factory, kwargs producing a different-but-valid config) _VALUE_QUANTIZERS = [ pytest.param(_mxfp8, {"dtype": tex.DType.kFloat8E5M2}, id="mxfp8"), pytest.param(_blockwise, {"force_pow_2_scales": False}, id="float8_blockwise"), pytest.param(_current_scaling, {"amax_epsilon": 1e-4}, id="float8_current_scaling"), + pytest.param( + _nvfp4, + {"with_rht": True}, + id="nvfp4", + marks=pytest.mark.skipif( + not torch.cuda.is_available(), + reason="NVFP4Quantizer requires CUDA to construct", + ), + ), ] From e1b1db6b1a9de58a6868f1a706659733792476cc Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Tue, 16 Jun 2026 17:13:15 +0200 Subject: [PATCH 06/14] Reject a value quantizer that carries an amax reduction group in __eq__/__hash__ The amax reduction group is excluded from the value key, so a value quantizer that stored one would compare/hash equal to a groupless one and let torch.compile reuse a graph that skips the reduction. __eq__/__hash__ now raise (mirroring __fx_repr__, which already rejects any process-group-bearing quantizer). The group should be passed per quantize call, not stored on the quantizer. Co-Authored-By: Claude Opus 4.8 (1M context) Signed-off-by: Pawel Gadzinski --- transformer_engine/pytorch/quantized_tensor.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/transformer_engine/pytorch/quantized_tensor.py b/transformer_engine/pytorch/quantized_tensor.py index 3612c1080d..b7357b0e7b 100644 --- a/transformer_engine/pytorch/quantized_tensor.py +++ b/transformer_engine/pytorch/quantized_tensor.py @@ -444,6 +444,18 @@ def _value_key(self) -> Tuple[Any, ...]: items.append((name, value)) return (type(self).__qualname__, tuple(items)) + def _check_value_has_no_amax_reduction_group(self) -> None: + # The amax reduction group is not part of the value key, so a value + # quantizer that stores one would compare/hash equal to a groupless one + # and let torch.compile reuse a graph that skips the reduction. Reject it + # (mirrors ``__fx_repr__``); pass the group per quantize call instead. + if getattr(self, "amax_reduction_group", None) is not None: + raise TypeError( + f"{type(self).__name__} with a non-None amax_reduction_group cannot be " + "used as a value object; pass the amax reduction group per quantize call " + "instead of storing it on the quantizer." + ) + def __eq__(self, other: object) -> Any: # Value quantizers compare by configuration; everything else keeps the # default identity semantics (returning ``NotImplemented`` makes Python @@ -454,11 +466,14 @@ def __eq__(self, other: object) -> Any: return NotImplemented if other._value_fields() is None: return NotImplemented + self._check_value_has_no_amax_reduction_group() + other._check_value_has_no_amax_reduction_group() return self._value_key() == other._value_key() def __hash__(self) -> int: if self._value_fields() is None: return object.__hash__(self) + self._check_value_has_no_amax_reduction_group() return hash(self._value_key()) From 8c33d0ec4dd8ff255a75827fd643d3619e6ae9af Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Tue, 16 Jun 2026 18:03:29 +0200 Subject: [PATCH 07/14] Recognize value-opaque quantizers via a class flag Add is_value_opaque_quantizer() + the _te_compile_value_opaque flag stamped at registration, so dynamo-traced code can detect registered quantizers (and fall back to eager for unregistered ones). Co-Authored-By: Claude Opus 4.8 (1M context) Signed-off-by: Pawel Gadzinski --- transformer_engine/pytorch/dynamo/__init__.py | 3 ++- .../pytorch/dynamo/quantizer_opaque.py | 15 +++++++++++++++ 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/dynamo/__init__.py b/transformer_engine/pytorch/dynamo/__init__.py index aae8b9cff6..ee860c78e3 100644 --- a/transformer_engine/pytorch/dynamo/__init__.py +++ b/transformer_engine/pytorch/dynamo/__init__.py @@ -4,8 +4,9 @@ """torch.compile glue for Transformer Engine.""" -from .quantizer_opaque import register_value_opaque_quantizer +from .quantizer_opaque import register_value_opaque_quantizer, is_value_opaque_quantizer __all__ = [ "register_value_opaque_quantizer", + "is_value_opaque_quantizer", ] diff --git a/transformer_engine/pytorch/dynamo/quantizer_opaque.py b/transformer_engine/pytorch/dynamo/quantizer_opaque.py index 55bc8326e6..6cb9552b71 100644 --- a/transformer_engine/pytorch/dynamo/quantizer_opaque.py +++ b/transformer_engine/pytorch/dynamo/quantizer_opaque.py @@ -10,6 +10,17 @@ from ..constants import DType, dist_group_type +# Class attribute stamped on quantizers registered as torch.compile value-opaque +# types. +_VALUE_OPAQUE_FLAG = "_te_compile_value_opaque" + + +def is_value_opaque_quantizer(quantizer: Any) -> bool: + """Whether *quantizer*'s class is registered as a torch.compile value-opaque + type.""" + return getattr(quantizer, _VALUE_OPAQUE_FLAG, False) + + def _contains_process_group(value: Any) -> bool: """Whether *value* is (or nests) a ``torch.distributed.ProcessGroup``. @@ -93,6 +104,10 @@ def register_value_opaque_quantizer(cls: type) -> None: a non-``None`` ``_value_fields`` (see :class:`transformer_engine.pytorch.quantized_tensor.Quantizer`). """ + # Stamp the class so it can be recognized as value-opaque in dynamo-traced + # code (used to fall back to eager for unregistered quantizers). + setattr(cls, _VALUE_OPAQUE_FLAG, True) + # ``register_opaque_type`` requires ``__fx_repr__`` to already exist on the # class, so attach it before registering. if "__fx_repr__" not in cls.__dict__: From 945f62dadd18d5d0bb7b68c54f4bfb9767699521 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Mon, 29 Jun 2026 11:34:37 +0200 Subject: [PATCH 08/14] Address review: narrow opaque-type except, add fullgraph test, fix nvfp4 value key - Narrow register_opaque_type except to (RuntimeError, TypeError): the API is already imported above, so ImportError/AttributeError there only mask real errors. - Add test_quantizer_value_object_fullgraph exercising torch.compile(fullgraph=True) end-to-end to verify opaque-type registration took effect. - Restore missing NVFP4Quantizer._with_random_sign_mask assignment required by _value_fields()/_value_key(). Co-Authored-By: Claude Opus 4.8 Signed-off-by: Pawel Gadzinski --- tests/pytorch/test_torch_compile.py | 18 ++++++++++++++++++ .../pytorch/dynamo/quantizer_opaque.py | 5 +++-- .../pytorch/tensor/nvfp4_tensor.py | 1 + 3 files changed, 22 insertions(+), 2 deletions(-) diff --git a/tests/pytorch/test_torch_compile.py b/tests/pytorch/test_torch_compile.py index 3405001e04..5e1f753f08 100644 --- a/tests/pytorch/test_torch_compile.py +++ b/tests/pytorch/test_torch_compile.py @@ -459,3 +459,21 @@ def test_quantizer_value_object(factory, other_kwargs): rebuilt = eval(repr_str, dict(globals_)) # pylint: disable=eval-used assert rebuilt == a and rebuilt is not a assert hash(rebuilt) == hash(a) + + +@pytest.mark.skipif( + not _opaque_available, + reason="torch.compile opaque-object support requires PyTorch >= 2.11", +) +@pytest.mark.parametrize("factory, other_kwargs", _VALUE_QUANTIZERS) +def test_quantizer_value_object_fullgraph(factory, other_kwargs): + """Quantizer survives torch.compile(fullgraph=True) - verifies registration took effect.""" + + def fn(quantizer): + return quantizer + + torch._dynamo.reset() + compiled = torch.compile(fn, fullgraph=True) + + quantizer = factory() + assert compiled(quantizer) is quantizer diff --git a/transformer_engine/pytorch/dynamo/quantizer_opaque.py b/transformer_engine/pytorch/dynamo/quantizer_opaque.py index 6cb9552b71..6d630d665c 100644 --- a/transformer_engine/pytorch/dynamo/quantizer_opaque.py +++ b/transformer_engine/pytorch/dynamo/quantizer_opaque.py @@ -128,6 +128,7 @@ def register_value_opaque_quantizer(cls: type) -> None: try: register_opaque_type(cls, typ="value") - except (ImportError, AttributeError, RuntimeError, TypeError): - # Tolerate partial / experimental opaque-object support. + except (RuntimeError, TypeError): + # Keep TE importable: registration must never crash the import, e.g. on + # PyTorch versions with only partial / experimental opaque-object support. pass diff --git a/transformer_engine/pytorch/tensor/nvfp4_tensor.py b/transformer_engine/pytorch/tensor/nvfp4_tensor.py index 4bca783922..ffc5f97eca 100644 --- a/transformer_engine/pytorch/tensor/nvfp4_tensor.py +++ b/transformer_engine/pytorch/tensor/nvfp4_tensor.py @@ -174,6 +174,7 @@ def __init__( self.nvfp4_4over6_err_mode = nvfp4_4over6_err_mode.upper() if self.nvfp4_4over6_err_mode not in ("MAE", "MSE"): raise ValueError("nvfp4_4over6_err_mode must be 'MAE' or 'MSE'.") + self._with_random_sign_mask = with_random_sign_mask self.rht_matrix_random_sign_mask_t = get_random_sign_mask_for_rht( with_random_sign_mask, torch.cuda.current_device() ) From e3c8f430883762ac7be289616bf79f40eb522a0d Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Mon, 29 Jun 2026 12:05:01 +0200 Subject: [PATCH 09/14] Restore NVFP4 rht_matrix on value-key rebuild; assert quantize round-trip _rebuild_quantizer only restores value-key fields, so a reconstructed NVFP4Quantizer was missing the derived rht_matrix tensor (not hashable, so not in the value key) and failed at copy()/quantize time. Add a _rebuild_derived_state hook (called by _rebuild_quantizer) that NVFP4Quantizer uses to rebuild rht_matrix from _with_random_sign_mask (lru_cache -> cheap). Extend test_quantizer_value_object to also quantize with the original and the rebuilt quantizer and require bit-exact results (gated on HW support), so a field the kernel needs but the value key omits can no longer slip through. Co-Authored-By: Claude Opus 4.8 Signed-off-by: Pawel Gadzinski --- tests/pytorch/test_torch_compile.py | 27 +++++++++++++++++-- .../pytorch/dynamo/quantizer_opaque.py | 5 ++++ .../pytorch/tensor/nvfp4_tensor.py | 10 +++++++ 3 files changed, 40 insertions(+), 2 deletions(-) diff --git a/tests/pytorch/test_torch_compile.py b/tests/pytorch/test_torch_compile.py index 5e1f753f08..da4f96d2f3 100644 --- a/tests/pytorch/test_torch_compile.py +++ b/tests/pytorch/test_torch_compile.py @@ -416,7 +416,10 @@ def _current_scaling(amax_epsilon=0.0): ) -def _nvfp4(with_rht=False): +def _nvfp4(with_rht=True): + # Default with_rht=True so the quantize round-trip below exercises the + # derived ``rht_matrix`` tensor (the field most likely to be dropped on + # value-key reconstruction). return NVFP4Quantizer( fp4_dtype=tex.DType.kFloat4E2M1, rowwise=True, @@ -425,6 +428,17 @@ def _nvfp4(with_rht=False): ) +def _hw_available(quantizer): + """Whether this HW can actually run the quantize kernel for *quantizer*.""" + if isinstance(quantizer, MXFP8Quantizer): + return mxfp8_available + if isinstance(quantizer, NVFP4Quantizer): + return nvfp4_available + if isinstance(quantizer, Float8BlockQuantizer): + return fp8_block_scaling_available + return fp8_available # Float8CurrentScalingQuantizer + + # (factory, kwargs producing a different-but-valid config) _VALUE_QUANTIZERS = [ pytest.param(_mxfp8, {"dtype": tex.DType.kFloat8E5M2}, id="mxfp8"), @@ -432,7 +446,7 @@ def _nvfp4(with_rht=False): pytest.param(_current_scaling, {"amax_epsilon": 1e-4}, id="float8_current_scaling"), pytest.param( _nvfp4, - {"with_rht": True}, + {"with_rht": False}, id="nvfp4", marks=pytest.mark.skipif( not torch.cuda.is_available(), @@ -460,6 +474,15 @@ def test_quantizer_value_object(factory, other_kwargs): assert rebuilt == a and rebuilt is not a assert hash(rebuilt) == hash(a) + # The rebuilt quantizer must also *behave* identically, not just compare + # equal: equality only looks at the value key, so a field the kernel needs + # but that is absent from the key (e.g. NVFP4's derived ``rht_matrix``) would + # slip through the checks above and only blow up at quantize time. Run the + # real quantize kernel on both and require bit-exact results. + if torch.cuda.is_available() and _hw_available(a): + x = torch.randn(128, 256, dtype=torch.bfloat16, device="cuda") + torch.testing.assert_close(rebuilt(x).dequantize(), a(x).dequantize(), rtol=0.0, atol=0.0) + @pytest.mark.skipif( not _opaque_available, diff --git a/transformer_engine/pytorch/dynamo/quantizer_opaque.py b/transformer_engine/pytorch/dynamo/quantizer_opaque.py index 6d630d665c..97532b2790 100644 --- a/transformer_engine/pytorch/dynamo/quantizer_opaque.py +++ b/transformer_engine/pytorch/dynamo/quantizer_opaque.py @@ -58,6 +58,11 @@ def _rebuild_quantizer(cls: type, items: Tuple[Tuple[str, Any], ...]) -> Any: # this only initializes the (groupless) attribute to keep access working. if "with_amax_reduction" in field_names and not hasattr(obj, "amax_reduction_group"): object.__setattr__(obj, "amax_reduction_group", None) + # Restore non-value derived state that ``__init__`` would normally build but + # that cannot live in the value key (e.g. NVFP4's ``rht_matrix`` tensor). + finalize = getattr(obj, "_rebuild_derived_state", None) + if finalize is not None: + finalize() return obj diff --git a/transformer_engine/pytorch/tensor/nvfp4_tensor.py b/transformer_engine/pytorch/tensor/nvfp4_tensor.py index ffc5f97eca..f2f30cdcc5 100644 --- a/transformer_engine/pytorch/tensor/nvfp4_tensor.py +++ b/transformer_engine/pytorch/tensor/nvfp4_tensor.py @@ -186,6 +186,16 @@ def __getstate__(self): state["amax_reduction_group"] = None return state + def _rebuild_derived_state(self) -> None: + """Restore the derived ``rht_matrix`` after value-key reconstruction. + + ``rht_matrix`` is a ``torch.Tensor`` built from ``_with_random_sign_mask`` + and the device, so it cannot be part of the (hashable) value key. + ``_rebuild_quantizer`` calls this hook to rebuild it; the ``lru_cache`` on + :func:`get_rht_matrix` makes an already-seen (flag, device) a cheap hit. + """ + self.rht_matrix = get_rht_matrix(self._with_random_sign_mask, torch.cuda.current_device()) + def update_quantized( self, src: torch.Tensor, From 3f6862137b30aa30c2980cc1cb9cc750599b02fa Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Mon, 29 Jun 2026 14:46:49 +0200 Subject: [PATCH 10/14] Enforce process-group rejection in _value_key, not __fx_repr__; add test Move the ProcessGroup guard out of the (overridable) __fx_repr__ into Quantizer._value_key -- the single point every value-materialization path (__eq__/__hash__/__fx_repr__) goes through -- so a custom __fx_repr__ can no longer bypass it. Generalizes the old amax-only check to any field holding a ProcessGroup. Add a test that a value quantizer carrying a live group raises. Addresses review on NVIDIA#3152. Co-Authored-By: Claude Opus 4.8 Signed-off-by: Pawel Gadzinski --- tests/pytorch/test_torch_compile.py | 21 ++++++++ .../pytorch/dynamo/quantizer_opaque.py | 34 +++---------- .../pytorch/quantized_tensor.py | 50 +++++++++++++------ 3 files changed, 61 insertions(+), 44 deletions(-) diff --git a/tests/pytorch/test_torch_compile.py b/tests/pytorch/test_torch_compile.py index da4f96d2f3..dc98e13708 100644 --- a/tests/pytorch/test_torch_compile.py +++ b/tests/pytorch/test_torch_compile.py @@ -484,6 +484,27 @@ def test_quantizer_value_object(factory, other_kwargs): torch.testing.assert_close(rebuilt(x).dequantize(), a(x).dequantize(), rtol=0.0, atol=0.0) +def test_value_quantizer_rejects_process_group(): + """A value quantizer holding a live ProcessGroup must refuse to be turned + into a value key / FX constant (raise), not silently drop the group.""" + import torch.distributed as dist # pylint: disable=import-outside-toplevel + + created = not dist.is_initialized() + if created: + dist.init_process_group(backend="gloo", store=dist.HashStore(), rank=0, world_size=1) + try: + q = MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3) + q.amax_reduction_group = dist.group.WORLD + # Every value-materialization path must reject it (hash, eq, __fx_repr__). + with pytest.raises(TypeError): + hash(q) + with pytest.raises(TypeError): + q.__fx_repr__() + finally: + if created: + dist.destroy_process_group() + + @pytest.mark.skipif( not _opaque_available, reason="torch.compile opaque-object support requires PyTorch >= 2.11", diff --git a/transformer_engine/pytorch/dynamo/quantizer_opaque.py b/transformer_engine/pytorch/dynamo/quantizer_opaque.py index 97532b2790..37e718689d 100644 --- a/transformer_engine/pytorch/dynamo/quantizer_opaque.py +++ b/transformer_engine/pytorch/dynamo/quantizer_opaque.py @@ -7,7 +7,7 @@ from __future__ import annotations from typing import Any, Dict, Tuple -from ..constants import DType, dist_group_type +from ..constants import DType # Class attribute stamped on quantizers registered as torch.compile value-opaque @@ -21,19 +21,6 @@ def is_value_opaque_quantizer(quantizer: Any) -> bool: return getattr(quantizer, _VALUE_OPAQUE_FLAG, False) -def _contains_process_group(value: Any) -> bool: - """Whether *value* is (or nests) a ``torch.distributed.ProcessGroup``. - - Checks the value directly and one level of ``tuple``/``list`` nesting, which - covers the shapes a quantizer value field could plausibly take. - """ - if isinstance(value, dist_group_type): - return True - if isinstance(value, (tuple, list)): - return any(_contains_process_group(item) for item in value) - return False - - def _rebuild_quantizer(cls: type, items: Tuple[Tuple[str, Any], ...]) -> Any: """Rebuild a tensorless quantizer of type *cls* from its value items. @@ -74,22 +61,13 @@ def _quantizer_fx_repr(self: Any) -> Tuple[str, Dict[str, Any]]: class itself in the FX globals so codegen can resolve them with no global registry and no qualname collisions. - Raises ``TypeError`` if the quantizer stores a process group (e.g. a - non-``None`` deprecated ``amax_reduction_group``): live distributed state - must never be baked into the graph as a constant, so such a quantizer cannot - be used with ``torch.compile``. Pass the reduction group per quantize call - instead of storing it on the quantizer. + Raises ``TypeError`` (via :meth:`Quantizer._value_key`) if the quantizer + stores a process group (e.g. a non-``None`` deprecated + ``amax_reduction_group``): live distributed state must never be baked into + the graph as a constant. Pass the reduction group per quantize call instead + of storing it on the quantizer. """ cls = type(self) - for name, value in vars(self).items(): - if _contains_process_group(value): - raise TypeError( - f"{cls.__name__} cannot be used with torch.compile: attribute " - f"{name!r} holds a torch.distributed.ProcessGroup, which is live " - "distributed state and must not be baked into an FX graph as a " - "constant. Pass the amax reduction group per quantize call instead " - "of storing it on the quantizer." - ) items = self._value_key()[1] return ( f"_rebuild_quantizer({cls.__name__}, {items!r})", diff --git a/transformer_engine/pytorch/quantized_tensor.py b/transformer_engine/pytorch/quantized_tensor.py index b7357b0e7b..033a35f8e1 100644 --- a/transformer_engine/pytorch/quantized_tensor.py +++ b/transformer_engine/pytorch/quantized_tensor.py @@ -16,6 +16,7 @@ import transformer_engine_torch as tex from transformer_engine.common.recipe import Recipe +from transformer_engine.pytorch.constants import dist_group_type from transformer_engine.pytorch.tensor._quantization_helpers import ( _QuantizeFunc, _IdentityFunc, @@ -23,6 +24,19 @@ ) +def _contains_process_group(value: Any) -> bool: + """Whether *value* is (or nests) a ``torch.distributed.ProcessGroup``. + + Checks the value directly and one level of ``tuple``/``list`` nesting, which + covers the shapes a quantizer value field could plausibly take. + """ + if isinstance(value, dist_group_type): + return True + if isinstance(value, (tuple, list)): + return any(_contains_process_group(item) for item in value) + return False + + # Custom ops that should pass through __torch_dispatch__ without unwrapping # QuantizedTensor subclasses (e.g. Float8Tensor). Register ops here that # handle quantized tensors internally. @@ -427,6 +441,24 @@ def _value_fields(self) -> Optional[Tuple[str, ...]]: """ return None + def _check_value_has_no_process_group(self) -> None: + # A value quantizer is baked into the FX graph as a constant via its + # value key, which cannot carry live distributed state. Enforced here -- + # the single point every value-materialization path (``__eq__`` / + # ``__hash__`` / ``__fx_repr__``) goes through -- so a custom + # ``__fx_repr__`` cannot bypass it. Reject any field holding a + # ProcessGroup (e.g. the deprecated ``amax_reduction_group``) rather than + # silently dropping it; pass the reduction group per quantize call. + for name, value in vars(self).items(): + if _contains_process_group(value): + raise TypeError( + f"{type(self).__name__} cannot be used as a torch.compile value " + f"object: attribute {name!r} holds a torch.distributed.ProcessGroup, " + "which is live distributed state and must not be baked into an FX " + "graph. Pass the amax reduction group per quantize call instead of " + "storing it on the quantizer." + ) + def _value_key(self) -> Tuple[Any, ...]: """Hashable, reproducible key identifying this quantizer's value. @@ -434,6 +466,7 @@ def _value_key(self) -> Tuple[Any, ...]: """ fields = self._value_fields() # pylint: disable=assignment-from-none assert fields is not None, f"{type(self).__name__} is not a value quantizer" + self._check_value_has_no_process_group() items = [] for name in self._BASE_VALUE_FIELDS + tuple(fields): value = getattr(self, name) @@ -444,36 +477,21 @@ def _value_key(self) -> Tuple[Any, ...]: items.append((name, value)) return (type(self).__qualname__, tuple(items)) - def _check_value_has_no_amax_reduction_group(self) -> None: - # The amax reduction group is not part of the value key, so a value - # quantizer that stores one would compare/hash equal to a groupless one - # and let torch.compile reuse a graph that skips the reduction. Reject it - # (mirrors ``__fx_repr__``); pass the group per quantize call instead. - if getattr(self, "amax_reduction_group", None) is not None: - raise TypeError( - f"{type(self).__name__} with a non-None amax_reduction_group cannot be " - "used as a value object; pass the amax reduction group per quantize call " - "instead of storing it on the quantizer." - ) - def __eq__(self, other: object) -> Any: # Value quantizers compare by configuration; everything else keeps the # default identity semantics (returning ``NotImplemented`` makes Python - # fall back to identity). + # fall back to identity). ``_value_key`` rejects a stored ProcessGroup. if self is other: return True if self._value_fields() is None or type(self) is not type(other): return NotImplemented if other._value_fields() is None: return NotImplemented - self._check_value_has_no_amax_reduction_group() - other._check_value_has_no_amax_reduction_group() return self._value_key() == other._value_key() def __hash__(self) -> int: if self._value_fields() is None: return object.__hash__(self) - self._check_value_has_no_amax_reduction_group() return hash(self._value_key()) From 32d17683f036390a6825902ebacba0e7bf592050 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Mon, 29 Jun 2026 15:08:39 +0200 Subject: [PATCH 11/14] Strengthen fullgraph test: quantize/dequantize via a custom op, not passthrough Replace the trivial pass-through fullgraph test with one that drives each production quantizer through a minimal custom op (quantize + dequantize) under torch.compile(fullgraph=True) and compares to eager -- so the opaque-type registration is actually exercised inside the graph (a graph break would make fullgraph=True raise). Op registration sits right before the test. Also drop stale comments referencing the old __fx_repr__-side process-group guard. Co-Authored-By: Claude Opus 4.8 Signed-off-by: Pawel Gadzinski --- tests/pytorch/test_torch_compile.py | 54 ++++++++++++++++--- .../pytorch/dynamo/quantizer_opaque.py | 10 ++-- .../pytorch/tensor/nvfp4_tensor.py | 3 +- 3 files changed, 51 insertions(+), 16 deletions(-) diff --git a/tests/pytorch/test_torch_compile.py b/tests/pytorch/test_torch_compile.py index dc98e13708..63cb82eca8 100644 --- a/tests/pytorch/test_torch_compile.py +++ b/tests/pytorch/test_torch_compile.py @@ -505,19 +505,59 @@ def test_value_quantizer_rejects_process_group(): dist.destroy_process_group() +if _opaque_available: + # A minimal custom op taking a tensor and a value-opaque quantizer that + # quantizes + dequantizes inside it, one per production quantizer class. + # ``test_quantizer_value_object_fullgraph`` drives this under + # ``torch.compile(fullgraph=True)`` so the quantizer is used *inside* the + # graph -- proving the opaque-type registration took effect (a graph break + # would make ``fullgraph=True`` raise). + _qdq_lib = torch.library.Library("test_te_qdq", "DEF") + _QDQ_OPS = {} + for _qcls in ( + MXFP8Quantizer, + Float8BlockQuantizer, + Float8CurrentScalingQuantizer, + NVFP4Quantizer, + ): + _op = f"qdq_{_qcls.__name__}" + _qdq_lib.define(f"{_op}(Tensor x, {get_opaque_type_name(_qcls)} q) -> Tensor") + + @torch.library.impl(f"test_te_qdq::{_op}", "CompositeExplicitAutograd", lib=_qdq_lib) + def _qdq_impl(x, q): + return q(x).dequantize() + + @torch.library.register_fake(f"test_te_qdq::{_op}", lib=_qdq_lib) + def _qdq_fake(x, q): + return torch.empty_like(x) + + _QDQ_OPS[_qcls] = getattr(torch.ops.test_te_qdq, _op) + + @pytest.mark.skipif( not _opaque_available, reason="torch.compile opaque-object support requires PyTorch >= 2.11", ) @pytest.mark.parametrize("factory, other_kwargs", _VALUE_QUANTIZERS) def test_quantizer_value_object_fullgraph(factory, other_kwargs): - """Quantizer survives torch.compile(fullgraph=True) - verifies registration took effect.""" + """Quantizer is usable *inside* a torch.compile(fullgraph=True) graph. - def fn(quantizer): - return quantizer + A custom op quantizes+dequantizes with the (opaque value) quantizer; the + compiled result must match eager. ``fullgraph=True`` raises on any graph + break, so this proves the opaque-type registration actually took effect -- + unlike merely passing the quantizer through. + """ + q = factory() + if not (torch.cuda.is_available() and _hw_available(q)): + pytest.skip("format not supported on this HW") - torch._dynamo.reset() - compiled = torch.compile(fn, fullgraph=True) + op = _QDQ_OPS[type(q)] + x = torch.randn(128, 256, dtype=torch.bfloat16, device="cuda") + + def fn(inp): + return op(inp, q) - quantizer = factory() - assert compiled(quantizer) is quantizer + ref = fn(x) + torch._dynamo.reset() + out = torch.compile(fn, fullgraph=True)(x) + torch.testing.assert_close(out, ref, rtol=0.0, atol=0.0) diff --git a/transformer_engine/pytorch/dynamo/quantizer_opaque.py b/transformer_engine/pytorch/dynamo/quantizer_opaque.py index 37e718689d..409cb979d1 100644 --- a/transformer_engine/pytorch/dynamo/quantizer_opaque.py +++ b/transformer_engine/pytorch/dynamo/quantizer_opaque.py @@ -25,10 +25,7 @@ def _rebuild_quantizer(cls: type, items: Tuple[Tuple[str, Any], ...]) -> Any: """Rebuild a tensorless quantizer of type *cls* from its value items. Referenced by the ``__fx_repr__`` emitted for value-opaque quantizers; the - generated FX code calls this to materialize the quantizer constant. A - quantizer that actually stores a process group never reaches this path: - ``__fx_repr__`` raises for it. The deprecated amax-reduction group is not a - value field, so the rebuilt quantizer simply has no group attribute. + generated FX code calls this to materialize the quantizer constant. """ # Bypass ``__init__`` and restore the value attributes directly: the value # items already capture every value-defining field (including derived ones), @@ -40,9 +37,8 @@ def _rebuild_quantizer(cls: type, items: Tuple[Tuple[str, Any], ...]) -> Any: value = DType.cast(value) object.__setattr__(obj, name, value) field_names.add(name) - # The deprecated amax-reduction group is not a value field. Quantizers that - # actually hold a group error out in ``__fx_repr__`` before reaching here, so - # this only initializes the (groupless) attribute to keep access working. + # The deprecated amax-reduction group is not a value field; initialize it to + # None so attribute access keeps working on the rebuilt quantizer. if "with_amax_reduction" in field_names and not hasattr(obj, "amax_reduction_group"): object.__setattr__(obj, "amax_reduction_group", None) # Restore non-value derived state that ``__init__`` would normally build but diff --git a/transformer_engine/pytorch/tensor/nvfp4_tensor.py b/transformer_engine/pytorch/tensor/nvfp4_tensor.py index f2f30cdcc5..d9cd534606 100644 --- a/transformer_engine/pytorch/tensor/nvfp4_tensor.py +++ b/transformer_engine/pytorch/tensor/nvfp4_tensor.py @@ -347,8 +347,7 @@ def _get_compatible_recipe(self) -> Union[type[Recipe], None]: def _value_fields(self) -> Tuple[str, ...]: # ``amax_reduction_group`` is intentionally excluded: it is a deprecated - # process group (not a value). If one is actually stored, ``__fx_repr__`` - # raises so it can never be baked into a torch.compile graph. + # process group, not a value (``_value_key`` rejects a stored group). # ``rht_matrix_random_sign_mask_t`` is derived (from # ``_with_random_sign_mask`` and the device) but is stored verbatim so # reconstruction does not need to touch the device. From 28bde9e7040a8b4c87b992cd5717b7518d19e000 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Mon, 29 Jun 2026 15:29:11 +0200 Subject: [PATCH 12/14] Clarify comments: rht_matrix_random_sign_mask_t derivation; why the opaque flag - rht_matrix_random_sign_mask_t is a device-independent int derived from _with_random_sign_mask (the device only places a throwaway tensor); fix the misleading comment. - Explain why registration uses a class attribute, not a registry set: is_value_opaque_quantizer is traced inside the compile graph and dynamo can bake a getattr constant but cannot do 'type(q) in set' on the opaque class. Co-Authored-By: Claude Opus 4.8 Signed-off-by: Pawel Gadzinski --- transformer_engine/pytorch/dynamo/quantizer_opaque.py | 8 ++++++-- transformer_engine/pytorch/tensor/nvfp4_tensor.py | 6 +++--- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/transformer_engine/pytorch/dynamo/quantizer_opaque.py b/transformer_engine/pytorch/dynamo/quantizer_opaque.py index 409cb979d1..33e831ef38 100644 --- a/transformer_engine/pytorch/dynamo/quantizer_opaque.py +++ b/transformer_engine/pytorch/dynamo/quantizer_opaque.py @@ -10,8 +10,12 @@ from ..constants import DType -# Class attribute stamped on quantizers registered as torch.compile value-opaque -# types. +# Registration marks the class with this attribute instead of recording it in a +# module-level set. ``is_value_opaque_quantizer`` runs *inside* the torch.compile +# graph (``Linear.forward`` consults it): Dynamo can trace a ``getattr`` on the +# opaque quantizer and bake the result as a constant, but cannot evaluate +# ``type(q) in some_set`` -- it has no equality/hash rules for the opaque class +# object, so a set/dict lookup graph-breaks under ``fullgraph=True``. _VALUE_OPAQUE_FLAG = "_te_compile_value_opaque" diff --git a/transformer_engine/pytorch/tensor/nvfp4_tensor.py b/transformer_engine/pytorch/tensor/nvfp4_tensor.py index d9cd534606..c8c0a7b854 100644 --- a/transformer_engine/pytorch/tensor/nvfp4_tensor.py +++ b/transformer_engine/pytorch/tensor/nvfp4_tensor.py @@ -348,9 +348,9 @@ def _get_compatible_recipe(self) -> Union[type[Recipe], None]: def _value_fields(self) -> Tuple[str, ...]: # ``amax_reduction_group`` is intentionally excluded: it is a deprecated # process group, not a value (``_value_key`` rejects a stored group). - # ``rht_matrix_random_sign_mask_t`` is derived (from - # ``_with_random_sign_mask`` and the device) but is stored verbatim so - # reconstruction does not need to touch the device. + # ``rht_matrix_random_sign_mask_t`` is a device-independent int derived + # from ``_with_random_sign_mask``; kept in the key so the rebuilt + # quantizer carries it without recomputation. return ( "dtype", "with_rht", From 2c3c5df0fb3488b8e95670066ac53f958640a6de Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Mon, 29 Jun 2026 15:30:44 +0200 Subject: [PATCH 13/14] Reword opaque-flag comment: self-contained, no Linear reference Co-Authored-By: Claude Opus 4.8 Signed-off-by: Pawel Gadzinski --- .../pytorch/dynamo/quantizer_opaque.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/transformer_engine/pytorch/dynamo/quantizer_opaque.py b/transformer_engine/pytorch/dynamo/quantizer_opaque.py index 33e831ef38..98349e12ba 100644 --- a/transformer_engine/pytorch/dynamo/quantizer_opaque.py +++ b/transformer_engine/pytorch/dynamo/quantizer_opaque.py @@ -10,12 +10,12 @@ from ..constants import DType -# Registration marks the class with this attribute instead of recording it in a -# module-level set. ``is_value_opaque_quantizer`` runs *inside* the torch.compile -# graph (``Linear.forward`` consults it): Dynamo can trace a ``getattr`` on the -# opaque quantizer and bake the result as a constant, but cannot evaluate -# ``type(q) in some_set`` -- it has no equality/hash rules for the opaque class -# object, so a set/dict lookup graph-breaks under ``fullgraph=True``. +# Registration marks the class with this attribute rather than recording it in a +# module-level set. It looks odd but is a deliberate workaround: the check must +# stay traceable when it runs inside a torch.compile graph -- Dynamo can bake a +# ``getattr`` on the opaque quantizer into a constant, but cannot evaluate +# ``type(q) in some_set`` (no equality/hash rules for the opaque class object), +# which would graph-break under ``fullgraph=True``. _VALUE_OPAQUE_FLAG = "_te_compile_value_opaque" From 826f271ebb28466cdd94b9ee5cc6875eb61d4a97 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Mon, 29 Jun 2026 15:38:53 +0200 Subject: [PATCH 14/14] Cover is_opaque_value_type with the import-safety guard too is_opaque_value_type(cls) sat between the import guard and the register_opaque_type guard, so on a partial/experimental opaque-object build it could raise RuntimeError/TypeError and crash TE import. Move it inside the same except so the 'registration never crashes import' promise holds for both calls. Co-Authored-By: Claude Opus 4.8 Signed-off-by: Pawel Gadzinski --- transformer_engine/pytorch/dynamo/quantizer_opaque.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/transformer_engine/pytorch/dynamo/quantizer_opaque.py b/transformer_engine/pytorch/dynamo/quantizer_opaque.py index 98349e12ba..8b8b3caa69 100644 --- a/transformer_engine/pytorch/dynamo/quantizer_opaque.py +++ b/transformer_engine/pytorch/dynamo/quantizer_opaque.py @@ -106,12 +106,11 @@ def register_value_opaque_quantizer(cls: type) -> None: # still work; torch.compile specialization on the quantizer does not. return - if is_opaque_value_type(cls): - return - try: - register_opaque_type(cls, typ="value") + if not is_opaque_value_type(cls): + register_opaque_type(cls, typ="value") except (RuntimeError, TypeError): - # Keep TE importable: registration must never crash the import, e.g. on - # PyTorch versions with only partial / experimental opaque-object support. + # Keep TE importable: neither the opaque-type query nor the registration + # must crash the import, e.g. on PyTorch versions with only partial / + # experimental opaque-object support. pass