-
Notifications
You must be signed in to change notification settings - Fork 78
Description
Issue Type
Bug
Source
source
MCT Version
2.3.0
OS Platform and Distribution
No response
Python version
3.11
Describe the issue
MCT quantization crashes on ModelCollector when the tensor is scalar (1 element)
Expected behaviour
MCT quantization will pass
Code to reproduce the issue
import numpy as np
import model_compression_toolkit as mct
import torch
import torch.nn as nn
import edgemdt_tpc
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.conv = nn.Conv2d(3, 3, kernel_size=3, padding=1)
self.bias = nn.Parameter(torch.tensor([-1.0]))
self.scale = nn.Parameter(-1.0 * torch.ones([]))
def forward(self, x):
x = self.conv(x)
y = torch.exp(self.scale) * x + self.bias
return y
float_model = Model()
#a = float_model(torch.randn((1, 3, 224, 224)))
def representative_data_gen():
yield [np.random.random((1, 3, 224, 224))]
tpc = edgemdt_tpc.get_target_platform_capabilities("4.0")
quantized_model, _ = mct.ptq.pytorch_post_training_quantization(float_model,
representative_data_gen=representative_data_gen,
target_platform_capabilities=tpc)Log output
2025-05-26 09:53:54.049632: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable TF_ENABLE_ONEDNN_OPTS=0.
WARNING:tensorflow:From C:\Users\1000322838\PycharmProjects\model_optimization\venv\lib\site-packages\keras\src\losses.py:2976: The name tf.losses.sparse_softmax_cross_entropy is deprecated. Please use tf.compat.v1.losses.sparse_softmax_cross_entropy instead.
C:\Users\1000322838\PycharmProjects\model_optimization\venv\lib\site-packages\torchvision\io\image.py:13: UserWarning: Failed to load image Python extension: '[WinError 127] The specified procedure could not be found'If you don't plan on using image functionality from torchvision.io, you can ignore this warning. Otherwise, there might be something wrong with your environment. Did you have libjpeg or libpng installed before building torchvision from source?
warn(
representative_data_gen generates a batch size of 1 which can be slow for optimization: consider increasing the batch size
WARNING:Model Compression Toolkit:DepthwiseConv2D is not in model.
Statistics Collection: 0it [00:00, ?it/s]
Traceback (most recent call last):
File "C:\Users\1000322838\PycharmProjects\model_optimization\test2.py", line 29, in
quantized_model, _ = mct.ptq.pytorch_post_training_quantization(float_model,
File "C:\Users\1000322838\PycharmProjects\model_optimization\model_compression_toolkit\ptq\pytorch\quantization_facade.py", line 125, in pytorch_post_training_quantization
tg, bit_widths_config, _, scheduling_info = core_runner(in_model=in_module,
File "C:\Users\1000322838\PycharmProjects\model_optimization\model_compression_toolkit\core\runner.py", line 112, in core_runner
tg = quantization_preparation_runner(graph=graph,
File "C:\Users\1000322838\PycharmProjects\model_optimization\model_compression_toolkit\core\quantization_prep_runner.py", line 76, in quantization_preparation_runner
mi.infer(_data)
File "C:\Users\1000322838\PycharmProjects\model_optimization\model_compression_toolkit\core\common\model_collector.py", line 264, in infer
stats_container.update_statistics(self.fw_impl.to_numpy(activation_tensor),
File "C:\Users\1000322838\PycharmProjects\model_optimization\model_compression_toolkit\core\common\collectors\statistics_collector.py", line 91, in update_statistics
self.mc.update(x)
File "C:\Users\1000322838\PycharmProjects\model_optimization\model_compression_toolkit\core\common\collectors\mean_collector.py", line 91, in update
n = x.shape[axis]
IndexError: tuple index out of range