From a57f7765a04fa668f5339861842cf9b1aa801557 Mon Sep 17 00:00:00 2001 From: hiyuchang Date: Tue, 20 Jan 2026 16:56:43 +0800 Subject: [PATCH 1/6] fix metrics --- trinity/explorer/scheduler.py | 80 ++++++++++++++++++++++++++++++----- trinity/utils/monitor.py | 12 ++++-- 2 files changed, 77 insertions(+), 15 deletions(-) diff --git a/trinity/explorer/scheduler.py b/trinity/explorer/scheduler.py index f84bb6ea26..761312a22e 100644 --- a/trinity/explorer/scheduler.py +++ b/trinity/explorer/scheduler.py @@ -6,8 +6,9 @@ import traceback from collections import defaultdict, deque from dataclasses import dataclass, field, replace -from typing import Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union +import numpy as np import ray from trinity.common.config import Config @@ -31,6 +32,48 @@ class TaskWrapper: results: List[Tuple[Status, List[Experience]]] = field(default_factory=list) +# Adapted from verl/trainer/ppo/metric_utils.py +def bootstrap_metric( + data: list[Any], + subset_size: int, + reduce_fns: list[Callable[[np.ndarray], float]], + n_bootstrap: int = 1000, + seed: int = 42, +) -> list[tuple[float, float]]: + """ + Performs bootstrap resampling to estimate statistics of metrics. + + This function uses bootstrap resampling to estimate the mean and standard deviation + of metrics computed by the provided reduction functions on random subsets of the data. + + Args: + data: List of data points to bootstrap from. + subset_size: Size of each bootstrap sample. + reduce_fns: List of functions that compute a metric from a subset of data. + n_bootstrap: Number of bootstrap iterations. Defaults to 1000. + seed: Random seed for reproducibility. Defaults to 42. + + Returns: + A list of tuples, where each tuple contains (mean, std) for a metric + corresponding to each reduction function in reduce_fns. + + Example: + >>> data = [1, 2, 3, 4, 5] + >>> reduce_fns = [np.mean, np.max] + >>> bootstrap_metric(data, 3, reduce_fns) + [(3.0, 0.5), (4.5, 0.3)] # Example values + """ + np.random.seed(seed) + + bootstrap_metric_lsts = [[] for _ in range(len(reduce_fns))] + for _ in range(n_bootstrap): + bootstrap_idxs = np.random.choice(len(data), size=subset_size, replace=True) + bootstrap_data = [data[i] for i in bootstrap_idxs] + for i, reduce_fn in enumerate(reduce_fns): + bootstrap_metric_lsts[i].append(reduce_fn(bootstrap_data)) + return [(np.mean(lst), np.std(lst)) for lst in bootstrap_metric_lsts] + + def calculate_task_level_metrics(metrics: List[Dict], is_eval: bool) -> Dict[str, float]: """Calculate task level metrics (mean) from multiple runs of the same task. @@ -54,16 +97,31 @@ def calculate_task_level_metrics(metrics: List[Dict], is_eval: bool) -> Dict[str if "time/task_execution" in key or "time/run_execution" in key: result[key] = sum(values) / len(values) continue - k_list = [] - k = 2 - while k < len(values): - k_list.append(k) - k *= 2 - k_list.append(len(values)) - for k in k_list: - result[f"{key}/mean@{k}"] = sum(values[:k]) / k - result[f"{key}/best@{k}"] = max(values[:k]) - result[f"{key}/worst@{k}"] = min(values[:k]) + + n_values = len(values) + result[f"{key}/mean@{n_values}"] = np.mean(values) + + if n_values > 1: + result[f"{key}/std@{n_values}"] = np.std(values) + ns = [] + n = 2 + while n < n_values: + ns.append(n) + n *= 2 + ns.append(n_values) + + for n in ns: + [(bon_mean, bon_std), (won_mean, won_std)] = bootstrap_metric( + data=values, subset_size=n, reduce_fns=[np.max, np.min], seed=42 + ) + result[f"{key}/best@{n}/mean"], result[f"{key}/best@{n}/std"] = ( + bon_mean, + bon_std, + ) + result[f"{key}/worst@{n}/mean"], result[f"{key}/worst@{n}/std"] = ( + won_mean, + won_std, + ) return result else: return { diff --git a/trinity/utils/monitor.py b/trinity/utils/monitor.py index 21ef7726f1..df90650c08 100644 --- a/trinity/utils/monitor.py +++ b/trinity/utils/monitor.py @@ -47,11 +47,15 @@ def gather_metrics( try: df = pd.DataFrame(metric_list) numeric_df = df.select_dtypes(include=[np.number]) - stats_df = numeric_df.agg(output_stats) metric = {} - for col in stats_df.columns: - for stats in output_stats: - metric[f"{prefix}/{col}/{stats}"] = stats_df.loc[stats, col].item() + for col in numeric_df.columns: + # Skip the columns that are already aggregated + if "std" in col.lower() or "mean" in col.lower(): + metric[f"{prefix}/{col}"] = numeric_df[col].mean() + else: + stats_df = numeric_df[[col]].agg(output_stats) + for stats in output_stats: + metric[f"{prefix}/{col}/{stats}"] = stats_df.loc[stats, col].item() return metric except Exception as e: raise ValueError(f"Failed to gather metrics: {e}") from e From 669729348ff052dfe24c6589ca2d395322cc05eb Mon Sep 17 00:00:00 2001 From: hiyuchang Date: Wed, 21 Jan 2026 10:54:54 +0800 Subject: [PATCH 2/6] add eval_metrics --- tests/explorer/explorer_test.py | 12 ++++++++++-- tests/trainer/trainer_test.py | 18 ++++++++++++++++-- trinity/explorer/explorer.py | 8 +++----- trinity/utils/monitor.py | 23 +++++++++++++++++++++-- 4 files changed, 50 insertions(+), 11 deletions(-) diff --git a/tests/explorer/explorer_test.py b/tests/explorer/explorer_test.py index b17bf7709b..f5f79a3beb 100644 --- a/tests/explorer/explorer_test.py +++ b/tests/explorer/explorer_test.py @@ -69,10 +69,18 @@ def test_explorer(self): self.assertEqual(parser.metric_max_step(rollout_metrics[0]), 8) self.assertEqual(parser.metric_max_step(eval_metrics[0]), 8) for eval_taskset, k_list in zip(eval_tasksets, [[1], [2, 4, 6], [2, 4, 8, 10]]): - for eval_stats in ["mean", "best", "worst"]: + metric_name = "score" if eval_taskset.name == "countdown" else "accuracy" + for eval_stats in ["mean", "std"]: + k = k_list[-1] + self.assertIn( + f"eval/{eval_taskset.name}/{metric_name}/{eval_stats}@{k}", + eval_metrics, + ) + for eval_stats in ["best", "worst"]: for k in k_list: + if k == 1: + continue for stats in ["mean", "std"]: - metric_name = "score" if eval_taskset.name == "countdown" else "accuracy" self.assertIn( f"eval/{eval_taskset.name}/{metric_name}/{eval_stats}@{k}/{stats}", eval_metrics, diff --git a/tests/trainer/trainer_test.py b/tests/trainer/trainer_test.py index 99eb711975..116d0d8086 100644 --- a/tests/trainer/trainer_test.py +++ b/tests/trainer/trainer_test.py @@ -166,7 +166,14 @@ def test_trainer(self): for taskset_name in ["countdown", "copy_countdown"]: metrics = parser.metric_list(f"{prefix}/{taskset_name}") self.assertGreater(len(metrics), 0, f"{prefix}/{taskset_name} metrics not found") - for eval_stats in ["mean", "best", "worst"]: + # mean@k, std@k + for eval_stats in ["mean", "std"]: + k = 4 + metric_name = f"{prefix}/{taskset_name}/score/{eval_stats}@{k}" + metric_steps = parser.metric_steps(metric_name) + self.assertEqual(metric_steps, [0, 4, 8]) + # best@k/mean, best@k/std, worst@k/mean, worst@k/std + for eval_stats in ["best", "worst"]: for k in [2, 4]: for stats in ["mean", "std"]: metric_name = f"{prefix}/{taskset_name}/score/{eval_stats}@{k}/{stats}" @@ -1332,7 +1339,14 @@ def test_trainer(self): for prefix in ["eval", "bench"]: gsm8k_metrics = parser.metric_list(f"{prefix}/gsm8k") self.assertGreater(len(gsm8k_metrics), 0, f"{prefix}/gsm8k metrics not found") - for eval_stats in ["mean", "best", "worst"]: + # mean@k, std@k + for eval_stats in ["mean", "std"]: + k = 4 + metric_name = f"{prefix}/gsm8k/accuracy/{eval_stats}@{k}" + metric_steps = parser.metric_steps(metric_name) + self.assertEqual(metric_steps, [0, 2]) + # best@k/mean, best@k/std, worst@k/mean, worst@k/std + for eval_stats in ["best", "worst"]: for k in [2, 4, 8]: for stats in ["mean", "std"]: metric_name = f"{prefix}/gsm8k/accuracy/{eval_stats}@{k}/{stats}" diff --git a/trinity/explorer/explorer.py b/trinity/explorer/explorer.py index b0893b8c52..539fdf2dc1 100644 --- a/trinity/explorer/explorer.py +++ b/trinity/explorer/explorer.py @@ -30,7 +30,7 @@ from trinity.manager.synchronizer import Synchronizer from trinity.utils.annotations import Experimental from trinity.utils.log import get_logger -from trinity.utils.monitor import MONITOR, gather_metrics +from trinity.utils.monitor import MONITOR, gather_eval_metrics, gather_metrics from trinity.utils.plugin_loader import load_plugins from trinity.utils.timer import Timer @@ -431,10 +431,8 @@ async def _finish_eval_step(self, step: Optional[int] = None, prefix: str = "eva statuses, _ = await self.scheduler.get_results(batch_id=f"{step}/{eval_task_name}") metric[f"{prefix}/{eval_task_name}/finished_task_count"] = len(statuses) metric.update( - gather_metrics( - [status.metrics[0] for status in statuses], - f"{prefix}/{eval_task_name}", - output_stats=["mean", "std"], + gather_eval_metrics( + [status.metrics[0] for status in statuses], f"{prefix}/{eval_task_name}" ) ) if self.eval_start_time is not None: diff --git a/trinity/utils/monitor.py b/trinity/utils/monitor.py index df90650c08..8b2b6364c4 100644 --- a/trinity/utils/monitor.py +++ b/trinity/utils/monitor.py @@ -41,6 +41,24 @@ def gather_metrics( metric_list: List[Dict], prefix: str, output_stats: List[str] = ["mean", "max", "min"] +) -> Dict: + if not metric_list: + return {} + try: + df = pd.DataFrame(metric_list) + numeric_df = df.select_dtypes(include=[np.number]) + stats_df = numeric_df.agg(output_stats) + metric = {} + for col in stats_df.columns: + for stats in output_stats: + metric[f"{prefix}/{col}/{stats}"] = stats_df.loc[stats, col].item() + return metric + except Exception as e: + raise ValueError(f"Failed to gather metrics: {e}") from e + + +def gather_eval_metrics( + metric_list: List[Dict], prefix: str, output_stats: List[str] = ["mean", "max", "min"] ) -> Dict: if not metric_list: return {} @@ -50,7 +68,8 @@ def gather_metrics( metric = {} for col in numeric_df.columns: # Skip the columns that are already aggregated - if "std" in col.lower() or "mean" in col.lower(): + key_words = ["std", "mean", "min", "max"] + if any(key_word in col.lower() for key_word in key_words): metric[f"{prefix}/{col}"] = numeric_df[col].mean() else: stats_df = numeric_df[[col]].agg(output_stats) @@ -58,7 +77,7 @@ def gather_metrics( metric[f"{prefix}/{col}/{stats}"] = stats_df.loc[stats, col].item() return metric except Exception as e: - raise ValueError(f"Failed to gather metrics: {e}") from e + raise ValueError(f"Failed to gather eval metrics: {e}") from e class Monitor(ABC): From a2fc766714eb6f12e614ee0e1379246d999f5614 Mon Sep 17 00:00:00 2001 From: hiyuchang Date: Wed, 21 Jan 2026 10:58:01 +0800 Subject: [PATCH 3/6] fix typo --- trinity/explorer/scheduler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trinity/explorer/scheduler.py b/trinity/explorer/scheduler.py index 761312a22e..ed58fff2c2 100644 --- a/trinity/explorer/scheduler.py +++ b/trinity/explorer/scheduler.py @@ -100,9 +100,9 @@ def calculate_task_level_metrics(metrics: List[Dict], is_eval: bool) -> Dict[str n_values = len(values) result[f"{key}/mean@{n_values}"] = np.mean(values) + result[f"{key}/std@{n_values}"] = np.std(values) if n_values > 1: - result[f"{key}/std@{n_values}"] = np.std(values) ns = [] n = 2 while n < n_values: From 19dc713224a64af8f5469ec452bc4bc8be7c7673 Mon Sep 17 00:00:00 2001 From: hiyuchang Date: Wed, 21 Jan 2026 17:17:21 +0800 Subject: [PATCH 4/6] fix unittest --- tests/conftest.py | 35 +++++++++++++++++++++++++++++++++++ tests/trainer/trainer_test.py | 2 +- 2 files changed, 36 insertions(+), 1 deletion(-) create mode 100644 tests/conftest.py diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000000..90b9a2bfb4 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,35 @@ +import pytest +import datetime + +# Get the result of each test +@pytest.hookimpl(tryfirst=True, hookwrapper=True) +def pytest_runtest_makereport(item, call): + outcome = yield + rep = outcome.get_result() + setattr(item, "rep_" + rep.when, rep) + +# Real-time print of start and end of test +@pytest.fixture(autouse=True) +def log_test_lifecycle(request): + node_id = request.node.nodeid + start_time = datetime.datetime.now().strftime("%H:%M:%S") + + print(f"\n[START] {start_time} - Running: {node_id}") + + yield + + end_time = datetime.datetime.now().strftime("%H:%M:%S") + # Get the result of each test (setup, call, teardown) + report = getattr(request.node, "rep_call", None) + + if report: + if report.passed: + status = "PASSED" + elif report.failed: + status = "FAILED" + else: + status = report.outcome.upper() + else: + status = "UNKNOWN" + + print(f"\n[END ] {end_time} - Result: {status} - {node_id}") diff --git a/tests/trainer/trainer_test.py b/tests/trainer/trainer_test.py index 116d0d8086..8dae2a921f 100644 --- a/tests/trainer/trainer_test.py +++ b/tests/trainer/trainer_test.py @@ -1341,7 +1341,7 @@ def test_trainer(self): self.assertGreater(len(gsm8k_metrics), 0, f"{prefix}/gsm8k metrics not found") # mean@k, std@k for eval_stats in ["mean", "std"]: - k = 4 + k = 8 metric_name = f"{prefix}/gsm8k/accuracy/{eval_stats}@{k}" metric_steps = parser.metric_steps(metric_name) self.assertEqual(metric_steps, [0, 2]) From 280bda18e577d4156e021976a674dbb4b7c86525 Mon Sep 17 00:00:00 2001 From: hiyuchang Date: Wed, 21 Jan 2026 17:18:42 +0800 Subject: [PATCH 5/6] fix comment --- trinity/utils/monitor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trinity/utils/monitor.py b/trinity/utils/monitor.py index 8b2b6364c4..f655433f79 100644 --- a/trinity/utils/monitor.py +++ b/trinity/utils/monitor.py @@ -69,7 +69,7 @@ def gather_eval_metrics( for col in numeric_df.columns: # Skip the columns that are already aggregated key_words = ["std", "mean", "min", "max"] - if any(key_word in col.lower() for key_word in key_words): + if any(col.endswith(key_word) for key_word in key_words): metric[f"{prefix}/{col}"] = numeric_df[col].mean() else: stats_df = numeric_df[[col]].agg(output_stats) From cf689002627ca1ca394ca3a84c61bec2e36b333f Mon Sep 17 00:00:00 2001 From: hiyuchang Date: Wed, 21 Jan 2026 17:20:48 +0800 Subject: [PATCH 6/6] fix pre-commit --- tests/conftest.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 90b9a2bfb4..c3b949136f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,8 @@ -import pytest import datetime +import pytest + + # Get the result of each test @pytest.hookimpl(tryfirst=True, hookwrapper=True) def pytest_runtest_makereport(item, call): @@ -8,20 +10,21 @@ def pytest_runtest_makereport(item, call): rep = outcome.get_result() setattr(item, "rep_" + rep.when, rep) + # Real-time print of start and end of test @pytest.fixture(autouse=True) def log_test_lifecycle(request): node_id = request.node.nodeid start_time = datetime.datetime.now().strftime("%H:%M:%S") - + print(f"\n[START] {start_time} - Running: {node_id}") - + yield - + end_time = datetime.datetime.now().strftime("%H:%M:%S") # Get the result of each test (setup, call, teardown) report = getattr(request.node, "rep_call", None) - + if report: if report.passed: status = "PASSED"