Skip to content

[PyTorch][torch.compile] Decouple amax reduction group from the quantizer#3104

Merged
pggPL merged 9 commits into
NVIDIA:mainfrom
pggPL:remove_process_group_from_quantizers
Jun 29, 2026
Merged

[PyTorch][torch.compile] Decouple amax reduction group from the quantizer#3104
pggPL merged 9 commits into
NVIDIA:mainfrom
pggPL:remove_process_group_from_quantizers

Conversation

@pggPL

@pggPL pggPL commented Jun 8, 2026

Copy link
Copy Markdown
Collaborator

Description

There are 2 problems with amax_reduction_group on Quantizers:

  • I want to declare quantizers as opaque value objects which is some kind of Python constant - the ProcessGroup and tensor inside are problematic,
  • there is PyTorch for dealing with custom tensor classes like Float8Tensor in torch.compile: tensor_flatten() and tensor_unflatten() which assume that all internal tensors or opaque reference objects like process groups are directly parameters of a tensor. Currently, they are parameters of parameter (quantizer). I change that in this PR - amax reduction group is also stored on a QuantizedTensor when applicable.

Also, current design is prone to bugs. Amax reduction group is set in module forward(). So if we have forward for tp_group1, then forward for tp_group2, this second forwad overrides the amax_reduction_group of quantizers which are used in both backwards. So I think it is better design to set amax reduction group in forward and backward directly.
We may need to slightly refactor the quantizer to mitigate this kinds of bugs, but for now the change in this PR will be sufficient for torch.compile support.

Fixes # (N/A)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • TP sequence parallel: set the amax reduction group on the input / grad-output
    quantizer at point of use in the fwd/bwd impls (linear, layernorm_linear,
    layernorm_mlp, ops/basic_linear) via a new
    set_quantizer_amax_reduction_group helper, instead of once at module setup.
    Removed the group wiring from _customize_quantizers_float8_current_scaling
    and dropped _customize_quantizers_nvfp4 entirely.
  • FSDP2: store the group on Float8Tensor / NVFP4Tensor
    (amax_reduction_group attribute, set in fsdp_pre_all_gather) and apply it
    to a throwaway quantizer copy during the in-place re-quant
    (update_quantized / _set_data) — the weight's own quantizer is never mutated.
  • Quantizer.quantize() strips the group off the output tensor's quantizer, so
    it never persists on any tensor's _quantizer.
  • No C++ changes.

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@pggPL

pggPL commented Jun 8, 2026

Copy link
Copy Markdown
Collaborator Author

/te-ci pytorch L1

@greptile-apps

greptile-apps Bot commented Jun 8, 2026

Copy link
Copy Markdown
Contributor

Greptile Summary

This PR decouples the amax-reduction process group from Quantizer objects by moving the group wiring from module-setup time to point-of-use in each forward/backward implementation, and by storing the FSDP2 mesh group directly on Float8Tensor/NVFP4Tensor rather than on their quantizers. Quantizer.quantize() now strips with_amax_reduction from the result's quantizer after each call, enforcing a one-shot clear so groups never leak across training steps.

  • _common.py introduces a set_quantizer_amax_reduction_group helper; linear.py and basic_linear.py call it in both forward and backward at the right sites; FSDP2 flows in float8_tensor.py and nvfp4_tensor.py use a throwaway quantizer copy so the canonical weight quantizer is never mutated.
  • layernorm_linear.py and layernorm_mlp.py call the helper for their input quantizer in the forward pass but omit the symmetric backward call for ctx.input_quantizer / ctx.fc1_input_quantizer used in the all-gather-for-wgrad block; because Quantizer.quantize() now resets with_amax_reduction after forward, these backward quantizations lose the cross-rank amax reduction that the old setup-time code provided.
  • All three _customize_quantizers_nvfp4 per-module methods and the parallel-group wiring inside _customize_quantizers_float8_current_scaling are correctly removed.

Confidence Score: 4/5

Safe to merge for most configurations, but column-parallel + sequence-parallel backward passes through LayerNormLinear and LayerNormMLP lose cross-rank amax reduction for the input quantizer, which can lead to scale divergence across TP ranks during weight-gradient computation.

