-
Notifications
You must be signed in to change notification settings - Fork 97
feat: modify aten_bilinear from einsum to matmul #2746
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -1208,12 +1208,31 @@ def aten_bilinear( | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # bias shape: (out_features) - optional | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # output shape: (..., out_features) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Use Einsum to compute the bilinear transformation | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # "...i,oij,...j->...o" means: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # - input1[..., i] * weight[o, i, j] * input2[..., j] -> output[..., o] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| result = op.Einsum(input1, weight, input2, equation="...i,oij,...j->...o") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Use MatMul to compute the bilinear transformation | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| batch_size = op.Shape(input1, start=0, end=1) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| input1_shape = op.Shape(input1, start=-1) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| input2_shape = op.Shape(input2, start=-1) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| output_shape = op.Shape(weight, start=0, end=1) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| neg_1 = op.Constant(value_ints=[-1]) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # (O, H1, H2) -> (H1, O, H2) where O = output_shape, H1 = input1_shape, H2 = input2_shape | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| W_permute = op.Transpose(weight, perm=[1, 0, 2]) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # (H1, O, H2) -> (H1, O * H2) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| W_flat = op.Reshape( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| W_permute, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| op.Concat(input1_shape, op.Mul(output_shape, input2_shape), axis=0), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # (B, H1) @ (H1, O*H2) -> (B, O*H2) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| tmp = op.MatMul(input1, W_flat) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # (B, O*H2) -> (B, O, H2) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| tmp = op.Reshape(tmp, op.Concat(batch_size, output_shape, input2_shape, axis=0)) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # (B, H2) -> (B, H2, 1) -> (B, O, H2) @ (B, H2, 1) -> (B, O, 1) -> (B, O) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+1218
to
+1233
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # (O, H1, H2) -> (H1, O, H2) where O = output_shape, H1 = input1_shape, H2 = input2_shape | |
| W_permute = op.Transpose(weight, perm=[1, 0, 2]) | |
| # (H1, O, H2) -> (H1, O * H2) | |
| W_flat = op.Reshape( | |
| W_permute, | |
| op.Concat(input1_shape, op.Mul(output_shape, input2_shape), axis=0), | |
| ) | |
| # (B, H1) @ (H1, O*H2) -> (B, O*H2) | |
| tmp = op.MatMul(input1, W_flat) | |
| # (B, O*H2) -> (B, O, H2) | |
| tmp = op.Reshape(tmp, op.Concat(batch_size, output_shape, input2_shape, axis=0)) | |
| # (B, H2) -> (B, H2, 1) -> (B, O, H2) @ (B, H2, 1) -> (B, O, 1) -> (B, O) | |
| # (out_features, in1_features, in2_features) -> (in1_features, out_features, in2_features) | |
| W_permute = op.Transpose(weight, perm=[1, 0, 2]) | |
| # (in1_features, out_features, in2_features) -> (in1_features, out_features * in2_features) | |
| W_flat = op.Reshape( | |
| W_permute, | |
| op.Concat(input1_shape, op.Mul(output_shape, input2_shape), axis=0), | |
| ) | |
| # (batch_size, in1_features) @ (in1_features, out_features * in2_features) | |
| # -> (batch_size, out_features * in2_features) | |
| tmp = op.MatMul(input1, W_flat) | |
| # (batch_size, out_features * in2_features) | |
| # -> (batch_size, out_features, in2_features) | |
| tmp = op.Reshape(tmp, op.Concat(batch_size, output_shape, input2_shape, axis=0)) | |
| # (batch_size, in2_features) -> (batch_size, in2_features, 1) | |
| # -> (batch_size, out_features, in2_features) @ (batch_size, in2_features, 1) | |
| # -> (batch_size, out_features, 1) -> (batch_size, out_features) |
Copilot
AI
Dec 30, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The test cases in sample_inputs_bilinear only test 2D inputs with shape (batch_size, features), but the bilinear function is documented to support arbitrary batch dimensions with shape (..., features). The test suite should include cases with multi-dimensional batch inputs like (B1, B2, in_features) to ensure the implementation correctly handles these cases. This would have caught the current implementation bug where only the first batch dimension is extracted.
Copilot
AI
Dec 30, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This operation assumes input2 has shape (B, H2) with a single batch dimension, but input2 can have arbitrary batch dimensions (..., H2). The Unsqueeze and subsequent MatMul operations will fail or produce incorrect results when input2 has multi-dimensional batch shape. For example, if input2 has shape (B1, B2, H2), the unsqueeze would create (B1, B2, H2, 1), but tmp has shape (B1, O, H2), making the matmul dimension-incompatible.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Output has shape (..., out_features). The final output shape should be op.Concat(op.Shape(input1, start=0, end=-1), op.Shape(weight, start=0, end=1))
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This line only extracts the first dimension of input1, but the bilinear function should support arbitrary batch dimensions (as indicated by the "..." notation in the comments above). The original einsum implementation correctly handled multiple batch dimensions. To fix this, you should extract all batch dimensions except the last one, which would be op.Shape(input1, end=-1).