Skip to content

[bug] PerGroup(1) will throw RuntimeError #3458

@Freed-Wu

Description

@Freed-Wu
import torch
from torchao.quantization import IntxWeightOnlyConfig, quantize_
from torchao.quantization.quant_primitives import MappingType
from torchao.quantization.granularity import PerGroup


class ToyLinearModel(torch.nn.Module):
    def __init__(self, m: int, n: int, k: int):
        super().__init__()
        self.linear1 = torch.nn.Linear(m, n, bias=False)
        self.linear2 = torch.nn.Linear(n, k, bias=False)

    def forward(self, x):
        x = self.linear1(x)
        x = self.linear2(x)
        return x


model = ToyLinearModel(32, 32, 32).eval()

# Optional: compile model for faster inference and generation
# model = torch.compile(model, mode="max-autotune", fullgraph=True)
# model_bf16 = copy.deepcopy(model)
config = IntxWeightOnlyConfig(torch.int4, PerGroup(1), mapping_type=MappingType.ASYMMETRIC)
quantize_(model, config)
inp = torch.ones(2, 32)
model(inp)
$ python a.py
Traceback (most recent call last):
  File "/home/wzy/Desktop/ao/a.py", line 25, in <module>
    quantize_(model, config)
    ~~~~~~~~~^^^^^^^^^^^^^^^
  File "/home/wzy/Desktop/ao/torchao/quantization/quant_api.py", line 498, in quantize_
    _replace_with_custom_fn_if_matches_filter(
    ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^
        model,
        ^^^^^^
    ...<3 lines>...
        extra_args=(config,),
        ^^^^^^^^^^^^^^^^^^^^^
    )
    ^
  File "/home/wzy/Desktop/ao/torchao/quantization/quant_api.py", line 214, in _replace_with_custom_fn_if_matches_filter
    new_child = _replace_with_custom_fn_if_matches_filter(
        child,
    ...<4 lines>...
        extra_args,
    )
  File "/home/wzy/Desktop/ao/torchao/quantization/quant_api.py", line 209, in _replace_with_custom_fn_if_matches_filter
    model = replacement_fn(model, *extra_args)
  File "/home/wzy/Desktop/ao/torchao/quantization/quant_api.py", line 2375, in _intx_weight_only_transform
    new_weight = _intx_weight_only_quantize_tensor(
        module.weight,
    ...<2 lines>...
        custom_zero_point=custom_zero_point,
    )
  File "/home/wzy/Desktop/ao/torchao/quantization/quant_api.py", line 2320, in _intx_weight_only_quantize_tensor
    new_weight = IntxUnpackedToInt8Tensor.from_hp(
        weight,
    ...<5 lines>...
        intx_choose_qparams_algorithm=intx_choose_qparams_algorithm,
    )
  File "/home/wzy/Desktop/ao/torchao/quantization/quantize_/workflows/intx/intx_unpacked_to_int8_tensor.py", line 233, in from_hp
    qdata = quantize_affine(
        hp_tensor,
    ...<5 lines>...
        quant_max=qmax,
    )
  File "/usr/lib/python3.13/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/home/wzy/Desktop/ao/torchao/quantization/quant_primitives.py", line 357, in quantize_affine
    return _quantize_affine(
        input,
    ...<5 lines>...
        quant_max,
    )
  File "/usr/lib/python3.13/site-packages/torch/_ops.py", line 1158, in __call__
    return self._op(*args, **(kwargs or {}))
           ~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/wzy/Desktop/ao/torchao/quantization/quant_primitives.py", line 403, in _quantize_affine
    return _quantize_affine_no_dtype_cast(
           ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^
        input,
        ^^^^^^
    ...<4 lines>...
        quant_max,
        ^^^^^^^^^^
    ).to(output_dtype)
    ^
  File "/home/wzy/Desktop/ao/torchao/quantization/quant_primitives.py", line 460, in _quantize_affine_no_dtype_cast
    scale = scale.view(shape_after_reduction)
RuntimeError: shape '[32, 32]' is invalid for input of size 1

Other PerGroup(x) can work.

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions