-
Notifications
You must be signed in to change notification settings - Fork 98
feat: modify aten_bilinear from einsum to matmul #2746
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
@microsoft-github-policy-service agree |
Note on test failureThe failure appears unrelated to the bilinear changes. The error is raised during test input generation for This suggests an RNG seed issue in the test harness rather than an op |
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## main #2746 +/- ##
=======================================
Coverage 70.10% 70.11%
=======================================
Files 228 228
Lines 27394 27403 +9
Branches 2781 2781
=======================================
+ Hits 19204 19213 +9
Misses 7234 7234
Partials 956 956 ☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks. @xadupre or @gramalingam for another review
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR aims to optimize the aten_bilinear function by replacing the Einstein summation (einsum) implementation with matrix multiplication (matmul) operations for improved performance.
Key changes:
- Replaced single einsum operation with a sequence of transpose, reshape, matmul, and squeeze operations
- Added explicit shape extraction for batch dimensions and feature dimensions
- Modified computation flow to use matmul instead of einsum for the bilinear transformation
| # (O, H1, H2) -> (H1, O, H2) where O = output_shape, H1 = input1_shape, H2 = input2_shape | ||
| W_permute = op.Transpose(weight, perm=[1, 0, 2]) | ||
|
|
||
| # (H1, O, H2) -> (H1, O * H2) | ||
| W_flat = op.Reshape( | ||
| W_permute, | ||
| op.Concat(input1_shape, op.Mul(output_shape, input2_shape), axis=0), | ||
| ) | ||
|
|
||
| # (B, H1) @ (H1, O*H2) -> (B, O*H2) | ||
| tmp = op.MatMul(input1, W_flat) | ||
|
|
||
| # (B, O*H2) -> (B, O, H2) | ||
| tmp = op.Reshape(tmp, op.Concat(batch_size, output_shape, input2_shape, axis=0)) | ||
|
|
||
| # (B, H2) -> (B, H2, 1) -> (B, O, H2) @ (B, H2, 1) -> (B, O, 1) -> (B, O) |
Copilot
AI
Dec 30, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The variable names in the comments don't match the actual variable names used in the code. The comment mentions 'B' for batch dimension, but the code uses 'batch_size'. Similarly, the comment uses 'H1', 'H2', and 'O' which aren't actual variables. For clarity, either use the actual variable names in comments or define these as actual variables in the code.
| # (O, H1, H2) -> (H1, O, H2) where O = output_shape, H1 = input1_shape, H2 = input2_shape | |
| W_permute = op.Transpose(weight, perm=[1, 0, 2]) | |
| # (H1, O, H2) -> (H1, O * H2) | |
| W_flat = op.Reshape( | |
| W_permute, | |
| op.Concat(input1_shape, op.Mul(output_shape, input2_shape), axis=0), | |
| ) | |
| # (B, H1) @ (H1, O*H2) -> (B, O*H2) | |
| tmp = op.MatMul(input1, W_flat) | |
| # (B, O*H2) -> (B, O, H2) | |
| tmp = op.Reshape(tmp, op.Concat(batch_size, output_shape, input2_shape, axis=0)) | |
| # (B, H2) -> (B, H2, 1) -> (B, O, H2) @ (B, H2, 1) -> (B, O, 1) -> (B, O) | |
| # (out_features, in1_features, in2_features) -> (in1_features, out_features, in2_features) | |
| W_permute = op.Transpose(weight, perm=[1, 0, 2]) | |
| # (in1_features, out_features, in2_features) -> (in1_features, out_features * in2_features) | |
| W_flat = op.Reshape( | |
| W_permute, | |
| op.Concat(input1_shape, op.Mul(output_shape, input2_shape), axis=0), | |
| ) | |
| # (batch_size, in1_features) @ (in1_features, out_features * in2_features) | |
| # -> (batch_size, out_features * in2_features) | |
| tmp = op.MatMul(input1, W_flat) | |
| # (batch_size, out_features * in2_features) | |
| # -> (batch_size, out_features, in2_features) | |
| tmp = op.Reshape(tmp, op.Concat(batch_size, output_shape, input2_shape, axis=0)) | |
| # (batch_size, in2_features) -> (batch_size, in2_features, 1) | |
| # -> (batch_size, out_features, in2_features) @ (batch_size, in2_features, 1) | |
| # -> (batch_size, out_features, 1) -> (batch_size, out_features) |
| # Use MatMul to compute the bilinear transformation | ||
| batch_size = op.Shape(input1, start=0, end=1) | ||
| input1_shape = op.Shape(input1, start=-1) | ||
| input2_shape = op.Shape(input2, start=-1) | ||
| output_shape = op.Shape(weight, start=0, end=1) | ||
| neg_1 = op.Constant(value_ints=[-1]) | ||
|
|
||
| # (O, H1, H2) -> (H1, O, H2) where O = output_shape, H1 = input1_shape, H2 = input2_shape | ||
| W_permute = op.Transpose(weight, perm=[1, 0, 2]) | ||
|
|
||
| # (H1, O, H2) -> (H1, O * H2) | ||
| W_flat = op.Reshape( | ||
| W_permute, | ||
| op.Concat(input1_shape, op.Mul(output_shape, input2_shape), axis=0), | ||
| ) | ||
|
|
||
| # (B, H1) @ (H1, O*H2) -> (B, O*H2) | ||
| tmp = op.MatMul(input1, W_flat) | ||
|
|
||
| # (B, O*H2) -> (B, O, H2) | ||
| tmp = op.Reshape(tmp, op.Concat(batch_size, output_shape, input2_shape, axis=0)) | ||
|
|
||
| # (B, H2) -> (B, H2, 1) -> (B, O, H2) @ (B, H2, 1) -> (B, O, 1) -> (B, O) | ||
| result = op.Squeeze(op.MatMul(tmp, op.Unsqueeze(input2, neg_1)), neg_1) |
Copilot
AI
Dec 30, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The test cases in sample_inputs_bilinear only test 2D inputs with shape (batch_size, features), but the bilinear function is documented to support arbitrary batch dimensions with shape (..., features). The test suite should include cases with multi-dimensional batch inputs like (B1, B2, in_features) to ensure the implementation correctly handles these cases. This would have caught the current implementation bug where only the first batch dimension is extracted.
| # - input1[..., i] * weight[o, i, j] * input2[..., j] -> output[..., o] | ||
| result = op.Einsum(input1, weight, input2, equation="...i,oij,...j->...o") | ||
| # Use MatMul to compute the bilinear transformation | ||
| batch_size = op.Shape(input1, start=0, end=1) |
Copilot
AI
Dec 30, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This line only extracts the first dimension of input1, but the bilinear function should support arbitrary batch dimensions (as indicated by the "..." notation in the comments above). The original einsum implementation correctly handled multiple batch dimensions. To fix this, you should extract all batch dimensions except the last one, which would be op.Shape(input1, end=-1).
| batch_size = op.Shape(input1, start=0, end=1) | |
| batch_size = op.Shape(input1, end=-1) |
| ) | ||
|
|
||
| # (B, H1) @ (H1, O*H2) -> (B, O*H2) | ||
| tmp = op.MatMul(input1, W_flat) |
Copilot
AI
Dec 30, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This MatMul operation assumes input1 has shape (B, H1) with only a single batch dimension, but input1 can have arbitrary batch dimensions like (B1, B2, ..., Bn, H1). The matmul needs to handle multi-dimensional batch inputs correctly. When input1 has shape (..., H1) where ... represents multiple dimensions, this needs to be flattened to (*, H1) before the matmul, or the weight needs to be broadcasted appropriately.
| tmp = op.MatMul(input1, W_flat) | ||
|
|
||
| # (B, O*H2) -> (B, O, H2) | ||
| tmp = op.Reshape(tmp, op.Concat(batch_size, output_shape, input2_shape, axis=0)) |
Copilot
AI
Dec 30, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This reshape only uses batch_size (the first dimension), but should use all batch dimensions. For inputs with shape (..., H1), the result of the previous matmul would have shape (, OH2) where * represents the flattened batch dimensions, and this needs to be reshaped to (..., O, H2) where ... represents the original multi-dimensional batch shape.
| tmp = op.Reshape(tmp, op.Concat(batch_size, output_shape, input2_shape, axis=0)) | ||
|
|
||
| # (B, H2) -> (B, H2, 1) -> (B, O, H2) @ (B, H2, 1) -> (B, O, 1) -> (B, O) | ||
| result = op.Squeeze(op.MatMul(tmp, op.Unsqueeze(input2, neg_1)), neg_1) |
Copilot
AI
Dec 30, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This operation assumes input2 has shape (B, H2) with a single batch dimension, but input2 can have arbitrary batch dimensions (..., H2). The Unsqueeze and subsequent MatMul operations will fail or produce incorrect results when input2 has multi-dimensional batch shape. For example, if input2 has shape (B1, B2, H2), the unsqueeze would create (B1, B2, H2, 1), but tmp has shape (B1, O, H2), making the matmul dimension-incompatible.
| # (B, H2) -> (B, H2, 1) -> (B, O, H2) @ (B, H2, 1) -> (B, O, 1) -> (B, O) | ||
| result = op.Squeeze(op.MatMul(tmp, op.Unsqueeze(input2, neg_1)), neg_1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Output has shape (..., out_features). The final output shape should be op.Concat(op.Shape(input1, start=0, end=-1), op.Shape(weight, start=0, end=1))
| tmp = op.MatMul(input1, W_flat) | ||
|
|
||
| # (B, O*H2) -> (B, O, H2) | ||
| tmp = op.Reshape(tmp, op.Concat(batch_size, output_shape, input2_shape, axis=0)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of using batch_size, I think you should use op.Shape(input1, start=0, end=-1) to get all of the leading shapes.
justinchuby
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Found some issues
What this does
Why
Testing
Resolves #2573