Skip to content

Conversation

@Giuseppe5
Copy link
Collaborator

@Giuseppe5 Giuseppe5 commented Dec 24, 2025

Reason for this PR

Some models have weird forward passes for RMSNorm, like gemma that uses (1 + weight) * input instead of (weight) * input.

Similarly, we were not handling correctly the fact that llama/gemma and maybe other models cast up to float32 and then back to (b)float16 during the forward pass.

Changes Made in this PR

We dynamically create a new class type that inherits from torch.nn.RMSNorm and from whatever RMSNorm class the LLM is using.

Then we initialize the internal dict manually (a bit scary but it works).

Finally we replace the newly created Frankestein instance with the old one.

Importantly, the new instance passes the check isinstance(module, torch.nn.RMSNorm), which we use to apply rotations.

All of the above does not work because dynamo cannot trace through custom classes.
Instead what we do now is that we replace just for dynamo to pick up the OG torch.nn.RMSNorm, and then we put back whatever RMSNorm class was originally intended.

Side effect, we need to carry around these classes types for some of our algorithms.

Testing Summary

All existing, hopefully I didn't break too many.

@Giuseppe5 Giuseppe5 changed the title Feat (brevitas_examples/llm): best RMSNorm replacement Feat (brevitas_examples/llm): better RMSNorm replacement Dec 24, 2025
@Giuseppe5 Giuseppe5 requested a review from pablomlago December 29, 2025 16:01
delay_rewriters: bool = False,
return_rewriters: bool = False) -> None:
return_rewriters: bool = False,
extra_rmsnorm_classes: Optional[Tuple] = None) -> None:
Copy link
Collaborator

@pablomlago pablomlago Dec 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not a big fan of having this extra argument and requiring to propagate these classes from the context manager. Can we instead extend _is_scale_invariant_module (https://github.com/Xilinx/brevitas/blob/master/src/brevitas/graph/equalize.py#L855) to have this extra logic? Ideally, part of the logic of set(type(x) for x in model.modules() if 'RMS' in type(x).__name__)) could be extracted in a standalone method to be used both by _is_scale_invariant_module and rmsnorm_patch.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The issue is, RMS are not always "scale invariant", for example in the case of other algorithms like weight equalization, they aren't.

Also, this decouples better how we identify the RMSNorm modules versus how we check whether that is a scale invariant function/module.

I agree this is not ideal, but I don't fully agree with your suggestion either.
If anything, there should be a fully general way to customize all the attributes that are used during the region walk algorithm.

The easiest way would be to have a dict where a user can override all the keys, but then we would need to handle that within each class (or GraphRotation to start with), which could be verbose but doable.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is RMS not "scale invariant" for weight equalization?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because you compute the per-channel variance of the input tensor, which changes if you have a scale factor per channel before/after this op

def __init__(self, model, config, enabled=True):
self.model = model
self.config = config
self.enabled = enabled
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This attribute seem to only be used in the next line, so I would remove it and just do if enabled.

for r in rewriters:
self.model = r.apply(self.model)

self.model = self.model.to(dtype)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this cast necessary?

Copy link
Collaborator Author

@Giuseppe5 Giuseppe5 Dec 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because when we re-init modules, we don't propagate correctly dtype and device to quant modules. It could be redudant but better safe

eps=self.config.rms_norm_eps,
dtype=dtype,
device=device) for rms_cls in self.rmsnorm_classes]
dtype = next(iter(self.model.parameters())).dtype
Copy link
Collaborator

@pablomlago pablomlago Dec 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does dtype need to be retrieved again? It seems to have been registered at L35 already.

rewriters = []
dtype = next(self.model.parameters()).dtype
device = next(self.model.parameters()).device
for rms_class in self.rmsnorm_classes:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This logic seems fragile to me, e.g. if the signature had more than two parameters this would crash. Moreover, there would be no real need to instantiate again the original RMSNorm modules if their references are kept and then assigned back to the model. However, this would require some extra logic, as this behaviour is not supported by ModuleToModuleByClass, but I think it could be done easily enough.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not meant to be robust and to support all type of LLM (present or future).

Copy link
Collaborator

@pablomlago pablomlago Dec 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, but this seems too patchy, specially considering that the only reason why those modules need to be instantiated again is because their references were lost. If you kept those references, there would be no need to make any assumptions on the method's signature, as you would only need to assign the original modules back to the model.

for r in rewriters:
self.model = r.apply(self.model)

self.model = self.model.to(dtype)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this cast necessary?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As above


if args.replace_rmsnorm:
model = replace_rmsnorm_with_torch(model, model.config)
# if args.replace_rmsnorm:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove?

Copy link
Collaborator

@pablomlago pablomlago left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some extra changes might be needed. Also, it might be worth adding some tests for random tiny Gemma models, maybe hf-internal-testing/dummy-gemma?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants