From ab90496adaba3f5eb4d2a41555692ef356a50faa Mon Sep 17 00:00:00 2001 From: N!no Date: Tue, 16 Dec 2025 20:11:16 -0500 Subject: [PATCH 01/13] Update maac.py --- comlrl/trainers/maac.py | 170 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 170 insertions(+) diff --git a/comlrl/trainers/maac.py b/comlrl/trainers/maac.py index 830d676..1844ac0 100644 --- a/comlrl/trainers/maac.py +++ b/comlrl/trainers/maac.py @@ -57,6 +57,8 @@ class MAACConfig: reward_norm_eps: float = 1e-3 num_return_sequences: int = 1 critic_model_name_or_path: Optional[Union[str, PreTrainedModel]] = None + num_turns: int = 1 + discount: float = 0.9 def __post_init__(self) -> None: if self.rollout_buffer_size < 1: @@ -73,6 +75,12 @@ def __post_init__(self) -> None: raise ValueError("num_return_sequences must be >= 1.") if self.critic_model_name_or_path is None: raise ValueError("critic_model_name_or_path must be provided for MAAC.") + if self.num_turns < 1: + raise ValueError("num_turns must be >= 1.") + if self.num_turns > 1 and self.num_return_sequences != 1: + raise ValueError( + "Multi-turn MAAC currently supports num_return_sequences == 1." + ) class MAACTrainer: @@ -91,6 +99,7 @@ def __init__( model_config: Optional[Dict[str, Any]] = None, wandb_config: Optional[Dict[str, Any]] = None, metrics_callback: Optional[MetricsCallback] = None, + external_transition: Optional[Callable] = None, ) -> None: if reward_func is None or not callable(reward_func): raise ValueError("A callable reward_func must be provided.") @@ -123,6 +132,10 @@ def __init__( "Multi-agent MAAC requires `model` to be a pretrained identifier string." ) + if self.args.num_turns > 1 and external_transition is None: + raise ValueError("Multi-turn MAAC requires an external_transition.") + self.external_transition = external_transition + self.actor_models: List[CausalLMWithValueHead] = [] for _ in range(self.args.num_agents): actor_model = self._load_actor_model(model) @@ -395,6 +408,10 @@ def _generate(self, actor_model, prompt: str) -> Dict[str, Any]: } def _collect_rollouts(self, item: Dict[str, Any]) -> List[RolloutSample]: + num_turns = max(1, int(getattr(self.args, "num_turns", 1))) + if num_turns > 1: + return self._collect_rollouts_multi_turn(item, num_turns) + prompts: List[str] = [] completions_per_agent: List[List[str]] = [] rollout_data: List[Dict[str, Any]] = [] @@ -496,6 +513,159 @@ def _collect_rollouts(self, item: Dict[str, Any]) -> List[RolloutSample]: return rollouts + def _collect_rollouts_multi_turn( + self, item: Dict[str, Any], num_turns: int + ) -> List[RolloutSample]: + if self.args.num_return_sequences != 1: + raise ValueError( + "Multi-turn MAAC currently supports num_return_sequences == 1." + ) + + prompt_history = [[] for _ in range(self.args.num_agents)] + response_history = [[] for _ in range(self.args.num_agents)] + previous_completions: List[Optional[str]] = [None] * self.args.num_agents + per_agent_samples: List[List[RolloutSample]] = [ + [] for _ in range(self.args.num_agents) + ] + rollouts: List[RolloutSample] = [] + gamma = float(getattr(self.args, "discount", 0.9)) + + for turn_idx in range(num_turns): + if turn_idx == 0: + turn_prompts = [ + self._format_prompt(item, agent_idx) + for agent_idx in range(self.args.num_agents) + ] + else: + if self.external_transition is None: + raise ValueError("external_transition is required for multi-turn.") + transition_result = self.external_transition( + prompt=item.get("prompt", ""), + agent_completions=previous_completions, + num_agents=self.args.num_agents, + prompt_history_per_agent=prompt_history, + response_history_per_agent=response_history, + ) + if ( + not isinstance(transition_result, (list, tuple)) + or len(transition_result) != self.args.num_agents + ): + raise ValueError( + "External transition must return per-agent prompts" + ) + turn_prompts = list(transition_result) + + completions_per_agent: List[List[str]] = [] + rollout_data: List[Dict[str, Any]] = [] + for agent_idx, actor_model in enumerate(self.actor_models): + prompt = turn_prompts[agent_idx] + gen = self._generate(actor_model, prompt) + completions_per_agent.append(gen["completions"]) + rollout_data.append( + { + "agent_idx": agent_idx, + "prompt": prompt, + "prompt_len": gen["prompt_len"], + "sequences": gen["sequences"], + "attention_mask": gen["attention_mask"], + "response_lens": gen["response_lens"], + "completion_texts": gen["completions"], + } + ) + prompt_history[agent_idx].append(prompt) + + rewards = self._call_reward_func(turn_prompts, completions_per_agent) + rewards_matrix = self._expand_rewards(rewards, num_ret=1) + + joint_prompt = self._build_joint_prompt(turn_prompts) + joint_encoded = self._encode_prompt(joint_prompt) + joint_ids = joint_encoded["input_ids"] + joint_mask = joint_encoded["attention_mask"] + joint_len = joint_ids.size(1) + with torch.no_grad(): + joint_value = self._value_on_prompt_only( + self.critic_model, joint_ids, joint_mask, joint_len + ) + + for data in rollout_data: + agent_idx = data["agent_idx"] + seq = data["sequences"][0] + attn = data["attention_mask"][0] + resp_len = data["response_lens"][0] + reward_val = float(rewards_matrix[agent_idx][0]) + reward_tensor = torch.tensor([reward_val], device=self.device) + + logprob, _ = self._policy_eval( + self.actor_models[agent_idx], + seq.unsqueeze(0), + attn.unsqueeze(0), + data["prompt_len"], + resp_len, + output_values=False, + ) + + value = joint_value.detach().cpu() + completion_text = data["completion_texts"][0] + sample = RolloutSample( + agent_idx=agent_idx, + prompt=data["prompt"], + completion=completion_text, + full_input_ids=seq.detach().cpu(), + attention_mask=attn.detach().cpu(), + prompt_len=data["prompt_len"], + response_len=resp_len, + old_logprob=logprob.detach().cpu(), + old_value=value.detach().cpu(), + reward=reward_tensor.detach().cpu(), + returns=reward_tensor.detach().cpu(), + advantage=torch.zeros_like(reward_tensor).detach().cpu(), + normalized_advantage=None, + metadata={ + "joint_input_ids": joint_ids.detach().cpu(), + "joint_attention_mask": joint_mask.detach().cpu(), + "joint_prompt_len": joint_len, + "turn_idx": turn_idx, + }, + ) + rollouts.append(sample) + per_agent_samples[agent_idx].append(sample) + response_history[agent_idx].append(completion_text) + previous_completions[agent_idx] = completion_text + + for agent_idx in range(self.args.num_agents): + future = 0.0 + for sample in reversed(per_agent_samples[agent_idx]): + immediate = float(sample.reward.view(-1)[0].item()) + future = immediate + gamma * future + sample.returns = ( + torch.tensor([future], device=self.device).detach().cpu() + ) + sample.advantage = torch.zeros_like(sample.returns) + sample.normalized_advantage = None + + if self.metrics_callback is not None: + try: + extra = self.metrics_callback(rollouts) + if isinstance(extra, dict): + self._log_metrics(extra) + except Exception: + pass + + return rollouts + + def _expand_rewards(self, rewards: List[float], num_ret: int) -> List[List[float]]: + """Map reward list to [num_agents x num_ret] matrix.""" + num_agents = self.args.num_agents + if len(rewards) == 1: + return [[rewards[0]] * num_ret for _ in range(num_agents)] + if len(rewards) == num_ret: + return [list(rewards) for _ in range(num_agents)] + if len(rewards) == num_agents: + return [[rewards[a]] * num_ret for a in range(num_agents)] + raise ValueError( + "Reward function must return 1 value, num_return_sequences values, or num_agents values." + ) + # ------------------------------------------------------------------ # # Advantage prep # ------------------------------------------------------------------ # From fd182ef97c33eec8d47afbbf3e0ad34448589a13 Mon Sep 17 00:00:00 2001 From: N!no Date: Tue, 16 Dec 2025 20:26:15 -0500 Subject: [PATCH 02/13] Update maac.py --- comlrl/trainers/maac.py | 32 ++++++++++++++++++++++++-------- 1 file changed, 24 insertions(+), 8 deletions(-) diff --git a/comlrl/trainers/maac.py b/comlrl/trainers/maac.py index 1844ac0..ed0618e 100644 --- a/comlrl/trainers/maac.py +++ b/comlrl/trainers/maac.py @@ -933,9 +933,10 @@ def train(self) -> None: # Logging and persistence # ------------------------------------------------------------------ # def _tag_metrics( - self, metrics: Dict[str, float], agent_idx: int + self, metrics: Dict[str, float], agent_idx: int, turn_idx: int = 0 ) -> Dict[str, float]: - return {f"turn_1/{key}": value for key, value in metrics.items()} + prefix = f"turn_{turn_idx + 1}/" + return {prefix + key: value for key, value in metrics.items()} def _log_metrics(self, metrics: Dict[str, float]) -> None: if not metrics: @@ -949,13 +950,28 @@ def _process_buffer( buffer: List[RolloutSample], epoch_metrics: Dict[str, List[float]], ) -> None: - metrics = self._update(agent_idx, buffer) + if not buffer: + return + + # Group samples by turn (if available); default to turn 0. + has_turn_idx = any( + "turn_idx" in (getattr(s, "metadata", {}) or {}) for s in buffer + ) + turn_groups: Dict[int, List[RolloutSample]] = {} + for sample in buffer: + t_idx = int(sample.metadata.get("turn_idx", 0)) if has_turn_idx else 0 + turn_groups.setdefault(t_idx, []).append(sample) + buffer.clear() - tagged = self._tag_metrics(metrics, agent_idx) - self._log_metrics(tagged) - self.global_step += 1 - for key, value in tagged.items(): - epoch_metrics[key].append(value) + + for t_idx in sorted(turn_groups.keys()): + samples = turn_groups[t_idx] + metrics = self._update(agent_idx, samples) + tagged = self._tag_metrics(metrics, agent_idx, turn_idx=t_idx) + self._log_metrics(tagged) + self.global_step += 1 + for key, value in tagged.items(): + epoch_metrics[key].append(value) def save_model(self, output_dir: str) -> None: os.makedirs(output_dir, exist_ok=True) From f6a36be6343a5cd3527099a1619feb06cf67cf9a Mon Sep 17 00:00:00 2001 From: N!no Date: Tue, 16 Dec 2025 20:51:50 -0500 Subject: [PATCH 03/13] monkey patch --- comlrl/trainers/maac.py | 23 ++++++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) diff --git a/comlrl/trainers/maac.py b/comlrl/trainers/maac.py index ed0618e..b1fe583 100644 --- a/comlrl/trainers/maac.py +++ b/comlrl/trainers/maac.py @@ -926,8 +926,29 @@ def train(self) -> None: for key, values in epoch_metrics.items() if values } + num_turns = max(1, int(getattr(self.args, "num_turns", 1))) + epoch_log: Dict[str, float] = {} + for turn_idx in range(num_turns): + prefix = f"turn_{turn_idx + 1}/" + + def _maybe_log(metric_key: str, epoch_key: str) -> None: + values = epoch_metrics.get(prefix + metric_key) + if values: + epoch_log[prefix + epoch_key] = float(sum(values) / len(values)) + + _maybe_log("reward_mean", "epoch_reward_mean") + _maybe_log("expected_return", "epoch_avg_return") + _maybe_log("value_variance", "epoch_value_variance") + _maybe_log("policy_loss", "epoch_policy_loss") + _maybe_log("value_loss", "epoch_value_loss") + + if epoch_log: + self._log_metrics(epoch_log) + self.global_step += 1 + if summary: - print(f"Epoch {epoch + 1}/{total_epochs} metrics: {summary}") + to_print = epoch_log if epoch_log else summary + print(f"Epoch {epoch + 1}/{total_epochs} metrics: {to_print}") # ------------------------------------------------------------------ # # Logging and persistence From 6397ee67bc9786d244fddf803c21e5a142dbb8e8 Mon Sep 17 00:00:00 2001 From: N!no Date: Tue, 16 Dec 2025 22:09:20 -0500 Subject: [PATCH 04/13] Update maac.py --- comlrl/trainers/maac.py | 35 +++++++++++++++++++++++++++++------ 1 file changed, 29 insertions(+), 6 deletions(-) diff --git a/comlrl/trainers/maac.py b/comlrl/trainers/maac.py index b1fe583..b042949 100644 --- a/comlrl/trainers/maac.py +++ b/comlrl/trainers/maac.py @@ -81,6 +81,11 @@ def __post_init__(self) -> None: raise ValueError( "Multi-turn MAAC currently supports num_return_sequences == 1." ) + if self.num_turns > 1 and self.rollout_buffer_size % self.num_turns != 0: + raise ValueError( + "For multi-turn MAAC, rollout_buffer_size must be a multiple of num_turns " + "so per-turn metrics align (e.g., num_turns=2 => buffer_size even)." + ) class MAACTrainer: @@ -859,11 +864,10 @@ def _update( ) -> Dict[str, float]: if not rollouts: return {} - self._prepare_advantages(rollouts) - random.shuffle(rollouts) metrics = defaultdict(list) - # Per-prompt reward variance across generations (all agents combined) + # Log metrics using raw rewards/returns before normalization. + # Note: _prepare_advantages() may normalize sample.returns in-place. prompt_rewards: Dict[str, List[float]] = defaultdict(list) for sample in rollouts: prompt_rewards[sample.prompt].append( @@ -883,7 +887,21 @@ def _update( if rewards.numel() > 0 and torch.isfinite(rewards).all(): mean_reward = float(rewards.mean().item()) metrics["reward_mean"].append(mean_reward) - metrics["expected_return"].append(mean_reward) + + returns = torch.stack( + [sample.returns.view(-1)[0] for sample in rollouts] + ).float() + if returns.numel() > 0 and torch.isfinite(returns).all(): + metrics["expected_return"].append(float(returns.mean().item())) + + values = torch.stack( + [sample.old_value.view(-1)[0] for sample in rollouts] + ).float() + if values.numel() > 0 and torch.isfinite(values).all(): + metrics["value_pred_mean"].append(float(values.mean().item())) + + self._prepare_advantages(rollouts) + random.shuffle(rollouts) for start in range(0, len(rollouts), self.args.mini_batch_size): batch = rollouts[start : start + self.args.mini_batch_size] @@ -939,6 +957,7 @@ def _maybe_log(metric_key: str, epoch_key: str) -> None: _maybe_log("reward_mean", "epoch_reward_mean") _maybe_log("expected_return", "epoch_avg_return") _maybe_log("value_variance", "epoch_value_variance") + _maybe_log("value_pred_mean", "epoch_value_pred_mean") _maybe_log("policy_loss", "epoch_policy_loss") _maybe_log("value_loss", "epoch_value_loss") @@ -985,15 +1004,19 @@ def _process_buffer( buffer.clear() + # Log all turns at the same wandb step so turn_1 and turn_2 align. + combined_log: Dict[str, float] = {} for t_idx in sorted(turn_groups.keys()): samples = turn_groups[t_idx] metrics = self._update(agent_idx, samples) tagged = self._tag_metrics(metrics, agent_idx, turn_idx=t_idx) - self._log_metrics(tagged) - self.global_step += 1 + combined_log.update(tagged) for key, value in tagged.items(): epoch_metrics[key].append(value) + self._log_metrics(combined_log) + self.global_step += 1 + def save_model(self, output_dir: str) -> None: os.makedirs(output_dir, exist_ok=True) for agent_idx, actor in enumerate(self.actor_models): From b3a8d121cd1b516ee5c7be82c43fd176621e9a5f Mon Sep 17 00:00:00 2001 From: N!no Date: Wed, 17 Dec 2025 20:54:12 -0500 Subject: [PATCH 05/13] set td or mc update critic configurable --- comlrl/trainers/maac.py | 161 ++++++++++++++++++++++++++++++++-------- 1 file changed, 131 insertions(+), 30 deletions(-) diff --git a/comlrl/trainers/maac.py b/comlrl/trainers/maac.py index b042949..ce04840 100644 --- a/comlrl/trainers/maac.py +++ b/comlrl/trainers/maac.py @@ -59,6 +59,8 @@ class MAACConfig: critic_model_name_or_path: Optional[Union[str, PreTrainedModel]] = None num_turns: int = 1 discount: float = 0.9 + critic_type: str = "v" # "v" (V(s)) or "q" (Q(s,a)) + critic_target: str = "td0" # "mc" (Monte Carlo) or "td0" (TD(0) on policy) def __post_init__(self) -> None: if self.rollout_buffer_size < 1: @@ -86,6 +88,12 @@ def __post_init__(self) -> None: "For multi-turn MAAC, rollout_buffer_size must be a multiple of num_turns " "so per-turn metrics align (e.g., num_turns=2 => buffer_size even)." ) + critic_type = (self.critic_type or "v").lower() + if critic_type not in ("v", "q"): + raise ValueError("critic_type must be one of: 'v', 'q'.") + critic_target = (self.critic_target or "mc").lower() + if critic_target not in ("mc", "td0"): + raise ValueError("critic_target must be one of: 'mc', 'td0'.") class MAACTrainer: @@ -314,6 +322,38 @@ def _build_joint_prompt(self, prompts: Sequence[str]) -> str: pieces = [f"[Agent {idx}] {p}" for idx, p in enumerate(prompts)] return "\n\n".join(pieces) + def _build_critic_input( + self, prompts: Sequence[str], action_completions: Optional[Sequence[str]] = None + ) -> str: + """Build centralized critic conditioning input. + + - critic_type='v': V(s) conditioned on joint prompt only. + - critic_type='q': Q(s,a) conditioned on joint prompt + joint action text. + """ + base = self._build_joint_prompt(prompts) + if (self.args.critic_type or "v").lower() == "v": + return base + + action_completions = list(action_completions or []) + action_lines: List[str] = ["[Joint Action]"] + for idx, comp in enumerate(action_completions): + action_lines.append(f"[Agent {idx} action]\n{comp}") + return base + "\n\n" + "\n\n".join(action_lines) + + def _critic_value_from_text(self, critic_input: str) -> Dict[str, Any]: + encoded = self._encode_prompt(critic_input) + ids = encoded["input_ids"] + mask = encoded["attention_mask"] + prompt_len = ids.size(1) + value = self._value_on_prompt_only(self.critic_model, ids, mask, prompt_len) + return { + "critic_input": critic_input, + "input_ids": ids, + "attention_mask": mask, + "prompt_len": prompt_len, + "value": value, + } + def _infer_reward_signature(self, reward_func: Callable) -> inspect.Signature: try: return inspect.signature(reward_func) @@ -452,17 +492,24 @@ def _collect_rollouts(self, item: Dict[str, Any]) -> List[RolloutSample]: "or num_agents values." ) - joint_prompt = self._build_joint_prompt(prompts) - joint_encoded = self._encode_prompt(joint_prompt) - joint_ids = joint_encoded["input_ids"] - joint_mask = joint_encoded["attention_mask"] - joint_len = joint_ids.size(1) - with torch.no_grad(): - joint_value = self._value_on_prompt_only( - self.critic_model, joint_ids, joint_mask, joint_len - ) - rollouts: List[RolloutSample] = [] + critic_type = (self.args.critic_type or "v").lower() + critic_values_by_i: List[Dict[str, Any]] = [] + if critic_type == "v": + critic_input = self._build_critic_input(prompts) + with torch.no_grad(): + critic_values_by_i = [self._critic_value_from_text(critic_input)] + else: + for i in range(num_ret): + joint_action = [ + completions_per_agent[a][i] for a in range(self.args.num_agents) + ] + critic_input = self._build_critic_input(prompts, joint_action) + with torch.no_grad(): + critic_values_by_i.append( + self._critic_value_from_text(critic_input) + ) + for data in rollout_data: agent_idx = data["agent_idx"] for i in range(num_ret): @@ -481,7 +528,15 @@ def _collect_rollouts(self, item: Dict[str, Any]) -> List[RolloutSample]: output_values=False, ) - value = joint_value.detach().cpu() + critic_pack = ( + critic_values_by_i[0] + if critic_type == "v" + else critic_values_by_i[i] + ) + joint_ids = critic_pack["input_ids"] + joint_mask = critic_pack["attention_mask"] + joint_len = int(critic_pack["prompt_len"]) + value = critic_pack["value"].detach().cpu() rollouts.append( RolloutSample( agent_idx=agent_idx, @@ -504,10 +559,17 @@ def _collect_rollouts(self, item: Dict[str, Any]) -> List[RolloutSample]: "joint_input_ids": joint_ids.detach().cpu(), "joint_attention_mask": joint_mask.detach().cpu(), "joint_prompt_len": joint_len, + "turn_idx": 0, + "adv_target": reward_tensor.detach().cpu(), }, ) ) + if (self.args.critic_target or "mc").lower() == "td0": + for sample in rollouts: + r = float(sample.reward.view(-1)[0].item()) + sample.metadata["value_target"] = torch.tensor([r]).detach().cpu() + if self.metrics_callback is not None: try: extra = self.metrics_callback(rollouts) @@ -581,16 +643,16 @@ def _collect_rollouts_multi_turn( rewards = self._call_reward_func(turn_prompts, completions_per_agent) rewards_matrix = self._expand_rewards(rewards, num_ret=1) - - joint_prompt = self._build_joint_prompt(turn_prompts) - joint_encoded = self._encode_prompt(joint_prompt) - joint_ids = joint_encoded["input_ids"] - joint_mask = joint_encoded["attention_mask"] - joint_len = joint_ids.size(1) + critic_input = self._build_critic_input( + turn_prompts, + action_completions=[c[0] for c in completions_per_agent], + ) with torch.no_grad(): - joint_value = self._value_on_prompt_only( - self.critic_model, joint_ids, joint_mask, joint_len - ) + critic_pack = self._critic_value_from_text(critic_input) + joint_ids = critic_pack["input_ids"] + joint_mask = critic_pack["attention_mask"] + joint_len = int(critic_pack["prompt_len"]) + joint_value = critic_pack["value"] for data in rollout_data: agent_idx = data["agent_idx"] @@ -637,6 +699,22 @@ def _collect_rollouts_multi_turn( response_history[agent_idx].append(completion_text) previous_completions[agent_idx] = completion_text + use_td_target = (self.args.critic_target or "mc").lower() == "td0" + for agent_idx in range(self.args.num_agents): + traj = per_agent_samples[agent_idx] + for t, sample in enumerate(traj): + r = float(sample.reward.view(-1)[0].item()) + if t < len(traj) - 1: + next_v = float(traj[t + 1].old_value.view(-1)[0].item()) + target = r + gamma * next_v + else: + target = r + sample.metadata["adv_target"] = torch.tensor([target]).detach().cpu() + if use_td_target: + sample.metadata["value_target"] = ( + torch.tensor([target]).detach().cpu() + ) + for agent_idx in range(self.args.num_agents): future = 0.0 for sample in reversed(per_agent_samples[agent_idx]): @@ -701,17 +779,28 @@ def _normalize_returns(self, rollouts: List[RolloutSample]) -> None: def _prepare_advantages(self, rollouts: List[RolloutSample]) -> None: if not rollouts: return - self._normalize_returns(rollouts) + if (self.args.critic_target or "mc").lower() == "mc": + self._normalize_returns(rollouts) - advantages = torch.stack( - [sample.advantage.to(torch.float32).view(-1)[0] for sample in rollouts] - ) + advantages = [] + for sample in rollouts: + target = sample.metadata.get("adv_target") or sample.metadata.get( + "value_target" + ) + if target is None: + target = sample.returns + adv = target.to(torch.float32) - sample.old_value.to(torch.float32) + sample.advantage = adv.to(sample.returns.dtype) + advantages.append(adv.view(-1)[0]) + advantages = torch.stack(advantages) if self.args.advantage_normalization and advantages.numel() > 1: mean = advantages.mean() std = advantages.std(unbiased=False).clamp(min=1e-6) for sample in rollouts: - sample.normalized_advantage = (sample.advantage - mean) / std + sample.normalized_advantage = ( + sample.advantage.to(torch.float32) - mean + ) / std else: for sample in rollouts: sample.normalized_advantage = sample.advantage.clone() @@ -817,7 +906,11 @@ def _ac_step(self, agent_idx: int, batch: List[RolloutSample]) -> Dict[str, floa old_value = sample.old_value.to(self.device, dtype=value.dtype) advantage = sample.normalized_advantage.to(self.device, dtype=value.dtype) - returns = sample.returns.to(self.device, dtype=value.dtype) + value_target = sample.metadata.get("value_target") + if value_target is None: + returns = sample.returns.to(self.device, dtype=value.dtype) + else: + returns = value_target.to(self.device, dtype=value.dtype) if not torch.isfinite(logprob).all(): raise FloatingPointError( @@ -866,8 +959,6 @@ def _update( return {} metrics = defaultdict(list) - # Log metrics using raw rewards/returns before normalization. - # Note: _prepare_advantages() may normalize sample.returns in-place. prompt_rewards: Dict[str, List[float]] = defaultdict(list) for sample in rollouts: prompt_rewards[sample.prompt].append( @@ -900,6 +991,17 @@ def _update( if values.numel() > 0 and torch.isfinite(values).all(): metrics["value_pred_mean"].append(float(values.mean().item())) + targets = [sample.metadata.get("value_target") for sample in rollouts] + if any(t is not None for t in targets): + target_vals = torch.stack( + [ + (t if t is not None else sample.returns).view(-1)[0] + for sample, t in zip(rollouts, targets) + ] + ).float() + if target_vals.numel() > 0 and torch.isfinite(target_vals).all(): + metrics["value_target_mean"].append(float(target_vals.mean().item())) + self._prepare_advantages(rollouts) random.shuffle(rollouts) @@ -958,6 +1060,7 @@ def _maybe_log(metric_key: str, epoch_key: str) -> None: _maybe_log("expected_return", "epoch_avg_return") _maybe_log("value_variance", "epoch_value_variance") _maybe_log("value_pred_mean", "epoch_value_pred_mean") + _maybe_log("value_target_mean", "epoch_value_target_mean") _maybe_log("policy_loss", "epoch_policy_loss") _maybe_log("value_loss", "epoch_value_loss") @@ -993,7 +1096,6 @@ def _process_buffer( if not buffer: return - # Group samples by turn (if available); default to turn 0. has_turn_idx = any( "turn_idx" in (getattr(s, "metadata", {}) or {}) for s in buffer ) @@ -1004,7 +1106,6 @@ def _process_buffer( buffer.clear() - # Log all turns at the same wandb step so turn_1 and turn_2 align. combined_log: Dict[str, float] = {} for t_idx in sorted(turn_groups.keys()): samples = turn_groups[t_idx] From b8e1d411b39253c319a3461c1884bc617d9847e1 Mon Sep 17 00:00:00 2001 From: N!no Date: Wed, 17 Dec 2025 21:14:27 -0500 Subject: [PATCH 06/13] add early termination --- comlrl/trainers/maac.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/comlrl/trainers/maac.py b/comlrl/trainers/maac.py index ce04840..107fa7b 100644 --- a/comlrl/trainers/maac.py +++ b/comlrl/trainers/maac.py @@ -61,6 +61,7 @@ class MAACConfig: discount: float = 0.9 critic_type: str = "v" # "v" (V(s)) or "q" (Q(s,a)) critic_target: str = "td0" # "mc" (Monte Carlo) or "td0" (TD(0) on policy) + early_termination_threshold: Optional[float] = None def __post_init__(self) -> None: if self.rollout_buffer_size < 1: @@ -699,6 +700,12 @@ def _collect_rollouts_multi_turn( response_history[agent_idx].append(completion_text) previous_completions[agent_idx] = completion_text + term_threshold = getattr(self.args, "early_termination_threshold", None) + if term_threshold is not None: + mean_reward = float(sum(rewards) / len(rewards)) if rewards else 0.0 + if mean_reward > float(term_threshold): + break + use_td_target = (self.args.critic_target or "mc").lower() == "td0" for agent_idx in range(self.args.num_agents): traj = per_agent_samples[agent_idx] From 2b74685658d2c2933583104b64433e5d41d5a9a6 Mon Sep 17 00:00:00 2001 From: N!no Date: Thu, 18 Dec 2025 01:21:02 -0500 Subject: [PATCH 07/13] add eval --- comlrl/trainers/maac.py | 106 ++++++++++++++++++++++++++++++++-------- 1 file changed, 85 insertions(+), 21 deletions(-) diff --git a/comlrl/trainers/maac.py b/comlrl/trainers/maac.py index 107fa7b..459fb98 100644 --- a/comlrl/trainers/maac.py +++ b/comlrl/trainers/maac.py @@ -62,6 +62,8 @@ class MAACConfig: critic_type: str = "v" # "v" (V(s)) or "q" (Q(s,a)) critic_target: str = "td0" # "mc" (Monte Carlo) or "td0" (TD(0) on policy) early_termination_threshold: Optional[float] = None + eval_interval: int = 4 + eval_num_samples: int = 4 def __post_init__(self) -> None: if self.rollout_buffer_size < 1: @@ -95,6 +97,10 @@ def __post_init__(self) -> None: critic_target = (self.critic_target or "mc").lower() if critic_target not in ("mc", "td0"): raise ValueError("critic_target must be one of: 'mc', 'td0'.") + if self.eval_interval < 0: + raise ValueError("eval_interval must be >= 0.") + if self.eval_num_samples < 1: + raise ValueError("eval_num_samples must be >= 1.") class MAACTrainer: @@ -301,6 +307,16 @@ def get_train_dataloader(self) -> DataLoader: collate_fn=lambda batch: batch, ) + def get_eval_dataloader(self) -> DataLoader: + if self.eval_dataset is None: + raise ValueError("Evaluation requires a dataset.") + return DataLoader( + self.eval_dataset, + batch_size=1, + shuffle=False, + collate_fn=lambda batch: batch, + ) + def _format_prompt(self, item: Dict[str, Any], agent_idx: int) -> str: formatter = self.formatters[agent_idx] prompt = formatter(item) @@ -964,8 +980,32 @@ def _update( ) -> Dict[str, float]: if not rollouts: return {} - metrics = defaultdict(list) + metrics = self._summarize_rollout_metrics(rollouts) + + self._prepare_advantages(rollouts) + random.shuffle(rollouts) + + loss_metrics = defaultdict(list) + for start in range(0, len(rollouts), self.args.mini_batch_size): + batch = rollouts[start : start + self.args.mini_batch_size] + step_metrics = self._ac_step(agent_idx, batch) + for key, value in step_metrics.items(): + loss_metrics[key].append(value) + averaged_losses = { + key: float(sum(values) / len(values)) + for key, values in loss_metrics.items() + if values + } + metrics.update(averaged_losses) + return metrics + + def _summarize_rollout_metrics( + self, rollouts: List[RolloutSample] + ) -> Dict[str, float]: + if not rollouts: + return {} + metrics: Dict[str, float] = {} prompt_rewards: Dict[str, List[float]] = defaultdict(list) for sample in rollouts: prompt_rewards[sample.prompt].append( @@ -977,26 +1017,25 @@ def _update( t = torch.tensor(vals, dtype=torch.float32) prompt_vars.append(float(torch.var(t, unbiased=False).item())) if prompt_vars: - metrics["value_variance"].append(float(sum(prompt_vars) / len(prompt_vars))) + metrics["value_variance"] = float(sum(prompt_vars) / len(prompt_vars)) rewards = torch.stack( [sample.reward.view(-1)[0] for sample in rollouts] ).float() if rewards.numel() > 0 and torch.isfinite(rewards).all(): - mean_reward = float(rewards.mean().item()) - metrics["reward_mean"].append(mean_reward) + metrics["reward_mean"] = float(rewards.mean().item()) returns = torch.stack( [sample.returns.view(-1)[0] for sample in rollouts] ).float() if returns.numel() > 0 and torch.isfinite(returns).all(): - metrics["expected_return"].append(float(returns.mean().item())) + metrics["expected_return"] = float(returns.mean().item()) values = torch.stack( [sample.old_value.view(-1)[0] for sample in rollouts] ).float() if values.numel() > 0 and torch.isfinite(values).all(): - metrics["value_pred_mean"].append(float(values.mean().item())) + metrics["value_pred_mean"] = float(values.mean().item()) targets = [sample.metadata.get("value_target") for sample in rollouts] if any(t is not None for t in targets): @@ -1007,22 +1046,41 @@ def _update( ] ).float() if target_vals.numel() > 0 and torch.isfinite(target_vals).all(): - metrics["value_target_mean"].append(float(target_vals.mean().item())) + metrics["value_target_mean"] = float(target_vals.mean().item()) - self._prepare_advantages(rollouts) - random.shuffle(rollouts) + return metrics - for start in range(0, len(rollouts), self.args.mini_batch_size): - batch = rollouts[start : start + self.args.mini_batch_size] - step_metrics = self._ac_step(agent_idx, batch) - for key, value in step_metrics.items(): - metrics[key].append(value) - averaged = { - key: float(sum(values) / len(values)) - for key, values in metrics.items() - if values - } - return averaged + def evaluate(self) -> Dict[str, float]: + if self.eval_dataset is None: + return {} + + dataloader = self.get_eval_dataloader() + num_samples = int(self.args.eval_num_samples) + turn_groups: Dict[int, List[RolloutSample]] = {} + seen = 0 + + with torch.no_grad(): + for batch in dataloader: + for item in batch: + rollouts = self._collect_rollouts(item) + for sample in rollouts: + t_idx = int(sample.metadata.get("turn_idx", 0)) + turn_groups.setdefault(t_idx, []).append(sample) + seen += 1 + if seen >= num_samples: + break + if seen >= num_samples: + break + + eval_log: Dict[str, float] = {} + for turn_idx, samples in sorted(turn_groups.items()): + metrics = self._summarize_rollout_metrics(samples) + for key, value in metrics.items(): + eval_log[f"eval/turn_{turn_idx + 1}/{key}"] = value + + if eval_log: + self._log_metrics(eval_log) + return eval_log # ------------------------------------------------------------------ # # Training loop @@ -1033,7 +1091,13 @@ def train(self) -> None: for epoch in range(total_epochs): epoch_metrics = defaultdict(list) - for batch in dataloader: + for batch_idx, batch in enumerate(dataloader): + if ( + self.eval_dataset is not None + and self.args.eval_interval > 0 + and batch_idx % int(self.args.eval_interval) == 0 + ): + self.evaluate() for item in batch: rollouts = self._collect_rollouts(item) for sample in rollouts: From 7f0c4b81ef5b4476947e1ab1c68ec22e72bc40cc Mon Sep 17 00:00:00 2001 From: N!no Date: Thu, 18 Dec 2025 10:26:02 -0500 Subject: [PATCH 08/13] remove monte carlo update --- comlrl/trainers/maac.py | 20 ++++---------------- 1 file changed, 4 insertions(+), 16 deletions(-) diff --git a/comlrl/trainers/maac.py b/comlrl/trainers/maac.py index 459fb98..a37416c 100644 --- a/comlrl/trainers/maac.py +++ b/comlrl/trainers/maac.py @@ -60,7 +60,6 @@ class MAACConfig: num_turns: int = 1 discount: float = 0.9 critic_type: str = "v" # "v" (V(s)) or "q" (Q(s,a)) - critic_target: str = "td0" # "mc" (Monte Carlo) or "td0" (TD(0) on policy) early_termination_threshold: Optional[float] = None eval_interval: int = 4 eval_num_samples: int = 4 @@ -94,9 +93,6 @@ def __post_init__(self) -> None: critic_type = (self.critic_type or "v").lower() if critic_type not in ("v", "q"): raise ValueError("critic_type must be one of: 'v', 'q'.") - critic_target = (self.critic_target or "mc").lower() - if critic_target not in ("mc", "td0"): - raise ValueError("critic_target must be one of: 'mc', 'td0'.") if self.eval_interval < 0: raise ValueError("eval_interval must be >= 0.") if self.eval_num_samples < 1: @@ -582,10 +578,9 @@ def _collect_rollouts(self, item: Dict[str, Any]) -> List[RolloutSample]: ) ) - if (self.args.critic_target or "mc").lower() == "td0": - for sample in rollouts: - r = float(sample.reward.view(-1)[0].item()) - sample.metadata["value_target"] = torch.tensor([r]).detach().cpu() + for sample in rollouts: + r = float(sample.reward.view(-1)[0].item()) + sample.metadata["value_target"] = torch.tensor([r]).detach().cpu() if self.metrics_callback is not None: try: @@ -722,7 +717,6 @@ def _collect_rollouts_multi_turn( if mean_reward > float(term_threshold): break - use_td_target = (self.args.critic_target or "mc").lower() == "td0" for agent_idx in range(self.args.num_agents): traj = per_agent_samples[agent_idx] for t, sample in enumerate(traj): @@ -733,10 +727,7 @@ def _collect_rollouts_multi_turn( else: target = r sample.metadata["adv_target"] = torch.tensor([target]).detach().cpu() - if use_td_target: - sample.metadata["value_target"] = ( - torch.tensor([target]).detach().cpu() - ) + sample.metadata["value_target"] = torch.tensor([target]).detach().cpu() for agent_idx in range(self.args.num_agents): future = 0.0 @@ -802,9 +793,6 @@ def _normalize_returns(self, rollouts: List[RolloutSample]) -> None: def _prepare_advantages(self, rollouts: List[RolloutSample]) -> None: if not rollouts: return - if (self.args.critic_target or "mc").lower() == "mc": - self._normalize_returns(rollouts) - advantages = [] for sample in rollouts: target = sample.metadata.get("adv_target") or sample.metadata.get( From 4d83fde0c828e9923a3d207f2b33fae5147a61b3 Mon Sep 17 00:00:00 2001 From: N!no Date: Thu, 18 Dec 2025 19:41:54 -0500 Subject: [PATCH 09/13] Update maac.py --- comlrl/trainers/maac.py | 1 - 1 file changed, 1 deletion(-) diff --git a/comlrl/trainers/maac.py b/comlrl/trainers/maac.py index a37416c..9e050fd 100644 --- a/comlrl/trainers/maac.py +++ b/comlrl/trainers/maac.py @@ -41,7 +41,6 @@ class MAACConfig: max_grad_norm: float = 0.5 rollout_buffer_size: int = 8 mini_batch_size: int = 4 - ac_epochs: int = 1 value_loss_coef: float = 0.5 entropy_coef: float = 0.0 advantage_normalization: bool = True From ef29106070b4876d119e0a4220752b0c09b3a4f8 Mon Sep 17 00:00:00 2001 From: N!no Date: Fri, 19 Dec 2025 08:35:28 -0500 Subject: [PATCH 10/13] remove variance cal --- comlrl/trainers/maac.py | 14 -------------- comlrl/trainers/magrpo.py | 9 --------- 2 files changed, 23 deletions(-) diff --git a/comlrl/trainers/maac.py b/comlrl/trainers/maac.py index 9e050fd..ee534f1 100644 --- a/comlrl/trainers/maac.py +++ b/comlrl/trainers/maac.py @@ -993,19 +993,6 @@ def _summarize_rollout_metrics( return {} metrics: Dict[str, float] = {} - prompt_rewards: Dict[str, List[float]] = defaultdict(list) - for sample in rollouts: - prompt_rewards[sample.prompt].append( - float(sample.reward.view(-1)[0].item()) - ) - prompt_vars: List[float] = [] - for vals in prompt_rewards.values(): - if len(vals) > 1: - t = torch.tensor(vals, dtype=torch.float32) - prompt_vars.append(float(torch.var(t, unbiased=False).item())) - if prompt_vars: - metrics["value_variance"] = float(sum(prompt_vars) / len(prompt_vars)) - rewards = torch.stack( [sample.reward.view(-1)[0] for sample in rollouts] ).float() @@ -1116,7 +1103,6 @@ def _maybe_log(metric_key: str, epoch_key: str) -> None: _maybe_log("reward_mean", "epoch_reward_mean") _maybe_log("expected_return", "epoch_avg_return") - _maybe_log("value_variance", "epoch_value_variance") _maybe_log("value_pred_mean", "epoch_value_pred_mean") _maybe_log("value_target_mean", "epoch_value_target_mean") _maybe_log("policy_loss", "epoch_policy_loss") diff --git a/comlrl/trainers/magrpo.py b/comlrl/trainers/magrpo.py index 2cdadc9..4130c56 100644 --- a/comlrl/trainers/magrpo.py +++ b/comlrl/trainers/magrpo.py @@ -787,13 +787,6 @@ def train(self, **kwargs): epoch_log[f"turn_{turn_idx + 1}/epoch_avg_return"] = float( np.mean(epoch_turn_returns[turn_idx]) ) - if ( - epoch_turn_value_variances - and epoch_turn_value_variances[turn_idx] - ): - epoch_log[f"turn_{turn_idx + 1}/epoch_value_variance"] = float( - np.mean(epoch_turn_value_variances[turn_idx]) - ) if epoch_log: wandb.log(epoch_log) @@ -1081,8 +1074,6 @@ def post_order_update(node): stats["batch_expected_return"] = float( np.mean(turn_return_node_means[t]) ) - if turn_return_variances[t]: - stats["value_variance"] = float(np.mean(turn_return_variances[t])) # No per-reward-function means; use a single reward function batch_stats[t] = stats From 37baa64c285b3aab4f726781977a5da9bbe39dc3 Mon Sep 17 00:00:00 2001 From: N!no Date: Fri, 19 Dec 2025 08:40:29 -0500 Subject: [PATCH 11/13] Update maac.py --- comlrl/trainers/maac.py | 1 - 1 file changed, 1 deletion(-) diff --git a/comlrl/trainers/maac.py b/comlrl/trainers/maac.py index ee534f1..f42d71a 100644 --- a/comlrl/trainers/maac.py +++ b/comlrl/trainers/maac.py @@ -42,7 +42,6 @@ class MAACConfig: rollout_buffer_size: int = 8 mini_batch_size: int = 4 value_loss_coef: float = 0.5 - entropy_coef: float = 0.0 advantage_normalization: bool = True max_new_tokens: int = 128 temperature: float = 0.7 From e305927cbd4b1b2c7a9f101ab517c62a7dda1b32 Mon Sep 17 00:00:00 2001 From: N!no Date: Fri, 19 Dec 2025 08:45:26 -0500 Subject: [PATCH 12/13] Update magrpo.py --- comlrl/trainers/magrpo.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comlrl/trainers/magrpo.py b/comlrl/trainers/magrpo.py index 4130c56..f6dd80e 100644 --- a/comlrl/trainers/magrpo.py +++ b/comlrl/trainers/magrpo.py @@ -74,7 +74,7 @@ class MAGRPOConfig(TrainingArguments): # Evaluation eval_interval: int = field( - default=4, + default=8, metadata={"help": "Run evaluation every N training batches."}, ) eval_num_samples: int = field( From 7ac0e47114236fb4c182c60c0a6da402a6b0b3d2 Mon Sep 17 00:00:00 2001 From: N!no Date: Fri, 19 Dec 2025 09:12:58 -0500 Subject: [PATCH 13/13] remove redundant logic and change step --- comlrl/trainers/maac.py | 6 +++--- comlrl/trainers/magrpo.py | 23 +++++------------------ 2 files changed, 8 insertions(+), 21 deletions(-) diff --git a/comlrl/trainers/maac.py b/comlrl/trainers/maac.py index f42d71a..349d4bf 100644 --- a/comlrl/trainers/maac.py +++ b/comlrl/trainers/maac.py @@ -192,6 +192,7 @@ def __init__( self.wandb_config = wandb_config self.wandb_initialized = False + self.data_step = 0 if wandb_config is not None: self._init_wandb() @@ -1079,6 +1080,7 @@ def train(self) -> None: buffer.append(sample) if len(buffer) >= self.args.rollout_buffer_size: self._process_buffer(agent_idx, buffer, epoch_metrics) + self.data_step += 1 for agent_idx, buffer in enumerate(self.rollout_buffers): if not buffer: @@ -1109,7 +1111,6 @@ def _maybe_log(metric_key: str, epoch_key: str) -> None: if epoch_log: self._log_metrics(epoch_log) - self.global_step += 1 if summary: to_print = epoch_log if epoch_log else summary @@ -1128,7 +1129,7 @@ def _log_metrics(self, metrics: Dict[str, float]) -> None: if not metrics: return if self.wandb_initialized and wandb is not None: - wandb.log(metrics, step=self.global_step) + wandb.log(metrics, step=self.data_step) def _process_buffer( self, @@ -1159,7 +1160,6 @@ def _process_buffer( epoch_metrics[key].append(value) self._log_metrics(combined_log) - self.global_step += 1 def save_model(self, output_dir: str) -> None: os.makedirs(output_dir, exist_ok=True) diff --git a/comlrl/trainers/magrpo.py b/comlrl/trainers/magrpo.py index f6dd80e..4f1e4b9 100644 --- a/comlrl/trainers/magrpo.py +++ b/comlrl/trainers/magrpo.py @@ -148,6 +148,7 @@ def __init__( # Training arguments self.args = args if args is not None else MAGRPOConfig() + self.data_step = 0 # Reward and formatting self._setup_formatters(formatters, num_agents) @@ -685,7 +686,7 @@ def _log_eval_metrics( # Log evaluation metrics if self.wandb_initialized: - wandb.log(eval_metrics) + wandb.log(eval_metrics, step=self.data_step) return eval_metrics @@ -712,10 +713,6 @@ def train(self, **kwargs): [] for _ in range(self.args.num_turns) ] # immediate rewards epoch_turn_returns = [[] for _ in range(self.args.num_turns)] # returns - epoch_turn_value_variances = [ - [] for _ in range(self.args.num_turns) - ] # variance of per-node returns used as value estimates - dl = self.get_train_dataloader() if not getattr(self, "verbose", True): it = enumerate( @@ -728,6 +725,7 @@ def train(self, **kwargs): else: it = enumerate(dl) for batch_idx, batch in it: + self.data_step += len(batch) # Periodic evaluation based on configuration if int(self.args.eval_interval) > 0 and ( batch_idx % int(self.args.eval_interval) == 0 @@ -751,10 +749,6 @@ def train(self, **kwargs): n_turns = max(1, int(self.args.num_turns)) for t in range(n_turns): stats = batch_stats.get(t) or {} - if "value_variance" in stats: - epoch_turn_value_variances[t].append( - stats["value_variance"] - ) if self.wandb_initialized: prefix = f"turn_{t + 1}/" if "batch_mean_reward" in stats: @@ -765,14 +759,10 @@ def train(self, **kwargs): batch_log[prefix + "expected_return"] = stats[ "batch_expected_return" ] - if "value_variance" in stats: - batch_log[prefix + "value_variance"] = stats[ - "value_variance" - ] # No per-function reward splitting in single reward mode if self.wandb_initialized and batch_log: - wandb.log(batch_log) + wandb.log(batch_log, step=self.data_step) # Log per-turn epoch averages inline (avoid custom system/* metrics) if self.wandb_initialized: @@ -788,7 +778,7 @@ def train(self, **kwargs): np.mean(epoch_turn_returns[turn_idx]) ) if epoch_log: - wandb.log(epoch_log) + wandb.log(epoch_log, step=self.data_step) def _train_step_returns( self, @@ -814,7 +804,6 @@ def _train_step_returns( # Per-turn accumulators for batch-level summaries turn_reward_node_means: List[List[float]] = [[] for _ in range(num_turns)] turn_return_node_means: List[List[float]] = [[] for _ in range(num_turns)] - turn_return_variances: List[List[float]] = [[] for _ in range(num_turns)] # No per-function accumulation in single reward mode turn_node_counts: List[int] = [0 for _ in range(num_turns)] @@ -1013,8 +1002,6 @@ def record_turn_returns(node): mean_ret = float(np.mean(vals)) epoch_turn_returns[t].append(mean_ret) turn_return_node_means[t].append(mean_ret) - var_ret = float(np.var(vals)) if len(vals) > 1 else 0.0 - turn_return_variances[t].append(var_ret) for ch in node["children"]: record_turn_returns(ch)