diff --git a/README.md b/README.md
index 8c8efeb..d1bbccf 100644
--- a/README.md
+++ b/README.md
@@ -69,7 +69,7 @@ OPENAI_API_BASE=https://api.openai.com/v1
* Supported benchmarks:
* Math: `math`, `aime`
* Code: `humaneval`, `mbpp`
- * Reasoning: `drop`, `bbh`, `mmlu_pro`, `ifeval`
+ * Reasoning: `drop`, `bbh`, `mmlu_pro`, `ifeval`, `hotpotqa`
* Supported agent systems:
* Single Agent: `single_agent`
* Multi-Agent: `supervisor_mas`, `swarm`, `agentverse`, `chateval`, `evoagent`, `jarvis`, `metagpt`
diff --git a/mas_arena/agents/format_prompts.py b/mas_arena/agents/format_prompts.py
index 9699797..91c4ac2 100644
--- a/mas_arena/agents/format_prompts.py
+++ b/mas_arena/agents/format_prompts.py
@@ -19,6 +19,7 @@ class DatasetType(Enum):
HUMANEVAL = auto()
MATH = auto()
CODE = auto() # Generic code generation format
+ HOTPOTQA = auto() # Multi-hop question answering
@dataclass
@@ -154,6 +155,31 @@ class FormatPrompt:
""",
description="Format prompt for math problems",
dataset_type=DatasetType.MATH
+ ),
+
+ "hotpotqa": FormatPrompt(
+ name="HOTPOTQA",
+ prompt="""
+- This is a multi-hop question answering task that requires reasoning across multiple documents.
+- Read through all provided context documents carefully to find relevant information.
+- The answer should be a specific entity, name, or short phrase (usually 1-5 words).
+- Provide your reasoning process to show how you connected information across documents.
+- Give your final answer in the format: your answer here
+- Ensure your answer is:
+ * Factually accurate and supported by the context
+ * Precisely answering what is asked (e.g., if asked for a year, give a year; if asked for a name, give a name)
+ * Concise and specific (avoid unnecessary words or explanations in the answer tags)
+ * Properly capitalized and formatted
+
+Example format:
+Based on the context, I need to find... [your reasoning]
+From document X, I can see that... [connection 1]
+From document Y, I can see that... [connection 2]
+Therefore, connecting these pieces of information...
+specific answer
+""",
+ description="Format prompt for HotpotQA multi-hop question answering",
+ dataset_type=DatasetType.HOTPOTQA
)
}
@@ -173,6 +199,9 @@ def get_format_prompt(dataset_name: str) -> Optional[str]:
return FORMAT_PROMPTS["code"].prompt
if dataset_name in ["mmlu_pro", "mmlu"]:
return FORMAT_PROMPTS["mmlu"].prompt
+ # Handle HotpotQA variants
+ if dataset_name.lower().startswith("hotpot"):
+ return FORMAT_PROMPTS["hotpotqa"].prompt
# Get prompt for other datasets
prompt_info = FORMAT_PROMPTS.get(dataset_name.lower())
return prompt_info.prompt if prompt_info else None
diff --git a/mas_arena/evaluators/aime_evaluator.py b/mas_arena/evaluators/aime_evaluator.py
index 83c5f29..4d40c1b 100644
--- a/mas_arena/evaluators/aime_evaluator.py
+++ b/mas_arena/evaluators/aime_evaluator.py
@@ -19,6 +19,7 @@
from mas_arena.evaluators.base_evaluator import BaseEvaluator
from mas_arena.evaluators.registry import register_benchmark
from mas_arena.evaluators.utils.math_equal import calculate_score
+from mas_arena.evaluators.utils import extract_answer_numeric
@register_benchmark(
@@ -67,13 +68,7 @@ def math_extract_answer(self, text: str) -> str:
"""
Extract the answer from model output text (last number or string).
"""
- # Try to extract the last number (int/float)
- matches = re.findall(r"[-+]?\d+(?:,\d{3})*(?:\.\d+)?|\d+\.\d+", str(text))
- if matches:
- return matches[-1].replace(",", "").strip()
- # Fallback: last non-empty line
- lines = [line.strip() for line in str(text).splitlines() if line.strip()]
- return lines[-1] if lines else str(text).strip()
+ return extract_answer_numeric(text)
def simple_extract_answer(self, text: str) -> str:
"""
diff --git a/mas_arena/evaluators/bbh_evaluator.py b/mas_arena/evaluators/bbh_evaluator.py
index 3d61198..92549db 100644
--- a/mas_arena/evaluators/bbh_evaluator.py
+++ b/mas_arena/evaluators/bbh_evaluator.py
@@ -14,6 +14,7 @@
from mas_arena.evaluators.base_evaluator import BaseEvaluator
from mas_arena.evaluators.registry import register_benchmark
+from mas_arena.evaluators.utils import extract_answer_generic
@register_benchmark(
@@ -56,49 +57,7 @@ def extract_answer(self, text: str) -> str:
Returns:
The extracted answer (e.g., "(A)", "True", "] >")
"""
- text = text.strip()
-
- # Primary pattern: Content within ... tags
- tag_pattern = r"\s*([\s\S]*?)\s*"
- match = re.search(tag_pattern, text, re.IGNORECASE)
- if match:
- return match.group(1).strip()
-
- # Fallback: "Final Answer: "
- final_answer_pattern = r"Final Answer:\s*(.+)"
- match = re.search(final_answer_pattern, text, re.DOTALL)
- if match:
- return match.group(1).strip()
-
- # Fallback: Look for multiple-choice options (e.g., (A), A, [A])
- option_pattern = r"\([A-Z]\)|[A-Z]\b|\[[A-Z]\]"
- matches = re.findall(option_pattern, text, re.DOTALL)
- if matches:
- last_match = matches[-1]
- # Normalize to (A) format
- if not last_match.startswith("("):
- last_match = f"({last_match[-1]})"
- return last_match.strip()
-
- # Fallback: Look for boolean values
- boolean_pattern = r"\b(True|False)\b"
- boolean_matches = re.findall(boolean_pattern, text, re.DOTALL)
- if boolean_matches:
- return boolean_matches[-1].strip()
-
- # Fallback: Look for sequence completions (e.g., "> ) }", "] ] ]")
- sequence_pattern = r"([>\]\}\)\[]+\s*)+"
- sequence_matches = re.findall(sequence_pattern, text, re.DOTALL)
- if sequence_matches:
- return sequence_matches[-1].strip()
-
- # Fallback: Last non-empty line
- lines = [line.strip() for line in text.split("\n") if line.strip()]
- if lines:
- return lines[-1]
-
- # Final fallback: Return stripped text
- return text.strip()
+ return extract_answer_generic(text)
def normalize_answer(self, answer: str) -> str:
"""
diff --git a/mas_arena/evaluators/drop_evaluator.py b/mas_arena/evaluators/drop_evaluator.py
index e7e8069..fce125f 100644
--- a/mas_arena/evaluators/drop_evaluator.py
+++ b/mas_arena/evaluators/drop_evaluator.py
@@ -8,17 +8,16 @@
from __future__ import annotations
import re
-import string
import time
from pathlib import Path
-from collections import Counter
-from typing import Dict, Any, List
+from typing import Dict, Any
from langsmith.evaluation import RunEvaluator
from langsmith.schemas import Run
from mas_arena.evaluators.base_evaluator import BaseEvaluator
from mas_arena.evaluators.registry import register_benchmark
+from mas_arena.evaluators.utils import calculate_f1_score, normalize_answer
_ANS_TAG_RE = re.compile(r"\s*([\s\S]*?)\s*", re.IGNORECASE)
_FINAL_RE = re.compile(r"(?:^|\n)\s*(?:final\s+answer|answer)\s*[:\-]?\s*([\s\S]+)", re.IGNORECASE)
@@ -73,38 +72,11 @@ def _extract_answer(self, raw: Any) -> str:
@staticmethod
def _normalize(s: Any) -> str:
"""DROP normalization: lowercase -> remove articles/punctuation -> collapse whitespace."""
- s = str(s)
-
- def remove_articles(t: str) -> str:
- return re.sub(r"\b(a|an|the)\b", " ", t)
-
- def white_space_fix(t: str) -> str:
- return " ".join(t.split())
-
- def remove_punc(t: str) -> str:
- return "".join(ch for ch in t if ch not in string.punctuation)
-
- return white_space_fix(remove_articles(remove_punc(s.lower())))
-
+ return normalize_answer(s)
def _f1(self, gold: str, pred: str) -> float:
"""Calculates token-level F1 score (AllenNLP-style)."""
- gold_toks: List[str] = self._normalize(gold).split()
- pred_toks: List[str] = self._normalize(pred).split()
-
- if not gold_toks and not pred_toks:
- return 1.0
- if not gold_toks or not pred_toks:
- return 0.0
-
- common = Counter(gold_toks) & Counter(pred_toks)
- num_same = sum(common.values())
- if num_same == 0:
- return 0.0
-
- precision = num_same / len(pred_toks)
- recall = num_same / len(gold_toks)
- return 2 * precision * recall / (precision + recall)
+ return calculate_f1_score(gold, pred, self._normalize)
def _make_run(
diff --git a/mas_arena/evaluators/hotpotqa_evaluator.py b/mas_arena/evaluators/hotpotqa_evaluator.py
index f2db35e..d1f336b 100644
--- a/mas_arena/evaluators/hotpotqa_evaluator.py
+++ b/mas_arena/evaluators/hotpotqa_evaluator.py
@@ -4,10 +4,7 @@
This module provides a standalone evaluator for HotpotQA (Multi-hop Question Answering) problems.
"""
-import re
-import string
import time
-from collections import Counter
from typing import Dict, Any, Tuple
from langsmith.evaluation import RunEvaluator
@@ -15,13 +12,15 @@
from mas_arena.evaluators.base_evaluator import BaseEvaluator
from mas_arena.evaluators.registry import register_benchmark
+from mas_arena.evaluators.utils import extract_answer_generic, calculate_f1_score, normalize_answer
@register_benchmark(
name="hotpotqa",
normalization_keys={
- "id": "id",
+ "id": "_id",
"problem": "question",
+ "context": "context",
"solution": "answer",
}
)
@@ -35,6 +34,18 @@ def __init__(self, name: str = "hotpotqa", config: Dict[str, Any] = None):
@classmethod
def from_config(cls, name: str, config: Dict[str, Any] = None):
return cls(name, config)
+
+ def extract_answer(self, text: str) -> str:
+ """
+ Extract the answer from model output text, expecting '...' tags first.
+
+ Args:
+ text: The model's output text
+
+ Returns:
+ The extracted answer (e.g., "(A)", "True", "] >")
+ """
+ return extract_answer_generic(text)
def normalize_answer(self, s: str) -> str:
"""
@@ -46,20 +57,7 @@ def normalize_answer(self, s: str) -> str:
Returns:
Normalized answer string
"""
- def remove_articles(text):
- return re.sub(r"\b(a|an|the)\b", " ", text)
-
- def white_space_fix(text):
- return " ".join(text.split())
-
- def remove_punc(text):
- exclude = set(string.punctuation)
- return "".join(ch for ch in text if ch not in exclude)
-
- def lower(text):
- return text.lower()
-
- return white_space_fix(remove_articles(remove_punc(lower(s))))
+ return normalize_answer(s)
def calculate_score(self, ground_truth: str, prediction: str) -> Tuple[float, str]:
"""
@@ -70,22 +68,11 @@ def calculate_score(self, ground_truth: str, prediction: str) -> Tuple[float, st
prediction: The predicted answer
Returns:
- Tuple of (f1_score, prediction)
+ Tuple of (f1_score, extracted_answer)
"""
- prediction_tokens = self.normalize_answer(prediction).split()
- ground_truth_tokens = self.normalize_answer(ground_truth).split()
-
- common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
- num_same = sum(common.values())
-
- if num_same == 0:
- return 0, prediction
-
- precision = 1.0 * num_same / len(prediction_tokens)
- recall = 1.0 * num_same / len(ground_truth_tokens)
- f1 = (2 * precision * recall) / (precision + recall)
-
- return f1, prediction
+ extracted_answer = self.extract_answer(prediction)
+ f1_score = calculate_f1_score(ground_truth, extracted_answer)
+ return f1_score, extracted_answer
def create_run(self, problem: Dict[str, Any], final_answer: str, extracted_answer: str, score: float) -> Run:
"""Create a LangSmith run for evaluation"""
@@ -95,13 +82,13 @@ def create_run(self, problem: Dict[str, Any], final_answer: str, extracted_answe
id=str(uuid.uuid4()),
name=f"{self.name.upper()}_Evaluation",
inputs={
- "question": problem["question"],
+ "question": problem["problem"],
"context": problem["context"]
},
outputs={
"prediction": final_answer,
"extracted_answer": extracted_answer,
- "expected": problem["answer"],
+ "expected": problem["solution"],
"score": score,
"passed": score >= 0.3, # HotpotQA uses 0.3 as threshold
},
@@ -129,7 +116,7 @@ def evaluate(self, problem: Dict[str, Any], run_result: Dict[str, Any]) -> Dict[
context_str = "\n".join(" ".join(paragraph) for paragraph in paragraphs)
# Calculate score
- score, extracted_answer = self.calculate_score(problem["answer"], final_answer)
+ score, extracted_answer = self.calculate_score(problem["solution"], final_answer)
# # Create LangSmith run
# run = self.create_run(problem, final_answer, extracted_answer, score)
@@ -138,15 +125,18 @@ def evaluate(self, problem: Dict[str, Any], run_result: Dict[str, Any]) -> Dict[
# Log mismatch if score is too low
if score < 0.3:
with open(f"{self.log_path}/mismatches.log", "a") as f:
- f.write(f"\nQuestion: {problem['question']}\n")
+ f.write(f"\nQuestion: {problem['problem']}\n")
f.write(f"Context: {context_str}\n")
- f.write(f"Expected: {problem['answer']}\n")
+ f.write(f"Expected: {problem['solution']}\n")
f.write(f"Predicted: {final_answer}\n")
f.write(f"Score: {score}\n")
+ # Final score: 1.0 if score >= 0.3, else use the score directly
+ final_score = 1 if score >= 0.3 else score
+
return {
"final_answer": final_answer,
"extracted_answer": extracted_answer,
- "score": score,
+ "score": final_score,
"context": context_str
}
\ No newline at end of file
diff --git a/mas_arena/evaluators/mmlu_pro_evaluator.py b/mas_arena/evaluators/mmlu_pro_evaluator.py
index 54c833f..2b5cdd4 100644
--- a/mas_arena/evaluators/mmlu_pro_evaluator.py
+++ b/mas_arena/evaluators/mmlu_pro_evaluator.py
@@ -12,6 +12,7 @@
from mas_arena.evaluators.base_evaluator import BaseEvaluator
from mas_arena.evaluators.registry import register_benchmark
+from mas_arena.evaluators.utils import extract_answer_simple_tags
@register_benchmark(
@@ -134,13 +135,7 @@ def extract_answer_from_response(self, response: str) -> str:
Returns:
Extracted answer letter
"""
- # Try to extract answer from tags, allowing for whitespace
- match = re.search(r'\s*(.*?)\s*', response, re.DOTALL)
- if match:
- return match.group(1).strip()
-
- # If no tags found, return original response
- return response.strip()
+ return extract_answer_simple_tags(response)
def evaluate(self, problem: Dict[str, Any], run_result: Dict[str, Any]) -> Dict[str, Any]:
"""
diff --git a/mas_arena/evaluators/utils/__init__.py b/mas_arena/evaluators/utils/__init__.py
index ce43f5f..2407db7 100644
--- a/mas_arena/evaluators/utils/__init__.py
+++ b/mas_arena/evaluators/utils/__init__.py
@@ -3,5 +3,27 @@
"""
from .sanitize import sanitize
+from .answer_extraction import (
+ extract_answer_generic,
+ extract_answer_numeric,
+ extract_answer_simple_tags,
+ extract_answer # backward compatibility
+)
+from .metrics import (
+ normalize_answer,
+ calculate_f1_score,
+ calculate_exact_match,
+ calculate_multi_answer_f1
+)
-__all__ = ["sanitize"]
+__all__ = [
+ "sanitize",
+ "extract_answer_generic",
+ "extract_answer_numeric",
+ "extract_answer_simple_tags",
+ "extract_answer",
+ "normalize_answer",
+ "calculate_f1_score",
+ "calculate_exact_match",
+ "calculate_multi_answer_f1"
+]
diff --git a/mas_arena/evaluators/utils/answer_extraction.py b/mas_arena/evaluators/utils/answer_extraction.py
new file mode 100644
index 0000000..93c8bfb
--- /dev/null
+++ b/mas_arena/evaluators/utils/answer_extraction.py
@@ -0,0 +1,108 @@
+"""
+Answer extraction utilities for evaluators.
+
+This module provides common functions for extracting answers from model outputs
+across different benchmark evaluators.
+"""
+
+import re
+
+
+def extract_answer_generic(text: str) -> str:
+ """
+ Generic answer extraction with comprehensive fallback patterns.
+ Suitable for most text-based benchmarks like HotpotQA, BBH, etc.
+
+ Args:
+ text: The model's output text
+
+ Returns:
+ The extracted answer (e.g., "(A)", "True", text content)
+ """
+ text = text.strip()
+
+ # Primary pattern: Content within ... tags
+ tag_pattern = r"\s*([\s\S]*?)\s*"
+ match = re.search(tag_pattern, text, re.IGNORECASE)
+ if match:
+ return match.group(1).strip()
+
+ # Fallback: "Final Answer: "
+ final_answer_pattern = r"Final Answer:\s*(.+)"
+ match = re.search(final_answer_pattern, text, re.DOTALL)
+ if match:
+ return match.group(1).strip()
+
+ # Fallback: Look for multiple-choice options (e.g., (A), A, [A])
+ option_pattern = r"\([A-Z]\)|[A-Z]\b|\[[A-Z]\]"
+ matches = re.findall(option_pattern, text, re.DOTALL)
+ if matches:
+ last_match = matches[-1]
+ # Normalize to (A) format
+ if not last_match.startswith("("):
+ last_match = f"({last_match[-1]})"
+ return last_match.strip()
+
+ # Fallback: Look for boolean values
+ boolean_pattern = r"\b(True|False)\b"
+ boolean_matches = re.findall(boolean_pattern, text, re.DOTALL)
+ if boolean_matches:
+ return boolean_matches[-1].strip()
+
+ # Fallback: Look for sequence completions (e.g., "> ) }", "] ] ]")
+ sequence_pattern = r"([>\]\}\)\[]+\s*)+"
+ sequence_matches = re.findall(sequence_pattern, text, re.DOTALL)
+ if sequence_matches:
+ return sequence_matches[-1].strip()
+
+ # Fallback: Last non-empty line
+ lines = [line.strip() for line in text.split("\n") if line.strip()]
+ if lines:
+ return lines[-1]
+
+ # Final fallback: Return stripped text
+ return text.strip()
+
+
+def extract_answer_numeric(text: str) -> str:
+ """
+ Extract numeric answers, suitable for math problems like AIME.
+
+ Args:
+ text: The model's output text
+
+ Returns:
+ The extracted numeric answer
+ """
+ # Try to extract the last number (int/float)
+ matches = re.findall(r"[-+]?\d+(?:,\d{3})*(?:\.\d+)?|\d+\.\d+", str(text))
+ if matches:
+ return matches[-1].replace(",", "").strip()
+
+ # Fallback: last non-empty line
+ lines = [line.strip() for line in str(text).splitlines() if line.strip()]
+ return lines[-1] if lines else str(text).strip()
+
+
+def extract_answer_simple_tags(text: str) -> str:
+ """
+ Simple answer extraction for evaluators that primarily use tags.
+ Suitable for MMLU-Pro and similar benchmarks.
+
+ Args:
+ text: The model's output text
+
+ Returns:
+ The extracted answer
+ """
+ # Try to extract answer from tags, allowing for whitespace
+ match = re.search(r'\s*(.*?)\s*', text, re.DOTALL)
+ if match:
+ return match.group(1).strip()
+
+ # If no tags found, return original response
+ return text.strip()
+
+
+# Backward compatibility aliases
+extract_answer = extract_answer_generic
diff --git a/mas_arena/evaluators/utils/metrics.py b/mas_arena/evaluators/utils/metrics.py
new file mode 100644
index 0000000..4067e43
--- /dev/null
+++ b/mas_arena/evaluators/utils/metrics.py
@@ -0,0 +1,111 @@
+"""
+Evaluation metrics utilities.
+
+This module provides common metric calculations used across different evaluators.
+"""
+
+import re
+import string
+from collections import Counter
+from typing import List, Any
+
+
+def normalize_answer(s: Any) -> str:
+ """
+ Normalize answer text for evaluation.
+ Standard normalization: lowercase -> remove articles/punctuation -> collapse whitespace.
+ Used by DROP, HotpotQA and similar benchmarks.
+
+ Args:
+ s: The text to normalize
+
+ Returns:
+ Normalized text string
+ """
+ s = str(s)
+
+ def remove_articles(text: str) -> str:
+ return re.sub(r"\b(a|an|the)\b", " ", text)
+
+ def white_space_fix(text: str) -> str:
+ return " ".join(text.split())
+
+ def remove_punc(text: str) -> str:
+ return "".join(ch for ch in text if ch not in string.punctuation)
+
+ return white_space_fix(remove_articles(remove_punc(s.lower())))
+
+
+def calculate_f1_score(gold: str, pred: str, normalize_fn=None) -> float:
+ """
+ Calculate token-level F1 score between gold and predicted answers.
+
+ Args:
+ gold: Ground truth answer
+ pred: Predicted answer
+ normalize_fn: Optional normalization function. If None, uses normalize_answer.
+
+ Returns:
+ F1 score between 0.0 and 1.0
+ """
+ if normalize_fn is None:
+ normalize_fn = normalize_answer
+
+ gold_toks: List[str] = normalize_fn(gold).split()
+ pred_toks: List[str] = normalize_fn(pred).split()
+
+ if not gold_toks and not pred_toks:
+ return 1.0
+ if not gold_toks or not pred_toks:
+ return 0.0
+
+ common = Counter(gold_toks) & Counter(pred_toks)
+ num_same = sum(common.values())
+ if num_same == 0:
+ return 0.0
+
+ precision = num_same / len(pred_toks)
+ recall = num_same / len(gold_toks)
+ return 2 * precision * recall / (precision + recall)
+
+
+def calculate_exact_match(gold: str, pred: str, normalize_fn=None) -> float:
+ """
+ Calculate exact match score between gold and predicted answers.
+
+ Args:
+ gold: Ground truth answer
+ pred: Predicted answer
+ normalize_fn: Optional normalization function. If None, uses normalize_answer.
+
+ Returns:
+ 1.0 if exact match, 0.0 otherwise
+ """
+ if normalize_fn is None:
+ normalize_fn = normalize_answer
+
+ return 1.0 if normalize_fn(gold) == normalize_fn(pred) else 0.0
+
+
+def calculate_multi_answer_f1(gold_answers: List[str], pred_answers: List[str], normalize_fn=None) -> float:
+ """
+ Calculate the best F1 score when there are multiple possible gold and predicted answers.
+ Used by DROP and similar benchmarks that support multiple valid answers.
+
+ Args:
+ gold_answers: List of valid ground truth answers
+ pred_answers: List of predicted answers
+ normalize_fn: Optional normalization function. If None, uses normalize_answer.
+
+ Returns:
+ Best F1 score found between any gold-pred pair
+ """
+ if not gold_answers or not pred_answers:
+ return 0.0
+
+ scores = [
+ calculate_f1_score(gold, pred, normalize_fn)
+ for gold in gold_answers for pred in pred_answers
+ ]
+
+ return max(scores) if scores else 0.0
diff --git a/mas_arena/evaluators/utils/normalization.py b/mas_arena/evaluators/utils/normalization.py
index 1cb7800..2a1c7df 100644
--- a/mas_arena/evaluators/utils/normalization.py
+++ b/mas_arena/evaluators/utils/normalization.py
@@ -72,6 +72,7 @@ def normalize_problem_keys(problem: Dict[str, Any], key_mapping: Dict[str, str],
key_definitions = {
"id": "id",
"problem": "problem",
+ "context": "context",
"solution": "solution",
"test": "test",
"entry_point": "entry_point",