-
Notifications
You must be signed in to change notification settings - Fork 238
Feat (brevitas_examples/llm): Support for batched inputs in GPXQ/Qronos forward passes. #1427
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Conversation
Co-authored-by: Pablo Monteagudo Lago <44771380+pablomlago@users.noreply.github.com>
Co-authored-by: Pablo Monteagudo Lago <44771380+pablomlago@users.noreply.github.com>
Co-authored-by: Pablo Monteagudo Lago <44771380+pablomlago@users.noreply.github.com>
…into batched-inp
pablomlago
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, I'll wait for @Giuseppe5 to review it before merging.
|
Update the description above to match the new version. You could paste the results in the PR description above |
Reason for this PR
Currently in
brevitas_examples/llmwhen GPTQ, GPFQ or Qronos are applied there is a need to run separate LLM forward passes for each of the samples in the calibration data. By performing these forward passes per batch instead of per sample it is possible to save considerable time when running these algorithms.For example, for Llama-3.2-1B applying GPTQ with the current setup and a calibration size of 512 takes around 1h whereas using the same configuration but processing the calibration data in batches of 128 samples reduces the runtime of the algorithm to ~0.5h (similar or greater speedups were observed for Qronos and GPFQ).
In private discussions during the PR review it was decided to extend the use of batched inputs to the other algorithms in
brevitas_examples/llm.main/pywhich could directly support it.Changes Made in this PR
--calibration-batchsizeas a new argument for LLM experiments allowing to choose the batch size for the forward pass of the LLM in GPXQ/Qronos algorithms as well as in the other algorithms supporting batched inputs. The default value is 1 to mimic the existing setup.Added a function in. Deleted after PR review this is now done directly inllm_quant/data_utils.pyto build a DataLoader from a DatasetToDevice which is the class currently used to handle the calibration data.main.py.llm_quant/data_utils.pythat is needed to create the DataLoader.main.pyadded code to instantiate and use the DataLoader, the data loader is stored in the variablecalibration_loaderand the dataset (previously namedcalibration_loader) is now namedcalibration_dataset. Similarly, the variablevalidation_dataloader(which was actually storing a dataset) has been renamed tovalidation_dataset.src/brevitas_examples/llm/llm_quant/rotation_optimization.pyremoved thecollate_fnto avoid having duplicated code and added instead a new function nameddata_collatoraccommodating to Hugging Face's interface that internally relies in the collate function that this PR adds inllm_quant/data_utils.py.tests/brevitas_examples/test_llm_data.pyto ensure the DataLoader is built correctly.Warning
Had to edit the quantized perplexity values in some test cases as for Pytorch inner reasons the random state changes after calling
iter(calibration_loader)(even when shuffle is False, and everything is set to be deterministic in the data loader). The change in the random state affected some algorithms, as for example random matrices may be used.Testing Summary
tests/brevitas_examples/test_llm_data.pyincluding the new test for the DataLoader were run locally.HuggingFaceTB/SmolLM-135Mwith weight only quantization to int4 forwikitext2dataset: