[CK_TILE] Allow UniversalGemmKernel::RunGemm to be called using tensor descriptors
#3457
+126
−131
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Proposed changes [WIP]
BatchedContractionKernelimplements a customRunGemmthat can be called with tensor descriptors, unlikeUniversalGemmKernel::RunGemmwhich computes tensor views directly from the given stride parameters.This adds
UniversalGemmKernel::RunGemmDescwhich takes in descriptors and thus replaces the custom implementation inBatchedContractionKernel. It also has a (partial) POC to implement the originalUniversalGemmKernel::RunGemmusing the descriptor version, see discussion below.Checklist
Please put an
xinto the boxes that apply. You can also fill these out after creating the PR. If you're not sure, please don't hesitate to ask.clang-formaton all changed filesDiscussion
UniversalGemmKernel::RunGemmgenerates tensor views directly from the given parameters throughMakeGemmTensorViews, rather then sepaately building the descriptors and then generating views. In order to refactor the rest of the code path toRunGemmDesc(which would make it simple to have both entry points available),MakeGemmTensorViewwould need to be refactored to separately build the descriptors and views.OTOH,
GroupedConvolutionKernelalready has a similarMakeGemmTensorViewwhich takes in descriptors as doesRunGemm, with no version that expects to construct these internally. ShouldUniversalGemmKernelfollow suite? Alternately, could we expect the caller to construct & pass in the tensor views directly, instead of just the descriptors?