Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ OPENAI_API_BASE=https://api.openai.com/v1
* Supported benchmarks:
* Math: `math`, `aime`
* Code: `humaneval`, `mbpp`
* Reasoning: `drop`, `bbh`, `mmlu_pro`, `ifeval`
* Reasoning: `drop`, `bbh`, `mmlu_pro`, `ifeval`, `hotpotqa`
* Supported agent systems:
* Single Agent: `single_agent`
* Multi-Agent: `supervisor_mas`, `swarm`, `agentverse`, `chateval`, `evoagent`, `jarvis`, `metagpt`
Expand Down
29 changes: 29 additions & 0 deletions mas_arena/agents/format_prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ class DatasetType(Enum):
HUMANEVAL = auto()
MATH = auto()
CODE = auto() # Generic code generation format
HOTPOTQA = auto() # Multi-hop question answering


@dataclass
Expand Down Expand Up @@ -154,6 +155,31 @@ class FormatPrompt:
""",
description="Format prompt for math problems",
dataset_type=DatasetType.MATH
),

"hotpotqa": FormatPrompt(
name="HOTPOTQA",
prompt="""
- This is a multi-hop question answering task that requires reasoning across multiple documents.
- Read through all provided context documents carefully to find relevant information.
- The answer should be a specific entity, name, or short phrase (usually 1-5 words).
- Provide your reasoning process to show how you connected information across documents.
- Give your final answer in the format: <answer>your answer here</answer>
- Ensure your answer is:
* Factually accurate and supported by the context
* Precisely answering what is asked (e.g., if asked for a year, give a year; if asked for a name, give a name)
* Concise and specific (avoid unnecessary words or explanations in the answer tags)
* Properly capitalized and formatted

Example format:
Based on the context, I need to find... [your reasoning]
From document X, I can see that... [connection 1]
From document Y, I can see that... [connection 2]
Therefore, connecting these pieces of information...
<answer>specific answer</answer>
""",
description="Format prompt for HotpotQA multi-hop question answering",
dataset_type=DatasetType.HOTPOTQA
)
}

Expand All @@ -173,6 +199,9 @@ def get_format_prompt(dataset_name: str) -> Optional[str]:
return FORMAT_PROMPTS["code"].prompt
if dataset_name in ["mmlu_pro", "mmlu"]:
return FORMAT_PROMPTS["mmlu"].prompt
# Handle HotpotQA variants
if dataset_name.lower().startswith("hotpot"):
return FORMAT_PROMPTS["hotpotqa"].prompt
# Get prompt for other datasets
prompt_info = FORMAT_PROMPTS.get(dataset_name.lower())
return prompt_info.prompt if prompt_info else None
Expand Down
9 changes: 2 additions & 7 deletions mas_arena/evaluators/aime_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from mas_arena.evaluators.base_evaluator import BaseEvaluator
from mas_arena.evaluators.registry import register_benchmark
from mas_arena.evaluators.utils.math_equal import calculate_score
from mas_arena.evaluators.utils import extract_answer_numeric


@register_benchmark(
Expand Down Expand Up @@ -67,13 +68,7 @@ def math_extract_answer(self, text: str) -> str:
"""
Extract the answer from model output text (last number or string).
"""
# Try to extract the last number (int/float)
matches = re.findall(r"[-+]?\d+(?:,\d{3})*(?:\.\d+)?|\d+\.\d+", str(text))
if matches:
return matches[-1].replace(",", "").strip()
# Fallback: last non-empty line
lines = [line.strip() for line in str(text).splitlines() if line.strip()]
return lines[-1] if lines else str(text).strip()
return extract_answer_numeric(text)

def simple_extract_answer(self, text: str) -> str:
"""
Expand Down
45 changes: 2 additions & 43 deletions mas_arena/evaluators/bbh_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from mas_arena.evaluators.base_evaluator import BaseEvaluator
from mas_arena.evaluators.registry import register_benchmark
from mas_arena.evaluators.utils import extract_answer_generic


@register_benchmark(
Expand Down Expand Up @@ -56,49 +57,7 @@ def extract_answer(self, text: str) -> str:
Returns:
The extracted answer (e.g., "(A)", "True", "] >")
"""
text = text.strip()

# Primary pattern: Content within <answer>...</answer> tags
tag_pattern = r"<answer>\s*([\s\S]*?)\s*</answer>"
match = re.search(tag_pattern, text, re.IGNORECASE)
if match:
return match.group(1).strip()

# Fallback: "Final Answer: <answer>"
final_answer_pattern = r"Final Answer:\s*(.+)"
match = re.search(final_answer_pattern, text, re.DOTALL)
if match:
return match.group(1).strip()

# Fallback: Look for multiple-choice options (e.g., (A), A, [A])
option_pattern = r"\([A-Z]\)|[A-Z]\b|\[[A-Z]\]"
matches = re.findall(option_pattern, text, re.DOTALL)
if matches:
last_match = matches[-1]
# Normalize to (A) format
if not last_match.startswith("("):
last_match = f"({last_match[-1]})"
return last_match.strip()

# Fallback: Look for boolean values
boolean_pattern = r"\b(True|False)\b"
boolean_matches = re.findall(boolean_pattern, text, re.DOTALL)
if boolean_matches:
return boolean_matches[-1].strip()

# Fallback: Look for sequence completions (e.g., "> ) }", "] ] ]")
sequence_pattern = r"([>\]\}\)\[]+\s*)+"
sequence_matches = re.findall(sequence_pattern, text, re.DOTALL)
if sequence_matches:
return sequence_matches[-1].strip()

# Fallback: Last non-empty line
lines = [line.strip() for line in text.split("\n") if line.strip()]
if lines:
return lines[-1]

# Final fallback: Return stripped text
return text.strip()
return extract_answer_generic(text)

def normalize_answer(self, answer: str) -> str:
"""
Expand Down
36 changes: 4 additions & 32 deletions mas_arena/evaluators/drop_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,16 @@
from __future__ import annotations

import re
import string
import time
from pathlib import Path
from collections import Counter
from typing import Dict, Any, List
from typing import Dict, Any

from langsmith.evaluation import RunEvaluator
from langsmith.schemas import Run

from mas_arena.evaluators.base_evaluator import BaseEvaluator
from mas_arena.evaluators.registry import register_benchmark
from mas_arena.evaluators.utils import calculate_f1_score, normalize_answer

_ANS_TAG_RE = re.compile(r"<answer>\s*([\s\S]*?)\s*</answer>", re.IGNORECASE)
_FINAL_RE = re.compile(r"(?:^|\n)\s*(?:final\s+answer|answer)\s*[:\-]?\s*([\s\S]+)", re.IGNORECASE)
Expand Down Expand Up @@ -73,38 +72,11 @@ def _extract_answer(self, raw: Any) -> str:
@staticmethod
def _normalize(s: Any) -> str:
"""DROP normalization: lowercase -> remove articles/punctuation -> collapse whitespace."""
s = str(s)

def remove_articles(t: str) -> str:
return re.sub(r"\b(a|an|the)\b", " ", t)

def white_space_fix(t: str) -> str:
return " ".join(t.split())

def remove_punc(t: str) -> str:
return "".join(ch for ch in t if ch not in string.punctuation)

return white_space_fix(remove_articles(remove_punc(s.lower())))

return normalize_answer(s)

def _f1(self, gold: str, pred: str) -> float:
"""Calculates token-level F1 score (AllenNLP-style)."""
gold_toks: List[str] = self._normalize(gold).split()
pred_toks: List[str] = self._normalize(pred).split()

if not gold_toks and not pred_toks:
return 1.0
if not gold_toks or not pred_toks:
return 0.0

common = Counter(gold_toks) & Counter(pred_toks)
num_same = sum(common.values())
if num_same == 0:
return 0.0

precision = num_same / len(pred_toks)
recall = num_same / len(gold_toks)
return 2 * precision * recall / (precision + recall)
return calculate_f1_score(gold, pred, self._normalize)


def _make_run(
Expand Down
68 changes: 29 additions & 39 deletions mas_arena/evaluators/hotpotqa_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,24 +4,23 @@
This module provides a standalone evaluator for HotpotQA (Multi-hop Question Answering) problems.
"""

import re
import string
import time
from collections import Counter
from typing import Dict, Any, Tuple

from langsmith.evaluation import RunEvaluator
from langsmith.schemas import Run

from mas_arena.evaluators.base_evaluator import BaseEvaluator
from mas_arena.evaluators.registry import register_benchmark
from mas_arena.evaluators.utils import extract_answer_generic, calculate_f1_score, normalize_answer


@register_benchmark(
name="hotpotqa",
normalization_keys={
"id": "id",
"id": "_id",
"problem": "question",
"context": "context",
"solution": "answer",
}
)
Expand All @@ -35,6 +34,18 @@ def __init__(self, name: str = "hotpotqa", config: Dict[str, Any] = None):
@classmethod
def from_config(cls, name: str, config: Dict[str, Any] = None):
return cls(name, config)

def extract_answer(self, text: str) -> str:
"""
Extract the answer from model output text, expecting '<answer>...</answer>' tags first.

Args:
text: The model's output text

Returns:
The extracted answer (e.g., "(A)", "True", "] >")
"""
return extract_answer_generic(text)

def normalize_answer(self, s: str) -> str:
"""
Expand All @@ -46,20 +57,7 @@ def normalize_answer(self, s: str) -> str:
Returns:
Normalized answer string
"""
def remove_articles(text):
return re.sub(r"\b(a|an|the)\b", " ", text)

def white_space_fix(text):
return " ".join(text.split())

def remove_punc(text):
exclude = set(string.punctuation)
return "".join(ch for ch in text if ch not in exclude)

def lower(text):
return text.lower()

return white_space_fix(remove_articles(remove_punc(lower(s))))
return normalize_answer(s)

def calculate_score(self, ground_truth: str, prediction: str) -> Tuple[float, str]:
"""
Expand All @@ -70,22 +68,11 @@ def calculate_score(self, ground_truth: str, prediction: str) -> Tuple[float, st
prediction: The predicted answer

Returns:
Tuple of (f1_score, prediction)
Tuple of (f1_score, extracted_answer)
"""
prediction_tokens = self.normalize_answer(prediction).split()
ground_truth_tokens = self.normalize_answer(ground_truth).split()

common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
num_same = sum(common.values())

if num_same == 0:
return 0, prediction

precision = 1.0 * num_same / len(prediction_tokens)
recall = 1.0 * num_same / len(ground_truth_tokens)
f1 = (2 * precision * recall) / (precision + recall)

return f1, prediction
extracted_answer = self.extract_answer(prediction)
f1_score = calculate_f1_score(ground_truth, extracted_answer)
return f1_score, extracted_answer

def create_run(self, problem: Dict[str, Any], final_answer: str, extracted_answer: str, score: float) -> Run:
"""Create a LangSmith run for evaluation"""
Expand All @@ -95,13 +82,13 @@ def create_run(self, problem: Dict[str, Any], final_answer: str, extracted_answe
id=str(uuid.uuid4()),
name=f"{self.name.upper()}_Evaluation",
inputs={
"question": problem["question"],
"question": problem["problem"],
"context": problem["context"]
},
outputs={
"prediction": final_answer,
"extracted_answer": extracted_answer,
"expected": problem["answer"],
"expected": problem["solution"],
"score": score,
"passed": score >= 0.3, # HotpotQA uses 0.3 as threshold
},
Expand Down Expand Up @@ -129,7 +116,7 @@ def evaluate(self, problem: Dict[str, Any], run_result: Dict[str, Any]) -> Dict[
context_str = "\n".join(" ".join(paragraph) for paragraph in paragraphs)

# Calculate score
score, extracted_answer = self.calculate_score(problem["answer"], final_answer)
score, extracted_answer = self.calculate_score(problem["solution"], final_answer)

# # Create LangSmith run
# run = self.create_run(problem, final_answer, extracted_answer, score)
Expand All @@ -138,15 +125,18 @@ def evaluate(self, problem: Dict[str, Any], run_result: Dict[str, Any]) -> Dict[
# Log mismatch if score is too low
if score < 0.3:
with open(f"{self.log_path}/mismatches.log", "a") as f:
f.write(f"\nQuestion: {problem['question']}\n")
f.write(f"\nQuestion: {problem['problem']}\n")
f.write(f"Context: {context_str}\n")
f.write(f"Expected: {problem['answer']}\n")
f.write(f"Expected: {problem['solution']}\n")
f.write(f"Predicted: {final_answer}\n")
f.write(f"Score: {score}\n")

# Final score: 1.0 if score >= 0.3, else use the score directly
final_score = 1 if score >= 0.3 else score

return {
"final_answer": final_answer,
"extracted_answer": extracted_answer,
"score": score,
"score": final_score,
"context": context_str
}
9 changes: 2 additions & 7 deletions mas_arena/evaluators/mmlu_pro_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from mas_arena.evaluators.base_evaluator import BaseEvaluator
from mas_arena.evaluators.registry import register_benchmark
from mas_arena.evaluators.utils import extract_answer_simple_tags


@register_benchmark(
Expand Down Expand Up @@ -134,13 +135,7 @@ def extract_answer_from_response(self, response: str) -> str:
Returns:
Extracted answer letter
"""
# Try to extract answer from <answer> tags, allowing for whitespace
match = re.search(r'<answer>\s*(.*?)\s*</answer>', response, re.DOTALL)
if match:
return match.group(1).strip()

# If no tags found, return original response
return response.strip()
return extract_answer_simple_tags(response)

def evaluate(self, problem: Dict[str, Any], run_result: Dict[str, Any]) -> Dict[str, Any]:
"""
Expand Down
24 changes: 23 additions & 1 deletion mas_arena/evaluators/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,27 @@
"""

from .sanitize import sanitize
from .answer_extraction import (
extract_answer_generic,
extract_answer_numeric,
extract_answer_simple_tags,
extract_answer # backward compatibility
)
from .metrics import (
normalize_answer,
calculate_f1_score,
calculate_exact_match,
calculate_multi_answer_f1
)

__all__ = ["sanitize"]
__all__ = [
"sanitize",
"extract_answer_generic",
"extract_answer_numeric",
"extract_answer_simple_tags",
"extract_answer",
"normalize_answer",
"calculate_f1_score",
"calculate_exact_match",
"calculate_multi_answer_f1"
]
Loading