-
Notifications
You must be signed in to change notification settings - Fork 209
feat: RL support for custom moe models in dtensor v2 #1695
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
📝 WalkthroughWalkthroughIntroduces a new GRPO Moonlight 16B automodel configuration file, refactors dtensor_policy_worker_v2.py with tensor adaptation and unified context management helpers for HuggingFace compatibility, and adds a structured experimental workflow test script with metrics validation. Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes Possibly related PRs
Suggested labels
Suggested reviewers
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (3 passed)
✨ Finishing touches
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Nitpick comments (4)
examples/configs/recipes/llm/grpo-moonlight-16b-automodel-1n8g-ep8.yaml (1)
6-7: Minor inconsistency between filename and directory names.The filename omits
a3b-instructbutcheckpoint_dirincludes it (grpo-moonlight-16b-a3b-instruct-automodel-1n8g-ep8). Consider aligning these for easier discoverability, or this may be intentional to keep the filename shorter.nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py (2)
439-439: Prefer generator expression over list comprehension inany().Using a list comprehension inside
any()creates an unnecessary intermediate list. A generator expression is more memory-efficient and idiomatic.🔎 Proposed fix
- self.is_moe_model = any(["expert" in key for key in self.model_state_dict_keys]) + self.is_moe_model = any("expert" in key for key in self.model_state_dict_keys)
1848-1848: Replace lambda with named function per Ruff E731.Static analysis flags the lambda assignment. Per Python best practices, use a
defstatement for named functions.🔎 Proposed fix
- dtensor_post_iter_func = lambda x: x[1] + def dtensor_post_iter_func(x): + return x[1]tests/test_suites/llm/grpo-moonlight-16b-automodel-1n8g-ep8.sh (1)
36-41: Consider parameterizing the step value in metric checks.The step value
"30"is hardcoded in the metric checks butMAX_STEPS=30is defined above. IfMAX_STEPSchanges, these would need manual updates.🔎 Proposed fix
You could interpolate the step value:
uv run tests/check_metrics.py $JSON_METRICS \ 'mean(data["train/gen_kl_error"]) < 0.001' \ - 'data["train/gen_kl_error"]["30"] < 0.001 ' \ - 'data["train/reward"]["30"] > 0.4' \ + "data[\"train/gen_kl_error\"][\"$MAX_STEPS\"] < 0.001" \ + "data[\"train/reward\"][\"$MAX_STEPS\"] > 0.4" \ 'data["train/grad_norm"] < 0.5' \ 'data["train/grad_norm"] > 0.05'
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
examples/configs/recipes/llm/grpo-moonlight-16b-automodel-1n8g-ep8.yamlnemo_rl/models/policy/workers/dtensor_policy_worker_v2.pytests/test_suites/llm/grpo-moonlight-16b-automodel-1n8g-ep8.sh
🧰 Additional context used
📓 Path-based instructions (8)
examples/configs/recipes/**/*.yaml
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
When adding support for a new model, create a recipe YAML under examples/configs/recipes/ in the appropriate domain subdirectory (llm, vlm, etc.)
Files:
examples/configs/recipes/llm/grpo-moonlight-16b-automodel-1n8g-ep8.yaml
examples/configs/recipes/llm/*.yaml
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
Recipe YAML files should follow the naming pattern: --ng-[-modifiers][-long][.vN].yaml for LLM recipes
Files:
examples/configs/recipes/llm/grpo-moonlight-16b-automodel-1n8g-ep8.yaml
!(**/tests/**|**/test_*.py|**/test_*.sh)
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
Add the NVIDIA copyright header to all Python files and shell scripts (excluding tests). The header should include the current year
Files:
examples/configs/recipes/llm/grpo-moonlight-16b-automodel-1n8g-ep8.yamltests/test_suites/llm/grpo-moonlight-16b-automodel-1n8g-ep8.shnemo_rl/models/policy/workers/dtensor_policy_worker_v2.py
**/*.sh
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
**/*.sh: Use uv run instead of python to execute scripts
Follow the Google Shell Style Guide for shell scripts
Files:
tests/test_suites/llm/grpo-moonlight-16b-automodel-1n8g-ep8.sh
tests/test_suites/**/*.sh
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
tests/test_suites/**/*.sh: When adding support for a new model, create a corresponding driver shell script under tests/test_suites/ in the matching domain
Driver shell scripts should match the YAML base name with .sh extension and invoke training entrypoint with uv run
Files:
tests/test_suites/llm/grpo-moonlight-16b-automodel-1n8g-ep8.sh
**/*.{py,sh}
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
The NVIDIA copyright header should appear at the top of all Python files and shell scripts (excluding tests)
Files:
tests/test_suites/llm/grpo-moonlight-16b-automodel-1n8g-ep8.shnemo_rl/models/policy/workers/dtensor_policy_worker_v2.py
**/*.py
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
**/*.py: Conform code to Python 3.12+
Indent code with 4 spaces. Do not use tabs
Use snake_case for file names
Use PascalCase for class names
Use snake_case for function and method names
Use snake_case for local variables
Prefix variable names that start with a number with 'k' (e.g., k_99th_percentile)
Use upper snake_case with 'G' prefix for global variables (e.g., G_MY_GLOBAL)
Use upper snake_case for constants
Avoid shadowing variables declared in an outer scope
Initialize all externally visible members of a class in the constructor
Prefer docstrings over comments for interfaces that may be used outside a file
Reserve comments for code within a function or interfaces that are local to a file
If a piece of code is commented out, include a comment describing its usage and why it's commented out. Remove debug comments before merging
Use Google style docstrings for classes and functions in Python, which can be parsed by Sphinx
Avoid using reflection when functionality can be easily achieved without reflection
When using try-except blocks, limit the except clause to the smallest set of specific errors possible
When using try-except blocks for duck-typing, keep the body of the try as small as possible and use the else block for logic
YAML is the single source of truth for configuration defaults. Do not set non-None defaults in code for configuration values
For required configuration attributes, access config directly and expect presence (e.g., policy_cfg['precision']) without hidden defaults
Use typing.NotRequired to mark optional attributes in TypedDict for configuration
When adding a new config key to a TypedDict subclass, document the key's purpose, valid values/types, and recommended default, and reflect the default in exemplar YAMLs under examples/configs/*.yaml
Follow the Google Python Style Guide for Python code
Files:
nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py
nemo_rl/**/*.py
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
For any source file under nemo_rl/*.py that defines a class or function decorated with @ray.remote, add a coverage pragma (# pragma: no cover) because these run in separate Ray processes
Files:
nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py
🧠 Learnings (6)
📚 Learning: 2025-11-24T17:24:41.976Z
Learnt from: CR
Repo: NVIDIA-NeMo/RL PR: 0
File: CODING_GUIDELINES.md:0-0
Timestamp: 2025-11-24T17:24:41.976Z
Learning: Applies to examples/configs/recipes/llm/*.yaml : Recipe YAML files should follow the naming pattern: <algo>-<model>-<nodes>n<gpus>g-<strategy-and-params>[-modifiers][-long][.vN].yaml for LLM recipes
Applied to files:
examples/configs/recipes/llm/grpo-moonlight-16b-automodel-1n8g-ep8.yaml
📚 Learning: 2025-09-19T03:00:58.662Z
Learnt from: shuo-nvidia
Repo: NVIDIA-NeMo/RL PR: 1006
File: examples/configs/recipes/llm/distillation-qwen3-32b-to-1.7b-base-1n8g-fsdp2tp1.v1.yaml:85-101
Timestamp: 2025-09-19T03:00:58.662Z
Learning: In distillation and GRPO configurations, max_new_tokens is intentionally set to the full context window (max_total_sequence_length) for consistency across the codebase. Overflow cases when prompt + generation tokens exceed max_model_len are handled by safeguards implemented in vllm_worker.py.
Applied to files:
examples/configs/recipes/llm/grpo-moonlight-16b-automodel-1n8g-ep8.yaml
📚 Learning: 2025-10-12T14:46:57.171Z
Learnt from: zpqiu
Repo: NVIDIA-NeMo/RL PR: 1324
File: tests/test_suites/llm/distillation-qwen3-32b-to-1.7b-base-1n8g-megatron-tp2pp2cp2-pack.sh:6-11
Timestamp: 2025-10-12T14:46:57.171Z
Learning: Test scripts in tests/test_suites/llm/ follow a standard configuration pattern that includes NUM_NODES, STEPS_PER_RUN, MAX_STEPS, NUM_RUNS (calculated as `$(( (MAX_STEPS + STEPS_PER_RUN - 1) / STEPS_PER_RUN ))`), and NUM_MINUTES. These variables are part of the test infrastructure's standard interface and should not be flagged as unused even if not directly referenced within the individual script, as they are consumed by external launch tooling or common.env.
Applied to files:
tests/test_suites/llm/grpo-moonlight-16b-automodel-1n8g-ep8.sh
📚 Learning: 2025-10-12T14:46:55.513Z
Learnt from: zpqiu
Repo: NVIDIA-NeMo/RL PR: 1324
File: tests/test_suites/llm/distillation-qwen3-32b-to-1.7b-base-1n8g-megatron-tp2pp2cp2-pack.sh:16-30
Timestamp: 2025-10-12T14:46:55.513Z
Learning: In the NVIDIA-NeMo/RL repository, test scripts under tests/ follow a consistent pattern: use `cd $PROJECT_ROOT` without quotes or error handling, and pass arguments with `$@` unquoted. Maintain this consistency when adding new test scripts.
Applied to files:
tests/test_suites/llm/grpo-moonlight-16b-automodel-1n8g-ep8.sh
📚 Learning: 2025-11-24T17:24:41.976Z
Learnt from: CR
Repo: NVIDIA-NeMo/RL PR: 0
File: CODING_GUIDELINES.md:0-0
Timestamp: 2025-11-24T17:24:41.976Z
Learning: Applies to tests/test_suites/**/*.sh : Driver shell scripts should match the YAML base name with .sh extension and invoke training entrypoint with uv run
Applied to files:
tests/test_suites/llm/grpo-moonlight-16b-automodel-1n8g-ep8.sh
📚 Learning: 2025-10-30T20:50:44.126Z
Learnt from: adil-a
Repo: NVIDIA-NeMo/RL PR: 1440
File: examples/configs/sft_automodel.yaml:48-58
Timestamp: 2025-10-30T20:50:44.126Z
Learning: In DTensor configurations for MoE (Mixture of Experts) models, expert_parallel_size and data_parallel_size can be applied together without multiplying the GPU requirements. Expert Parallelism (EP) only applies to MoE layers, while Data Parallelism/FSDP applies to non-MoE layers. Therefore, configurations like expert_parallel_size: 8 and data_parallel_size: 8 are valid on an 8-GPU cluster for MoE models.
Applied to files:
nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py
🧬 Code graph analysis (1)
nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py (1)
nemo_rl/models/policy/workers/dtensor_policy_worker.py (2)
create_context_parallel_ctx(450-474)dtensor_params_generator(1697-1710)
🪛 Ruff (0.14.10)
nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py
1848-1848: Do not assign a lambda expression, use a def
Rewrite dtensor_post_iter_func as a def
(E731)
🪛 Shellcheck (0.11.0)
tests/test_suites/llm/grpo-moonlight-16b-automodel-1n8g-ep8.sh
[error] 35-35: Couldn't find 'fi' for this 'if'.
(SC1046)
[error] 35-35: Couldn't parse this if expression. Fix to allow more checks.
(SC1073)
[warning] 42-42: Use semicolon or linefeed before 'fi' (or quote to make it literal).
(SC1010)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: Post submodule check comment / Comment on PR
- GitHub Check: Post automodel integration comment / Comment on PR
🔇 Additional comments (11)
examples/configs/recipes/llm/grpo-moonlight-16b-automodel-1n8g-ep8.yaml (1)
27-27: Verifygate_precision: float64is intentional.Using
float64for gate precision is unusual and may incur performance overhead. Typicallyfloat32is used. Please verify this is intended for numerical stability in MoE gating.nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py (9)
105-116: LGTM - Tensor adaptation helper is well-structured.The function correctly handles the case where no adapter is present by returning the input as-is, and properly returns a list of tuples to handle adapters that may split tensors.
119-145: LGTM - Unified context manager consolidates CP and autocast handling.The
get_train_contextfunction correctly combines context parallel and autocast contexts usingExitStack, making the code DRY across train/inference paths. The conditional autocast enables the fix for custom MoE models.
770-808: LGTM - Training context refactored to use unified helper.The training path now correctly uses
get_train_contextfor consistent CP and autocast handling. The conditional autocast based onself.autocast_enabledproperly addresses the numerical issues with custom MoE models.
1105-1134: LGTM - Consistent refactoring in get_logprobs.The
get_logprobsmethod follows the same pattern astrain, using the unifiedget_train_contexthelper for CP and autocast management.
1364-1382: LGTM - Score method updated consistently.The
scoremethod follows the same unified context pattern.
1516-1534: LGTM - get_topk_logits updated consistently.
1738-1751: LGTM - Refit info preparation adapted for HF format.The
prepare_refit_infomethod now correctly uses_maybe_adapt_tensor_to_hfto ensure the state dict metadata matches the adapted tensor format that will be streamed.
1788-1802: LGTM - IPC weight streaming adapted for HF format.The weight streaming generator now correctly adapts tensors to HF format before yielding, ensuring consistency with the metadata prepared in
prepare_refit_info.
1831-1855: LGTM - Broadcast weights adapted for HF format.The collective broadcast path now correctly applies HF tensor adaptation, consistent with the IPC streaming path.
tests/test_suites/llm/grpo-moonlight-16b-automodel-1n8g-ep8.sh (1)
1-14: LGTM - Standard test script configuration.The script follows the standard test infrastructure pattern with
NUM_NODES,STEPS_PER_RUN,MAX_STEPS,NUM_RUNS, andNUM_MINUTESvariables. Usesuv runas required per coding guidelines. Based on learnings, these variables are consumed by external launch tooling.
| if [[ $(jq 'to_entries | .[] | select(.key == "train/loss") | .value | keys | map(tonumber) | max' $JSON_METRICS) -ge $MAX_STEPS ]]; then | ||
| uv run tests/check_metrics.py $JSON_METRICS \ | ||
| 'mean(data["train/gen_kl_error"]) < 0.001' \ | ||
| 'data["train/gen_kl_error"]["30"] < 0.001 ' \ | ||
| 'data["train/reward"]["30"] > 0.4' \ | ||
| 'data["train/grad_norm"] < 0.5' \ | ||
| 'data["train/grad_norm"] > 0.05' \ | ||
| fi |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
🏁 Script executed:
bash -n tests/test_suites/llm/grpo-moonlight-16b-automodel-1n8g-ep8.sh 2>&1 | head -50Repository: NVIDIA-NeMo/RL
Length of output: 167
🏁 Script executed:
cat -n tests/test_suites/llm/grpo-moonlight-16b-automodel-1n8g-ep8.sh | sed -n '30,45p'Repository: NVIDIA-NeMo/RL
Length of output: 733
🏁 Script executed:
cat -n tests/test_suites/llm/grpo-moonlight-16b-automodel-1n8g-ep8.sh | sed -n '35,43p'Repository: NVIDIA-NeMo/RL
Length of output: 541
Remove trailing backslash from line 41.
The trailing backslash at the end of line 41 signals shell line continuation, but line 42 contains fi (the if-block terminator) instead of command arguments. Remove the backslash from 'data["train/grad_norm"] > 0.05' \ to fix the syntax error.
🧰 Tools
🪛 Shellcheck (0.11.0)
[error] 35-35: Couldn't find 'fi' for this 'if'.
(SC1046)
[error] 35-35: Couldn't parse this if expression. Fix to allow more checks.
(SC1073)
[warning] 42-42: Use semicolon or linefeed before 'fi' (or quote to make it literal).
(SC1010)
🤖 Prompt for AI Agents
In tests/test_suites/llm/grpo-moonlight-16b-automodel-1n8g-ep8.sh around lines
35 to 42, there is a trailing backslash at the end of the line containing
'data["train/grad_norm"] > 0.05' which incorrectly escapes the newline and makes
the following fi token part of the continued command; remove that trailing
backslash so the uv run command ends after the last condition and the if-block
can close with fi.
|
f28ef58 to
7177b19
Compare
|
|
|
64c220d to
14dc6fd
Compare
|
|
|
|
yuki-97
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for supporting this! left some comments.
| tensor = tensor.full_tensor() | ||
| tensor = tensor.to(dtype, non_blocking=True) | ||
| return tensor | ||
| def dtensor_params_generator(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this function looks exactly the same as line 1788, can we move it as a class function and reuse it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed in 58ca35d
| rms_norm: te | ||
| enable_deepep: true | ||
| fake_balanced_gate: false | ||
| enable_hf_state_dict_adapter: true |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I saw the new function _maybe_adapt_tensor_to_hf may also use state_dict_adapter.
just curious, when do we need to set enable_hf_state_dict_adapter to true? and do we need to update the comment here since it seems not only for checkpoint loading now?
RL/nemo_rl/models/policy/__init__.py
Lines 50 to 51 in 1720466
| # Enable HuggingFace state dict adapter for checkpoint loading | |
| enable_hf_state_dict_adapter: NotRequired[bool] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes I will update the comment. You only need to use it when using custom MoE implementation with backend set in the automodel kwargs in your config. Alternatively, you can pass force_hf: true in the automodel_kwargs and that will fall back to the HF implementation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added comment in 58ca35d
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You only need to use it when using custom MoE implementation with backend set in the automodel kwargs in your config. Alternatively, you can pass force_hf: true in the automodel_kwargs and that will fall back to the HF implementation.
thanks for the explanation! can you also add this in the code comment above? so that people can also know when to enable it.
|
Moonlight 16b run
Comparison to Megatron
Summary by CodeRabbit
New Features
Chores
✏️ Tip: You can customize this high-level summary in your review settings.