Skip to content

Conversation

@Giuseppe5
Copy link
Collaborator

Reason for this PR

Refactor the graph equalization code, and include supports for permutation in regions.

Changes Made in this PR

Testing Summary

Missing tests for permutation yet.

Risk Highlight

  • This PR includes code from another work (please detail).
  • This PR contains API-breaking changes.
  • This PR depends on work in another PR (please provide links/details).
  • This PR introduces new dependencies (please detail).
  • There are coverage gaps not covered by tests.
  • Documentation updates required in subsequent PR.

Checklist

  • Code comments added to any hard-to-understand areas, if applicable.
  • Changes generate no new warnings.
  • Updated any relevant tests, if applicable.
  • No conflicts with destination dev branch.
  • I reviewed my own code changes.
  • Initial CI/CD passing.
  • 1+ reviews given, and any review issues addressed and approved.
  • Post-review full CI/CD passing.

self.rewriters = None

def __enter__(self):
model, rewriters = self.rotation.apply(self.model)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
model, rewriters = self.rotation.apply(self.model)
self.model, self.rewriters = self.rotation.apply(self.model)

and remove the two following lines.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This file is over 2K lines of code at the moment and it contains a lot of functionality, so maybe it is worth splitting it in a future refactor. One possibility would be to have a file with the base equalization functionality, and then other file with the code for scalar/rotation/permutation equalization, but other options could be explored.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One possibility would be to have a file with the base equalization functionality, and then other file with the code for scalar/rotation/permutation equalization, but other options could be explored.

Agreed, this would be better... Do we do that in this PR or in a future refactor of this section?

Copy link
Collaborator Author

@Giuseppe5 Giuseppe5 Oct 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It will be done in the next release, as part of the re-org of our PTQ algorithms.
I understand it can be done without necessary breaking anything on the outside, but it will still be a half-job if we did it that way

weight.size(0), -1))[self.equalization_indexes.start:self.equalization_indexes.end]

def permute(self, permute_index):
permutation_list = []
Copy link
Collaborator

@pablomlago pablomlago Oct 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe I would opt for doing a list comprehension:

permutation_tuple = tuple(permute_index if i == dim else slice(size)for dim, size in enumerate(self.module.weight.shape))

Also, for indexing the tensor I would pass a tuple, instead of a list, to prevent the following warning:

<stdin>:1: UserWarning: Using a non-tuple sequence for multidimensional indexing is deprecated and will be changed in pytorch 2.9; use x[tuple(seq)] instead of x[seq]. In pytorch 2.9 this will be interpreted as tensor index, x[torch.tensor(seq)], which will result either in an error or a different result (Triggered internally at /pytorch/torch/csrc/autograd/python_variable_indexing.cpp:306.)

class PermuteGraph():

def __init__(self):
super().__init__()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this needed? PermuteGraph does not seem to inherit from other class.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One possibility would be to have a file with the base equalization functionality, and then other file with the code for scalar/rotation/permutation equalization, but other options could be explored.

Agreed, this would be better... Do we do that in this PR or in a future refactor of this section?

device = next(single_module.parameters()).device
dtype = next(single_module.parameters()).dtype

# If equalization criteria are not met, we return a scalar one to indicate that no equalization
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this supposed to return a scalar 1?

if delay_rewriters:
return model

if not hasattr(model, '_hf_map'):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's move this to a method in accelerate_utils.

if offload_model is None or remove_hooks is None:
raise RuntimeError("Accelerate is not installed")
# if we use _hf_map to check and all the model is on a single GPU, then all rewriters are safe
if len(model._hf_map.values()) > 1:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's move this to a method in accelerate_utils.

equalized_layers.update(r[1])

# Check that we found all the expected regions
print(len(regions))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove or use logger?

sinks = region.sinks_names
sinks_check = set(sinks) == set(expected_region[1])
print(len(srcs), len(expected_region[0]))
print(srcs)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove or use logger?

weight = weight.cpu().to(torch.float32)
return scale_fn(weight.reshape(weight.size(0), -1))

def permute(self, permute_index):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This logic is repeated in SinkWrapper (excluding the bias permutation). I would factor out the common logic to ModuleWrapper (or even a standalone method) and then do the appropiate calls in Sink/SourceWrapper.

self.module.weight.data = self.module.weight.data[permutation_list]


def new_axis(x, block_size=32):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would add a docstring to this method, and potentially change its name.

# If act_val is enabled, use source or sink weights to determine the activation channel
# For example, if the source is BatchNorm, we need to use the information coming from the sinks
if not region.is_valid_activation_equalization:
return _no_permute()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that with the current implementation of _no_permute, this line is returning None.



@torch.no_grad()
def _permute(region, list_of_act_val):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel that at the moment there are too many functions related to permutations, which makes it difficult to follow: permute in both EqualizationSinkWrapper and EqualizationSourceWrapper, _permute as a standalone method and apply_permute for Region. From my point of view, we could either move the logic of _permute into apply_permute or viceversa. Also, I would consider refactoring the common logic of permute (EqualizationSink/SourceWrapper) into the base class.


# scale_fn = permute_op_type
single_module = region.get_module_from_name(next(iter(region.sinks_names)))
device = next(single_module.parameters()).device
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are device and dtype used in this method?

if not region.is_valid_activation_equalization:
return _no_permute()

list_of_act_val_shapes = [act_val.shape for act_val in list_of_act_val]
Copy link
Collaborator

@pablomlago pablomlago Oct 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This piece of code does not seem self-explanatory. I would consider adding some comments to explain its functionality.


# If act_val is enabled, use source or sink weights to determine the activation channel
# For example, if the source is BatchNorm, we need to use the information coming from the sinks
if list_of_act_val is not None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a lot of logic around list_of_act_val scattered around this file. I feel that, we should stop using a dictionary for float_act_map and define a class for it to encapsulate its functionality. Probably, this should be done in a future PR.



@torch.no_grad()
def apply_rewriters(
Copy link
Collaborator

@pablomlago pablomlago Oct 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If moving the accelerate-related functionality into equalize.py is not strictly needed to support permutations, I would leave this changes for a future PR, as we might need to discuss whether to introduce this dependency at src/brevitas level.

self.model = model
self.rewriters = rewriters
self.rotation.permute_class.setup_permute()

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return self

to use rewriters outside the context.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants