-
Notifications
You must be signed in to change notification settings - Fork 9
Muon torch submission [WIP] #15
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
Niccolo-Ajroldi
wants to merge
27
commits into
mlcommons:main
Choose a base branch
from
Niccolo-Ajroldi:muon_torch
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
|
MLCommons CLA bot All contributors have signed the MLCommons CLA ✍️ ✅ |
…lgorithms into muon_torch
…lgorithms into muon_torch
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
New Submission: Muon
Submission Information
*credits to original Muon implementation from Keller Jordan
Evidence for the Submission's Performance
Muon original blogpost
Muon is Scalable for LLM Training
Modded-nanogpt
Submission Details
The final goal is to have a single implementation to score. In the development process, we find three possible different implementations of Muon:
MuonVanilla: the optimization algorithm is replicated identically across devices: each rank orthogonalizes all parameters.MuonKJ: this design follows up-to-date code on Muon GitHub repo.Each device updates a distinct subset of parameters locally, which are later AllGathered.
This exploits PyTorch 2.5 possibility of all-gathering parameters of different shape (docs).
MuonBucketed: parameters are bucketed into buckets of same-shaped params.Each device updates a distinct subset of parameters locally, which are later AllGathered.
I believe a similar approach is followed in Dion.
A visualization of the different parallelization strategy is reported in this diagram.
MuonBucketedparallelization strategy is similar to the OGMuonKJimplementation, resulting a similar number of NCCL calls. The main advantage is that AllGather calls would be across tensors of the same shape, hence being likely faster than AllGathering across tensors of different shape. The downside of this approach is that each bucket might need padding, hence resulting in some GPUs being idle while other compute updates. This appears to be less efficient compared toMuonKJ, hence we will not further considerMuonBucketed.Since different devices update different parameters, each device only requires the reduced gradients for the subset it updates. Therefore, when using
MuonKJ(orMuonBucketed) it is superfluous to AllReduce gradients before optimizer step, as traditionally done with DDP when callingloss.backward(). We just need to ReduceScatter the gradients, apply Muon, and AllGather the updated parameters. Crucially, the scatter operation should matches the block structure of MuonKJ distributed update.Notice however that torch AllReduce called is optimized to overlap with gradient computation during the backward pass. Therefore, we decide to test both
MuonKJwith and without a custom ReduceScatter implementation, to test if we can improve over the highly optimized torch AllReduce.We implement 3 submissions:
muon_vanilla.py, usesMuonVanilla, with traditional DP all-reduce grads.muon_kj.py: usesMuonKJ, with traditional DP all-reduce grads.muon_kj_custom_reduce_scatter.py, usesMuonKJ, with the aforementioned reduce-scatter of grads.We compare these implementations:
Comparison results
We compare the implementations by means of yielded loss (in a deterministic run) and accumulated submission time in a dedicated wandb report.
Implementation equivalence
We verify that the three implementations are equivalent across workloads: loss profiles are overlapped for all workloads, but criteo1tb. Further check on criteo1tb might be necessary.
Speed
We compare implementations on 4xA100-80GB, training for 5% of
step_hint, and repeating the experiment for 5 different random seeds. We report here the mean total time (in minutes) accumulated across workloads.We observe high variability in some workloads, and scoring order varies across workloads. Overall, KJ is significantly faster than the vanilla implementation, tho the ReduceScatter version doesn't clearly improve over just using AllReduce.
Parameters allocation: Muon vs Adam
We use AdamW as the backup optimizer, optimizing the following parameters with it:
embeddingsin param nameWe attach txt files of the resulting parameter split for each workloads for ease of inspection.
Momentum implementation
We follow the Adam-style EMA implementation of momentum, as also done in Muon official repo, and in
modded-nanogpt. Notice, however, that the original formulation of Muon uses PyTorch-SGD-style momentum, and a similar implementation is followed by MoonShootAI.
Notes
3D and 4D parameters are flattened on the trailing dimensions and NS orthogonalization is applied.
Nesterov momentum is supported, and we support dampening in Muon momentum implementation.
We use separate learning rates for muon and AdamW.
Discussion points
Next steps
tests/).