Skip to content

Conversation

@Niccolo-Ajroldi
Copy link
Member

@Niccolo-Ajroldi Niccolo-Ajroldi commented Sep 30, 2025

New Submission: Muon

Submission Information

submission_name: "muon_torch"  # As it will appear on the leaderboard
submission_folder: "submissions/external_tuning/"  # Name of folder within `submissions/external_tuning/` or `submissions/self_tuning/` (lowercase, no spaces)
authors: "Niccolò Ajroldi*"  # List authors separated by commas
affiliations: "ELLIS Institute Tübingen, Max Planck Institute for Intelligent Systems"  # List all affiliations of the authors, separated by commas
version: "1.0"  # Optional version number of your submission
ruleset: "external"  # Either "external" or "self-tuning"
framework: "PyTorch"  # Either "PyTorch" or "JAX"
description: "Muon implementation in PyTorch."  # A short, high-level description of the algorithm

*credits to original Muon implementation from Keller Jordan

Evidence for the Submission's Performance

Muon original blogpost
Muon is Scalable for LLM Training
Modded-nanogpt

Submission Details

The final goal is to have a single implementation to score. In the development process, we find three possible different implementations of Muon:

  • MuonVanilla: the optimization algorithm is replicated identically across devices: each rank orthogonalizes all parameters.
  • MuonKJ: this design follows up-to-date code on Muon GitHub repo.
    Each device updates a distinct subset of parameters locally, which are later AllGathered.
    This exploits PyTorch 2.5 possibility of all-gathering parameters of different shape (docs).
  • MuonBucketed: parameters are bucketed into buckets of same-shaped params.
    Each device updates a distinct subset of parameters locally, which are later AllGathered.
    I believe a similar approach is followed in Dion.

A visualization of the different parallelization strategy is reported in this diagram.

MuonBucketed parallelization strategy is similar to the OG MuonKJ implementation, resulting a similar number of NCCL calls. The main advantage is that AllGather calls would be across tensors of the same shape, hence being likely faster than AllGathering across tensors of different shape. The downside of this approach is that each bucket might need padding, hence resulting in some GPUs being idle while other compute updates. This appears to be less efficient compared to MuonKJ, hence we will not further consider MuonBucketed.

Since different devices update different parameters, each device only requires the reduced gradients for the subset it updates. Therefore, when using MuonKJ (or MuonBucketed) it is superfluous to AllReduce gradients before optimizer step, as traditionally done with DDP when calling loss.backward(). We just need to ReduceScatter the gradients, apply Muon, and AllGather the updated parameters. Crucially, the scatter operation should matches the block structure of MuonKJ distributed update.

Notice however that torch AllReduce called is optimized to overlap with gradient computation during the backward pass. Therefore, we decide to test both MuonKJ with and without a custom ReduceScatter implementation, to test if we can improve over the highly optimized torch AllReduce.

We implement 3 submissions:

  1. muon_vanilla.py, uses MuonVanilla, with traditional DP all-reduce grads.
  2. muon_kj.py: uses MuonKJ, with traditional DP all-reduce grads.
  3. muon_kj_custom_reduce_scatter.py, uses MuonKJ, with the aforementioned reduce-scatter of grads.

We compare these implementations:

  • Equivalence: do they yield the same updates?
  • Are the efficient DP algorithms faster than the vanilla one?
  • Is the KJ version with reduce-scatter faster than the KJ version with traditional DP all-reduce?

Comparison results

We compare the implementations by means of yielded loss (in a deterministic run) and accumulated submission time in a dedicated wandb report.

Implementation equivalence

We verify that the three implementations are equivalent across workloads: loss profiles are overlapped for all workloads, but criteo1tb. Further check on criteo1tb might be necessary.

Speed

We compare implementations on 4xA100-80GB, training for 5% of step_hint, and repeating the experiment for 5 different random seeds. We report here the mean total time (in minutes) accumulated across workloads.

optim_name accumulated_submission_time_min time_saved_over_vanilla_min
vanilla 349.66 0.00
KJ 337.26 12.41
KJ_RS 337.98 11.69

We observe high variability in some workloads, and scoring order varies across workloads. Overall, KJ is significantly faster than the vanilla implementation, tho the ReduceScatter version doesn't clearly improve over just using AllReduce.

Parameters allocation: Muon vs Adam

We use AdamW as the backup optimizer, optimizing the following parameters with it:

  • 1D params (biases, layernorm, batchnorm)
  • Embeddings of WMT, CRITEO, identified by embeddings in param name

We attach txt files of the resulting parameter split for each workloads for ease of inspection.

Momentum implementation

We follow the Adam-style EMA implementation of momentum, as also done in Muon official repo, and in
modded-nanogpt. Notice, however, that the original formulation of Muon uses PyTorch-SGD-style momentum, and a similar implementation is followed by MoonShootAI.

Notes

3D and 4D parameters are flattened on the trailing dimensions and NS orthogonalization is applied.
Nesterov momentum is supported, and we support dampening in Muon momentum implementation.
We use separate learning rates for muon and AdamW.

Discussion points

  • Should we further optimize the ReduceScatter to potentially overlap with backward?
  • Should we use AdamW or NAdamW as a backup optimizer?

Next steps

  • Efficient DP implementation
  • Momentum with dampening
  • Update with new dropout
  • Identify which layers to optimize with Muon and which with AdamW.
  • Test equivalence on toy problem (tests/).
  • Test equivalence on AlgoPerf workloads (deterministic + fixed eval_every_steps)
  • Compare speed across implementations
    • Is the efficient DP version faster?
    • Is it worth it to manually ReduceScatter gradients?
  • Decide on a single implementation to score.

@Niccolo-Ajroldi Niccolo-Ajroldi requested a review from a team as a code owner September 30, 2025 18:54
@github-actions
Copy link

github-actions bot commented Sep 30, 2025

MLCommons CLA bot All contributors have signed the MLCommons CLA ✍️ ✅

@Niccolo-Ajroldi Niccolo-Ajroldi changed the title muon torch vanilla DP Muon torch submission (vanilla DP) Sep 30, 2025
@Niccolo-Ajroldi Niccolo-Ajroldi changed the title Muon torch submission (vanilla DP) Muon torch submission Oct 6, 2025
@Niccolo-Ajroldi Niccolo-Ajroldi changed the title Muon torch submission Muon torch submission [WIP] Oct 6, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant