-
Notifications
You must be signed in to change notification settings - Fork 759
[Common] EP C API: version config structs and extend nvte_ep_prepare with total_recv_tokens_per_rank placeholder
#3154
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
799badc
5b104ce
3d3abf2
c7da769
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -97,8 +97,6 @@ void EPBackend::validate_config(const NVTEEpGroupConfig& config) { | |
| NVTE_CHECK(config.max_recv_tokens_per_rank > 0, "max_recv_tokens_per_rank must be positive, got ", | ||
| config.max_recv_tokens_per_rank); | ||
| NVTE_CHECK(config.hidden_dim > 0, "hidden_dim must be positive, got ", config.hidden_dim); | ||
| NVTE_CHECK(config.max_token_dtype >= 0 && config.max_token_dtype < kNVTENumTypes, | ||
| "max_token_dtype out of range, got ", static_cast<int>(config.max_token_dtype)); | ||
| const size_t elem_bytes = typeToSize(static_cast<DType>(config.max_token_dtype)); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
The explicit Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add it back! |
||
| const size_t row_bytes = static_cast<size_t>(config.hidden_dim) * elem_bytes; | ||
| NVTE_CHECK(row_bytes >= 16, | ||
|
|
@@ -319,8 +317,11 @@ size_t EPBackend::handle_mem_size(NVTEEpLayerConfig layer_cfg) { | |
| return hm_size; | ||
| } | ||
|
|
||
| void EPBackend::prepare(void* handle_mem, const NVTETensor topk_idx, NVTETensor token_counts, | ||
| NVTEEpLayerConfig layer_cfg, cudaStream_t stream) { | ||
| void EPBackend::prepare(void* handle_mem, const NVTETensor topk_idx, | ||
| NVTETensor recv_tokens_per_expert, | ||
| NVTETensor /*total_recv_tokens_per_rank*/, NVTEEpLayerConfig layer_cfg, | ||
| cudaStream_t stream) { | ||
| // total_recv_tokens_per_rank is a reserved placeholder; not yet populated. | ||
| NVTE_CHECK(handle_mem != nullptr, "handle_mem must not be null"); | ||
| NVTE_CHECK(layer_cfg.top_k > 0, "top_k must be > 0, got ", layer_cfg.top_k); | ||
| NVTE_CHECK(nvte_tensor_shape(topk_idx).ndim == 2, "topk_idx must be 2D [T, top_k]"); | ||
|
|
@@ -329,13 +330,15 @@ void EPBackend::prepare(void* handle_mem, const NVTETensor topk_idx, NVTETensor | |
| ncclEpTensor_t nccl_topk_idx = make_nccl_ep_tensor(topk_idx, topk_idx_shape); | ||
|
|
||
| // ncclEpUpdateHandle writes per-expert counts via expert_counters. | ||
| NVTEShape token_counts_shape; | ||
| ncclEpTensor_t token_counts_desc; | ||
| if (token_counts != nullptr) { | ||
| token_counts_desc = make_nccl_ep_tensor(token_counts, token_counts_shape); | ||
| NVTEShape recv_tokens_per_expert_shape; | ||
| ncclEpTensor_t recv_tokens_per_expert_desc; | ||
| if (recv_tokens_per_expert != nullptr) { | ||
| recv_tokens_per_expert_desc = | ||
| make_nccl_ep_tensor(recv_tokens_per_expert, recv_tokens_per_expert_shape); | ||
| } | ||
| ncclEpLayoutInfo_t layout_info = NCCL_EP_LAYOUT_INFO_INIT; | ||
| layout_info.expert_counters = (token_counts != nullptr) ? &token_counts_desc : nullptr; | ||
| layout_info.expert_counters = | ||
| (recv_tokens_per_expert != nullptr) ? &recv_tokens_per_expert_desc : nullptr; | ||
|
|
||
| std::lock_guard<std::mutex> lock(mutex_); | ||
| NVTE_CHECK(initialized_, "EPBackend not initialized"); | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
struct_size == 0and truncated memcpy can silently copy a wrong number of bytesWhen
user->struct_sizeis betweenmin_sizeandsizeof(Cfg)thememcpycopies exactlywantbytes, which is correct. However, the NVTE_CHECK below only fires whenwant < min_sizeafter the ternary substitution. Ifstruct_sizeis a small non-zero garbage value (e.g. 1), the check accepts it as long asmin_size == 1, andmemcpywill copy just 1 byte intocfgleaving almost every field zeroed. Consider making the guard explicit:user->struct_size == 0 || want >= min_size.