Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/gimbench/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,11 @@ def _add_mcqa_eval_args(parser):
default=0,
help="Number of reasoning steps to include in the prompt",
)
parser.add_argument(
"--auto_budget",
action="store_true",
help="Automatically determine the reasoning budget (overrides --reason_budget if both are set)",
)


def validate_and_standardize(args: argparse.Namespace) -> argparse.Namespace:
Expand Down
53 changes: 45 additions & 8 deletions src/gimbench/mcqa/evaluators.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class EvalItemResult(BaseModel):
response_tokens: int = -1
query_len: int = -1
response_len: int = -1
reason_budget: int = -1

error_msg: str = ""
additional_info: dict = {}
Expand All @@ -52,6 +53,7 @@ class EvalResult(BaseEvalResult):
avg_response_tokens: float
avg_query_len: float
avg_response_len: float
avg_reason_budget: float

evaled_items: list[EvalItemResult]

Expand All @@ -64,7 +66,10 @@ def __init__(self, args: Namespace, dataset: Dataset):
logger.info(f"Loaded tokenizer {args.counter_tokenizer} for token counting.")

@abstractmethod
def _form_cot_query(self, question: str, choices: list[str]) -> str: ...
def _get_reason_budget(self, question: str) -> int: ...

@abstractmethod
def _form_cot_query(self, question: str, choices: list[str], *args) -> str: ...

@abstractmethod
def _model_call(self, query: str) -> Any: ...
Expand All @@ -80,9 +85,15 @@ def _evaluate_item(self, item: dict) -> EvalItemResult:
item["choices"],
item["correct_choice"],
)
query = self._form_cot_query(question, choices)
try:
raw_response = self._model_call(query)
if self.args.no_gimkit:
reason_budget = -1
query = self._form_cot_query(question, choices)
raw_response = self._model_call(query)
else:
reason_budget = self._get_reason_budget(question)
query = self._form_cot_query(question, choices, reason_budget)
raw_response = self._model_call(query)
response, model_choice, additional_info = self._parse_response(raw_response, choices)
conclusion = model_choice == correct_choice
error_msg = ""
Expand All @@ -103,6 +114,7 @@ def _evaluate_item(self, item: dict) -> EvalItemResult:
response_tokens=self._count_tokens(response) if response != "ERROR" else -1,
query_len=len(query),
response_len=len(response),
reason_budget=reason_budget,
error_msg=error_msg,
additional_info=additional_info,
)
Expand Down Expand Up @@ -141,6 +153,7 @@ def evaluate(self) -> EvalResult:
avg_response_tokens=self._safe_average(evaled_items, "response_tokens"),
avg_query_len=self._safe_average(evaled_items, "query_len"),
avg_response_len=self._safe_average(evaled_items, "response_len"),
avg_reason_budget=self._safe_average(evaled_items, "reason_budget"),
start_time=self.start_time,
end_time=self.end_time,
elapsed_minutes=(self.end_time - self.start_time).total_seconds() / 60.0,
Expand All @@ -164,13 +177,34 @@ def __init__(self, args: Namespace, dataset: Dataset):
super().__init__(args, dataset)
self.model = SimpleGIM(args)

def _form_cot_query(self, question: str, choices: list[str]) -> str:
def _get_reason_budget(self, question: str) -> int:
if self.args.auto_budget:
try:
r = self.model.generate(
f"I'll show you a question. "
f"You need to determine how many reasoning steps are required to accurately answer it.\n\n"
f"## Question: Find the sum of first 5 positive integers.\n\n"
f"## Reasoning steps: 2\n\n"
f"## Question: {question}\n\n"
f"## Reasoning steps: "
+ guide(name="reason_budget", desc="A positive integer number", regex=r"\d+")
)
budget = int(r.tags["reason_budget"].content or "1")
except Exception as e:
logger.warning(f"Auto-budget determination failed: {e}")
budget = 1
reason_budget = max(1, budget)
logger.info(f"Auto-determined reasoning budget: {reason_budget}")
else:
reason_budget = self.args.reason_budget
return reason_budget

def _form_cot_query(self, question: str, choices: list[str], reason_budget: int) -> str:
reasoning_guides = [
f"## Step {idx + 1}\n\n" + guide(desc="One thinking step. About 60 words")
for idx in range(self.args.reason_budget)
f"## Step {idx + 1}\n\n" + guide(desc="One thinking step. About 60 words") for idx in range(reason_budget)
]
prompt = SHARED_PROMPT_PREFIX + f"\n\nQuestion: {question}\n\n"
if self.args.reason_budget > 0:
if reason_budget > 0:
prompt += "Let's think step by step.\n\n" + "\n\n".join(reasoning_guides) + "\n\n"
prompt += "## Conclusion\n\nFinal answer: " + guide.select(choices=choices, name="predicted_choice")
return prompt
Expand All @@ -193,7 +227,10 @@ def __init__(self, args: Namespace, dataset: Dataset):
super().__init__(args, dataset)
self.model = OpenAI(api_key=args.api_key, base_url=args.base_url)

def _form_cot_query(self, question: str, choices: list[str]) -> str:
def _get_reason_budget(self, question: str) -> int:
raise NotImplementedError("CommonEvaluator does not support reason budget.")

def _form_cot_query(self, question: str, choices: list[str], *args) -> str:
prompt = SHARED_PROMPT_PREFIX + (
" Remember to end with `The answer is: xxx`.\n\n"
f"Question: {question}\n\n"
Expand Down