-
Notifications
You must be signed in to change notification settings - Fork 8
Merging code from IFU branch. #8
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
|
||
| # To build rocSHMEM with MPI disabled, please add this flag -DUSE_EXTERNAL_MPI=OFF | ||
| MPI_ROOT=$BUILD_DIR/ompi ../rocSHMEM/scripts/build_configs/gda_mlx5 --fresh \ | ||
| -DUSE_IPC=ON \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why are you reverting the README?
| parser.add_argument("--verbose", action="store_true", help="Verbose build") | ||
| parser.add_argument("--enable_timer", action="store_true", help="Enable timer to debug time out in internode") | ||
| parser.add_argument("--rocm-disable-ctx", action="store_true", help="Disable workgroup context optimization in internode") | ||
| parser.add_argument("--disable-mpi", action="store_true", help="Disable MPI detection and configuration") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
disable-mpi should be kept.
| for (int j = 0; j < kNumElemsPerRead; j += 2) { | ||
| float2 fp32x2 = {fp32_values[j] * scale, fp32_values[j + 1] * scale}; | ||
| #ifdef USE_ROCM | ||
| #if defined(__gfx942__) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These changes need to be reverted. It breaks for MI350.
csrc/kernels/internode_ll.cu
Outdated
| internode::shmem_ctx_schar_put_nbi_warp(ctx, | ||
| #endif | ||
| reinterpret_cast<signed char*>(dst_ptr), reinterpret_cast<signed char*>(src_ptr), num_bytes_per_msg, dst_rank); | ||
| #if defined(ROCM_DISABLE_CTX) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These changes also need to be reverted.
| // Assign bias pointers | ||
| /*auto bias_opts = std::vector<std::optional<torch::Tensor>>({bias_0, bias_1}); | ||
| void* bias_ptrs[2] = {nullptr, nullptr}; | ||
| for (int i = 0; i < 2; ++i) | ||
| if (bias_opts[i].has_value()) { | ||
| auto bias = bias_opts[i].value(); | ||
| EP_HOST_ASSERT(bias.dim() == 2 and bias.is_contiguous()); | ||
| EP_HOST_ASSERT(bias.scalar_type() == x.scalar_type()); | ||
| EP_HOST_ASSERT(bias.size(0) == num_recv_tokens and bias.size(1) == hidden); | ||
| bias_ptrs[i] = bias.data_ptr(); | ||
| } | ||
| */ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's remove it or comment that it might be needed for future work
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've added a comment to say it's not supported at this time.
| /*for (auto& to : {topk_weights, recv_topk_weights, bias_0, bias_1}) { | ||
| to.has_value() ? to->record_stream(comm_stream) : void(); | ||
| if (allocate_on_comm_stream) | ||
| to.has_value() ? to->record_stream(compute_stream) : void(); | ||
| }*/ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's remove it or comment that it might be needed for future work
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added a comment.
csrc/deep_ep.hpp
Outdated
| //const std::optional<torch::Tensor>& bias_0, | ||
| //const std::optional<torch::Tensor>& bias_1, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's remove it
| namespace intranode { | ||
|
|
||
| void barrier(int **task_fifo_ptrs, int head, int rank, int num_ranks, cudaStream_t stream); | ||
| //void barrier(int **task_fifo_ptrs, int head, int rank, int num_ranks, cudaStream_t stream); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's remove it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
| if (not (cond)) { \ | ||
| printf("Assertion failed: %s:%d, condition: %s\n", __FILE__, __LINE__, #cond); \ | ||
| trap(); \ | ||
| abort();\ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why was that changed? As far as I remember, abort() function was unavailable on device side
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Trap was unrecognized during compilation.
csrc/kernels/internode_ll.cu
Outdated
| #if !defined(ROCM_DISABLE_CTX) | ||
| __shared__ internode::shmem_ctx_t ctx; | ||
| internode::shmem_wg_ctx_create(&ctx); | ||
| EP_DEVICE_ASSERT(internode::shmem_wg_ctx_create(&ctx) == 0); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe there's something like INVALID_CTX to compare against, but not zero?
csrc/kernels/intranode.cu
Outdated
| //#pragma unroll | ||
| //for (int i = 0; i < kNumRanks; ++ i) | ||
| // per_rank_buffer[rank * kNumRanks + i] = num_tokens_per_rank[i]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's clean-up
csrc/kernels/runtime.cu
Outdated
| } | ||
|
|
||
| void barrier(int** task_fifo_ptrs, int head, int rank, int num_ranks, cudaStream_t stream) { | ||
| /*void barrier(int** task_fifo_ptrs, int head, int rank, int num_ranks, cudaStream_t stream) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's remove old version
csrc/kernels/utils.cuh
Outdated
| #include "exception.cuh" | ||
|
|
||
| #ifdef USE_ROCM | ||
| #define syncthreads() __syncthreads() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why can't we just use __syncthreads() everywhere? There's no custom functionality added behind this function, and __ will explicitly mark that we're using runtime one
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wondered this, but was just following how it's always done and assumed that there was some good reason for this?
Probably just some debug at some point?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's seems like there's no point for that particular function to wrap it. It is (was) necessary for some other calls like __shfl_sync for example, because there we have different number of arguments compared to CUDA runtime thus a decorator is required. Let's revert to __synchtreads()
Removed unused definition.
Removed unused definition.
Motivation
Technical Details
Test Plan
Test Result
Submission Checklist