Skip to content

Conversation

@WHoutstanding
Copy link
Contributor

PR Category

other

Description

对以下算子进行特殊处理
torch.matmul,
torch.nn.functional.linear,
torch.nn.functional.conv2d,
torch.bmm,
torch.nn.functional.scaled_dot_product_attention,

@paddle-bot
Copy link

paddle-bot bot commented Jan 20, 2026

Thanks for your contribution!

@paddle-bot paddle-bot bot added the contributor External developers label Jan 20, 2026
@WHoutstanding
Copy link
Contributor Author

float16和bfloat16都失败的samples原来共有19个,特殊处理后成功转换11个samples
另外8个失败的samples不会在以下算子出错

torch.matmul,
torch.nn.functional.linear,
torch.nn.functional.conv2d,
torch.bmm,
torch.nn.functional.scaled_dot_product_attention,

但是会有新的node类型为call_function的算子或者node类型为call_method的节点Error

@WHoutstanding
Copy link
Contributor Author

测试samples:

samples/transformers-auto-model/hf-tiny-model-private_tiny-random-Swinv2ForImageClassification
samples/torchgeometric/RECT_L
samples/transformers-auto-model/google_byt5_base
samples/transformers-auto-model/Fsoft-AIC_videberta-base
samples/mmseg/SegNeXt-b
samples/transformers-auto-model/OFA-Sys_chinese-clip-vit-large-patch14
samples/timm/vit_small_patch16_rope_mixed_ape_224.naver_in1k
samples/transformers-auto-model/google/t5-efficient-large-kv128
samples/transformers-auto-model/MoritzLaurer_xtremedistil-l6-h256-mnli-fever-anli-ling-binary
samples/mmseg/CCNet_R101
samples/transformers-auto-model/facebook_sam-vit-large
samples/transformers-auto-model/ogoshi2000_stance-nystromformer
samples/transformers-auto-model/TinyLlama/TinyLlama-1.1B-Chat-v0.4
samples/transformers-auto-model/google-t5_t5-large
samples/transformers-auto-model/apple_aimv2-huge-patch14-224
samples/transformers-auto-model/microsoft_swin-base-patch4-window12-384-in22k
samples/mmpose/pose_swin_b
samples/transformers-auto-model/all-mpnet-base-v2
samples/transformers-auto-model/Neurora_opus-tatoeba-heb-eng

运行log:
call_function_f16_bf16_error_log_init.txt

@Xreki
Copy link
Collaborator

Xreki commented Jan 20, 2026

日志中,大部分样本报错跟matmul_3 = attn_9 @ v_1denominator = coef_1.bmm(bmm_2)有关,其中@实际应该是matmul算子,这两个可以再看下。

你这个PR处理的是amp的白名单算子,其他的报错跟amp的黑名单算子有关,可以先不继续。

@WHoutstanding
Copy link
Contributor Author

手动设置如下白名单:
AMP_CALL_FUNCTION = {
torch.matmul,
torch.mm,
torch.bmm,
torch.nn.functional.linear,
torch.nn.functional.conv1d,
torch.nn.functional.conv2d,
torch.nn.functional.conv3d,
torch.nn.functional.scaled_dot_product_attention,
}

AMP_CALL_METHOD = {
"matmul",
"mm",
"bmm",
}

float16和bfloat16都失败的samples原来共有19个,设置白名单后成功转换12个samples

2个samples不需要rewrite:
model_path='samples/transformers-auto-model/Neurora_opus-tatoeba-heb-eng'
model_path='samples/transformers-auto-model/MoritzLaurer_xtremedistil-l6-h256-mnli-fever-anli-ling-binary'

5个samples没有成功转换,error算子为:

    attention_scores_3 = attention_scores_2.masked_fill(invert, -3.4028234663852886e+38);  attention_scores_2 = invert = None
RuntimeError: value cannot be converted to type at::Half without overflow
    freqs_x = unsqueeze @ unsqueeze_1;  unsqueeze = unsqueeze_1 = None
RuntimeError: expected scalar type Float but found Half
    return _VF.einsum(equation, operands)  # type: ignore[attr-defined]
RuntimeError: expected scalar type Float but found Half
    coordinates_3 = coordinates_2 @ to_1;  coordinates_2 = to_1 = None
RuntimeError: expected mat1 and mat2 to have the same dtype, but got: float != c10::Half
    matmul_3 = attn_9 @ v_1;  attn_9 = v_1 = None
RuntimeError: expected scalar type BFloat16 but found Float

matmul_3 = attn_9 @ v_1 算子报错的sample只有1个 :
samples/mmpose/pose_swin_b

@Xreki
Copy link
Collaborator

Xreki commented Jan 21, 2026

#593 (comment) 中只是以matmul_3 = attn_9 @ v_1为例,表示@这一类,你贴出的freqs_x = unsqueeze @ unsqueeze_1coordinates_3 = coordinates_2 @ to_1都属于这一类。这类算子没有转换成功的原因是什么呢?

@WHoutstanding
Copy link
Contributor Author

重新设置白名单:
AMP_CALL_FUNCTION = {
torch.matmul,
torch.mm,
torch.bmm,
operator.matmul,
torch.nn.functional.linear,
torch.nn.functional.conv1d,
torch.nn.functional.conv2d,
torch.nn.functional.conv3d,
torch.nn.functional.scaled_dot_product_attention,
torch.addmm,
torch.einsum,
}

AMP_CALL_METHOD = {
"matmul",
"mm",
"bmm",
}

float16和bfloat16都失败的samples原来共有19个,重新设置白名单后成功转换15个samples

2个samples不需要rewrite:
model_path='samples/transformers-auto-model/Neurora_opus-tatoeba-heb-eng'
model_path='samples/transformers-auto-model/MoritzLaurer_xtremedistil-l6-h256-mnli-fever-anli-ling-binary'

2个samples没有成功转换,失败的为黑名单算子:

    attention_scores_3 = attention_scores_2.masked_fill(invert, -3.4028234663852886e+38);  attention_scores_2 = invert = None
RuntimeError: value cannot be converted to type at::Half without overflow
    return torch.layer_norm(
RuntimeError: expected scalar type Float but found BFloat16

Copy link
Collaborator

@Xreki Xreki left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

contributor External developers

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants