-
Notifications
You must be signed in to change notification settings - Fork 22
GEMMTestSuite: use rocrand for input data generation #417
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -783,7 +783,6 @@ std::pair<double, double> getTolerances(const DType type) { | |
| template <typename T> | ||
| void generate_data_uniformly(T* data, const size_t size, std::mt19937* gen) { | ||
| #ifdef __HIP_PLATFORM_AMD__ | ||
| // TODO: Introduce a parallel RNG library (Random123, PCG, rocRAND) | ||
| std::uniform_real_distribution<> dis(-2.0, 1.0); | ||
| for (int i = 0; i < size; i++) { | ||
| data[i] = static_cast<T>(dis(*gen)); | ||
|
|
@@ -822,21 +821,71 @@ void generate_data_uniformly(T* data, const size_t size, std::mt19937* gen) { | |
| #endif | ||
| } | ||
|
|
||
| #ifdef __HIP_PLATFORM_AMD__ | ||
| #include <rocrand/rocrand.h> | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Even though it does not cause errors, better move #include to the top of the file, out of test namespace |
||
|
|
||
| template <typename T> | ||
| __global__ void affine_transform_and_cast(float* __restrict__ in, T* __restrict__ out, size_t n, float lo, float hi) { | ||
| // Clamp values in *in* to [lo, hi] and cast to type *T* for *out*. | ||
| size_t idx = blockIdx.x * blockDim.x + threadIdx.x; | ||
| if (idx < n) { | ||
| in[idx] = lo + (hi - lo) * in[idx]; | ||
| out[idx] = static_cast<T>(in[idx]); | ||
| } | ||
| } | ||
|
|
||
| void fillUniformDevice(Tensor* t) { | ||
| void* dst = t->rowwise() ? t->rowwise_dptr() : t->columnwise_dptr(); | ||
| const auto shape = t->rowwise() ? t->rowwise_shape() : t->columnwise_shape(); | ||
| const size_t N = product(shape); | ||
|
|
||
| float* tmp = nullptr; | ||
| hipMalloc(&tmp, N * sizeof(float)); | ||
|
|
||
| // per-tensor deterministic seed | ||
| const unsigned long long seed = static_cast<unsigned long long>(t->gen()()); | ||
| rocrand_generator gen; | ||
| rocrand_create_generator(&gen, ROCRAND_RNG_PSEUDO_PHILOX4_32_10); | ||
| rocrand_set_seed(gen, seed); | ||
|
|
||
| rocrand_generate_uniform(gen, tmp, N); | ||
|
|
||
| // map to [-2.0, 1.0] (like generate_data_uniformly) and cast into tensor dtype | ||
| TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(t->dtype(), T, { | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. T should either be template parameter and no TRANSFORMER_ENGINE_TYPE_SWITCH_ALL here, or the method calling should be moved out of TRANSFORMER_ENGINE_TYPE_SWITCH_ALL in fillUniform |
||
| dim3 block(256); | ||
| dim3 grid((N + block.x - 1) / block.x); | ||
| hipLaunchKernelGGL(affine_transform_and_cast<T>, grid, block, 0, 0, | ||
| tmp, reinterpret_cast<T*>(dst), N, -2.0f, 1.0f); | ||
| }); | ||
|
|
||
| rocrand_destroy_generator(gen); | ||
| hipFree(tmp); | ||
| } | ||
| #endif | ||
|
|
||
| void fillUniform(Tensor *t) { | ||
| if (t->rowwise()) { | ||
| const size_t size = product(t->rowwise_shape()); | ||
| TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(t->dtype(), T, | ||
| { | ||
| #ifdef __HIP_PLATFORM_AMD__ | ||
| fillUniformDevice(t); | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there any test that tests this generation? I think using GPU generation here does not produce correct result because of using t->from_cpu() below in this method |
||
| #else | ||
| T *data = t->rowwise_cpu_dptr<T>(); | ||
| generate_data_uniformly(data, size, &(t->gen())); | ||
| #endif | ||
| } | ||
| ); | ||
| } else { | ||
| const size_t size = product(t->columnwise_shape()); | ||
| TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(t->dtype(), T, | ||
| { | ||
| #ifdef __HIP_PLATFORM_AMD__ | ||
| fillUniformDevice(t); | ||
| #else | ||
| T *data = t->columnwise_cpu_dptr<T>(); | ||
| generate_data_uniformly(data, size, &(t->gen())); | ||
| #endif | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are we able to use rocRAND for fillCase_special as well? Also, I think there were a few tests that for some reason generate their own data... I might be wrong about that, or it may have been updated. |
||
| } | ||
| ); | ||
| } | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can probably remove this entire #ifdef guarded section, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think so, unless we want to keep it around for future reference?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's remove, we can always revert if we need it again.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The whole method seems unused on ROCm and can be guarded