[PyTorch][torch.compile] Add TensorProto mechanism#3153
Conversation
…ompile Give tensorless quantizers (MXFP8, FP8 blockwise, FP8 current-scaling, NVFP4) value-object semantics so torch.compile can treat them as baked-in constants: - Add opt-in value identity to the base Quantizer (_value_fields / _value_key / __eq__ / __hash__). Quantizers holding live tensors (delayed-scaling Float8Quantizer) and custom quantizers keep identity semantics. - New transformer_engine/pytorch/dynamo.py houses the torch.compile glue: __fx_repr__, value-key reconstruction and register_value_opaque_quantizer (gracefully a no-op on PyTorch builds without the opaque-object API). - Register the four tensorless quantizers as value opaque types. Also fix CustomRecipe state caching in TransformerEngineBaseModule: set_meta_tensor now rebuilds quantizers when the CustomRecipe instance changes (e.g. nested te.autocast regions) instead of reusing the first recipe's state, since every CustomRecipe shares the CustomRecipeState type but carries its own qfactory. Move the quantizer value-object tests into tests/pytorch/test_torch_compile.py and add that file to the L0 pytorch unittest QA suite. Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
…globals Follow-up to the value-opaque quantizer support: - Remove the module-level _QUANTIZER_VALUE_REGISTRY (qualname -> class) and _quantizer_from_value_key. __fx_repr__ now captures the quantizer class directly in the FX globals and reconstructs via _rebuild_quantizer(cls, items), matching how PyTorch's own value opaque types (e.g. DTensor placements) reconstruct themselves. This removes global mutable state and the qualname collision risk. - Consolidate the quantizer value-object tests in test_torch_compile.py down to two functions and exercise reconstruction through the public __fx_repr__ path instead of internal helpers. Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Replace the single dynamo.py module with a dynamo/ package so the
torch.compile glue can grow with a clear responsibility split across the
stacked branches. This branch owns the value-opaque quantizer layer.
* dynamo/quantizer_opaque.py -- register_value_opaque_quantizer and helpers
* dynamo/__init__.py -- re-exports the public API so callers keep importing
from transformer_engine.pytorch.dynamo unchanged
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
A value-opaque quantizer must not carry live distributed state. Scan the quantizer attributes in __fx_repr__ and raise TypeError if any holds a torch.distributed.ProcessGroup (e.g. a non-None deprecated amax_reduction_group), so it cannot be silently baked into a torch.compile FX graph. Clarify the related comments accordingly. Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
NVFP4Quantizer is registered as a value-opaque quantizer but was missing from the value-semantics / __fx_repr__ round-trip test. Add it to _VALUE_QUANTIZERS (skipped without CUDA, which it needs to construct). Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
…__/__hash__ The amax reduction group is excluded from the value key, so a value quantizer that stored one would compare/hash equal to a groupless one and let torch.compile reuse a graph that skips the reduction. __eq__/__hash__ now raise (mirroring __fx_repr__, which already rejects any process-group-bearing quantizer). The group should be passed per quantize call, not stored on the quantizer. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Add is_value_opaque_quantizer() + the _te_compile_value_opaque flag stamped at registration, so dynamo-traced code can detect registered quantizers (and fall back to eager for unregistered ones). Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
…fp4 value key - Narrow register_opaque_type except to (RuntimeError, TypeError): the API is already imported above, so ImportError/AttributeError there only mask real errors. - Add test_quantizer_value_object_fullgraph exercising torch.compile(fullgraph=True) end-to-end to verify opaque-type registration took effect. - Restore missing NVFP4Quantizer._with_random_sign_mask assignment required by _value_fields()/_value_key(). Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Greptile SummaryThis PR introduces
Confidence Score: 4/5Safe to merge after fixing the wrong flag in The transformer_engine/pytorch/tensor/nvfp4_tensor.py — both Important Files Changed
Sequence Diagram%%{init: {'theme': 'neutral'}}%%
sequenceDiagram
participant C as Caller / register_fake
participant TP as TensorProto
participant Q as Quantizer
participant S as QuantizedTensorStorage
participant R as _STORAGE_REGISTRY
Note over C,R: TensorProto.create_tensor() flow
C->>TP: TensorProto(shape, dtype, quantizer)
TP->>Q: copy() [__post_init__]
C->>TP: create_tensor()
TP->>TP: create_metadata()
TP->>Q: create_metadata(shape, dtype)
Q->>Q: _storage_metadata(dtype)
Q-->>TP: "ctx {cls: qualname, nontensor_kwargs, ...}"
TP->>TP: create_inner_tensors()
TP->>Q: alloc_tensors(shape, device)
Q->>Q: _describe_buffers(shape)
Q-->>TP: "{attr: torch.empty(buf_shape, buf_dtype)}"
TP->>TP: inner_names() [reorder to _FLATTEN_TENSOR_BUFFERS order]
TP->>R: _STORAGE_REGISTRY[ctx["cls"]]
R-->>TP: storage_cls
TP->>S: __tensor_unflatten__(inner, ctx, shape, stride)
S->>R: _STORAGE_REGISTRY[ctx["cls"]]
R-->>S: cls
S-->>C: QuantizedTensorStorage instance
Note over C,R: FX graph rebuild via _rebuild_quantizer
C->>C: eval(quantizer.__fx_repr__())
C->>C: _rebuild_quantizer(cls, items)
C->>C: object.__setattr__(obj, field, value) for each field
C->>Q: _rebuild_derived_state() [NVFP4 only]
Q->>Q: get_rht_matrix(_with_random_sign_mask, device)
%%{init: {'theme': 'base', 'themeVariables': {"darkMode": true, "background": "#0d1117", "primaryColor": "#21262d", "primaryTextColor": "#e6edf3", "primaryBorderColor": "#8b949e", "lineColor": "#8b949e", "textColor": "#e6edf3", "edgeLabelBackground": "#161b22", "actorBkg": "#21262d", "actorBorder": "#8b949e", "actorTextColor": "#e6edf3", "actorLineColor": "#8b949e", "signalColor": "#8b949e", "signalTextColor": "#e6edf3", "noteBkgColor": "#373320", "noteBorderColor": "#d4a72c", "noteTextColor": "#f0e6c0", "labelBoxBkgColor": "#21262d", "labelBoxBorderColor": "#8b949e", "labelTextColor": "#e6edf3", "loopTextColor": "#e6edf3", "activationBkgColor": "#30363d", "activationBorderColor": "#8b949e"}}}%%
sequenceDiagram
participant C as Caller / register_fake
participant TP as TensorProto
participant Q as Quantizer
participant S as QuantizedTensorStorage
participant R as _STORAGE_REGISTRY
Note over C,R: TensorProto.create_tensor() flow
C->>TP: TensorProto(shape, dtype, quantizer)
TP->>Q: copy() [__post_init__]
C->>TP: create_tensor()
TP->>TP: create_metadata()
TP->>Q: create_metadata(shape, dtype)
Q->>Q: _storage_metadata(dtype)
Q-->>TP: "ctx {cls: qualname, nontensor_kwargs, ...}"
TP->>TP: create_inner_tensors()
TP->>Q: alloc_tensors(shape, device)
Q->>Q: _describe_buffers(shape)
Q-->>TP: "{attr: torch.empty(buf_shape, buf_dtype)}"
TP->>TP: inner_names() [reorder to _FLATTEN_TENSOR_BUFFERS order]
TP->>R: _STORAGE_REGISTRY[ctx["cls"]]
R-->>TP: storage_cls
TP->>S: __tensor_unflatten__(inner, ctx, shape, stride)
S->>R: _STORAGE_REGISTRY[ctx["cls"]]
R-->>S: cls
S-->>C: QuantizedTensorStorage instance
Note over C,R: FX graph rebuild via _rebuild_quantizer
C->>C: eval(quantizer.__fx_repr__())
C->>C: _rebuild_quantizer(cls, items)
C->>C: object.__setattr__(obj, field, value) for each field
C->>Q: _rebuild_derived_state() [NVFP4 only]
Q->>Q: get_rht_matrix(_with_random_sign_mask, device)
|
…trip _rebuild_quantizer only restores value-key fields, so a reconstructed NVFP4Quantizer was missing the derived rht_matrix tensor (not hashable, so not in the value key) and failed at copy()/quantize time. Add a _rebuild_derived_state hook (called by _rebuild_quantizer) that NVFP4Quantizer uses to rebuild rht_matrix from _with_random_sign_mask (lru_cache -> cheap). Extend test_quantizer_value_object to also quantize with the original and the rebuilt quantizer and require bit-exact results (gated on HW support), so a field the kernel needs but the value key omits can no longer slip through. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
5131ebc to
77831be
Compare
Move the ProcessGroup guard out of the (overridable) __fx_repr__ into Quantizer._value_key -- the single point every value-materialization path (__eq__/__hash__/__fx_repr__) goes through -- so a custom __fx_repr__ can no longer bypass it. Generalizes the old amax-only check to any field holding a ProcessGroup. Add a test that a value quantizer carrying a live group raises. Addresses review on NVIDIA#3152. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
77831be to
29e5245
Compare
…assthrough Replace the trivial pass-through fullgraph test with one that drives each production quantizer through a minimal custom op (quantize + dequantize) under torch.compile(fullgraph=True) and compares to eager -- so the opaque-type registration is actually exercised inside the graph (a graph break would make fullgraph=True raise). Op registration sits right before the test. Also drop stale comments referencing the old __fx_repr__-side process-group guard. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
29e5245 to
99c1377
Compare
…paque flag - rht_matrix_random_sign_mask_t is a device-independent int derived from _with_random_sign_mask (the device only places a throwaway tensor); fix the misleading comment. - Explain why registration uses a class attribute, not a registry set: is_value_opaque_quantizer is traced inside the compile graph and dynamo can bake a getattr constant but cannot do 'type(q) in set' on the opaque class. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
99c1377 to
afa86ff
Compare
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
afa86ff to
9e78a6c
Compare
is_opaque_value_type(cls) sat between the import guard and the register_opaque_type guard, so on a partial/experimental opaque-object build it could raise RuntimeError/TypeError and crash TE import. Move it inside the same except so the 'registration never crashes import' promise holds for both calls. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Squashed PR #8 (tensor_proto_mechanism) onto the rebased base. Adds TensorProto (pure-Python, torch.compile-traceable quantized-tensor allocation via Quantizer.alloc_tensors + storage __tensor_flatten__/__tensor_unflatten__), Linear fake fwd/bwd impls for the custom-op path, and tests. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
The cached FP8 weight is the same tensor returned as new_weight_workspace (cache miss) or passed in as weight_workspace (cache hit). A custom op may not return a tensor that aliases an input or another return, so mark those slots and reconstruct wt_save in _linear_setup_ctx instead of saving it twice. Mirrored in the fake impl so the saved-slot layout matches. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
NVFP4Quantizer._describe_buffers grouped each amax right after its scale (per-usage), diverging from NVFP4TensorStorage._FLATTEN_TENSOR_BUFFERS (amax buffers last). The order is functionally irrelevant (buffers are consumed by name in alloc_tensors and reordered in TensorProto.inner_names), but aligning it makes describe/flatten agree and fixes test_to_tensor_proto_quantized[nvfp4]. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
…upport - TensorProto.inner_names now raises if the quantizer describes buffer(s) absent from the storage's _FLATTEN_TENSOR_BUFFERS, instead of silently appending them. - Gate the nvfp4 proto-quantizer param on nvfp4_available so it skips on hardware without NVFP4 support rather than failing. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
9e78a6c to
50c11cd
Compare
…escribe_buffers Access NVFP4Quantizer @staticmethods (convert_shape_for_fp4, get_columnwise_shape) via the class instead of the instance. Under torch.compile, instance access of a @staticmethod on a value-opaque object crashes Dynamo guard generation with "'function' object has no attribute '__func__'" (pytorch/pytorch#182741). Temporary workaround until the PyTorch-side fix lands. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
| ``_rebuild_quantizer`` calls this hook to rebuild it; the ``lru_cache`` on | ||
| :func:`get_rht_matrix` makes an already-seen (flag, device) a cheap hit. | ||
| """ | ||
| self.rht_matrix = get_rht_matrix(self._with_random_sign_mask, torch.cuda.current_device()) |
There was a problem hiding this comment.
_rebuild_derived_state passes self._with_random_sign_mask to get_rht_matrix, but __init__ passes with_rht. These are distinct parameters: with_rht controls whether the Hadamard transform is applied at all; _with_random_sign_mask controls whether random signs are used within it. When they differ (e.g. with_rht=True, with_random_sign_mask=False) the rebuilt quantizer gets a different matrix than the original. Since get_rht_matrix is LRU-cached by (flag, device), get_rht_matrix(True, d) and get_rht_matrix(False, d) return different objects, so the kernel would receive the wrong transform matrix and silently produce incorrect quantization results.
| self.rht_matrix = get_rht_matrix(self._with_random_sign_mask, torch.cuda.current_device()) | |
| self.rht_matrix = get_rht_matrix(self.with_rht, torch.cuda.current_device()) |
Note
This PR is stacked on top of #3152 ([PyTorch][torch.compile] Make quantizers opaque value objects).
Until #3152 is merged, the diff below also includes its changes. Review/merge #3152 first.
Description
This PR introduces
TensorProto— a data-free prototype of a tensor (or quantized tensor) that captures everything needed to reason about and rebuild a tensor without holding any storage: its logicalshape/dtypeand, for quantized tensors, the value-opaquequantizerdefining the layout.The key property is that
TensorProto.create_tensor()materializes a quantized tensor purely in Python (viaQuantizer.alloc_tensors+ the storage's__tensor_unflatten__), so it traces undertorch.compile(fullgraph=True)with no graph break — unlikemake_empty, which goes through the opaque C++tex.create_empty_quantized_tensor. This is the foundation for writingtorch.librarycustom-op fake implementations of quantized ops.This builds on the value-opaque quantizer work (so a
TensorProtois itself safe to treat as a compile-time constant).Type of change
Changes
dynamo.py: AddTensorProtodataclass (shape,dtype,quantizer,requires_grad,device) withis_quantized,inner_names(),create_metadata()andcreate_tensor(), plus ato_tensor_proto()helper that builds a proto from a plaintorch.Tensoror aQuantizedTensorStorage/QuantizedTensor.quantized_tensor.py:__tensor_flatten__/__tensor_unflatten__) toQuantizedTensorStorage, driven by a per-class_FLATTEN_TENSOR_BUFFERSdeclaration of(attribute_name, constructor_kwarg)pairs._STORAGE_REGISTRY(populated via__init_subclass__) so__tensor_unflatten__can resolve a concrete storage/wrapper class from its qualname inside an FX graph.Quantizer:alloc_tensors,create_metadata, and the opt-in overrides_describe_buffers,_storage_scalars,_resolve_storage_cls.Float8CurrentScalingQuantizer,MXFP8QuantizerandFloat8BlockQuantizer._FLATTEN_TENSOR_BUFFERSforFloat8TensorStorage,MXFP8TensorStorageandFloat8BlockwiseQTensorStorage.ops/basic/basic_linear.py: Add allocation-free_functional_forward_fake/_functional_backward_fakethat operate onTensorProtoand return output/gradient protos, as a basis for custom-op fake impls (single-device only; TP/SP shape effects not yet modeled).tests/pytorch/test_tensor_proto.py(CPU smoke tests for_describe_buffers/alloc_tensors/create_metadata, flatten round-trip, andto_tensor_proto) andtorch.compilefullgraph tests intest_torch_compile.py.Checklist: