-
Notifications
You must be signed in to change notification settings - Fork 238
Feat (graph/equalize): implement permute regions #1380
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Conversation
Co-authored-by: Pablo Monteagudo Lago <44771380+pablomlago@users.noreply.github.com>
…itas into refactor_rotation
| self.rewriters = None | ||
|
|
||
| def __enter__(self): | ||
| model, rewriters = self.rotation.apply(self.model) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| model, rewriters = self.rotation.apply(self.model) | |
| self.model, self.rewriters = self.rotation.apply(self.model) |
and remove the two following lines.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This file is over 2K lines of code at the moment and it contains a lot of functionality, so maybe it is worth splitting it in a future refactor. One possibility would be to have a file with the base equalization functionality, and then other file with the code for scalar/rotation/permutation equalization, but other options could be explored.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One possibility would be to have a file with the base equalization functionality, and then other file with the code for scalar/rotation/permutation equalization, but other options could be explored.
Agreed, this would be better... Do we do that in this PR or in a future refactor of this section?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It will be done in the next release, as part of the re-org of our PTQ algorithms.
I understand it can be done without necessary breaking anything on the outside, but it will still be a half-job if we did it that way
| weight.size(0), -1))[self.equalization_indexes.start:self.equalization_indexes.end] | ||
|
|
||
| def permute(self, permute_index): | ||
| permutation_list = [] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe I would opt for doing a list comprehension:
permutation_tuple = tuple(permute_index if i == dim else slice(size)for dim, size in enumerate(self.module.weight.shape))
Also, for indexing the tensor I would pass a tuple, instead of a list, to prevent the following warning:
<stdin>:1: UserWarning: Using a non-tuple sequence for multidimensional indexing is deprecated and will be changed in pytorch 2.9; use x[tuple(seq)] instead of x[seq]. In pytorch 2.9 this will be interpreted as tensor index, x[torch.tensor(seq)], which will result either in an error or a different result (Triggered internally at /pytorch/torch/csrc/autograd/python_variable_indexing.cpp:306.)
| class PermuteGraph(): | ||
|
|
||
| def __init__(self): | ||
| super().__init__() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this needed? PermuteGraph does not seem to inherit from other class.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One possibility would be to have a file with the base equalization functionality, and then other file with the code for scalar/rotation/permutation equalization, but other options could be explored.
Agreed, this would be better... Do we do that in this PR or in a future refactor of this section?
| device = next(single_module.parameters()).device | ||
| dtype = next(single_module.parameters()).dtype | ||
|
|
||
| # If equalization criteria are not met, we return a scalar one to indicate that no equalization |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this supposed to return a scalar 1?
| if delay_rewriters: | ||
| return model | ||
|
|
||
| if not hasattr(model, '_hf_map'): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's move this to a method in accelerate_utils.
| if offload_model is None or remove_hooks is None: | ||
| raise RuntimeError("Accelerate is not installed") | ||
| # if we use _hf_map to check and all the model is on a single GPU, then all rewriters are safe | ||
| if len(model._hf_map.values()) > 1: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's move this to a method in accelerate_utils.
| equalized_layers.update(r[1]) | ||
|
|
||
| # Check that we found all the expected regions | ||
| print(len(regions)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove or use logger?
| sinks = region.sinks_names | ||
| sinks_check = set(sinks) == set(expected_region[1]) | ||
| print(len(srcs), len(expected_region[0])) | ||
| print(srcs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove or use logger?
| weight = weight.cpu().to(torch.float32) | ||
| return scale_fn(weight.reshape(weight.size(0), -1)) | ||
|
|
||
| def permute(self, permute_index): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This logic is repeated in SinkWrapper (excluding the bias permutation). I would factor out the common logic to ModuleWrapper (or even a standalone method) and then do the appropiate calls in Sink/SourceWrapper.
| self.module.weight.data = self.module.weight.data[permutation_list] | ||
|
|
||
|
|
||
| def new_axis(x, block_size=32): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would add a docstring to this method, and potentially change its name.
| # If act_val is enabled, use source or sink weights to determine the activation channel | ||
| # For example, if the source is BatchNorm, we need to use the information coming from the sinks | ||
| if not region.is_valid_activation_equalization: | ||
| return _no_permute() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note that with the current implementation of _no_permute, this line is returning None.
|
|
||
|
|
||
| @torch.no_grad() | ||
| def _permute(region, list_of_act_val): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel that at the moment there are too many functions related to permutations, which makes it difficult to follow: permute in both EqualizationSinkWrapper and EqualizationSourceWrapper, _permute as a standalone method and apply_permute for Region. From my point of view, we could either move the logic of _permute into apply_permute or viceversa. Also, I would consider refactoring the common logic of permute (EqualizationSink/SourceWrapper) into the base class.
|
|
||
| # scale_fn = permute_op_type | ||
| single_module = region.get_module_from_name(next(iter(region.sinks_names))) | ||
| device = next(single_module.parameters()).device |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are device and dtype used in this method?
| if not region.is_valid_activation_equalization: | ||
| return _no_permute() | ||
|
|
||
| list_of_act_val_shapes = [act_val.shape for act_val in list_of_act_val] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This piece of code does not seem self-explanatory. I would consider adding some comments to explain its functionality.
|
|
||
| # If act_val is enabled, use source or sink weights to determine the activation channel | ||
| # For example, if the source is BatchNorm, we need to use the information coming from the sinks | ||
| if list_of_act_val is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is a lot of logic around list_of_act_val scattered around this file. I feel that, we should stop using a dictionary for float_act_map and define a class for it to encapsulate its functionality. Probably, this should be done in a future PR.
|
|
||
|
|
||
| @torch.no_grad() | ||
| def apply_rewriters( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If moving the accelerate-related functionality into equalize.py is not strictly needed to support permutations, I would leave this changes for a future PR, as we might need to discuss whether to introduce this dependency at src/brevitas level.
| self.model = model | ||
| self.rewriters = rewriters | ||
| self.rotation.permute_class.setup_permute() | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| return self |
to use rewriters outside the context.
Reason for this PR
Refactor the graph equalization code, and include supports for permutation in regions.
Changes Made in this PR
Testing Summary
Missing tests for permutation yet.
Risk Highlight
Checklist
devbranch.