Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions tests/explorer/explorer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,18 @@ def test_explorer(self):
self.assertEqual(parser.metric_max_step(rollout_metrics[0]), 8)
self.assertEqual(parser.metric_max_step(eval_metrics[0]), 8)
for eval_taskset, k_list in zip(eval_tasksets, [[1], [2, 4, 6], [2, 4, 8, 10]]):
for eval_stats in ["mean", "best", "worst"]:
metric_name = "score" if eval_taskset.name == "countdown" else "accuracy"
for eval_stats in ["mean", "std"]:
k = k_list[-1]
self.assertIn(
f"eval/{eval_taskset.name}/{metric_name}/{eval_stats}@{k}",
eval_metrics,
)
for eval_stats in ["best", "worst"]:
for k in k_list:
if k == 1:
continue
for stats in ["mean", "std"]:
metric_name = "score" if eval_taskset.name == "countdown" else "accuracy"
self.assertIn(
f"eval/{eval_taskset.name}/{metric_name}/{eval_stats}@{k}/{stats}",
eval_metrics,
Expand Down
18 changes: 16 additions & 2 deletions tests/trainer/trainer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,14 @@ def test_trainer(self):
for taskset_name in ["countdown", "copy_countdown"]:
metrics = parser.metric_list(f"{prefix}/{taskset_name}")
self.assertGreater(len(metrics), 0, f"{prefix}/{taskset_name} metrics not found")
for eval_stats in ["mean", "best", "worst"]:
# mean@k, std@k
for eval_stats in ["mean", "std"]:
k = 4
metric_name = f"{prefix}/{taskset_name}/score/{eval_stats}@{k}"
metric_steps = parser.metric_steps(metric_name)
self.assertEqual(metric_steps, [0, 4, 8])
# best@k/mean, best@k/std, worst@k/mean, worst@k/std
for eval_stats in ["best", "worst"]:
for k in [2, 4]:
for stats in ["mean", "std"]:
metric_name = f"{prefix}/{taskset_name}/score/{eval_stats}@{k}/{stats}"
Expand Down Expand Up @@ -1332,7 +1339,14 @@ def test_trainer(self):
for prefix in ["eval", "bench"]:
gsm8k_metrics = parser.metric_list(f"{prefix}/gsm8k")
self.assertGreater(len(gsm8k_metrics), 0, f"{prefix}/gsm8k metrics not found")
for eval_stats in ["mean", "best", "worst"]:
# mean@k, std@k
for eval_stats in ["mean", "std"]:
k = 4
metric_name = f"{prefix}/gsm8k/accuracy/{eval_stats}@{k}"
metric_steps = parser.metric_steps(metric_name)
self.assertEqual(metric_steps, [0, 2])
# best@k/mean, best@k/std, worst@k/mean, worst@k/std
for eval_stats in ["best", "worst"]:
for k in [2, 4, 8]:
for stats in ["mean", "std"]:
metric_name = f"{prefix}/gsm8k/accuracy/{eval_stats}@{k}/{stats}"
Expand Down
8 changes: 3 additions & 5 deletions trinity/explorer/explorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from trinity.manager.synchronizer import Synchronizer
from trinity.utils.annotations import Experimental
from trinity.utils.log import get_logger
from trinity.utils.monitor import MONITOR, gather_metrics
from trinity.utils.monitor import MONITOR, gather_eval_metrics, gather_metrics
from trinity.utils.plugin_loader import load_plugins
from trinity.utils.timer import Timer

Expand Down Expand Up @@ -431,10 +431,8 @@ async def _finish_eval_step(self, step: Optional[int] = None, prefix: str = "eva
statuses, _ = await self.scheduler.get_results(batch_id=f"{step}/{eval_task_name}")
metric[f"{prefix}/{eval_task_name}/finished_task_count"] = len(statuses)
metric.update(
gather_metrics(
[status.metrics[0] for status in statuses],
f"{prefix}/{eval_task_name}",
output_stats=["mean", "std"],
gather_eval_metrics(
[status.metrics[0] for status in statuses], f"{prefix}/{eval_task_name}"
)
)
if self.eval_start_time is not None:
Expand Down
80 changes: 69 additions & 11 deletions trinity/explorer/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@
import traceback
from collections import defaultdict, deque
from dataclasses import dataclass, field, replace
from typing import Dict, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import numpy as np
import ray

from trinity.common.config import Config
Expand All @@ -31,6 +32,48 @@ class TaskWrapper:
results: List[Tuple[Status, List[Experience]]] = field(default_factory=list)


# Adapted from verl/trainer/ppo/metric_utils.py
def bootstrap_metric(
data: list[Any],
subset_size: int,
reduce_fns: list[Callable[[np.ndarray], float]],
n_bootstrap: int = 1000,
seed: int = 42,
) -> list[tuple[float, float]]:
"""
Performs bootstrap resampling to estimate statistics of metrics.
This function uses bootstrap resampling to estimate the mean and standard deviation
of metrics computed by the provided reduction functions on random subsets of the data.
Args:
data: List of data points to bootstrap from.
subset_size: Size of each bootstrap sample.
reduce_fns: List of functions that compute a metric from a subset of data.
n_bootstrap: Number of bootstrap iterations. Defaults to 1000.
seed: Random seed for reproducibility. Defaults to 42.
Returns:
A list of tuples, where each tuple contains (mean, std) for a metric
corresponding to each reduction function in reduce_fns.
Example:
>>> data = [1, 2, 3, 4, 5]
>>> reduce_fns = [np.mean, np.max]
>>> bootstrap_metric(data, 3, reduce_fns)
[(3.0, 0.5), (4.5, 0.3)] # Example values
"""
np.random.seed(seed)

bootstrap_metric_lsts = [[] for _ in range(len(reduce_fns))]
for _ in range(n_bootstrap):
bootstrap_idxs = np.random.choice(len(data), size=subset_size, replace=True)
bootstrap_data = [data[i] for i in bootstrap_idxs]
for i, reduce_fn in enumerate(reduce_fns):
bootstrap_metric_lsts[i].append(reduce_fn(bootstrap_data))
return [(np.mean(lst), np.std(lst)) for lst in bootstrap_metric_lsts]


def calculate_task_level_metrics(metrics: List[Dict], is_eval: bool) -> Dict[str, float]:
"""Calculate task level metrics (mean) from multiple runs of the same task.
Expand All @@ -54,16 +97,31 @@ def calculate_task_level_metrics(metrics: List[Dict], is_eval: bool) -> Dict[str
if "time/task_execution" in key or "time/run_execution" in key:
result[key] = sum(values) / len(values)
continue
k_list = []
k = 2
while k < len(values):
k_list.append(k)
k *= 2
k_list.append(len(values))
for k in k_list:
result[f"{key}/mean@{k}"] = sum(values[:k]) / k
result[f"{key}/best@{k}"] = max(values[:k])
result[f"{key}/worst@{k}"] = min(values[:k])

n_values = len(values)
result[f"{key}/mean@{n_values}"] = np.mean(values)
result[f"{key}/std@{n_values}"] = np.std(values)

if n_values > 1:
ns = []
n = 2
while n < n_values:
ns.append(n)
n *= 2
ns.append(n_values)

for n in ns:
[(bon_mean, bon_std), (won_mean, won_std)] = bootstrap_metric(
data=values, subset_size=n, reduce_fns=[np.max, np.min], seed=42
)
result[f"{key}/best@{n}/mean"], result[f"{key}/best@{n}/std"] = (
bon_mean,
bon_std,
)
result[f"{key}/worst@{n}/mean"], result[f"{key}/worst@{n}/std"] = (
won_mean,
won_std,
)
return result
else:
return {
Expand Down
23 changes: 23 additions & 0 deletions trinity/utils/monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,29 @@ def gather_metrics(
raise ValueError(f"Failed to gather metrics: {e}") from e


def gather_eval_metrics(
metric_list: List[Dict], prefix: str, output_stats: List[str] = ["mean", "max", "min"]
) -> Dict:
if not metric_list:
return {}
try:
df = pd.DataFrame(metric_list)
numeric_df = df.select_dtypes(include=[np.number])
metric = {}
for col in numeric_df.columns:
# Skip the columns that are already aggregated
key_words = ["std", "mean", "min", "max"]
if any(key_word in col.lower() for key_word in key_words):
metric[f"{prefix}/{col}"] = numeric_df[col].mean()
else:
stats_df = numeric_df[[col]].agg(output_stats)
for stats in output_stats:
metric[f"{prefix}/{col}/{stats}"] = stats_df.loc[stats, col].item()
return metric
except Exception as e:
raise ValueError(f"Failed to gather eval metrics: {e}") from e


class Monitor(ABC):
"""Monitor"""

Expand Down