Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions trinity/common/workflows/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@
# on-policy distillation workflows
"on_policy_distill_workflow": "trinity.common.workflows.on_policy_distill_workflow.OnPolicyDistillWorkflow",
"on_policy_distill_math_workflow": "trinity.common.workflows.on_policy_distill_workflow.OnPolicyDistillMathWorkflow",
# custom workflows
"sudoku_workflow": "trinity.common.workflows.sudoku_workflow.SudokuWorkflow",
},
)

Expand Down
75 changes: 75 additions & 0 deletions trinity/common/workflows/sudoku_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import random


class SudokuGenerator:
"""
Sudoku puzzle generator inspired by standard backtracking-based generators.

- Generates a fresh solved Sudoku board using backtracking
- Removes cells based on difficulty (number of empty cells)
- Avoids relying on a single canonical solution
"""

def generate(self, difficulty="medium"):
holes_map = {
"easy": 30,
"medium": 40,
"hard": 50,
}
holes = holes_map.get(difficulty, 40)

board = [[0 for _ in range(9)] for _ in range(9)]
self._fill_board(board)

solution = [row[:] for row in board]
self._remove_cells(board, holes)

return board, solution

def _fill_board(self, board):
empty = self._find_empty(board)
if not empty:
return True

r, c = empty
nums = list(range(1, 10))
random.shuffle(nums)

for v in nums:
if self._is_valid(board, r, c, v):
board[r][c] = v
if self._fill_board(board):
return True
board[r][c] = 0

return False

def _find_empty(self, board):
for i in range(9):
for j in range(9):
if board[i][j] == 0:
return i, j
return None

def _is_valid(self, board, r, c, v):
if v in board[r]:
return False

if v in [board[i][c] for i in range(9)]:
return False

br, bc = (r // 3) * 3, (c // 3) * 3
for i in range(br, br + 3):
for j in range(bc, bc + 3):
if board[i][j] == v:
return False

return True

def _remove_cells(self, board, holes):
cells = [(i, j) for i in range(9) for j in range(9)]
random.shuffle(cells)

for i in range(min(holes, 81)):
r, c = cells[i]
board[r][c] = 0
43 changes: 43 additions & 0 deletions trinity/common/workflows/sudoku_judge.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
class SudokuJudge:
"""
Judge Sudoku board state.
- Checks row validity
- Checks column validity
- Checks 3x3 block validity
"""

@staticmethod
def is_valid(board):
# Check rows
for row in board:
nums = [v for v in row if v != 0]
if len(nums) != len(set(nums)):
return False

# Check columns
for col in range(9):
nums = []
for row in range(9):
v = board[row][col]
if v != 0:
nums.append(v)
if len(nums) != len(set(nums)):
return False

# Check 3x3 sub-grids
for br in range(0, 9, 3):
for bc in range(0, 9, 3):
nums = []
for r in range(br, br + 3):
for c in range(bc, bc + 3):
v = board[r][c]
if v != 0:
nums.append(v)
if len(nums) != len(set(nums)):
return False

return True

@staticmethod
def is_solved(board, solution):
return board == solution
159 changes: 159 additions & 0 deletions trinity/common/workflows/sudoku_workflow.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your contribution. The Sudoku has some similarities to frozen lake.
And the current version has significant room for improvement.

A qualified Sudoku workflow should include three parts:

1.A Sudoku generator: Automatically generate solvable Sudoku puzzles and allow you to set the difficulty level.
2. An agentic workflow to solve the Sudoku: Some Sudoku is hard to solve in just one step, so an agentic workflow should be designed to solve the game in multiple steps.
3. A general judge function: Some Sudoku puzzles may have multiple possible solutions, the judge function should correctly parse the model's output and determine the correctness of the result according to the Sudoku rules, not just exactly match.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @pan-x-c ,
Thanks for the detailed feedback earlier.

I’ve implemented all the requested changes:

  • Added a SudokuGenerator that produces solvable puzzles (with adjustable difficulty via hole count)
  • Reworked the workflow into a multi-step agentic loop, similar in structure to FrozenLakeWorkflow
  • Added a SudokuJudge that validates rows, columns, and 3×3 blocks instead of exact string matching
  • Integrated generator + judge inside the workflow
  • Updated workflow registry

Please have a look and let me know if you’d like further improvements or additional refinements.

Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
from trinity.common.experience import Experience
from trinity.common.workflows.workflow import Workflow

from .sudoku_generator import SudokuGenerator
from .sudoku_judge import SudokuJudge


class SudokuWorkflow(Workflow):
can_reset = True

def __init__(self, task, model, auxiliary_models=None):
super().__init__(task=task, model=model, auxiliary_models=auxiliary_models)

if "puzzle" in task.raw_task and "solution" in task.raw_task:
self.board = [row[:] for row in task.raw_task["puzzle"]]
self.solution = [row[:] for row in task.raw_task["solution"]]
else:
generator = SudokuGenerator()
self.board, self.solution = generator.generate()

self.judge = SudokuJudge()
self.max_steps = 20
self.max_moves_per_step = 5

self.current_step = 0
self.last_board = None
self.last_action = None

def reset(self, task):
self.board = [row[:] for row in task.raw_task["puzzle"]]
self.solution = [row[:] for row in task.raw_task["solution"]]
self.current_step = 0
self.last_board = None
self.last_action = None

def render_board(self):
return "\n".join(" ".join(str(v) for v in row) for row in self.board)

def _build_prompt(self):
prompt = (
"You are playing a Sudoku game.\n\n"
"Rules:\n"
"- The board is 9x9.\n"
"- 0 means empty.\n"
"- Numbers 1–9 must appear exactly once in every row, column, and 3x3 block.\n"
"- You may only fill empty cells.\n\n"
"Task:\n"
"- In each step, output ONE OR MORE valid moves.\n"
f"- You may output up to {self.max_moves_per_step} moves per step.\n\n"
"Output format (STRICT):\n"
"row col value\n"
"row col value\n\n"
"Example:\n"
"0 2 4\n"
"1 3 5\n\n"
f"Current step: {self.current_step}\n"
f"Remaining steps: {self.max_steps - self.current_step}\n\n"
f"Current board:\n{self.render_board()}\n"
)

if self.last_board is not None and self.board == self.last_board:
prompt += (
"\nYour previous response was invalid or had no effect. "
"Please follow the rules and output format strictly."
)

return prompt

def parse_action(self, text):
lines = text.strip().splitlines()
actions = []

for line in lines:
line = line.strip()
if not line:
continue
parts = line.split()
if len(parts) != 3:
return None
try:
r, c, v = map(int, parts)
except ValueError:
return None
if not (0 <= r <= 8 and 0 <= c <= 8 and 1 <= v <= 9):
return None
actions.append((r, c, v))

if not actions or len(actions) > self.max_moves_per_step:
return None

return actions

def run(self):
experiences = []

for _ in range(self.max_steps):
prompt = self._build_prompt()
responses = self.model.chat([{"role": "user", "content": prompt}])
resp = responses[0]

self.last_board = [row[:] for row in self.board]

actions = self.parse_action(resp.response_text)
if actions is None:
experiences.append(
Experience(
tokens=resp.tokens,
prompt_length=resp.prompt_length,
reward=-1.0,
logprobs=resp.logprobs,
)
)
break

board_changed = False
invalid_move = False

for r, c, v in actions:
if self.board[r][c] != 0:
invalid_move = True
break
self.board[r][c] = v
board_changed = True

if invalid_move or not board_changed or not self.judge.is_valid(self.board):
experiences.append(
Experience(
tokens=resp.tokens,
prompt_length=resp.prompt_length,
reward=-1.0,
logprobs=resp.logprobs,
)
)
break

if self.judge.is_solved(self.board, self.solution):
experiences.append(
Experience(
tokens=resp.tokens,
prompt_length=resp.prompt_length,
reward=1.0,
logprobs=resp.logprobs,
)
)
break

experiences.append(
Experience(
tokens=resp.tokens,
prompt_length=resp.prompt_length,
reward=0.0,
logprobs=resp.logprobs,
)
)

self.last_action = actions
self.current_step += 1

return experiences