Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/brevitas/graph/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,18 +435,18 @@ class ModuleToModuleByClass(ModuleToModule):
def __init__(self, old_module_class, new_module_class, **kwargs):
super().__init__(new_module_class, **kwargs)
self.old_module_class = old_module_class
self.old_new_module_dict = {}

def apply(self, model: GraphModule) -> GraphModule:
old_new_module_dict = {}
for old_module in model.modules():
# check for equality, not inheritance
if type(old_module) == self.old_module_class:
# init the new module based on the old one
new_module = self.init_new_module(old_module)
# register modules pair to be replaced
old_new_module_dict[old_module] = new_module
self.old_new_module_dict[old_module] = new_module
# replace all pairs registered
for old_module, new_module in old_new_module_dict.items():
for old_module, new_module in self.old_new_module_dict.items():
replace_module(model, old_module, new_module)
return model

Expand Down
102 changes: 62 additions & 40 deletions src/brevitas/graph/equalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -1785,10 +1785,23 @@ def _merge_ln(layer_norm, next_module, scale_bias_by_weight):
_replace_bias(next_module, new_bias)


class StateMixin:

def __init__(self, base_state_kwargs, extra_state_kwargs=None):

self.full_state_kwargs = dict()
extra_state_kwargs = dict() if extra_state_kwargs is None else extra_state_kwargs

for d in (base_state_kwargs, extra_state_kwargs):
for key, value in d.items():
current_value = self.full_state_kwargs.get(key, ())
current_value = current_value + value
self.full_state_kwargs[key] = current_value


class RotationEqualization(GraphTransform):

def __init__(self, blacklist_layers, layers_to_expand) -> None:
super(RotationEqualization, self).__init__()
if blacklist_layers is not None:
self.blacklist_layers = blacklist_layers
else:
Expand All @@ -1797,19 +1810,19 @@ def __init__(self, blacklist_layers, layers_to_expand) -> None:
self.layers_to_expand = layers_to_expand
else:
self.layers_to_expand = []
self.supported_sinks = ()

def find_module(
self,
model: nn.Module,
regions: List[Region],
supported_sinks: tuple,
prefix: str = '',
blacklist_layers: Optional[List[str]] = None):
"""
Iterate through the model looking at immediate children of every module to look for supported modules.
This allows us to stop the search when we meet a top-level module that is supported.
"""
if isinstance(model, self.supported_sinks):
if isinstance(model, supported_sinks):
if prefix in blacklist_layers:
return
weight = get_weight_sink(model)
Expand All @@ -1820,7 +1833,7 @@ def find_module(
else:
for name, module in model.named_children():
full_name = prefix + '.' + name if prefix != '' else name
self.find_module(module, regions, full_name, blacklist_layers)
self.find_module(module, regions, supported_sinks, full_name, blacklist_layers)

def find_module_by_name(self, model: nn.Module, regions: List[Region], prefix: str = ''):
"""
Expand Down Expand Up @@ -1852,7 +1865,7 @@ def transform_model(
return apply_rewriters(model, rewriters)


class GraphRotationEqualization(RotationEqualization):
class GraphRotationEqualization(RotationEqualization, StateMixin):

def __init__(
self,
Expand All @@ -1866,16 +1879,20 @@ def __init__(
layers_to_expand: Optional[List[str]] = None,
expansion_step: int = None,
delay_rewriters: bool = False,
return_rewriters: bool = False) -> None:
super(GraphRotationEqualization, self).__init__(blacklist_layers, layers_to_expand)
return_rewriters: bool = False,
extra_state_kwargs: Optional[Dict[str, Tuple]] = None) -> None:
RotationEqualization.__init__(self, blacklist_layers, layers_to_expand)

self.supported_srcs = (nn.Linear, nn.Embedding)
self.supported_sinks = (nn.Linear)
common_scale_invariant = list(_scale_invariant_layers)
common_scale_invariant.remove(torch.nn.ReLU)
common_scale_invariant.remove(torch.nn.LeakyReLU)
self.scale_invariant_layers = tuple(common_scale_invariant) + (RMSNorm,)
self.scale_invariant_function = ()
base_state_kwargs = {
'supported_srcs': (nn.Linear, nn.Embedding),
'supported_sinks': (nn.Linear,),
'scale_invariant_layers': tuple(common_scale_invariant) + (RMSNorm,),
'scale_invariant_function': ()}
StateMixin.__init__(self, base_state_kwargs, extra_state_kwargs)

self.orphan_sink = orphan_sink
self.rotate_matmul = rotate_matmul
self.full_rotation_method = full_rotation_method
Expand Down Expand Up @@ -1992,13 +2009,7 @@ def find_sink(node):
def apply(self,
graph_model: GraphModule) -> Union[Tuple[GraphModule, List[Transform]], GraphModule]:
rewriters = []
regions = _extract_regions(
graph_model,
state_impl_kwargs={
'supported_srcs': self.supported_srcs,
'supported_sinks': self.supported_sinks,
'scale_invariant_layers': self.scale_invariant_layers,
'scale_invariant_function': self.scale_invariant_function})
regions = _extract_regions(graph_model, state_impl_kwargs=self.full_state_kwargs)

expanded_regions = []
self.find_module_by_name(graph_model, expanded_regions)
Expand All @@ -2007,7 +2018,11 @@ def apply(self,

if self.orphan_sink:
blacklist_orphan_layers = self.blacklist_layers + self.layers_to_expand
self.find_module(graph_model, orphan_regions, blacklist_layers=blacklist_orphan_layers)
self.find_module(
graph_model,
orphan_regions,
self.full_state_kwargs['supported_sinks'],
blacklist_layers=blacklist_orphan_layers)

if len(expanded_regions) > 0:
parameter_number_pre = 0
Expand Down Expand Up @@ -2095,20 +2110,23 @@ def apply_rewriters(
return model


class LayerNormToRMS(GraphTransform):
class LayerNormToRMS(GraphTransform, StateMixin):

def __init__(
self,
return_rewriters: bool = False,
extra_state_kwargs: Optional[Dict[str, Tuple]] = None) -> None:
GraphTransform.__init__(self)

base_state_kwargs = {
'supported_srcs': (nn.Linear, nn.Embedding), 'supported_sinks': (nn.LayerNorm,)}
StateMixin.__init__(self, base_state_kwargs, extra_state_kwargs)

def __init__(self, return_rewriters=False) -> None:
super(LayerNormToRMS, self).__init__()
self.supported_srcs = (nn.Linear, nn.Embedding)
self.supported_sinks = (nn.LayerNorm)
self.return_rewriters = return_rewriters
assert RMSNorm is not object, 'Update your Pytorch version to 2.4+'

def apply(self, graph_model: GraphModule) -> GraphModule:
regions = _extract_regions(
graph_model,
state_impl_kwargs={
'supported_srcs': self.supported_srcs, 'supported_sinks': self.supported_sinks})
regions = _extract_regions(graph_model, state_impl_kwargs=self.full_state_kwargs)

rewriters = []
if len(regions) > 0:
Expand Down Expand Up @@ -2141,18 +2159,17 @@ def apply(self, graph_model: GraphModule) -> GraphModule:
return graph_model


class MergeLnAffine(GraphTransform):
class MergeLnAffine(GraphTransform, StateMixin):

def __init__(self) -> None:
super(MergeLnAffine, self).__init__()
def __init__(self, extra_state_kwargs: Optional[Dict[str, Tuple]] = None) -> None:
GraphTransform.__init__(self)
self.supported_srcs = (RMSNorm, nn.LayerNorm)
self.supported_sinks = (nn.Linear)
base_state_kwargs = {
'supported_srcs': (RMSNorm, nn.LayerNorm), 'supported_sinks': (nn.Linear,)}
StateMixin.__init__(self, base_state_kwargs, extra_state_kwargs)

def apply(self, graph_model: GraphModule) -> GraphModule:
regions = _extract_regions(
graph_model,
state_impl_kwargs={
'supported_srcs': self.supported_srcs, 'supported_sinks': self.supported_sinks})
regions = _extract_regions(graph_model, state_impl_kwargs=self.full_state_kwargs)

if len(regions) > 0:
scaled_biases = set()
Expand Down Expand Up @@ -2180,18 +2197,23 @@ def __init__(
blacklist_layer: Optional[List] = None,
layers_to_expand: Optional[List] = None,
expansion_step: int = 0,
block_rotation_dim: Optional[int] = None):
super().__init__(blacklist_layer, layers_to_expand)
block_rotation_dim: Optional[int] = None,
extra_state_kwargs: Optional[Dict[str, Tuple]] = None):

RotationEqualization.__init__(self, blacklist_layer, layers_to_expand)
self.expansion_step = expansion_step
self.supported_sinks = (nn.Linear)
self.block_rotation_dim = block_rotation_dim
self.supported_sinks = (nn.Linear,)
# base_state_kwargs = {'supported_sinks': (nn.Linear,)}
# StateMixin.__init__(self, base_state_kwargs, extra_state_kwargs)

def apply(self, model: nn.Module) -> nn.Module:
regions: List[Region] = []
rewriters: List[Transform] = []

blacklist_orphan_layers = self.blacklist_layers + self.layers_to_expand
self.find_module(model, regions, blacklist_layers=blacklist_orphan_layers)
self.find_module(
model, regions, self.supported_sinks, blacklist_layers=blacklist_orphan_layers)
expanded_regions = []
self.find_module_by_name(model, expanded_regions)

Expand Down
73 changes: 52 additions & 21 deletions src/brevitas_examples/llm/llm_quant/ln_affine_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,36 +3,67 @@
SPDX-License-Identifier: MIT
"""

from inspect import signature

from packaging import version
import torch
from torch import nn

from brevitas import torch_version
from brevitas.graph.base import ModuleToModuleByClass
from brevitas.graph import ModuleInstanceToModuleInstance
from brevitas.graph import ModuleToModuleByClass
from brevitas.graph.equalize import _is_scale_invariant_module
from brevitas.graph.equalize import LayerNormToRMS
from brevitas.graph.equalize import MergeLnAffine
from brevitas.graph.utils import get_module


def replace_rmsnorm_with_torch(model, config):
assert torch_version >= version.parse('2.4'), "torch.nn.RMSNorm requires torch 2.4 or greater"
set_of_layers = set(type(x) for x in model.modules() if 'RMS' in type(x).__name__)
dtype = next(model.parameters()).dtype
device = next(model.parameters()).device
rewriters = [
ModuleToModuleByClass(
rms_cls,
torch.nn.RMSNorm,
normalized_shape=lambda module: module.weight.shape[0],
eps=config.rms_norm_eps,
dtype=dtype,
device=device) for rms_cls in set_of_layers]
dtype = next(iter(model.parameters())).dtype
for r in rewriters:
model = r.apply(model)
model = model.to(dtype)
return model
class rmsnorm_patch:

def __init__(self, model, config, enabled=True):
self.model = model
self.config = config
if enabled:
self.rmsnorm_classes = tuple(
set(type(x) for x in model.modules() if 'RMS' in type(x).__name__))
else:
self.rmsnorm_classes = tuple()
self.mapping = dict()

def __enter__(self):
assert torch_version >= version.parse('2.4'), "torch.nn.RMSNorm requires torch 2.4 or greater"

dtype = next(self.model.parameters()).dtype
device = next(self.model.parameters()).device

rewriters = [
ModuleToModuleByClass(
rms_cls,
torch.nn.RMSNorm,
normalized_shape=lambda module: module.weight.shape[0],
eps=self.config.rms_norm_eps,
dtype=dtype,
device=device) for rms_cls in self.rmsnorm_classes]

for r in rewriters:
self.model = r.apply(self.model)
self.mapping.update(r.old_new_module_dict)

self.model = self.model.to(dtype)
return self

def __exit__(self, *args, **kwargs):
rewriters = []
dtype = next(self.model.parameters()).dtype

for old_module, new_module in self.mapping.items():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to iterate twice. It can be done as:

for old_module, new_module in self.mapping.items():
  r = ModuleInstanceToModuleInstance(old_module, new_module)
  self.model = r.apply(self.model)

rewriter = ModuleInstanceToModuleInstance(old_module, new_module)
rewriters.append(rewriter)

for r in rewriters:
self.model = r.apply(self.model)

self.model = self.model.to(dtype)


def replace_bias(next_module, new_bias):
Expand Down Expand Up @@ -106,8 +137,8 @@ def merge_layernorm_affine_params(graph_model):


@torch.no_grad()
def apply_layernorm_affine_merge(graph_model):
eq = MergeLnAffine()
def apply_layernorm_affine_merge(graph_model, rmsnorm_classes):
eq = MergeLnAffine(extra_state_kwargs={'scale_invariant_layers': rmsnorm_classes})
graph_model = eq.apply(graph_model)
return graph_model

Expand Down
22 changes: 12 additions & 10 deletions src/brevitas_examples/llm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
from brevitas_examples.llm.llm_quant.learned_round_utils import apply_learned_round
from brevitas_examples.llm.llm_quant.ln_affine_merge import apply_layernorm_affine_merge
from brevitas_examples.llm.llm_quant.ln_affine_merge import apply_layernorm_to_rmsnorm
from brevitas_examples.llm.llm_quant.ln_affine_merge import replace_rmsnorm_with_torch
from brevitas_examples.llm.llm_quant.ln_affine_merge import rmsnorm_patch
from brevitas_examples.llm.llm_quant.prepare_for_quantize import add_zero_bias_to_linear
from brevitas_examples.llm.llm_quant.prepare_for_quantize import make_dynamo_compatible
from brevitas_examples.llm.llm_quant.prepare_for_quantize import \
Expand Down Expand Up @@ -89,7 +89,8 @@ def filter_results(results, tasks):


def fused_rotation_no_fx(model, calibration_loader, args):
with torch.no_grad():
with torch.no_grad(), rmsnorm_patch(model, model.config) as patcher:
rmsnorm_classes = patcher.rmsnorm_classes
with make_dynamo_compatible(model) as dynamo_comp:
fx_model, guards = torch._dynamo.export(dynamo_comp.model)(**calibration_loader[0])
if hasattr(model, str(torch.nn.functional.scaled_dot_product_attention)):
Expand All @@ -102,7 +103,7 @@ def fused_rotation_no_fx(model, calibration_loader, args):
if any(map(lambda x: x in name, args.rotation_layers_to_expand)):
layers_to_expand.append(name)

apply_layernorm_affine_merge(fx_model)
apply_layernorm_affine_merge(fx_model, rmsnorm_classes)
# NOTE: This call breaks ties between the the lm_head and the embedding layer
fx_model, rewriters = apply_layernorm_to_rmsnorm(fx_model, return_rewriters=True)
rewriters = fix_rewriter(rewriters, model, 'weight')
Expand All @@ -124,7 +125,8 @@ def fused_rotation_no_fx(model, calibration_loader, args):
delay_rewriters=delay_rewriters,
expansion_step=args.expansion_step,
layers_to_expand=layers_to_expand,
block_rotation_dim=args.block_rotation_dim)
block_rotation_dim=args.block_rotation_dim,
extra_state_kwargs={'scale_invariant_layers': rmsnorm_classes})
fx_model, rewriters = eq.apply(fx_model)

model = offload_model(model)
Expand Down Expand Up @@ -283,11 +285,10 @@ def quantize_llm(args, extra_args=None):
remove_hooks(model)
print(f"Float perplexity ({args.dataset}): {float_ppl:.3f}")

if args.replace_rmsnorm:
model = replace_rmsnorm_with_torch(model, model.config)

rmsnorm_classes = ()
if require_fx:
with torch.no_grad():
with torch.no_grad(), rmsnorm_patch(model, model.config, enabled=args.replace_rmsnorm) as patcher:
rmsnorm_classes = patcher.rmsnorm_classes
with make_dynamo_compatible(model) as dynamo_comp:
model, guards = torch._dynamo.export(dynamo_comp.model)(**calibration_loader[0])
# Blockwise optimization does not work with FX at the moment
Expand All @@ -298,7 +299,7 @@ def quantize_llm(args, extra_args=None):
# since currently there is support only for merging into Linear
if args.ln_affine_merge:
print("Apply LN affine merge...")
apply_layernorm_affine_merge(model)
apply_layernorm_affine_merge(model, rmsnorm_classes)
print("LN affine merge applied.")

if args.convert_layernorm_to_rmsnorm:
Expand Down Expand Up @@ -337,7 +338,8 @@ def quantize_llm(args, extra_args=None):
use_parametrized_rotations=args.optimize_rotations,
expansion_step=args.expansion_step,
layers_to_expand=layers_to_expand,
block_rotation_dim=args.block_rotation_dim)
block_rotation_dim=args.block_rotation_dim,
extra_state_kwargs={'scale_invariant_layers': rmsnorm_classes})
model = eq.apply(model)
remove_hooks(model)
elif args.rotation == 'layerwise':
Expand Down