Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion opto/trace/projections/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from opto.trace.projections.projections import Projection
from opto.trace.projections.code_projections import BlackCodeFormatter, DocstringProjection
from opto.trace.projections.code_projections import BlackCodeFormatter, DocstringProjection, SuggestionNormalizationProjection
54 changes: 53 additions & 1 deletion opto/trace/projections/code_projections.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@

from opto.trace.projections import Projection
import re
import ast

class BlackCodeFormatter(Projection):
# This requires the `black` package to be installed.
Expand Down Expand Up @@ -28,4 +30,54 @@ def project(self, x: str) -> str:
x = f'{x[0]}"""{self.docstring}"""{x[2]}'
else:
x = f'{x[0]}"""{self.docstring}"""'
return x
return x

class SuggestionNormalizationProjection(Projection):
"""
Normalize LLM-generated suggestion dicts:
- Literal-eval strings to their true types
- Alias frequent keys like "__code:8" ↔ "__code8"
- Black-reformat any code snippets
"""
def __init__(self, parameters):
self.parameters = parameters

def project(self, suggestion: dict) -> dict:
from black import format_str, FileMode
def _find_key(node_name: str):
# exact match
if node_name in suggestion:
return node_name
# strip a colon before digits ("__code:8" → "__code8")
norm = re.sub(r":(?=\d+$)", "", node_name)
for k in suggestion:
if re.sub(r":(?=\d+$)", "", k) == norm:
return k
return None

normalized: dict = {}
for node in self.parameters:
if not getattr(node, "trainable", False):
continue
key = _find_key(node.py_name)
if key is None:
continue

raw_val = suggestion[key]
# re-format any Python defs
# check that key start with "__code" and contains "def"
if isinstance(raw_val, str) and key.startswith("__code"):
raw_val = format_str(raw_val, mode=FileMode())

# convert "123" → 123, "[1,2]" → [1,2], etc.
target_type = type(node.data)
if isinstance(raw_val, str) and target_type is not str:
try:
raw_val = target_type(ast.literal_eval(raw_val))
except Exception:
pass

# map by the parameter’s name, not the node itself
normalized[node.py_name] = raw_val

return normalized
54 changes: 52 additions & 2 deletions tests/unit_tests/test_projection.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from opto.trace.projections import BlackCodeFormatter, DocstringProjection
from opto.trace.projections import BlackCodeFormatter, DocstringProjection, SuggestionNormalizationProjection
from types import SimpleNamespace

def test_black_code_formatter():
code = """
Expand Down Expand Up @@ -35,4 +36,53 @@ def example_function():
assert formatted_code == new_code

# assert '"""This is a new docstring."""' in formatted_code
# assert 'print("Hello, World!")' in formatted_code
# assert 'print("Hello, World!")' in formatted_code

def test_suggestion_normalization_projection():
import re
import pytest
# Prepare a mock parameter list with various py_names, types, and trainable flags
params = [
# code param: key comes in as "__code:1", should alias to "__code1" and be black‑formatted
SimpleNamespace(py_name="__code1", trainable=True, data=""),
# learning rate param: as float, but suggestion comes as a literal string
SimpleNamespace(py_name="__lr", trainable=True, data=0.0),
# should be skipped because not trainable
SimpleNamespace(py_name="__frozen", trainable=False, data=123),
# some other param, no suggestion provided
SimpleNamespace(py_name="__missing", trainable=True, data=1)
]

raw_suggestion = {
"__code:1": "def foo(x):return x*2", # needs black formatting
"__lr": "\"0.01\"", # needs literal‐eval → float
"__frozen": "999", # should be ignored
"unrelated": "[1,2,3]", # not in params
}

proj = SuggestionNormalizationProjection(params)
normalized = proj.project(raw_suggestion)

# It should only contain keys for trainable params that were suggested
assert set(normalized.keys()) == {"__code1", "__lr"}

# 1) __code1 should be black‐formatted: 'def foo' newline indent 'return x * 2'
code_out = normalized["__code1"]
# check that there's exactly one indent (4 spaces) before the return,
# and that black added a trailing newline
assert re.search(r"def foo\(x\):\n {4}return x \* 2\n$", code_out)

# 2) __lr should have been converted from the string "0.01" to float 0.01
assert isinstance(normalized["__lr"], float)
assert normalized["__lr"] == pytest.approx(0.01)

# 3) Non‐trainable or missing params should not appear
assert "__frozen" not in normalized
assert "__missing" not in normalized

# --- literal‑eval failure should be left unchanged ---
# If ast.literal_eval raises, the original string remains
params_bad = [SimpleNamespace(py_name="__bad", trainable=True, data=100)]
raw_suggestion_bad = {"__bad": "not_a_number"}
normalized_bad = SuggestionNormalizationProjection(params_bad).project(raw_suggestion_bad)
assert normalized_bad["__bad"] == "not_a_number"