Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
179 changes: 178 additions & 1 deletion tests/pytorch/test_torch_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,18 @@
from transformer_engine.common import recipe
from transformer_engine.pytorch.constants import FP8FwdTensorIdx, FP8BwdTensorIdx
from transformer_engine.pytorch.module.base import TransformerEngineBaseModule
from transformer_engine.pytorch.quantization import QuantizerRole
from transformer_engine.pytorch.ops.basic.basic_linear import BasicLinear
from transformer_engine.pytorch.tensor.float8_tensor import Float8CurrentScalingQuantizer
from transformer_engine.pytorch.quantization import QuantizerRole
from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer
from transformer_engine.pytorch import (
is_fp8_available,
is_mxfp8_available,
is_fp8_block_scaling_available,
is_nvfp4_available,
Float8Quantizer,
Float8BlockQuantizer,
MXFP8Quantizer,
)
from utils import recipe_id

Expand Down Expand Up @@ -384,3 +388,176 @@ def fn(inp):

out = compiled(inp)
out.sum().backward()


# ---------------------------------------------------------------------------
# Value-opaque quantizers
# ---------------------------------------------------------------------------


def _mxfp8(dtype=tex.DType.kFloat8E4M3):
return MXFP8Quantizer(fp8_dtype=dtype)


def _blockwise(force_pow_2_scales=True):
return Float8BlockQuantizer(
fp8_dtype=tex.DType.kFloat8E4M3,
rowwise=True,
columnwise=True,
force_pow_2_scales=force_pow_2_scales,
)


def _current_scaling(amax_epsilon=0.0):
return Float8CurrentScalingQuantizer(
fp8_dtype=tex.DType.kFloat8E4M3,
device=torch.device("cpu"),
amax_epsilon=amax_epsilon,
)


def _nvfp4(with_rht=True):
# Default with_rht=True so the quantize round-trip below exercises the
# derived ``rht_matrix`` tensor (the field most likely to be dropped on
# value-key reconstruction).
return NVFP4Quantizer(
fp4_dtype=tex.DType.kFloat4E2M1,
rowwise=True,
columnwise=True,
with_rht=with_rht,
)


def _hw_available(quantizer):
"""Whether this HW can actually run the quantize kernel for *quantizer*."""
if isinstance(quantizer, MXFP8Quantizer):
return mxfp8_available
if isinstance(quantizer, NVFP4Quantizer):
return nvfp4_available
if isinstance(quantizer, Float8BlockQuantizer):
return fp8_block_scaling_available
return fp8_available # Float8CurrentScalingQuantizer


# (factory, kwargs producing a different-but-valid config)
_VALUE_QUANTIZERS = [
pytest.param(_mxfp8, {"dtype": tex.DType.kFloat8E5M2}, id="mxfp8"),
pytest.param(_blockwise, {"force_pow_2_scales": False}, id="float8_blockwise"),
pytest.param(_current_scaling, {"amax_epsilon": 1e-4}, id="float8_current_scaling"),
pytest.param(
_nvfp4,
{"with_rht": False},
id="nvfp4",
marks=pytest.mark.skipif(
not torch.cuda.is_available(),
reason="NVFP4Quantizer requires CUDA to construct",
),
),
]


@pytest.mark.parametrize("factory, other_kwargs", _VALUE_QUANTIZERS)
def test_quantizer_value_object(factory, other_kwargs):
"""Value semantics + ``__fx_repr__`` round-trip via the production FX path."""
a, b = factory(), factory()
# Same config -> equal, same hash, interchangeable as a dict/set key.
assert a is not b
assert a == b
assert hash(a) == hash(b)
assert {a: "x"}[b] == "x"
# Different config -> not equal.
assert a != factory(**other_kwargs)

# ``__fx_repr__`` (used by torch.compile codegen) rebuilds an equal object.
repr_str, globals_ = a.__fx_repr__()
rebuilt = eval(repr_str, dict(globals_)) # pylint: disable=eval-used
assert rebuilt == a and rebuilt is not a
assert hash(rebuilt) == hash(a)

# The rebuilt quantizer must also *behave* identically, not just compare
# equal: equality only looks at the value key, so a field the kernel needs
# but that is absent from the key (e.g. NVFP4's derived ``rht_matrix``) would
# slip through the checks above and only blow up at quantize time. Run the
# real quantize kernel on both and require bit-exact results.
if torch.cuda.is_available() and _hw_available(a):
x = torch.randn(128, 256, dtype=torch.bfloat16, device="cuda")
torch.testing.assert_close(rebuilt(x).dequantize(), a(x).dequantize(), rtol=0.0, atol=0.0)


def test_value_quantizer_rejects_process_group():
"""A value quantizer holding a live ProcessGroup must refuse to be turned
into a value key / FX constant (raise), not silently drop the group."""
import torch.distributed as dist # pylint: disable=import-outside-toplevel

created = not dist.is_initialized()
if created:
dist.init_process_group(backend="gloo", store=dist.HashStore(), rank=0, world_size=1)
try:
q = MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3)
q.amax_reduction_group = dist.group.WORLD
# Every value-materialization path must reject it (hash, eq, __fx_repr__).
with pytest.raises(TypeError):
hash(q)
with pytest.raises(TypeError):
q.__fx_repr__()
finally:
if created:
dist.destroy_process_group()


if _opaque_available:
# A minimal custom op taking a tensor and a value-opaque quantizer that
# quantizes + dequantizes inside it, one per production quantizer class.
# ``test_quantizer_value_object_fullgraph`` drives this under
# ``torch.compile(fullgraph=True)`` so the quantizer is used *inside* the
# graph -- proving the opaque-type registration took effect (a graph break
# would make ``fullgraph=True`` raise).
_qdq_lib = torch.library.Library("test_te_qdq", "DEF")
_QDQ_OPS = {}
for _qcls in (
MXFP8Quantizer,
Float8BlockQuantizer,
Float8CurrentScalingQuantizer,
NVFP4Quantizer,
):
_op = f"qdq_{_qcls.__name__}"
_qdq_lib.define(f"{_op}(Tensor x, {get_opaque_type_name(_qcls)} q) -> Tensor")

@torch.library.impl(f"test_te_qdq::{_op}", "CompositeExplicitAutograd", lib=_qdq_lib)
def _qdq_impl(x, q):
return q(x).dequantize()

@torch.library.register_fake(f"test_te_qdq::{_op}", lib=_qdq_lib)
def _qdq_fake(x, q):
return torch.empty_like(x)

_QDQ_OPS[_qcls] = getattr(torch.ops.test_te_qdq, _op)


@pytest.mark.skipif(
not _opaque_available,
reason="torch.compile opaque-object support requires PyTorch >= 2.11",
)
@pytest.mark.parametrize("factory, other_kwargs", _VALUE_QUANTIZERS)
def test_quantizer_value_object_fullgraph(factory, other_kwargs):
"""Quantizer is usable *inside* a torch.compile(fullgraph=True) graph.

A custom op quantizes+dequantizes with the (opaque value) quantizer; the
compiled result must match eager. ``fullgraph=True`` raises on any graph
break, so this proves the opaque-type registration actually took effect --
unlike merely passing the quantizer through.
"""
q = factory()
if not (torch.cuda.is_available() and _hw_available(q)):
pytest.skip("format not supported on this HW")

op = _QDQ_OPS[type(q)]
x = torch.randn(128, 256, dtype=torch.bfloat16, device="cuda")

def fn(inp):
return op(inp, q)

ref = fn(x)
torch._dynamo.reset()
out = torch.compile(fn, fullgraph=True)(x)
torch.testing.assert_close(out, ref, rtol=0.0, atol=0.0)
12 changes: 12 additions & 0 deletions transformer_engine/pytorch/dynamo/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

"""torch.compile glue for Transformer Engine."""

from .quantizer_opaque import register_value_opaque_quantizer, is_value_opaque_quantizer

__all__ = [
"register_value_opaque_quantizer",
"is_value_opaque_quantizer",
]
116 changes: 116 additions & 0 deletions transformer_engine/pytorch/dynamo/quantizer_opaque.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

"""Value-opaque quantizers for torch.compile."""

from __future__ import annotations
from typing import Any, Dict, Tuple

from ..constants import DType


# Registration marks the class with this attribute rather than recording it in a
# module-level set. It looks odd but is a deliberate workaround: the check must
# stay traceable when it runs inside a torch.compile graph -- Dynamo can bake a
# ``getattr`` on the opaque quantizer into a constant, but cannot evaluate
# ``type(q) in some_set`` (no equality/hash rules for the opaque class object),
# which would graph-break under ``fullgraph=True``.
_VALUE_OPAQUE_FLAG = "_te_compile_value_opaque"


def is_value_opaque_quantizer(quantizer: Any) -> bool:
"""Whether *quantizer*'s class is registered as a torch.compile value-opaque
type."""
return getattr(quantizer, _VALUE_OPAQUE_FLAG, False)


def _rebuild_quantizer(cls: type, items: Tuple[Tuple[str, Any], ...]) -> Any:
"""Rebuild a tensorless quantizer of type *cls* from its value items.

Referenced by the ``__fx_repr__`` emitted for value-opaque quantizers; the
generated FX code calls this to materialize the quantizer constant.
"""
# Bypass ``__init__`` and restore the value attributes directly: the value
# items already capture every value-defining field (including derived ones),
# and the constructors have heterogeneous signatures / side effects.
obj = cls.__new__(cls)
field_names = set()
for name, value in items:
if name == "dtype":
value = DType.cast(value)
object.__setattr__(obj, name, value)
field_names.add(name)
# The deprecated amax-reduction group is not a value field; initialize it to
# None so attribute access keeps working on the rebuilt quantizer.
if "with_amax_reduction" in field_names and not hasattr(obj, "amax_reduction_group"):
object.__setattr__(obj, "amax_reduction_group", None)
# Restore non-value derived state that ``__init__`` would normally build but
# that cannot live in the value key (e.g. NVFP4's ``rht_matrix`` tensor).
finalize = getattr(obj, "_rebuild_derived_state", None)
if finalize is not None:
finalize()
return obj


def _quantizer_fx_repr(self: Any) -> Tuple[str, Dict[str, Any]]:
"""``__fx_repr__`` for value-opaque quantizers (attached at registration).

Returns an evaluable expression that rebuilds the quantizer via
:func:`_rebuild_quantizer`, capturing both the helper and the quantizer
class itself in the FX globals so codegen can resolve them with no global
registry and no qualname collisions.

Raises ``TypeError`` (via :meth:`Quantizer._value_key`) if the quantizer
stores a process group (e.g. a non-``None`` deprecated
``amax_reduction_group``): live distributed state must never be baked into
the graph as a constant. Pass the reduction group per quantize call instead
of storing it on the quantizer.
"""
cls = type(self)
items = self._value_key()[1]
return (
f"_rebuild_quantizer({cls.__name__}, {items!r})",
{"_rebuild_quantizer": _rebuild_quantizer, cls.__name__: cls},
)


def register_value_opaque_quantizer(cls: type) -> None:
"""Register a tensorless quantizer class as a torch.compile value opaque type.

Attaches ``__fx_repr__`` and registers the class with
``torch._library.opaque_object``. Safe to call on any PyTorch build: on
versions without the opaque-object API it only attaches ``__fx_repr__``
(harmless), so Transformer Engine keeps importing and running in eager mode.

The quantizer class must already provide value ``__eq__`` / ``__hash__`` and
a non-``None`` ``_value_fields`` (see
:class:`transformer_engine.pytorch.quantized_tensor.Quantizer`).
"""
# Stamp the class so it can be recognized as value-opaque in dynamo-traced
# code (used to fall back to eager for unregistered quantizers).
setattr(cls, _VALUE_OPAQUE_FLAG, True)

# ``register_opaque_type`` requires ``__fx_repr__`` to already exist on the
# class, so attach it before registering.
if "__fx_repr__" not in cls.__dict__:
cls.__fx_repr__ = _quantizer_fx_repr
Comment thread
pggPL marked this conversation as resolved.

try:
from torch._library.opaque_object import ( # pylint: disable=import-outside-toplevel
register_opaque_type,
is_opaque_value_type,
)
except (ImportError, AttributeError):
# Older PyTorch without the opaque-object API: eager value semantics
# still work; torch.compile specialization on the quantizer does not.
return

try:
if not is_opaque_value_type(cls):
register_opaque_type(cls, typ="value")
except (RuntimeError, TypeError):
# Keep TE importable: neither the opaque-type query nor the registration
# must crash the import, e.g. on PyTorch versions with only partial /
# experimental opaque-object support.
pass
Loading
Loading