diff --git a/tests/pytorch/test_torch_compile.py b/tests/pytorch/test_torch_compile.py index 1286492a6e..63cb82eca8 100644 --- a/tests/pytorch/test_torch_compile.py +++ b/tests/pytorch/test_torch_compile.py @@ -24,14 +24,18 @@ from transformer_engine.common import recipe from transformer_engine.pytorch.constants import FP8FwdTensorIdx, FP8BwdTensorIdx from transformer_engine.pytorch.module.base import TransformerEngineBaseModule +from transformer_engine.pytorch.quantization import QuantizerRole from transformer_engine.pytorch.ops.basic.basic_linear import BasicLinear from transformer_engine.pytorch.tensor.float8_tensor import Float8CurrentScalingQuantizer -from transformer_engine.pytorch.quantization import QuantizerRole +from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer from transformer_engine.pytorch import ( is_fp8_available, is_mxfp8_available, is_fp8_block_scaling_available, is_nvfp4_available, + Float8Quantizer, + Float8BlockQuantizer, + MXFP8Quantizer, ) from utils import recipe_id @@ -384,3 +388,176 @@ def fn(inp): out = compiled(inp) out.sum().backward() + + +# --------------------------------------------------------------------------- +# Value-opaque quantizers +# --------------------------------------------------------------------------- + + +def _mxfp8(dtype=tex.DType.kFloat8E4M3): + return MXFP8Quantizer(fp8_dtype=dtype) + + +def _blockwise(force_pow_2_scales=True): + return Float8BlockQuantizer( + fp8_dtype=tex.DType.kFloat8E4M3, + rowwise=True, + columnwise=True, + force_pow_2_scales=force_pow_2_scales, + ) + + +def _current_scaling(amax_epsilon=0.0): + return Float8CurrentScalingQuantizer( + fp8_dtype=tex.DType.kFloat8E4M3, + device=torch.device("cpu"), + amax_epsilon=amax_epsilon, + ) + + +def _nvfp4(with_rht=True): + # Default with_rht=True so the quantize round-trip below exercises the + # derived ``rht_matrix`` tensor (the field most likely to be dropped on + # value-key reconstruction). + return NVFP4Quantizer( + fp4_dtype=tex.DType.kFloat4E2M1, + rowwise=True, + columnwise=True, + with_rht=with_rht, + ) + + +def _hw_available(quantizer): + """Whether this HW can actually run the quantize kernel for *quantizer*.""" + if isinstance(quantizer, MXFP8Quantizer): + return mxfp8_available + if isinstance(quantizer, NVFP4Quantizer): + return nvfp4_available + if isinstance(quantizer, Float8BlockQuantizer): + return fp8_block_scaling_available + return fp8_available # Float8CurrentScalingQuantizer + + +# (factory, kwargs producing a different-but-valid config) +_VALUE_QUANTIZERS = [ + pytest.param(_mxfp8, {"dtype": tex.DType.kFloat8E5M2}, id="mxfp8"), + pytest.param(_blockwise, {"force_pow_2_scales": False}, id="float8_blockwise"), + pytest.param(_current_scaling, {"amax_epsilon": 1e-4}, id="float8_current_scaling"), + pytest.param( + _nvfp4, + {"with_rht": False}, + id="nvfp4", + marks=pytest.mark.skipif( + not torch.cuda.is_available(), + reason="NVFP4Quantizer requires CUDA to construct", + ), + ), +] + + +@pytest.mark.parametrize("factory, other_kwargs", _VALUE_QUANTIZERS) +def test_quantizer_value_object(factory, other_kwargs): + """Value semantics + ``__fx_repr__`` round-trip via the production FX path.""" + a, b = factory(), factory() + # Same config -> equal, same hash, interchangeable as a dict/set key. + assert a is not b + assert a == b + assert hash(a) == hash(b) + assert {a: "x"}[b] == "x" + # Different config -> not equal. + assert a != factory(**other_kwargs) + + # ``__fx_repr__`` (used by torch.compile codegen) rebuilds an equal object. + repr_str, globals_ = a.__fx_repr__() + rebuilt = eval(repr_str, dict(globals_)) # pylint: disable=eval-used + assert rebuilt == a and rebuilt is not a + assert hash(rebuilt) == hash(a) + + # The rebuilt quantizer must also *behave* identically, not just compare + # equal: equality only looks at the value key, so a field the kernel needs + # but that is absent from the key (e.g. NVFP4's derived ``rht_matrix``) would + # slip through the checks above and only blow up at quantize time. Run the + # real quantize kernel on both and require bit-exact results. + if torch.cuda.is_available() and _hw_available(a): + x = torch.randn(128, 256, dtype=torch.bfloat16, device="cuda") + torch.testing.assert_close(rebuilt(x).dequantize(), a(x).dequantize(), rtol=0.0, atol=0.0) + + +def test_value_quantizer_rejects_process_group(): + """A value quantizer holding a live ProcessGroup must refuse to be turned + into a value key / FX constant (raise), not silently drop the group.""" + import torch.distributed as dist # pylint: disable=import-outside-toplevel + + created = not dist.is_initialized() + if created: + dist.init_process_group(backend="gloo", store=dist.HashStore(), rank=0, world_size=1) + try: + q = MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3) + q.amax_reduction_group = dist.group.WORLD + # Every value-materialization path must reject it (hash, eq, __fx_repr__). + with pytest.raises(TypeError): + hash(q) + with pytest.raises(TypeError): + q.__fx_repr__() + finally: + if created: + dist.destroy_process_group() + + +if _opaque_available: + # A minimal custom op taking a tensor and a value-opaque quantizer that + # quantizes + dequantizes inside it, one per production quantizer class. + # ``test_quantizer_value_object_fullgraph`` drives this under + # ``torch.compile(fullgraph=True)`` so the quantizer is used *inside* the + # graph -- proving the opaque-type registration took effect (a graph break + # would make ``fullgraph=True`` raise). + _qdq_lib = torch.library.Library("test_te_qdq", "DEF") + _QDQ_OPS = {} + for _qcls in ( + MXFP8Quantizer, + Float8BlockQuantizer, + Float8CurrentScalingQuantizer, + NVFP4Quantizer, + ): + _op = f"qdq_{_qcls.__name__}" + _qdq_lib.define(f"{_op}(Tensor x, {get_opaque_type_name(_qcls)} q) -> Tensor") + + @torch.library.impl(f"test_te_qdq::{_op}", "CompositeExplicitAutograd", lib=_qdq_lib) + def _qdq_impl(x, q): + return q(x).dequantize() + + @torch.library.register_fake(f"test_te_qdq::{_op}", lib=_qdq_lib) + def _qdq_fake(x, q): + return torch.empty_like(x) + + _QDQ_OPS[_qcls] = getattr(torch.ops.test_te_qdq, _op) + + +@pytest.mark.skipif( + not _opaque_available, + reason="torch.compile opaque-object support requires PyTorch >= 2.11", +) +@pytest.mark.parametrize("factory, other_kwargs", _VALUE_QUANTIZERS) +def test_quantizer_value_object_fullgraph(factory, other_kwargs): + """Quantizer is usable *inside* a torch.compile(fullgraph=True) graph. + + A custom op quantizes+dequantizes with the (opaque value) quantizer; the + compiled result must match eager. ``fullgraph=True`` raises on any graph + break, so this proves the opaque-type registration actually took effect -- + unlike merely passing the quantizer through. + """ + q = factory() + if not (torch.cuda.is_available() and _hw_available(q)): + pytest.skip("format not supported on this HW") + + op = _QDQ_OPS[type(q)] + x = torch.randn(128, 256, dtype=torch.bfloat16, device="cuda") + + def fn(inp): + return op(inp, q) + + ref = fn(x) + torch._dynamo.reset() + out = torch.compile(fn, fullgraph=True)(x) + torch.testing.assert_close(out, ref, rtol=0.0, atol=0.0) diff --git a/transformer_engine/pytorch/dynamo/__init__.py b/transformer_engine/pytorch/dynamo/__init__.py new file mode 100644 index 0000000000..ee860c78e3 --- /dev/null +++ b/transformer_engine/pytorch/dynamo/__init__.py @@ -0,0 +1,12 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""torch.compile glue for Transformer Engine.""" + +from .quantizer_opaque import register_value_opaque_quantizer, is_value_opaque_quantizer + +__all__ = [ + "register_value_opaque_quantizer", + "is_value_opaque_quantizer", +] diff --git a/transformer_engine/pytorch/dynamo/quantizer_opaque.py b/transformer_engine/pytorch/dynamo/quantizer_opaque.py new file mode 100644 index 0000000000..8b8b3caa69 --- /dev/null +++ b/transformer_engine/pytorch/dynamo/quantizer_opaque.py @@ -0,0 +1,116 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Value-opaque quantizers for torch.compile.""" + +from __future__ import annotations +from typing import Any, Dict, Tuple + +from ..constants import DType + + +# Registration marks the class with this attribute rather than recording it in a +# module-level set. It looks odd but is a deliberate workaround: the check must +# stay traceable when it runs inside a torch.compile graph -- Dynamo can bake a +# ``getattr`` on the opaque quantizer into a constant, but cannot evaluate +# ``type(q) in some_set`` (no equality/hash rules for the opaque class object), +# which would graph-break under ``fullgraph=True``. +_VALUE_OPAQUE_FLAG = "_te_compile_value_opaque" + + +def is_value_opaque_quantizer(quantizer: Any) -> bool: + """Whether *quantizer*'s class is registered as a torch.compile value-opaque + type.""" + return getattr(quantizer, _VALUE_OPAQUE_FLAG, False) + + +def _rebuild_quantizer(cls: type, items: Tuple[Tuple[str, Any], ...]) -> Any: + """Rebuild a tensorless quantizer of type *cls* from its value items. + + Referenced by the ``__fx_repr__`` emitted for value-opaque quantizers; the + generated FX code calls this to materialize the quantizer constant. + """ + # Bypass ``__init__`` and restore the value attributes directly: the value + # items already capture every value-defining field (including derived ones), + # and the constructors have heterogeneous signatures / side effects. + obj = cls.__new__(cls) + field_names = set() + for name, value in items: + if name == "dtype": + value = DType.cast(value) + object.__setattr__(obj, name, value) + field_names.add(name) + # The deprecated amax-reduction group is not a value field; initialize it to + # None so attribute access keeps working on the rebuilt quantizer. + if "with_amax_reduction" in field_names and not hasattr(obj, "amax_reduction_group"): + object.__setattr__(obj, "amax_reduction_group", None) + # Restore non-value derived state that ``__init__`` would normally build but + # that cannot live in the value key (e.g. NVFP4's ``rht_matrix`` tensor). + finalize = getattr(obj, "_rebuild_derived_state", None) + if finalize is not None: + finalize() + return obj + + +def _quantizer_fx_repr(self: Any) -> Tuple[str, Dict[str, Any]]: + """``__fx_repr__`` for value-opaque quantizers (attached at registration). + + Returns an evaluable expression that rebuilds the quantizer via + :func:`_rebuild_quantizer`, capturing both the helper and the quantizer + class itself in the FX globals so codegen can resolve them with no global + registry and no qualname collisions. + + Raises ``TypeError`` (via :meth:`Quantizer._value_key`) if the quantizer + stores a process group (e.g. a non-``None`` deprecated + ``amax_reduction_group``): live distributed state must never be baked into + the graph as a constant. Pass the reduction group per quantize call instead + of storing it on the quantizer. + """ + cls = type(self) + items = self._value_key()[1] + return ( + f"_rebuild_quantizer({cls.__name__}, {items!r})", + {"_rebuild_quantizer": _rebuild_quantizer, cls.__name__: cls}, + ) + + +def register_value_opaque_quantizer(cls: type) -> None: + """Register a tensorless quantizer class as a torch.compile value opaque type. + + Attaches ``__fx_repr__`` and registers the class with + ``torch._library.opaque_object``. Safe to call on any PyTorch build: on + versions without the opaque-object API it only attaches ``__fx_repr__`` + (harmless), so Transformer Engine keeps importing and running in eager mode. + + The quantizer class must already provide value ``__eq__`` / ``__hash__`` and + a non-``None`` ``_value_fields`` (see + :class:`transformer_engine.pytorch.quantized_tensor.Quantizer`). + """ + # Stamp the class so it can be recognized as value-opaque in dynamo-traced + # code (used to fall back to eager for unregistered quantizers). + setattr(cls, _VALUE_OPAQUE_FLAG, True) + + # ``register_opaque_type`` requires ``__fx_repr__`` to already exist on the + # class, so attach it before registering. + if "__fx_repr__" not in cls.__dict__: + cls.__fx_repr__ = _quantizer_fx_repr + + try: + from torch._library.opaque_object import ( # pylint: disable=import-outside-toplevel + register_opaque_type, + is_opaque_value_type, + ) + except (ImportError, AttributeError): + # Older PyTorch without the opaque-object API: eager value semantics + # still work; torch.compile specialization on the quantizer does not. + return + + try: + if not is_opaque_value_type(cls): + register_opaque_type(cls, typ="value") + except (RuntimeError, TypeError): + # Keep TE importable: neither the opaque-type query nor the registration + # must crash the import, e.g. on PyTorch versions with only partial / + # experimental opaque-object support. + pass diff --git a/transformer_engine/pytorch/quantized_tensor.py b/transformer_engine/pytorch/quantized_tensor.py index cfe488aae5..033a35f8e1 100644 --- a/transformer_engine/pytorch/quantized_tensor.py +++ b/transformer_engine/pytorch/quantized_tensor.py @@ -16,6 +16,7 @@ import transformer_engine_torch as tex from transformer_engine.common.recipe import Recipe +from transformer_engine.pytorch.constants import dist_group_type from transformer_engine.pytorch.tensor._quantization_helpers import ( _QuantizeFunc, _IdentityFunc, @@ -23,6 +24,19 @@ ) +def _contains_process_group(value: Any) -> bool: + """Whether *value* is (or nests) a ``torch.distributed.ProcessGroup``. + + Checks the value directly and one level of ``tuple``/``list`` nesting, which + covers the shapes a quantizer value field could plausibly take. + """ + if isinstance(value, dist_group_type): + return True + if isinstance(value, (tuple, list)): + return any(_contains_process_group(item) for item in value) + return False + + # Custom ops that should pass through __torch_dispatch__ without unwrapping # QuantizedTensor subclasses (e.g. Float8Tensor). Register ops here that # handle quantized tensors internally. @@ -408,6 +422,78 @@ def get_usages(self) -> Dict[str, bool]: "columnwise": self.columnwise_usage, } + #: Attributes shared by every quantizer that take part in value identity. + _BASE_VALUE_FIELDS: Tuple[str, ...] = ( + "rowwise_usage", + "columnwise_usage", + "internal", + "optimize_for_gemm", + ) + + def _value_fields(self) -> Optional[Tuple[str, ...]]: + """Subclass-specific value-defining attribute names, or ``None``. + + Returning ``None`` (the default) means the quantizer cannot be represented as + a value opaque object and keeps identity-based equality/hashing. + This also means that passing such a quantizer as an argument to a custom op + causes a graph break under torch.compile, since it cannot be baked into the + FX graph as a constant. + """ + return None + + def _check_value_has_no_process_group(self) -> None: + # A value quantizer is baked into the FX graph as a constant via its + # value key, which cannot carry live distributed state. Enforced here -- + # the single point every value-materialization path (``__eq__`` / + # ``__hash__`` / ``__fx_repr__``) goes through -- so a custom + # ``__fx_repr__`` cannot bypass it. Reject any field holding a + # ProcessGroup (e.g. the deprecated ``amax_reduction_group``) rather than + # silently dropping it; pass the reduction group per quantize call. + for name, value in vars(self).items(): + if _contains_process_group(value): + raise TypeError( + f"{type(self).__name__} cannot be used as a torch.compile value " + f"object: attribute {name!r} holds a torch.distributed.ProcessGroup, " + "which is live distributed state and must not be baked into an FX " + "graph. Pass the amax reduction group per quantize call instead of " + "storing it on the quantizer." + ) + + def _value_key(self) -> Tuple[Any, ...]: + """Hashable, reproducible key identifying this quantizer's value. + + Only valid for value quantizers (``_value_fields()`` is not ``None``). + """ + fields = self._value_fields() # pylint: disable=assignment-from-none + assert fields is not None, f"{type(self).__name__} is not a value quantizer" + self._check_value_has_no_process_group() + items = [] + for name in self._BASE_VALUE_FIELDS + tuple(fields): + value = getattr(self, name) + if name == "dtype": + # ``DType`` is an ``IntEnum``; store the int so the key stays + # plain: hashable and ``repr``-reproducible for FX codegen. + value = int(value) + items.append((name, value)) + return (type(self).__qualname__, tuple(items)) + + def __eq__(self, other: object) -> Any: + # Value quantizers compare by configuration; everything else keeps the + # default identity semantics (returning ``NotImplemented`` makes Python + # fall back to identity). ``_value_key`` rejects a stored ProcessGroup. + if self is other: + return True + if self._value_fields() is None or type(self) is not type(other): + return NotImplemented + if other._value_fields() is None: + return NotImplemented + return self._value_key() == other._value_key() + + def __hash__(self) -> int: + if self._value_fields() is None: + return object.__hash__(self) + return hash(self._value_key()) + class QuantizedTensor(torch.Tensor): """Abstract base class for tensor with quantized data diff --git a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py index ba46508d74..92c22acd0b 100644 --- a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py @@ -14,6 +14,7 @@ from transformer_engine.common.recipe import Float8BlockScaling, Recipe from .storage.float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage from ..quantized_tensor import QuantizedTensor, Quantizer +from ..dynamo import register_value_opaque_quantizer from ._quantization_helpers import _IdentityFunc from ..constants import DType from ..utils import devices_match, round_up_to_nearest_multiple @@ -69,6 +70,9 @@ def copy(self) -> Float8BlockQuantizer: return quantizer + def _value_fields(self) -> Tuple[str, ...]: + return ("dtype", "block_len", "amax_epsilon", "force_pow_2_scales", "block_scaling_dim") + def update_quantized( self, src: torch.Tensor, @@ -211,6 +215,9 @@ def _get_compatible_recipe(self) -> Union[type[Recipe], None]: return Float8BlockScaling +register_value_opaque_quantizer(Float8BlockQuantizer) + + class Float8BlockwiseQTensor(Float8BlockwiseQTensorStorage, QuantizedTensor): """Tensor class with FP8 data quantized via NxN blocks or 1xN blocks. diff --git a/transformer_engine/pytorch/tensor/float8_tensor.py b/transformer_engine/pytorch/tensor/float8_tensor.py index e26abf7df0..0310c3855c 100644 --- a/transformer_engine/pytorch/tensor/float8_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_tensor.py @@ -18,6 +18,7 @@ from ..utils import canonicalize_process_group, devices_match from .storage.float8_tensor_storage import Float8TensorStorage, _FromFloat8Func from ..quantized_tensor import QuantizedTensor, Quantizer +from ..dynamo import register_value_opaque_quantizer from ._quantization_helpers import _IdentityFunc from ..constants import dist_group_type, DType @@ -386,6 +387,15 @@ def supports_only_rowwise_all_gather(self) -> bool: """ return True + def _value_fields(self) -> Tuple[str, ...]: + # ``amax_reduction_group`` is intentionally excluded: it is a deprecated + # process group (not a value). If one is actually stored, ``__fx_repr__`` + # raises so it can never be baked into a torch.compile graph. + return ("dtype", "force_pow_2_scales", "amax_epsilon", "with_amax_reduction") + + +register_value_opaque_quantizer(Float8CurrentScalingQuantizer) + class Float8Tensor(Float8TensorStorage, QuantizedTensor): """Experimental tensor class with FP8 data diff --git a/transformer_engine/pytorch/tensor/mxfp8_tensor.py b/transformer_engine/pytorch/tensor/mxfp8_tensor.py index d759aaf5c4..a3746b3088 100644 --- a/transformer_engine/pytorch/tensor/mxfp8_tensor.py +++ b/transformer_engine/pytorch/tensor/mxfp8_tensor.py @@ -18,6 +18,7 @@ from ..utils import devices_match, round_up_to_nearest_multiple from .storage.mxfp8_tensor_storage import MXFP8TensorStorage, _FromMXFP8Func from ..quantized_tensor import QuantizedTensor, Quantizer +from ..dynamo import register_value_opaque_quantizer from ._quantization_helpers import _IdentityFunc aten = torch.ops.aten @@ -57,6 +58,9 @@ def copy(self) -> MXFP8Quantizer: return quantizer + def _value_fields(self) -> Tuple[str, ...]: + return ("dtype",) + def update_quantized( self, src: torch.Tensor, @@ -1058,3 +1062,6 @@ def backward( ) return dgrad, None return grad.view(ctx.shape), None + + +register_value_opaque_quantizer(MXFP8Quantizer) diff --git a/transformer_engine/pytorch/tensor/nvfp4_tensor.py b/transformer_engine/pytorch/tensor/nvfp4_tensor.py index aa92be004f..c8c0a7b854 100644 --- a/transformer_engine/pytorch/tensor/nvfp4_tensor.py +++ b/transformer_engine/pytorch/tensor/nvfp4_tensor.py @@ -23,6 +23,7 @@ from .storage.nvfp4_tensor_storage import NVFP4TensorStorage, _FromNVFP4Func from ..quantized_tensor import QuantizedTensor, Quantizer +from ..dynamo import register_value_opaque_quantizer from ._quantization_helpers import _IdentityFunc aten = torch.ops.aten @@ -173,6 +174,7 @@ def __init__( self.nvfp4_4over6_err_mode = nvfp4_4over6_err_mode.upper() if self.nvfp4_4over6_err_mode not in ("MAE", "MSE"): raise ValueError("nvfp4_4over6_err_mode must be 'MAE' or 'MSE'.") + self._with_random_sign_mask = with_random_sign_mask self.rht_matrix_random_sign_mask_t = get_random_sign_mask_for_rht( with_random_sign_mask, torch.cuda.current_device() ) @@ -184,6 +186,16 @@ def __getstate__(self): state["amax_reduction_group"] = None return state + def _rebuild_derived_state(self) -> None: + """Restore the derived ``rht_matrix`` after value-key reconstruction. + + ``rht_matrix`` is a ``torch.Tensor`` built from ``_with_random_sign_mask`` + and the device, so it cannot be part of the (hashable) value key. + ``_rebuild_quantizer`` calls this hook to rebuild it; the ``lru_cache`` on + :func:`get_rht_matrix` makes an already-seen (flag, device) a cheap hit. + """ + self.rht_matrix = get_rht_matrix(self._with_random_sign_mask, torch.cuda.current_device()) + def update_quantized( self, src: torch.Tensor, @@ -333,6 +345,30 @@ def _canonicalized_amax_reduction_group(self) -> dist_group_type: def _get_compatible_recipe(self) -> Union[type[Recipe], None]: return NVFP4BlockScaling + def _value_fields(self) -> Tuple[str, ...]: + # ``amax_reduction_group`` is intentionally excluded: it is a deprecated + # process group, not a value (``_value_key`` rejects a stored group). + # ``rht_matrix_random_sign_mask_t`` is a device-independent int derived + # from ``_with_random_sign_mask``; kept in the key so the rebuilt + # quantizer carries it without recomputation. + return ( + "dtype", + "with_rht", + "with_post_rht_amax", + "with_2d_quantization", + "stochastic_rounding", + "row_scaled_nvfp4", + "nvfp4_use_4over6", + "nvfp4_e4m3_max", + "nvfp4_4over6_err_mode", + "_with_random_sign_mask", + "rht_matrix_random_sign_mask_t", + "with_amax_reduction", + ) + + +register_value_opaque_quantizer(NVFP4Quantizer) + class NVFP4Tensor(NVFP4TensorStorage, QuantizedTensor): """Quantized tensor class with FP4 data