-
Notifications
You must be signed in to change notification settings - Fork 209
feat: Megatron SFT LoRA #1629
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
feat: Megatron SFT LoRA #1629
Conversation
Signed-off-by: adithyare <adithyare@nvidia.com>
Signed-off-by: adithyare <adithyare@nvidia.com>
Signed-off-by: adithyare <adithyare@nvidia.com>
Signed-off-by: adithyare <adithyare@nvidia.com>
…s w.o crashing. WIP check correctness Signed-off-by: adithyare <adithyare@nvidia.com>
terrykong
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yaoyu-33 to review
terrykong
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
📝 WalkthroughWalkthroughThis pull request introduces LoRA/PEFT integration for Megatron policy workers with conditional pre-wrap hooking and checkpoint-aware loading. Changes include updated documentation clarifying DTensor v2 as default, Megatron-specific LoRA configuration alongside DTensor, example configurations, implementation changes to the policy worker, and functional testing for the LoRA SFT workflow. Changes
Sequence Diagram(s)sequenceDiagram
participant PolicyWorker as Policy Worker
participant Config as Policy Config
participant PEFTHook as PEFT Hook<br/>Manager
participant Model as Megatron<br/>Model
participant Checkpoint as Checkpoint<br/>Loader
PolicyWorker->>Config: Load policy.lora_cfg
alt LoRA Enabled
Config-->>PolicyWorker: lora_cfg config
PolicyWorker->>PEFTHook: Create LoRA config<br/>from lora_cfg
PEFTHook->>PEFTHook: Generate<br/>_create_peft_pre_wrap_hook()
PEFTHook->>PEFTHook: Compose peft_hook<br/>with pre-wrap behavior
PolicyWorker->>Model: Construct with<br/>pre_wrap_hook=peft_hook
else LoRA Disabled
Config-->>PolicyWorker: lora_cfg = None
PEFTHook-->>PolicyWorker: peft_hook = []
PolicyWorker->>Model: Construct with<br/>pre_wrap_hook=[]
end
Model-->>PolicyWorker: Model created
alt should_load_checkpoint
PolicyWorker->>Checkpoint: Check checkpoint<br/>availability
Checkpoint-->>PolicyWorker: Checkpoint found
alt LoRA Enabled
PolicyWorker->>Checkpoint: Set finetune=False<br/>for state loading
end
Checkpoint->>Checkpoint: Load checkpoint
else
Checkpoint-->>PolicyWorker: No checkpoint
end
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes
Suggested reviewers
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (3 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 4
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (4)
docs/guides/sft.md(3 hunks)examples/configs/sft.yaml(1 hunks)nemo_rl/models/policy/workers/megatron_policy_worker.py(6 hunks)tests/functional/test_mbridge_lora_sft.sh(1 hunks)
🧰 Additional context used
📓 Path-based instructions (6)
docs/**/*.md
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
Update docs/index.md when a new markdown doc is added under docs/**/*.md or a markdown file is renamed, ensuring the document appears in the most appropriate section
Files:
docs/guides/sft.md
!(**/tests/**|**/test_*.py|**/test_*.sh)
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
Add the NVIDIA copyright header to all Python files and shell scripts (excluding tests). The header should include the current year
Files:
docs/guides/sft.mdexamples/configs/sft.yamltests/functional/test_mbridge_lora_sft.shnemo_rl/models/policy/workers/megatron_policy_worker.py
**/*.sh
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
**/*.sh: Use uv run instead of python to execute scripts
Follow the Google Shell Style Guide for shell scripts
Files:
tests/functional/test_mbridge_lora_sft.sh
**/*.{py,sh}
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
The NVIDIA copyright header should appear at the top of all Python files and shell scripts (excluding tests)
Files:
tests/functional/test_mbridge_lora_sft.shnemo_rl/models/policy/workers/megatron_policy_worker.py
**/*.py
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
**/*.py: Conform code to Python 3.12+
Indent code with 4 spaces. Do not use tabs
Use snake_case for file names
Use PascalCase for class names
Use snake_case for function and method names
Use snake_case for local variables
Prefix variable names that start with a number with 'k' (e.g., k_99th_percentile)
Use upper snake_case with 'G' prefix for global variables (e.g., G_MY_GLOBAL)
Use upper snake_case for constants
Avoid shadowing variables declared in an outer scope
Initialize all externally visible members of a class in the constructor
Prefer docstrings over comments for interfaces that may be used outside a file
Reserve comments for code within a function or interfaces that are local to a file
If a piece of code is commented out, include a comment describing its usage and why it's commented out. Remove debug comments before merging
Use Google style docstrings for classes and functions in Python, which can be parsed by Sphinx
Avoid using reflection when functionality can be easily achieved without reflection
When using try-except blocks, limit the except clause to the smallest set of specific errors possible
When using try-except blocks for duck-typing, keep the body of the try as small as possible and use the else block for logic
YAML is the single source of truth for configuration defaults. Do not set non-None defaults in code for configuration values
For required configuration attributes, access config directly and expect presence (e.g., policy_cfg['precision']) without hidden defaults
Use typing.NotRequired to mark optional attributes in TypedDict for configuration
When adding a new config key to a TypedDict subclass, document the key's purpose, valid values/types, and recommended default, and reflect the default in exemplar YAMLs under examples/configs/*.yaml
Follow the Google Python Style Guide for Python code
Files:
nemo_rl/models/policy/workers/megatron_policy_worker.py
nemo_rl/**/*.py
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
For any source file under nemo_rl/*.py that defines a class or function decorated with @ray.remote, add a coverage pragma (# pragma: no cover) because these run in separate Ray processes
Files:
nemo_rl/models/policy/workers/megatron_policy_worker.py
🧠 Learnings (3)
📚 Learning: 2025-09-19T07:28:29.887Z
Learnt from: shuo-nvidia
Repo: NVIDIA-NeMo/RL PR: 1006
File: tests/test_suites/llm/distillation-qwen3-32b-to-4b-base-2n8g-fsdp2tp2-long.v1.sh:1-4
Timestamp: 2025-09-19T07:28:29.887Z
Learning: The NVIDIA-NeMo/RL project prefers to maintain consistent formatting across test scripts rather than applying individual bash hardening improvements like `set -euo pipefail` or proper quoting for sourcing files.
Applied to files:
tests/functional/test_mbridge_lora_sft.sh
📚 Learning: 2025-11-24T17:24:41.976Z
Learnt from: CR
Repo: NVIDIA-NeMo/RL PR: 0
File: CODING_GUIDELINES.md:0-0
Timestamp: 2025-11-24T17:24:41.976Z
Learning: Applies to tests/test_suites/**/*.sh : Driver shell scripts should match the YAML base name with .sh extension and invoke training entrypoint with uv run
Applied to files:
tests/functional/test_mbridge_lora_sft.sh
📚 Learning: 2025-10-12T14:46:55.513Z
Learnt from: zpqiu
Repo: NVIDIA-NeMo/RL PR: 1324
File: tests/test_suites/llm/distillation-qwen3-32b-to-1.7b-base-1n8g-megatron-tp2pp2cp2-pack.sh:16-30
Timestamp: 2025-10-12T14:46:55.513Z
Learning: In the NVIDIA-NeMo/RL repository, test scripts under tests/ follow a consistent pattern: use `cd $PROJECT_ROOT` without quotes or error handling, and pass arguments with `$@` unquoted. Maintain this consistency when adding new test scripts.
Applied to files:
tests/functional/test_mbridge_lora_sft.sh
🪛 markdownlint-cli2 (0.18.1)
docs/guides/sft.md
233-233: Multiple headings with the same content
(MD024, no-duplicate-heading)
236-236: Unordered list indentation
Expected: 2; Actual: 4
(MD007, ul-indent)
237-237: Unordered list indentation
Expected: 2; Actual: 4
(MD007, ul-indent)
238-238: Unordered list indentation
Expected: 2; Actual: 4
(MD007, ul-indent)
239-239: Unordered list indentation
Expected: 2; Actual: 4
(MD007, ul-indent)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: Post automodel integration comment / Comment on PR
- GitHub Check: Post submodule check comment / Comment on PR
🔇 Additional comments (10)
tests/functional/test_mbridge_lora_sft.sh (1)
1-51: LGTM! Well-structured functional test for Megatron LoRA SFT.The test script correctly:
- Uses
uv runper coding guidelines- Implements proper error handling with
set -eou pipefail- Sets up cleanup trap for temporary checkpoints
- Exercises the new Megatron LoRA configuration (
policy.megatron_cfg.lora_cfg.enabled=true)- Validates training metrics with reasonable thresholds
examples/configs/sft.yaml (2)
117-129: LGTM! Comprehensive Megatron LoRA configuration.The configuration block:
- Includes all necessary LoRA parameters
- Maintains
enabled: falseby default to preserve existing behavior (as per past review feedback)- Aligns with the DTensor LoRA configuration structure
- Provides clear documentation for each parameter
132-135: Helpful clarification about AdamW usage.The added comments correctly note that when
weight_decayis set, the optimizer effectively uses AdamW. This helps users understand the actual optimizer behavior.nemo_rl/models/policy/workers/megatron_policy_worker.py (4)
26-26: LGTM! Necessary imports for LoRA integration.The imports correctly bring in:
LoRAfor PEFT configuration_create_peft_pre_wrap_hookfor model wrappingMegatronModulefor type annotationsAlso applies to: 54-54, 90-90
280-295: LGTM! Correct LoRA configuration extraction.The code properly:
- Checks if LoRA is enabled via policy configuration
- Extracts all required LoRA parameters
- Instantiates the LoRA PEFT configuration
- Assigns it to
cfg.peftfor use in model setup
297-305: LGTM! Correct PEFT pre-wrap hook composition.The code properly:
- Creates the PEFT pre-wrap hook using the configuration
- Registers it with the model configuration
- Composes it into a callable hook for model construction
- Falls back to an empty list when PEFT is not configured
314-314: LGTM! Correct integration of PEFT hook into model construction.The
peft_hookis properly passed as thepre_wrap_hookparameter toget_model, ensuring the PEFT transformations are applied during model instantiation.docs/guides/sft.md (3)
171-172: LGTM! Clear documentation of LoRA backend support.The notes correctly clarify:
- DTensor v2 and Megatron backends support LoRA
- DTensor v1 does not support LoRA
- Triton kernel usage details
174-210: LGTM! Comprehensive DTensor LoRA documentation.The section provides:
- Clear configuration example with all parameters
- Detailed parameter descriptions
- Usage example with command-line override
- Important note about Triton kernel compatibility with TP > 1
252-259: LGTM! Clear example of enabling Megatron LoRA.The example correctly demonstrates:
- Disabling DTensor backend
- Enabling Megatron backend
- Enabling Megatron LoRA configuration
This helps users understand the backend switching requirements.
Signed-off-by: Yuki Huang <yukih@nvidia.com>
Signed-off-by: Yuki Huang <yukih@nvidia.com> Co-authored-by: Yuki Huang <yukih@nvidia.com>
examples/configs/recipes/llm/sft-llama3.1-8b-1n8g-fsdp2tp1-lora-megatron.yaml
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for the update, LGTM.
as discussed in the meeting, we still have perf issue with tulu3 dataset now, tracked at #1719.
| # Only run metrics if the target step is reached | ||
| if [[ $(jq 'to_entries | .[] | select(.key == "train/loss") | .value | keys | map(tonumber) | max' $JSON_METRICS) -ge $MAX_STEPS ]]; then | ||
| uv run tests/check_metrics.py $JSON_METRICS \ | ||
| 'data["train/loss"]["1"] < 1.0' \ | ||
| 'data["train/loss"]["50"] < 0.8' \ | ||
| 'max(data["ray/node.0.gpu.0.mem_gb"]) < 50' \ | ||
| 'mean(data["timing/train/total_step_time"], 2) < 10' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
According to the latest test, max(data["ray/node.0.gpu.0.mem_gb"]) = 52.06640625 and mean(data["timing/train/total_step_time"], 2) = 22.511235587450923.
| # Only run metrics if the target step is reached | |
| if [[ $(jq 'to_entries | .[] | select(.key == "train/loss") | .value | keys | map(tonumber) | max' $JSON_METRICS) -ge $MAX_STEPS ]]; then | |
| uv run tests/check_metrics.py $JSON_METRICS \ | |
| 'data["train/loss"]["1"] < 1.0' \ | |
| 'data["train/loss"]["50"] < 0.8' \ | |
| 'max(data["ray/node.0.gpu.0.mem_gb"]) < 50' \ | |
| 'mean(data["timing/train/total_step_time"], 2) < 10' | |
| # Revert to `mean(data["timing/train/total_step_time"], 2) < 30` once https://github.com/NVIDIA-NeMo/RL/issues/1719 resolved | |
| # Only run metrics if the target step is reached | |
| if [[ $(jq 'to_entries | .[] | select(.key == "train/loss") | .value | keys | map(tonumber) | max' $JSON_METRICS) -ge $MAX_STEPS ]]; then | |
| uv run tests/check_metrics.py $JSON_METRICS \ | |
| 'data["train/loss"]["1"] < 1.0' \ | |
| 'data["train/loss"]["50"] < 0.8' \ | |
| 'max(data["ray/node.0.gpu.0.mem_gb"]) < 60' \ | |
| 'mean(data["timing/train/total_step_time"], 2) < 30' |
What does this PR do ?
Adds Llama LoRA SFT support via Megatron Bridge. To verify correctness we compared train and validation curves with the LoRA SFT DTensor path on the squad and tulu3 datasets. The wandb link for the verification runs
Issues
List issues that this PR closes (syntax):
Usage
Before your PR is "Ready for review"
Pre checks:
Additional Information
Summary by CodeRabbit
Documentation
New Features
Tests
✏️ Tip: You can customize this high-level summary in your review settings.