From fc62be880d282b951a58e5f9e782a4c3cf61f1a2 Mon Sep 17 00:00:00 2001 From: ashors1 Date: Wed, 17 Dec 2025 14:53:32 -0800 Subject: [PATCH 1/4] fix Signed-off-by: ashors1 --- nemo_rl/models/policy/workers/megatron_policy_worker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemo_rl/models/policy/workers/megatron_policy_worker.py b/nemo_rl/models/policy/workers/megatron_policy_worker.py index 3517061559..27429aaee1 100644 --- a/nemo_rl/models/policy/workers/megatron_policy_worker.py +++ b/nemo_rl/models/policy/workers/megatron_policy_worker.py @@ -1563,7 +1563,7 @@ def collection_fn(_): for out in list_of_outputs: tk = out["topk_logits"] ti = out["topk_indices"] - pad_len = input_seq_dim_size - tk.shape[1] + pad_len = padded_seq_length - tk.shape[1] if pad_len > 0: tk = torch.nn.functional.pad(tk, (0, 0, 0, pad_len), value=0.0) ti = torch.nn.functional.pad(ti, (0, 0, 0, pad_len), value=0) From 50689b57965e6fdbd6c29f98a7f9dc211e3608a1 Mon Sep 17 00:00:00 2001 From: ashors1 Date: Wed, 17 Dec 2025 15:42:43 -0800 Subject: [PATCH 2/4] refactor train/forward utilities Signed-off-by: ashors1 --- nemo_rl/models/megatron/common.py | 131 +---- nemo_rl/models/megatron/pipeline_parallel.py | 141 +++++ nemo_rl/models/megatron/train.py | 509 ++++++++++++++++++ .../policy/workers/megatron_policy_worker.py | 431 +++------------ 4 files changed, 719 insertions(+), 493 deletions(-) create mode 100644 nemo_rl/models/megatron/pipeline_parallel.py create mode 100644 nemo_rl/models/megatron/train.py diff --git a/nemo_rl/models/megatron/common.py b/nemo_rl/models/megatron/common.py index 6c83851076..bf50da147b 100644 --- a/nemo_rl/models/megatron/common.py +++ b/nemo_rl/models/megatron/common.py @@ -11,27 +11,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from functools import partial -from typing import Any, Iterator, Optional + +from typing import Optional, Any import torch import torch.distributed as dist -from megatron.bridge.training.state import GlobalState -from megatron.core.models.gpt import GPTModel -from megatron.core.parallel_state import ( - get_context_parallel_group, - get_context_parallel_world_size, - get_tensor_model_parallel_group, - get_tensor_model_parallel_rank, -) + from megatron.core.transformer.moe.moe_utils import ( clear_aux_losses_tracker, get_moe_layer_wise_logging_tracker, reduce_aux_losses_tracker_across_ranks, ) -from nemo_rl.algorithms.loss_functions import LossFunction, SequencePackingLossWrapper -from nemo_rl.distributed.batched_data_dict import BatchedDataDict - def _round_up_to_multiple(value: int, multiple: int) -> int: return ( @@ -40,119 +30,6 @@ def _round_up_to_multiple(value: int, multiple: int) -> int: else value ) -def forward_step_arbitrary_loss( - state: GlobalState, - global_valid_seqs: torch.Tensor, - global_valid_toks: torch.Tensor, - data_iterator: Iterator[BatchedDataDict[Any]], - model: GPTModel, - loss_fn: LossFunction, - pack_sequences: bool = False, - defer_fp32_logits: Optional[bool] = None, - cp_normalize: bool = True, - policy_cfg: Optional[dict] = None, -): - """Forward training step with support for packed sequences and context parallelism. - - Args: - state (GlobalState): Global state for the run - global_valid_seqs: Global count of valid sequences - global_valid_toks: Global count of valid tokens - data_iterator: Input data iterator - model (GPTModel): The GPT Model - loss_fn (LossFunction): Loss function to apply - pack_sequences (bool): Whether to pack sequences for efficiency - defer_fp32_logits (Optional[bool]): Whether to skip the conversion of logits to fp32 - cp_normalize (bool): Whether to normalize the loss by the cp_size - policy_cfg (Optional[dict]): Policy configuration containing generation parameters - - Notes on packed sequences with context parallelism (CP): - - When CP > 1, each sequence is padded to a multiple of (cp_size * 2) - - The factor of 2 ensures load balancing for causal attention - - cu_seqlens tracks actual sequence boundaries - - cu_seqlens_padded tracks padded sequence boundaries for CP - - Requires TransformerEngine >= 1.10 for CP support - """ - straggler_timer = state.straggler_timer - - # Get the pre-processed microbatch from the iterator - processed_mb = next(data_iterator) - - # Extract the processed components - data_dict = processed_mb.data_dict - input_ids = processed_mb.input_ids - input_ids_cp_sharded = processed_mb.input_ids_cp_sharded - attention_mask = processed_mb.attention_mask - position_ids = processed_mb.position_ids - packed_seq_params = processed_mb.packed_seq_params - cu_seqlens_padded = processed_mb.cu_seqlens_padded - - multimodal_data = data_dict.get_multimodal_dict( - as_tensors=True, device=input_ids_cp_sharded.device - ) - if len(multimodal_data) > 0: - position_ids = None - - additional_kwargs = {} - # Mamba models currently do not support packed_seq_params - if packed_seq_params is not None: - additional_kwargs["packed_seq_params"] = packed_seq_params - - if defer_fp32_logits: - additional_kwargs["fp32_output"] = False - - with straggler_timer: - output_tensor = model( - input_ids=input_ids_cp_sharded, - position_ids=position_ids, - attention_mask=attention_mask, - **additional_kwargs, - **multimodal_data, - ) - - # Apply temperature scaling to logits for training - # This matches the dtensor worker's _apply_temperature_scaling in the train method - if ( - policy_cfg is not None - and "generation" in policy_cfg - and policy_cfg["generation"] is not None - ): - output_tensor.div_(policy_cfg["generation"]["temperature"]) - - # Unpack the output tensor if we did packed sequences - if pack_sequences and packed_seq_params is not None: - # remove padding - loss_fn = SequencePackingLossWrapper( - loss_fn=loss_fn, - cu_seqlens_q=packed_seq_params.cu_seqlens_q, - cu_seqlens_q_padded=packed_seq_params.cu_seqlens_q_padded, - ) - - loss_data = data_dict - - loss_fn_wrapped = partial( - loss_fn, - data=loss_data, - global_valid_seqs=global_valid_seqs, - global_valid_toks=global_valid_toks, - vocab_parallel_rank=get_tensor_model_parallel_rank(), - vocab_parallel_group=get_tensor_model_parallel_group(), - context_parallel_group=get_context_parallel_group(), - ) - - if cp_normalize: - cp_size = get_context_parallel_world_size() - orig_loss_fn_wrapped = loss_fn_wrapped - - def _div_by_cp_size(*args, **kwargs): - loss, metrics = orig_loss_fn_wrapped(*args, **kwargs) - return loss / cp_size, metrics - - loss_fn_wrapped = _div_by_cp_size - - return output_tensor, loss_fn_wrapped - - def broadcast_tensor( tensor: torch.Tensor | None, src_rank: int, group: dist.ProcessGroup ) -> torch.Tensor: @@ -285,4 +162,4 @@ def get_moe_metrics( metrics[f"moe/{name}_layer_{i}"] = float(loss) clear_aux_losses_tracker() - return metrics + return metrics \ No newline at end of file diff --git a/nemo_rl/models/megatron/pipeline_parallel.py b/nemo_rl/models/megatron/pipeline_parallel.py new file mode 100644 index 0000000000..afb1bfda8f --- /dev/null +++ b/nemo_rl/models/megatron/pipeline_parallel.py @@ -0,0 +1,141 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Pipeline parallel utilities for Megatron models.""" + +from typing import Any, Optional + +import torch +from megatron.core.parallel_state import ( + get_pipeline_model_parallel_group, + get_pipeline_model_parallel_last_rank, + get_pipeline_model_parallel_world_size, + is_pipeline_last_stage, +) + +def broadcast_obj_from_pp_rank(obj: Any) -> Any: + """Broadcast an object across pipeline parallel ranks. + This utility function handles broadcasting an object from the rank that owns it + to all other pipeline parallel ranks. If only one rank has the object (non-None), + it will be broadcast to all other ranks. + Args: + obj: The object to broadcast. Can be None on ranks that don't own it. + Returns: + The object on all ranks (either the original or the broadcast copy). + Raises: + ValueError: If the object doesn't exist on any pipeline parallel rank. + """ + pp_size = get_pipeline_model_parallel_world_size() + pp_group = get_pipeline_model_parallel_group() + + if pp_size == 1: + return obj + + # ------------------------------------------------------------------ + # 1. Gather presence flags from all PP ranks to find the source rank + # ------------------------------------------------------------------ + has_obj = obj is not None + obj_flags = [None] * pp_size + torch.distributed.all_gather_object(obj_flags, has_obj, group=pp_group) + + # ------------------------------------------------------------------ + # 2. Identify the owning rank (the only rank with True flag) + # ------------------------------------------------------------------ + src_rank = None # Rank *inside* the PP group + for rank, flag in enumerate(obj_flags): + if flag: + src_rank = rank + break + + if src_rank is None: + raise ValueError("Object must exist on at least one PP rank") + + # ------------------------------------------------------------------ + # 3. Broadcast the object from the source rank to all ranks + # ------------------------------------------------------------------ + # Use broadcast_object_list which is more robust than all_gather_object + obj_list = [obj] + pp_ranks = torch.distributed.get_process_group_ranks(pp_group) + global_src = pp_ranks[src_rank] + torch.distributed.broadcast_object_list(obj_list, src=global_src, group=pp_group) + + return obj_list[0] + +def broadcast_loss_metrics_from_last_stage(loss_metrics: Optional[list] = None) -> list: + """Broadcast loss metrics from the last pipeline stage to all stages. + + This utility handles the common pattern where loss computation happens on the last + pipeline stage and needs to be broadcast to all other stages. + + Args: + loss_metrics: List of loss metrics if on last stage, None otherwise + + Returns: + List of loss metrics on all ranks + """ + pp_group = get_pipeline_model_parallel_group() + last_rank = get_pipeline_model_parallel_last_rank() + + if is_pipeline_last_stage(ignore_virtual=True): + metrics_to_broadcast = [loss_metrics] + torch.distributed.broadcast_object_list( + metrics_to_broadcast, + src=last_rank, + group=pp_group, + ) + return loss_metrics + else: + metrics_to_broadcast = [None] + torch.distributed.broadcast_object_list( + metrics_to_broadcast, + src=last_rank, + group=pp_group, + ) + return metrics_to_broadcast[0] + + +def broadcast_tensors_from_last_stage( + tensors: dict[str, Optional[torch.Tensor]], +) -> dict[str, torch.Tensor]: + """Broadcast multiple tensors from the last pipeline stage to all stages. + + Args: + tensors: Dictionary mapping tensor names to tensors (None on non-last stages) + pp_group: Pipeline parallel group (auto-detected if None) + + Returns: + Dictionary of broadcasted tensors on all ranks + """ + pp_group = get_pipeline_model_parallel_group() + + from nemo_rl.models.megatron.common import broadcast_tensor + + last_rank = get_pipeline_model_parallel_last_rank() + current_rank = torch.distributed.get_rank() + + broadcasted_tensors = {} + + if is_pipeline_last_stage(ignore_virtual=True): + # Broadcast tensors from last stage + for name, tensor in tensors.items(): + if tensor is not None: + broadcasted_tensors[name] = broadcast_tensor(tensor, current_rank, pp_group) + else: + broadcasted_tensors[name] = None + else: + # Receive tensors on other stages + for name in tensors.keys(): + broadcasted_tensors[name] = broadcast_tensor(None, last_rank, pp_group) + + return broadcasted_tensors \ No newline at end of file diff --git a/nemo_rl/models/megatron/train.py b/nemo_rl/models/megatron/train.py new file mode 100644 index 0000000000..4b2345a73c --- /dev/null +++ b/nemo_rl/models/megatron/train.py @@ -0,0 +1,509 @@ + +from functools import partial +from typing import Any, Callable, Dict, Iterator, Optional, Tuple, Union + +import torch + +from megatron.core.models.gpt import GPTModel +from megatron.core.packed_seq_params import PackedSeqParams +from megatron.core.parallel_state import ( + get_context_parallel_group, + get_context_parallel_world_size, + get_tensor_model_parallel_group, + get_tensor_model_parallel_rank, +) +from megatron.core.pipeline_parallel import get_forward_backward_func + +from nemo_rl.algorithms.loss_functions import LossFunction, SequencePackingLossWrapper +from nemo_rl.distributed.batched_data_dict import BatchedDataDict +from nemo_rl.distributed.model_utils import ( + allgather_cp_sharded_tensor, + distributed_vocab_topk, + from_parallel_logits_to_logprobs, + from_parallel_logits_to_logprobs_packed_sequences, +) +from nemo_rl.models.megatron.data import ProcessedMicrobatch + + +# Union type for any post-processing function (defined after classes below) +PostProcessingFunction = Union[ + "LossPostProcessor", + "LogprobsPostProcessor", + "TopkLogitsPostProcessor", +] + + +def model_forward( + model: GPTModel, + data_dict: BatchedDataDict[Any], + cfg: Dict[str, Any], + input_ids_cp_sharded: torch.Tensor, + position_ids: torch.Tensor, + attention_mask: torch.Tensor, + packed_seq_params: Optional[PackedSeqParams] = None, + defer_fp32_logits: Optional[bool] = None, +) -> torch.Tensor: + """ + Perform a single forward pass through the model. + + Args: + model: The model to run forward pass on + data_dict (BatchedDataDict): Dictionary containing batch data + cfg (dict): Configuration dictionary + input_ids_cp_sharded: Context-parallel sharded input token IDs + position_ids: Position IDs for tokens + attention_mask: Attention mask for the sequence + packed_seq_params: Parameters for packed sequences (optional) + defer_fp32_logits (Optional[bool]): Whether to skip the conversion of logits to fp32 + + Returns: + torch.Tensor: Output tensor from the model (logits) + """ + multimodal_data = data_dict.get_multimodal_dict( + as_tensors=True, device=input_ids_cp_sharded.device + ) + if len(multimodal_data) > 0: + position_ids = None + + additional_kwargs = {} + # Mamba models currently do not support packed_seq_params + if packed_seq_params is not None: + additional_kwargs["packed_seq_params"] = packed_seq_params + if defer_fp32_logits: + additional_kwargs["fp32_output"] = False + #with straggler_timer: + output_tensor = model( + input_ids=input_ids_cp_sharded, + position_ids=position_ids, + attention_mask=attention_mask, + **additional_kwargs, + **multimodal_data, + ) + + # Apply temperature scaling to logits for training + # This matches the dtensor worker's _apply_temperature_scaling in the train method + if ( + "generation" in cfg + and cfg["generation"] is not None + ): + output_tensor.div_(cfg["generation"]["temperature"]) + + return output_tensor + +def forward_with_post_processing_fn( + data_iterator: Iterator[ProcessedMicrobatch], + model: GPTModel, + cfg: Dict[str, Any], + post_processing_fn: PostProcessingFunction, + defer_fp32_logits: Optional[bool] = True, + global_valid_seqs: Optional[torch.Tensor] = None, + global_valid_toks: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, Callable]: + """ + Perform forward pass with pre-processed microbatch and return output tensor and post-processing function. + + This function takes a pre-processed microbatch (with sequence packing already handled), + runs the forward step through the model, and prepares a post-processing function for + post-processing the outputs. + + Args: + data_iterator: Iterator yielding ProcessedMicrobatch objects (already processed) + model: The model to run forward pass on + cfg (dict): Configuration dictionary + post_processing_fn: Post-processing function to post-process the logits + defer_fp32_logits: Whether to defer FP32 conversion of logits + global_valid_seqs: Global valid sequence count for loss normalization + global_valid_toks: Global valid token count for loss normalization + + Returns: + tuple: (output_tensor, post_processing_fn_wrapped) + - output_tensor: Raw model outputs (logits) + - post_processing_fn_wrapped: Function to create output post-processing function when called + """ + # Get the pre-processed microbatch from the iterator + processed_mb = next(data_iterator) + + # Extract the processed components + data_dict = processed_mb.data_dict + input_ids = processed_mb.input_ids + input_ids_cp_sharded = processed_mb.input_ids_cp_sharded + attention_mask = processed_mb.attention_mask + position_ids = processed_mb.position_ids + packed_seq_params = processed_mb.packed_seq_params + cu_seqlens_padded = processed_mb.cu_seqlens_padded + + output_tensor = model_forward( + model, + data_dict, + cfg, + input_ids_cp_sharded, + position_ids, + attention_mask, + packed_seq_params, + defer_fp32_logits, + ) + + ## calling post_processing_fn will return a function that takes the output tensor and returns a tuple of (loss, metrics) + # Use type checking to dispatch to the correct post-processing method + if isinstance(post_processing_fn, LossPostProcessor): + post_processing_fn_wrapped = post_processing_fn( + data_dict=data_dict, + packed_seq_params=packed_seq_params, + global_valid_seqs=global_valid_seqs, + global_valid_toks=global_valid_toks, + ) + elif isinstance(post_processing_fn, LogprobsPostProcessor): + post_processing_fn_wrapped = post_processing_fn( + data_dict=data_dict, + input_ids=input_ids, + cu_seqlens_padded=cu_seqlens_padded, + ) + elif isinstance(post_processing_fn, TopkLogitsPostProcessor): + post_processing_fn_wrapped = post_processing_fn( + data_dict=data_dict, + cu_seqlens_padded=cu_seqlens_padded, + ) + else: + raise TypeError(f"Unknown post-processing function type: {type(post_processing_fn)}") + + return output_tensor, post_processing_fn_wrapped + +def megatron_forward_backward( + model: GPTModel, + cfg: Dict[str, Any], + data_iterator: Iterator[ProcessedMicrobatch], + num_microbatches: int, + seq_length: int, + mbs: int, + post_processing_fn: PostProcessingFunction, + forward_only: bool = False, + defer_fp32_logits: Optional[bool] = None, + global_valid_seqs: Optional[torch.Tensor] = None, + global_valid_toks: Optional[torch.Tensor] = None, +) -> Any: + """ + Execute forward and backward passes using Megatron's utilities. + + This is the main training loop function that coordinates forward and backward + passes across multiple microbatches using Megatron's pipeline parallel + execution framework. + + Args: + model: The model to train + cfg (dict): Configuration dictionary + data_iterator: Iterator yielding ProcessedMicrobatch objects (already processed) + num_microbatches (int): Number of microbatches to process + seq_length (int): Sequence length + mbs (int): Micro batch size + post_processing_fn: Post-processing function to post-process the logits + forward_only (bool): If True, skip backward pass + defer_fp32_logits (Optional[bool]): Whether to skip the conversion of logits to fp32 + global_valid_seqs: Global valid sequence count for loss normalization + global_valid_toks: Global valid token count for loss normalization + + Returns: + BatchedDataDict: Results from the forward/backward execution + """ + forward_step = partial( + forward_with_post_processing_fn, + cfg=cfg, + post_processing_fn=post_processing_fn, + defer_fp32_logits=defer_fp32_logits, + global_valid_seqs=global_valid_seqs, + global_valid_toks=global_valid_toks, + ) + forward_backward_func = get_forward_backward_func() + return forward_backward_func( + forward_step_func=forward_step, + data_iterator=data_iterator, + model=model, + num_microbatches=num_microbatches, + seq_length=seq_length, + micro_batch_size=mbs, + decoder_seq_length=seq_length, + forward_only=forward_only, + ) + +class LossPostProcessor: + + def __init__( + self, + loss_fn: LossFunction, + cfg: Dict[str, Any], + cp_normalize: bool = True, + ): + self.loss_fn = loss_fn + self.cfg = cfg + self.cp_normalize = cp_normalize + + def __call__(self, + data_dict: BatchedDataDict[Any], + packed_seq_params: Optional[PackedSeqParams] = None, + global_valid_seqs: Optional[torch.Tensor] = None, + global_valid_toks: Optional[torch.Tensor] = None, + ) -> Callable[[torch.Tensor], Tuple[torch.Tensor, Dict[str, Any]]]: + """ + Create a loss post-processing function for training. + + This function wraps a loss function with the necessary context and parameters + to compute loss and metrics from model outputs. It handles sequence packing + and context parallelism normalization. + + Args: + loss_fn: The base loss function to wrap + cfg (dict): Configuration dictionary + data_dict: Batched data dictionary + packed_seq_params: Parameters for packed sequences (optional) + cp_normalize (bool): Whether to normalize by context parallel size + + Returns: + Callable: Function that takes output tensor and returns (loss, metrics) tuple + """ + + loss_fn = self.loss_fn + pack_sequences = self.cfg["sequence_packing"]["enabled"] + if pack_sequences and packed_seq_params is not None: + # remove padding + loss_fn = SequencePackingLossWrapper( + loss_fn=loss_fn, + cu_seqlens_q=packed_seq_params.cu_seqlens_q, + cu_seqlens_q_padded=packed_seq_params.cu_seqlens_q_padded, + ) + + loss_fn_wrapped = partial( + loss_fn, + data=data_dict, + global_valid_seqs=global_valid_seqs, + global_valid_toks=global_valid_toks, + vocab_parallel_rank=get_tensor_model_parallel_rank(), + vocab_parallel_group=get_tensor_model_parallel_group(), + context_parallel_group=get_context_parallel_group(), + ) + + if self.cp_normalize: + cp_size = get_context_parallel_world_size() + orig_loss_fn_wrapped = loss_fn_wrapped + + def _div_by_cp_size(*args, **kwargs): + loss, metrics = orig_loss_fn_wrapped(*args, **kwargs) + return loss / cp_size, metrics + + loss_fn_wrapped = _div_by_cp_size + + return loss_fn_wrapped + +class LogprobsPostProcessor: + + def __init__(self, cfg: Dict[str, Any]): + self.cfg = cfg + + def __call__( + self, + data_dict: BatchedDataDict[Any], + input_ids: torch.Tensor, + cu_seqlens_padded: torch.Tensor, + ) -> Callable[[torch.Tensor], Tuple[torch.Tensor, Dict[str, torch.Tensor]]]: + """ + Create a post-processing function that computes token log probabilities. + + This function returns a processor that takes model logits and converts them + to token-level log probabilities, handling both packed and unpacked sequences. + + Args: + data_dict: Batched data dictionary containing input sequences + input_ids: Processed input token IDs + cu_seqlens_padded: Cumulative sequence lengths for packed sequences + + Returns: + Callable: Function that takes output tensor and returns (dummy_loss, {"logprobs": token_logprobs}) + """ + unpacked_input_ids = data_dict["input_ids"] + original_seq_length = unpacked_input_ids.shape[1] + + def processor_fn_inner(output_tensor): + tp_grp = get_tensor_model_parallel_group() + tp_rank = get_tensor_model_parallel_rank() + logprob_chunk_size = self.cfg.get("logprob_chunk_size", None) + if self.cfg["sequence_packing"]["enabled"]: + token_logprobs = from_parallel_logits_to_logprobs_packed_sequences( + output_tensor, + target=input_ids, + cu_seqlens_padded=cu_seqlens_padded, + unpacked_seqlen=original_seq_length, + vocab_start_index=tp_rank * output_tensor.shape[-1], + vocab_end_index=(tp_rank + 1) * output_tensor.shape[-1], + group=tp_grp, + inference_only=True, + cp_group=get_context_parallel_group(), + chunk_size=logprob_chunk_size, + ) + else: + token_logprobs = from_parallel_logits_to_logprobs( + output_tensor, + target=unpacked_input_ids, + vocab_start_index=tp_rank * output_tensor.shape[-1], + vocab_end_index=(tp_rank + 1) * output_tensor.shape[-1], + tp_group=tp_grp, + inference_only=True, + chunk_size=logprob_chunk_size, + ) + + # Prepend 0 logprob for first token to maintain same sequence length as input + token_logprobs = torch.cat( + [torch.zeros_like(token_logprobs[:, :1]), token_logprobs], dim=1 + ) + return torch.tensor(0.0, device=token_logprobs.device), { + "logprobs": token_logprobs + } + return processor_fn_inner + + +class TopkLogitsPostProcessor: + + def __init__(self, cfg: Dict[str, Any], k: int): + self.cfg = cfg + self.k = k + + def __call__( + self, + data_dict: BatchedDataDict[Any], + cu_seqlens_padded: torch.Tensor, + ) -> Callable[[torch.Tensor], Tuple[torch.Tensor, Dict[str, torch.Tensor]]]: + """ + Create a post-processing function that computes top-k logits and indices. + + This function returns a processor that extracts the top-k highest logits + and their corresponding vocabulary indices from model outputs. It handles + tensor parallelism, context parallelism, and sequence packing. + + Args: + data_dict: Batched data dictionary + cu_seqlens_padded: Cumulative sequence lengths for packed sequences + + Returns: + Callable: Function that takes output tensor and returns + (dummy_loss, {"topk_logits": values, "topk_indices": indices}) + """ + + pack = self.cfg["sequence_packing"]["enabled"] + cp_size = self.cfg["megatron_cfg"]["context_parallel_size"] + unpacked_seqlen = data_dict["input_ids"].shape[1] + seq_lengths = data_dict["input_lengths"] + + def processor_fn_inner(output_tensor): + # Only the last PP stage produces final logits/top-k; earlier stages return empty + # if not is_pipeline_last_stage(ignore_virtual=True): + # return output_tensor.new_zeros(()), {} + + tp_grp = get_tensor_model_parallel_group() + tp_rank = get_tensor_model_parallel_rank() + vocab_shard_size = output_tensor.shape[-1] + vocab_start_index = tp_rank * vocab_shard_size + + chunk_size = None + if "logprob_chunk_size" in self.cfg: + chunk_size = self.cfg["logprob_chunk_size"] + + topk_vals_local, topk_idx_local = distributed_vocab_topk( + output_tensor, + self.k, + tp_grp, + vocab_start_index=vocab_start_index, + vocab_end_index=vocab_start_index + vocab_shard_size, + chunk_size=chunk_size, + ) + + if self.cfg["megatron_cfg"]["context_parallel_size"] > 1: + cp_grp = get_context_parallel_group() + if pack: + # Per-sequence CP allgather following packed-sequence logic + batch_size = data_dict["input_ids"].shape[0] + total_packed_len = int(cu_seqlens_padded[-1].item()) + + topk_vals_full = torch.zeros( + (1, total_packed_len, self.k), + dtype=topk_vals_local.dtype, + device=topk_vals_local.device, + ) + topk_idx_full = torch.zeros( + (1, total_packed_len, self.k), + dtype=topk_idx_local.dtype, + device=topk_idx_local.device, + ) + + for i in range(batch_size): + start_idx = int(cu_seqlens_padded[i].item()) + end_idx = int(cu_seqlens_padded[i + 1].item()) + if end_idx > start_idx: + local_vals_slice = topk_vals_local[ + :, start_idx // cp_size : end_idx // cp_size, : + ] + local_idx_slice = topk_idx_local[ + :, start_idx // cp_size : end_idx // cp_size, : + ] + gathered_vals = allgather_cp_sharded_tensor( + local_vals_slice, cp_grp, seq_dim=1 + ) + gathered_idx = allgather_cp_sharded_tensor( + local_idx_slice, cp_grp, seq_dim=1 + ) + # Some kernels may return [X, Y, k] where X*Y = (end_idx - start_idx). + # Flatten leading dims and reshape to [1, expected_len, k] to match target. + expected_len = end_idx - start_idx + if ( + gathered_vals.dim() == 3 + and gathered_vals.shape[1] != expected_len + ): + gathered_vals = gathered_vals.reshape( + 1, expected_len, gathered_vals.shape[-1] + ) + if ( + gathered_idx.dim() == 3 + and gathered_idx.shape[1] != expected_len + ): + gathered_idx = gathered_idx.reshape( + 1, expected_len, gathered_idx.shape[-1] + ) + topk_vals_full[:, start_idx:end_idx, :] = gathered_vals + topk_idx_full[:, start_idx:end_idx, :] = gathered_idx + else: + # Sequence packing must be enabled when CP > 1 + raise RuntimeError( + "Context Parallelism (CP>1) requires sequence packing to be enabled." + ) + else: + topk_vals_full = topk_vals_local + topk_idx_full = topk_idx_local + + if pack: + batch_size = data_dict["input_ids"].shape[0] + out_vals = torch.zeros( + (batch_size, unpacked_seqlen, self.k), + dtype=topk_vals_full.dtype, + device=topk_vals_full.device, + ) + out_idx = torch.zeros( + (batch_size, unpacked_seqlen, self.k), + dtype=topk_idx_full.dtype, + device=topk_idx_full.device, + ) + for i in range(batch_size): + seq_len = int(seq_lengths[i].item()) + start_idx = int(cu_seqlens_padded[i].item()) + if seq_len > 0: + out_vals[i, :seq_len, :] = topk_vals_full[ + 0, start_idx : start_idx + seq_len, : + ] + out_idx[i, :seq_len, :] = topk_idx_full[ + 0, start_idx : start_idx + seq_len, : + ] + return output_tensor.new_zeros(()), { + "topk_logits": out_vals, + "topk_indices": out_idx, + } + else: + return output_tensor.new_zeros(()), { + "topk_logits": topk_vals_full, + "topk_indices": topk_idx_full, + } + return processor_fn_inner \ No newline at end of file diff --git a/nemo_rl/models/policy/workers/megatron_policy_worker.py b/nemo_rl/models/policy/workers/megatron_policy_worker.py index 27429aaee1..bdce4233a4 100644 --- a/nemo_rl/models/policy/workers/megatron_policy_worker.py +++ b/nemo_rl/models/policy/workers/megatron_policy_worker.py @@ -71,18 +71,11 @@ from megatron.core.inference.text_generation_controllers.text_generation_controller import ( TextGenerationController, ) -from megatron.core.models.gpt import GPTModel from megatron.core.optimizer import ChainedOptimizer from megatron.core.parallel_state import ( - get_context_parallel_group, get_pipeline_model_parallel_group, - get_pipeline_model_parallel_last_rank, - get_pipeline_model_parallel_world_size, - get_tensor_model_parallel_group, - get_tensor_model_parallel_rank, is_pipeline_last_stage, ) -from megatron.core.pipeline_parallel import get_forward_backward_func from megatron.core.rerun_state_machine import get_rerun_state_machine from megatron.core.transformer.module import Float16Module from megatron.core.transformer.transformer_config import TransformerConfig @@ -91,12 +84,6 @@ from nemo_rl.algorithms.interfaces import LossFunction from nemo_rl.distributed.batched_data_dict import BatchedDataDict -from nemo_rl.distributed.model_utils import ( - allgather_cp_sharded_tensor, - distributed_vocab_topk, - from_parallel_logits_to_logprobs, - from_parallel_logits_to_logprobs_packed_sequences, -) from nemo_rl.distributed.named_sharding import NamedSharding from nemo_rl.models.generation.fp8 import ( convert_calibration_to_vllm_format, @@ -108,15 +95,22 @@ verify_right_padding, ) from nemo_rl.models.generation.vllm.config import VllmConfig -from nemo_rl.models.megatron.common import ( - broadcast_tensor, - forward_step_arbitrary_loss, - get_moe_metrics, -) +from nemo_rl.models.megatron.common import get_moe_metrics from nemo_rl.models.megatron.data import ( get_microbatch_iterator, process_global_batch, ) +from nemo_rl.models.megatron.pipeline_parallel import ( + broadcast_obj_from_pp_rank, + broadcast_loss_metrics_from_last_stage, + broadcast_tensors_from_last_stage, +) +from nemo_rl.models.megatron.train import ( + megatron_forward_backward, + LossPostProcessor, + LogprobsPostProcessor, + TopkLogitsPostProcessor, +) from nemo_rl.models.megatron.community_import import import_model_from_hf_name from nemo_rl.models.policy import PolicyConfig from nemo_rl.models.policy.interfaces import ( @@ -144,59 +138,6 @@ TokenizerType = TypeVar("TokenizerType", bound=PreTrainedTokenizerBase) -def broadcast_object_across_pp_ranks(obj): - """Broadcast an object across pipeline parallel ranks. - - This utility function handles broadcasting an object from the rank that owns it - to all other pipeline parallel ranks. If only one rank has the object (non-None), - it will be broadcast to all other ranks. - - Args: - obj: The object to broadcast. Can be None on ranks that don't own it. - - Returns: - The object on all ranks (either the original or the broadcast copy). - - Raises: - ValueError: If the object doesn't exist on any pipeline parallel rank. - """ - pp_size = get_pipeline_model_parallel_world_size() - pp_group = get_pipeline_model_parallel_group() - - if pp_size == 1: - return obj - - # ------------------------------------------------------------------ - # 1. Gather presence flags from all PP ranks to find the source rank - # ------------------------------------------------------------------ - has_obj = obj is not None - obj_flags = [None] * pp_size - torch.distributed.all_gather_object(obj_flags, has_obj, group=pp_group) - - # ------------------------------------------------------------------ - # 2. Identify the owning rank (the only rank with True flag) - # ------------------------------------------------------------------ - src_rank = None # Rank *inside* the PP group - for rank, flag in enumerate(obj_flags): - if flag: - src_rank = rank - break - - if src_rank is None: - raise ValueError("Object must exist on at least one PP rank") - - # ------------------------------------------------------------------ - # 3. Broadcast the object from the source rank to all ranks - # ------------------------------------------------------------------ - # Use broadcast_object_list which is more robust than all_gather_object - obj_list = [obj] - pp_ranks = torch.distributed.get_process_group_ranks(pp_group) - global_src = pp_ranks[src_rank] - torch.distributed.broadcast_object_list(obj_list, src=global_src, group=pp_group) - - return obj_list[0] - - def setup_megatron_model( policy_cfg: PolicyConfig, cfg: ConfigContainer, @@ -984,9 +925,6 @@ def train( f"Dim 1 must be the sequence dim, expected dim 1={seq_dim_size} but got shape {v.shape}" ) - forward_step = partial( - forward_step_arbitrary_loss, loss_fn=loss_fn, policy_cfg=self.cfg - ) all_mb_metrics = [] losses = [] total_num_microbatches = 0 @@ -1013,6 +951,11 @@ def train( # Track total microbatches for MoE aux-loss averaging total_num_microbatches += int(num_microbatches) + loss_fn_wrapped = LossPostProcessor( + loss_fn=loss_fn, + cfg=self.cfg, + ) + rerun_state_machine = get_rerun_state_machine() while rerun_state_machine.should_run_forward_backward(data_iterator): # Set grad to zero. @@ -1020,24 +963,19 @@ def train( self.optimizer.zero_grad() # Forward pass. - forward_backward_func = get_forward_backward_func() - losses_reduced = forward_backward_func( - forward_step_func=partial( - forward_step, - self.mcore_state, - global_valid_seqs, - global_valid_toks, - pack_sequences=self.cfg["sequence_packing"]["enabled"], - defer_fp32_logits=self.defer_fp32_logits, - ), - data_iterator=data_iterator, + losses_reduced = megatron_forward_backward( model=self.model, + cfg=self.cfg, + data_iterator=data_iterator, num_microbatches=num_microbatches, - seq_length=seq_dim_size, - micro_batch_size=mbs, - decoder_seq_length=seq_dim_size, + seq_length=padded_seq_length, + mbs=micro_batch_size, + post_processing_fn=loss_fn_wrapped, forward_only=eval_mode, - do_not_average_loss=True, + #do_not_average_loss=True, ## TODO! + defer_fp32_logits=self.defer_fp32_logits, + global_valid_seqs=global_valid_seqs, + global_valid_toks=global_valid_toks, ) # Empty unused memory. @@ -1092,19 +1030,13 @@ def train( loss_metrics["global_valid_toks"] = global_valid_toks.item() mb_losses.append(loss_metrics["loss"]) - torch.distributed.broadcast_object_list( - [gb_loss_metrics], - src=get_pipeline_model_parallel_last_rank(), - group=get_pipeline_model_parallel_group(), - ) else: - loss_metrics = [None] # type: ignore - torch.distributed.broadcast_object_list( - loss_metrics, - src=get_pipeline_model_parallel_last_rank(), - group=get_pipeline_model_parallel_group(), - ) - gb_loss_metrics = loss_metrics[0] + gb_loss_metrics = None + + # Broadcast loss metrics from last stage to all stages + ## TODO: check with PP > 1 + gb_loss_metrics = broadcast_loss_metrics_from_last_stage(gb_loss_metrics) + if not parallel_state.is_pipeline_last_stage(ignore_virtual=True): mb_losses = [x["loss"] for x in gb_loss_metrics] all_mb_metrics.extend(gb_loss_metrics) @@ -1187,97 +1119,18 @@ def get_logprobs( padded_seq_length, ) = get_microbatch_iterator(data, self.cfg, logprob_batch_size) - def forward_step_fn( - data_iterator: Iterator[BatchedDataDict[Any]], model: GPTModel - ): - processed_mb = next(data_iterator) - # Extract the processed components - data_dict = processed_mb.data_dict - input_ids = processed_mb.input_ids - input_ids_cp_sharded = processed_mb.input_ids_cp_sharded - attention_mask = processed_mb.attention_mask - position_ids = processed_mb.position_ids - packed_seq_params = processed_mb.packed_seq_params - cu_seqlens_padded = processed_mb.cu_seqlens_padded - unpacked_input_ids = data_dict["input_ids"] - - multimodal_data = data_dict.get_multimodal_dict( - as_tensors=True, device=input_ids.device - ) - if len(multimodal_data) > 0: - position_ids = None - - additional_kwargs = {} - # Mamba models currently do not support packed_seq_params - if packed_seq_params is not None: - additional_kwargs["packed_seq_params"] = packed_seq_params - - if self.defer_fp32_logits: - additional_kwargs["fp32_output"] = False - - output_tensor = model( - input_ids=input_ids_cp_sharded, - position_ids=position_ids, - attention_mask=attention_mask, - **multimodal_data, - **additional_kwargs, - ) - - # Apply temperature scaling to logits for training - # This matches the dtensor worker's _apply_temperature_scaling in the train method - if "generation" in self.cfg and self.cfg["generation"] is not None: - output_tensor.div_(self.cfg["generation"]["temperature"]) - - def collection_fn(output_tensor): - stc = time.time() - tp_grp = get_tensor_model_parallel_group() - tp_rank = get_tensor_model_parallel_rank() - logprob_chunk_size = self.cfg.get("logprob_chunk_size", None) - if self.cfg["sequence_packing"]["enabled"]: - token_logprobs = from_parallel_logits_to_logprobs_packed_sequences( - output_tensor, - target=input_ids, - cu_seqlens_padded=cu_seqlens_padded, - unpacked_seqlen=seq_length, - vocab_start_index=tp_rank * output_tensor.shape[-1], - vocab_end_index=(tp_rank + 1) * output_tensor.shape[-1], - group=tp_grp, - inference_only=True, - cp_group=get_context_parallel_group(), - chunk_size=logprob_chunk_size, - ) - else: - token_logprobs = from_parallel_logits_to_logprobs( - output_tensor, - target=unpacked_input_ids, - vocab_start_index=tp_rank * output_tensor.shape[-1], - vocab_end_index=(tp_rank + 1) * output_tensor.shape[-1], - tp_group=tp_grp, - inference_only=True, - chunk_size=logprob_chunk_size, - ) - - # Prepend 0 logprob for first token to maintain same sequence length as input - token_logprobs = torch.cat( - [torch.zeros_like(token_logprobs[:, :1]), token_logprobs], dim=1 - ) - return torch.tensor(0.0, device=token_logprobs.device), { - "logprobs": token_logprobs - } - - return output_tensor, collection_fn - - forward_backward_func = get_forward_backward_func() - list_of_logprobs = forward_backward_func( - forward_step_func=forward_step_fn, - data_iterator=mb_iterator, + list_of_logprobs = megatron_forward_backward( model=self.model, - num_microbatches=num_microbatches, + cfg=self.cfg, + data_iterator=mb_iterator, seq_length=padded_seq_length, - micro_batch_size=micro_batch_size, - decoder_seq_length=padded_seq_length, + mbs=micro_batch_size, + num_microbatches=num_microbatches, + post_processing_fn=LogprobsPostProcessor(cfg=self.cfg), forward_only=True, + defer_fp32_logits=self.defer_fp32_logits, ) + if is_pipeline_last_stage(ignore_virtual=True): all_log_probs_padded = [] all_logprobs = [l["logprobs"] for l in list_of_logprobs] @@ -1290,12 +1143,10 @@ def collection_fn(output_tensor): all_log_probs_padded.append(lp) logprobs = torch.cat(all_log_probs_padded, dim=0) - # broadcast logprobs to first pp rank - broadcast_tensor(logprobs, torch.distributed.get_rank(), pp_grp) + tensors = {"logprobs": logprobs} else: - logprobs = broadcast_tensor( - None, get_pipeline_model_parallel_last_rank(), pp_grp - ) + tensors = {"logprobs": None} + logprobs = broadcast_tensors_from_last_stage(tensors)["logprobs"] no_grad.__exit__(None, None, None) return BatchedDataDict[LogprobOutputSpec](logprobs=logprobs).to("cpu") @@ -1389,172 +1240,16 @@ def get_topk_logits( padded_seq_length, ) = get_microbatch_iterator(data, self.cfg, logprob_batch_size) - def forward_step_fn( - data_iterator: Iterator[BatchedDataDict[Any]], model: GPTModel - ): - processed_mb = next(data_iterator).to("cuda") - # Extract the processed components - data_dict = processed_mb.data_dict - input_ids = processed_mb.input_ids - input_ids_cp_sharded = processed_mb.input_ids_cp_sharded - attention_mask = processed_mb.attention_mask - position_ids = processed_mb.position_ids - packed_seq_params = processed_mb.packed_seq_params - cu_seqlens_padded = processed_mb.cu_seqlens_padded - unpacked_input_ids = data_dict["input_ids"] - - multimodal_data = data_dict.get_multimodal_dict( - as_tensors=True, device=input_ids_cp_sharded.device - ) - if len(multimodal_data) > 0: - position_ids = None - - additional_kwargs = {} - if packed_seq_params is not None: - additional_kwargs["packed_seq_params"] = packed_seq_params - - output_tensor = model( - input_ids=input_ids_cp_sharded, - position_ids=position_ids, - attention_mask=attention_mask, - **additional_kwargs, - **multimodal_data, - ) - - if "generation" in self.cfg and self.cfg["generation"] is not None: - output_tensor.div_(self.cfg["generation"]["temperature"]) - - def collection_fn(_): - # Only the last PP stage produces final logits/top-k; earlier stages return empty - # if not is_pipeline_last_stage(ignore_virtual=True): - # return output_tensor.new_zeros(()), {} - - tp_grp = get_tensor_model_parallel_group() - tp_rank = get_tensor_model_parallel_rank() - vocab_shard_size = output_tensor.shape[-1] - vocab_start_index = tp_rank * vocab_shard_size - - chunk_size = None - if "logprob_chunk_size" in self.cfg: - chunk_size = self.cfg["logprob_chunk_size"] - - topk_vals_local, topk_idx_local = distributed_vocab_topk( - output_tensor, - k, - tp_grp, - vocab_start_index=vocab_start_index, - vocab_end_index=vocab_start_index + vocab_shard_size, - chunk_size=chunk_size, - ) - - if self.cfg["megatron_cfg"]["context_parallel_size"] > 1: - cp_grp = get_context_parallel_group() - if self.cfg["sequence_packing"]["enabled"]: - cp_size = self.cfg["megatron_cfg"]["context_parallel_size"] - # Per-sequence CP allgather following packed-sequence logic - batch_size = data_dict["input_ids"].shape[0] - total_packed_len = int(cu_seqlens_padded[-1].item()) - - topk_vals_full = torch.zeros( - (1, total_packed_len, k), - dtype=topk_vals_local.dtype, - device=topk_vals_local.device, - ) - topk_idx_full = torch.zeros( - (1, total_packed_len, k), - dtype=topk_idx_local.dtype, - device=topk_idx_local.device, - ) - - for i in range(batch_size): - start_idx = int(cu_seqlens_padded[i].item()) - end_idx = int(cu_seqlens_padded[i + 1].item()) - if end_idx > start_idx: - local_vals_slice = topk_vals_local[ - :, start_idx // cp_size : end_idx // cp_size, : - ] - local_idx_slice = topk_idx_local[ - :, start_idx // cp_size : end_idx // cp_size, : - ] - gathered_vals = allgather_cp_sharded_tensor( - local_vals_slice, cp_grp, seq_dim=1 - ) - gathered_idx = allgather_cp_sharded_tensor( - local_idx_slice, cp_grp, seq_dim=1 - ) - # Some kernels may return [X, Y, k] where X*Y = (end_idx - start_idx). - # Flatten leading dims and reshape to [1, expected_len, k] to match target. - expected_len = end_idx - start_idx - if ( - gathered_vals.dim() == 3 - and gathered_vals.shape[1] != expected_len - ): - gathered_vals = gathered_vals.reshape( - 1, expected_len, gathered_vals.shape[-1] - ) - if ( - gathered_idx.dim() == 3 - and gathered_idx.shape[1] != expected_len - ): - gathered_idx = gathered_idx.reshape( - 1, expected_len, gathered_idx.shape[-1] - ) - topk_vals_full[:, start_idx:end_idx, :] = gathered_vals - topk_idx_full[:, start_idx:end_idx, :] = gathered_idx - else: - # Sequence packing must be enabled when CP > 1 - raise RuntimeError( - "Context Parallelism (CP>1) requires sequence packing to be enabled." - ) - else: - topk_vals_full = topk_vals_local - topk_idx_full = topk_idx_local - - if self.cfg["sequence_packing"]["enabled"]: - batch_size = data_dict["input_ids"].shape[0] - seq_lengths = data_dict["input_lengths"] - out_vals = torch.zeros( - (batch_size, seq_length, k), - dtype=topk_vals_full.dtype, - device=topk_vals_full.device, - ) - out_idx = torch.zeros( - (batch_size, seq_length, k), - dtype=topk_idx_full.dtype, - device=topk_idx_full.device, - ) - for i in range(batch_size): - seq_len = int(seq_lengths[i].item()) - start_idx = int(cu_seqlens_padded[i].item()) - if seq_len > 0: - out_vals[i, :seq_len, :] = topk_vals_full[ - 0, start_idx : start_idx + seq_len, : - ] - out_idx[i, :seq_len, :] = topk_idx_full[ - 0, start_idx : start_idx + seq_len, : - ] - return output_tensor.new_zeros(()), { - "topk_logits": out_vals, - "topk_indices": out_idx, - } - else: - return output_tensor.new_zeros(()), { - "topk_logits": topk_vals_full, - "topk_indices": topk_idx_full, - } - - return output_tensor, collection_fn - - forward_backward_func = get_forward_backward_func() - list_of_outputs = forward_backward_func( - forward_step_func=forward_step_fn, - data_iterator=mb_iterator, + list_of_outputs = megatron_forward_backward( model=self.model, - num_microbatches=num_microbatches, + cfg=self.cfg, + data_iterator=mb_iterator, seq_length=padded_seq_length, - micro_batch_size=micro_batch_size, - decoder_seq_length=padded_seq_length, + mbs=micro_batch_size, + num_microbatches=num_microbatches, + post_processing_fn=TopkLogitsPostProcessor(cfg=self.cfg, k=k), forward_only=True, + defer_fp32_logits=self.defer_fp32_logits, ) if is_pipeline_last_stage(ignore_virtual=True): @@ -1573,16 +1268,20 @@ def collection_fn(_): topk_logits = torch.cat(logits_chunks, dim=0) topk_indices = torch.cat(indices_chunks, dim=0) - topk_logits = broadcast_tensor( - topk_logits, torch.distributed.get_rank(), pp_grp - ) - topk_indices = broadcast_tensor( - topk_indices, torch.distributed.get_rank(), pp_grp - ) + tensors_to_broadcast = { + "topk_logits": topk_logits, + "topk_indices": topk_indices, + } else: - last_pp_rank = get_pipeline_model_parallel_last_rank() - topk_logits = broadcast_tensor(None, last_pp_rank, pp_grp) - topk_indices = broadcast_tensor(None, last_pp_rank, pp_grp) + tensors_to_broadcast = { + "topk_logits": None, + "topk_indices": None, + } + + # Broadcast tensors from last stage to all stages + broadcasted = broadcast_tensors_from_last_stage(tensors_to_broadcast) + topk_logits = broadcasted["topk_logits"] + topk_indices = broadcasted["topk_indices"] no_grad.__exit__(None, None, None) return BatchedDataDict.from_batches( @@ -1875,7 +1574,7 @@ def calculate_size_in_bytes(param, tp_size, ep_size): ) # Broadcast size_in_bytes across pipeline parallel ranks - return broadcast_object_across_pp_ranks(size_in_bytes) + return broadcast_obj_from_pp_rank(size_in_bytes) for task in self.refit_conversion_tasks: param_info.append( From 54174911118bf7d84591e2d5f26dc7ab775ca60c Mon Sep 17 00:00:00 2001 From: ashors1 Date: Wed, 17 Dec 2025 16:39:56 -0800 Subject: [PATCH 3/4] add do_not_average_loss arg Signed-off-by: ashors1 --- nemo_rl/models/megatron/train.py | 2 ++ nemo_rl/models/policy/workers/megatron_policy_worker.py | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/nemo_rl/models/megatron/train.py b/nemo_rl/models/megatron/train.py index 4b2345a73c..0db32f8564 100644 --- a/nemo_rl/models/megatron/train.py +++ b/nemo_rl/models/megatron/train.py @@ -180,6 +180,7 @@ def megatron_forward_backward( defer_fp32_logits: Optional[bool] = None, global_valid_seqs: Optional[torch.Tensor] = None, global_valid_toks: Optional[torch.Tensor] = None, + do_not_average_loss: bool = False, ) -> Any: """ Execute forward and backward passes using Megatron's utilities. @@ -222,6 +223,7 @@ def megatron_forward_backward( micro_batch_size=mbs, decoder_seq_length=seq_length, forward_only=forward_only, + do_not_average_loss=do_not_average_loss, ) class LossPostProcessor: diff --git a/nemo_rl/models/policy/workers/megatron_policy_worker.py b/nemo_rl/models/policy/workers/megatron_policy_worker.py index 215761e902..f995c0f768 100644 --- a/nemo_rl/models/policy/workers/megatron_policy_worker.py +++ b/nemo_rl/models/policy/workers/megatron_policy_worker.py @@ -963,10 +963,10 @@ def train( mbs=micro_batch_size, post_processing_fn=loss_fn_wrapped, forward_only=eval_mode, - #do_not_average_loss=True, ## TODO! defer_fp32_logits=self.defer_fp32_logits, global_valid_seqs=global_valid_seqs, global_valid_toks=global_valid_toks, + do_not_average_loss=True, ) # Empty unused memory. From 5b65462c3bd9e692326bebd27a998a2813247ab2 Mon Sep 17 00:00:00 2001 From: ashors1 Date: Thu, 18 Dec 2025 15:48:50 -0800 Subject: [PATCH 4/4] remove unused import Signed-off-by: ashors1 --- nemo_rl/models/policy/workers/megatron_policy_worker.py | 1 - 1 file changed, 1 deletion(-) diff --git a/nemo_rl/models/policy/workers/megatron_policy_worker.py b/nemo_rl/models/policy/workers/megatron_policy_worker.py index f995c0f768..b4a76a6b20 100644 --- a/nemo_rl/models/policy/workers/megatron_policy_worker.py +++ b/nemo_rl/models/policy/workers/megatron_policy_worker.py @@ -18,7 +18,6 @@ import warnings from collections import defaultdict from contextlib import AbstractContextManager, contextmanager, nullcontext -from functools import partial from typing import Any, Iterator, Optional, TypeVar, cast import ray