Skip to content

[PyTorch] Expert Parallelism: PyTorch wrapper + autograd ops with symm-mem zero-copy#3035

Open
phu0ngng wants to merge 43 commits into
NVIDIA:mainfrom
phu0ngng:phuong/ep-3-pytorch-on-commwindow
Open

[PyTorch] Expert Parallelism: PyTorch wrapper + autograd ops with symm-mem zero-copy#3035
phu0ngng wants to merge 43 commits into
NVIDIA:mainfrom
phu0ngng:phuong/ep-3-pytorch-on-commwindow

Conversation

@phu0ngng

@phu0ngng phu0ngng commented May 22, 2026

Copy link
Copy Markdown
Collaborator

Summary

Second PR in the TE Expert Parallelism (EP) series. Adds the PyTorch binding on top of the common C API (#3034): exposes EP dispatch/combine as torch.library custom ops with autograd, and plumbs NCCL symmetric-memory windows through for the zero-copy path.

Payload tensors allocated via te.pytorch.ep.symm_mem_alloc take the one-sided zero-copy path when ep_bootstrap(zero_copy=True); anything else falls back to staged-copy, so the API stays drop-in compatible with any allocator.

Implementation

Public Python API (transformer_engine/pytorch/ep.py)

    EpBuffer, ep_bootstrap, ep_finalize,                                                                                                                                                                                                                                                        ep_dispatch, ep_combine,
    symm_mem_alloc,                                                                                                                                                                                                                                                                         )
  • ep_bootstrap / ep_finalize - one-time per-process init/teardown. Borrows the NCCL comm from ep_group via ProcessGroupNCCL._comm_ptr() (no separate ncclUniqueId bootstrap). ep_finalize is optional - an atexit handler covers normal shutdown; call it explicitly before dist.destroy_process_group(). Requires ep_group.size() >= 2.
  • symm_mem_alloc(shape, dtype, ep_group) - per-rank tensor backed by NCCL symmetric memory, rendezvoused on ep_group.
  • EpBuffer - per-layer state: routing handle + persistent payload slots (recv_tokens, combine_in, grad buffers). One per concurrently-in-flight call (e.g. PP-1F1B microbatch). Symm-mem-backed when zero_copy=True.
  • ep_dispatch / ep_combine - autograd-aware per-step ops, registered as torch.library.custom_op with correct mutates_args, so they compose with torch.compile fullgraph and CUDA graphs.
    Current payload dtype is restricted to bfloat16; FP8 quantize/dequantize stays outside the EP boundary.

C++ bindings (transformer_engine/pytorch/csrc/extensions/ep.cpp)

  • POD-only pybind boundary (primitives + pybind11::object for dtype) - no c10d ABI on the boundary. - maybe_make_window() resolves each payload tensor to an NVTECommWindow via c10d::symmetric_memory::rendezvous; non-symm-mem tensors return kNoWindow and the backend picks staged-copy automatically.
  • Zero-copy toggle captured at ep_initialize and forwarded into NVTEEpGroupConfig.zero_copy.

Build

build_tools/pytorch.py propagates -DNVTE_WITH_NCCL_EP (gated on NVTE_BUILD_WITH_NCCL_EP=1, default on) and -DUSE_NCCL so PyTorch's symm-mem feature macros are visible. When NCCL EP is off, ep.cpp no-ops behind the #ifdef.

Testing

  • tests/pytorch/distributed/run_ep.py - 8-test suite: prepare correctness, raw dispatch/combine identity round-trip, dispatch fwd+bwd VJP, full fwd+bwd round-trip, multi-iter bit-stability, CUDA graph capture, PP-1F1B 3-buffer interleave, int64 topk_idx validation. Launcher run_test_ep.sh auto-detects GPUs (skips with <4). Pytest driver: tests/pytorch/distributed/test_ep.py.
  • Example: examples/pytorch/ep/ep_moe.py - minimal end-to-end MoE fwd+bwd driver with --check against an analytical reference.
  • Bench: examples/pytorch/ep/bench/ep_bench.py - times raw + autograd dispatch/combine, optional --cuda-graph capture and --kineto/--nsys profiling.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@phu0ngng phu0ngng requested review from ksivaman and ptrendx as code owners May 22, 2026 02:54
@greptile-apps

greptile-apps Bot commented May 22, 2026

Copy link
Copy Markdown
Contributor

Greptile Summary

This PR adds the PyTorch binding layer for NCCL Expert Parallelism (EP): a public ep.py API (ep_bootstrap, ep_finalize, EpBuffer, ep_dispatch, ep_combine, symm_mem_alloc), C++ pybind glue in ep.cpp, build-system integration, an 8-test distributed suite, an end-to-end MoE example, and a perf benchmark. Several previously flagged gaps (missing contiguity guards, token-count mismatch checks, topk_weights dtype validation, wrong symm_mem.rendezvous argument type) are resolved in the current code; the C++ backward paths now use maybe_make_window instead of check_symm_mem_required for autograd-allocated gradients, fixing the zero-copy backward crash.

  • ep_bench.py ep_dispatch_fwd_bwd stage is broken: fwd_bwd_dispatch_fn returns only recv_tokens[0], so g_recv_topk_weights=None reaches _EpDispatch.backward, crashing on .contiguous() in every warmup and timed iteration.
  • ep_moe.py --check flag is a no-op: action=\"store_true\" with default=True means the reference-check all-gather runs unconditionally; there is no way to disable it from the command line.
  • The test suite's workaround (0.0 * rw.float().sum() in every backward loss) keeps run_ep.py functional, but the same pattern is missing from the benchmark.

Confidence Score: 3/5

The benchmark crashes unconditionally on every run of ep_dispatch_fwd_bwd; the core library and test suite are functional.

The benchmark crashes in warmup on every invocation of ep_dispatch_fwd_bwd because g_recv_topk_weights arrives as None in _EpDispatch.backward. The test suite avoids this by including a 0.0*rw.float().sum() term, but ep_bench.py omits that guard and indexes only [0] from ep_dispatch.

examples/pytorch/ep/bench/ep_bench.py (ep_dispatch_fwd_bwd stage crashes) and examples/pytorch/ep/ep_moe.py (--check flag inversion)

Important Files Changed

Filename Overview
transformer_engine/pytorch/ep.py Core EP Python API; autograd wrappers, EpBuffer, and public ep_dispatch/ep_combine. Several known correctness bugs remain (None grad crash in dispatch backward, wrong zero-grad shape), though many previously flagged gaps (contiguity, dtype validation) are now fixed.
transformer_engine/pytorch/csrc/extensions/ep.cpp C++ pybind layer; adds contiguity/dtype/token-count guards, moves backward grad inputs to maybe_make_window (staged-copy fallback) instead of check_symm_mem_required, resolving the previously flagged zero-copy backward crash.
transformer_engine/pytorch/distributed.py Adds symm_mem_alloc helper; correctly passes ProcessGroup object to symm_mem.rendezvous (the previously flagged group_name string bug is not present here).
examples/pytorch/ep/ep_moe.py MoE end-to-end example; --check flag is unusable as an opt-in because default=True + action="store_true" means it is always True and cannot be disabled from the command line.
examples/pytorch/ep/bench/ep_bench.py EP performance benchmark; ep_dispatch_fwd_bwd stage crashes during warmup because the loss is computed only on recv_tokens[0], leaving g_recv_topk_weights=None in backward.
tests/pytorch/distributed/run_ep.py 8-test distributed suite; correctly works around the None-grad issue by including 0.0*rw.float().sum() in all backward losses. Zero-copy tests properly stage grads through symm-mem via _GradToSymm.
build_tools/pytorch.py Propagates -DNVTE_WITH_NCCL_EP and -DUSE_NCCL from NVTE_WITH_NCCL_EP env var; consistent with setup.py and build_tools/utils.py.

Sequence Diagram

%%{init: {'theme': 'neutral'}}%%
sequenceDiagram
    participant PY as ep.py (Python)
    participant CPP as ep.cpp (C++ pybind)
    participant NVTE as NVTE EP Backend
    participant NCCL as NCCL EP

    Note over PY,NCCL: Bootstrap (once per process)
    PY->>CPP: ep_initialize(comm_ptr, group_name, cfg)
    CPP->>NVTE: nvte_ep_initialize(ep_comm, NVTEEpGroupConfig)
    NVTE->>NCCL: ncclEpCreateGroup(ep_comm, cfg)

    Note over PY,NCCL: Per-step Forward
    PY->>CPP: ep_prepare + ep_dispatch(tokens, recv_tokens)
    CPP->>CPP: maybe_make_window(recv_tokens)
    CPP->>NVTE: nvte_ep_dispatch(..., recv_tokens_win)
    NVTE->>NCCL: ncclEpDispatch (zero-copy or staged)
    PY->>CPP: ep_combine(expert_out, result)
    CPP->>NVTE: nvte_ep_combine(..., expert_out_win)
    NVTE->>NCCL: ncclEpCombine

    Note over PY,NCCL: Backward
    PY->>CPP: ep_dispatch_bwd(g_recv_tokens, g_recv_topk_weights)
    CPP->>CPP: maybe_make_window(grad) kNoWindow for autograd grads
    CPP->>NVTE: nvte_ep_dispatch_bwd (staged-copy path)
    PY->>CPP: ep_combine_bwd(g_result, grad_expert_out)
    CPP->>CPP: check_symm_mem_required(grad_expert_out) zero-copy only
    CPP->>NVTE: nvte_ep_combine_bwd
Loading
%%{init: {'theme': 'base', 'themeVariables': {"darkMode": true, "background": "#0d1117", "primaryColor": "#21262d", "primaryTextColor": "#e6edf3", "primaryBorderColor": "#8b949e", "lineColor": "#8b949e", "textColor": "#e6edf3", "edgeLabelBackground": "#161b22", "actorBkg": "#21262d", "actorBorder": "#8b949e", "actorTextColor": "#e6edf3", "actorLineColor": "#8b949e", "signalColor": "#8b949e", "signalTextColor": "#e6edf3", "noteBkgColor": "#373320", "noteBorderColor": "#d4a72c", "noteTextColor": "#f0e6c0", "labelBoxBkgColor": "#21262d", "labelBoxBorderColor": "#8b949e", "labelTextColor": "#e6edf3", "loopTextColor": "#e6edf3", "activationBkgColor": "#30363d", "activationBorderColor": "#8b949e"}}}%%
sequenceDiagram
    participant PY as ep.py (Python)
    participant CPP as ep.cpp (C++ pybind)
    participant NVTE as NVTE EP Backend
    participant NCCL as NCCL EP

    Note over PY,NCCL: Bootstrap (once per process)
    PY->>CPP: ep_initialize(comm_ptr, group_name, cfg)
    CPP->>NVTE: nvte_ep_initialize(ep_comm, NVTEEpGroupConfig)
    NVTE->>NCCL: ncclEpCreateGroup(ep_comm, cfg)

    Note over PY,NCCL: Per-step Forward
    PY->>CPP: ep_prepare + ep_dispatch(tokens, recv_tokens)
    CPP->>CPP: maybe_make_window(recv_tokens)
    CPP->>NVTE: nvte_ep_dispatch(..., recv_tokens_win)
    NVTE->>NCCL: ncclEpDispatch (zero-copy or staged)
    PY->>CPP: ep_combine(expert_out, result)
    CPP->>NVTE: nvte_ep_combine(..., expert_out_win)
    NVTE->>NCCL: ncclEpCombine

    Note over PY,NCCL: Backward
    PY->>CPP: ep_dispatch_bwd(g_recv_tokens, g_recv_topk_weights)
    CPP->>CPP: maybe_make_window(grad) kNoWindow for autograd grads
    CPP->>NVTE: nvte_ep_dispatch_bwd (staged-copy path)
    PY->>CPP: ep_combine_bwd(g_result, grad_expert_out)
    CPP->>CPP: check_symm_mem_required(grad_expert_out) zero-copy only
    CPP->>NVTE: nvte_ep_combine_bwd
Loading

Reviews (19): Last reviewed commit: "Merge branch 'main' into phuong/ep-3-pyt..." | Re-trigger Greptile

Comment thread transformer_engine/pytorch/ep.py Outdated
Comment thread transformer_engine/pytorch/csrc/extensions/ep.cpp Outdated
Comment thread transformer_engine/pytorch/ep.py Outdated
Comment on lines +558 to +568
@contextlib.contextmanager
def _zero_copy_scope(enabled: bool):
"""Toggles whether per-step ops apply the symm-mem NCCL window annotation."""
if enabled:
yield
return
tex.ep_set_zero_copy(False)
try:
yield
finally:
tex.ep_set_zero_copy(True)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 _zero_copy_scope does not save/restore the previous flag value

When enabled=False, the manager unconditionally sets g_zero_copy_enabled=False on entry and g_zero_copy_enabled=True on exit. If two callers both use zero_copy=False concurrently (e.g., pipeline-parallel microbatches dispatched from separate Python threads) or if the context is nested, the inner scope's finally block prematurely re-enables zero-copy while the outer scope is still active. The outer scope's finally then sets True again, but between the inner finally and the outer finally the C++ layer sees True unexpectedly.

The fix is to capture the previous value before writing and restore it unconditionally: save old = tex.ep_get_zero_copy() (adding a corresponding getter), then tex.ep_set_zero_copy(old) in the finally block. At minimum, document the single-caller-at-a-time assumption prominently so pipeline-parallel users know to serialize.

Comment thread transformer_engine/common/ep/ep_backend.cpp Outdated
@phu0ngng phu0ngng marked this pull request as draft May 22, 2026 03:03
@phu0ngng phu0ngng force-pushed the phuong/ep-3-pytorch-on-commwindow branch 4 times, most recently from 540ef54 to bacae5f Compare May 24, 2026 00:06
Comment thread transformer_engine/pytorch/ep.py Outdated
device = expert_out.device
# Weight in payload dtype: single fused broadcast multiply into combine_in.
w = recv_topk_weights.unsqueeze(-1).to(expert_out.dtype)
torch.mul(expert_out, w, out=combine_in)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why we need this?🤔
At the training scenario, the weight gets multiplied onto the activation between fc1 and fc2 (we also dispatch the weight at the same time as dispatching the tokens), or am I misunderstanding something here?

My understanding is that this multiplication is unnecessary. Furthermore, if it is removed, another problem becomes more prominent: how do we add symm buffer support for the combine input? This would require changes on the grouped GEMM side.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Second this. I saw unexpected kernel here and found this same problem. A potential solution is to provide a separate path when the weight is not provided. This means the weight multiplication is handled elsewhere, and in this case skip the multiplication here.

@phu0ngng phu0ngng May 26, 2026

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good to learn that we can fuse the weight x to the activation. I will make this optional.

We will need to change the GG to return the symmetric memory buf.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. we need change the grouped gemm I think

ep_group: dist.ProcessGroup,
num_experts: int,
max_tokens_per_rank: int,
recv_capacity_per_rank: int,

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When allocating the buffer, we need to allocate according to the worst case. There are two scenarios here:

  • The first is rank-major, where the memory footprint is max_tokens_per_rank × num_of_ranks. This generally stays below 10 GB, which is the primary memory overhead of typical EP setups and is acceptable.
  • The second is expert-major, where the memory footprint is max_tokens_per_rank × num_of_ranks × min(topk, num_of_experts). This could reach 40–50 GB, which is unacceptable.

If I understand this correctly, we must find a way to optimize the memory usage in the expert-major layout — or alternatively, we need to fall back to the rank-major layout + explicit permutation approach.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With the rank-major, you still need to overallocate the output buffer of local permute as in expert-major. Right?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are two types of buffers:

The first is the EP buffer, which serves as the destination for communication (NCCL EP is a push-based design), so it requires a relatively costly registration process. These are reused globally as static buffers as much as possible, so they are allocated based on the worst-case size. In HEP, the rank-major output buffer is an EP buffer, so we only need a rank-major worst-case-size buffer. I haven't studied NCCL EP in detail, but my understanding is that if our output is a symmetric buffer, we don't need a built-in static comm buffer inside NCCL EP — meaning recv_capacity_per_rank is not needed when the output buffer is a symm buffer. I think this is worth discussing and clarifying.

The second type is regular GPU memory, which can be managed by the caching allocator. In HEP, the output of the permute operation falls into this category — it can be dynamically allocated each iteration based on the scan result, with just one additional sync required. Additionally, in sync-free mode, the size of this buffer is specified by the user.

To summarize, we may need to confirm whether recv_capacity_per_rank requires building an expert-major worst-case-size buffer inside NCCL EP. If the output is a symm buffer, we theoretically don't need such a buffer. However, if it is necessary, then we cannot accept an expert-major worst-case-size buffer. I also observed in my draft PR that NCCL EP uses more memory.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi,
It's correct that if the output buffer is a symmem, then we should not need to register the gigantic IPC/MC buffer in ep_group with the size based on recv_capacity_per_rank. Let's request NCCL EP to add an option to skip this buffer allocation.

However, I think we should still ask users to specify this recv_capacity_per_rank so that we can handle overflow policy in the metadata_preprocessing rather than delaying it to dispatch phase.

@Autumn1998 Autumn1998 May 28, 2026

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need an option to skip this internal buffer.
Also, are you thinking of using recv_capacity_per_rank to support the sync-free mechanism? That is, tokens exceeding the threshold get dropped, and then trigger the flipping of the overflow flag? I think this is not correct — we should not set it at buffer initialization, but instead pass it as a parameter before the preprocess step of each dispatch, because the threshold changes every iteration.
cc @nanz-nv plz correct me if I made mistakes

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

because the threshold changes every iteration.

I'm curious to learn about this possibility. From my understanding, the output buffers need to have a static size for CUDA Graph replay, and so does the recv_capacity.

@Autumn1998 Autumn1998 May 29, 2026

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think for each global batch, we recalculate a new output size, since each batch has its own CUDA graph — but I'm not 100% sure on this. You may want to confirm with @nanz-nv.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is something in between. With the current way of doing full-iteration cuda graph, ideally recv_capacity_per_rank should stay the same across training, but it can sometimes gets updated. So I'd treat it as something that may change but not frequently.

@timmoon10 timmoon10 self-requested a review June 1, 2026 17:34
@phu0ngng phu0ngng force-pushed the phuong/ep-3-pytorch-on-commwindow branch 4 times, most recently from 40d8011 to 2153492 Compare June 10, 2026 01:27
@phu0ngng phu0ngng marked this pull request as ready for review June 10, 2026 01:28
Comment thread transformer_engine/pytorch/csrc/extensions/ep.cpp
@phu0ngng phu0ngng force-pushed the phuong/ep-3-pytorch-on-commwindow branch from 9ec1aff to 7ce8d8b Compare June 10, 2026 03:20
Comment thread transformer_engine/pytorch/ep.py
@phu0ngng phu0ngng force-pushed the phuong/ep-3-pytorch-on-commwindow branch from b2ab069 to c8c54fd Compare June 11, 2026 00:22
Comment thread transformer_engine/pytorch/ep.py
Comment thread transformer_engine/pytorch/csrc/extensions/ep.cpp
@phu0ngng phu0ngng force-pushed the phuong/ep-3-pytorch-on-commwindow branch from df732a5 to 67917a3 Compare June 11, 2026 16:16
@phu0ngng

Copy link
Copy Markdown
Collaborator Author

/te-ci pytorch L1

@phu0ngng

Copy link
Copy Markdown
Collaborator Author

Pipeline #54455868 TE EP tests passed in L1_pytorch_distributed_unittest--B200_8GPU and L1_pytorch_distributed_unittest--H100_4GPU. There are other failures that are unrelated to TE EP.

@phu0ngng phu0ngng force-pushed the phuong/ep-3-pytorch-on-commwindow branch 2 times, most recently from 52bbf88 to d6c5745 Compare June 13, 2026 00:08
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
@phu0ngng phu0ngng added the 2.7.0 label Jun 24, 2026
phu0ngng and others added 5 commits June 24, 2026 06:39
…ler_provides_combine_grad_buffer; recv_topk_weights is always buffer-owned

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…ovides_combine_grad_buffer CLI flags to ep_moe example and ep_bench (default False)

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…buffer, dispatch/combine, tests and examples

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Comment thread transformer_engine/pytorch/ep.py Outdated
# to opt in; the C++ backend then operates the EP group in zero-copy mode.


def symm_mem_alloc(

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: This function is pretty generic to allocate symm mem, maybe consider to move it to general utils?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

self.token_counts = torch.empty(self.num_local_experts, dtype=torch.int32, device=device)
# Persistent tensor; keep resident if activation CPU offloading is on.
mark_not_offload(self.handle_mem)
self._alloc_symm_buffers()

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just note this buffer allocation logic might be pending for future change.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is an existing warning that the API related to zero-copy is subject to change.

ctx.mark_non_differentiable(token_counts)
# Detach so the long-lived buffers aren't tracked as differentiable outputs;
# autograd re-attaches grad_fn pointing back at this Function.
return recv_tokens.detach(), recv_topk_weights.detach(), token_counts

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not super clear to me why we do detach here. Autograd function is running in no_grad context anyway. If these tensors are long-lived buffers, user should allocate them as requires_grad=False?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think when recv_tokens are symmem, we need to do detach to avoid the grad_fn from sticking with this tensor, while when it is a non-symmem, we have an in-place modification which requires dirty-mark, which detachs can trigger a similar effect. I'm new to PyTorch so let me know if this is incorrect.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we allocate recv_tokens as requires_grad=False, it should not have a grad_fn. I think we should let user manage the responsibility if this tensor requires grad, i.e. if user explicitly want the output of dispatch carries grad_fn for some reason, we should not forbid it. Otherwise they can just make the tensor not require grad.

@phu0ngng phu0ngng Jun 26, 2026

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think the detach is about requires_grad or blocking grad on the output - the returned tensor still carries grad_fn, so gradients flow back to tokens. And I suspect allocating the buffer requires_grad=False wouldn't help here, since I think the real issue is views, not grad ownership.

My understanding: when I return recv_tokens as-is, autograd treats the output as a differentiable view of the persistent buffer ("because an input was returned as-is"). The next dispatch overwrites that buffer in place, so if a prior output still has a live graph - grad accumulation, multiple microbatches before backward, or CUDA-graph capture, autograd should raise Output 0 of DispatchBackward is a view ... modified inplace ... forbidden.

I wasn't fully sure, so I made a small repro and confirmed it:

import torch
  
def make_fn(detach):
    class Dispatch(torch.autograd.Function):
        @staticmethod
        def forward(ctx, recv, x):
            recv.copy_(x.detach() * 2)          # kernel writes into recv in place
            ctx.save_for_backward(x)
            return recv.detach() if detach else recv
        @staticmethod
        def backward(ctx, g):
            return None, g
    return Dispatch
 
def run(detach):
    print("detach" if detach else "raw", end=": ")
    Fn = make_fn(detach)
    recv = torch.zeros(4)                         # persistent buffer, requires_grad=False
    x0 = torch.ones(4, requires_grad=True)
    out0 = Fn.apply(recv, x0)                     # out0 aliases recv
    x1 = torch.ones(4, requires_grad=True)
    try:
        Fn.apply(recv, x1)                        # overwrite recv while out0 still live
        out0.sum().backward()
        print("OK")
    except RuntimeError as e:
        print(e.args[0].splitlines()[0])

run(detach=False)   # raw:    Output 0 of DispatchBackward is a view ... modified inplace ... forbidden.
run(detach=True)    # detach: OK

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure I'm following the idea, the example contains 2 autograd fn, which shared the same buffer, but we are starting backward from the out0, meaning the entire x1's autograd is not part of the backprop. That looks invalid to me. If you directly backprop from the second autograd function, it should be able to run, i.e.

try:
  out1=Fn.apply(recv, x1)                        # overwrite recv while out0 still live
  out1.sum().backward()
  print("OK")

Could you help me understand in real case, when we have dispatch->mlp->combine->att->dispatch..., what are potential issue without detach? My major concern is if the recv buffer is allocated as requires_grad=True, but this detach might break that. Otherwise, like what to today, with/without symm buffer, the buffer is already requires_grad=False, I think adding an additional detach should not break anything. If detach can make it safer I'm okay to move forward with it.

torch.ops.transformer_engine_ep.combine(handle_mem, expert_out, result)
ctx.save_for_backward(handle_mem)
ctx.grad_symm_buf = grad_symm_buf
if grad_symm_buf is None:

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If grad_symm_buf is not None, we should use ctx.save_for_backward to save it?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think save_for_backward should be used only when you need to read the value of the tensor in the backward path.
Here, we only want to pass the reference to the buffer so that we can write to it.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All tensors that will be used by backward should be passed by ctx.save_for_backward, this is to prevent memory leak, check https://docs.pytorch.org/docs/2.12/generated/torch.autograd.function.FunctionCtx.save_for_backward.html. I think it is needed for torch to manage its autograd graph lifecycle. If the tensor does not require grad, it is probably okay to assign it with ctx directly, but it is always safe to use ctx.save_for_backward

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suspect save_for_backward is for tensors backward reads. It version-checks them and breaks the grad_fn_ctx_tensor cycle that leaks when you save an output. grad_symm_buf is the opposite: backward only writes into it, and it's a requires_grad=False buffer owned by EpBuffer (no grad_fn, so no cycle, nothing to leak).

So I made a small code and ran a quick check, and saving broke things - since backward mutates it in place, a second backward fails the version check (on torch 2.12 it raises ...modified by an inplace operation: is at version 1; expected version 0.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

import torch

def make_fn(use_save):
    class Combine(torch.autograd.Function):
        @staticmethod
        def forward(ctx, buf, x):
            ctx.save_for_backward(buf, x) if use_save else setattr(ctx, "buf", buf)
            if not use_save:
                ctx.save_for_backward(x)
            return x * 2.0                       # differentiable output
        @staticmethod
        def backward(ctx, g):
            buf = ctx.saved_tensors[0] if use_save else ctx.buf
            buf.add_(g)                          # in-place scatter into the target
            return None, g * 2.0
    return Combine
    
def run(use_save):
    print("save_for_backward(buf)" if use_save else "ctx.buf attr", end=": ")
    Fn = make_fn(use_save)
    x = torch.ones(4, requires_grad=True)
    buf = torch.zeros(4)                          # write target, requires_grad=False
    out = Fn.apply(buf, x)
    try:
        out.sum().backward(retain_graph=True)     # 1st backward: writes buf
        out.sum().backward(retain_graph=True)     # 2nd: re-unpacks saved tensors
        print("OK")
    except RuntimeError as e:
        print(e.args[0].splitlines()[0])

run(use_save=True)    # save:  ...modified by an inplace operation: ... is at version 1; expected version 0
run(use_save=False)   # attr:  OK

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That makes sense. Do you mind to add a safety check, if the grad_symm_buf does not require grad, we put it in ctx, otherwise use save_for_backward?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The version check when using save_for_backward will fail regardless of requires_grad.

Anyway, it's good to add an assertion in a follow-up PR. So I will leave this thread unresolved for tracking.

assert grad_symm_buf is None or not grad_symm_buf.requires_grad

phu0ngng added 2 commits June 25, 2026 05:09
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…/combine_grad_expert_out buffers on EpBuffer

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Comment thread examples/pytorch/ep/ep_moe.py Outdated
phu0ngng and others added 2 commits June 25, 2026 15:54
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…ard docstring

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
@phu0ngng

Copy link
Copy Markdown
Collaborator Author

/te-ci pytorch L1

Comment thread transformer_engine/pytorch/distributed.py Outdated
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
@phu0ngng

Copy link
Copy Markdown
Collaborator Author

/te-ci pytorch L1

vthumbe1503
vthumbe1503 previously approved these changes Jun 26, 2026

@vthumbe1503 vthumbe1503 left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. This looks pretty amazing. Enabling NCCL EP seems pretty simple now from TE. This also would enable us to have a performant MOE Block implementation self contained in TE. All my comments are Nits

gc.collect()
torch.cuda.synchronize()
# Release NCCL EP's borrowed comm before torch destroys it.
ep_finalize()

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not blocking but From the usage example perspective, I think showing thr usage of API something like te.ep_finalize or te.pytorch.ep_finalize might be a better choice. Stresses that we are using this utility from TE.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ep_finalize() is optional and idempotent. ep_bootstrap registers an atexit handler that covers normal shutdown, so most users never call it.
Two cases need an explicit call:

  1. before dist.destroy_process_group(), since EP's borrowed NCCL comm goes invalid once the PG is destroyed;
  2. before re-ep_bootstrap() in the same process, to clear the double-init guard.

ep_bench hits case (1): it destroys the PG at exit, so finalizing first keeps NCCL teardown clean and warning-free. The example and unit tests, for example, do not call ep_finalize().

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wait. I see what you mean now - using them from a util imported package instead of importing individual functions. Will make a follow-up PR. Thanks!

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, right.. Alternatively, these utilities can also be made available to transformer_engine.pytorch init file. And we can directly do te.pytorch.ep_finalize

if [ -z "${KEEP_EP_LOGS:-}" ]; then rm -f "${log}"; fi
}

run_pass "default" 0

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Non Blocking: But i have two comments regarding the script

  1. I think this parameterization of default and zero_copy is better suted inside the pytest. This will also allow more parametrizations for us in the future and add more tests
  2. If we can share the torchrun among multiple parametrization, that would be more ideal. I had done something similar for FSDP2 tests here. Distributed tests take a lot of time in general because of lot of torchruns launched throughout our CI. Not a big deal currently because we are just running two tests here.

In general I wont consider this to be blocker for this PR

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For 1, the reason zero_copy needs to be in the run script instead of pytest parameterize is that one needs to bootstrap accordingly wrt zero_copy and we only allow to bootstrap once per program.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps we can rewrite the test, using ep_finalize then re-bootstrap.

@phu0ngng phu0ngng added 2.17 and removed 2.7.0 labels Jun 26, 2026
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
@phu0ngng

Copy link
Copy Markdown
Collaborator Author

/te-ci pytorch L1

@phu0ngng

Copy link
Copy Markdown
Collaborator Author

/te-ci pytorch L1

Comment thread examples/pytorch/ep/bench/ep_bench.py
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants