Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 94 additions & 0 deletions docsrc/source/tutorials/llm_rotations.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
=================================
Rotations in Brevitas
=================================

Why are rotations important?
----------------------------------------------------------

Large Language Models exhibit *computational invariance* [1]_ meaning that applying an invertible linear operation to the output of certain modules (sources), and its inverse to the input of others (sinks), leaves the model's output unchanged (assuming sufficient precision). This property allows for the selective application of random orthogonal transformations, which effectively mitigate weight and activation outliers, enhancing their quantization amenability [2]_. Moreover, some rotations can be fused into the module weights, thus preserving floating-point inference performance.

Although random orthogonal rotations generally improve quantization amenability in low-bit regimes, the quantized network performance exhibits a large variance under different random rotations, as observed in [4]_. Consequently, these authors propose to further optimize the rotations to improve quantized performance. In order to do so, they leverage the Cailey-SGD optimizer to ensure that the optimized rotations stay within the Stiefel manifold during optimization [5]_.


Rotations in Brevitas
----------------------------------------------------------

Brevitas enables to add rotations to an arbitrary model in a fined-grained manner through a number of options, specified in the LLM entrypoint (``brevitas_examples/llm/llm_args.py``):

- ``--rotation`` [``'fx'``, ``'layerwise'``, ``'fused_no_fx'``]. If ``'layerwise'``, each linear layer is wrapped in a ``RotatedModule``, which rotates the input to the module by an orthogonal (Hadamard) matrix, while its inverse is fused into the weights of the linear layer. On the other hand, for ``'fx'`` or ``'fused_no_fx'``, Brevitas automatically detects the regions exhibiting rotation invariance, fusing the rotations into the weights of sources/sinks.
- ``--rotation-mode`` [``'had'``, ``'ort'``]. If ``'had'``, random Hadamard matrices are used for rotations, which provide tighter bounds and are more efficient to apply [3]_. Therefore, this option is generally preferable to ``'ort'``, which uses arbitrary random orthogonal matrices.
- ``--rotation-orphan-sink``. If enabled, linear layers that are not sinks in any other rotation-invariant region are wrapped in a ``RotatedModule``, as described for ``--rotation layerwise``.
- ``--rotation-sdpa-regions``. If enabled, the value/output region (R₂ in [4]_) is rotated.

Moreover, similarly to [5]_, Brevitas can leverage the Cailey-SGD optimizer to further optimize the rotations, which can be enabled by setting the flag ``--optimize-rotations``. The rotation training procedure relies on the `HF Trainer <https://huggingface.co/docs/transformers/en/main_classes/trainer>`_ class, and, therefore, can be configured by passing arguments accepted by the dataclass `TrainingArguments <https://huggingface.co/docs/transformers/en/main_classes/trainer#transformers.TrainingArguments>`_. Moreover, the number of samples used for rotation calibration can be configured through the parameter ``--nsamples-rot-calibration``.

Following, we provide a minimal example configuration for optimizing, in a single GPU, the rotations of a ``HuggingfaceTB/SmolLM2-135M`` model, with its weights quantized to 4 bits:

.. code-block:: yaml

dataset: wikitext2
eval: true
model: HuggingfaceTB/SmolLM2-135M
rotation: fused_no_fx
optimize_rotations: true
nsamples_rot_calibration: 800
replace_rmsnorm: true
weight_bit_width: 4
dtype: float32
learning_rate: 1.5
weight_decay: 0.0
lr_scheduler_type: cosine
max_steps: 100
per_device_train_batch_size: 2
gradient_accumulation_steps: 4
save_safetensors: false
logging_steps: 10
log_on_each_node: false

Note that the training parameters used in the SpinQuant paper [5]_ can be found in their `repository <https://github.com/facebookresearch/SpinQuant>`_.

Optimizing rotations in multiple GPUs
----------------------------------------------------------

As mentioned before, rotation optimization leverages the `HF Trainer <https://huggingface.co/docs/transformers/en/main_classes/trainer>`_ class. Therefore, to optimize rotations in a distributed environment, the LLM entrypoint has to be launched as an `accelerate script <https://huggingface.co/docs/accelerate/en/basic_tutorials/launch>`_ using the command ``accelerate launch``.

To do so, the first step is to select the environment configuration through the command ``accelerate config``, which provides an easy-to-use interface to specify the distributed environment. Once finished, a configuration file is generated, which can be passed to ``accelerate launch`` by setting the ``--config_file`` flag. Following, we provide an example configuration for a single-node environment with 2 GPUs:

.. code-block:: yaml

compute_environment: LOCAL_MACHINE
debug: false
distributed_type: MULTI_GPU
downcast_bf16: 'no'
enable_cpu_affinity: false
gpu_ids: 0,1
machine_rank: 0
main_training_function: main
mixed_precision: 'no'
num_machines: 1
num_processes: 2
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

Once the configuration file is generated, the LLM entrypoint can be run in a distributed fashion as follows:

.. code-block:: shell

accelerate launch --config_file ${configFolder}/accelerate_config.yaml ${workspaceFolder}/src/brevitas_examples/llm/main.py --config ${configFolder}/experiment_config.yaml

Caveats
----------------------------------------------------------

Currently, we only support distributed training using ``DistributedDataParallel``, and we plan to provide support for ``DeepSpeed`` and ``FullyShardedDataParallel`` in the future.

.. rubric:: References

