Feat (brevitas_examples/common): support for mse scale for weights quantized to float formats #1433
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Reason for this PR
It was not possible to run the LLM Examples with a configuration using
weight_quant_format: float_eXmYandweight_param_method: mse. That is, if a "regular" (no OCP) float was being used to quantize the weights, then it was not possible to use mean squared error to choose the scale. I need to use this configuration for some of my experiments.Changes Made in this PR
src/brevitas_examples/common/generative/quantizers.py, I added the classFp8e4m3WeightPerChannelFloatMSEthat inherits fromMSESymmetricScaleandFp8e4m3WeightPerChannelFloat. This was done following the same pattern used to support mse scale for Float OCP per-channel weights which was already available.src/brevitas_examples/common/generative/quantize.py, I added the entryfor
float,float_scalein the dictionaryWEIGHT_QUANT_MAP. This is how it is currently done for other data types (i.e. Float OCP).Testing Summary
tests/brevitas_examples/test_llm_cases.pyto ensure that the combination ofweight_quant_format: float_e2m1withweight_param_method: mseruns.weight_param_method: mse.