diff --git a/src/gimbench/base.py b/src/gimbench/base.py index 0c240f6..f48ca31 100644 --- a/src/gimbench/base.py +++ b/src/gimbench/base.py @@ -87,3 +87,12 @@ def __init__(self, args: Namespace, dataset: Dataset): def _safe_average(items: list, attr: str) -> float: values = [getattr(item, attr) for item in items if getattr(item, attr) != -1] return sum(values) / len(values) if values else 0.0 + + def _log_progress(self, total: int, curr_idx: int, log_interval: int = 10) -> None: + completed = curr_idx + 1 + if completed % log_interval == 0: + speed = (datetime.now() - self.start_time).total_seconds() / completed + logger.info( + f"Progress: {completed}/{total} items evaluated with speed {speed:.2f} seconds/item. " + f"Time Remaining: {(total - completed) * speed / 60:.2f} minutes" + ) diff --git a/src/gimbench/ctp/evaluators.py b/src/gimbench/ctp/evaluators.py index 189a77b..18b1209 100644 --- a/src/gimbench/ctp/evaluators.py +++ b/src/gimbench/ctp/evaluators.py @@ -104,6 +104,8 @@ def evaluate(self) -> EvalResult: result = self._evaluate_item(self.dataset[idx]) evaled_items.append(result) + self._log_progress(total, idx) + self.end_time = datetime.now() logger.info(f"Evaluation completed at {self.end_time}") diff --git a/src/gimbench/match/evaluators.py b/src/gimbench/match/evaluators.py index f24218b..4213f16 100644 --- a/src/gimbench/match/evaluators.py +++ b/src/gimbench/match/evaluators.py @@ -102,6 +102,8 @@ def evaluate(self) -> EvalResult: result = self._evaluate_item(self.dataset[idx]) evaled_items.append(result) + self._log_progress(total, idx) + self.end_time = datetime.now() logger.info(f"Evaluation completed at {self.end_time}") diff --git a/src/gimbench/mcqa/evaluators.py b/src/gimbench/mcqa/evaluators.py index f8e5c58..9c7e276 100644 --- a/src/gimbench/mcqa/evaluators.py +++ b/src/gimbench/mcqa/evaluators.py @@ -128,10 +128,13 @@ def evaluate(self) -> EvalResult: for idx in tqdm(range(total), desc=f"Evaluating {self.args.model_name}"): result = self._evaluate_item(self.dataset[idx]) evaled_items.append(result) + + self._log_progress(total, idx) else: with ThreadPoolExecutor(max_workers=self.args.num_proc) as executor: results = executor.map(self._evaluate_item, (self.dataset[i] for i in range(total))) evaled_items = list(tqdm(results, total=total, desc=f"Evaluating {self.args.model_name}")) + # TODO: Add progress logging for multi-threaded evaluation errors = sum(1 for item in evaled_items if item.error_msg) corrects = sum(1 for item in evaled_items if item.conclusion)