[PyTorch][torch.compile] Decouple amax reduction group from the quantizer#3104
Conversation
|
/te-ci pytorch L1 |
Greptile SummaryThis PR decouples the amax-reduction process group from
Confidence Score: 4/5Safe to merge for most configurations, but column-parallel + sequence-parallel backward passes through LayerNormLinear and LayerNormMLP lose cross-rank amax reduction for the input quantizer, which can lead to scale divergence across TP ranks during weight-gradient computation. The refactoring is well-structured and linear.py/basic_linear.py handle both forward and backward correctly. However, layernorm_linear and layernorm_mlp backward paths are missing the set_quantizer_amax_reduction_group call for the input quantizer in the all-gather-for-wgrad block. The new Quantizer.quantize() strip resets the group after forward, so backward runs without it — a regression from the old always-on setup-time wiring that could produce inconsistent scales across ranks in column-parallel sequence-parallel training. transformer_engine/pytorch/module/layernorm_linear.py and transformer_engine/pytorch/module/layernorm_mlp.py — specifically the backward all-gather block for the FC1/GEMM1 input quantizer. Important Files Changed
Sequence Diagram%%{init: {'theme': 'neutral'}}%%
sequenceDiagram
participant FWD as Forward impl
participant Q as Quantizer
participant QT as result._quantizer (same obj)
participant BWD as Backward impl
FWD->>Q: set_quantizer_amax_reduction_group(q, tp_group)
Note right of Q: with_amax_reduction=True
FWD->>Q: q.quantize(x)
Q->>QT: "strip: with_amax_reduction=False"
Note right of Q: with_amax_reduction=False after strip
FWD-->>BWD: autograd saves ctx.input_quantizer
BWD->>Q: set_quantizer_amax_reduction_group(q, tp_group)
Note over Q,BWD: linear.py does this correctly
Note over Q,BWD: layernorm_linear/mlp skip this for input_quantizer
BWD->>Q: q.quantize(ln_out_shard)
Note right of Q: Missing amax reduction in layernorm variants
%%{init: {'theme': 'base', 'themeVariables': {"darkMode": true, "background": "#0d1117", "primaryColor": "#21262d", "primaryTextColor": "#e6edf3", "primaryBorderColor": "#8b949e", "lineColor": "#8b949e", "textColor": "#e6edf3", "edgeLabelBackground": "#161b22", "actorBkg": "#21262d", "actorBorder": "#8b949e", "actorTextColor": "#e6edf3", "actorLineColor": "#8b949e", "signalColor": "#8b949e", "signalTextColor": "#e6edf3", "noteBkgColor": "#373320", "noteBorderColor": "#d4a72c", "noteTextColor": "#f0e6c0", "labelBoxBkgColor": "#21262d", "labelBoxBorderColor": "#8b949e", "labelTextColor": "#e6edf3", "loopTextColor": "#e6edf3", "activationBkgColor": "#30363d", "activationBorderColor": "#8b949e"}}}%%
sequenceDiagram
participant FWD as Forward impl
participant Q as Quantizer
participant QT as result._quantizer (same obj)
participant BWD as Backward impl
FWD->>Q: set_quantizer_amax_reduction_group(q, tp_group)
Note right of Q: with_amax_reduction=True
FWD->>Q: q.quantize(x)
Q->>QT: "strip: with_amax_reduction=False"
Note right of Q: with_amax_reduction=False after strip
FWD-->>BWD: autograd saves ctx.input_quantizer
BWD->>Q: set_quantizer_amax_reduction_group(q, tp_group)
Note over Q,BWD: linear.py does this correctly
Note over Q,BWD: layernorm_linear/mlp skip this for input_quantizer
BWD->>Q: q.quantize(ln_out_shard)
Note right of Q: Missing amax reduction in layernorm variants
Reviews (12): Last reviewed commit: "Merge branch 'main' into remove_process_..." | Re-trigger Greptile |
|
Blocked by FSDP bug, refactor in progress. I plan to store .amax_reduction_group in QuantizedTensor. |
There was a problem hiding this comment.
This would be a design mistake. The amax reduction does not have a consistent meaning across recipes (including recipes where it doesn't make sense), and this change requires spilling out amax reduction logic into quantizer callsites (even where it doesn't make sense).
Can you go into more detail exactly why torch.compile doesn't work when quantizers have process groups? If we just want the quantizer to hold simple Python objects, maybe we can make the quantizer hold an int for the communicator ID. I envision something like:
class Float8CurrentScalingQuantizer(Quantizer):
_communicator_cache = {}
@property
def amax_reduction_group(self):
if self._amax_reduction_group_id is None:
return None
return Float8CurrentScalingQuantizer._communicator_cache[self._amax_reduction_group_id]
@property.setter
def amax_reduction_group(self, comm):
if comm is None:
self._amax_reduction_group_id = None
self._amax_reduction_group_id = id(comm)
Float8CurrentScalingQuantizer._communicator_cache[self._amax_reduction_group_id] = commI'm not sure how this would interact with checkpointing though.
| dst: QuantizedTensor, | ||
| *, | ||
| noop_flag: Optional[torch.Tensor] = None, | ||
| amax_reduction_group: Optional[Any] = None, # pylint: disable=unused-argument |
There was a problem hiding this comment.
I strongly oppose this API change. amax reduction is very recipe-specific. It has different meanings for different recipes (FP8 DS might reduce over the TP+DP group, FP8 CS might only reduce over the TP group) and it has no meaning for other recipes (MXFP8 and FP8 block scaling). Moving it into the generic API will leak recipe-specific information, defeating the point of a generic API.
|
/te-ci pytorch L1 |
…stants; fix SP memory leak; test suite hook-up Wrap CommOverlapCore pybind11 methods that return compile-time constants so torch.compile(fullgraph=True) can trace through them without graph breaks: - `is_fp8_ubuf()` → `ub_is_fp8()` / `get_ub_is_fp8()` in base.py; `_ub_is_fp8()` in gemm.py - `with_cublasmp()` → `ub_is_cublasmp()` in base.py All callers in linear.py, layernorm_linear.py, layernorm_mlp.py, base.py, gemm.py, userbuffers_backward_linear.py and userbuffers_forward_linear.py updated. Fix quantized grad_output not being freed early for column-parallel SP backward. Row-parallel SP already called clear_tensor_data(grad_output) to release the gathered tensor; column-parallel SP quantizes grad_output to Float8TensorStorage but never freed it before returning. Under torch.compile reduce-overhead this leaves 3 live pool tensors at recording end and triggers "Detected 3 tensor(s) in the cudagraph pool not tracked as outputs". Extend the existing clear_tensor_data guard to cover both parallel modes. Fix custom-recipe quantizer state being re-initialised on every forward call even when the recipe object has not changed. The existing early-exit for CustomRecipeState was missing an identity check on the recipe object, so any repeated call with the same recipe would bypass the early-return and rebuild quantizers unnecessarily. Add `if recipe_state.recipe is recipe: return` to restore the intended caching behaviour. Add test_torch_compile.py to L0_pytorch_unittest so the autocast and existing compile tests run in CI. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
…-accumulator booleans LinearBwdArgs stored the entire FP8 recipe object so the backward could extract fp8_gemm_dgrad.use_split_accumulator and fp8_gemm_wgrad.use_split_accumulator at GEMM time. Recipe objects hold process-group references and are not serialisable as compile-time constants, making them incompatible with torch.compile custom-op paths. Replace fp8_recipe with two plain bool fields: - dgrad_use_split_accumulator (default _2X_ACC_DGRAD) - wgrad_use_split_accumulator (default _2X_ACC_WGRAD) These are resolved once in _linear_setup_ctx and passed into the args struct, so the backward consumes scalars instead of a live recipe object. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
for more information, see https://pre-commit.ci
…t_result get_ub_is_fp8 bakes is_fp8_ubuf() as a compile-time constant; without a reset, destroy_ub + re-init with different FP8 settings would read stale values until recompile. Only affects in-memory caches, not disk. Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
e9097d6 to
948cd6d
Compare
ToyLinear now overrides get_quantizer_roles so CustomRecipeState doesn't hit the no-roles warning, which graph-breaks under fullgraph=True. qfactory dispatches on role.tensor_type instead of a pre-baked string key. Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
The amax reduction process group is no longer stored persistently on a module quantizer or on a tensor's quantizer. No C++ changes. - TP sequence parallel: the group is set on the input/grad-output quantizer at point of use in the fwd/bwd impls (linear, layernorm_linear, layernorm_mlp, ops basic_linear), replacing the setup-time _customize_quantizers wiring. - FSDP2: the group is stored on Float8Tensor/NVFP4Tensor (set in fsdp_pre_all_gather) and applied to a throwaway quantizer copy during the in-place re-quant (update_quantized / _set_data). - quantize() strips the group off the output tensor's quantizer so it never persists on any tensor's quantizer (breaks flatten/pickle otherwise). Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
b8c1bec to
6c9b986
Compare
|
/te-ci pytorch L1 |
set_quantizer_amax_reduction_group was a no-op on a DebugQuantizer (it lacks with_amax_reduction), so with nvinspect enabled the parent quantizer never got the SP amax reduction group, breaking fp8 current scaling column-parallel sequence-parallel numerics (debug test_numerics). Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
|
@timmoon10 I have changed the concept of this PR. Note that it also consist the changed from #3130 |
|
/te-ci pytorch L1 |
|
/te-ci pytorch L0 L1 |
|
Failures are unrelated |
I have changed the approach, which does not change C++ code.
Description
There are 2 problems with amax_reduction_group on Quantizers:
Also, current design is prone to bugs. Amax reduction group is set in module forward(). So if we have forward for tp_group1, then forward for tp_group2, this second forwad overrides the amax_reduction_group of quantizers which are used in both backwards. So I think it is better design to set amax reduction group in forward and backward directly.
We may need to slightly refactor the quantizer to mitigate this kinds of bugs, but for now the change in this PR will be sufficient for torch.compile support.
Fixes # (N/A)
Type of change
Changes
quantizer at point of use in the fwd/bwd impls (
linear,layernorm_linear,layernorm_mlp,ops/basic_linear) via a newset_quantizer_amax_reduction_grouphelper, instead of once at module setup.Removed the group wiring from
_customize_quantizers_float8_current_scalingand dropped
_customize_quantizers_nvfp4entirely.Float8Tensor/NVFP4Tensor(
amax_reduction_groupattribute, set infsdp_pre_all_gather) and apply itto a throwaway quantizer copy during the in-place re-quant
(
update_quantized/_set_data) — the weight's own quantizer is never mutated.Quantizer.quantize()strips the group off the output tensor's quantizer, soit never persists on any tensor's
_quantizer.Checklist: