Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 26 additions & 18 deletions tests/cpp_distributed/test_ep.cu
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ static std::vector<T> generate_tokens(int rank, int num_tokens, int hidden_dim)
return v;
}

static std::vector<int32_t> expected_token_counts(
static std::vector<int32_t> expected_recv_tokens_per_expert(
int recv_rank, int num_processes, int num_tokens, int top_k,
int num_experts, int num_local_experts) {
int base = recv_rank * num_local_experts;
Expand Down Expand Up @@ -128,7 +128,7 @@ struct EPBuffers {
DevBuf<int64_t> topk_idx;
DevBuf<float> topk_weights;
DevBuf<T> tokens;
DevBuf<int32_t> token_counts;
DevBuf<int32_t> recv_tokens_per_expert;
DevBuf<uint8_t> handle_mem;
DevBuf<T> recv_tokens;
DevBuf<float> recv_topk_weights;
Expand All @@ -144,22 +144,26 @@ struct EPBuffers {
size_t recv_capacity = 0;
int top_k_ = 0;
size_t alignment_ = 0;
NVTEEpLayerConfig layer_cfg_{};

void alloc(int num_tokens, int top_k, int hidden_dim, int num_local_experts,
int ep_size, int max_tokens_per_rank, size_t alignment = 0) {
top_k_ = top_k;
alignment_ = alignment;
layer_cfg_ = NVTE_EP_LAYER_CONFIG_INIT;
layer_cfg_.top_k = top_k;
layer_cfg_.dispatch_output_per_expert_alignment = alignment;
recv_capacity = static_cast<size_t>(ep_size) * max_tokens_per_rank * 2;

topk_idx.alloc(num_tokens * top_k);
topk_weights.alloc(num_tokens * top_k);
tokens.alloc(num_tokens * hidden_dim);
token_counts.alloc(num_local_experts);
recv_tokens_per_expert.alloc(num_local_experts);
recv_tokens.alloc(recv_capacity * hidden_dim);
recv_topk_weights.alloc(recv_capacity);
result.alloc(num_tokens * hidden_dim);

handle_mem_size = nvte_ep_handle_mem_size(NVTEEpLayerConfig{top_k, alignment});
handle_mem_size = nvte_ep_handle_mem_size(&layer_cfg_);
handle_mem.alloc(handle_mem_size);

grad_result.alloc(num_tokens * hidden_dim);
Expand All @@ -174,25 +178,29 @@ struct EPBuffers {
// expects.
template <typename T = nv_bfloat16>
struct EPTensors {
TensorWrapper topk_idx, topk_weights, token_counts, handle_mem, tokens;
TensorWrapper topk_idx, topk_weights, recv_tokens_per_expert, handle_mem, tokens;
TensorWrapper recv_tokens, recv_topk_weights, result;
TensorWrapper grad_result, grad_expert, grad_tokens;
TensorWrapper g_recv_topk_weights, grad_topk_weights;

int top_k_ = 0;
size_t alignment_ = 0;
NVTEEpLayerConfig layer_cfg_{};

EPTensors(EPBuffers<T>& b, int num_tokens, int top_k, int hidden_dim,
int num_local_experts) {
top_k_ = top_k;
alignment_ = b.alignment_;
layer_cfg_ = NVTE_EP_LAYER_CONFIG_INIT;
layer_cfg_.top_k = top_k;
layer_cfg_.dispatch_output_per_expert_alignment = b.alignment_;
constexpr DType kTokDType = test::TypeInfo<T>::dtype;
using Shape = std::vector<size_t>;
topk_idx = TensorWrapper(b.topk_idx.get(),
Shape{(size_t)num_tokens, (size_t)top_k}, DType::kInt64);
topk_weights = TensorWrapper(b.topk_weights.get(),
Shape{(size_t)num_tokens, (size_t)top_k}, DType::kFloat32);
token_counts = TensorWrapper(b.token_counts.get(),
recv_tokens_per_expert = TensorWrapper(b.recv_tokens_per_expert.get(),
Shape{(size_t)num_local_experts}, DType::kInt32);
handle_mem = TensorWrapper(b.handle_mem.get(),
Shape{b.handle_mem_size}, DType::kByte);
Expand Down Expand Up @@ -259,7 +267,7 @@ class EpOpTestBase : public ::testing::Test {
template <typename T = nv_bfloat16>
int read_total_recv(const EPBuffers<T>& buf) const {
std::vector<int32_t> cnt(num_local_experts_);
NVTE_CHECK_CUDA(cudaMemcpy(cnt.data(), buf.token_counts.get(),
NVTE_CHECK_CUDA(cudaMemcpy(cnt.data(), buf.recv_tokens_per_expert.get(),
num_local_experts_ * sizeof(int32_t), cudaMemcpyDeviceToHost));
int total = 0;
for (int c : cnt) total += c;
Expand Down Expand Up @@ -300,7 +308,7 @@ TYPED_TEST(EPDispatchTest, PrepareAndDispatch) {
cudaStream_t stream;
NVTE_CHECK_CUDA(cudaStreamCreate(&stream));

ASSERT_NO_THROW(nvte_ep_prepare(t.handle_mem.data(), t.topk_idx.data(), t.token_counts.data(), NVTEEpLayerConfig{t.top_k_, t.alignment_}, stream));
ASSERT_NO_THROW(nvte_ep_prepare(t.handle_mem.data(), t.topk_idx.data(), t.recv_tokens_per_expert.data(), nullptr, &t.layer_cfg_, stream));
ASSERT_NO_THROW(nvte_ep_dispatch(t.handle_mem.data(), t.topk_idx.data(),
t.tokens.data(), NVTECommWindow{}, t.topk_weights.data(),
NVTECommWindow{}, t.recv_tokens.data(), NVTECommWindow{},
Expand All @@ -309,9 +317,9 @@ TYPED_TEST(EPDispatchTest, PrepareAndDispatch) {

// 1. Per-expert counts.
std::vector<int32_t> got_counts(num_local_experts_);
NVTE_CHECK_CUDA(cudaMemcpy(got_counts.data(), buf.token_counts.get(),
NVTE_CHECK_CUDA(cudaMemcpy(got_counts.data(), buf.recv_tokens_per_expert.get(),
num_local_experts_ * sizeof(int32_t), cudaMemcpyDeviceToHost));
auto exp_counts = expected_token_counts(g_process_id, g_num_processes, num_tokens_, top_k_,
auto exp_counts = expected_recv_tokens_per_expert(g_process_id, g_num_processes, num_tokens_, top_k_,
num_experts_, num_local_experts_);
int total_recv = 0;
for (int i = 0; i < num_local_experts_; ++i) {
Expand Down Expand Up @@ -379,7 +387,7 @@ TYPED_TEST(EPCombineTest, Combine) {
cudaStream_t stream;
NVTE_CHECK_CUDA(cudaStreamCreate(&stream));

ASSERT_NO_THROW(nvte_ep_prepare(t.handle_mem.data(), t.topk_idx.data(), t.token_counts.data(), NVTEEpLayerConfig{t.top_k_, t.alignment_}, stream));
ASSERT_NO_THROW(nvte_ep_prepare(t.handle_mem.data(), t.topk_idx.data(), t.recv_tokens_per_expert.data(), nullptr, &t.layer_cfg_, stream));
ASSERT_NO_THROW(nvte_ep_dispatch(t.handle_mem.data(), t.topk_idx.data(),
t.tokens.data(), NVTECommWindow{}, t.topk_weights.data(),
NVTECommWindow{}, t.recv_tokens.data(), NVTECommWindow{},
Expand Down Expand Up @@ -426,7 +434,7 @@ TYPED_TEST(EPCombineBwdTest, CombineBwdCheck) {
cudaStream_t stream;
NVTE_CHECK_CUDA(cudaStreamCreate(&stream));

ASSERT_NO_THROW(nvte_ep_prepare(t.handle_mem.data(), t.topk_idx.data(), t.token_counts.data(), NVTEEpLayerConfig{t.top_k_, t.alignment_}, stream));
ASSERT_NO_THROW(nvte_ep_prepare(t.handle_mem.data(), t.topk_idx.data(), t.recv_tokens_per_expert.data(), nullptr, &t.layer_cfg_, stream));
ASSERT_NO_THROW(nvte_ep_dispatch(t.handle_mem.data(), t.topk_idx.data(),
t.tokens.data(), NVTECommWindow{}, t.topk_weights.data(),
NVTECommWindow{}, t.recv_tokens.data(), NVTECommWindow{},
Expand All @@ -447,7 +455,7 @@ TYPED_TEST(EPCombineBwdTest, CombineBwdCheck) {
int total_recv = this->template read_total_recv<Tok>(buf);

std::vector<int32_t> cnt(num_local_experts_);
NVTE_CHECK_CUDA(cudaMemcpy(cnt.data(), buf.token_counts.get(),
NVTE_CHECK_CUDA(cudaMemcpy(cnt.data(), buf.recv_tokens_per_expert.get(),
num_local_experts_ * sizeof(int32_t), cudaMemcpyDeviceToHost));
std::vector<Tok> h_ge(buf.recv_capacity * hidden_dim_);
NVTE_CHECK_CUDA(cudaMemcpy(h_ge.data(), buf.grad_expert.get(),
Expand Down Expand Up @@ -495,7 +503,7 @@ TYPED_TEST(EPDispatchBwdTest, DispatchBwdCheck) {
cudaStream_t stream;
NVTE_CHECK_CUDA(cudaStreamCreate(&stream));

ASSERT_NO_THROW(nvte_ep_prepare(t.handle_mem.data(), t.topk_idx.data(), t.token_counts.data(), NVTEEpLayerConfig{t.top_k_, t.alignment_}, stream));
ASSERT_NO_THROW(nvte_ep_prepare(t.handle_mem.data(), t.topk_idx.data(), t.recv_tokens_per_expert.data(), nullptr, &t.layer_cfg_, stream));
ASSERT_NO_THROW(nvte_ep_dispatch(t.handle_mem.data(), t.topk_idx.data(),
t.tokens.data(), NVTECommWindow{}, t.topk_weights.data(),
NVTECommWindow{}, t.recv_tokens.data(), NVTECommWindow{},
Expand Down Expand Up @@ -563,7 +571,7 @@ TYPED_TEST(EPDispatchBwdGradWeightsTest, RoundTrip) {
cudaStream_t stream;
NVTE_CHECK_CUDA(cudaStreamCreate(&stream));

ASSERT_NO_THROW(nvte_ep_prepare(t.handle_mem.data(), t.topk_idx.data(), t.token_counts.data(), NVTEEpLayerConfig{t.top_k_, t.alignment_}, stream));
ASSERT_NO_THROW(nvte_ep_prepare(t.handle_mem.data(), t.topk_idx.data(), t.recv_tokens_per_expert.data(), nullptr, &t.layer_cfg_, stream));
NVTE_CHECK_CUDA(cudaMemsetAsync(buf.recv_topk_weights.get(), 0,
buf.recv_topk_weights.bytes(), stream));
ASSERT_NO_THROW(nvte_ep_dispatch(t.handle_mem.data(), t.topk_idx.data(),
Expand Down Expand Up @@ -634,7 +642,7 @@ class EPPipelineTest : public EpOpTestBase, public ::testing::WithParamInterface
cudaStream_t stream;
NVTE_CHECK_CUDA(cudaStreamCreate(&stream));

ASSERT_NO_THROW(nvte_ep_prepare(t.handle_mem.data(), t.topk_idx.data(), t.token_counts.data(), NVTEEpLayerConfig{t.top_k_, t.alignment_}, stream));
ASSERT_NO_THROW(nvte_ep_prepare(t.handle_mem.data(), t.topk_idx.data(), t.recv_tokens_per_expert.data(), nullptr, &t.layer_cfg_, stream));
ASSERT_NO_THROW(nvte_ep_dispatch(t.handle_mem.data(), t.topk_idx.data(),
t.tokens.data(), NVTECommWindow{}, t.topk_weights.data(),
NVTECommWindow{}, t.recv_tokens.data(), NVTECommWindow{},
Expand Down Expand Up @@ -759,7 +767,7 @@ TYPED_TEST(EPZeroCopyTest, IdentityAllSymm) {
cudaStream_t stream;
NVTE_CHECK_CUDA(cudaStreamCreate(&stream));

ASSERT_NO_THROW(nvte_ep_prepare(ref_t.handle_mem.data(), ref_t.topk_idx.data(), ref_t.token_counts.data(), NVTEEpLayerConfig{ref_t.top_k_, ref_t.alignment_}, stream));
ASSERT_NO_THROW(nvte_ep_prepare(ref_t.handle_mem.data(), ref_t.topk_idx.data(), ref_t.recv_tokens_per_expert.data(), nullptr, &ref_t.layer_cfg_, stream));
ASSERT_NO_THROW(nvte_ep_dispatch(ref_t.handle_mem.data(), ref_t.topk_idx.data(),
ref_t.tokens.data(), NVTECommWindow{}, ref_t.topk_weights.data(),
NVTECommWindow{}, ref_t.recv_tokens.data(), NVTECommWindow{},
Expand Down Expand Up @@ -800,7 +808,7 @@ TYPED_TEST(EPZeroCopyTest, IdentityAllSymm) {
sym_t.recv_tokens = TensorWrapper(sym_recv.ptr,
std::vector<size_t>{sym_buf.recv_capacity, (size_t)hidden_dim_}, kTokDType);

ASSERT_NO_THROW(nvte_ep_prepare(sym_t.handle_mem.data(), sym_t.topk_idx.data(), sym_t.token_counts.data(), NVTEEpLayerConfig{sym_t.top_k_, sym_t.alignment_}, stream));
ASSERT_NO_THROW(nvte_ep_prepare(sym_t.handle_mem.data(), sym_t.topk_idx.data(), sym_t.recv_tokens_per_expert.data(), nullptr, &sym_t.layer_cfg_, stream));
ASSERT_NO_THROW(nvte_ep_dispatch(sym_t.handle_mem.data(), sym_t.topk_idx.data(),
sym_t.tokens.data(), symm_window(sym_tokens),
sym_t.topk_weights.data(), NVTECommWindow{},
Expand Down
8 changes: 4 additions & 4 deletions tests/cpp_distributed/test_ep_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ static bool ep_bootstrap(int argc, char* argv[]) {
ncclUniqueId uid{};
exchange_unique_id(&uid);

NVTEEpGroupConfig group_config{};
NVTEEpGroupConfig group_config = NVTE_EP_GROUP_CONFIG_INIT;
group_config.ep_size = g_ep_size;
group_config.num_experts = g_num_experts;
group_config.max_tokens_per_rank = g_max_tokens_per_rank;
Expand All @@ -156,7 +156,7 @@ static bool ep_bootstrap(int argc, char* argv[]) {
group_config.max_token_dtype = g_max_token_dtype;

NVTE_CHECK_NCCL(ncclCommInitRank(&g_ep_comm, g_num_processes, uid, g_process_id));
nvte_ep_initialize(static_cast<void*>(g_ep_comm), group_config);
nvte_ep_initialize(static_cast<void*>(g_ep_comm), &group_config);

if (g_process_id == 0) {
printf("EP initialized: ep_size=%d num_experts=%d "
Expand All @@ -173,15 +173,15 @@ static bool ep_bootstrap(int argc, char* argv[]) {
static void ep_reinitialize(int zero_copy) {
if (!g_ep_initialized) return;
nvte_ep_shutdown();
NVTEEpGroupConfig group_config{};
NVTEEpGroupConfig group_config = NVTE_EP_GROUP_CONFIG_INIT;
group_config.ep_size = g_ep_size;
group_config.num_experts = g_num_experts;
group_config.max_tokens_per_rank = g_max_tokens_per_rank;
group_config.max_recv_tokens_per_rank = g_ep_size * g_max_tokens_per_rank * 2;
group_config.hidden_dim = g_hidden_dim;
group_config.max_token_dtype = g_max_token_dtype;
group_config.zero_copy = zero_copy;
nvte_ep_initialize(static_cast<void*>(g_ep_comm), group_config);
nvte_ep_initialize(static_cast<void*>(g_ep_comm), &group_config);
}

// Tear down in dependency order: backend's ep_group reads from ep_comm,
Expand Down
69 changes: 52 additions & 17 deletions transformer_engine/common/ep/ep_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@

#include <transformer_engine/ep.h>

#include <algorithm>
#include <cstddef>
#include <cstring>

#include "../util/logging.h"

#if defined(NVTE_WITH_NCCL_EP)
Expand All @@ -24,28 +28,56 @@

using transformer_engine::ep::EPBackend;

void nvte_ep_initialize(void* ep_comm, NVTEEpGroupConfig group_config) {
NVTE_CHECK(ep_comm != nullptr, "ep_comm must not be null");
EPBackend::initialize(static_cast<ncclComm_t>(ep_comm), group_config);
}

void nvte_ep_shutdown(void) { EPBackend::shutdown(); }

size_t nvte_ep_handle_mem_size(NVTEEpLayerConfig layer_cfg) {
return EPBackend::get().handle_mem_size(layer_cfg);
namespace {
// Smallest accepted struct_size: covers the base (required) fields. Frozen;
// never raise these. Later fields are read only when struct_size covers them.
constexpr size_t kGroupConfigMinSize = offsetof(NVTEEpGroupConfig, zero_copy) + sizeof(int);
constexpr size_t kLayerConfigMinSize =
offsetof(NVTEEpLayerConfig, dispatch_output_per_expert_alignment) + sizeof(size_t);

// Copy a caller's versioned config into a full current-layout struct: fields
// the caller did not provide stay zero, extra trailing fields are dropped.
// struct_size 0 is read as the base layout. Requires a size_t struct_size
// first member.
template <typename Cfg>
Cfg normalize_ep_config(const Cfg* user, size_t min_size, const char* name) {
NVTE_CHECK(user != nullptr, name, " must not be null");
const size_t want = (user->struct_size == 0) ? min_size : user->struct_size;
NVTE_CHECK(want >= min_size, name, ".struct_size (", user->struct_size,
Comment on lines +45 to +46

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 struct_size == 0 and truncated memcpy can silently copy a wrong number of bytes

When user->struct_size is between min_size and sizeof(Cfg) the memcpy copies exactly want bytes, which is correct. However, the NVTE_CHECK below only fires when want < min_size after the ternary substitution. If struct_size is a small non-zero garbage value (e.g. 1), the check accepts it as long as min_size == 1, and memcpy will copy just 1 byte into cfg leaving almost every field zeroed. Consider making the guard explicit: user->struct_size == 0 || want >= min_size.

Suggested change
const size_t want = (user->struct_size == 0) ? min_size : user->struct_size;
NVTE_CHECK(want >= min_size, name, ".struct_size (", user->struct_size,
const size_t want = (user->struct_size == 0) ? min_size : user->struct_size;
NVTE_CHECK(user->struct_size == 0 || want >= min_size, name, ".struct_size (", user->struct_size,

") is below the required minimum ", min_size,
"; zero-initialize the struct or set struct_size via NVTE_EP_*_CONFIG_INIT");
Cfg cfg{};
std::memcpy(&cfg, user, std::min(want, sizeof(Cfg)));
cfg.struct_size = sizeof(Cfg);
return cfg;
}

namespace {
inline void* handle_mem_ptr(NVTETensor mem) {
void* p = nvte_tensor_data(mem);
NVTE_CHECK(p != nullptr, "handle_mem tensor data must not be null");
return p;
}
} // namespace

void nvte_ep_prepare(NVTETensor handle_mem, NVTETensor topk_idx, NVTETensor token_counts,
NVTEEpLayerConfig layer_cfg, cudaStream_t stream) {
EPBackend::get().prepare(handle_mem_ptr(handle_mem), topk_idx, token_counts, layer_cfg, stream);
void nvte_ep_initialize(void* ep_comm, const NVTEEpGroupConfig* group_config) {
NVTE_CHECK(ep_comm != nullptr, "ep_comm must not be null");
NVTEEpGroupConfig cfg = normalize_ep_config(group_config, kGroupConfigMinSize, "group_config");
EPBackend::initialize(static_cast<ncclComm_t>(ep_comm), cfg);
}

void nvte_ep_shutdown(void) { EPBackend::shutdown(); }

size_t nvte_ep_handle_mem_size(const NVTEEpLayerConfig* layer_cfg) {
NVTEEpLayerConfig cfg = normalize_ep_config(layer_cfg, kLayerConfigMinSize, "layer_cfg");
return EPBackend::get().handle_mem_size(cfg);
}

void nvte_ep_prepare(NVTETensor handle_mem, NVTETensor topk_idx, NVTETensor recv_tokens_per_expert,
NVTETensor total_recv_tokens_per_rank, const NVTEEpLayerConfig* layer_cfg,
cudaStream_t stream) {
NVTEEpLayerConfig cfg = normalize_ep_config(layer_cfg, kLayerConfigMinSize, "layer_cfg");
EPBackend::get().prepare(handle_mem_ptr(handle_mem), topk_idx, recv_tokens_per_expert,
total_recv_tokens_per_rank, cfg, stream);
}

void nvte_ep_dispatch(NVTETensor handle_mem, NVTETensor topk_idx, NVTETensor tokens,
Expand Down Expand Up @@ -88,15 +120,18 @@ namespace {
}
} // namespace

void nvte_ep_initialize(void* /*ep_comm*/, NVTEEpGroupConfig /*group_config*/) { ep_not_built(); }
void nvte_ep_initialize(void* /*ep_comm*/, const NVTEEpGroupConfig* /*group_config*/) {
ep_not_built();
}

void nvte_ep_shutdown(void) {}

size_t nvte_ep_handle_mem_size(NVTEEpLayerConfig /*layer_cfg*/) { ep_not_built(); }
size_t nvte_ep_handle_mem_size(const NVTEEpLayerConfig* /*layer_cfg*/) { ep_not_built(); }

void nvte_ep_prepare(NVTETensor /*handle_mem*/, NVTETensor /*topk_idx*/,
NVTETensor /*token_counts*/, NVTEEpLayerConfig /*layer_cfg*/,
cudaStream_t /*stream*/) {
NVTETensor /*recv_tokens_per_expert*/,
NVTETensor /*total_recv_tokens_per_rank*/,
const NVTEEpLayerConfig* /*layer_cfg*/, cudaStream_t /*stream*/) {
ep_not_built();
}

Expand Down
21 changes: 12 additions & 9 deletions transformer_engine/common/ep/ep_backend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,6 @@ void EPBackend::validate_config(const NVTEEpGroupConfig& config) {
NVTE_CHECK(config.max_recv_tokens_per_rank > 0, "max_recv_tokens_per_rank must be positive, got ",
config.max_recv_tokens_per_rank);
NVTE_CHECK(config.hidden_dim > 0, "hidden_dim must be positive, got ", config.hidden_dim);
NVTE_CHECK(config.max_token_dtype >= 0 && config.max_token_dtype < kNVTENumTypes,
"max_token_dtype out of range, got ", static_cast<int>(config.max_token_dtype));
const size_t elem_bytes = typeToSize(static_cast<DType>(config.max_token_dtype));

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Removed max_token_dtype range check degrades error messages

The explicit NVTE_CHECK(config.max_token_dtype >= 0 && config.max_token_dtype < kNVTENumTypes, ...) guard was removed, so typeToSize() is now the first line of defense. An out-of-range value (e.g. an accidentally unset field left as a garbage integer) will produce either a generic TRANSFORMER_ENGINE_TYPE_SWITCH_ALL error or — if typeToNumBits silently returns 0 for an unknown type — a confusing "row_bytes < 16" failure rather than a clear "max_token_dtype out of range" message. Given that the header comment now says the field is "Required (no default)", a fast, explicit bounds check here would give users a much better diagnostic.

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add it back!

const size_t row_bytes = static_cast<size_t>(config.hidden_dim) * elem_bytes;
NVTE_CHECK(row_bytes >= 16,
Expand Down Expand Up @@ -319,8 +317,11 @@ size_t EPBackend::handle_mem_size(NVTEEpLayerConfig layer_cfg) {
return hm_size;
}

void EPBackend::prepare(void* handle_mem, const NVTETensor topk_idx, NVTETensor token_counts,
NVTEEpLayerConfig layer_cfg, cudaStream_t stream) {
void EPBackend::prepare(void* handle_mem, const NVTETensor topk_idx,
NVTETensor recv_tokens_per_expert,
NVTETensor /*total_recv_tokens_per_rank*/, NVTEEpLayerConfig layer_cfg,
cudaStream_t stream) {
// total_recv_tokens_per_rank is a reserved placeholder; not yet populated.
NVTE_CHECK(handle_mem != nullptr, "handle_mem must not be null");
NVTE_CHECK(layer_cfg.top_k > 0, "top_k must be > 0, got ", layer_cfg.top_k);
NVTE_CHECK(nvte_tensor_shape(topk_idx).ndim == 2, "topk_idx must be 2D [T, top_k]");
Expand All @@ -329,13 +330,15 @@ void EPBackend::prepare(void* handle_mem, const NVTETensor topk_idx, NVTETensor
ncclEpTensor_t nccl_topk_idx = make_nccl_ep_tensor(topk_idx, topk_idx_shape);

// ncclEpUpdateHandle writes per-expert counts via expert_counters.
NVTEShape token_counts_shape;
ncclEpTensor_t token_counts_desc;
if (token_counts != nullptr) {
token_counts_desc = make_nccl_ep_tensor(token_counts, token_counts_shape);
NVTEShape recv_tokens_per_expert_shape;
ncclEpTensor_t recv_tokens_per_expert_desc;
if (recv_tokens_per_expert != nullptr) {
recv_tokens_per_expert_desc =
make_nccl_ep_tensor(recv_tokens_per_expert, recv_tokens_per_expert_shape);
}
ncclEpLayoutInfo_t layout_info = NCCL_EP_LAYOUT_INFO_INIT;
layout_info.expert_counters = (token_counts != nullptr) ? &token_counts_desc : nullptr;
layout_info.expert_counters =
(recv_tokens_per_expert != nullptr) ? &recv_tokens_per_expert_desc : nullptr;

std::lock_guard<std::mutex> lock(mutex_);
NVTE_CHECK(initialized_, "EPBackend not initialized");
Expand Down
Loading
Loading