Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 4 additions & 127 deletions nemo_rl/models/megatron/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,27 +11,17 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
from typing import Any, Iterator, Optional

from typing import Optional, Any

import torch
import torch.distributed as dist
from megatron.bridge.training.state import GlobalState
from megatron.core.models.gpt import GPTModel
from megatron.core.parallel_state import (
get_context_parallel_group,
get_context_parallel_world_size,
get_tensor_model_parallel_group,
get_tensor_model_parallel_rank,
)

from megatron.core.transformer.moe.moe_utils import (
clear_aux_losses_tracker,
get_moe_layer_wise_logging_tracker,
reduce_aux_losses_tracker_across_ranks,
)
from nemo_rl.algorithms.loss_functions import LossFunction, SequencePackingLossWrapper
from nemo_rl.distributed.batched_data_dict import BatchedDataDict


def _round_up_to_multiple(value: int, multiple: int) -> int:
return (
Expand All @@ -40,119 +30,6 @@ def _round_up_to_multiple(value: int, multiple: int) -> int:
else value
)

def forward_step_arbitrary_loss(
state: GlobalState,
global_valid_seqs: torch.Tensor,
global_valid_toks: torch.Tensor,
data_iterator: Iterator[BatchedDataDict[Any]],
model: GPTModel,
loss_fn: LossFunction,
pack_sequences: bool = False,
defer_fp32_logits: Optional[bool] = None,
cp_normalize: bool = True,
policy_cfg: Optional[dict] = None,
):
"""Forward training step with support for packed sequences and context parallelism.

Args:
state (GlobalState): Global state for the run
global_valid_seqs: Global count of valid sequences
global_valid_toks: Global count of valid tokens
data_iterator: Input data iterator
model (GPTModel): The GPT Model
loss_fn (LossFunction): Loss function to apply
pack_sequences (bool): Whether to pack sequences for efficiency
defer_fp32_logits (Optional[bool]): Whether to skip the conversion of logits to fp32
cp_normalize (bool): Whether to normalize the loss by the cp_size
policy_cfg (Optional[dict]): Policy configuration containing generation parameters

Notes on packed sequences with context parallelism (CP):
- When CP > 1, each sequence is padded to a multiple of (cp_size * 2)
- The factor of 2 ensures load balancing for causal attention
- cu_seqlens tracks actual sequence boundaries
- cu_seqlens_padded tracks padded sequence boundaries for CP
- Requires TransformerEngine >= 1.10 for CP support
"""
straggler_timer = state.straggler_timer

# Get the pre-processed microbatch from the iterator
processed_mb = next(data_iterator)

# Extract the processed components
data_dict = processed_mb.data_dict
input_ids = processed_mb.input_ids
input_ids_cp_sharded = processed_mb.input_ids_cp_sharded
attention_mask = processed_mb.attention_mask
position_ids = processed_mb.position_ids
packed_seq_params = processed_mb.packed_seq_params
cu_seqlens_padded = processed_mb.cu_seqlens_padded

multimodal_data = data_dict.get_multimodal_dict(
as_tensors=True, device=input_ids_cp_sharded.device
)
if len(multimodal_data) > 0:
position_ids = None

additional_kwargs = {}
# Mamba models currently do not support packed_seq_params
if packed_seq_params is not None:
additional_kwargs["packed_seq_params"] = packed_seq_params

if defer_fp32_logits:
additional_kwargs["fp32_output"] = False

with straggler_timer:
output_tensor = model(
input_ids=input_ids_cp_sharded,
position_ids=position_ids,
attention_mask=attention_mask,
**additional_kwargs,
**multimodal_data,
)

# Apply temperature scaling to logits for training
# This matches the dtensor worker's _apply_temperature_scaling in the train method
if (
policy_cfg is not None
and "generation" in policy_cfg
and policy_cfg["generation"] is not None
):
output_tensor.div_(policy_cfg["generation"]["temperature"])

# Unpack the output tensor if we did packed sequences
if pack_sequences and packed_seq_params is not None:
# remove padding
loss_fn = SequencePackingLossWrapper(
loss_fn=loss_fn,
cu_seqlens_q=packed_seq_params.cu_seqlens_q,
cu_seqlens_q_padded=packed_seq_params.cu_seqlens_q_padded,
)

loss_data = data_dict

loss_fn_wrapped = partial(
loss_fn,
data=loss_data,
global_valid_seqs=global_valid_seqs,
global_valid_toks=global_valid_toks,
vocab_parallel_rank=get_tensor_model_parallel_rank(),
vocab_parallel_group=get_tensor_model_parallel_group(),
context_parallel_group=get_context_parallel_group(),
)

if cp_normalize:
cp_size = get_context_parallel_world_size()
orig_loss_fn_wrapped = loss_fn_wrapped

def _div_by_cp_size(*args, **kwargs):
loss, metrics = orig_loss_fn_wrapped(*args, **kwargs)
return loss / cp_size, metrics

loss_fn_wrapped = _div_by_cp_size

return output_tensor, loss_fn_wrapped


def broadcast_tensor(
tensor: torch.Tensor | None, src_rank: int, group: dist.ProcessGroup
) -> torch.Tensor:
Expand Down Expand Up @@ -285,4 +162,4 @@ def get_moe_metrics(
metrics[f"moe/{name}_layer_{i}"] = float(loss)

clear_aux_losses_tracker()
return metrics
return metrics
141 changes: 141 additions & 0 deletions nemo_rl/models/megatron/pipeline_parallel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Pipeline parallel utilities for Megatron models."""

from typing import Any, Optional

import torch
from megatron.core.parallel_state import (
get_pipeline_model_parallel_group,
get_pipeline_model_parallel_last_rank,
get_pipeline_model_parallel_world_size,
is_pipeline_last_stage,
)

def broadcast_obj_from_pp_rank(obj: Any) -> Any:
"""Broadcast an object across pipeline parallel ranks.
This utility function handles broadcasting an object from the rank that owns it
to all other pipeline parallel ranks. If only one rank has the object (non-None),
it will be broadcast to all other ranks.
Args:
obj: The object to broadcast. Can be None on ranks that don't own it.
Returns:
The object on all ranks (either the original or the broadcast copy).
Raises:
ValueError: If the object doesn't exist on any pipeline parallel rank.
"""
pp_size = get_pipeline_model_parallel_world_size()
pp_group = get_pipeline_model_parallel_group()

if pp_size == 1:
return obj

# ------------------------------------------------------------------
# 1. Gather presence flags from all PP ranks to find the source rank
# ------------------------------------------------------------------
has_obj = obj is not None
obj_flags = [None] * pp_size
torch.distributed.all_gather_object(obj_flags, has_obj, group=pp_group)

# ------------------------------------------------------------------
# 2. Identify the owning rank (the only rank with True flag)
# ------------------------------------------------------------------
src_rank = None # Rank *inside* the PP group
for rank, flag in enumerate(obj_flags):
if flag:
src_rank = rank
break

if src_rank is None:
raise ValueError("Object must exist on at least one PP rank")

# ------------------------------------------------------------------
# 3. Broadcast the object from the source rank to all ranks
# ------------------------------------------------------------------
# Use broadcast_object_list which is more robust than all_gather_object
obj_list = [obj]
pp_ranks = torch.distributed.get_process_group_ranks(pp_group)
global_src = pp_ranks[src_rank]
torch.distributed.broadcast_object_list(obj_list, src=global_src, group=pp_group)

return obj_list[0]

def broadcast_loss_metrics_from_last_stage(loss_metrics: Optional[list] = None) -> list:
"""Broadcast loss metrics from the last pipeline stage to all stages.

This utility handles the common pattern where loss computation happens on the last
pipeline stage and needs to be broadcast to all other stages.

Args:
loss_metrics: List of loss metrics if on last stage, None otherwise

Returns:
List of loss metrics on all ranks
"""
pp_group = get_pipeline_model_parallel_group()
last_rank = get_pipeline_model_parallel_last_rank()

if is_pipeline_last_stage(ignore_virtual=True):
metrics_to_broadcast = [loss_metrics]
torch.distributed.broadcast_object_list(
metrics_to_broadcast,
src=last_rank,
group=pp_group,
)
return loss_metrics
else:
metrics_to_broadcast = [None]
torch.distributed.broadcast_object_list(
metrics_to_broadcast,
src=last_rank,
group=pp_group,
)
return metrics_to_broadcast[0]


def broadcast_tensors_from_last_stage(
tensors: dict[str, Optional[torch.Tensor]],
) -> dict[str, torch.Tensor]:
"""Broadcast multiple tensors from the last pipeline stage to all stages.

Args:
tensors: Dictionary mapping tensor names to tensors (None on non-last stages)
pp_group: Pipeline parallel group (auto-detected if None)

Returns:
Dictionary of broadcasted tensors on all ranks
"""
pp_group = get_pipeline_model_parallel_group()

from nemo_rl.models.megatron.common import broadcast_tensor

last_rank = get_pipeline_model_parallel_last_rank()
current_rank = torch.distributed.get_rank()

broadcasted_tensors = {}

if is_pipeline_last_stage(ignore_virtual=True):
# Broadcast tensors from last stage
for name, tensor in tensors.items():
if tensor is not None:
broadcasted_tensors[name] = broadcast_tensor(tensor, current_rank, pp_group)
else:
broadcasted_tensors[name] = None
else:
# Receive tensors on other stages
for name in tensors.keys():
broadcasted_tensors[name] = broadcast_tensor(None, last_rank, pp_group)

return broadcasted_tensors
Loading
Loading