The refactoring is well-structured and linear.py/basic_linear.py handle both forward and backward correctly. However, layernorm_linear and layernorm_mlp backward paths are missing the set_quantizer_amax_reduction_group call for the input quantizer in the all-gather-for-wgrad block. The new Quantizer.quantize() strip resets the group after forward, so backward runs without it — a regression from the old always-on setup-time wiring that could produce inconsistent scales across ranks in column-parallel sequence-parallel training.

transformer_engine/pytorch/module/layernorm_linear.py and transformer_engine/pytorch/module/layernorm_mlp.py — specifically the backward all-gather block for the FC1/GEMM1 input quantizer.

Important Files Changed

Filename Overview
transformer_engine/pytorch/module/_common.py Adds set_quantizer_amax_reduction_group helper that unwraps DebugQuantizer and sets with_amax_reduction / amax_reduction_group at point-of-use; correctly handles None quantizer and non-supporting targets.
transformer_engine/pytorch/module/linear.py Moves amax-reduction group wiring from setup-time customize_quantizers* methods to forward/backward call sites; backward correctly sets both input_quantizer (column-parallel) and grad_output_quantizer (row-parallel).
transformer_engine/pytorch/module/layernorm_linear.py Forward correctly calls set_quantizer_amax_reduction_group for input quantizer; backward sets it for grad_output but NOT for ctx.input_quantizer used in the ln_out_needs_gather all-gather block — missing the column-parallel backward amax reduction that the old setup-time code provided.
transformer_engine/pytorch/module/layernorm_mlp.py Same omission as layernorm_linear: backward sets group for fc2_grad_output_quantizer but not for fc1_input_quantizer in the tensor-parallel all-gather block, breaking column-parallel + sequence-parallel amax reduction in backward.
transformer_engine/pytorch/ops/basic/basic_linear.py Correctly removes setup-time group wiring and adds set_quantizer_amax_reduction_group at each relevant forward/backward all-gather site; both input_quantizer and grad_output_quantizer paths are covered.
transformer_engine/pytorch/quantized_tensor.py Adds post-quantize() stripping of with_amax_reduction from the result tensor's quantizer; since result._quantizer is self, this implements a one-shot clear of the calling quantizer after each quantize call.
transformer_engine/pytorch/tensor/float8_tensor.py FSDP2 path moves amax_reduction_group from quantizer to tensor; update_quantized and _set_data both read the group from the tensor and use a throwaway quantizer copy, keeping the base quantizer immutable.
transformer_engine/pytorch/tensor/nvfp4_tensor.py Mirrors Float8Tensor FSDP2 pattern: amax_reduction_group class attribute added, set in fsdp_pre_all_gather, and consumed via throwaway quantizer copy in NVFP4Quantizer.update_quantized.

Sequence Diagram

%%{init: {'theme': 'neutral'}}%%
sequenceDiagram
    participant FWD as Forward impl
    participant Q as Quantizer
    participant QT as result._quantizer (same obj)
    participant BWD as Backward impl

    FWD->>Q: set_quantizer_amax_reduction_group(q, tp_group)
    Note right of Q: with_amax_reduction=True
    FWD->>Q: q.quantize(x)
    Q->>QT: "strip: with_amax_reduction=False"
    Note right of Q: with_amax_reduction=False after strip
    FWD-->>BWD: autograd saves ctx.input_quantizer
    BWD->>Q: set_quantizer_amax_reduction_group(q, tp_group)
    Note over Q,BWD: linear.py does this correctly
    Note over Q,BWD: layernorm_linear/mlp skip this for input_quantizer
    BWD->>Q: q.quantize(ln_out_shard)
    Note right of Q: Missing amax reduction in layernorm variants
Loading
%%{init: {'theme': 'base', 'themeVariables': {"darkMode": true, "background": "#0d1117", "primaryColor": "#21262d", "primaryTextColor": "#e6edf3", "primaryBorderColor": "#8b949e", "lineColor": "#8b949e", "textColor": "#e6edf3", "edgeLabelBackground": "#161b22", "actorBkg": "#21262d", "actorBorder": "#8b949e", "actorTextColor": "#e6edf3", "actorLineColor": "#8b949e", "signalColor": "#8b949e", "signalTextColor": "#e6edf3", "noteBkgColor": "#373320", "noteBorderColor": "#d4a72c", "noteTextColor": "#f0e6c0", "labelBoxBkgColor": "#21262d", "labelBoxBorderColor": "#8b949e", "labelTextColor": "#e6edf3", "loopTextColor": "#e6edf3", "activationBkgColor": "#30363d", "activationBorderColor": "#8b949e"}}}%%
sequenceDiagram
    participant FWD as Forward impl
    participant Q as Quantizer
    participant QT as result._quantizer (same obj)
    participant BWD as Backward impl

    FWD->>Q: set_quantizer_amax_reduction_group(q, tp_group)
    Note right of Q: with_amax_reduction=True
    FWD->>Q: q.quantize(x)
    Q->>QT: "strip: with_amax_reduction=False"
    Note right of Q: with_amax_reduction=False after strip
    FWD-->>BWD: autograd saves ctx.input_quantizer
    BWD->>Q: set_quantizer_amax_reduction_group(q, tp_group)
    Note over Q,BWD: linear.py does this correctly
    Note over Q,BWD: layernorm_linear/mlp skip this for input_quantizer
    BWD->>Q: q.quantize(ln_out_shard)
    Note right of Q: Missing amax reduction in layernorm variants
Loading

Reviews (12): Last reviewed commit: "Merge branch 'main' into remove_process_..." | Re-trigger Greptile

Comment thread transformer_engine/pytorch/quantized_tensor.py Outdated
Comment thread transformer_engine/pytorch/tensor/nvfp4_tensor.py Outdated
@pggPL

pggPL commented Jun 8, 2026

Copy link
Copy Markdown
Collaborator Author

Blocked by FSDP bug, refactor in progress.

I plan to store .amax_reduction_group in QuantizedTensor.

@timmoon10 timmoon10 left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would be a design mistake. The amax reduction does not have a consistent meaning across recipes (including recipes where it doesn't make sense), and this change requires spilling out amax reduction logic into quantizer callsites (even where it doesn't make sense).

Can you go into more detail exactly why torch.compile doesn't work when quantizers have process groups? If we just want the quantizer to hold simple Python objects, maybe we can make the quantizer hold an int for the communicator ID. I envision something like:

class Float8CurrentScalingQuantizer(Quantizer):

    _communicator_cache = {}

    @property
    def amax_reduction_group(self):
        if self._amax_reduction_group_id is None:
            return None
        return Float8CurrentScalingQuantizer._communicator_cache[self._amax_reduction_group_id]

    @property.setter
    def amax_reduction_group(self, comm):
        if comm is None:
            self._amax_reduction_group_id = None
        self._amax_reduction_group_id = id(comm)
        Float8CurrentScalingQuantizer._communicator_cache[self._amax_reduction_group_id] = comm

I'm not sure how this would interact with checkpointing though.

dst: QuantizedTensor,
*,
noop_flag: Optional[torch.Tensor] = None,
amax_reduction_group: Optional[Any] = None, # pylint: disable=unused-argument

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I strongly oppose this API change. amax reduction is very recipe-specific. It has different meanings for different recipes (FP8 DS might reduce over the TP+DP group, FP8 CS might only reduce over the TP group) and it has no meaning for other recipes (MXFP8 and FP8 block scaling). Moving it into the generic API will leak recipe-specific information, defeating the point of a generic API.

@pggPL

pggPL commented Jun 10, 2026

Copy link
Copy Markdown
Collaborator Author

/te-ci pytorch L1

pggPL and others added 4 commits June 15, 2026 16:40
…stants; fix SP memory leak; test suite hook-up

Wrap CommOverlapCore pybind11 methods that return compile-time constants
so torch.compile(fullgraph=True) can trace through them without graph
breaks:
- `is_fp8_ubuf()` → `ub_is_fp8()` / `get_ub_is_fp8()` in base.py;
  `_ub_is_fp8()` in gemm.py
- `with_cublasmp()` → `ub_is_cublasmp()` in base.py

All callers in linear.py, layernorm_linear.py, layernorm_mlp.py,
base.py, gemm.py, userbuffers_backward_linear.py and
userbuffers_forward_linear.py updated.

Fix quantized grad_output not being freed early for column-parallel SP
backward. Row-parallel SP already called clear_tensor_data(grad_output)
to release the gathered tensor; column-parallel SP quantizes grad_output
to Float8TensorStorage but never freed it before returning.  Under
torch.compile reduce-overhead this leaves 3 live pool tensors at
recording end and triggers "Detected 3 tensor(s) in the cudagraph pool
not tracked as outputs".  Extend the existing clear_tensor_data guard to
cover both parallel modes.

Fix custom-recipe quantizer state being re-initialised on every forward
call even when the recipe object has not changed. The existing early-exit
for CustomRecipeState was missing an identity check on the recipe object,
so any repeated call with the same recipe would bypass the early-return
and rebuild quantizers unnecessarily.  Add `if recipe_state.recipe is
recipe: return` to restore the intended caching behaviour.

Add test_torch_compile.py to L0_pytorch_unittest so the autocast and
existing compile tests run in CI.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
…-accumulator booleans

LinearBwdArgs stored the entire FP8 recipe object so the backward could
extract fp8_gemm_dgrad.use_split_accumulator and
fp8_gemm_wgrad.use_split_accumulator at GEMM time.  Recipe objects hold
process-group references and are not serialisable as compile-time
constants, making them incompatible with torch.compile custom-op paths.

Replace fp8_recipe with two plain bool fields:
- dgrad_use_split_accumulator (default _2X_ACC_DGRAD)
- wgrad_use_split_accumulator (default _2X_ACC_WGRAD)

These are resolved once in _linear_setup_ctx and passed into the args
struct, so the backward consumes scalars instead of a live recipe object.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
…t_result

get_ub_is_fp8 bakes is_fp8_ubuf() as a compile-time constant; without a
reset, destroy_ub + re-init with different FP8 settings would read stale
values until recompile. Only affects in-memory caches, not disk.

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
@pggPL pggPL force-pushed the remove_process_group_from_quantizers branch from e9097d6 to 948cd6d Compare June 16, 2026 12:23
pggPL and others added 2 commits June 16, 2026 16:32
ToyLinear now overrides get_quantizer_roles so CustomRecipeState doesn't hit
the no-roles warning, which graph-breaks under fullgraph=True. qfactory
dispatches on role.tensor_type instead of a pre-baked string key.

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
The amax reduction process group is no longer stored persistently on a module
quantizer or on a tensor's quantizer. No C++ changes.

- TP sequence parallel: the group is set on the input/grad-output quantizer at
  point of use in the fwd/bwd impls (linear, layernorm_linear, layernorm_mlp,
  ops basic_linear), replacing the setup-time _customize_quantizers wiring.
- FSDP2: the group is stored on Float8Tensor/NVFP4Tensor (set in
  fsdp_pre_all_gather) and applied to a throwaway quantizer copy during the
  in-place re-quant (update_quantized / _set_data).
- quantize() strips the group off the output tensor's quantizer so it never
  persists on any tensor's quantizer (breaks flatten/pickle otherwise).

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
@pggPL pggPL force-pushed the remove_process_group_from_quantizers branch from b8c1bec to 6c9b986 Compare June 16, 2026 14:56
@pggPL

pggPL commented Jun 17, 2026

Copy link
Copy Markdown
Collaborator Author

/te-ci pytorch L1

@pggPL pggPL changed the title [PyTorch][torch.compile] Remove process group from quantizers [PyTorch][torch.compile] Decouple amax reduction group from the quantizer Jun 17, 2026
set_quantizer_amax_reduction_group was a no-op on a DebugQuantizer (it
lacks with_amax_reduction), so with nvinspect enabled the parent
quantizer never got the SP amax reduction group, breaking fp8 current
scaling column-parallel sequence-parallel numerics (debug test_numerics).

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
@pggPL

pggPL commented Jun 17, 2026

Copy link
Copy Markdown
Collaborator Author

@timmoon10 I have changed the concept of this PR. Note that it also consist the changed from #3130

@pggPL

pggPL commented Jun 17, 2026

Copy link
Copy Markdown
Collaborator Author

/te-ci pytorch L1

@ksivaman

Copy link
Copy Markdown
Member

/te-ci pytorch L0 L1

@ksivaman ksivaman left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@pggPL

pggPL commented Jun 29, 2026

Copy link
Copy Markdown
Collaborator Author

Failures are unrelated

@pggPL pggPL closed this Jun 29, 2026
@pggPL pggPL reopened this Jun 29, 2026
@pggPL pggPL dismissed timmoon10’s stale review June 29, 2026 09:16

I have changed the approach, which does not change C++ code.

@pggPL pggPL merged commit a076917 into NVIDIA:main Jun 29, 2026
16 of 24 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants