From 6da09c01d6caef6750dccca754e23f9e44648e0b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=A0=96=E9=B8=9F=E9=97=B2=E6=A6=86?= <59347549+coxine@users.noreply.github.com> Date: Mon, 30 Jun 2025 00:11:34 +0800 Subject: [PATCH 1/8] Add context normalization for HotpotQA evaluator and update problem keys --- mas_arena/evaluators/hotpotqa_evaluator.py | 11 ++++++----- mas_arena/evaluators/utils/normalization.py | 1 + 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/mas_arena/evaluators/hotpotqa_evaluator.py b/mas_arena/evaluators/hotpotqa_evaluator.py index f2db35e..6f3197e 100644 --- a/mas_arena/evaluators/hotpotqa_evaluator.py +++ b/mas_arena/evaluators/hotpotqa_evaluator.py @@ -22,6 +22,7 @@ normalization_keys={ "id": "id", "problem": "question", + "context": "context", "solution": "answer", } ) @@ -95,13 +96,13 @@ def create_run(self, problem: Dict[str, Any], final_answer: str, extracted_answe id=str(uuid.uuid4()), name=f"{self.name.upper()}_Evaluation", inputs={ - "question": problem["question"], + "question": problem["problem"], "context": problem["context"] }, outputs={ "prediction": final_answer, "extracted_answer": extracted_answer, - "expected": problem["answer"], + "expected": problem["solution"], "score": score, "passed": score >= 0.3, # HotpotQA uses 0.3 as threshold }, @@ -129,7 +130,7 @@ def evaluate(self, problem: Dict[str, Any], run_result: Dict[str, Any]) -> Dict[ context_str = "\n".join(" ".join(paragraph) for paragraph in paragraphs) # Calculate score - score, extracted_answer = self.calculate_score(problem["answer"], final_answer) + score, extracted_answer = self.calculate_score(problem["solution"], final_answer) # # Create LangSmith run # run = self.create_run(problem, final_answer, extracted_answer, score) @@ -138,9 +139,9 @@ def evaluate(self, problem: Dict[str, Any], run_result: Dict[str, Any]) -> Dict[ # Log mismatch if score is too low if score < 0.3: with open(f"{self.log_path}/mismatches.log", "a") as f: - f.write(f"\nQuestion: {problem['question']}\n") + f.write(f"\nQuestion: {problem['problem']}\n") f.write(f"Context: {context_str}\n") - f.write(f"Expected: {problem['answer']}\n") + f.write(f"Expected: {problem['solution']}\n") f.write(f"Predicted: {final_answer}\n") f.write(f"Score: {score}\n") diff --git a/mas_arena/evaluators/utils/normalization.py b/mas_arena/evaluators/utils/normalization.py index 1cb7800..2a1c7df 100644 --- a/mas_arena/evaluators/utils/normalization.py +++ b/mas_arena/evaluators/utils/normalization.py @@ -72,6 +72,7 @@ def normalize_problem_keys(problem: Dict[str, Any], key_mapping: Dict[str, str], key_definitions = { "id": "id", "problem": "problem", + "context": "context", "solution": "solution", "test": "test", "entry_point": "entry_point", From e814d2b0166fc835fa67e16c4a4fa6d01f5978e3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=A0=96=E9=B8=9F=E9=97=B2=E6=A6=86?= <59347549+coxine@users.noreply.github.com> Date: Mon, 30 Jun 2025 14:04:50 +0800 Subject: [PATCH 2/8] Add HotpotQA format prompt and update dataset handling --- mas_arena/agents/format_prompts.py | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/mas_arena/agents/format_prompts.py b/mas_arena/agents/format_prompts.py index 9699797..91c4ac2 100644 --- a/mas_arena/agents/format_prompts.py +++ b/mas_arena/agents/format_prompts.py @@ -19,6 +19,7 @@ class DatasetType(Enum): HUMANEVAL = auto() MATH = auto() CODE = auto() # Generic code generation format + HOTPOTQA = auto() # Multi-hop question answering @dataclass @@ -154,6 +155,31 @@ class FormatPrompt: """, description="Format prompt for math problems", dataset_type=DatasetType.MATH + ), + + "hotpotqa": FormatPrompt( + name="HOTPOTQA", + prompt=""" +- This is a multi-hop question answering task that requires reasoning across multiple documents. +- Read through all provided context documents carefully to find relevant information. +- The answer should be a specific entity, name, or short phrase (usually 1-5 words). +- Provide your reasoning process to show how you connected information across documents. +- Give your final answer in the format: your answer here +- Ensure your answer is: + * Factually accurate and supported by the context + * Precisely answering what is asked (e.g., if asked for a year, give a year; if asked for a name, give a name) + * Concise and specific (avoid unnecessary words or explanations in the answer tags) + * Properly capitalized and formatted + +Example format: +Based on the context, I need to find... [your reasoning] +From document X, I can see that... [connection 1] +From document Y, I can see that... [connection 2] +Therefore, connecting these pieces of information... +specific answer +""", + description="Format prompt for HotpotQA multi-hop question answering", + dataset_type=DatasetType.HOTPOTQA ) } @@ -173,6 +199,9 @@ def get_format_prompt(dataset_name: str) -> Optional[str]: return FORMAT_PROMPTS["code"].prompt if dataset_name in ["mmlu_pro", "mmlu"]: return FORMAT_PROMPTS["mmlu"].prompt + # Handle HotpotQA variants + if dataset_name.lower().startswith("hotpot"): + return FORMAT_PROMPTS["hotpotqa"].prompt # Get prompt for other datasets prompt_info = FORMAT_PROMPTS.get(dataset_name.lower()) return prompt_info.prompt if prompt_info else None From 167b9301de2f4fc47bf7dc14e13914d1feffb288 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=A0=96=E9=B8=9F=E9=97=B2=E6=A6=86?= <59347549+coxine@users.noreply.github.com> Date: Mon, 30 Jun 2025 15:06:08 +0800 Subject: [PATCH 3/8] Implement answer extraction logic in HotpotQAEvaluator and update score calculation --- mas_arena/evaluators/hotpotqa_evaluator.py | 62 ++++++++++++++++++++-- 1 file changed, 59 insertions(+), 3 deletions(-) diff --git a/mas_arena/evaluators/hotpotqa_evaluator.py b/mas_arena/evaluators/hotpotqa_evaluator.py index 6f3197e..b99a1ca 100644 --- a/mas_arena/evaluators/hotpotqa_evaluator.py +++ b/mas_arena/evaluators/hotpotqa_evaluator.py @@ -36,6 +36,60 @@ def __init__(self, name: str = "hotpotqa", config: Dict[str, Any] = None): @classmethod def from_config(cls, name: str, config: Dict[str, Any] = None): return cls(name, config) + + def extract_answer(self, text: str) -> str: + """ + Extract the answer from model output text, expecting '...' tags first. + + Args: + text: The model's output text + + Returns: + The extracted answer (e.g., "(A)", "True", "] >") + """ + text = text.strip() + + # Primary pattern: Content within ... tags + tag_pattern = r"\s*([\s\S]*?)\s*" + match = re.search(tag_pattern, text, re.IGNORECASE) + if match: + return match.group(1).strip() + + # Fallback: "Final Answer: " + final_answer_pattern = r"Final Answer:\s*(.+)" + match = re.search(final_answer_pattern, text, re.DOTALL) + if match: + return match.group(1).strip() + + # Fallback: Look for multiple-choice options (e.g., (A), A, [A]) + option_pattern = r"\([A-Z]\)|[A-Z]\b|\[[A-Z]\]" + matches = re.findall(option_pattern, text, re.DOTALL) + if matches: + last_match = matches[-1] + # Normalize to (A) format + if not last_match.startswith("("): + last_match = f"({last_match[-1]})" + return last_match.strip() + + # Fallback: Look for boolean values + boolean_pattern = r"\b(True|False)\b" + boolean_matches = re.findall(boolean_pattern, text, re.DOTALL) + if boolean_matches: + return boolean_matches[-1].strip() + + # Fallback: Look for sequence completions (e.g., "> ) }", "] ] ]") + sequence_pattern = r"([>\]\}\)\[]+\s*)+" + sequence_matches = re.findall(sequence_pattern, text, re.DOTALL) + if sequence_matches: + return sequence_matches[-1].strip() + + # Fallback: Last non-empty line + lines = [line.strip() for line in text.split("\n") if line.strip()] + if lines: + return lines[-1] + + # Final fallback: Return stripped text + return text.strip() def normalize_answer(self, s: str) -> str: """ @@ -73,20 +127,22 @@ def calculate_score(self, ground_truth: str, prediction: str) -> Tuple[float, st Returns: Tuple of (f1_score, prediction) """ - prediction_tokens = self.normalize_answer(prediction).split() + extracted_answer = self.extract_answer(prediction) + + prediction_tokens = self.normalize_answer(extracted_answer).split() ground_truth_tokens = self.normalize_answer(ground_truth).split() common = Counter(prediction_tokens) & Counter(ground_truth_tokens) num_same = sum(common.values()) if num_same == 0: - return 0, prediction + return 0, extracted_answer precision = 1.0 * num_same / len(prediction_tokens) recall = 1.0 * num_same / len(ground_truth_tokens) f1 = (2 * precision * recall) / (precision + recall) - return f1, prediction + return f1, extracted_answer def create_run(self, problem: Dict[str, Any], final_answer: str, extracted_answer: str, score: float) -> Run: """Create a LangSmith run for evaluation""" From 362f918991a7bcfa8d1a8f18f5b0e1e0c0f86d83 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=A0=96=E9=B8=9F=E9=97=B2=E6=A6=86?= <59347549+coxine@users.noreply.github.com> Date: Mon, 30 Jun 2025 15:49:54 +0800 Subject: [PATCH 4/8] Refactor answer extraction methods across evaluators and add utility functions to decrease redundant code --- mas_arena/evaluators/aime_evaluator.py | 9 +- mas_arena/evaluators/bbh_evaluator.py | 45 +------- mas_arena/evaluators/hotpotqa_evaluator.py | 45 +------- mas_arena/evaluators/mmlu_pro_evaluator.py | 9 +- mas_arena/evaluators/utils/__init__.py | 14 ++- .../evaluators/utils/answer_extraction.py | 108 ++++++++++++++++++ 6 files changed, 129 insertions(+), 101 deletions(-) create mode 100644 mas_arena/evaluators/utils/answer_extraction.py diff --git a/mas_arena/evaluators/aime_evaluator.py b/mas_arena/evaluators/aime_evaluator.py index 66226b2..989ee3f 100644 --- a/mas_arena/evaluators/aime_evaluator.py +++ b/mas_arena/evaluators/aime_evaluator.py @@ -10,6 +10,7 @@ from mas_arena.evaluators.base_evaluator import BaseEvaluator from mas_arena.evaluators.registry import register_benchmark from mas_arena.evaluators.utils.math_equal import calculate_score +from mas_arena.evaluators.utils import extract_answer_numeric @register_benchmark( name="aime", @@ -36,13 +37,7 @@ def extract_answer(self, text: str) -> str: """ Extract the answer from model output text (last number or string). """ - # Try to extract the last number (int/float) - matches = re.findall(r"[-+]?\d+(?:,\d{3})*(?:\.\d+)?|\d+\.\d+", str(text)) - if matches: - return matches[-1].replace(",", "").strip() - # Fallback: last non-empty line - lines = [line.strip() for line in str(text).splitlines() if line.strip()] - return lines[-1] if lines else str(text).strip() + return extract_answer_numeric(text) def calculate_score(self, expected_output: str, prediction: str) -> Tuple[int, str]: return calculate_score(expected_output, prediction) diff --git a/mas_arena/evaluators/bbh_evaluator.py b/mas_arena/evaluators/bbh_evaluator.py index 3d61198..92549db 100644 --- a/mas_arena/evaluators/bbh_evaluator.py +++ b/mas_arena/evaluators/bbh_evaluator.py @@ -14,6 +14,7 @@ from mas_arena.evaluators.base_evaluator import BaseEvaluator from mas_arena.evaluators.registry import register_benchmark +from mas_arena.evaluators.utils import extract_answer_generic @register_benchmark( @@ -56,49 +57,7 @@ def extract_answer(self, text: str) -> str: Returns: The extracted answer (e.g., "(A)", "True", "] >") """ - text = text.strip() - - # Primary pattern: Content within ... tags - tag_pattern = r"\s*([\s\S]*?)\s*" - match = re.search(tag_pattern, text, re.IGNORECASE) - if match: - return match.group(1).strip() - - # Fallback: "Final Answer: " - final_answer_pattern = r"Final Answer:\s*(.+)" - match = re.search(final_answer_pattern, text, re.DOTALL) - if match: - return match.group(1).strip() - - # Fallback: Look for multiple-choice options (e.g., (A), A, [A]) - option_pattern = r"\([A-Z]\)|[A-Z]\b|\[[A-Z]\]" - matches = re.findall(option_pattern, text, re.DOTALL) - if matches: - last_match = matches[-1] - # Normalize to (A) format - if not last_match.startswith("("): - last_match = f"({last_match[-1]})" - return last_match.strip() - - # Fallback: Look for boolean values - boolean_pattern = r"\b(True|False)\b" - boolean_matches = re.findall(boolean_pattern, text, re.DOTALL) - if boolean_matches: - return boolean_matches[-1].strip() - - # Fallback: Look for sequence completions (e.g., "> ) }", "] ] ]") - sequence_pattern = r"([>\]\}\)\[]+\s*)+" - sequence_matches = re.findall(sequence_pattern, text, re.DOTALL) - if sequence_matches: - return sequence_matches[-1].strip() - - # Fallback: Last non-empty line - lines = [line.strip() for line in text.split("\n") if line.strip()] - if lines: - return lines[-1] - - # Final fallback: Return stripped text - return text.strip() + return extract_answer_generic(text) def normalize_answer(self, answer: str) -> str: """ diff --git a/mas_arena/evaluators/hotpotqa_evaluator.py b/mas_arena/evaluators/hotpotqa_evaluator.py index b99a1ca..c46eeb9 100644 --- a/mas_arena/evaluators/hotpotqa_evaluator.py +++ b/mas_arena/evaluators/hotpotqa_evaluator.py @@ -15,6 +15,7 @@ from mas_arena.evaluators.base_evaluator import BaseEvaluator from mas_arena.evaluators.registry import register_benchmark +from mas_arena.evaluators.utils import extract_answer_generic @register_benchmark( @@ -47,49 +48,7 @@ def extract_answer(self, text: str) -> str: Returns: The extracted answer (e.g., "(A)", "True", "] >") """ - text = text.strip() - - # Primary pattern: Content within ... tags - tag_pattern = r"\s*([\s\S]*?)\s*" - match = re.search(tag_pattern, text, re.IGNORECASE) - if match: - return match.group(1).strip() - - # Fallback: "Final Answer: " - final_answer_pattern = r"Final Answer:\s*(.+)" - match = re.search(final_answer_pattern, text, re.DOTALL) - if match: - return match.group(1).strip() - - # Fallback: Look for multiple-choice options (e.g., (A), A, [A]) - option_pattern = r"\([A-Z]\)|[A-Z]\b|\[[A-Z]\]" - matches = re.findall(option_pattern, text, re.DOTALL) - if matches: - last_match = matches[-1] - # Normalize to (A) format - if not last_match.startswith("("): - last_match = f"({last_match[-1]})" - return last_match.strip() - - # Fallback: Look for boolean values - boolean_pattern = r"\b(True|False)\b" - boolean_matches = re.findall(boolean_pattern, text, re.DOTALL) - if boolean_matches: - return boolean_matches[-1].strip() - - # Fallback: Look for sequence completions (e.g., "> ) }", "] ] ]") - sequence_pattern = r"([>\]\}\)\[]+\s*)+" - sequence_matches = re.findall(sequence_pattern, text, re.DOTALL) - if sequence_matches: - return sequence_matches[-1].strip() - - # Fallback: Last non-empty line - lines = [line.strip() for line in text.split("\n") if line.strip()] - if lines: - return lines[-1] - - # Final fallback: Return stripped text - return text.strip() + return extract_answer_generic(text) def normalize_answer(self, s: str) -> str: """ diff --git a/mas_arena/evaluators/mmlu_pro_evaluator.py b/mas_arena/evaluators/mmlu_pro_evaluator.py index 54c833f..2b5cdd4 100644 --- a/mas_arena/evaluators/mmlu_pro_evaluator.py +++ b/mas_arena/evaluators/mmlu_pro_evaluator.py @@ -12,6 +12,7 @@ from mas_arena.evaluators.base_evaluator import BaseEvaluator from mas_arena.evaluators.registry import register_benchmark +from mas_arena.evaluators.utils import extract_answer_simple_tags @register_benchmark( @@ -134,13 +135,7 @@ def extract_answer_from_response(self, response: str) -> str: Returns: Extracted answer letter """ - # Try to extract answer from tags, allowing for whitespace - match = re.search(r'\s*(.*?)\s*', response, re.DOTALL) - if match: - return match.group(1).strip() - - # If no tags found, return original response - return response.strip() + return extract_answer_simple_tags(response) def evaluate(self, problem: Dict[str, Any], run_result: Dict[str, Any]) -> Dict[str, Any]: """ diff --git a/mas_arena/evaluators/utils/__init__.py b/mas_arena/evaluators/utils/__init__.py index ce43f5f..187ed0e 100644 --- a/mas_arena/evaluators/utils/__init__.py +++ b/mas_arena/evaluators/utils/__init__.py @@ -3,5 +3,17 @@ """ from .sanitize import sanitize +from .answer_extraction import ( + extract_answer_generic, + extract_answer_numeric, + extract_answer_simple_tags, + extract_answer # backward compatibility +) -__all__ = ["sanitize"] +__all__ = [ + "sanitize", + "extract_answer_generic", + "extract_answer_numeric", + "extract_answer_simple_tags", + "extract_answer" +] diff --git a/mas_arena/evaluators/utils/answer_extraction.py b/mas_arena/evaluators/utils/answer_extraction.py new file mode 100644 index 0000000..93c8bfb --- /dev/null +++ b/mas_arena/evaluators/utils/answer_extraction.py @@ -0,0 +1,108 @@ +""" +Answer extraction utilities for evaluators. + +This module provides common functions for extracting answers from model outputs +across different benchmark evaluators. +""" + +import re + + +def extract_answer_generic(text: str) -> str: + """ + Generic answer extraction with comprehensive fallback patterns. + Suitable for most text-based benchmarks like HotpotQA, BBH, etc. + + Args: + text: The model's output text + + Returns: + The extracted answer (e.g., "(A)", "True", text content) + """ + text = text.strip() + + # Primary pattern: Content within ... tags + tag_pattern = r"\s*([\s\S]*?)\s*" + match = re.search(tag_pattern, text, re.IGNORECASE) + if match: + return match.group(1).strip() + + # Fallback: "Final Answer: " + final_answer_pattern = r"Final Answer:\s*(.+)" + match = re.search(final_answer_pattern, text, re.DOTALL) + if match: + return match.group(1).strip() + + # Fallback: Look for multiple-choice options (e.g., (A), A, [A]) + option_pattern = r"\([A-Z]\)|[A-Z]\b|\[[A-Z]\]" + matches = re.findall(option_pattern, text, re.DOTALL) + if matches: + last_match = matches[-1] + # Normalize to (A) format + if not last_match.startswith("("): + last_match = f"({last_match[-1]})" + return last_match.strip() + + # Fallback: Look for boolean values + boolean_pattern = r"\b(True|False)\b" + boolean_matches = re.findall(boolean_pattern, text, re.DOTALL) + if boolean_matches: + return boolean_matches[-1].strip() + + # Fallback: Look for sequence completions (e.g., "> ) }", "] ] ]") + sequence_pattern = r"([>\]\}\)\[]+\s*)+" + sequence_matches = re.findall(sequence_pattern, text, re.DOTALL) + if sequence_matches: + return sequence_matches[-1].strip() + + # Fallback: Last non-empty line + lines = [line.strip() for line in text.split("\n") if line.strip()] + if lines: + return lines[-1] + + # Final fallback: Return stripped text + return text.strip() + + +def extract_answer_numeric(text: str) -> str: + """ + Extract numeric answers, suitable for math problems like AIME. + + Args: + text: The model's output text + + Returns: + The extracted numeric answer + """ + # Try to extract the last number (int/float) + matches = re.findall(r"[-+]?\d+(?:,\d{3})*(?:\.\d+)?|\d+\.\d+", str(text)) + if matches: + return matches[-1].replace(",", "").strip() + + # Fallback: last non-empty line + lines = [line.strip() for line in str(text).splitlines() if line.strip()] + return lines[-1] if lines else str(text).strip() + + +def extract_answer_simple_tags(text: str) -> str: + """ + Simple answer extraction for evaluators that primarily use tags. + Suitable for MMLU-Pro and similar benchmarks. + + Args: + text: The model's output text + + Returns: + The extracted answer + """ + # Try to extract answer from tags, allowing for whitespace + match = re.search(r'\s*(.*?)\s*', text, re.DOTALL) + if match: + return match.group(1).strip() + + # If no tags found, return original response + return text.strip() + + +# Backward compatibility aliases +extract_answer = extract_answer_generic From c47e38c20df92568f1cba5a58a00c77124119016 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=A0=96=E9=B8=9F=E9=97=B2=E6=A6=86?= <59347549+coxine@users.noreply.github.com> Date: Mon, 30 Jun 2025 19:52:10 +0800 Subject: [PATCH 5/8] Refactor normalization and F1 score calculation in DROP and HotpotQA evaluators; add metrics utility functions --- mas_arena/evaluators/drop_evaluator.py | 36 +------ mas_arena/evaluators/hotpotqa_evaluator.py | 39 +------- mas_arena/evaluators/utils/__init__.py | 12 ++- mas_arena/evaluators/utils/metrics.py | 111 +++++++++++++++++++++ 4 files changed, 131 insertions(+), 67 deletions(-) create mode 100644 mas_arena/evaluators/utils/metrics.py diff --git a/mas_arena/evaluators/drop_evaluator.py b/mas_arena/evaluators/drop_evaluator.py index e7e8069..fce125f 100644 --- a/mas_arena/evaluators/drop_evaluator.py +++ b/mas_arena/evaluators/drop_evaluator.py @@ -8,17 +8,16 @@ from __future__ import annotations import re -import string import time from pathlib import Path -from collections import Counter -from typing import Dict, Any, List +from typing import Dict, Any from langsmith.evaluation import RunEvaluator from langsmith.schemas import Run from mas_arena.evaluators.base_evaluator import BaseEvaluator from mas_arena.evaluators.registry import register_benchmark +from mas_arena.evaluators.utils import calculate_f1_score, normalize_answer _ANS_TAG_RE = re.compile(r"\s*([\s\S]*?)\s*", re.IGNORECASE) _FINAL_RE = re.compile(r"(?:^|\n)\s*(?:final\s+answer|answer)\s*[:\-]?\s*([\s\S]+)", re.IGNORECASE) @@ -73,38 +72,11 @@ def _extract_answer(self, raw: Any) -> str: @staticmethod def _normalize(s: Any) -> str: """DROP normalization: lowercase -> remove articles/punctuation -> collapse whitespace.""" - s = str(s) - - def remove_articles(t: str) -> str: - return re.sub(r"\b(a|an|the)\b", " ", t) - - def white_space_fix(t: str) -> str: - return " ".join(t.split()) - - def remove_punc(t: str) -> str: - return "".join(ch for ch in t if ch not in string.punctuation) - - return white_space_fix(remove_articles(remove_punc(s.lower()))) - + return normalize_answer(s) def _f1(self, gold: str, pred: str) -> float: """Calculates token-level F1 score (AllenNLP-style).""" - gold_toks: List[str] = self._normalize(gold).split() - pred_toks: List[str] = self._normalize(pred).split() - - if not gold_toks and not pred_toks: - return 1.0 - if not gold_toks or not pred_toks: - return 0.0 - - common = Counter(gold_toks) & Counter(pred_toks) - num_same = sum(common.values()) - if num_same == 0: - return 0.0 - - precision = num_same / len(pred_toks) - recall = num_same / len(gold_toks) - return 2 * precision * recall / (precision + recall) + return calculate_f1_score(gold, pred, self._normalize) def _make_run( diff --git a/mas_arena/evaluators/hotpotqa_evaluator.py b/mas_arena/evaluators/hotpotqa_evaluator.py index c46eeb9..5bd675e 100644 --- a/mas_arena/evaluators/hotpotqa_evaluator.py +++ b/mas_arena/evaluators/hotpotqa_evaluator.py @@ -4,10 +4,7 @@ This module provides a standalone evaluator for HotpotQA (Multi-hop Question Answering) problems. """ -import re -import string import time -from collections import Counter from typing import Dict, Any, Tuple from langsmith.evaluation import RunEvaluator @@ -15,7 +12,7 @@ from mas_arena.evaluators.base_evaluator import BaseEvaluator from mas_arena.evaluators.registry import register_benchmark -from mas_arena.evaluators.utils import extract_answer_generic +from mas_arena.evaluators.utils import extract_answer_generic, calculate_f1_score, normalize_answer @register_benchmark( @@ -60,20 +57,7 @@ def normalize_answer(self, s: str) -> str: Returns: Normalized answer string """ - def remove_articles(text): - return re.sub(r"\b(a|an|the)\b", " ", text) - - def white_space_fix(text): - return " ".join(text.split()) - - def remove_punc(text): - exclude = set(string.punctuation) - return "".join(ch for ch in text if ch not in exclude) - - def lower(text): - return text.lower() - - return white_space_fix(remove_articles(remove_punc(lower(s)))) + return normalize_answer(s) def calculate_score(self, ground_truth: str, prediction: str) -> Tuple[float, str]: """ @@ -84,24 +68,11 @@ def calculate_score(self, ground_truth: str, prediction: str) -> Tuple[float, st prediction: The predicted answer Returns: - Tuple of (f1_score, prediction) + Tuple of (f1_score, extracted_answer) """ extracted_answer = self.extract_answer(prediction) - - prediction_tokens = self.normalize_answer(extracted_answer).split() - ground_truth_tokens = self.normalize_answer(ground_truth).split() - - common = Counter(prediction_tokens) & Counter(ground_truth_tokens) - num_same = sum(common.values()) - - if num_same == 0: - return 0, extracted_answer - - precision = 1.0 * num_same / len(prediction_tokens) - recall = 1.0 * num_same / len(ground_truth_tokens) - f1 = (2 * precision * recall) / (precision + recall) - - return f1, extracted_answer + f1_score = calculate_f1_score(ground_truth, extracted_answer) + return f1_score, extracted_answer def create_run(self, problem: Dict[str, Any], final_answer: str, extracted_answer: str, score: float) -> Run: """Create a LangSmith run for evaluation""" diff --git a/mas_arena/evaluators/utils/__init__.py b/mas_arena/evaluators/utils/__init__.py index 187ed0e..2407db7 100644 --- a/mas_arena/evaluators/utils/__init__.py +++ b/mas_arena/evaluators/utils/__init__.py @@ -9,11 +9,21 @@ extract_answer_simple_tags, extract_answer # backward compatibility ) +from .metrics import ( + normalize_answer, + calculate_f1_score, + calculate_exact_match, + calculate_multi_answer_f1 +) __all__ = [ "sanitize", "extract_answer_generic", "extract_answer_numeric", "extract_answer_simple_tags", - "extract_answer" + "extract_answer", + "normalize_answer", + "calculate_f1_score", + "calculate_exact_match", + "calculate_multi_answer_f1" ] diff --git a/mas_arena/evaluators/utils/metrics.py b/mas_arena/evaluators/utils/metrics.py new file mode 100644 index 0000000..4067e43 --- /dev/null +++ b/mas_arena/evaluators/utils/metrics.py @@ -0,0 +1,111 @@ +""" +Evaluation metrics utilities. + +This module provides common metric calculations used across different evaluators. +""" + +import re +import string +from collections import Counter +from typing import List, Any + + +def normalize_answer(s: Any) -> str: + """ + Normalize answer text for evaluation. + Standard normalization: lowercase -> remove articles/punctuation -> collapse whitespace. + Used by DROP, HotpotQA and similar benchmarks. + + Args: + s: The text to normalize + + Returns: + Normalized text string + """ + s = str(s) + + def remove_articles(text: str) -> str: + return re.sub(r"\b(a|an|the)\b", " ", text) + + def white_space_fix(text: str) -> str: + return " ".join(text.split()) + + def remove_punc(text: str) -> str: + return "".join(ch for ch in text if ch not in string.punctuation) + + return white_space_fix(remove_articles(remove_punc(s.lower()))) + + +def calculate_f1_score(gold: str, pred: str, normalize_fn=None) -> float: + """ + Calculate token-level F1 score between gold and predicted answers. + + Args: + gold: Ground truth answer + pred: Predicted answer + normalize_fn: Optional normalization function. If None, uses normalize_answer. + + Returns: + F1 score between 0.0 and 1.0 + """ + if normalize_fn is None: + normalize_fn = normalize_answer + + gold_toks: List[str] = normalize_fn(gold).split() + pred_toks: List[str] = normalize_fn(pred).split() + + if not gold_toks and not pred_toks: + return 1.0 + if not gold_toks or not pred_toks: + return 0.0 + + common = Counter(gold_toks) & Counter(pred_toks) + num_same = sum(common.values()) + if num_same == 0: + return 0.0 + + precision = num_same / len(pred_toks) + recall = num_same / len(gold_toks) + return 2 * precision * recall / (precision + recall) + + +def calculate_exact_match(gold: str, pred: str, normalize_fn=None) -> float: + """ + Calculate exact match score between gold and predicted answers. + + Args: + gold: Ground truth answer + pred: Predicted answer + normalize_fn: Optional normalization function. If None, uses normalize_answer. + + Returns: + 1.0 if exact match, 0.0 otherwise + """ + if normalize_fn is None: + normalize_fn = normalize_answer + + return 1.0 if normalize_fn(gold) == normalize_fn(pred) else 0.0 + + +def calculate_multi_answer_f1(gold_answers: List[str], pred_answers: List[str], normalize_fn=None) -> float: + """ + Calculate the best F1 score when there are multiple possible gold and predicted answers. + Used by DROP and similar benchmarks that support multiple valid answers. + + Args: + gold_answers: List of valid ground truth answers + pred_answers: List of predicted answers + normalize_fn: Optional normalization function. If None, uses normalize_answer. + + Returns: + Best F1 score found between any gold-pred pair + """ + if not gold_answers or not pred_answers: + return 0.0 + + scores = [ + calculate_f1_score(gold, pred, normalize_fn) + for gold in gold_answers for pred in pred_answers + ] + + return max(scores) if scores else 0.0 From 7bd67d2e78c6d56bcbcae8493824be5ca8486814 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=A0=96=E9=B8=9F=E9=97=B2=E6=A6=86?= <59347549+coxine@users.noreply.github.com> Date: Mon, 30 Jun 2025 21:24:08 +0800 Subject: [PATCH 6/8] Update score calculation in HotpotQAEvaluator to set final score to 1.0 if score is 0.3 or higher --- mas_arena/evaluators/hotpotqa_evaluator.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/mas_arena/evaluators/hotpotqa_evaluator.py b/mas_arena/evaluators/hotpotqa_evaluator.py index 5bd675e..a7066b2 100644 --- a/mas_arena/evaluators/hotpotqa_evaluator.py +++ b/mas_arena/evaluators/hotpotqa_evaluator.py @@ -131,9 +131,12 @@ def evaluate(self, problem: Dict[str, Any], run_result: Dict[str, Any]) -> Dict[ f.write(f"Predicted: {final_answer}\n") f.write(f"Score: {score}\n") + # Final score: 1.0 if score >= 0.3, else use the score directly + final_score = 1 if score >= 0.3 else score + return { "final_answer": final_answer, "extracted_answer": extracted_answer, - "score": score, + "score": final_score, "context": context_str } \ No newline at end of file From 28d55dee86538e5ff0684c6dc73a3bbbd77405d4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=A0=96=E9=B8=9F=E9=97=B2=E6=A6=86?= <59347549+coxine@users.noreply.github.com> Date: Tue, 1 Jul 2025 10:36:22 +0800 Subject: [PATCH 7/8] Update normalization key for 'id' in HotpotQAEvaluator to use '_id' --- mas_arena/evaluators/hotpotqa_evaluator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mas_arena/evaluators/hotpotqa_evaluator.py b/mas_arena/evaluators/hotpotqa_evaluator.py index a7066b2..d1f336b 100644 --- a/mas_arena/evaluators/hotpotqa_evaluator.py +++ b/mas_arena/evaluators/hotpotqa_evaluator.py @@ -18,7 +18,7 @@ @register_benchmark( name="hotpotqa", normalization_keys={ - "id": "id", + "id": "_id", "problem": "question", "context": "context", "solution": "answer", From d24a180559bc9d120894733f9de1c02cd368481a Mon Sep 17 00:00:00 2001 From: RuishanFang <22015012@zju.edu.cn> Date: Mon, 28 Jul 2025 19:24:21 +0800 Subject: [PATCH 8/8] feat: add 'hotpotqa' to supported reasoning benchmarks in README(#6) --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 8c8efeb..d1bbccf 100644 --- a/README.md +++ b/README.md @@ -69,7 +69,7 @@ OPENAI_API_BASE=https://api.openai.com/v1 * Supported benchmarks: * Math: `math`, `aime` * Code: `humaneval`, `mbpp` - * Reasoning: `drop`, `bbh`, `mmlu_pro`, `ifeval` + * Reasoning: `drop`, `bbh`, `mmlu_pro`, `ifeval`, `hotpotqa` * Supported agent systems: * Single Agent: `single_agent` * Multi-Agent: `supervisor_mas`, `swarm`, `agentverse`, `chateval`, `evoagent`, `jarvis`, `metagpt`