Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions src/brevitas_examples/common/generative/quant_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,17 @@

from typing import Callable

from dependencies import this
from dependencies import value
from torch import Tensor
import torch.nn as nn

from brevitas.core.restrict_val import _RestrictClampValue
from brevitas.core.stats.stats_wrapper import SCALAR_SHAPE
from brevitas.core.zero_point import _ScaleShiftZeroPoint
from brevitas.function.ops_ste import abs_binary_sign_grad
from brevitas.inject import ExtendedInjector
from brevitas.inject.enum import ScalingPerOutputType


# TODO: restore JIT compatibility
Expand Down Expand Up @@ -74,3 +79,32 @@ def forward(self, x, scale, bit_width) -> Tensor:
x = abs_binary_sign_grad(x)
x = self.scale_shift_zero_point(x, scale, bit_width)
return x


class QuantScaleScaleShapeMixin(ExtendedInjector):

@value
def scaling_shape(
scaling_per_output,
scaling_per_output_channel_shape,
expanded_groupwise_shape,
group_dim,
upstream_scaling):
if scaling_per_output == ScalingPerOutputType.TENSOR:
scaling = SCALAR_SHAPE
elif scaling_per_output == ScalingPerOutputType.CHANNEL:
scaling = scaling_per_output_channel_shape
elif scaling_per_output == ScalingPerOutputType.GROUP:
# Scaling shape is like expanded_groupwise_shape but has 1 in position group_dim + 1
assert expanded_groupwise_shape is not None, "Per Group scaling not correctly configured"
assert group_dim is not None, "Per Group scaling not correctly configured"
size = list(expanded_groupwise_shape)
size[group_dim + 1] = 1
scaling = tuple(size)

# When quantizing scale of groupwise, there will be one extra dim compared to the normal case
if upstream_scaling == ScalingPerOutputType.GROUP:
scaling = list(scaling)
scaling.insert(-1, 1)
scaling = tuple(scaling)
return scaling
83 changes: 78 additions & 5 deletions src/brevitas_examples/common/generative/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerTensorFloatMSE
from brevitas_examples.common.generative.nn import LoRACompatibleQuantConv2d
from brevitas_examples.common.generative.nn import LoRACompatibleQuantLinear
from brevitas_examples.common.generative.quantizers import DynamicQuantScaleMXFloat8e4m3Act
from brevitas_examples.common.generative.quantizers import Fp8e4m3DynamicActPerGroupFloat
from brevitas_examples.common.generative.quantizers import FP8e4m3FNUZDynamicActPerRowFloat
from brevitas_examples.common.generative.quantizers import Fp8e4m3FNUZDynamicActPerTensorFloat
Expand All @@ -75,13 +76,24 @@
from brevitas_examples.common.generative.quantizers import Int8DynamicActPerRowFloat
from brevitas_examples.common.generative.quantizers import Int8DynamicActPerTensorFloat
from brevitas_examples.common.generative.quantizers import IntWeightSymmetricGroupQuant
from brevitas_examples.common.generative.quantizers import QuantScaleMXFloat8e4m3Weight
from brevitas_examples.common.generative.quantizers import QuantScaleMXFloat8e4m3WeightMSE
from brevitas_examples.common.generative.quantizers import QuantScaleIntWeightSymmetricGroupQuant
from brevitas_examples.common.generative.quantizers import QuantScaleIntWeightSymmetricGroupQuantMSE
from brevitas_examples.common.generative.quantizers import RuntimeDynamicStatsZeroPoint
from brevitas_examples.common.generative.quantizers import ShiftedUint8DynamicActPerGroupFloat
from brevitas_examples.common.generative.quantizers import ShiftedUint8DynamicActPerRowFloat
from brevitas_examples.common.generative.quantizers import ShiftedUint8DynamicActPerTensorFloat

WEIGHT_QUANT_MAP = {
'int': {
'fixed_quant_scale': {
'stats': {
'per_group': {
'sym': QuantScaleIntWeightSymmetricGroupQuant}},
'mse': {
'per_group': {
'sym': QuantScaleIntWeightSymmetricGroupQuantMSE}}},
'float_scale': {
'stats': {
'per_tensor': {
Expand Down Expand Up @@ -143,6 +155,13 @@
'mse': {
'per_channel': {
'sym': Fp8e4m3OCPWeightPerChannelFloatMSE}}},
'float_quant_scale': {
'stats': {
'per_group': {
'sym': QuantScaleMXFloat8e4m3Weight}},
'mse': {
'per_group': {
'sym': QuantScaleMXFloat8e4m3WeightMSE}}},
'po2_scale': {
'stats': {
'per_group': {
Expand Down Expand Up @@ -222,6 +241,10 @@
'sym': FP8e4m3OCPDynamicActPerRowFloat},
'per_group': {
'sym': Fp8e4m3OCPDynamicActPerGroupFloat}}},
'float_quant_scale': {
'stats': {
'per_group': {
'sym': DynamicQuantScaleMXFloat8e4m3Act}}},
'po2_scale': {
'stats': {
'per_row': {
Expand Down Expand Up @@ -277,20 +300,41 @@ def generate_quantizers(
"""
# Retrive base input and weight quantizers
# match against custom float format
if re.compile(r'e[1-8]m[1-8]').findall(weight_quant_format):
format = re.compile(r'e[1-8]m[1-8]').findall(weight_quant_format)[0]
fpre = re.compile(r'e[1-8]m[1-8]')
if fpre.findall(weight_quant_format):
format = fpre.findall(weight_quant_format)[0]
weight_quant_format = weight_quant_format.replace('_' + format, '')
weight_float_format = {
'exponent_bit_width': int(format[1]), 'mantissa_bit_width': int(format[3])}
else:
weight_float_format = {}
if re.compile(r'e[1-8]m[1-8]').findall(input_quant_format):
format = re.compile(r'e[1-8]m[1-8]').findall(input_quant_format)[0]
if fpre.findall(input_quant_format):
format = fpre.findall(input_quant_format)[0]
input_quant_format = input_quant_format.replace('_' + format, '')
input_float_format = {
'exponent_bit_width': int(format[1]), 'mantissa_bit_width': int(format[3])}
else:
input_float_format = {}
if fpre.findall(weight_scale_precision):
format = fpre.findall(weight_scale_precision)[0]
weight_scale_precision_format = weight_scale_precision.replace('_' + format, '')
weight_scale_precision_format = {
'exponent_bit_width': int(format[1]),
'mantissa_bit_width': int(format[3]),
'bit_width': int(format[1]) + int(format[3]) + 1}
weight_scale_precision = "float_quant_scale"
else:
weight_scale_precision_format = {}
if fpre.findall(input_scale_precision):
format = fpre.findall(input_scale_precision)[0]
input_scale_precision_format = input_scale_precision.replace('_' + format, '')
input_scale_precision_format = {
'exponent_bit_width': int(format[1]),
'mantissa_bit_width': int(format[3]),
'bit_width': int(format[1]) + int(format[3]) + 1}
input_scale_precision = "float_quant_scale"
else:
input_scale_precision_format = {}

weight_quant = WEIGHT_QUANT_MAP[weight_quant_format][weight_scale_precision][
weight_param_method][weight_quant_granularity][weight_quant_type]
Expand Down Expand Up @@ -345,7 +389,30 @@ def generate_quantizers(
q_scaled_quant = q_scaled_quant.let(**input_kwargs) if q_scaled_quant is not None else None
attn_output_weights_quant = attn_output_weights_quant.let(
**input_kwargs) if attn_output_weights_quant is not None else None

if input_scale_precision == "float_quant_scale":
# Set the format of the input's quantized scale
input_quant = input_quant.let(
scaling_float_quant=input_quant.scaling_float_quant.let(
**input_scale_precision_format))
sym_input_quant = sym_input_quant.let(
scaling_float_quant=sym_input_quant.scaling_float_quant.let(
**input_scale_precision_format))
linear_input_quant = linear_input_quant.let(
scaling_float_quant=linear_input_quant.scaling_float_quant.let(
**input_scale_precision_format))
v_quant = v_quant.let(
scaling_float_quant=v_quant.scaling_float_quant.let(**input_scale_precision_format))
k_transposed_quant = k_transposed_quant.let(
scaling_float_quant=k_transposed_quant.scaling_float_quant.let(
**input_scale_precision_format))
if q_scaled_quant is not None:
q_scaled_quant = q_scaled_quant.let(
scaling_float_quant=q_scaled_quant.scaling_float_quant.let(
**input_scale_precision_format))
if attn_output_weights_quant is not None:
attn_output_weights_quant = attn_output_weights_quant.let(
scaling_float_quant=attn_output_weights_quant.scaling_float_quant.let(
**input_scale_precision_format))
else:
input_quant = None
sym_input_quant = None
Expand Down Expand Up @@ -385,6 +452,12 @@ def generate_quantizers(
if weight_quant_type == 'asym' and weight_scaling_impl_type == 'parameter_from_stats':
weight_quant = weight_quant.let(zero_point_impl=ParameterFromStatsFromParameterZeroPoint)

# Set the format of the weight's quantized scale
if weight_scale_precision == "float_quant_scale":
weight_quant = weight_quant.let(
scaling_float_quant=weight_quant.scaling_float_quant.let(
**weight_scale_precision_format))

if quant_attn_mode == 'sdpa':
kv_permute_dims = (0, 1, 3, 2)
kv_broadcastable_shape_lambda = lambda x, shape: x.view(shape[0], shape[1], 1, shape[-1])
Expand Down
79 changes: 79 additions & 0 deletions src/brevitas_examples/common/generative/quantizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
from brevitas.core.function_wrapper.ops_ste import FloorSte
from brevitas.core.function_wrapper.shape import OverOutputFeaturesView
from brevitas.core.function_wrapper.shape import OverTensorView
from brevitas.core.quant.float import FloatQuant
from brevitas.core.quant import RescalingIntQuant
from brevitas.core.restrict_val import QuantRestrictValue
from brevitas.core.scaling.runtime import RuntimeDynamicGroupStatsScaling
from brevitas.core.stats import AbsMinMax
from brevitas.core.stats import NegativeMinOrZero
Expand Down Expand Up @@ -35,9 +38,14 @@
from brevitas.quant.experimental.float_quant_fnuz import Fp8e4m3FNUZActPerTensorFloat
from brevitas.quant.experimental.float_quant_ocp import Fp8e4m3OCPActPerTensorFloat
from brevitas.quant.experimental.float_quant_ocp import Fp8e4m3OCPWeightPerChannelFloat
from brevitas.quant.experimental.float_quant_ocp import Fp8e4m3OCPWeightPerTensorFloat
from brevitas.quant.experimental.mx_quant_ocp import MXFloat8e4m3Act
from brevitas.quant.experimental.mx_quant_ocp import MXFloat8e4m3Weight
from brevitas.quant.fixed_point import Uint8ActPerTensorFixedPoint
from brevitas.quant.scaled_int import Int8ActPerTensorFloat
from brevitas.quant.scaled_int import Int8WeightPerChannelFloat
from brevitas.quant.scaled_int import Int8WeightPerChannelFloatHQO
from brevitas.quant.scaled_int import Int8WeightPerTensorFloat
from brevitas.quant.shifted_scaled_int import ShiftedUint8ActPerTensorFloat
from brevitas.quant.shifted_scaled_int import ShiftedUint8WeightPerChannelFloat

Expand Down Expand Up @@ -213,3 +221,74 @@ class FP8e4m3FNUZDynamicActPerRowFloat(Fp8e4m3FNUZActPerTensorFloat):
scaling_stats_op = 'min_max'
scaling_per_output_channel = True
proxy_class = DynamicActFloatQuantProxyFromInjector


class DynamicQuantScalingFloat(QuantScaleScaleShapeMixin,
DynamicActProxyMixin,
Fp8e4m3OCPActPerTensorFloat):
module = (this << 1).module
upstream_scaling = (this << 1).scaling_per_output_type
float_quant = FloatQuant
scaling_impl = RuntimeDynamicStatsScaling
scaling_stats_input_view_shape_impl = OverTensorView
scaling_stats_op = 'min_max'
dynamic_scaling_broadcastable_fn = lambda x, shape: x.view(SCALAR_SHAPE)


class DynamicQuantScaleMXFloat8e4m3Act(MXFloat8e4m3Act):
scaling_float_quant = DynamicQuantScalingFloat
restrict_scaling_impl = QuantRestrictValue

@value
def restrict_value_float_to_int_impl():
return this.scaling_float_quant.float_quant


class QuantScalingFloat(QuantScaleScaleShapeMixin, Fp8e4m3OCPWeightPerTensorFloat):
module = (this << 1).module
tracked_parameter_list = (this << 1).tracked_parameter_list
upstream_scaling = (this << 1).scaling_per_output_type
float_quant = FloatQuant


class QuantScaleMXFloat8e4m3Weight(MXFloat8e4m3Weight):
scaling_float_quant = QuantScalingFloat
restrict_scaling_impl = QuantRestrictValue

@value
def restrict_value_float_to_int_impl():
return this.scaling_float_quant.float_quant


class QuantScaleMXFloat8e4m3WeightMSE(MSESymmetricScale, MXFloat8e4m3Weight):
scaling_float_quant = QuantScalingFloat
restrict_scaling_impl = QuantRestrictValue

@value
def restrict_value_float_to_int_impl():
return this.scaling_float_quant.float_quant


class QuantWeightScalingFixed(QuantScaleScaleShapeMixin, Int8WeightPerTensorFloat):
module = (this << 1).module
upstream_scaling = (this << 1).scaling_per_output_type
tracked_parameter_list = (this << 1).tracked_parameter_list
signed = False


class QuantScaleIntWeightSymmetricGroupQuant(IntWeightSymmetricGroupQuant):
scaling_int_quant = QuantWeightScalingFixed
restrict_scaling_impl = QuantRestrictValue

@value
def restrict_value_float_to_int_impl():
return this.scaling_int_quant.tensor_quant


class QuantScaleIntWeightSymmetricGroupQuantMSE(MSESymmetricScale, IntWeightSymmetricGroupQuant):
scaling_int_quant = QuantWeightScalingFixed
restrict_scaling_impl = QuantRestrictValue

@value
def restrict_value_float_to_int_impl():
return this.scaling_int_quant.tensor_quant
32 changes: 32 additions & 0 deletions src/brevitas_examples/llm/benchmark/test_scale_format.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@

import torch

import brevitas.nn as qnn
from brevitas_examples.common.generative.quantizers import Int8DynamicActPerTensorFloat

class Uint8DynamicActPerTensorFloat(Int8DynamicActPerTensorFloat):
signed = False
narrow_range = True

class Uint7DynamicActPerTensorFloat(Uint8DynamicActPerTensorFloat):
bit_width=7

def test_scale_quant(model):
uint8 = qnn.QuantIdentity(act_quant=Uint8DynamicActPerTensorFloat)
uint7 = qnn.QuantIdentity(act_quant=Uint7DynamicActPerTensorFloat)
layers_tested = 0
layers_passed = 0
layers_failed = 0
for name, module in model.named_modules():
if isinstance(module, qnn.QuantLinear):
try:
weight_scale = module.quant_weight().scale
uint8.to(device=weight_scale.device)
uint7.to(device=weight_scale.device)
assert (weight_scale == uint8(weight_scale)).all()
assert not (weight_scale == uint7(weight_scale)).all()
layers_passed += 1
except:
layers_failed += 1
layers_tested += 1
print(f"Layers passed: {layers_passed}, Layers failed: {layers_failed}, Layers tested: {layers_tested}")
3 changes: 3 additions & 0 deletions src/brevitas_examples/llm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
from brevitas_examples.llm.llm_quant.run_utils import get_fx
from brevitas_examples.llm.llm_quant.svd_quant import apply_svd_quant

from brevitas_examples.llm.benchmark.test_scale_format import test_scale_quant

def filter_results(results, tasks):
# filter out what we actually want to track
Expand Down Expand Up @@ -555,6 +556,8 @@ def quantize_llm(args, extra_args=None):
for k, v in dict_hooks.items():
k._hf_hook.post_forward = v

test_scale_quant(model)

if args.eval and not args.no_quantize:

print("Model eval...")
Expand Down
34 changes: 32 additions & 2 deletions tests/brevitas_examples/test_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,8 @@ def test_small_models_acc(caplog, acc_args_and_acc):
"mistral-int8-quant-last-layer",
"llama-int8-svd_quant",
"opt-replace-mha",
"opt-quant-sdpa",],
"opt-quant-sdpa",
"llama-mxfp4-quant-scale",],
params=[
{
"model": "hf-internal-testing/tiny-random-MistralForCausalLM",
Expand Down Expand Up @@ -449,7 +450,36 @@ def test_small_models_acc(caplog, acc_args_and_acc):
"quant_sdpa": True,
"exp_layer_types": {
"scaled_dot_product_attention":
"<class 'brevitas.nn.quant_sdpa.QuantScaledDotProductAttention'>",}},])
"<class 'brevitas.nn.quant_sdpa.QuantScaledDotProductAttention'>",}},
{
"model": "hf-internal-testing/tiny-random-LlamaForCausalLM",
"weight_bit_width": 4,
"weight_quant_format": "float_ocp_e2m1",
"weight_scale_precision": "e4m3_scale",
"weight_param_method": "stats",
"weight_quant_granularity": "per_group",
"weight_group_size": 16,
"weight_quant_type": "sym",
"input_bit_width": 4,
"input_quant_format": "float_ocp_e2m1",
"input_scale_type": "dynamic",
"input_scale_precision": "e4m3_scale",
"input_param_method": "stats",
"input_quant_granularity": "per_group",
"input_group_size": 16,
"input_quant_type": "sym",
"act_calibration": False,
"exp_layer_types": {
"model.layers.0.self_attn.q_proj.input_quant.fused_activation_quant_proxy.tensor_quant.scaling_impl.restrict_clamp_scaling.restrict_value_impl":
"<class 'brevitas.core.restrict_val.QuantRestrictValue'>",
"model.layers.0.self_attn.q_proj.input_quant.fused_activation_quant_proxy.tensor_quant.scaling_impl.restrict_clamp_scaling.restrict_value_impl.float_to_int_impl":
"<class 'brevitas.core.quant.float.FloatQuant'>",
"model.layers.0.self_attn.q_proj.weight_quant.tensor_quant.scaling_impl.stats_scaling_impl.restrict_clamp_scaling.restrict_value_impl":
"<class 'brevitas.core.restrict_val.QuantRestrictValue'>",
"model.layers.0.self_attn.q_proj.weight_quant.tensor_quant.scaling_impl.stats_scaling_impl.restrict_clamp_scaling.restrict_value_impl.float_to_int_impl":
"<class 'brevitas.core.quant.float.FloatQuant'>",},
}, # MX weights/activations with minifloat-quantized scales
])
def layer_args(default_run_args, request):
args = default_run_args
layer_dict = request.param
Expand Down