Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions src/brevitas/graph/equalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -1832,7 +1832,10 @@ def find_sink(node):
return output_node

for sdpa_node in sdpa_nodes:
value_input = sdpa_node.args[-1]
if sdpa_node.kwargs.get('value', False):
value_input = sdpa_node.kwargs['value'] # value passed as kwarg
else:
value_input = sdpa_node.args[-1] # value passed as last arg

value_node = find_src(value_input)
output_node = find_sink(value_input)
Expand All @@ -1855,10 +1858,6 @@ def find_sink(node):
'value_sdpa': value_module, 'output_sdpa': output_module})
regions.append(region)

for m in graph_module.modules():
if isinstance(m, ScaledDotProductAttention):
m.pre_process_q = functional_rotate_input
m.pre_process_k = functional_rotate_input
return regions

def apply(self,
Expand Down
6 changes: 5 additions & 1 deletion src/brevitas/nn/equalized_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def __init__(
self.expansion_step = expansion_step
self.expand_input = expand_input

def forward(self, inp, **kwargs):
def rotate(self, inp, **kwargs):
is_cuda = 'cuda' in str(inp.device) and torch.version.cuda is not None
if self.expand_input:
# TODO: This only works for Linear layers. We have an assert in equalize.py to check for this
Expand All @@ -100,6 +100,10 @@ def forward(self, inp, **kwargs):
inp = matmul_hadU_cuda(inp, had_K, K)
else:
inp = matmul_hadU(inp)
return inp

def forward(self, inp, **kwargs):
inp = self.rotate(inp, **kwargs)
o = self.layer(inp)

return o
Expand Down
40 changes: 40 additions & 0 deletions src/brevitas_examples/llm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,6 +557,46 @@ def quantize_llm(args, extra_args=None):
remove_hooks(model)
with load_quant_model_mode(model):
model.load_state_dict(torch.load(args.checkpoint_name, map_location='cpu'))
model.eval()
for name, module in model.named_modules():
import brevitas.nn as qnn
if isinstance(module, qnn.QuantLinear):
qw = module.quant_weight()
print(f"{name}")
print(torch.unique(qw.value / qw.scale))
print(qw.scale.shape)
print(qw.bit_width)
print(qw.zero_point)
print(qw.signed)

# In case you want to test the activation quantization function
qi = module.input_quant
x = torch.rand((2,module.in_features), device=model.device, dtype=model.dtype)
qx = qi(x)
print(torch.unique((qx.value / qx.scale) - qx.zero_point))
#print(qx.scale) # Dynamic!
#print(qx.zero_point) # Dynamic!
print(qx.bit_width)
print(qx.signed_t)
elif isinstance(module, qnn.QuantScaledDotProductAttention):
# Find the attention quantization modules
for ni, qi in module.named_modules():
if isinstance(module, qnn.QuantIdentity):
rqt = qi.return_quant_tensor
qi.return_quant_tensor = True
qa = qi(torch.rand((1,2,3,4), device=qi.device, dtype=qi.dtype))
qi.return_quant_tensor = rqt
print(f"{name}-{ni}")
print(qa.scale)
print(qa.bit_width)
print(qa.zero_point)
print(qa.signed_t)
elif isinstance(module, qnn.equalized_layer.RotatedModule):
x = torch.eye(module.layer.in_features, device=model.device, dtype=model.dtype)
had = module.rotate(x)
print(had.shape)
print(had[:4,:4])
print(had[:8,:8])
model = offload_model(model)

if args.gptq and not args.load_checkpoint:
Expand Down