From 4cabf0326f45dd58b77536c0b65404006b5a930f Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Mon, 29 Jun 2026 10:28:03 +0000 Subject: [PATCH 1/8] versioning EP C configs Signed-off-by: Phuong Nguyen --- tests/cpp_distributed/test_ep.cu | 26 +++++--- tests/cpp_distributed/test_ep_common.h | 8 +-- transformer_engine/common/ep/ep_api.cpp | 65 ++++++++++++++----- transformer_engine/common/ep/ep_backend.cpp | 2 - .../common/include/transformer_engine/ep.h | 35 +++++++--- transformer_engine/jax/csrc/extensions/ep.cpp | 21 ++++-- 6 files changed, 111 insertions(+), 46 deletions(-) diff --git a/tests/cpp_distributed/test_ep.cu b/tests/cpp_distributed/test_ep.cu index c7fee7720c..aabbad6b8c 100644 --- a/tests/cpp_distributed/test_ep.cu +++ b/tests/cpp_distributed/test_ep.cu @@ -144,11 +144,15 @@ struct EPBuffers { size_t recv_capacity = 0; int top_k_ = 0; size_t alignment_ = 0; + NVTEEpLayerConfig layer_cfg_{}; void alloc(int num_tokens, int top_k, int hidden_dim, int num_local_experts, int ep_size, int max_tokens_per_rank, size_t alignment = 0) { top_k_ = top_k; alignment_ = alignment; + layer_cfg_ = NVTE_EP_LAYER_CONFIG_INIT; + layer_cfg_.top_k = top_k; + layer_cfg_.dispatch_output_per_expert_alignment = alignment; recv_capacity = static_cast(ep_size) * max_tokens_per_rank * 2; topk_idx.alloc(num_tokens * top_k); @@ -159,7 +163,7 @@ struct EPBuffers { recv_topk_weights.alloc(recv_capacity); result.alloc(num_tokens * hidden_dim); - handle_mem_size = nvte_ep_handle_mem_size(NVTEEpLayerConfig{top_k, alignment}); + handle_mem_size = nvte_ep_handle_mem_size(&layer_cfg_); handle_mem.alloc(handle_mem_size); grad_result.alloc(num_tokens * hidden_dim); @@ -181,11 +185,15 @@ struct EPTensors { int top_k_ = 0; size_t alignment_ = 0; + NVTEEpLayerConfig layer_cfg_{}; EPTensors(EPBuffers& b, int num_tokens, int top_k, int hidden_dim, int num_local_experts) { top_k_ = top_k; alignment_ = b.alignment_; + layer_cfg_ = NVTE_EP_LAYER_CONFIG_INIT; + layer_cfg_.top_k = top_k; + layer_cfg_.dispatch_output_per_expert_alignment = b.alignment_; constexpr DType kTokDType = test::TypeInfo::dtype; using Shape = std::vector; topk_idx = TensorWrapper(b.topk_idx.get(), @@ -300,7 +308,7 @@ TYPED_TEST(EPDispatchTest, PrepareAndDispatch) { cudaStream_t stream; NVTE_CHECK_CUDA(cudaStreamCreate(&stream)); - ASSERT_NO_THROW(nvte_ep_prepare(t.handle_mem.data(), t.topk_idx.data(), t.token_counts.data(), NVTEEpLayerConfig{t.top_k_, t.alignment_}, stream)); + ASSERT_NO_THROW(nvte_ep_prepare(t.handle_mem.data(), t.topk_idx.data(), t.token_counts.data(), &t.layer_cfg_, stream)); ASSERT_NO_THROW(nvte_ep_dispatch(t.handle_mem.data(), t.topk_idx.data(), t.tokens.data(), NVTECommWindow{}, t.topk_weights.data(), NVTECommWindow{}, t.recv_tokens.data(), NVTECommWindow{}, @@ -379,7 +387,7 @@ TYPED_TEST(EPCombineTest, Combine) { cudaStream_t stream; NVTE_CHECK_CUDA(cudaStreamCreate(&stream)); - ASSERT_NO_THROW(nvte_ep_prepare(t.handle_mem.data(), t.topk_idx.data(), t.token_counts.data(), NVTEEpLayerConfig{t.top_k_, t.alignment_}, stream)); + ASSERT_NO_THROW(nvte_ep_prepare(t.handle_mem.data(), t.topk_idx.data(), t.token_counts.data(), &t.layer_cfg_, stream)); ASSERT_NO_THROW(nvte_ep_dispatch(t.handle_mem.data(), t.topk_idx.data(), t.tokens.data(), NVTECommWindow{}, t.topk_weights.data(), NVTECommWindow{}, t.recv_tokens.data(), NVTECommWindow{}, @@ -426,7 +434,7 @@ TYPED_TEST(EPCombineBwdTest, CombineBwdCheck) { cudaStream_t stream; NVTE_CHECK_CUDA(cudaStreamCreate(&stream)); - ASSERT_NO_THROW(nvte_ep_prepare(t.handle_mem.data(), t.topk_idx.data(), t.token_counts.data(), NVTEEpLayerConfig{t.top_k_, t.alignment_}, stream)); + ASSERT_NO_THROW(nvte_ep_prepare(t.handle_mem.data(), t.topk_idx.data(), t.token_counts.data(), &t.layer_cfg_, stream)); ASSERT_NO_THROW(nvte_ep_dispatch(t.handle_mem.data(), t.topk_idx.data(), t.tokens.data(), NVTECommWindow{}, t.topk_weights.data(), NVTECommWindow{}, t.recv_tokens.data(), NVTECommWindow{}, @@ -495,7 +503,7 @@ TYPED_TEST(EPDispatchBwdTest, DispatchBwdCheck) { cudaStream_t stream; NVTE_CHECK_CUDA(cudaStreamCreate(&stream)); - ASSERT_NO_THROW(nvte_ep_prepare(t.handle_mem.data(), t.topk_idx.data(), t.token_counts.data(), NVTEEpLayerConfig{t.top_k_, t.alignment_}, stream)); + ASSERT_NO_THROW(nvte_ep_prepare(t.handle_mem.data(), t.topk_idx.data(), t.token_counts.data(), &t.layer_cfg_, stream)); ASSERT_NO_THROW(nvte_ep_dispatch(t.handle_mem.data(), t.topk_idx.data(), t.tokens.data(), NVTECommWindow{}, t.topk_weights.data(), NVTECommWindow{}, t.recv_tokens.data(), NVTECommWindow{}, @@ -563,7 +571,7 @@ TYPED_TEST(EPDispatchBwdGradWeightsTest, RoundTrip) { cudaStream_t stream; NVTE_CHECK_CUDA(cudaStreamCreate(&stream)); - ASSERT_NO_THROW(nvte_ep_prepare(t.handle_mem.data(), t.topk_idx.data(), t.token_counts.data(), NVTEEpLayerConfig{t.top_k_, t.alignment_}, stream)); + ASSERT_NO_THROW(nvte_ep_prepare(t.handle_mem.data(), t.topk_idx.data(), t.token_counts.data(), &t.layer_cfg_, stream)); NVTE_CHECK_CUDA(cudaMemsetAsync(buf.recv_topk_weights.get(), 0, buf.recv_topk_weights.bytes(), stream)); ASSERT_NO_THROW(nvte_ep_dispatch(t.handle_mem.data(), t.topk_idx.data(), @@ -634,7 +642,7 @@ class EPPipelineTest : public EpOpTestBase, public ::testing::WithParamInterface cudaStream_t stream; NVTE_CHECK_CUDA(cudaStreamCreate(&stream)); - ASSERT_NO_THROW(nvte_ep_prepare(t.handle_mem.data(), t.topk_idx.data(), t.token_counts.data(), NVTEEpLayerConfig{t.top_k_, t.alignment_}, stream)); + ASSERT_NO_THROW(nvte_ep_prepare(t.handle_mem.data(), t.topk_idx.data(), t.token_counts.data(), &t.layer_cfg_, stream)); ASSERT_NO_THROW(nvte_ep_dispatch(t.handle_mem.data(), t.topk_idx.data(), t.tokens.data(), NVTECommWindow{}, t.topk_weights.data(), NVTECommWindow{}, t.recv_tokens.data(), NVTECommWindow{}, @@ -759,7 +767,7 @@ TYPED_TEST(EPZeroCopyTest, IdentityAllSymm) { cudaStream_t stream; NVTE_CHECK_CUDA(cudaStreamCreate(&stream)); - ASSERT_NO_THROW(nvte_ep_prepare(ref_t.handle_mem.data(), ref_t.topk_idx.data(), ref_t.token_counts.data(), NVTEEpLayerConfig{ref_t.top_k_, ref_t.alignment_}, stream)); + ASSERT_NO_THROW(nvte_ep_prepare(ref_t.handle_mem.data(), ref_t.topk_idx.data(), ref_t.token_counts.data(), &ref_t.layer_cfg_, stream)); ASSERT_NO_THROW(nvte_ep_dispatch(ref_t.handle_mem.data(), ref_t.topk_idx.data(), ref_t.tokens.data(), NVTECommWindow{}, ref_t.topk_weights.data(), NVTECommWindow{}, ref_t.recv_tokens.data(), NVTECommWindow{}, @@ -800,7 +808,7 @@ TYPED_TEST(EPZeroCopyTest, IdentityAllSymm) { sym_t.recv_tokens = TensorWrapper(sym_recv.ptr, std::vector{sym_buf.recv_capacity, (size_t)hidden_dim_}, kTokDType); - ASSERT_NO_THROW(nvte_ep_prepare(sym_t.handle_mem.data(), sym_t.topk_idx.data(), sym_t.token_counts.data(), NVTEEpLayerConfig{sym_t.top_k_, sym_t.alignment_}, stream)); + ASSERT_NO_THROW(nvte_ep_prepare(sym_t.handle_mem.data(), sym_t.topk_idx.data(), sym_t.token_counts.data(), &sym_t.layer_cfg_, stream)); ASSERT_NO_THROW(nvte_ep_dispatch(sym_t.handle_mem.data(), sym_t.topk_idx.data(), sym_t.tokens.data(), symm_window(sym_tokens), sym_t.topk_weights.data(), NVTECommWindow{}, diff --git a/tests/cpp_distributed/test_ep_common.h b/tests/cpp_distributed/test_ep_common.h index d5e006cef6..7cf6017090 100644 --- a/tests/cpp_distributed/test_ep_common.h +++ b/tests/cpp_distributed/test_ep_common.h @@ -146,7 +146,7 @@ static bool ep_bootstrap(int argc, char* argv[]) { ncclUniqueId uid{}; exchange_unique_id(&uid); - NVTEEpGroupConfig group_config{}; + NVTEEpGroupConfig group_config = NVTE_EP_GROUP_CONFIG_INIT; group_config.ep_size = g_ep_size; group_config.num_experts = g_num_experts; group_config.max_tokens_per_rank = g_max_tokens_per_rank; @@ -156,7 +156,7 @@ static bool ep_bootstrap(int argc, char* argv[]) { group_config.max_token_dtype = g_max_token_dtype; NVTE_CHECK_NCCL(ncclCommInitRank(&g_ep_comm, g_num_processes, uid, g_process_id)); - nvte_ep_initialize(static_cast(g_ep_comm), group_config); + nvte_ep_initialize(static_cast(g_ep_comm), &group_config); if (g_process_id == 0) { printf("EP initialized: ep_size=%d num_experts=%d " @@ -173,7 +173,7 @@ static bool ep_bootstrap(int argc, char* argv[]) { static void ep_reinitialize(int zero_copy) { if (!g_ep_initialized) return; nvte_ep_shutdown(); - NVTEEpGroupConfig group_config{}; + NVTEEpGroupConfig group_config = NVTE_EP_GROUP_CONFIG_INIT; group_config.ep_size = g_ep_size; group_config.num_experts = g_num_experts; group_config.max_tokens_per_rank = g_max_tokens_per_rank; @@ -181,7 +181,7 @@ static void ep_reinitialize(int zero_copy) { group_config.hidden_dim = g_hidden_dim; group_config.max_token_dtype = g_max_token_dtype; group_config.zero_copy = zero_copy; - nvte_ep_initialize(static_cast(g_ep_comm), group_config); + nvte_ep_initialize(static_cast(g_ep_comm), &group_config); } // Tear down in dependency order: backend's ep_group reads from ep_comm, diff --git a/transformer_engine/common/ep/ep_api.cpp b/transformer_engine/common/ep/ep_api.cpp index 66ee3dc8d9..fde4882806 100644 --- a/transformer_engine/common/ep/ep_api.cpp +++ b/transformer_engine/common/ep/ep_api.cpp @@ -13,6 +13,10 @@ #include +#include +#include +#include + #include "../util/logging.h" #if defined(NVTE_WITH_NCCL_EP) @@ -24,18 +28,30 @@ using transformer_engine::ep::EPBackend; -void nvte_ep_initialize(void* ep_comm, NVTEEpGroupConfig group_config) { - NVTE_CHECK(ep_comm != nullptr, "ep_comm must not be null"); - EPBackend::initialize(static_cast(ep_comm), group_config); -} - -void nvte_ep_shutdown(void) { EPBackend::shutdown(); } - -size_t nvte_ep_handle_mem_size(NVTEEpLayerConfig layer_cfg) { - return EPBackend::get().handle_mem_size(layer_cfg); +namespace { +// Smallest accepted struct_size: covers the base (required) fields. Frozen; +// never raise these. Later fields are read only when struct_size covers them. +constexpr size_t kGroupConfigMinSize = offsetof(NVTEEpGroupConfig, zero_copy) + sizeof(int); +constexpr size_t kLayerConfigMinSize = + offsetof(NVTEEpLayerConfig, dispatch_output_per_expert_alignment) + sizeof(size_t); + +// Copy a caller's versioned config into a full current-layout struct: fields +// the caller did not provide stay zero, extra trailing fields are dropped. +// struct_size 0 is read as the base layout. Requires a size_t struct_size +// first member. +template +Cfg normalize_ep_config(const Cfg* user, size_t min_size, const char* name) { + NVTE_CHECK(user != nullptr, name, " must not be null"); + const size_t want = (user->struct_size == 0) ? min_size : user->struct_size; + NVTE_CHECK(want >= min_size, name, ".struct_size (", user->struct_size, + ") is below the required minimum ", min_size, + "; zero-initialize the struct or set struct_size via NVTE_EP_*_CONFIG_INIT"); + Cfg cfg{}; + std::memcpy(&cfg, user, std::min(want, sizeof(Cfg))); + cfg.struct_size = sizeof(Cfg); + return cfg; } -namespace { inline void* handle_mem_ptr(NVTETensor mem) { void* p = nvte_tensor_data(mem); NVTE_CHECK(p != nullptr, "handle_mem tensor data must not be null"); @@ -43,9 +59,26 @@ inline void* handle_mem_ptr(NVTETensor mem) { } } // namespace +void nvte_ep_initialize(void* ep_comm, const NVTEEpGroupConfig* group_config) { + NVTE_CHECK(ep_comm != nullptr, "ep_comm must not be null"); + NVTEEpGroupConfig cfg = + normalize_ep_config(group_config, kGroupConfigMinSize, "group_config"); + EPBackend::initialize(static_cast(ep_comm), cfg); +} + +void nvte_ep_shutdown(void) { EPBackend::shutdown(); } + +size_t nvte_ep_handle_mem_size(const NVTEEpLayerConfig* layer_cfg) { + NVTEEpLayerConfig cfg = + normalize_ep_config(layer_cfg, kLayerConfigMinSize, "layer_cfg"); + return EPBackend::get().handle_mem_size(cfg); +} + void nvte_ep_prepare(NVTETensor handle_mem, NVTETensor topk_idx, NVTETensor token_counts, - NVTEEpLayerConfig layer_cfg, cudaStream_t stream) { - EPBackend::get().prepare(handle_mem_ptr(handle_mem), topk_idx, token_counts, layer_cfg, stream); + const NVTEEpLayerConfig* layer_cfg, cudaStream_t stream) { + NVTEEpLayerConfig cfg = + normalize_ep_config(layer_cfg, kLayerConfigMinSize, "layer_cfg"); + EPBackend::get().prepare(handle_mem_ptr(handle_mem), topk_idx, token_counts, cfg, stream); } void nvte_ep_dispatch(NVTETensor handle_mem, NVTETensor topk_idx, NVTETensor tokens, @@ -88,14 +121,16 @@ namespace { } } // namespace -void nvte_ep_initialize(void* /*ep_comm*/, NVTEEpGroupConfig /*group_config*/) { ep_not_built(); } +void nvte_ep_initialize(void* /*ep_comm*/, const NVTEEpGroupConfig* /*group_config*/) { + ep_not_built(); +} void nvte_ep_shutdown(void) {} -size_t nvte_ep_handle_mem_size(NVTEEpLayerConfig /*layer_cfg*/) { ep_not_built(); } +size_t nvte_ep_handle_mem_size(const NVTEEpLayerConfig* /*layer_cfg*/) { ep_not_built(); } void nvte_ep_prepare(NVTETensor /*handle_mem*/, NVTETensor /*topk_idx*/, - NVTETensor /*token_counts*/, NVTEEpLayerConfig /*layer_cfg*/, + NVTETensor /*token_counts*/, const NVTEEpLayerConfig* /*layer_cfg*/, cudaStream_t /*stream*/) { ep_not_built(); } diff --git a/transformer_engine/common/ep/ep_backend.cpp b/transformer_engine/common/ep/ep_backend.cpp index f1510693bb..d20364ac66 100644 --- a/transformer_engine/common/ep/ep_backend.cpp +++ b/transformer_engine/common/ep/ep_backend.cpp @@ -97,8 +97,6 @@ void EPBackend::validate_config(const NVTEEpGroupConfig& config) { NVTE_CHECK(config.max_recv_tokens_per_rank > 0, "max_recv_tokens_per_rank must be positive, got ", config.max_recv_tokens_per_rank); NVTE_CHECK(config.hidden_dim > 0, "hidden_dim must be positive, got ", config.hidden_dim); - NVTE_CHECK(config.max_token_dtype >= 0 && config.max_token_dtype < kNVTENumTypes, - "max_token_dtype out of range, got ", static_cast(config.max_token_dtype)); const size_t elem_bytes = typeToSize(static_cast(config.max_token_dtype)); const size_t row_bytes = static_cast(config.hidden_dim) * elem_bytes; NVTE_CHECK(row_bytes >= 16, diff --git a/transformer_engine/common/include/transformer_engine/ep.h b/transformer_engine/common/include/transformer_engine/ep.h index 8928b92825..e03323cd65 100644 --- a/transformer_engine/common/include/transformer_engine/ep.h +++ b/transformer_engine/common/include/transformer_engine/ep.h @@ -28,11 +28,16 @@ extern "C" { #endif /* -- Config structs ------------------------------------------------------- */ -/* TODO: add a struct_size/version field to these configs (and align with other - * TE public structs) once a TE-wide convention for ABI versioning lands. */ +/* Each config begins with struct_size so the API can add fields without + * breaking ABI. The backend reads only the bytes struct_size covers and + * zero-defaults the rest; struct_size 0 means the base layout. Append new + * fields at the end only; never reorder, resize, or remove existing ones. */ /*! \brief Group-level EP configuration (fixed for the EP group lifetime). */ typedef struct { + /*! Struct size in bytes, or 0 for the base layout. Set to + * sizeof(NVTEEpGroupConfig) to include fields added in newer versions. */ + size_t struct_size; /*! EP world size. */ int ep_size; /*! Total experts across all ranks. */ @@ -46,7 +51,8 @@ typedef struct { /*! Max SMs for EP kernels. 0 = auto. */ int max_num_sms; /*! Widest token dtype the group will dispatch; sizes staging buffers. - * Per-dispatch tensors may use any dtype with element size <= this. */ + * Required (no default): must be set to a real token dtype. Per-dispatch + * tensors may use any dtype with element size <= this. */ NVTEDType max_token_dtype; /*! Zero-copy dispatch/combine. When nonzero, payload tensors must be backed * by NVTECommWindow handles and transfer in place (no staging copies); @@ -59,6 +65,9 @@ typedef struct { * overflow policy, ...). */ typedef struct { + /*! Struct size in bytes, or 0 for the base layout. Set to + * sizeof(NVTEEpLayerConfig) to include fields added in newer versions. */ + size_t struct_size; /*! Per-token expert fan-out (> 0). */ int top_k; /*! Per-expert recv-slab alignment in tokens (power of two; 0/1 disables). @@ -67,6 +76,14 @@ typedef struct { size_t dispatch_output_per_expert_alignment; } NVTEEpLayerConfig; +/* Zero-init a config with struct_size set to the current layout: + * NVTEEpGroupConfig cfg = NVTE_EP_GROUP_CONFIG_INIT; + * cfg.ep_size = ...; */ +#define NVTE_EP_GROUP_CONFIG_INIT \ + { sizeof(NVTEEpGroupConfig) } +#define NVTE_EP_LAYER_CONFIG_INIT \ + { sizeof(NVTEEpLayerConfig) } + /* -- Bootstrap ------------------------------------------------------------ */ /*! \brief Bootstrap the EP backend from an existing NCCL EP sub-communicator. @@ -78,9 +95,9 @@ typedef struct { * group per process, bound to the current CUDA device. * * \param[in] ep_comm Opaque ncclComm_t for the EP sub-group. - * \param[in] group_config Group-level EP configuration. + * \param[in] group_config Group-level EP configuration (struct_size set). */ -void nvte_ep_initialize(void* ep_comm, NVTEEpGroupConfig group_config); +void nvte_ep_initialize(void* ep_comm, const NVTEEpGroupConfig* group_config); /*! \brief Tear down the EP backend. Idempotent. Does not destroy ep_comm. */ void nvte_ep_shutdown(void); @@ -94,10 +111,10 @@ void nvte_ep_shutdown(void); * for that layer (the backend keys its cache on the pointer). Host-only; * size is stable for a given (group, layer) pair. * - * \param[in] layer_cfg Per-call layer configuration. + * \param[in] layer_cfg Per-call layer configuration (struct_size set). * \return size in bytes for the handle_mem buffer. */ -size_t nvte_ep_handle_mem_size(NVTEEpLayerConfig layer_cfg); +size_t nvte_ep_handle_mem_size(const NVTEEpLayerConfig* layer_cfg); /* -- Per-step ops (all allocation-free, CUDA graph-capturable) ------------ */ @@ -112,11 +129,11 @@ size_t nvte_ep_handle_mem_size(NVTEEpLayerConfig layer_cfg); * \param[in] handle_mem uint8 routing-state buffer. * \param[in] topk_idx [T, top_k] int64 routing indices. * \param[out] token_counts [num_local_experts] int32 counts. - * \param[in] layer_cfg Per-call layer configuration. + * \param[in] layer_cfg Per-call layer configuration (struct_size set). * \param[in] stream CUDA stream. */ void nvte_ep_prepare(NVTETensor handle_mem, NVTETensor topk_idx, NVTETensor token_counts, - NVTEEpLayerConfig layer_cfg, cudaStream_t stream); + const NVTEEpLayerConfig* layer_cfg, cudaStream_t stream); /*! \brief Dispatch tokens (and routing weights) to expert ranks. * diff --git a/transformer_engine/jax/csrc/extensions/ep.cpp b/transformer_engine/jax/csrc/extensions/ep.cpp index ee204e7594..370066e6ff 100644 --- a/transformer_engine/jax/csrc/extensions/ep.cpp +++ b/transformer_engine/jax/csrc/extensions/ep.cpp @@ -45,7 +45,8 @@ class EpResources { NVTE_CHECK_NCCL(ncclCommInitRank(&comm_, p.ep_size, uid, p.rank_within_group)); // zero_copy=0: JAX EP path always stages payloads; the zero-copy fast path // requires NVTECommWindow-backed tensors, which JAX bindings don't expose. - NVTEEpGroupConfig cfg{.ep_size = p.ep_size, + NVTEEpGroupConfig cfg{.struct_size = sizeof(NVTEEpGroupConfig), + .ep_size = p.ep_size, .num_experts = p.num_experts, .max_tokens_per_rank = p.max_tokens_per_rank, .max_recv_tokens_per_rank = p.max_recv_tokens_per_rank, @@ -54,7 +55,7 @@ class EpResources { .max_token_dtype = p.max_token_dtype, .zero_copy = 0}; try { - nvte_ep_initialize(static_cast(comm_), cfg); + nvte_ep_initialize(static_cast(comm_), &cfg); } catch (...) { ncclCommDestroy(comm_); comm_ = nullptr; @@ -162,8 +163,11 @@ void ReleaseEpResources() { } size_t EpHandleMemSize(int top_k, size_t dispatch_output_per_expert_alignment) { - NVTEEpLayerConfig layer_cfg{top_k, dispatch_output_per_expert_alignment}; - return nvte_ep_handle_mem_size(layer_cfg); + NVTEEpLayerConfig layer_cfg{.struct_size = sizeof(NVTEEpLayerConfig), + .top_k = top_k, + .dispatch_output_per_expert_alignment = + dispatch_output_per_expert_alignment}; + return nvte_ep_handle_mem_size(&layer_cfg); } pybind11::capsule GetEpInstanceStateTypeIdCapsule() { @@ -224,9 +228,12 @@ Error_Type EpPrepareFFI(cudaStream_t stream, EpInstanceState* ep_state, Buffer_T std::vector hm_shape = {static_cast(handle_mem->element_count())}; auto handle_mem_ = TensorWrapper(handle_mem->untyped_data(), hm_shape, DType::kByte); - NVTEEpLayerConfig layer_cfg{static_cast(config.top_k), - static_cast(config.dispatch_output_per_expert_alignment)}; - nvte_ep_prepare(handle_mem_.data(), topk_idx_.data(), token_counts_.data(), layer_cfg, stream); + NVTEEpLayerConfig layer_cfg{ + .struct_size = sizeof(NVTEEpLayerConfig), + .top_k = static_cast(config.top_k), + .dispatch_output_per_expert_alignment = + static_cast(config.dispatch_output_per_expert_alignment)}; + nvte_ep_prepare(handle_mem_.data(), topk_idx_.data(), token_counts_.data(), &layer_cfg, stream); return ffi_with_cuda_error_check(); } From ed3d74056a9ba2c92d6a26e01b028e891a35e514 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Mon, 29 Jun 2026 03:36:25 -0700 Subject: [PATCH 2/8] Rename EP prepare token_counts to recv_tokens_per_expert Signed-off-by: Phuong Nguyen --- tests/cpp_distributed/test_ep.cu | 34 +++++++++---------- transformer_engine/common/ep/ep_api.cpp | 6 ++-- transformer_engine/common/ep/ep_backend.cpp | 12 +++---- transformer_engine/common/ep/ep_backend.h | 2 +- .../common/include/transformer_engine/ep.h | 18 +++++----- transformer_engine/jax/csrc/extensions/ep.cpp | 10 +++--- 6 files changed, 41 insertions(+), 41 deletions(-) diff --git a/tests/cpp_distributed/test_ep.cu b/tests/cpp_distributed/test_ep.cu index aabbad6b8c..98796ab091 100644 --- a/tests/cpp_distributed/test_ep.cu +++ b/tests/cpp_distributed/test_ep.cu @@ -60,7 +60,7 @@ static std::vector generate_tokens(int rank, int num_tokens, int hidden_dim) return v; } -static std::vector expected_token_counts( +static std::vector expected_recv_tokens_per_expert( int recv_rank, int num_processes, int num_tokens, int top_k, int num_experts, int num_local_experts) { int base = recv_rank * num_local_experts; @@ -128,7 +128,7 @@ struct EPBuffers { DevBuf topk_idx; DevBuf topk_weights; DevBuf tokens; - DevBuf token_counts; + DevBuf recv_tokens_per_expert; DevBuf handle_mem; DevBuf recv_tokens; DevBuf recv_topk_weights; @@ -158,7 +158,7 @@ struct EPBuffers { topk_idx.alloc(num_tokens * top_k); topk_weights.alloc(num_tokens * top_k); tokens.alloc(num_tokens * hidden_dim); - token_counts.alloc(num_local_experts); + recv_tokens_per_expert.alloc(num_local_experts); recv_tokens.alloc(recv_capacity * hidden_dim); recv_topk_weights.alloc(recv_capacity); result.alloc(num_tokens * hidden_dim); @@ -178,7 +178,7 @@ struct EPBuffers { // expects. template struct EPTensors { - TensorWrapper topk_idx, topk_weights, token_counts, handle_mem, tokens; + TensorWrapper topk_idx, topk_weights, recv_tokens_per_expert, handle_mem, tokens; TensorWrapper recv_tokens, recv_topk_weights, result; TensorWrapper grad_result, grad_expert, grad_tokens; TensorWrapper g_recv_topk_weights, grad_topk_weights; @@ -200,7 +200,7 @@ struct EPTensors { Shape{(size_t)num_tokens, (size_t)top_k}, DType::kInt64); topk_weights = TensorWrapper(b.topk_weights.get(), Shape{(size_t)num_tokens, (size_t)top_k}, DType::kFloat32); - token_counts = TensorWrapper(b.token_counts.get(), + recv_tokens_per_expert = TensorWrapper(b.recv_tokens_per_expert.get(), Shape{(size_t)num_local_experts}, DType::kInt32); handle_mem = TensorWrapper(b.handle_mem.get(), Shape{b.handle_mem_size}, DType::kByte); @@ -267,7 +267,7 @@ class EpOpTestBase : public ::testing::Test { template int read_total_recv(const EPBuffers& buf) const { std::vector cnt(num_local_experts_); - NVTE_CHECK_CUDA(cudaMemcpy(cnt.data(), buf.token_counts.get(), + NVTE_CHECK_CUDA(cudaMemcpy(cnt.data(), buf.recv_tokens_per_expert.get(), num_local_experts_ * sizeof(int32_t), cudaMemcpyDeviceToHost)); int total = 0; for (int c : cnt) total += c; @@ -308,7 +308,7 @@ TYPED_TEST(EPDispatchTest, PrepareAndDispatch) { cudaStream_t stream; NVTE_CHECK_CUDA(cudaStreamCreate(&stream)); - ASSERT_NO_THROW(nvte_ep_prepare(t.handle_mem.data(), t.topk_idx.data(), t.token_counts.data(), &t.layer_cfg_, stream)); + ASSERT_NO_THROW(nvte_ep_prepare(t.handle_mem.data(), t.topk_idx.data(), t.recv_tokens_per_expert.data(), &t.layer_cfg_, stream)); ASSERT_NO_THROW(nvte_ep_dispatch(t.handle_mem.data(), t.topk_idx.data(), t.tokens.data(), NVTECommWindow{}, t.topk_weights.data(), NVTECommWindow{}, t.recv_tokens.data(), NVTECommWindow{}, @@ -317,9 +317,9 @@ TYPED_TEST(EPDispatchTest, PrepareAndDispatch) { // 1. Per-expert counts. std::vector got_counts(num_local_experts_); - NVTE_CHECK_CUDA(cudaMemcpy(got_counts.data(), buf.token_counts.get(), + NVTE_CHECK_CUDA(cudaMemcpy(got_counts.data(), buf.recv_tokens_per_expert.get(), num_local_experts_ * sizeof(int32_t), cudaMemcpyDeviceToHost)); - auto exp_counts = expected_token_counts(g_process_id, g_num_processes, num_tokens_, top_k_, + auto exp_counts = expected_recv_tokens_per_expert(g_process_id, g_num_processes, num_tokens_, top_k_, num_experts_, num_local_experts_); int total_recv = 0; for (int i = 0; i < num_local_experts_; ++i) { @@ -387,7 +387,7 @@ TYPED_TEST(EPCombineTest, Combine) { cudaStream_t stream; NVTE_CHECK_CUDA(cudaStreamCreate(&stream)); - ASSERT_NO_THROW(nvte_ep_prepare(t.handle_mem.data(), t.topk_idx.data(), t.token_counts.data(), &t.layer_cfg_, stream)); + ASSERT_NO_THROW(nvte_ep_prepare(t.handle_mem.data(), t.topk_idx.data(), t.recv_tokens_per_expert.data(), &t.layer_cfg_, stream)); ASSERT_NO_THROW(nvte_ep_dispatch(t.handle_mem.data(), t.topk_idx.data(), t.tokens.data(), NVTECommWindow{}, t.topk_weights.data(), NVTECommWindow{}, t.recv_tokens.data(), NVTECommWindow{}, @@ -434,7 +434,7 @@ TYPED_TEST(EPCombineBwdTest, CombineBwdCheck) { cudaStream_t stream; NVTE_CHECK_CUDA(cudaStreamCreate(&stream)); - ASSERT_NO_THROW(nvte_ep_prepare(t.handle_mem.data(), t.topk_idx.data(), t.token_counts.data(), &t.layer_cfg_, stream)); + ASSERT_NO_THROW(nvte_ep_prepare(t.handle_mem.data(), t.topk_idx.data(), t.recv_tokens_per_expert.data(), &t.layer_cfg_, stream)); ASSERT_NO_THROW(nvte_ep_dispatch(t.handle_mem.data(), t.topk_idx.data(), t.tokens.data(), NVTECommWindow{}, t.topk_weights.data(), NVTECommWindow{}, t.recv_tokens.data(), NVTECommWindow{}, @@ -455,7 +455,7 @@ TYPED_TEST(EPCombineBwdTest, CombineBwdCheck) { int total_recv = this->template read_total_recv(buf); std::vector cnt(num_local_experts_); - NVTE_CHECK_CUDA(cudaMemcpy(cnt.data(), buf.token_counts.get(), + NVTE_CHECK_CUDA(cudaMemcpy(cnt.data(), buf.recv_tokens_per_expert.get(), num_local_experts_ * sizeof(int32_t), cudaMemcpyDeviceToHost)); std::vector h_ge(buf.recv_capacity * hidden_dim_); NVTE_CHECK_CUDA(cudaMemcpy(h_ge.data(), buf.grad_expert.get(), @@ -503,7 +503,7 @@ TYPED_TEST(EPDispatchBwdTest, DispatchBwdCheck) { cudaStream_t stream; NVTE_CHECK_CUDA(cudaStreamCreate(&stream)); - ASSERT_NO_THROW(nvte_ep_prepare(t.handle_mem.data(), t.topk_idx.data(), t.token_counts.data(), &t.layer_cfg_, stream)); + ASSERT_NO_THROW(nvte_ep_prepare(t.handle_mem.data(), t.topk_idx.data(), t.recv_tokens_per_expert.data(), &t.layer_cfg_, stream)); ASSERT_NO_THROW(nvte_ep_dispatch(t.handle_mem.data(), t.topk_idx.data(), t.tokens.data(), NVTECommWindow{}, t.topk_weights.data(), NVTECommWindow{}, t.recv_tokens.data(), NVTECommWindow{}, @@ -571,7 +571,7 @@ TYPED_TEST(EPDispatchBwdGradWeightsTest, RoundTrip) { cudaStream_t stream; NVTE_CHECK_CUDA(cudaStreamCreate(&stream)); - ASSERT_NO_THROW(nvte_ep_prepare(t.handle_mem.data(), t.topk_idx.data(), t.token_counts.data(), &t.layer_cfg_, stream)); + ASSERT_NO_THROW(nvte_ep_prepare(t.handle_mem.data(), t.topk_idx.data(), t.recv_tokens_per_expert.data(), &t.layer_cfg_, stream)); NVTE_CHECK_CUDA(cudaMemsetAsync(buf.recv_topk_weights.get(), 0, buf.recv_topk_weights.bytes(), stream)); ASSERT_NO_THROW(nvte_ep_dispatch(t.handle_mem.data(), t.topk_idx.data(), @@ -642,7 +642,7 @@ class EPPipelineTest : public EpOpTestBase, public ::testing::WithParamInterface cudaStream_t stream; NVTE_CHECK_CUDA(cudaStreamCreate(&stream)); - ASSERT_NO_THROW(nvte_ep_prepare(t.handle_mem.data(), t.topk_idx.data(), t.token_counts.data(), &t.layer_cfg_, stream)); + ASSERT_NO_THROW(nvte_ep_prepare(t.handle_mem.data(), t.topk_idx.data(), t.recv_tokens_per_expert.data(), &t.layer_cfg_, stream)); ASSERT_NO_THROW(nvte_ep_dispatch(t.handle_mem.data(), t.topk_idx.data(), t.tokens.data(), NVTECommWindow{}, t.topk_weights.data(), NVTECommWindow{}, t.recv_tokens.data(), NVTECommWindow{}, @@ -767,7 +767,7 @@ TYPED_TEST(EPZeroCopyTest, IdentityAllSymm) { cudaStream_t stream; NVTE_CHECK_CUDA(cudaStreamCreate(&stream)); - ASSERT_NO_THROW(nvte_ep_prepare(ref_t.handle_mem.data(), ref_t.topk_idx.data(), ref_t.token_counts.data(), &ref_t.layer_cfg_, stream)); + ASSERT_NO_THROW(nvte_ep_prepare(ref_t.handle_mem.data(), ref_t.topk_idx.data(), ref_t.recv_tokens_per_expert.data(), &ref_t.layer_cfg_, stream)); ASSERT_NO_THROW(nvte_ep_dispatch(ref_t.handle_mem.data(), ref_t.topk_idx.data(), ref_t.tokens.data(), NVTECommWindow{}, ref_t.topk_weights.data(), NVTECommWindow{}, ref_t.recv_tokens.data(), NVTECommWindow{}, @@ -808,7 +808,7 @@ TYPED_TEST(EPZeroCopyTest, IdentityAllSymm) { sym_t.recv_tokens = TensorWrapper(sym_recv.ptr, std::vector{sym_buf.recv_capacity, (size_t)hidden_dim_}, kTokDType); - ASSERT_NO_THROW(nvte_ep_prepare(sym_t.handle_mem.data(), sym_t.topk_idx.data(), sym_t.token_counts.data(), &sym_t.layer_cfg_, stream)); + ASSERT_NO_THROW(nvte_ep_prepare(sym_t.handle_mem.data(), sym_t.topk_idx.data(), sym_t.recv_tokens_per_expert.data(), &sym_t.layer_cfg_, stream)); ASSERT_NO_THROW(nvte_ep_dispatch(sym_t.handle_mem.data(), sym_t.topk_idx.data(), sym_t.tokens.data(), symm_window(sym_tokens), sym_t.topk_weights.data(), NVTECommWindow{}, diff --git a/transformer_engine/common/ep/ep_api.cpp b/transformer_engine/common/ep/ep_api.cpp index fde4882806..ff20af8d0a 100644 --- a/transformer_engine/common/ep/ep_api.cpp +++ b/transformer_engine/common/ep/ep_api.cpp @@ -74,11 +74,11 @@ size_t nvte_ep_handle_mem_size(const NVTEEpLayerConfig* layer_cfg) { return EPBackend::get().handle_mem_size(cfg); } -void nvte_ep_prepare(NVTETensor handle_mem, NVTETensor topk_idx, NVTETensor token_counts, +void nvte_ep_prepare(NVTETensor handle_mem, NVTETensor topk_idx, NVTETensor recv_tokens_per_expert, const NVTEEpLayerConfig* layer_cfg, cudaStream_t stream) { NVTEEpLayerConfig cfg = normalize_ep_config(layer_cfg, kLayerConfigMinSize, "layer_cfg"); - EPBackend::get().prepare(handle_mem_ptr(handle_mem), topk_idx, token_counts, cfg, stream); + EPBackend::get().prepare(handle_mem_ptr(handle_mem), topk_idx, recv_tokens_per_expert, cfg, stream); } void nvte_ep_dispatch(NVTETensor handle_mem, NVTETensor topk_idx, NVTETensor tokens, @@ -130,7 +130,7 @@ void nvte_ep_shutdown(void) {} size_t nvte_ep_handle_mem_size(const NVTEEpLayerConfig* /*layer_cfg*/) { ep_not_built(); } void nvte_ep_prepare(NVTETensor /*handle_mem*/, NVTETensor /*topk_idx*/, - NVTETensor /*token_counts*/, const NVTEEpLayerConfig* /*layer_cfg*/, + NVTETensor /*recv_tokens_per_expert*/, const NVTEEpLayerConfig* /*layer_cfg*/, cudaStream_t /*stream*/) { ep_not_built(); } diff --git a/transformer_engine/common/ep/ep_backend.cpp b/transformer_engine/common/ep/ep_backend.cpp index d20364ac66..0699cd9b8c 100644 --- a/transformer_engine/common/ep/ep_backend.cpp +++ b/transformer_engine/common/ep/ep_backend.cpp @@ -317,7 +317,7 @@ size_t EPBackend::handle_mem_size(NVTEEpLayerConfig layer_cfg) { return hm_size; } -void EPBackend::prepare(void* handle_mem, const NVTETensor topk_idx, NVTETensor token_counts, +void EPBackend::prepare(void* handle_mem, const NVTETensor topk_idx, NVTETensor recv_tokens_per_expert, NVTEEpLayerConfig layer_cfg, cudaStream_t stream) { NVTE_CHECK(handle_mem != nullptr, "handle_mem must not be null"); NVTE_CHECK(layer_cfg.top_k > 0, "top_k must be > 0, got ", layer_cfg.top_k); @@ -327,13 +327,13 @@ void EPBackend::prepare(void* handle_mem, const NVTETensor topk_idx, NVTETensor ncclEpTensor_t nccl_topk_idx = make_nccl_ep_tensor(topk_idx, topk_idx_shape); // ncclEpUpdateHandle writes per-expert counts via expert_counters. - NVTEShape token_counts_shape; - ncclEpTensor_t token_counts_desc; - if (token_counts != nullptr) { - token_counts_desc = make_nccl_ep_tensor(token_counts, token_counts_shape); + NVTEShape recv_tokens_per_expert_shape; + ncclEpTensor_t recv_tokens_per_expert_desc; + if (recv_tokens_per_expert != nullptr) { + recv_tokens_per_expert_desc = make_nccl_ep_tensor(recv_tokens_per_expert, recv_tokens_per_expert_shape); } ncclEpLayoutInfo_t layout_info = NCCL_EP_LAYOUT_INFO_INIT; - layout_info.expert_counters = (token_counts != nullptr) ? &token_counts_desc : nullptr; + layout_info.expert_counters = (recv_tokens_per_expert != nullptr) ? &recv_tokens_per_expert_desc : nullptr; std::lock_guard lock(mutex_); NVTE_CHECK(initialized_, "EPBackend not initialized"); diff --git a/transformer_engine/common/ep/ep_backend.h b/transformer_engine/common/ep/ep_backend.h index 2325baafca..bbf7bae1f8 100644 --- a/transformer_engine/common/ep/ep_backend.h +++ b/transformer_engine/common/ep/ep_backend.h @@ -46,7 +46,7 @@ class EPBackend { size_t handle_mem_size(NVTEEpLayerConfig layer_cfg); // Seeds the cache for handle_mem with layer_cfg and runs the routing AllGather. - void prepare(void* handle_mem, const NVTETensor topk_idx, NVTETensor token_counts, + void prepare(void* handle_mem, const NVTETensor topk_idx, NVTETensor recv_tokens_per_expert, NVTEEpLayerConfig layer_cfg, cudaStream_t stream); // Per-step ops below require a prior prepare(). diff --git a/transformer_engine/common/include/transformer_engine/ep.h b/transformer_engine/common/include/transformer_engine/ep.h index e03323cd65..59d424adbe 100644 --- a/transformer_engine/common/include/transformer_engine/ep.h +++ b/transformer_engine/common/include/transformer_engine/ep.h @@ -123,16 +123,16 @@ size_t nvte_ep_handle_mem_size(const NVTEEpLayerConfig* layer_cfg); * AllGathers topk_idx across the EP group and stages per-expert offsets and * counts into handle_mem so the matching dispatch/combine/_bwd can run with * no further routing computation. Must precede every dispatch/combine/_bwd - * that uses this handle_mem. token_counts becomes host-valid after a stream - * sync. - * - * \param[in] handle_mem uint8 routing-state buffer. - * \param[in] topk_idx [T, top_k] int64 routing indices. - * \param[out] token_counts [num_local_experts] int32 counts. - * \param[in] layer_cfg Per-call layer configuration (struct_size set). - * \param[in] stream CUDA stream. + * that uses this handle_mem. recv_tokens_per_expert becomes host-valid after a + * stream sync. + * + * \param[in] handle_mem uint8 routing-state buffer. + * \param[in] topk_idx [T, top_k] int64 routing indices. + * \param[out] recv_tokens_per_expert [num_local_experts] int32 counts. + * \param[in] layer_cfg Per-call layer configuration (struct_size set). + * \param[in] stream CUDA stream. */ -void nvte_ep_prepare(NVTETensor handle_mem, NVTETensor topk_idx, NVTETensor token_counts, +void nvte_ep_prepare(NVTETensor handle_mem, NVTETensor topk_idx, NVTETensor recv_tokens_per_expert, const NVTEEpLayerConfig* layer_cfg, cudaStream_t stream); /*! \brief Dispatch tokens (and routing weights) to expert ranks. diff --git a/transformer_engine/jax/csrc/extensions/ep.cpp b/transformer_engine/jax/csrc/extensions/ep.cpp index 370066e6ff..399f47e842 100644 --- a/transformer_engine/jax/csrc/extensions/ep.cpp +++ b/transformer_engine/jax/csrc/extensions/ep.cpp @@ -196,7 +196,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(EpInstantiateHandler, EpInstantiateImpl, FFI::Bind // ── ep_prepare ──────────────────────────────────────────────────────────────── Error_Type EpPrepareFFI(cudaStream_t stream, EpInstanceState* ep_state, Buffer_Type topk_idx, - Result_Type token_counts, Result_Type handle_mem, Result_Type workspace, + Result_Type recv_tokens_per_expert, Result_Type handle_mem, Result_Type workspace, EpConfig config) { (void)ep_state; // lifetime only. auto topk_dims = topk_idx.dimensions(); @@ -222,8 +222,8 @@ Error_Type EpPrepareFFI(cudaStream_t stream, EpInstanceState* ep_state, Buffer_T } auto topk_idx_ = TensorWrapper(topk_idx_data, topk_shape, DType::kInt64); - std::vector tc_shape = {static_cast(token_counts->element_count())}; - auto token_counts_ = TensorWrapper(token_counts->untyped_data(), tc_shape, DType::kInt32); + std::vector tc_shape = {static_cast(recv_tokens_per_expert->element_count())}; + auto recv_tokens_per_expert_ = TensorWrapper(recv_tokens_per_expert->untyped_data(), tc_shape, DType::kInt32); std::vector hm_shape = {static_cast(handle_mem->element_count())}; auto handle_mem_ = TensorWrapper(handle_mem->untyped_data(), hm_shape, DType::kByte); @@ -233,7 +233,7 @@ Error_Type EpPrepareFFI(cudaStream_t stream, EpInstanceState* ep_state, Buffer_T .top_k = static_cast(config.top_k), .dispatch_output_per_expert_alignment = static_cast(config.dispatch_output_per_expert_alignment)}; - nvte_ep_prepare(handle_mem_.data(), topk_idx_.data(), token_counts_.data(), &layer_cfg, stream); + nvte_ep_prepare(handle_mem_.data(), topk_idx_.data(), recv_tokens_per_expert_.data(), &layer_cfg, stream); return ffi_with_cuda_error_check(); } @@ -242,7 +242,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(EpPrepareHandler, EpPrepareFFI, .Ctx() // stream .Ctx<::xla::ffi::State>() // EP state .Arg() // topk_idx - .Ret() // token_counts + .Ret() // recv_tokens_per_expert .Ret() // handle_mem .Ret() // workspace (FFI scratch) .Attrs(), From 1dbbddb5c180cb42a922e67da697a0d35c3df9c5 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Mon, 29 Jun 2026 03:37:55 -0700 Subject: [PATCH 3/8] Add total_recv_tokens_per_rank placeholder to nvte_ep_prepare Signed-off-by: Phuong Nguyen --- tests/cpp_distributed/test_ep.cu | 16 ++++++++-------- transformer_engine/common/ep/ep_api.cpp | 11 +++++++---- transformer_engine/common/ep/ep_backend.cpp | 4 +++- transformer_engine/common/ep/ep_backend.h | 3 ++- .../common/include/transformer_engine/ep.h | 14 ++++++++------ transformer_engine/jax/csrc/extensions/ep.cpp | 3 ++- 6 files changed, 30 insertions(+), 21 deletions(-) diff --git a/tests/cpp_distributed/test_ep.cu b/tests/cpp_distributed/test_ep.cu index 98796ab091..7dbbcdce9d 100644 --- a/tests/cpp_distributed/test_ep.cu +++ b/tests/cpp_distributed/test_ep.cu @@ -308,7 +308,7 @@ TYPED_TEST(EPDispatchTest, PrepareAndDispatch) { cudaStream_t stream; NVTE_CHECK_CUDA(cudaStreamCreate(&stream)); - ASSERT_NO_THROW(nvte_ep_prepare(t.handle_mem.data(), t.topk_idx.data(), t.recv_tokens_per_expert.data(), &t.layer_cfg_, stream)); + ASSERT_NO_THROW(nvte_ep_prepare(t.handle_mem.data(), t.topk_idx.data(), t.recv_tokens_per_expert.data(), nullptr, &t.layer_cfg_, stream)); ASSERT_NO_THROW(nvte_ep_dispatch(t.handle_mem.data(), t.topk_idx.data(), t.tokens.data(), NVTECommWindow{}, t.topk_weights.data(), NVTECommWindow{}, t.recv_tokens.data(), NVTECommWindow{}, @@ -387,7 +387,7 @@ TYPED_TEST(EPCombineTest, Combine) { cudaStream_t stream; NVTE_CHECK_CUDA(cudaStreamCreate(&stream)); - ASSERT_NO_THROW(nvte_ep_prepare(t.handle_mem.data(), t.topk_idx.data(), t.recv_tokens_per_expert.data(), &t.layer_cfg_, stream)); + ASSERT_NO_THROW(nvte_ep_prepare(t.handle_mem.data(), t.topk_idx.data(), t.recv_tokens_per_expert.data(), nullptr, &t.layer_cfg_, stream)); ASSERT_NO_THROW(nvte_ep_dispatch(t.handle_mem.data(), t.topk_idx.data(), t.tokens.data(), NVTECommWindow{}, t.topk_weights.data(), NVTECommWindow{}, t.recv_tokens.data(), NVTECommWindow{}, @@ -434,7 +434,7 @@ TYPED_TEST(EPCombineBwdTest, CombineBwdCheck) { cudaStream_t stream; NVTE_CHECK_CUDA(cudaStreamCreate(&stream)); - ASSERT_NO_THROW(nvte_ep_prepare(t.handle_mem.data(), t.topk_idx.data(), t.recv_tokens_per_expert.data(), &t.layer_cfg_, stream)); + ASSERT_NO_THROW(nvte_ep_prepare(t.handle_mem.data(), t.topk_idx.data(), t.recv_tokens_per_expert.data(), nullptr, &t.layer_cfg_, stream)); ASSERT_NO_THROW(nvte_ep_dispatch(t.handle_mem.data(), t.topk_idx.data(), t.tokens.data(), NVTECommWindow{}, t.topk_weights.data(), NVTECommWindow{}, t.recv_tokens.data(), NVTECommWindow{}, @@ -503,7 +503,7 @@ TYPED_TEST(EPDispatchBwdTest, DispatchBwdCheck) { cudaStream_t stream; NVTE_CHECK_CUDA(cudaStreamCreate(&stream)); - ASSERT_NO_THROW(nvte_ep_prepare(t.handle_mem.data(), t.topk_idx.data(), t.recv_tokens_per_expert.data(), &t.layer_cfg_, stream)); + ASSERT_NO_THROW(nvte_ep_prepare(t.handle_mem.data(), t.topk_idx.data(), t.recv_tokens_per_expert.data(), nullptr, &t.layer_cfg_, stream)); ASSERT_NO_THROW(nvte_ep_dispatch(t.handle_mem.data(), t.topk_idx.data(), t.tokens.data(), NVTECommWindow{}, t.topk_weights.data(), NVTECommWindow{}, t.recv_tokens.data(), NVTECommWindow{}, @@ -571,7 +571,7 @@ TYPED_TEST(EPDispatchBwdGradWeightsTest, RoundTrip) { cudaStream_t stream; NVTE_CHECK_CUDA(cudaStreamCreate(&stream)); - ASSERT_NO_THROW(nvte_ep_prepare(t.handle_mem.data(), t.topk_idx.data(), t.recv_tokens_per_expert.data(), &t.layer_cfg_, stream)); + ASSERT_NO_THROW(nvte_ep_prepare(t.handle_mem.data(), t.topk_idx.data(), t.recv_tokens_per_expert.data(), nullptr, &t.layer_cfg_, stream)); NVTE_CHECK_CUDA(cudaMemsetAsync(buf.recv_topk_weights.get(), 0, buf.recv_topk_weights.bytes(), stream)); ASSERT_NO_THROW(nvte_ep_dispatch(t.handle_mem.data(), t.topk_idx.data(), @@ -642,7 +642,7 @@ class EPPipelineTest : public EpOpTestBase, public ::testing::WithParamInterface cudaStream_t stream; NVTE_CHECK_CUDA(cudaStreamCreate(&stream)); - ASSERT_NO_THROW(nvte_ep_prepare(t.handle_mem.data(), t.topk_idx.data(), t.recv_tokens_per_expert.data(), &t.layer_cfg_, stream)); + ASSERT_NO_THROW(nvte_ep_prepare(t.handle_mem.data(), t.topk_idx.data(), t.recv_tokens_per_expert.data(), nullptr, &t.layer_cfg_, stream)); ASSERT_NO_THROW(nvte_ep_dispatch(t.handle_mem.data(), t.topk_idx.data(), t.tokens.data(), NVTECommWindow{}, t.topk_weights.data(), NVTECommWindow{}, t.recv_tokens.data(), NVTECommWindow{}, @@ -767,7 +767,7 @@ TYPED_TEST(EPZeroCopyTest, IdentityAllSymm) { cudaStream_t stream; NVTE_CHECK_CUDA(cudaStreamCreate(&stream)); - ASSERT_NO_THROW(nvte_ep_prepare(ref_t.handle_mem.data(), ref_t.topk_idx.data(), ref_t.recv_tokens_per_expert.data(), &ref_t.layer_cfg_, stream)); + ASSERT_NO_THROW(nvte_ep_prepare(ref_t.handle_mem.data(), ref_t.topk_idx.data(), ref_t.recv_tokens_per_expert.data(), nullptr, &ref_t.layer_cfg_, stream)); ASSERT_NO_THROW(nvte_ep_dispatch(ref_t.handle_mem.data(), ref_t.topk_idx.data(), ref_t.tokens.data(), NVTECommWindow{}, ref_t.topk_weights.data(), NVTECommWindow{}, ref_t.recv_tokens.data(), NVTECommWindow{}, @@ -808,7 +808,7 @@ TYPED_TEST(EPZeroCopyTest, IdentityAllSymm) { sym_t.recv_tokens = TensorWrapper(sym_recv.ptr, std::vector{sym_buf.recv_capacity, (size_t)hidden_dim_}, kTokDType); - ASSERT_NO_THROW(nvte_ep_prepare(sym_t.handle_mem.data(), sym_t.topk_idx.data(), sym_t.recv_tokens_per_expert.data(), &sym_t.layer_cfg_, stream)); + ASSERT_NO_THROW(nvte_ep_prepare(sym_t.handle_mem.data(), sym_t.topk_idx.data(), sym_t.recv_tokens_per_expert.data(), nullptr, &sym_t.layer_cfg_, stream)); ASSERT_NO_THROW(nvte_ep_dispatch(sym_t.handle_mem.data(), sym_t.topk_idx.data(), sym_t.tokens.data(), symm_window(sym_tokens), sym_t.topk_weights.data(), NVTECommWindow{}, diff --git a/transformer_engine/common/ep/ep_api.cpp b/transformer_engine/common/ep/ep_api.cpp index ff20af8d0a..8402506e5d 100644 --- a/transformer_engine/common/ep/ep_api.cpp +++ b/transformer_engine/common/ep/ep_api.cpp @@ -75,10 +75,12 @@ size_t nvte_ep_handle_mem_size(const NVTEEpLayerConfig* layer_cfg) { } void nvte_ep_prepare(NVTETensor handle_mem, NVTETensor topk_idx, NVTETensor recv_tokens_per_expert, - const NVTEEpLayerConfig* layer_cfg, cudaStream_t stream) { + NVTETensor total_recv_tokens_per_rank, const NVTEEpLayerConfig* layer_cfg, + cudaStream_t stream) { NVTEEpLayerConfig cfg = normalize_ep_config(layer_cfg, kLayerConfigMinSize, "layer_cfg"); - EPBackend::get().prepare(handle_mem_ptr(handle_mem), topk_idx, recv_tokens_per_expert, cfg, stream); + EPBackend::get().prepare(handle_mem_ptr(handle_mem), topk_idx, recv_tokens_per_expert, + total_recv_tokens_per_rank, cfg, stream); } void nvte_ep_dispatch(NVTETensor handle_mem, NVTETensor topk_idx, NVTETensor tokens, @@ -130,8 +132,9 @@ void nvte_ep_shutdown(void) {} size_t nvte_ep_handle_mem_size(const NVTEEpLayerConfig* /*layer_cfg*/) { ep_not_built(); } void nvte_ep_prepare(NVTETensor /*handle_mem*/, NVTETensor /*topk_idx*/, - NVTETensor /*recv_tokens_per_expert*/, const NVTEEpLayerConfig* /*layer_cfg*/, - cudaStream_t /*stream*/) { + NVTETensor /*recv_tokens_per_expert*/, + NVTETensor /*total_recv_tokens_per_rank*/, + const NVTEEpLayerConfig* /*layer_cfg*/, cudaStream_t /*stream*/) { ep_not_built(); } diff --git a/transformer_engine/common/ep/ep_backend.cpp b/transformer_engine/common/ep/ep_backend.cpp index 0699cd9b8c..7c0afd8b8a 100644 --- a/transformer_engine/common/ep/ep_backend.cpp +++ b/transformer_engine/common/ep/ep_backend.cpp @@ -318,7 +318,9 @@ size_t EPBackend::handle_mem_size(NVTEEpLayerConfig layer_cfg) { } void EPBackend::prepare(void* handle_mem, const NVTETensor topk_idx, NVTETensor recv_tokens_per_expert, - NVTEEpLayerConfig layer_cfg, cudaStream_t stream) { + NVTETensor /*total_recv_tokens_per_rank*/, NVTEEpLayerConfig layer_cfg, + cudaStream_t stream) { + // total_recv_tokens_per_rank is a reserved placeholder; not yet populated. NVTE_CHECK(handle_mem != nullptr, "handle_mem must not be null"); NVTE_CHECK(layer_cfg.top_k > 0, "top_k must be > 0, got ", layer_cfg.top_k); NVTE_CHECK(nvte_tensor_shape(topk_idx).ndim == 2, "topk_idx must be 2D [T, top_k]"); diff --git a/transformer_engine/common/ep/ep_backend.h b/transformer_engine/common/ep/ep_backend.h index bbf7bae1f8..80c9b9cea3 100644 --- a/transformer_engine/common/ep/ep_backend.h +++ b/transformer_engine/common/ep/ep_backend.h @@ -47,7 +47,8 @@ class EPBackend { // Seeds the cache for handle_mem with layer_cfg and runs the routing AllGather. void prepare(void* handle_mem, const NVTETensor topk_idx, NVTETensor recv_tokens_per_expert, - NVTEEpLayerConfig layer_cfg, cudaStream_t stream); + NVTETensor total_recv_tokens_per_rank, NVTEEpLayerConfig layer_cfg, + cudaStream_t stream); // Per-step ops below require a prior prepare(). void dispatch(void* handle_mem, const NVTETensor topk_idx, const NVTETensor tokens, diff --git a/transformer_engine/common/include/transformer_engine/ep.h b/transformer_engine/common/include/transformer_engine/ep.h index 59d424adbe..e5cadf5e0d 100644 --- a/transformer_engine/common/include/transformer_engine/ep.h +++ b/transformer_engine/common/include/transformer_engine/ep.h @@ -126,14 +126,16 @@ size_t nvte_ep_handle_mem_size(const NVTEEpLayerConfig* layer_cfg); * that uses this handle_mem. recv_tokens_per_expert becomes host-valid after a * stream sync. * - * \param[in] handle_mem uint8 routing-state buffer. - * \param[in] topk_idx [T, top_k] int64 routing indices. - * \param[out] recv_tokens_per_expert [num_local_experts] int32 counts. - * \param[in] layer_cfg Per-call layer configuration (struct_size set). - * \param[in] stream CUDA stream. + * \param[in] handle_mem uint8 routing-state buffer. + * \param[in] topk_idx [T, top_k] int64 routing indices. + * \param[out] recv_tokens_per_expert [num_local_experts] int32 counts. + * \param[out] total_recv_tokens_per_rank Reserved placeholder; may be null. Unused for now. + * \param[in] layer_cfg Per-call layer configuration (struct_size set). + * \param[in] stream CUDA stream. */ void nvte_ep_prepare(NVTETensor handle_mem, NVTETensor topk_idx, NVTETensor recv_tokens_per_expert, - const NVTEEpLayerConfig* layer_cfg, cudaStream_t stream); + NVTETensor total_recv_tokens_per_rank, const NVTEEpLayerConfig* layer_cfg, + cudaStream_t stream); /*! \brief Dispatch tokens (and routing weights) to expert ranks. * diff --git a/transformer_engine/jax/csrc/extensions/ep.cpp b/transformer_engine/jax/csrc/extensions/ep.cpp index 399f47e842..0f6520a1a5 100644 --- a/transformer_engine/jax/csrc/extensions/ep.cpp +++ b/transformer_engine/jax/csrc/extensions/ep.cpp @@ -233,7 +233,8 @@ Error_Type EpPrepareFFI(cudaStream_t stream, EpInstanceState* ep_state, Buffer_T .top_k = static_cast(config.top_k), .dispatch_output_per_expert_alignment = static_cast(config.dispatch_output_per_expert_alignment)}; - nvte_ep_prepare(handle_mem_.data(), topk_idx_.data(), recv_tokens_per_expert_.data(), &layer_cfg, stream); + nvte_ep_prepare(handle_mem_.data(), topk_idx_.data(), recv_tokens_per_expert_.data(), + /*total_recv_tokens_per_rank=*/nullptr, &layer_cfg, stream); return ffi_with_cuda_error_check(); } From afa0656d8138c8d72146785822122a03b60dc4f7 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Mon, 29 Jun 2026 09:48:32 -0700 Subject: [PATCH 4/8] Adapt PyTorch EP binding to versioned nvte_ep C config API Signed-off-by: Phuong Nguyen --- .../pytorch/csrc/extensions/ep.cpp | 33 ++++++++++--------- 1 file changed, 18 insertions(+), 15 deletions(-) diff --git a/transformer_engine/pytorch/csrc/extensions/ep.cpp b/transformer_engine/pytorch/csrc/extensions/ep.cpp index d1ef76af40..58b260e979 100644 --- a/transformer_engine/pytorch/csrc/extensions/ep.cpp +++ b/transformer_engine/pytorch/csrc/extensions/ep.cpp @@ -136,16 +136,17 @@ void ep_initialize(uintptr_t comm_ptr, const std::string& group_name, int64_t nu NVTE_CHECK(ncclCommCount(ep_comm, &ep_size) == ncclSuccess, "ncclCommCount failed"); auto torch_dtype = max_token_dtype.cast(); NVTEEpGroupConfig cfg{ - /*ep_size=*/ep_size, - /*num_experts=*/static_cast(num_experts), - /*max_tokens_per_rank=*/static_cast(max_tokens_per_rank), - /*max_recv_tokens_per_rank=*/static_cast(max_recv_tokens_per_rank), - /*hidden_dim=*/static_cast(hidden_dim), - /*max_num_sms=*/static_cast(max_num_sms), - /*max_token_dtype=*/static_cast(GetTransformerEngineDType(torch_dtype)), - /*zero_copy=*/zero_copy ? 1 : 0, + .struct_size = sizeof(NVTEEpGroupConfig), + .ep_size = ep_size, + .num_experts = static_cast(num_experts), + .max_tokens_per_rank = static_cast(max_tokens_per_rank), + .max_recv_tokens_per_rank = static_cast(max_recv_tokens_per_rank), + .hidden_dim = static_cast(hidden_dim), + .max_num_sms = static_cast(max_num_sms), + .max_token_dtype = static_cast(GetTransformerEngineDType(torch_dtype)), + .zero_copy = zero_copy ? 1 : 0, }; - nvte_ep_initialize(static_cast(ep_comm), cfg); + nvte_ep_initialize(static_cast(ep_comm), &cfg); g_zero_copy_enabled.store(zero_copy, std::memory_order_relaxed); g_ep_initialized = true; g_ep_group_name = group_name; @@ -164,17 +165,18 @@ namespace { NVTEEpLayerConfig make_layer_cfg(int64_t top_k, int64_t dispatch_output_per_expert_alignment) { return NVTEEpLayerConfig{ - /*top_k=*/static_cast(top_k), - /*dispatch_output_per_expert_alignment=*/ - static_cast(dispatch_output_per_expert_alignment), + .struct_size = sizeof(NVTEEpLayerConfig), + .top_k = static_cast(top_k), + .dispatch_output_per_expert_alignment = + static_cast(dispatch_output_per_expert_alignment), }; } } // namespace int64_t ep_handle_mem_size(int64_t top_k, int64_t dispatch_output_per_expert_alignment) { - return static_cast( - nvte_ep_handle_mem_size(make_layer_cfg(top_k, dispatch_output_per_expert_alignment))); + auto layer_cfg = make_layer_cfg(top_k, dispatch_output_per_expert_alignment); + return static_cast(nvte_ep_handle_mem_size(&layer_cfg)); } // ── Per-step ops ───────────────────────────────────────────────────────────── @@ -194,8 +196,9 @@ void ep_prepare(at::Tensor handle_mem, at::Tensor topk_idx, at::Tensor token_cou auto handle_mem_te = makeTransformerEngineTensor( handle_mem.data_ptr(), Shape{static_cast(handle_mem.numel())}, DType::kByte); + auto layer_cfg = make_layer_cfg(top_k, dispatch_output_per_expert_alignment); nvte_ep_prepare(handle_mem_te.data(), topk_idx_te.data(), token_counts_te.data(), - make_layer_cfg(top_k, dispatch_output_per_expert_alignment), stream); + /*total_recv_tokens_per_rank=*/nullptr, &layer_cfg, stream); } void ep_dispatch(at::Tensor handle_mem, at::Tensor topk_idx, at::Tensor tokens, From c92997e4fcf5402f3aba173a9de030a789958b81 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Mon, 29 Jun 2026 10:15:07 -0700 Subject: [PATCH 5/8] Rename EP group config max_num_sms to num_comm_sms Signed-off-by: Phuong Nguyen --- transformer_engine/common/ep/ep_backend.cpp | 8 ++++---- transformer_engine/common/include/transformer_engine/ep.h | 4 ++-- transformer_engine/jax/csrc/extensions/ep.cpp | 2 +- transformer_engine/pytorch/csrc/extensions/ep.cpp | 2 +- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/transformer_engine/common/ep/ep_backend.cpp b/transformer_engine/common/ep/ep_backend.cpp index 7c0afd8b8a..7c93b24526 100644 --- a/transformer_engine/common/ep/ep_backend.cpp +++ b/transformer_engine/common/ep/ep_backend.cpp @@ -108,8 +108,8 @@ void EPBackend::validate_config(const NVTEEpGroupConfig& config) { "hidden_dim * sizeof(max_token_dtype) exceeds 4 GiB; got ", row_bytes, " bytes"); NVTE_CHECK(config.num_experts % config.ep_size == 0, "num_experts (", config.num_experts, ") must be divisible by ep_size (", config.ep_size, ")"); - NVTE_CHECK(config.max_num_sms >= 0, "max_num_sms must be >= 0 (0 = auto), got ", - config.max_num_sms); + NVTE_CHECK(config.num_comm_sms >= 0, "num_comm_sms must be >= 0 (0 = auto), got ", + config.num_comm_sms); const int sm = cuda::sm_arch(); NVTE_CHECK(sm >= 90, "NCCL EP requires SM_90+ (Hopper or later), but current device is SM_", sm); @@ -205,8 +205,8 @@ void EPBackend::init(ncclComm_t ep_comm, NVTEEpGroupConfig group_config) { cfg.rdma_buffer_size = NCCL_EP_AUTO; cfg.num_qp_per_rank = NCCL_EP_AUTO; cfg.num_channels = NCCL_EP_AUTO; - cfg.max_num_sms = group_config.max_num_sms > 0 - ? static_cast(group_config.max_num_sms) + cfg.max_num_sms = group_config.num_comm_sms > 0 + ? static_cast(group_config.num_comm_sms) : NCCL_EP_AUTO; // Must be > 0; NCCL EP errors out on 0. cfg.max_recv_tokens_per_rank = static_cast(group_config.max_recv_tokens_per_rank); diff --git a/transformer_engine/common/include/transformer_engine/ep.h b/transformer_engine/common/include/transformer_engine/ep.h index e5cadf5e0d..52085fcc0e 100644 --- a/transformer_engine/common/include/transformer_engine/ep.h +++ b/transformer_engine/common/include/transformer_engine/ep.h @@ -48,8 +48,8 @@ typedef struct { int max_recv_tokens_per_rank; /*! Token hidden dimension. */ int hidden_dim; - /*! Max SMs for EP kernels. 0 = auto. */ - int max_num_sms; + /*! Max SMs for NCCL EP dispatch/combine kernels. 0 = auto. */ + int num_comm_sms; /*! Widest token dtype the group will dispatch; sizes staging buffers. * Required (no default): must be set to a real token dtype. Per-dispatch * tensors may use any dtype with element size <= this. */ diff --git a/transformer_engine/jax/csrc/extensions/ep.cpp b/transformer_engine/jax/csrc/extensions/ep.cpp index 0f6520a1a5..34124eb674 100644 --- a/transformer_engine/jax/csrc/extensions/ep.cpp +++ b/transformer_engine/jax/csrc/extensions/ep.cpp @@ -51,7 +51,7 @@ class EpResources { .max_tokens_per_rank = p.max_tokens_per_rank, .max_recv_tokens_per_rank = p.max_recv_tokens_per_rank, .hidden_dim = p.hidden_dim, - .max_num_sms = p.max_num_sms, + .num_comm_sms = p.max_num_sms, .max_token_dtype = p.max_token_dtype, .zero_copy = 0}; try { diff --git a/transformer_engine/pytorch/csrc/extensions/ep.cpp b/transformer_engine/pytorch/csrc/extensions/ep.cpp index 58b260e979..ae23c705e5 100644 --- a/transformer_engine/pytorch/csrc/extensions/ep.cpp +++ b/transformer_engine/pytorch/csrc/extensions/ep.cpp @@ -142,7 +142,7 @@ void ep_initialize(uintptr_t comm_ptr, const std::string& group_name, int64_t nu .max_tokens_per_rank = static_cast(max_tokens_per_rank), .max_recv_tokens_per_rank = static_cast(max_recv_tokens_per_rank), .hidden_dim = static_cast(hidden_dim), - .max_num_sms = static_cast(max_num_sms), + .num_comm_sms = static_cast(max_num_sms), .max_token_dtype = static_cast(GetTransformerEngineDType(torch_dtype)), .zero_copy = zero_copy ? 1 : 0, }; From 70aee42ae90041abcc654bca0b5c5dbf5292c810 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Mon, 29 Jun 2026 12:53:52 -0700 Subject: [PATCH 6/8] Detect active NVLink via nvlink --status link bandwidth in PyTorch EP test Signed-off-by: Phuong Nguyen --- tests/pytorch/distributed/run_test_ep.sh | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tests/pytorch/distributed/run_test_ep.sh b/tests/pytorch/distributed/run_test_ep.sh index ae40c8ba4b..8ef0876be0 100755 --- a/tests/pytorch/distributed/run_test_ep.sh +++ b/tests/pytorch/distributed/run_test_ep.sh @@ -18,10 +18,13 @@ if [ "${DETECTED_GPUS}" -lt 4 ]; then exit 0 fi -# NCCL EP requires NVLink/NVSwitch between GPUs. +# NCCL EP requires active NVLink P2P among ranks on the node. # On PCIe-only nodes (no NVLink) it falls back to the network # transport and deadlocks, so skip cleanly there. -if ! nvidia-smi topo -m 2>/dev/null | grep -qE "\bNV[0-9]+\b"; then +# Capture first: piping into grep -q closes the pipe early and SIGPIPEs +# nvidia-smi, which under pipefail would falsely report "no NVLink". +NVLINK_STATUS="$(nvidia-smi nvlink --status 2>/dev/null)" +if ! grep -qE 'Link [0-9]+:.*GB/s' <<<"${NVLINK_STATUS}"; then echo "No NVLink between GPUs (PCIe-only fabric); NCCL EP is unsupported here. SKIPPING." exit 0 fi From 1db0c8060fae7cd0cf41ada1d2f9e00385538bc0 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 29 Jun 2026 19:55:21 +0000 Subject: [PATCH 7/8] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/common/ep/ep_api.cpp | 9 ++---- transformer_engine/common/ep/ep_backend.cpp | 9 ++++-- .../common/include/transformer_engine/ep.h | 6 ++-- transformer_engine/jax/csrc/extensions/ep.cpp | 28 +++++++++---------- 4 files changed, 25 insertions(+), 27 deletions(-) diff --git a/transformer_engine/common/ep/ep_api.cpp b/transformer_engine/common/ep/ep_api.cpp index 8402506e5d..0981289ffe 100644 --- a/transformer_engine/common/ep/ep_api.cpp +++ b/transformer_engine/common/ep/ep_api.cpp @@ -61,24 +61,21 @@ inline void* handle_mem_ptr(NVTETensor mem) { void nvte_ep_initialize(void* ep_comm, const NVTEEpGroupConfig* group_config) { NVTE_CHECK(ep_comm != nullptr, "ep_comm must not be null"); - NVTEEpGroupConfig cfg = - normalize_ep_config(group_config, kGroupConfigMinSize, "group_config"); + NVTEEpGroupConfig cfg = normalize_ep_config(group_config, kGroupConfigMinSize, "group_config"); EPBackend::initialize(static_cast(ep_comm), cfg); } void nvte_ep_shutdown(void) { EPBackend::shutdown(); } size_t nvte_ep_handle_mem_size(const NVTEEpLayerConfig* layer_cfg) { - NVTEEpLayerConfig cfg = - normalize_ep_config(layer_cfg, kLayerConfigMinSize, "layer_cfg"); + NVTEEpLayerConfig cfg = normalize_ep_config(layer_cfg, kLayerConfigMinSize, "layer_cfg"); return EPBackend::get().handle_mem_size(cfg); } void nvte_ep_prepare(NVTETensor handle_mem, NVTETensor topk_idx, NVTETensor recv_tokens_per_expert, NVTETensor total_recv_tokens_per_rank, const NVTEEpLayerConfig* layer_cfg, cudaStream_t stream) { - NVTEEpLayerConfig cfg = - normalize_ep_config(layer_cfg, kLayerConfigMinSize, "layer_cfg"); + NVTEEpLayerConfig cfg = normalize_ep_config(layer_cfg, kLayerConfigMinSize, "layer_cfg"); EPBackend::get().prepare(handle_mem_ptr(handle_mem), topk_idx, recv_tokens_per_expert, total_recv_tokens_per_rank, cfg, stream); } diff --git a/transformer_engine/common/ep/ep_backend.cpp b/transformer_engine/common/ep/ep_backend.cpp index 7c93b24526..d5b7300570 100644 --- a/transformer_engine/common/ep/ep_backend.cpp +++ b/transformer_engine/common/ep/ep_backend.cpp @@ -317,7 +317,8 @@ size_t EPBackend::handle_mem_size(NVTEEpLayerConfig layer_cfg) { return hm_size; } -void EPBackend::prepare(void* handle_mem, const NVTETensor topk_idx, NVTETensor recv_tokens_per_expert, +void EPBackend::prepare(void* handle_mem, const NVTETensor topk_idx, + NVTETensor recv_tokens_per_expert, NVTETensor /*total_recv_tokens_per_rank*/, NVTEEpLayerConfig layer_cfg, cudaStream_t stream) { // total_recv_tokens_per_rank is a reserved placeholder; not yet populated. @@ -332,10 +333,12 @@ void EPBackend::prepare(void* handle_mem, const NVTETensor topk_idx, NVTETensor NVTEShape recv_tokens_per_expert_shape; ncclEpTensor_t recv_tokens_per_expert_desc; if (recv_tokens_per_expert != nullptr) { - recv_tokens_per_expert_desc = make_nccl_ep_tensor(recv_tokens_per_expert, recv_tokens_per_expert_shape); + recv_tokens_per_expert_desc = + make_nccl_ep_tensor(recv_tokens_per_expert, recv_tokens_per_expert_shape); } ncclEpLayoutInfo_t layout_info = NCCL_EP_LAYOUT_INFO_INIT; - layout_info.expert_counters = (recv_tokens_per_expert != nullptr) ? &recv_tokens_per_expert_desc : nullptr; + layout_info.expert_counters = + (recv_tokens_per_expert != nullptr) ? &recv_tokens_per_expert_desc : nullptr; std::lock_guard lock(mutex_); NVTE_CHECK(initialized_, "EPBackend not initialized"); diff --git a/transformer_engine/common/include/transformer_engine/ep.h b/transformer_engine/common/include/transformer_engine/ep.h index 52085fcc0e..224622fd41 100644 --- a/transformer_engine/common/include/transformer_engine/ep.h +++ b/transformer_engine/common/include/transformer_engine/ep.h @@ -79,10 +79,8 @@ typedef struct { /* Zero-init a config with struct_size set to the current layout: * NVTEEpGroupConfig cfg = NVTE_EP_GROUP_CONFIG_INIT; * cfg.ep_size = ...; */ -#define NVTE_EP_GROUP_CONFIG_INIT \ - { sizeof(NVTEEpGroupConfig) } -#define NVTE_EP_LAYER_CONFIG_INIT \ - { sizeof(NVTEEpLayerConfig) } +#define NVTE_EP_GROUP_CONFIG_INIT {sizeof(NVTEEpGroupConfig)} +#define NVTE_EP_LAYER_CONFIG_INIT {sizeof(NVTEEpLayerConfig)} /* -- Bootstrap ------------------------------------------------------------ */ diff --git a/transformer_engine/jax/csrc/extensions/ep.cpp b/transformer_engine/jax/csrc/extensions/ep.cpp index 34124eb674..bfd96776c7 100644 --- a/transformer_engine/jax/csrc/extensions/ep.cpp +++ b/transformer_engine/jax/csrc/extensions/ep.cpp @@ -163,10 +163,10 @@ void ReleaseEpResources() { } size_t EpHandleMemSize(int top_k, size_t dispatch_output_per_expert_alignment) { - NVTEEpLayerConfig layer_cfg{.struct_size = sizeof(NVTEEpLayerConfig), - .top_k = top_k, - .dispatch_output_per_expert_alignment = - dispatch_output_per_expert_alignment}; + NVTEEpLayerConfig layer_cfg{ + .struct_size = sizeof(NVTEEpLayerConfig), + .top_k = top_k, + .dispatch_output_per_expert_alignment = dispatch_output_per_expert_alignment}; return nvte_ep_handle_mem_size(&layer_cfg); } @@ -196,8 +196,8 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(EpInstantiateHandler, EpInstantiateImpl, FFI::Bind // ── ep_prepare ──────────────────────────────────────────────────────────────── Error_Type EpPrepareFFI(cudaStream_t stream, EpInstanceState* ep_state, Buffer_Type topk_idx, - Result_Type recv_tokens_per_expert, Result_Type handle_mem, Result_Type workspace, - EpConfig config) { + Result_Type recv_tokens_per_expert, Result_Type handle_mem, + Result_Type workspace, EpConfig config) { (void)ep_state; // lifetime only. auto topk_dims = topk_idx.dimensions(); NVTE_CHECK(topk_dims.size() >= 2, @@ -223,16 +223,16 @@ Error_Type EpPrepareFFI(cudaStream_t stream, EpInstanceState* ep_state, Buffer_T auto topk_idx_ = TensorWrapper(topk_idx_data, topk_shape, DType::kInt64); std::vector tc_shape = {static_cast(recv_tokens_per_expert->element_count())}; - auto recv_tokens_per_expert_ = TensorWrapper(recv_tokens_per_expert->untyped_data(), tc_shape, DType::kInt32); + auto recv_tokens_per_expert_ = + TensorWrapper(recv_tokens_per_expert->untyped_data(), tc_shape, DType::kInt32); std::vector hm_shape = {static_cast(handle_mem->element_count())}; auto handle_mem_ = TensorWrapper(handle_mem->untyped_data(), hm_shape, DType::kByte); - NVTEEpLayerConfig layer_cfg{ - .struct_size = sizeof(NVTEEpLayerConfig), - .top_k = static_cast(config.top_k), - .dispatch_output_per_expert_alignment = - static_cast(config.dispatch_output_per_expert_alignment)}; + NVTEEpLayerConfig layer_cfg{.struct_size = sizeof(NVTEEpLayerConfig), + .top_k = static_cast(config.top_k), + .dispatch_output_per_expert_alignment = + static_cast(config.dispatch_output_per_expert_alignment)}; nvte_ep_prepare(handle_mem_.data(), topk_idx_.data(), recv_tokens_per_expert_.data(), /*total_recv_tokens_per_rank=*/nullptr, &layer_cfg, stream); return ffi_with_cuda_error_check(); @@ -243,8 +243,8 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(EpPrepareHandler, EpPrepareFFI, .Ctx() // stream .Ctx<::xla::ffi::State>() // EP state .Arg() // topk_idx - .Ret() // recv_tokens_per_expert - .Ret() // handle_mem + .Ret() // recv_tokens_per_expert + .Ret() // handle_mem .Ret() // workspace (FFI scratch) .Attrs(), FFI_CudaGraph_Traits); From a2b8fd335ea7dd823aad6aaeda32defb5a5054fa Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Mon, 29 Jun 2026 12:55:45 -0700 Subject: [PATCH 8/8] Add max_token_dtype range check to nvte_ep_init for clearer error Signed-off-by: Phuong Nguyen --- transformer_engine/common/ep/ep_backend.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/transformer_engine/common/ep/ep_backend.cpp b/transformer_engine/common/ep/ep_backend.cpp index d5b7300570..a82ec1c98d 100644 --- a/transformer_engine/common/ep/ep_backend.cpp +++ b/transformer_engine/common/ep/ep_backend.cpp @@ -97,6 +97,8 @@ void EPBackend::validate_config(const NVTEEpGroupConfig& config) { NVTE_CHECK(config.max_recv_tokens_per_rank > 0, "max_recv_tokens_per_rank must be positive, got ", config.max_recv_tokens_per_rank); NVTE_CHECK(config.hidden_dim > 0, "hidden_dim must be positive, got ", config.hidden_dim); + NVTE_CHECK(config.max_token_dtype >= 0 && config.max_token_dtype < kNVTENumTypes, + "max_token_dtype out of range, got ", static_cast(config.max_token_dtype)); const size_t elem_bytes = typeToSize(static_cast(config.max_token_dtype)); const size_t row_bytes = static_cast(config.hidden_dim) * elem_bytes; NVTE_CHECK(row_bytes >= 16,