From c7e30e6dddc5c6b7303b9d3aa60e4ce3280453a2 Mon Sep 17 00:00:00 2001 From: fw7th Date: Wed, 24 Dec 2025 07:43:07 +0100 Subject: [PATCH 1/2] feat: modify aten_bilinear from einsum to matmul --- .../function_libs/torch_lib/ops/core.py | 29 +++++++++++++++---- 1 file changed, 24 insertions(+), 5 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 0468b9e05d..15d7806bf7 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -1208,12 +1208,31 @@ def aten_bilinear( # bias shape: (out_features) - optional # output shape: (..., out_features) - # Use Einsum to compute the bilinear transformation - # "...i,oij,...j->...o" means: - # - input1[..., i] * weight[o, i, j] * input2[..., j] -> output[..., o] - result = op.Einsum(input1, weight, input2, equation="...i,oij,...j->...o") + # Use MatMul to compute the bilinear transformation + batch_size = op.Shape(input1, start=0, end=1) + input1_shape = op.Shape(input1, start=-1) + input2_shape = op.Shape(input2, start=-1) + output_shape = op.Shape(weight, start=0, end=1) + neg_1 = op.Constant(value_ints=[-1]) + + # (O, H1, H2) -> (H1, O, H2) where O = output_shape, H1 = input1_shape, H2 = input2_shape + W_permute = op.Transpose(weight, perm=[1, 0, 2]) + + # (H1, O, H2) -> (H1, O * H2) + W_flat = op.Reshape( + W_permute, + op.Concat(input1_shape, op.Mul(output_shape, input2_shape), axis=0), + ) + + # (B, H1) @ (H1, O*H2) -> (B, O*H2) + tmp = op.MatMul(input1, W_flat) + + # (B, O*H2) -> (B, O, H2) + tmp = op.Reshape(tmp, op.Concat(batch_size, output_shape, input2_shape, axis=0)) + + # (B, H2) -> (B, H2, 1) -> (B, O, H2) @ (B, H2, 1) -> (B, O, 1) -> (B, O) + result = op.Squeeze(op.MatMul(tmp, op.Unsqueeze(input2, neg_1)), neg_1) - # Add bias if provided if bias is not None: result = op.Add(result, bias) From 292b8eb4edb7ac4aad3b08462b134f0a65ea54da Mon Sep 17 00:00:00 2001 From: fw7th Date: Tue, 6 Jan 2026 08:13:39 +0100 Subject: [PATCH 2/2] resolved batch dimension issue from initial commit --- onnxscript/function_libs/torch_lib/ops/core.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index cf53c07ab8..b35977a81c 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -1208,29 +1208,32 @@ def aten_bilinear( # bias shape: (out_features) - optional # output shape: (..., out_features) + # input1 and input2 must have identical batch dimensions # Use MatMul to compute the bilinear transformation - batch_size = op.Shape(input1, start=0, end=1) + batch_size = op.Shape(input1, start=0, end=-1) input1_shape = op.Shape(input1, start=-1) input2_shape = op.Shape(input2, start=-1) output_shape = op.Shape(weight, start=0, end=1) neg_1 = op.Constant(value_ints=[-1]) - # (O, H1, H2) -> (H1, O, H2) where O = output_shape, H1 = input1_shape, H2 = input2_shape + # (out_features, in1_features, in2_features) -> (in1_features, out_features, in2_features) W_permute = op.Transpose(weight, perm=[1, 0, 2]) - # (H1, O, H2) -> (H1, O * H2) + # (in1_features, out_features, in2_features) -> (in1_features, out_features * in2_features) W_flat = op.Reshape( W_permute, op.Concat(input1_shape, op.Mul(output_shape, input2_shape), axis=0), ) - # (B, H1) @ (H1, O*H2) -> (B, O*H2) + # (..., in1_features) @ (in1_features, out_features * in2_features) -> (..., out_features * in2_features) tmp = op.MatMul(input1, W_flat) - # (B, O*H2) -> (B, O, H2) + # (..., out_features * in2_features) -> (..., out_features, in2_features) tmp = op.Reshape(tmp, op.Concat(batch_size, output_shape, input2_shape, axis=0)) - # (B, H2) -> (B, H2, 1) -> (B, O, H2) @ (B, H2, 1) -> (B, O, 1) -> (B, O) + # (..., in2_features) -> (..., in2_features, 1) + # -> (..., out_features, in2_features) @ (..., in2_features, 1) + # -> (..., out_features, 1) -> (..., out_features) result = op.Squeeze(op.MatMul(tmp, op.Unsqueeze(input2, neg_1)), neg_1) if bias is not None: