-
Notifications
You must be signed in to change notification settings - Fork 45
Dtype gen pass #593
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Dtype gen pass #593
Conversation
|
Thanks for your contribution! |
|
float16和bfloat16都失败的samples原来共有19个,特殊处理后成功转换11个samples 但是会有新的node类型为call_function的算子或者node类型为call_method的节点Error |
|
测试samples: |
|
日志中,大部分样本报错跟 你这个PR处理的是amp的白名单算子,其他的报错跟amp的黑名单算子有关,可以先不继续。 |
|
手动设置如下白名单: AMP_CALL_METHOD = { float16和bfloat16都失败的samples原来共有19个,设置白名单后成功转换12个samples 2个samples不需要rewrite: 5个samples没有成功转换,error算子为: matmul_3 = attn_9 @ v_1 算子报错的sample只有1个 : |
|
#593 (comment) 中只是以 |
|
重新设置白名单: AMP_CALL_METHOD = { float16和bfloat16都失败的samples原来共有19个,重新设置白名单后成功转换15个samples 2个samples不需要rewrite: 2个samples没有成功转换,失败的为黑名单算子: |
Xreki
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
PR Category
other
Description
对以下算子进行特殊处理
torch.matmul,
torch.nn.functional.linear,
torch.nn.functional.conv2d,
torch.bmm,
torch.nn.functional.scaled_dot_product_attention,