-
Notifications
You must be signed in to change notification settings - Fork 238
Feat (brevitas_examples/llm): better RMSNorm replacement #1436
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Conversation
9ed374b to
030d735
Compare
| delay_rewriters: bool = False, | ||
| return_rewriters: bool = False) -> None: | ||
| return_rewriters: bool = False, | ||
| extra_rmsnorm_classes: Optional[Tuple] = None) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not a big fan of having this extra argument and requiring to propagate these classes from the context manager. Can we instead extend _is_scale_invariant_module (https://github.com/Xilinx/brevitas/blob/master/src/brevitas/graph/equalize.py#L855) to have this extra logic? Ideally, part of the logic of set(type(x) for x in model.modules() if 'RMS' in type(x).__name__)) could be extracted in a standalone method to be used both by _is_scale_invariant_module and rmsnorm_patch.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The issue is, RMS are not always "scale invariant", for example in the case of other algorithms like weight equalization, they aren't.
Also, this decouples better how we identify the RMSNorm modules versus how we check whether that is a scale invariant function/module.
I agree this is not ideal, but I don't fully agree with your suggestion either.
If anything, there should be a fully general way to customize all the attributes that are used during the region walk algorithm.
The easiest way would be to have a dict where a user can override all the keys, but then we would need to handle that within each class (or GraphRotation to start with), which could be verbose but doable.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is RMS not "scale invariant" for weight equalization?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because you compute the per-channel variance of the input tensor, which changes if you have a scale factor per channel before/after this op
| def __init__(self, model, config, enabled=True): | ||
| self.model = model | ||
| self.config = config | ||
| self.enabled = enabled |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This attribute seem to only be used in the next line, so I would remove it and just do if enabled.
| for r in rewriters: | ||
| self.model = r.apply(self.model) | ||
|
|
||
| self.model = self.model.to(dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this cast necessary?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because when we re-init modules, we don't propagate correctly dtype and device to quant modules. It could be redudant but better safe
| eps=self.config.rms_norm_eps, | ||
| dtype=dtype, | ||
| device=device) for rms_cls in self.rmsnorm_classes] | ||
| dtype = next(iter(self.model.parameters())).dtype |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why does dtype need to be retrieved again? It seems to have been registered at L35 already.
| rewriters = [] | ||
| dtype = next(self.model.parameters()).dtype | ||
| device = next(self.model.parameters()).device | ||
| for rms_class in self.rmsnorm_classes: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This logic seems fragile to me, e.g. if the signature had more than two parameters this would crash. Moreover, there would be no real need to instantiate again the original RMSNorm modules if their references are kept and then assigned back to the model. However, this would require some extra logic, as this behaviour is not supported by ModuleToModuleByClass, but I think it could be done easily enough.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not meant to be robust and to support all type of LLM (present or future).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed, but this seems too patchy, specially considering that the only reason why those modules need to be instantiated again is because their references were lost. If you kept those references, there would be no need to make any assumptions on the method's signature, as you would only need to assign the original modules back to the model.
| for r in rewriters: | ||
| self.model = r.apply(self.model) | ||
|
|
||
| self.model = self.model.to(dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this cast necessary?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As above
|
|
||
| if args.replace_rmsnorm: | ||
| model = replace_rmsnorm_with_torch(model, model.config) | ||
| # if args.replace_rmsnorm: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove?
pablomlago
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some extra changes might be needed. Also, it might be worth adding some tests for random tiny Gemma models, maybe hf-internal-testing/dummy-gemma?
Reason for this PR
Some models have weird forward passes for RMSNorm, like gemma that uses
(1 + weight) * inputinstead of(weight) * input.Similarly, we were not handling correctly the fact that llama/gemma and maybe other models cast up to float32 and then back to (b)float16 during the forward pass.
Changes Made in this PR
We dynamically create a new class type that inherits from torch.nn.RMSNorm and from whatever RMSNorm class the LLM is using.Then we initialize the internal dict manually (a bit scary but it works).Finally we replace the newly created Frankestein instance with the old one.Importantly, the new instance passes the checkisinstance(module, torch.nn.RMSNorm), which we use to apply rotations.All of the above does not work because dynamo cannot trace through custom classes.
Instead what we do now is that we replace just for dynamo to pick up the OG torch.nn.RMSNorm, and then we put back whatever RMSNorm class was originally intended.
Side effect, we need to carry around these classes types for some of our algorithms.
Testing Summary
All existing, hopefully I didn't break too many.