Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 24 additions & 5 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1208,12 +1208,31 @@ def aten_bilinear(
# bias shape: (out_features) - optional
# output shape: (..., out_features)

# Use Einsum to compute the bilinear transformation
# "...i,oij,...j->...o" means:
# - input1[..., i] * weight[o, i, j] * input2[..., j] -> output[..., o]
result = op.Einsum(input1, weight, input2, equation="...i,oij,...j->...o")
# Use MatMul to compute the bilinear transformation
batch_size = op.Shape(input1, start=0, end=1)
Copy link

Copilot AI Dec 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line only extracts the first dimension of input1, but the bilinear function should support arbitrary batch dimensions (as indicated by the "..." notation in the comments above). The original einsum implementation correctly handled multiple batch dimensions. To fix this, you should extract all batch dimensions except the last one, which would be op.Shape(input1, end=-1).

Suggested change
batch_size = op.Shape(input1, start=0, end=1)
batch_size = op.Shape(input1, end=-1)

Copilot uses AI. Check for mistakes.
input1_shape = op.Shape(input1, start=-1)
input2_shape = op.Shape(input2, start=-1)
output_shape = op.Shape(weight, start=0, end=1)
neg_1 = op.Constant(value_ints=[-1])

# (O, H1, H2) -> (H1, O, H2) where O = output_shape, H1 = input1_shape, H2 = input2_shape
W_permute = op.Transpose(weight, perm=[1, 0, 2])

# (H1, O, H2) -> (H1, O * H2)
W_flat = op.Reshape(
W_permute,
op.Concat(input1_shape, op.Mul(output_shape, input2_shape), axis=0),
)

# (B, H1) @ (H1, O*H2) -> (B, O*H2)
tmp = op.MatMul(input1, W_flat)
Copy link

Copilot AI Dec 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This MatMul operation assumes input1 has shape (B, H1) with only a single batch dimension, but input1 can have arbitrary batch dimensions like (B1, B2, ..., Bn, H1). The matmul needs to handle multi-dimensional batch inputs correctly. When input1 has shape (..., H1) where ... represents multiple dimensions, this needs to be flattened to (*, H1) before the matmul, or the weight needs to be broadcasted appropriately.

Copilot uses AI. Check for mistakes.

# (B, O*H2) -> (B, O, H2)
tmp = op.Reshape(tmp, op.Concat(batch_size, output_shape, input2_shape, axis=0))
Copy link

Copilot AI Dec 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This reshape only uses batch_size (the first dimension), but should use all batch dimensions. For inputs with shape (..., H1), the result of the previous matmul would have shape (, OH2) where * represents the flattened batch dimensions, and this needs to be reshaped to (..., O, H2) where ... represents the original multi-dimensional batch shape.

Copilot uses AI. Check for mistakes.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of using batch_size, I think you should use op.Shape(input1, start=0, end=-1) to get all of the leading shapes.


# (B, H2) -> (B, H2, 1) -> (B, O, H2) @ (B, H2, 1) -> (B, O, 1) -> (B, O)
Comment on lines +1218 to +1233
Copy link

Copilot AI Dec 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The variable names in the comments don't match the actual variable names used in the code. The comment mentions 'B' for batch dimension, but the code uses 'batch_size'. Similarly, the comment uses 'H1', 'H2', and 'O' which aren't actual variables. For clarity, either use the actual variable names in comments or define these as actual variables in the code.

Suggested change
# (O, H1, H2) -> (H1, O, H2) where O = output_shape, H1 = input1_shape, H2 = input2_shape
W_permute = op.Transpose(weight, perm=[1, 0, 2])
# (H1, O, H2) -> (H1, O * H2)
W_flat = op.Reshape(
W_permute,
op.Concat(input1_shape, op.Mul(output_shape, input2_shape), axis=0),
)
# (B, H1) @ (H1, O*H2) -> (B, O*H2)
tmp = op.MatMul(input1, W_flat)
# (B, O*H2) -> (B, O, H2)
tmp = op.Reshape(tmp, op.Concat(batch_size, output_shape, input2_shape, axis=0))
# (B, H2) -> (B, H2, 1) -> (B, O, H2) @ (B, H2, 1) -> (B, O, 1) -> (B, O)
# (out_features, in1_features, in2_features) -> (in1_features, out_features, in2_features)
W_permute = op.Transpose(weight, perm=[1, 0, 2])
# (in1_features, out_features, in2_features) -> (in1_features, out_features * in2_features)
W_flat = op.Reshape(
W_permute,
op.Concat(input1_shape, op.Mul(output_shape, input2_shape), axis=0),
)
# (batch_size, in1_features) @ (in1_features, out_features * in2_features)
# -> (batch_size, out_features * in2_features)
tmp = op.MatMul(input1, W_flat)
# (batch_size, out_features * in2_features)
# -> (batch_size, out_features, in2_features)
tmp = op.Reshape(tmp, op.Concat(batch_size, output_shape, input2_shape, axis=0))
# (batch_size, in2_features) -> (batch_size, in2_features, 1)
# -> (batch_size, out_features, in2_features) @ (batch_size, in2_features, 1)
# -> (batch_size, out_features, 1) -> (batch_size, out_features)

Copilot uses AI. Check for mistakes.
result = op.Squeeze(op.MatMul(tmp, op.Unsqueeze(input2, neg_1)), neg_1)
Comment on lines +1211 to +1234
Copy link

Copilot AI Dec 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test cases in sample_inputs_bilinear only test 2D inputs with shape (batch_size, features), but the bilinear function is documented to support arbitrary batch dimensions with shape (..., features). The test suite should include cases with multi-dimensional batch inputs like (B1, B2, in_features) to ensure the implementation correctly handles these cases. This would have caught the current implementation bug where only the first batch dimension is extracted.

Copilot uses AI. Check for mistakes.
Copy link

Copilot AI Dec 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This operation assumes input2 has shape (B, H2) with a single batch dimension, but input2 can have arbitrary batch dimensions (..., H2). The Unsqueeze and subsequent MatMul operations will fail or produce incorrect results when input2 has multi-dimensional batch shape. For example, if input2 has shape (B1, B2, H2), the unsqueeze would create (B1, B2, H2, 1), but tmp has shape (B1, O, H2), making the matmul dimension-incompatible.

Copilot uses AI. Check for mistakes.
Comment on lines +1233 to +1234
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Output has shape (..., out_features). The final output shape should be op.Concat(op.Shape(input1, start=0, end=-1), op.Shape(weight, start=0, end=1))


# Add bias if provided
if bias is not None:
result = op.Add(result, bias)

Expand Down