diff --git a/tests/cpp_distributed/test_ep.cu b/tests/cpp_distributed/test_ep.cu index c7fee7720c..7dbbcdce9d 100644 --- a/tests/cpp_distributed/test_ep.cu +++ b/tests/cpp_distributed/test_ep.cu @@ -60,7 +60,7 @@ static std::vector generate_tokens(int rank, int num_tokens, int hidden_dim) return v; } -static std::vector expected_token_counts( +static std::vector expected_recv_tokens_per_expert( int recv_rank, int num_processes, int num_tokens, int top_k, int num_experts, int num_local_experts) { int base = recv_rank * num_local_experts; @@ -128,7 +128,7 @@ struct EPBuffers { DevBuf topk_idx; DevBuf topk_weights; DevBuf tokens; - DevBuf token_counts; + DevBuf recv_tokens_per_expert; DevBuf handle_mem; DevBuf recv_tokens; DevBuf recv_topk_weights; @@ -144,22 +144,26 @@ struct EPBuffers { size_t recv_capacity = 0; int top_k_ = 0; size_t alignment_ = 0; + NVTEEpLayerConfig layer_cfg_{}; void alloc(int num_tokens, int top_k, int hidden_dim, int num_local_experts, int ep_size, int max_tokens_per_rank, size_t alignment = 0) { top_k_ = top_k; alignment_ = alignment; + layer_cfg_ = NVTE_EP_LAYER_CONFIG_INIT; + layer_cfg_.top_k = top_k; + layer_cfg_.dispatch_output_per_expert_alignment = alignment; recv_capacity = static_cast(ep_size) * max_tokens_per_rank * 2; topk_idx.alloc(num_tokens * top_k); topk_weights.alloc(num_tokens * top_k); tokens.alloc(num_tokens * hidden_dim); - token_counts.alloc(num_local_experts); + recv_tokens_per_expert.alloc(num_local_experts); recv_tokens.alloc(recv_capacity * hidden_dim); recv_topk_weights.alloc(recv_capacity); result.alloc(num_tokens * hidden_dim); - handle_mem_size = nvte_ep_handle_mem_size(NVTEEpLayerConfig{top_k, alignment}); + handle_mem_size = nvte_ep_handle_mem_size(&layer_cfg_); handle_mem.alloc(handle_mem_size); grad_result.alloc(num_tokens * hidden_dim); @@ -174,25 +178,29 @@ struct EPBuffers { // expects. template struct EPTensors { - TensorWrapper topk_idx, topk_weights, token_counts, handle_mem, tokens; + TensorWrapper topk_idx, topk_weights, recv_tokens_per_expert, handle_mem, tokens; TensorWrapper recv_tokens, recv_topk_weights, result; TensorWrapper grad_result, grad_expert, grad_tokens; TensorWrapper g_recv_topk_weights, grad_topk_weights; int top_k_ = 0; size_t alignment_ = 0; + NVTEEpLayerConfig layer_cfg_{}; EPTensors(EPBuffers& b, int num_tokens, int top_k, int hidden_dim, int num_local_experts) { top_k_ = top_k; alignment_ = b.alignment_; + layer_cfg_ = NVTE_EP_LAYER_CONFIG_INIT; + layer_cfg_.top_k = top_k; + layer_cfg_.dispatch_output_per_expert_alignment = b.alignment_; constexpr DType kTokDType = test::TypeInfo::dtype; using Shape = std::vector; topk_idx = TensorWrapper(b.topk_idx.get(), Shape{(size_t)num_tokens, (size_t)top_k}, DType::kInt64); topk_weights = TensorWrapper(b.topk_weights.get(), Shape{(size_t)num_tokens, (size_t)top_k}, DType::kFloat32); - token_counts = TensorWrapper(b.token_counts.get(), + recv_tokens_per_expert = TensorWrapper(b.recv_tokens_per_expert.get(), Shape{(size_t)num_local_experts}, DType::kInt32); handle_mem = TensorWrapper(b.handle_mem.get(), Shape{b.handle_mem_size}, DType::kByte); @@ -259,7 +267,7 @@ class EpOpTestBase : public ::testing::Test { template int read_total_recv(const EPBuffers& buf) const { std::vector cnt(num_local_experts_); - NVTE_CHECK_CUDA(cudaMemcpy(cnt.data(), buf.token_counts.get(), + NVTE_CHECK_CUDA(cudaMemcpy(cnt.data(), buf.recv_tokens_per_expert.get(), num_local_experts_ * sizeof(int32_t), cudaMemcpyDeviceToHost)); int total = 0; for (int c : cnt) total += c; @@ -300,7 +308,7 @@ TYPED_TEST(EPDispatchTest, PrepareAndDispatch) { cudaStream_t stream; NVTE_CHECK_CUDA(cudaStreamCreate(&stream)); - ASSERT_NO_THROW(nvte_ep_prepare(t.handle_mem.data(), t.topk_idx.data(), t.token_counts.data(), NVTEEpLayerConfig{t.top_k_, t.alignment_}, stream)); + ASSERT_NO_THROW(nvte_ep_prepare(t.handle_mem.data(), t.topk_idx.data(), t.recv_tokens_per_expert.data(), nullptr, &t.layer_cfg_, stream)); ASSERT_NO_THROW(nvte_ep_dispatch(t.handle_mem.data(), t.topk_idx.data(), t.tokens.data(), NVTECommWindow{}, t.topk_weights.data(), NVTECommWindow{}, t.recv_tokens.data(), NVTECommWindow{}, @@ -309,9 +317,9 @@ TYPED_TEST(EPDispatchTest, PrepareAndDispatch) { // 1. Per-expert counts. std::vector got_counts(num_local_experts_); - NVTE_CHECK_CUDA(cudaMemcpy(got_counts.data(), buf.token_counts.get(), + NVTE_CHECK_CUDA(cudaMemcpy(got_counts.data(), buf.recv_tokens_per_expert.get(), num_local_experts_ * sizeof(int32_t), cudaMemcpyDeviceToHost)); - auto exp_counts = expected_token_counts(g_process_id, g_num_processes, num_tokens_, top_k_, + auto exp_counts = expected_recv_tokens_per_expert(g_process_id, g_num_processes, num_tokens_, top_k_, num_experts_, num_local_experts_); int total_recv = 0; for (int i = 0; i < num_local_experts_; ++i) { @@ -379,7 +387,7 @@ TYPED_TEST(EPCombineTest, Combine) { cudaStream_t stream; NVTE_CHECK_CUDA(cudaStreamCreate(&stream)); - ASSERT_NO_THROW(nvte_ep_prepare(t.handle_mem.data(), t.topk_idx.data(), t.token_counts.data(), NVTEEpLayerConfig{t.top_k_, t.alignment_}, stream)); + ASSERT_NO_THROW(nvte_ep_prepare(t.handle_mem.data(), t.topk_idx.data(), t.recv_tokens_per_expert.data(), nullptr, &t.layer_cfg_, stream)); ASSERT_NO_THROW(nvte_ep_dispatch(t.handle_mem.data(), t.topk_idx.data(), t.tokens.data(), NVTECommWindow{}, t.topk_weights.data(), NVTECommWindow{}, t.recv_tokens.data(), NVTECommWindow{}, @@ -426,7 +434,7 @@ TYPED_TEST(EPCombineBwdTest, CombineBwdCheck) { cudaStream_t stream; NVTE_CHECK_CUDA(cudaStreamCreate(&stream)); - ASSERT_NO_THROW(nvte_ep_prepare(t.handle_mem.data(), t.topk_idx.data(), t.token_counts.data(), NVTEEpLayerConfig{t.top_k_, t.alignment_}, stream)); + ASSERT_NO_THROW(nvte_ep_prepare(t.handle_mem.data(), t.topk_idx.data(), t.recv_tokens_per_expert.data(), nullptr, &t.layer_cfg_, stream)); ASSERT_NO_THROW(nvte_ep_dispatch(t.handle_mem.data(), t.topk_idx.data(), t.tokens.data(), NVTECommWindow{}, t.topk_weights.data(), NVTECommWindow{}, t.recv_tokens.data(), NVTECommWindow{}, @@ -447,7 +455,7 @@ TYPED_TEST(EPCombineBwdTest, CombineBwdCheck) { int total_recv = this->template read_total_recv(buf); std::vector cnt(num_local_experts_); - NVTE_CHECK_CUDA(cudaMemcpy(cnt.data(), buf.token_counts.get(), + NVTE_CHECK_CUDA(cudaMemcpy(cnt.data(), buf.recv_tokens_per_expert.get(), num_local_experts_ * sizeof(int32_t), cudaMemcpyDeviceToHost)); std::vector h_ge(buf.recv_capacity * hidden_dim_); NVTE_CHECK_CUDA(cudaMemcpy(h_ge.data(), buf.grad_expert.get(), @@ -495,7 +503,7 @@ TYPED_TEST(EPDispatchBwdTest, DispatchBwdCheck) { cudaStream_t stream; NVTE_CHECK_CUDA(cudaStreamCreate(&stream)); - ASSERT_NO_THROW(nvte_ep_prepare(t.handle_mem.data(), t.topk_idx.data(), t.token_counts.data(), NVTEEpLayerConfig{t.top_k_, t.alignment_}, stream)); + ASSERT_NO_THROW(nvte_ep_prepare(t.handle_mem.data(), t.topk_idx.data(), t.recv_tokens_per_expert.data(), nullptr, &t.layer_cfg_, stream)); ASSERT_NO_THROW(nvte_ep_dispatch(t.handle_mem.data(), t.topk_idx.data(), t.tokens.data(), NVTECommWindow{}, t.topk_weights.data(), NVTECommWindow{}, t.recv_tokens.data(), NVTECommWindow{}, @@ -563,7 +571,7 @@ TYPED_TEST(EPDispatchBwdGradWeightsTest, RoundTrip) { cudaStream_t stream; NVTE_CHECK_CUDA(cudaStreamCreate(&stream)); - ASSERT_NO_THROW(nvte_ep_prepare(t.handle_mem.data(), t.topk_idx.data(), t.token_counts.data(), NVTEEpLayerConfig{t.top_k_, t.alignment_}, stream)); + ASSERT_NO_THROW(nvte_ep_prepare(t.handle_mem.data(), t.topk_idx.data(), t.recv_tokens_per_expert.data(), nullptr, &t.layer_cfg_, stream)); NVTE_CHECK_CUDA(cudaMemsetAsync(buf.recv_topk_weights.get(), 0, buf.recv_topk_weights.bytes(), stream)); ASSERT_NO_THROW(nvte_ep_dispatch(t.handle_mem.data(), t.topk_idx.data(), @@ -634,7 +642,7 @@ class EPPipelineTest : public EpOpTestBase, public ::testing::WithParamInterface cudaStream_t stream; NVTE_CHECK_CUDA(cudaStreamCreate(&stream)); - ASSERT_NO_THROW(nvte_ep_prepare(t.handle_mem.data(), t.topk_idx.data(), t.token_counts.data(), NVTEEpLayerConfig{t.top_k_, t.alignment_}, stream)); + ASSERT_NO_THROW(nvte_ep_prepare(t.handle_mem.data(), t.topk_idx.data(), t.recv_tokens_per_expert.data(), nullptr, &t.layer_cfg_, stream)); ASSERT_NO_THROW(nvte_ep_dispatch(t.handle_mem.data(), t.topk_idx.data(), t.tokens.data(), NVTECommWindow{}, t.topk_weights.data(), NVTECommWindow{}, t.recv_tokens.data(), NVTECommWindow{}, @@ -759,7 +767,7 @@ TYPED_TEST(EPZeroCopyTest, IdentityAllSymm) { cudaStream_t stream; NVTE_CHECK_CUDA(cudaStreamCreate(&stream)); - ASSERT_NO_THROW(nvte_ep_prepare(ref_t.handle_mem.data(), ref_t.topk_idx.data(), ref_t.token_counts.data(), NVTEEpLayerConfig{ref_t.top_k_, ref_t.alignment_}, stream)); + ASSERT_NO_THROW(nvte_ep_prepare(ref_t.handle_mem.data(), ref_t.topk_idx.data(), ref_t.recv_tokens_per_expert.data(), nullptr, &ref_t.layer_cfg_, stream)); ASSERT_NO_THROW(nvte_ep_dispatch(ref_t.handle_mem.data(), ref_t.topk_idx.data(), ref_t.tokens.data(), NVTECommWindow{}, ref_t.topk_weights.data(), NVTECommWindow{}, ref_t.recv_tokens.data(), NVTECommWindow{}, @@ -800,7 +808,7 @@ TYPED_TEST(EPZeroCopyTest, IdentityAllSymm) { sym_t.recv_tokens = TensorWrapper(sym_recv.ptr, std::vector{sym_buf.recv_capacity, (size_t)hidden_dim_}, kTokDType); - ASSERT_NO_THROW(nvte_ep_prepare(sym_t.handle_mem.data(), sym_t.topk_idx.data(), sym_t.token_counts.data(), NVTEEpLayerConfig{sym_t.top_k_, sym_t.alignment_}, stream)); + ASSERT_NO_THROW(nvte_ep_prepare(sym_t.handle_mem.data(), sym_t.topk_idx.data(), sym_t.recv_tokens_per_expert.data(), nullptr, &sym_t.layer_cfg_, stream)); ASSERT_NO_THROW(nvte_ep_dispatch(sym_t.handle_mem.data(), sym_t.topk_idx.data(), sym_t.tokens.data(), symm_window(sym_tokens), sym_t.topk_weights.data(), NVTECommWindow{}, diff --git a/tests/cpp_distributed/test_ep_common.h b/tests/cpp_distributed/test_ep_common.h index d5e006cef6..7cf6017090 100644 --- a/tests/cpp_distributed/test_ep_common.h +++ b/tests/cpp_distributed/test_ep_common.h @@ -146,7 +146,7 @@ static bool ep_bootstrap(int argc, char* argv[]) { ncclUniqueId uid{}; exchange_unique_id(&uid); - NVTEEpGroupConfig group_config{}; + NVTEEpGroupConfig group_config = NVTE_EP_GROUP_CONFIG_INIT; group_config.ep_size = g_ep_size; group_config.num_experts = g_num_experts; group_config.max_tokens_per_rank = g_max_tokens_per_rank; @@ -156,7 +156,7 @@ static bool ep_bootstrap(int argc, char* argv[]) { group_config.max_token_dtype = g_max_token_dtype; NVTE_CHECK_NCCL(ncclCommInitRank(&g_ep_comm, g_num_processes, uid, g_process_id)); - nvte_ep_initialize(static_cast(g_ep_comm), group_config); + nvte_ep_initialize(static_cast(g_ep_comm), &group_config); if (g_process_id == 0) { printf("EP initialized: ep_size=%d num_experts=%d " @@ -173,7 +173,7 @@ static bool ep_bootstrap(int argc, char* argv[]) { static void ep_reinitialize(int zero_copy) { if (!g_ep_initialized) return; nvte_ep_shutdown(); - NVTEEpGroupConfig group_config{}; + NVTEEpGroupConfig group_config = NVTE_EP_GROUP_CONFIG_INIT; group_config.ep_size = g_ep_size; group_config.num_experts = g_num_experts; group_config.max_tokens_per_rank = g_max_tokens_per_rank; @@ -181,7 +181,7 @@ static void ep_reinitialize(int zero_copy) { group_config.hidden_dim = g_hidden_dim; group_config.max_token_dtype = g_max_token_dtype; group_config.zero_copy = zero_copy; - nvte_ep_initialize(static_cast(g_ep_comm), group_config); + nvte_ep_initialize(static_cast(g_ep_comm), &group_config); } // Tear down in dependency order: backend's ep_group reads from ep_comm, diff --git a/transformer_engine/common/ep/ep_api.cpp b/transformer_engine/common/ep/ep_api.cpp index 66ee3dc8d9..0981289ffe 100644 --- a/transformer_engine/common/ep/ep_api.cpp +++ b/transformer_engine/common/ep/ep_api.cpp @@ -13,6 +13,10 @@ #include +#include +#include +#include + #include "../util/logging.h" #if defined(NVTE_WITH_NCCL_EP) @@ -24,18 +28,30 @@ using transformer_engine::ep::EPBackend; -void nvte_ep_initialize(void* ep_comm, NVTEEpGroupConfig group_config) { - NVTE_CHECK(ep_comm != nullptr, "ep_comm must not be null"); - EPBackend::initialize(static_cast(ep_comm), group_config); -} - -void nvte_ep_shutdown(void) { EPBackend::shutdown(); } - -size_t nvte_ep_handle_mem_size(NVTEEpLayerConfig layer_cfg) { - return EPBackend::get().handle_mem_size(layer_cfg); +namespace { +// Smallest accepted struct_size: covers the base (required) fields. Frozen; +// never raise these. Later fields are read only when struct_size covers them. +constexpr size_t kGroupConfigMinSize = offsetof(NVTEEpGroupConfig, zero_copy) + sizeof(int); +constexpr size_t kLayerConfigMinSize = + offsetof(NVTEEpLayerConfig, dispatch_output_per_expert_alignment) + sizeof(size_t); + +// Copy a caller's versioned config into a full current-layout struct: fields +// the caller did not provide stay zero, extra trailing fields are dropped. +// struct_size 0 is read as the base layout. Requires a size_t struct_size +// first member. +template +Cfg normalize_ep_config(const Cfg* user, size_t min_size, const char* name) { + NVTE_CHECK(user != nullptr, name, " must not be null"); + const size_t want = (user->struct_size == 0) ? min_size : user->struct_size; + NVTE_CHECK(want >= min_size, name, ".struct_size (", user->struct_size, + ") is below the required minimum ", min_size, + "; zero-initialize the struct or set struct_size via NVTE_EP_*_CONFIG_INIT"); + Cfg cfg{}; + std::memcpy(&cfg, user, std::min(want, sizeof(Cfg))); + cfg.struct_size = sizeof(Cfg); + return cfg; } -namespace { inline void* handle_mem_ptr(NVTETensor mem) { void* p = nvte_tensor_data(mem); NVTE_CHECK(p != nullptr, "handle_mem tensor data must not be null"); @@ -43,9 +59,25 @@ inline void* handle_mem_ptr(NVTETensor mem) { } } // namespace -void nvte_ep_prepare(NVTETensor handle_mem, NVTETensor topk_idx, NVTETensor token_counts, - NVTEEpLayerConfig layer_cfg, cudaStream_t stream) { - EPBackend::get().prepare(handle_mem_ptr(handle_mem), topk_idx, token_counts, layer_cfg, stream); +void nvte_ep_initialize(void* ep_comm, const NVTEEpGroupConfig* group_config) { + NVTE_CHECK(ep_comm != nullptr, "ep_comm must not be null"); + NVTEEpGroupConfig cfg = normalize_ep_config(group_config, kGroupConfigMinSize, "group_config"); + EPBackend::initialize(static_cast(ep_comm), cfg); +} + +void nvte_ep_shutdown(void) { EPBackend::shutdown(); } + +size_t nvte_ep_handle_mem_size(const NVTEEpLayerConfig* layer_cfg) { + NVTEEpLayerConfig cfg = normalize_ep_config(layer_cfg, kLayerConfigMinSize, "layer_cfg"); + return EPBackend::get().handle_mem_size(cfg); +} + +void nvte_ep_prepare(NVTETensor handle_mem, NVTETensor topk_idx, NVTETensor recv_tokens_per_expert, + NVTETensor total_recv_tokens_per_rank, const NVTEEpLayerConfig* layer_cfg, + cudaStream_t stream) { + NVTEEpLayerConfig cfg = normalize_ep_config(layer_cfg, kLayerConfigMinSize, "layer_cfg"); + EPBackend::get().prepare(handle_mem_ptr(handle_mem), topk_idx, recv_tokens_per_expert, + total_recv_tokens_per_rank, cfg, stream); } void nvte_ep_dispatch(NVTETensor handle_mem, NVTETensor topk_idx, NVTETensor tokens, @@ -88,15 +120,18 @@ namespace { } } // namespace -void nvte_ep_initialize(void* /*ep_comm*/, NVTEEpGroupConfig /*group_config*/) { ep_not_built(); } +void nvte_ep_initialize(void* /*ep_comm*/, const NVTEEpGroupConfig* /*group_config*/) { + ep_not_built(); +} void nvte_ep_shutdown(void) {} -size_t nvte_ep_handle_mem_size(NVTEEpLayerConfig /*layer_cfg*/) { ep_not_built(); } +size_t nvte_ep_handle_mem_size(const NVTEEpLayerConfig* /*layer_cfg*/) { ep_not_built(); } void nvte_ep_prepare(NVTETensor /*handle_mem*/, NVTETensor /*topk_idx*/, - NVTETensor /*token_counts*/, NVTEEpLayerConfig /*layer_cfg*/, - cudaStream_t /*stream*/) { + NVTETensor /*recv_tokens_per_expert*/, + NVTETensor /*total_recv_tokens_per_rank*/, + const NVTEEpLayerConfig* /*layer_cfg*/, cudaStream_t /*stream*/) { ep_not_built(); } diff --git a/transformer_engine/common/ep/ep_backend.cpp b/transformer_engine/common/ep/ep_backend.cpp index f1510693bb..94f2e2413f 100644 --- a/transformer_engine/common/ep/ep_backend.cpp +++ b/transformer_engine/common/ep/ep_backend.cpp @@ -97,8 +97,6 @@ void EPBackend::validate_config(const NVTEEpGroupConfig& config) { NVTE_CHECK(config.max_recv_tokens_per_rank > 0, "max_recv_tokens_per_rank must be positive, got ", config.max_recv_tokens_per_rank); NVTE_CHECK(config.hidden_dim > 0, "hidden_dim must be positive, got ", config.hidden_dim); - NVTE_CHECK(config.max_token_dtype >= 0 && config.max_token_dtype < kNVTENumTypes, - "max_token_dtype out of range, got ", static_cast(config.max_token_dtype)); const size_t elem_bytes = typeToSize(static_cast(config.max_token_dtype)); const size_t row_bytes = static_cast(config.hidden_dim) * elem_bytes; NVTE_CHECK(row_bytes >= 16, @@ -319,8 +317,11 @@ size_t EPBackend::handle_mem_size(NVTEEpLayerConfig layer_cfg) { return hm_size; } -void EPBackend::prepare(void* handle_mem, const NVTETensor topk_idx, NVTETensor token_counts, - NVTEEpLayerConfig layer_cfg, cudaStream_t stream) { +void EPBackend::prepare(void* handle_mem, const NVTETensor topk_idx, + NVTETensor recv_tokens_per_expert, + NVTETensor /*total_recv_tokens_per_rank*/, NVTEEpLayerConfig layer_cfg, + cudaStream_t stream) { + // total_recv_tokens_per_rank is a reserved placeholder; not yet populated. NVTE_CHECK(handle_mem != nullptr, "handle_mem must not be null"); NVTE_CHECK(layer_cfg.top_k > 0, "top_k must be > 0, got ", layer_cfg.top_k); NVTE_CHECK(nvte_tensor_shape(topk_idx).ndim == 2, "topk_idx must be 2D [T, top_k]"); @@ -329,13 +330,15 @@ void EPBackend::prepare(void* handle_mem, const NVTETensor topk_idx, NVTETensor ncclEpTensor_t nccl_topk_idx = make_nccl_ep_tensor(topk_idx, topk_idx_shape); // ncclEpUpdateHandle writes per-expert counts via expert_counters. - NVTEShape token_counts_shape; - ncclEpTensor_t token_counts_desc; - if (token_counts != nullptr) { - token_counts_desc = make_nccl_ep_tensor(token_counts, token_counts_shape); + NVTEShape recv_tokens_per_expert_shape; + ncclEpTensor_t recv_tokens_per_expert_desc; + if (recv_tokens_per_expert != nullptr) { + recv_tokens_per_expert_desc = + make_nccl_ep_tensor(recv_tokens_per_expert, recv_tokens_per_expert_shape); } ncclEpLayoutInfo_t layout_info = NCCL_EP_LAYOUT_INFO_INIT; - layout_info.expert_counters = (token_counts != nullptr) ? &token_counts_desc : nullptr; + layout_info.expert_counters = + (recv_tokens_per_expert != nullptr) ? &recv_tokens_per_expert_desc : nullptr; std::lock_guard lock(mutex_); NVTE_CHECK(initialized_, "EPBackend not initialized"); diff --git a/transformer_engine/common/ep/ep_backend.h b/transformer_engine/common/ep/ep_backend.h index 2325baafca..80c9b9cea3 100644 --- a/transformer_engine/common/ep/ep_backend.h +++ b/transformer_engine/common/ep/ep_backend.h @@ -46,8 +46,9 @@ class EPBackend { size_t handle_mem_size(NVTEEpLayerConfig layer_cfg); // Seeds the cache for handle_mem with layer_cfg and runs the routing AllGather. - void prepare(void* handle_mem, const NVTETensor topk_idx, NVTETensor token_counts, - NVTEEpLayerConfig layer_cfg, cudaStream_t stream); + void prepare(void* handle_mem, const NVTETensor topk_idx, NVTETensor recv_tokens_per_expert, + NVTETensor total_recv_tokens_per_rank, NVTEEpLayerConfig layer_cfg, + cudaStream_t stream); // Per-step ops below require a prior prepare(). void dispatch(void* handle_mem, const NVTETensor topk_idx, const NVTETensor tokens, diff --git a/transformer_engine/common/include/transformer_engine/ep.h b/transformer_engine/common/include/transformer_engine/ep.h index 8928b92825..716d6aca35 100644 --- a/transformer_engine/common/include/transformer_engine/ep.h +++ b/transformer_engine/common/include/transformer_engine/ep.h @@ -28,11 +28,16 @@ extern "C" { #endif /* -- Config structs ------------------------------------------------------- */ -/* TODO: add a struct_size/version field to these configs (and align with other - * TE public structs) once a TE-wide convention for ABI versioning lands. */ +/* Each config begins with struct_size so the API can add fields without + * breaking ABI. The backend reads only the bytes struct_size covers and + * zero-defaults the rest; struct_size 0 means the base layout. Append new + * fields at the end only; never reorder, resize, or remove existing ones. */ /*! \brief Group-level EP configuration (fixed for the EP group lifetime). */ typedef struct { + /*! Struct size in bytes, or 0 for the base layout. Set to + * sizeof(NVTEEpGroupConfig) to include fields added in newer versions. */ + size_t struct_size; /*! EP world size. */ int ep_size; /*! Total experts across all ranks. */ @@ -46,7 +51,8 @@ typedef struct { /*! Max SMs for EP kernels. 0 = auto. */ int max_num_sms; /*! Widest token dtype the group will dispatch; sizes staging buffers. - * Per-dispatch tensors may use any dtype with element size <= this. */ + * Required (no default): must be set to a real token dtype. Per-dispatch + * tensors may use any dtype with element size <= this. */ NVTEDType max_token_dtype; /*! Zero-copy dispatch/combine. When nonzero, payload tensors must be backed * by NVTECommWindow handles and transfer in place (no staging copies); @@ -59,6 +65,9 @@ typedef struct { * overflow policy, ...). */ typedef struct { + /*! Struct size in bytes, or 0 for the base layout. Set to + * sizeof(NVTEEpLayerConfig) to include fields added in newer versions. */ + size_t struct_size; /*! Per-token expert fan-out (> 0). */ int top_k; /*! Per-expert recv-slab alignment in tokens (power of two; 0/1 disables). @@ -67,6 +76,12 @@ typedef struct { size_t dispatch_output_per_expert_alignment; } NVTEEpLayerConfig; +/* Zero-init a config with struct_size set to the current layout: + * NVTEEpGroupConfig cfg = NVTE_EP_GROUP_CONFIG_INIT; + * cfg.ep_size = ...; */ +#define NVTE_EP_GROUP_CONFIG_INIT {sizeof(NVTEEpGroupConfig)} +#define NVTE_EP_LAYER_CONFIG_INIT {sizeof(NVTEEpLayerConfig)} + /* -- Bootstrap ------------------------------------------------------------ */ /*! \brief Bootstrap the EP backend from an existing NCCL EP sub-communicator. @@ -78,9 +93,9 @@ typedef struct { * group per process, bound to the current CUDA device. * * \param[in] ep_comm Opaque ncclComm_t for the EP sub-group. - * \param[in] group_config Group-level EP configuration. + * \param[in] group_config Group-level EP configuration (struct_size set). */ -void nvte_ep_initialize(void* ep_comm, NVTEEpGroupConfig group_config); +void nvte_ep_initialize(void* ep_comm, const NVTEEpGroupConfig* group_config); /*! \brief Tear down the EP backend. Idempotent. Does not destroy ep_comm. */ void nvte_ep_shutdown(void); @@ -94,10 +109,10 @@ void nvte_ep_shutdown(void); * for that layer (the backend keys its cache on the pointer). Host-only; * size is stable for a given (group, layer) pair. * - * \param[in] layer_cfg Per-call layer configuration. + * \param[in] layer_cfg Per-call layer configuration (struct_size set). * \return size in bytes for the handle_mem buffer. */ -size_t nvte_ep_handle_mem_size(NVTEEpLayerConfig layer_cfg); +size_t nvte_ep_handle_mem_size(const NVTEEpLayerConfig* layer_cfg); /* -- Per-step ops (all allocation-free, CUDA graph-capturable) ------------ */ @@ -106,17 +121,19 @@ size_t nvte_ep_handle_mem_size(NVTEEpLayerConfig layer_cfg); * AllGathers topk_idx across the EP group and stages per-expert offsets and * counts into handle_mem so the matching dispatch/combine/_bwd can run with * no further routing computation. Must precede every dispatch/combine/_bwd - * that uses this handle_mem. token_counts becomes host-valid after a stream - * sync. - * - * \param[in] handle_mem uint8 routing-state buffer. - * \param[in] topk_idx [T, top_k] int64 routing indices. - * \param[out] token_counts [num_local_experts] int32 counts. - * \param[in] layer_cfg Per-call layer configuration. - * \param[in] stream CUDA stream. + * that uses this handle_mem. recv_tokens_per_expert becomes host-valid after a + * stream sync. + * + * \param[in] handle_mem uint8 routing-state buffer. + * \param[in] topk_idx [T, top_k] int64 routing indices. + * \param[out] recv_tokens_per_expert [num_local_experts] int32 counts. + * \param[out] total_recv_tokens_per_rank Reserved placeholder; may be null. Unused for now. + * \param[in] layer_cfg Per-call layer configuration (struct_size set). + * \param[in] stream CUDA stream. */ -void nvte_ep_prepare(NVTETensor handle_mem, NVTETensor topk_idx, NVTETensor token_counts, - NVTEEpLayerConfig layer_cfg, cudaStream_t stream); +void nvte_ep_prepare(NVTETensor handle_mem, NVTETensor topk_idx, NVTETensor recv_tokens_per_expert, + NVTETensor total_recv_tokens_per_rank, const NVTEEpLayerConfig* layer_cfg, + cudaStream_t stream); /*! \brief Dispatch tokens (and routing weights) to expert ranks. * diff --git a/transformer_engine/jax/csrc/extensions/ep.cpp b/transformer_engine/jax/csrc/extensions/ep.cpp index ee204e7594..14f90b7b72 100644 --- a/transformer_engine/jax/csrc/extensions/ep.cpp +++ b/transformer_engine/jax/csrc/extensions/ep.cpp @@ -45,7 +45,8 @@ class EpResources { NVTE_CHECK_NCCL(ncclCommInitRank(&comm_, p.ep_size, uid, p.rank_within_group)); // zero_copy=0: JAX EP path always stages payloads; the zero-copy fast path // requires NVTECommWindow-backed tensors, which JAX bindings don't expose. - NVTEEpGroupConfig cfg{.ep_size = p.ep_size, + NVTEEpGroupConfig cfg{.struct_size = sizeof(NVTEEpGroupConfig), + .ep_size = p.ep_size, .num_experts = p.num_experts, .max_tokens_per_rank = p.max_tokens_per_rank, .max_recv_tokens_per_rank = p.max_recv_tokens_per_rank, @@ -54,7 +55,7 @@ class EpResources { .max_token_dtype = p.max_token_dtype, .zero_copy = 0}; try { - nvte_ep_initialize(static_cast(comm_), cfg); + nvte_ep_initialize(static_cast(comm_), &cfg); } catch (...) { ncclCommDestroy(comm_); comm_ = nullptr; @@ -162,8 +163,11 @@ void ReleaseEpResources() { } size_t EpHandleMemSize(int top_k, size_t dispatch_output_per_expert_alignment) { - NVTEEpLayerConfig layer_cfg{top_k, dispatch_output_per_expert_alignment}; - return nvte_ep_handle_mem_size(layer_cfg); + NVTEEpLayerConfig layer_cfg{ + .struct_size = sizeof(NVTEEpLayerConfig), + .top_k = top_k, + .dispatch_output_per_expert_alignment = dispatch_output_per_expert_alignment}; + return nvte_ep_handle_mem_size(&layer_cfg); } pybind11::capsule GetEpInstanceStateTypeIdCapsule() { @@ -192,8 +196,8 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(EpInstantiateHandler, EpInstantiateImpl, FFI::Bind // ── ep_prepare ──────────────────────────────────────────────────────────────── Error_Type EpPrepareFFI(cudaStream_t stream, EpInstanceState* ep_state, Buffer_Type topk_idx, - Result_Type token_counts, Result_Type handle_mem, Result_Type workspace, - EpConfig config) { + Result_Type recv_tokens_per_expert, Result_Type handle_mem, + Result_Type workspace, EpConfig config) { (void)ep_state; // lifetime only. auto topk_dims = topk_idx.dimensions(); NVTE_CHECK(topk_dims.size() >= 2, @@ -218,15 +222,19 @@ Error_Type EpPrepareFFI(cudaStream_t stream, EpInstanceState* ep_state, Buffer_T } auto topk_idx_ = TensorWrapper(topk_idx_data, topk_shape, DType::kInt64); - std::vector tc_shape = {static_cast(token_counts->element_count())}; - auto token_counts_ = TensorWrapper(token_counts->untyped_data(), tc_shape, DType::kInt32); + std::vector tc_shape = {static_cast(recv_tokens_per_expert->element_count())}; + auto recv_tokens_per_expert_ = + TensorWrapper(recv_tokens_per_expert->untyped_data(), tc_shape, DType::kInt32); std::vector hm_shape = {static_cast(handle_mem->element_count())}; auto handle_mem_ = TensorWrapper(handle_mem->untyped_data(), hm_shape, DType::kByte); - NVTEEpLayerConfig layer_cfg{static_cast(config.top_k), - static_cast(config.dispatch_output_per_expert_alignment)}; - nvte_ep_prepare(handle_mem_.data(), topk_idx_.data(), token_counts_.data(), layer_cfg, stream); + NVTEEpLayerConfig layer_cfg{.struct_size = sizeof(NVTEEpLayerConfig), + .top_k = static_cast(config.top_k), + .dispatch_output_per_expert_alignment = + static_cast(config.dispatch_output_per_expert_alignment)}; + nvte_ep_prepare(handle_mem_.data(), topk_idx_.data(), recv_tokens_per_expert_.data(), + /*total_recv_tokens_per_rank=*/nullptr, &layer_cfg, stream); return ffi_with_cuda_error_check(); } @@ -235,8 +243,8 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(EpPrepareHandler, EpPrepareFFI, .Ctx() // stream .Ctx<::xla::ffi::State>() // EP state .Arg() // topk_idx - .Ret() // token_counts - .Ret() // handle_mem + .Ret() // recv_tokens_per_expert + .Ret() // handle_mem .Ret() // workspace (FFI scratch) .Attrs(), FFI_CudaGraph_Traits);