diff --git a/trinity/common/workflows/__init__.py b/trinity/common/workflows/__init__.py index ea7390b4a4..ad38cb935c 100644 --- a/trinity/common/workflows/__init__.py +++ b/trinity/common/workflows/__init__.py @@ -48,6 +48,8 @@ # on-policy distillation workflows "on_policy_distill_workflow": "trinity.common.workflows.on_policy_distill_workflow.OnPolicyDistillWorkflow", "on_policy_distill_math_workflow": "trinity.common.workflows.on_policy_distill_workflow.OnPolicyDistillMathWorkflow", + # custom workflows + "sudoku_workflow": "trinity.common.workflows.sudoku_workflow.SudokuWorkflow", }, ) diff --git a/trinity/common/workflows/sudoku_generator.py b/trinity/common/workflows/sudoku_generator.py new file mode 100644 index 0000000000..bcca7e12cc --- /dev/null +++ b/trinity/common/workflows/sudoku_generator.py @@ -0,0 +1,75 @@ +import random + + +class SudokuGenerator: + """ + Sudoku puzzle generator inspired by standard backtracking-based generators. + + - Generates a fresh solved Sudoku board using backtracking + - Removes cells based on difficulty (number of empty cells) + - Avoids relying on a single canonical solution + """ + + def generate(self, difficulty="medium"): + holes_map = { + "easy": 30, + "medium": 40, + "hard": 50, + } + holes = holes_map.get(difficulty, 40) + + board = [[0 for _ in range(9)] for _ in range(9)] + self._fill_board(board) + + solution = [row[:] for row in board] + self._remove_cells(board, holes) + + return board, solution + + def _fill_board(self, board): + empty = self._find_empty(board) + if not empty: + return True + + r, c = empty + nums = list(range(1, 10)) + random.shuffle(nums) + + for v in nums: + if self._is_valid(board, r, c, v): + board[r][c] = v + if self._fill_board(board): + return True + board[r][c] = 0 + + return False + + def _find_empty(self, board): + for i in range(9): + for j in range(9): + if board[i][j] == 0: + return i, j + return None + + def _is_valid(self, board, r, c, v): + if v in board[r]: + return False + + if v in [board[i][c] for i in range(9)]: + return False + + br, bc = (r // 3) * 3, (c // 3) * 3 + for i in range(br, br + 3): + for j in range(bc, bc + 3): + if board[i][j] == v: + return False + + return True + + def _remove_cells(self, board, holes): + cells = [(i, j) for i in range(9) for j in range(9)] + random.shuffle(cells) + + for i in range(min(holes, 81)): + r, c = cells[i] + board[r][c] = 0 diff --git a/trinity/common/workflows/sudoku_judge.py b/trinity/common/workflows/sudoku_judge.py new file mode 100644 index 0000000000..9fee423710 --- /dev/null +++ b/trinity/common/workflows/sudoku_judge.py @@ -0,0 +1,43 @@ +class SudokuJudge: + """ + Judge Sudoku board state. + - Checks row validity + - Checks column validity + - Checks 3x3 block validity + """ + + @staticmethod + def is_valid(board): + # Check rows + for row in board: + nums = [v for v in row if v != 0] + if len(nums) != len(set(nums)): + return False + + # Check columns + for col in range(9): + nums = [] + for row in range(9): + v = board[row][col] + if v != 0: + nums.append(v) + if len(nums) != len(set(nums)): + return False + + # Check 3x3 sub-grids + for br in range(0, 9, 3): + for bc in range(0, 9, 3): + nums = [] + for r in range(br, br + 3): + for c in range(bc, bc + 3): + v = board[r][c] + if v != 0: + nums.append(v) + if len(nums) != len(set(nums)): + return False + + return True + + @staticmethod + def is_solved(board, solution): + return board == solution diff --git a/trinity/common/workflows/sudoku_workflow.py b/trinity/common/workflows/sudoku_workflow.py new file mode 100644 index 0000000000..e65604ca30 --- /dev/null +++ b/trinity/common/workflows/sudoku_workflow.py @@ -0,0 +1,159 @@ +from trinity.common.experience import Experience +from trinity.common.workflows.workflow import Workflow + +from .sudoku_generator import SudokuGenerator +from .sudoku_judge import SudokuJudge + + +class SudokuWorkflow(Workflow): + can_reset = True + + def __init__(self, task, model, auxiliary_models=None): + super().__init__(task=task, model=model, auxiliary_models=auxiliary_models) + + if "puzzle" in task.raw_task and "solution" in task.raw_task: + self.board = [row[:] for row in task.raw_task["puzzle"]] + self.solution = [row[:] for row in task.raw_task["solution"]] + else: + generator = SudokuGenerator() + self.board, self.solution = generator.generate() + + self.judge = SudokuJudge() + self.max_steps = 20 + self.max_moves_per_step = 5 + + self.current_step = 0 + self.last_board = None + self.last_action = None + + def reset(self, task): + self.board = [row[:] for row in task.raw_task["puzzle"]] + self.solution = [row[:] for row in task.raw_task["solution"]] + self.current_step = 0 + self.last_board = None + self.last_action = None + + def render_board(self): + return "\n".join(" ".join(str(v) for v in row) for row in self.board) + + def _build_prompt(self): + prompt = ( + "You are playing a Sudoku game.\n\n" + "Rules:\n" + "- The board is 9x9.\n" + "- 0 means empty.\n" + "- Numbers 1–9 must appear exactly once in every row, column, and 3x3 block.\n" + "- You may only fill empty cells.\n\n" + "Task:\n" + "- In each step, output ONE OR MORE valid moves.\n" + f"- You may output up to {self.max_moves_per_step} moves per step.\n\n" + "Output format (STRICT):\n" + "row col value\n" + "row col value\n\n" + "Example:\n" + "0 2 4\n" + "1 3 5\n\n" + f"Current step: {self.current_step}\n" + f"Remaining steps: {self.max_steps - self.current_step}\n\n" + f"Current board:\n{self.render_board()}\n" + ) + + if self.last_board is not None and self.board == self.last_board: + prompt += ( + "\nYour previous response was invalid or had no effect. " + "Please follow the rules and output format strictly." + ) + + return prompt + + def parse_action(self, text): + lines = text.strip().splitlines() + actions = [] + + for line in lines: + line = line.strip() + if not line: + continue + parts = line.split() + if len(parts) != 3: + return None + try: + r, c, v = map(int, parts) + except ValueError: + return None + if not (0 <= r <= 8 and 0 <= c <= 8 and 1 <= v <= 9): + return None + actions.append((r, c, v)) + + if not actions or len(actions) > self.max_moves_per_step: + return None + + return actions + + def run(self): + experiences = [] + + for _ in range(self.max_steps): + prompt = self._build_prompt() + responses = self.model.chat([{"role": "user", "content": prompt}]) + resp = responses[0] + + self.last_board = [row[:] for row in self.board] + + actions = self.parse_action(resp.response_text) + if actions is None: + experiences.append( + Experience( + tokens=resp.tokens, + prompt_length=resp.prompt_length, + reward=-1.0, + logprobs=resp.logprobs, + ) + ) + break + + board_changed = False + invalid_move = False + + for r, c, v in actions: + if self.board[r][c] != 0: + invalid_move = True + break + self.board[r][c] = v + board_changed = True + + if invalid_move or not board_changed or not self.judge.is_valid(self.board): + experiences.append( + Experience( + tokens=resp.tokens, + prompt_length=resp.prompt_length, + reward=-1.0, + logprobs=resp.logprobs, + ) + ) + break + + if self.judge.is_solved(self.board, self.solution): + experiences.append( + Experience( + tokens=resp.tokens, + prompt_length=resp.prompt_length, + reward=1.0, + logprobs=resp.logprobs, + ) + ) + break + + experiences.append( + Experience( + tokens=resp.tokens, + prompt_length=resp.prompt_length, + reward=0.0, + logprobs=resp.logprobs, + ) + ) + + self.last_action = actions + self.current_step += 1 + + return experiences