-
Notifications
You must be signed in to change notification settings - Fork 208
feat: Add attention_backend config support for Megatron policy #1628
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
feat: Add attention_backend config support for Megatron policy #1628
Conversation
📝 WalkthroughWalkthroughAdds optional wiring to configure Megatron's attention backend in the policy worker initialization. When Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~15–20 minutes
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (3 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
nemo_rl/models/policy/workers/megatron_policy_worker.py(1 hunks)
🧰 Additional context used
📓 Path-based instructions (4)
**/*.py
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
**/*.py: Conform code to Python 3.12+
Indent code with 4 spaces. Do not use tabs
Use snake_case for file names
Use PascalCase for class names
Use snake_case for function and method names
Use snake_case for local variables
Prefix variable names that start with a number with 'k' (e.g., k_99th_percentile)
Use upper snake_case with 'G' prefix for global variables (e.g., G_MY_GLOBAL)
Use upper snake_case for constants
Avoid shadowing variables declared in an outer scope
Initialize all externally visible members of a class in the constructor
Prefer docstrings over comments for interfaces that may be used outside a file
Reserve comments for code within a function or interfaces that are local to a file
If a piece of code is commented out, include a comment describing its usage and why it's commented out. Remove debug comments before merging
Use Google style docstrings for classes and functions in Python, which can be parsed by Sphinx
Avoid using reflection when functionality can be easily achieved without reflection
When using try-except blocks, limit the except clause to the smallest set of specific errors possible
When using try-except blocks for duck-typing, keep the body of the try as small as possible and use the else block for logic
YAML is the single source of truth for configuration defaults. Do not set non-None defaults in code for configuration values
For required configuration attributes, access config directly and expect presence (e.g., policy_cfg['precision']) without hidden defaults
Use typing.NotRequired to mark optional attributes in TypedDict for configuration
When adding a new config key to a TypedDict subclass, document the key's purpose, valid values/types, and recommended default, and reflect the default in exemplar YAMLs under examples/configs/*.yaml
Follow the Google Python Style Guide for Python code
Files:
nemo_rl/models/policy/workers/megatron_policy_worker.py
nemo_rl/**/*.py
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
For any source file under nemo_rl/*.py that defines a class or function decorated with @ray.remote, add a coverage pragma (# pragma: no cover) because these run in separate Ray processes
Files:
nemo_rl/models/policy/workers/megatron_policy_worker.py
!(**/tests/**|**/test_*.py|**/test_*.sh)
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
Add the NVIDIA copyright header to all Python files and shell scripts (excluding tests). The header should include the current year
Files:
nemo_rl/models/policy/workers/megatron_policy_worker.py
**/*.{py,sh}
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
The NVIDIA copyright header should appear at the top of all Python files and shell scripts (excluding tests)
Files:
nemo_rl/models/policy/workers/megatron_policy_worker.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: Post submodule check comment / Comment on PR
- GitHub Check: Post automodel integration comment / Comment on PR
| if "attention_backend" in self.cfg["megatron_cfg"]: | ||
| from megatron.core.transformer.enums import AttnBackend | ||
| attention_backend_str = self.cfg["megatron_cfg"]["attention_backend"] | ||
| model_cfg.attention_backend = AttnBackend[attention_backend_str] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# Search for AttnBackend usage and imports
rg "AttnBackend" --type py -B 2 -A 2Repository: NVIDIA-NeMo/RL
Length of output: 802
🏁 Script executed:
#!/bin/bash
# Look for attention_backend configuration examples in YAML files
find . -name "*.yaml" -o -name "*.yml" | xargs grep -l "attention_backend" 2>/dev/nullRepository: NVIDIA-NeMo/RL
Length of output: 42
🏁 Script executed:
#!/bin/bash
# Check if megatron is imported or available in the project dependencies
fd "requirements" | xargs cat 2>/dev/null | grep -i megatronRepository: NVIDIA-NeMo/RL
Length of output: 40
🏁 Script executed:
#!/bin/bash
# Search for example configs with attention_backend
fd "\.yaml$" -o "\.yml$" | xargs grep -H "attention_backend" 2>/dev/null | head -20Repository: NVIDIA-NeMo/RL
Length of output: 189
🌐 Web query:
megatron-core AttnBackend enum member names values
💡 Result:
AttnBackend enum members (megatron-core) — names and integer values:
- auto = 5. [1]
- flash = 1. [1]
- fused = 2. [1]
- local = 4. [1]
- unfused = 3. [1]
Sources: Megatron Core transformer enums documentation. [1]
🏁 Script executed:
#!/bin/bash
# Search for example configs with attention_backend
find . -type f \( -name "*.yaml" -o -name "*.yml" \) -exec grep -l "attention_backend" {} \;Repository: NVIDIA-NeMo/RL
Length of output: 40
🏁 Script executed:
#!/bin/bash
# Search in examples and tests directories
find ./examples -type f \( -name "*.yaml" -o -name "*.yml" \) 2>/dev/null | xargs grep -H "attention_backend" 2>/dev/nullRepository: NVIDIA-NeMo/RL
Length of output: 42
🏁 Script executed:
#!/bin/bash
# Check test files for megatron_policy_worker
find . -type f -name "*test*.py" -exec grep -l "megatron_policy_worker\|attention_backend" {} \;Repository: NVIDIA-NeMo/RL
Length of output: 40
🏁 Script executed:
#!/bin/bash
# Examine the megatron_policy_worker.py file context around lines 628-631
sed -n '620,640p' nemo_rl/models/policy/workers/megatron_policy_worker.pyRepository: NVIDIA-NeMo/RL
Length of output: 1165
🏁 Script executed:
#!/bin/bash
# Search for schema definitions or TypedDict for megatron_cfg
rg "megatron_cfg" --type py -B 3 -A 3 | grep -A 5 -B 5 "TypedDict\|class.*Config"Repository: NVIDIA-NeMo/RL
Length of output: 40
🏁 Script executed:
#!/bin/bash
# Check for any validation or schema files
find . -type f \( -name "*schema*" -o -name "*config*.py" \) | head -20Repository: NVIDIA-NeMo/RL
Length of output: 259
🏁 Script executed:
#!/bin/bash
# Look for get() usage pattern with attention_backend or similar enum config
rg "attention_backend\|AttnBackend" --type py -B 5 -A 5Repository: NVIDIA-NeMo/RL
Length of output: 40
🏁 Script executed:
#!/bin/bash
# Check if there are any comments or docstrings about attention_backend in the file
sed -n '600,650p' nemo_rl/models/policy/workers/megatron_policy_worker.py | head -50Repository: NVIDIA-NeMo/RL
Length of output: 2771
🏁 Script executed:
#!/bin/bash
# Check the config validation module
cat nemo_rl/utils/config.py | head -100Repository: NVIDIA-NeMo/RL
Length of output: 3196
🏁 Script executed:
#!/bin/bash
# Search for megatron_cfg configuration definition
rg "megatron_cfg" --type py -B 2 -A 2 | grep -A 10 "TypedDict\|class\|Dict" | head -40Repository: NVIDIA-NeMo/RL
Length of output: 40
🌐 Web query:
megatron core AttnBackend enum member names case sensitivity Python
💡 Result:
Short answer: Enum member names are case‑sensitive in Python, so you must use the exact AttnBackend member names defined by Megatron Core. The Megatron Core docs list AttnBackend members (auto, flash, fused, local, unfused). Use those exact names (or access by value) or normalize your input before lookup. [1][2]
References:
- Megatron Core API — AttnBackend members. [1]
- Python enum documentation (case sensitivity / name lookup). [2]
Consider hardening attention_backend lookup to handle null values and provide clearer error messages.
The current code uses AttnBackend[attention_backend_str] which relies on enum member name lookup. While the enum members are lowercase (flash, fused, local, unfused, auto) matching expected YAML values, two edge cases could be improved:
- If YAML provides
attention_backend: null, the guard passes but the lookup fails with a bare KeyError - Invalid member names produce unclear error messages
Consider using .get() with explicit None check and wrapping the enum lookup in a try-except for clearer error reporting:
if "attention_backend" in self.cfg["megatron_cfg"]:
from megatron.core.transformer.enums import AttnBackend
- attention_backend_str = self.cfg["megatron_cfg"]["attention_backend"]
- model_cfg.attention_backend = AttnBackend[attention_backend_str]
+ attention_backend_str = self.cfg["megatron_cfg"].get("attention_backend")
+ if attention_backend_str is not None:
+ try:
+ model_cfg.attention_backend = AttnBackend[attention_backend_str]
+ except KeyError as exc:
+ allowed = ", ".join(m.name for m in AttnBackend)
+ raise ValueError(
+ f"Invalid megatron_cfg.attention_backend={attention_backend_str!r}. Allowed: {allowed}"
+ ) from exc🤖 Prompt for AI Agents
In nemo_rl/models/policy/workers/megatron_policy_worker.py around lines 628 to
631, the code directly indexes AttnBackend with attention_backend_str which will
raise a raw KeyError when the YAML provides null/empty or an invalid name;
update the block to first check for None/empty and skip or handle accordingly,
then wrap the enum lookup in a try/except catching KeyError (or ValueError) and
raise/log a clearer error that includes the provided value and the list of
allowed enum members (optionally normalize case with .lower() before lookup if
desired).
284865b to
5b401cc
Compare
Signed-off-by: Sahger Lad <lad.sahger@gmail.com>
Adds support for testing attention_backend config in megatron_cfg. Includes a new test case that verifies attention_backend='flash' works with the Megatron policy worker. Signed-off-by: Sahger Lad <lad.sahger@gmail.com>
7884009 to
9a97ffe
Compare
What does this PR do ?
Enable configuring the attention backend (flash, fused, unfused, local, auto) via megatron_cfg.attention_backend in the YAML configuration.
policy.megatron_cfg.attention_backendin YAML configurationflash,fused,unfused,local, orautoattention implementationsUsage
Summary by CodeRabbit
✏️ Tip: You can customize this high-level summary in your review settings.