.. [1] Ashkboos, S., Croci, M. L., Nascimento, M. G. D., Hoefler, T., & Hensman, J. (2024). Slicegpt: Compress large language models by deleting rows and columns. arXiv preprint arXiv:2401.15024.
.. [2] Ashkboos, S., Mohtashami, A., Croci, M., Li, B., Cameron, P., Jaggi, M., ... & Hensman, J. (2025). Quarot: Outlier-free 4-bit inference in rotated llms. Advances in Neural Information Processing Systems, 37, 100213-100240.
.. [3] Tseng, A., Chee, J., Sun, Q., Kuleshov, V., & De Sa, C. (2024). Quip#: Even better llm quantization with hadamard incoherence and lattice codebooks. arXiv preprint arXiv:2402.04396.
.. [4] Liu, Z., Zhao, C., Fedorov, I., Soran, B., Choudhary, D., Krishnamoorthi, R., ... & Blankevoort, T. (2024). Spinquant: Llm quantization with learned rotations. arXiv preprint arXiv:2405.16406.
.. [5] Li, J., Fuxin, L., & Todorovic, S. (2020). Efficient riemannian optimization on the stiefel manifold via the cayley transform. arXiv preprint arXiv:2002.01113.
8 changes: 6 additions & 2 deletions src/brevitas/graph/equalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -1585,6 +1585,7 @@ def __init__(
sdpa_regions: bool = False,
rotate_matmul: bool = False,
use_parametrized_rotations: bool = False,
apply_inplace_rotations: bool = True,
full_rotation_method: str = 'had',
layers_to_expand: Optional[List[str]] = None,
return_rewriters: bool = False) -> None:
Expand Down Expand Up @@ -1613,6 +1614,7 @@ def __init__(
"Using parametrized results might break type-checking, which could lead to unexpected behaviour."
)
self.use_parametrized_rotations = use_parametrized_rotations
self.apply_inplace_rotations = apply_inplace_rotations

def rotate_matmuls(self, graph_module):
matmul_nodes = list(graph_module.graph.nodes)
Expand Down Expand Up @@ -1737,13 +1739,15 @@ def apply(self,
graph_model,
first_set,
self.full_rotation_method,
fuse_rotations=not self.use_parametrized_rotations))
fuse_rotations=not self.use_parametrized_rotations,
apply_inplace_rotations=self.apply_inplace_rotations))
rewriters.extend(
_apply_rotate(
graph_model,
second_set,
self.full_rotation_method,
fuse_rotations=not self.use_parametrized_rotations))
fuse_rotations=not self.use_parametrized_rotations,
apply_inplace_rotations=self.apply_inplace_rotations))
if len(expanded_regions) > 0:
parameter_number_post = 0
for m in graph_model.parameters():
Expand Down
44 changes: 44 additions & 0 deletions src/brevitas_examples/llm/config/optimized_rotation_template.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
convert_layernorm_to_rmsnorm: false
dataset: wikitext2
dtype: float32
eval: true
# Input quantization parameters
input_bit_width: 4
input_group_size: 32
input_param_method: stats
input_quant_format: int
input_quant_granularity: per_row
input_quant_type: asym
input_scale_precision: float_scale
input_scale_type: dynamic
# Model to quantize
model: HuggingfaceTB/SmolLM2-135M
# Rotation-related parameters
optimize_rotations: true
replace_rmsnorm: true
rotation: fused_no_fx
rotation_sdpa_regions: true
rotation_mode: had
rotation_orphan_sink: true
# Weight quantization parameters
weight_bit_width: 4
weight_equalization: false
weight_group_dim: null
weight_group_size: null
weight_param_method: mse
weight_quant_format: int
weight_quant_granularity: per_channel
weight_quant_type: sym
weight_scale_precision: float_scale
# HuggingFace TrainerArguments
learning_rate: 1.5
weight_decay: 0.0
lr_scheduler_type: cosine
max_steps: 100
save_safetensors: false
per_device_train_batch_size: 4
logging_steps: 10
gradient_accumulation_steps: 2
log_on_each_node: false
torch_empty_cache_steps: 1
gradient_checkpointing: true
65 changes: 65 additions & 0 deletions src/brevitas_examples/llm/llm_quant/distributed_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause

from argparse import Namespace
import functools
import logging
import os
from typing import Callable, List

import torch

from brevitas_examples.common.accelerate_utils.accelerate import offload_model
from brevitas_examples.common.accelerate_utils.accelerate import remove_hooks


# If the environment variable 'LOCAL_RANK' is not set, a single
# process is running, so os.environ.get('LOCAL_RANK', -1) returns
# -1.
def is_multi_process():
return int(os.environ.get('LOCAL_RANK', -1)) != -1


def is_main_process():
return int(os.environ.get('LOCAL_RANK', -1)) in [-1, 0]


def on_process(func: Callable, process_index: int):

@functools.wraps(func)
def _wrapper(model, *args, **kwargs):
curr_process_index = int(os.environ.get('LOCAL_RANK', -1))
# TODO: Change to logging.debug
if curr_process_index == -1:
logging.debug(f"Applying {func.__name__} on main process")
return func(model, *args, **kwargs)
elif process_index == curr_process_index:
logging.debug(f"Applying {func.__name__} on process index {curr_process_index}")
return func(model, *args, **kwargs)
else:
logging.debug(
f"Skipping function {func.__name__} on process index {curr_process_index}")
return model

return _wrapper


on_main_process = functools.partial(on_process, process_index=0)


def validate_distributed_args(args: Namespace) -> None:
assert args.optimize_rotations, "The entry-point should be run as a single-process if rotations are not being optimized."


class dist_offload_model:

def __init__(self, model: torch.nn.Module) -> None:
self.model = model

def __enter__(self):
if is_main_process():
self.model = offload_model(self.model)

def __exit__(self, type, value, traceback):
if is_main_process():
remove_hooks(self.model)
Loading