Skip to content

Conversation

@holgerroth
Copy link
Collaborator

Fixes # .

Description

Key Optimizations:

  1. For the first contribution, use v.clone().mul_(weight) instead of v * weight - this creates only one copy instead of an intermediate tensor.
  2. Use add_(v, alpha=weight) which is equivalent to += v * weight but done in-place without creating any intermediate tensors. This is the biggest memory saver.
  3. Use div_(self.counts[k]) for in-place division instead of creating a new tensor with multiplication.
  4. Backward compatibility: The code checks if tensors support in-place ops and falls back to the original approach for non-PyTorch data types.

Memory Savings:

For N clients with model size M:

  • Before: Creates ~2M temporary tensors during aggregation (weighted_value + sum result for each parameter)
  • After: Creates ~0.5M temporary tensors (only initial clone), all other ops are in-place

For a large model (e.g., 1B parameters as float32 = 4GB), this saves approximately 4-8GB of peak memory during aggregation with just a few clients.

Types of changes

  • Non-breaking change (fix or new feature that would not break existing functionality).
  • Breaking change (fix or new feature that would cause existing functionality to change).
  • New tests added to cover the changes.
  • Quick tests passed locally by running ./runtest.sh.
  • In-line docstrings updated.
  • Documentation updated.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant