diff --git a/OVERVIEW.md b/OVERVIEW.md index 947180c6..e53a8c19 100644 --- a/OVERVIEW.md +++ b/OVERVIEW.md @@ -1,32 +1,57 @@ # Overview of Trace and Development Guide -The library of Trace is designed to be a lightweight, modularized package to allow developers to easily try new ideas on generative optimization and integrate learning wtih their pipelines. - -Currently, the Trace library has three main modules collected under the `opto` top module. - -1. `opto.trace` provides the infrastructure for tracing computational workflows. It defines two primitives `trace.node` and `@trace.bundle`. They can be applied to Python objects and methods, respectively, which define the root nodes and operators of the directed acyclic graph (DAG) of computation. They both have a `trainable` flag. When set `True`, the wrapped objects are viewed as *parameters* of the computational worflow. Users can use `trace.node` and `@trace.bundle` to declare the data and computation that they wish to trace and/or adapt, and we call the resulting workflow defined by these two primitives a *traced* workflow. When running a traced workflow, a DAG will be automatiically created by Trace as a data structure, which will later be sent to optimizers in `opto.optimizers`for updates (upon calling `node.backward` with soem feedback). - -2. `opto.optimizers` has a collection of generative optimization algorithms, whose API is defined by an abstract class `Optimizer`. Think them like gradient algorithms. Their job is to propose a new version of the parameters (i.e. those set with `trainable=True`) when receiving a computational graph (DAG) and the feedback given to the computed output. Typically, these algorithms can be viewed as an LLM agent, which makes calls to LLM to analyze the computational graph and the feedback, and to propose updates. In Trace library, we provide implementation of several popular optimizers, such `OptoPrime`, `TextGrad`, and `OPRO`. - -3. `opto.trainers` are a collection of training algorithms (under the `AlgorithmBase` class) that use optimizers in `opto.optimizers` as subroutines to improve a given workflow following a feedback oracle constructed by datasets, interactive environments, etc. While `Optimizer` defines a low-level *optimization* API, `AlgorithmBase` defines a high-level *learning* API which standarizes the format of agent (by the `Module` class created by `@trace.model`), the data loader (by the `DataLoader` class), and the feedback oracle (by the `AutoGuide` class). With this common abstraction, we offer training algorithms, from the basic `MinibatchAlgorithm` which trains minibatches of samples to search algorithms like `BeamSearch`. The `AlgorithmBase` also handles logging of the training process. While there are overlapping between the functions of `Optimizer` and `AlgorithmBase`, the main distinction is that algorithms under `AlgorithmBase` are meta algorithms, as they should work for different optimizers in `opto.optimizers`. - - -4. `opto.utils` has a collection of helper functions and backends, which are reusable for various applications. This includes, e.g., abstraction of LLMs, database, etc. Making use of all these utils would requie installing optional depedencies. - - -In summary, `opto.trace` is the infrastructure, `opto.optimizers` are algorithms that process feedback and propose new parameter candidates, and `opto.trainers` are algorithms built on top of `opto.trace` and `opto.optimizers` to train learning agents. - -## Common Workflow of Using Trace - -1. Use `trace.node` and `@trace.bundle` to define the traceable workflow and its trainable parameter. -2. Wrap the workflow as a `trace.Module` using `@trace.model` -3. Create a dataloader using `DataLoader` and define the feedback oracle (an analogy of loss function) using `AutoGuide`. -4. Create a trainer from `opto.trainers` using optimizers from `opto.optimizers` and the above module, dataloader, and feedback oracle. +The Trace library is a lightweight, modular package designed to allow developers to experiment easily with generative optimization and integrate feedback-driven learning into their computational workflows. +The library has four modules within the `opto` top-level namespace: + +1. `opto.trace` provides the infrastructure for converting executing Python code into symbolic directed acyclic graphs (DAGs). +It defines two tracing primitives: + - `trace.node`: Wraps Python objects, designating them as nodes within the computational graph. + - `@trace.bundle`: Decorates Python methods/functions, marking them as operators within the graph. + +Each primitive has a `trainable` flag. +When set to `True`, these marked nodes and bundles become the trainable *parameters* of the workflow. +By using these primitives, developers can create a *traced workflow* represented as a DAG. +This DAG structure is automatically constructed at runtime, capturing both computational dependencies and trainable parameters, ready for optimization. + +2. `opto.optimizers` has an abstract class `Optimizer` that defines algorithms that take computation DAGs and associated feedback objects as input, and output values for the trainable parameters. +These algorithms are analogous to gradient-based optimizers in PyTorch, but are typically implemented as generative optimization agents, leveraging LLMs to analyze feedback and propose parameter updates. +We provide implementations of several generative optimizers: + - `OptoPrime` + - `TextGrad` + - `OPRO` + +3. `opto.trainers` has the `AlgorithmBase` abstraction that orchestrates the overall training process. +Trainers manage data handling, tracing control, feedback collection, optimizer invocation, and iterating/stopping. Specifically, a trainer: + - Controls data sampling (via `DataLoader`). + - Determines when DAGs are constructed and when feedback (e.g. via `AutoGuide`) is collected . + - Invokes `optimizers` for parameter updates, possibly repeatedly and manages the training loop. + - Logs training progress. + +Although `optimizers` handle lower-level optimization decisions, trainers under `AlgorithmBase` manage broader training logic and are designed to be compatible across various `optimizers`. +We provide implementations of common trainers: `MinibatchAlgorithm`(basic minibatch training) and `BeamSearch` (example of search-based training). + +4. `opto.utils` has a collection of reusable helper functions and backend utilities, including abstraction for: + - Large Language Models (LLMs) + - Databases + - Miscellaneous support tools. + +Note: Some utilities might require installing optional depedencies. + +## Concise Summary of Abstractions + - `trace`: Infrastructure to construct symbolic computational DAGs + - `optimizers`: Receive DAG and feedback, output parameter values. + - `trainer`: Manages DAG construction, data sampling, feedback collection, optimizer invocation, and training workflow control. + +## Common Workflow for Using Trace + +1. Define a traceable workflow with `trace.node` and `@trace.bundle`, marking trainable parameters. +2. Wrap this workflow into a `trace.Module` with `@trace.model`. +3. Define a dataloader (`DataLoader`) and feedback oracle (analogous to a loss function, using e.g. `AutoGuide`). +4. Instantiate a trainer from `opto.trainers`, specifying the optimizer from `opto.optimizers` alongside the defined module above, dataloader, and feedback oracle. 5. Run the trainer. - -## Common Workflow of Improving Trace -- **Developing new optimization agent** Contribute to `trace.optimizers` and design new algorithms under `Optimizer` -- **Developing new learning algorithms** Contribute to `trace.trainers` (and `trace.optimizers` when necessary). Design new algorithms under `AlgorithmBase`, new dataloader under `DataLoader`, or new feedback oracle under `AutoGuide`. -- **Improving infrastructure** Propose updates to change `opto.trace` (e.g., to improve UI, add new tracing, etc.) -- **Onboarding other utility tools** Add to `opto.utils` and update `setup.py` with optional requirements. \ No newline at end of file +## Guidelines for Improving and Extending Trace + - **New optimization agents**: Contribute to `opto.optimizers`, sub-class from the `Optimizer` abstraction. + - **New learning algorithms**: Contribute to `opto.trainers` (and optionally `opto.optimizers` if necessary). Design new algorithms sub-classing `AlgorithmBase`, new dataloader under `DataLoader`, or new feedback oracle under `AutoGuide`. + - **Improving infrastructure**: Propose modifications to `opto.trace` to improve tracing capability, user experience, or additional functionality. + - **Onboarding other utility tools**: Add helpful tools to `opto.utils` and update `setup.py` accordingly for optional dependencies. diff --git a/tests/llm_optimizers_tests/test_trainer_refactored.py b/examples/gsm8k_trainer_example.py similarity index 50% rename from tests/llm_optimizers_tests/test_trainer_refactored.py rename to examples/gsm8k_trainer_example.py index 58b32dcd..61c604f0 100644 --- a/tests/llm_optimizers_tests/test_trainer_refactored.py +++ b/examples/gsm8k_trainer_example.py @@ -2,16 +2,16 @@ import numpy as np from opto import trace from opto.utils.llm import LLM, LiteLLM -from opto.optimizers.utils import print_color from opto.optimizers import OptoPrime -from opto.trainer.algorithms.basic_algorithm import BatchedFeedback +from opto.trainer.algorithms.basic_algorithms import MinibatchAlgorithm +from opto.trainer.loggers import DefaultLogger, TensorboardLogger from opto.trainer.guide import VerbalJudgeGuide from typing import Any @trace.model class Learner: - # A basic LLM agent. + """ A basic LLM agent. """ def __init__(self, system_prompt: str = "You're a helpful agent", user_prompt_template: str = "Query: {message}", @@ -22,9 +22,15 @@ def __init__(self, system_prompt: str = "You're a helpful agent", @trace.bundle() def model(self, system_prompt: str, user_prompt_template: str, message: str) -> str: - """ Call the LLM model. system_prompt specifies - the behavior of the agent. user prompt is the input to the agent, which - is formatted as user_prompt_template.format(message=message).""" + """Call the LLM model. + + Args: + system_prompt: the system prompt to the agent. By tuning this prompt, we can control the behavior of the agent. For example, it can be used to provide instructions to the agent (such as how to reason about the problem, how to answer the question), or provide in-context examples of how to solve the problem. + user_prompt_template: the user prompt template to the agent. It is used as formatting the input to the agent as user_prompt_template.format(message=message). + message: the input to the agent. It can be a query, a task, a code, etc. + Returns: + The response from the agent. + """ if '{message}' not in user_prompt_template: raise ValueError("user_prompt_template must contain '{message}'") @@ -39,9 +45,9 @@ def forward(self, message: Any) -> Any: """ Forward pass of the agent. """ return self.model(self.system_prompt, self.user_prompt_template, message) -class Logger: - def log(self, *messages, color=None, **kwargs): - print_color(messages, color=color) + +Guide = VerbalJudgeGuide +Logger = TensorboardLogger def main(): @@ -49,32 +55,38 @@ def main(): seed = 42 num_epochs = 1 batch_size = 1 - eval_frequency = 1 - teacher_model = "gpt-4o-mini" #"gpt-4o-mini_2024-07-18" - student_model = "gpt-35-turbo_1106" + eval_frequency = -1 + verbose = True + teacher_model = None # use default mode + student_model = None # use default mode np.random.seed(seed) - train_dataset = datasets.load_dataset('openai/gsm8k', 'main')['train'][ - :10] # NOTE for now, we train on a smaller portion + # In this example, we use the GSM8K dataset, which is a dataset of math word problems. + # We will look the training error of the agent on a small portion of this dataset. + train_dataset = datasets.load_dataset('openai/gsm8k', 'main')['train'][:10] train_dataset = dict(inputs=train_dataset['question'], infos=train_dataset['answer']) - test_dataset = train_dataset # NOTE for now, we just look at training error - - agent = Learner(llm=LiteLLM(model="gpt-3.5-turbo")) - - guide = VerbalJudgeGuide(model=teacher_model) - - alg = BatchedFeedback(agent=agent, - optimizer=OptoPrime(agent.parameters()), - logger=Logger()) - + test_dataset = train_dataset + + agent = Learner(llm=LLM(student_model)) + guide = Guide(model=teacher_model) + optimizer = OptoPrime(agent.parameters()) + logger = Logger(verbose=verbose) + # set use_json_object_format=False if LLM does not support JSON object format + + alg = MinibatchAlgorithm( + agent=agent, + optimizer=optimizer, + logger=logger) + alg.train(guide, train_dataset, num_epochs=num_epochs, batch_size=batch_size, eval_frequency=eval_frequency, test_dataset=test_dataset, - num_threads=3) + num_threads=3, + verbose='output' if verbose else False) if __name__ == "__main__": diff --git a/examples/minibatch_bbh_aynsc/run_bigbench_trace_async.py b/examples/minibatch_bbh_aynsc/run_bigbench_trace_async.py index b0ed9b28..7e12339f 100644 --- a/examples/minibatch_bbh_aynsc/run_bigbench_trace_async.py +++ b/examples/minibatch_bbh_aynsc/run_bigbench_trace_async.py @@ -10,7 +10,7 @@ import autogen import pickle import os -from opto.trainer.algorithms.basic_algorithm import MinibatchAlgorithm, evaluate +from opto.trainer.algorithms.basic_algorithms import MinibatchAlgorithm, evaluate from opto.trainer.guide import AutoGuide diff --git a/opto/optimizers/__init__.py b/opto/optimizers/__init__.py index e03b7f93..9b0b2007 100644 --- a/opto/optimizers/__init__.py +++ b/opto/optimizers/__init__.py @@ -1,7 +1,9 @@ -from opto.optimizers.optoprime import OptoPrime +from opto.optimizers.optoprime import OptoPrime as OptoPrimeV1 from opto.optimizers.optoprimemulti import OptoPrimeMulti from opto.optimizers.opro import OPRO from opto.optimizers.textgrad import TextGrad -from opto.optimizers.optoprime_batchopt import OptoprimeBatchOpt +from opto.optimizers.optoprime_v2 import OptoPrimeV2 -__all__ = ["OPRO", "OptoPrime", "OptoPrimeMulti", "TextGrad", "OptoprimeBatchOpt"] \ No newline at end of file +OptoPrime = OptoPrimeV1 + +__all__ = ["OPRO", "OptoPrime", "OptoPrimeMulti", "TextGrad", "OptoPrimeV2", "OptoPrimeV1"] \ No newline at end of file diff --git a/opto/optimizers/optimizer.py b/opto/optimizers/optimizer.py index ea2a0503..77ee10db 100644 --- a/opto/optimizers/optimizer.py +++ b/opto/optimizers/optimizer.py @@ -54,10 +54,19 @@ def trace_graph(self): def step(self, bypassing=False, *args, **kwargs): update_dict = self.propose(*args, **kwargs) + self.project(update_dict) if not bypassing: self.update(update_dict) return update_dict # TODO add reasoning + def project(self, update_dict: Dict[ParameterNode, Any]): + """Project the update dictionary onto the feasible set.""" + for p, d in update_dict.items(): + if p.trainable: + for projection in p.projections: + d = projection.project(d) + update_dict[p] = d + def propose(self, *args, **kwargs): """Propose the new data of the parameters based on the feedback.""" return self._step(*args, **kwargs) diff --git a/opto/optimizers/optoprime.py b/opto/optimizers/optoprime.py index 6ac4ce95..5a5c5c36 100644 --- a/opto/optimizers/optoprime.py +++ b/opto/optimizers/optoprime.py @@ -259,6 +259,7 @@ def __init__( max_tokens=4096, log=True, prompt_symbols=None, + use_json_object_format=True, # whether to use json object format for the response when calling LLM **kwargs, ): super().__init__(parameters, *args, propagator=propagator, **kwargs) @@ -294,6 +295,7 @@ def __init__( self.prompt_symbols = copy.deepcopy(self.default_prompt_symbols) if prompt_symbols is not None: self.prompt_symbols.update(prompt_symbols) + self.use_json_object_format = use_json_object_format def default_propagator(self): """Return the default Propagator object of the optimizer.""" @@ -478,11 +480,7 @@ def construct_update_dict( for node in self.parameters: if node.trainable and node.py_name in suggestion: try: - from black import format_str, FileMode formatted_suggestion = suggestion[node.py_name] - # use black formatter for code reformatting - if type(formatted_suggestion) == str and 'def' in formatted_suggestion: - formatted_suggestion = format_str(formatted_suggestion, mode=FileMode()) update_dict[node] = type(node.data)(formatted_suggestion) except (ValueError, KeyError) as e: # catch error due to suggestion missing the key or wrong data type @@ -561,15 +559,13 @@ def call_llm( {"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}, ] - + + response_format = {"type": "json_object"} if self.use_json_object_format else None try: # Try tp force it to be a json object - response = self.llm( - messages=messages, - response_format={"type": "json_object"}, - max_tokens=max_tokens, - ) + response = self.llm(messages=messages, max_tokens=max_tokens, response_format=response_format) except Exception: response = self.llm(messages=messages, max_tokens=max_tokens) + response = response.choices[0].message.content if verbose: diff --git a/opto/optimizers/optoprime_batchopt.py b/opto/optimizers/optoprime_v2.py similarity index 99% rename from opto/optimizers/optoprime_batchopt.py rename to opto/optimizers/optoprime_v2.py index c34265dd..f0c78258 100644 --- a/opto/optimizers/optoprime_batchopt.py +++ b/opto/optimizers/optoprime_v2.py @@ -3,7 +3,7 @@ from opto.optimizers.optoprime import OptoPrime -class OptoprimeBatchOpt(OptoPrime): +class OptoPrimeV2(OptoPrime): # This is generic representation prompt, which just explains how to read the problem. representation_prompt = dedent( """ diff --git a/opto/trace/__init__.py b/opto/trace/__init__.py index ddf2a778..ddf01300 100644 --- a/opto/trace/__init__.py +++ b/opto/trace/__init__.py @@ -4,6 +4,7 @@ from opto.trace.broadcast import apply_op import opto.trace.propagators as propagators import opto.trace.operators as operators +import opto.trace.projections as projections from opto.trace.nodes import Node, GRAPH from opto.trace.nodes import node diff --git a/opto/trace/bundle.py b/opto/trace/bundle.py index ce080360..db51f8eb 100644 --- a/opto/trace/bundle.py +++ b/opto/trace/bundle.py @@ -39,6 +39,7 @@ def bundle( catch_execution_error=True, allow_external_dependencies=False, overwrite_python_recursion=False, + projections=None, ): """Wrap a function as a FunModule which returns node objects. @@ -53,6 +54,7 @@ def bundle( catch_execution_error (bool, optional): Whether to catch exceptions during operator execution. Defaults to True. allow_external_dependencies (bool, optional): Whether to allow external dependencies. Defaults to False. overwrite_python_recursion (bool, optional): Whether to overwrite Python recursion behavior. Defaults to False. + projections (List[Projection], optional): List of projections to be used in updating trainable parameter. Defaults to None. Returns: FunModule: The wrapped function that returns node objects. @@ -70,6 +72,7 @@ def decorator(fun): allow_external_dependencies=allow_external_dependencies, overwrite_python_recursion=overwrite_python_recursion, _ldict=prev_f_locals, # Get the locals of the calling function + projections=projections, ) return fun_module @@ -124,6 +127,7 @@ def __init__( catch_execution_error=True, allow_external_dependencies=False, overwrite_python_recursion=False, + projections=None, _ldict=None, ): @@ -183,10 +187,12 @@ def __init__( signature = re.search(r"\s*(def.*:)", source).group(1) else: signature = signature_sr.group(1) + self.parameter = ParameterNode( self.info["source"], name="__code", constraint="The code should start with:\n" + signature, + projections=projections, ) @property diff --git a/opto/trace/modules.py b/opto/trace/modules.py index a85d1efb..bdfbcda3 100644 --- a/opto/trace/modules.py +++ b/opto/trace/modules.py @@ -1,8 +1,11 @@ import os import pickle import copy -from opto.trace.containers import ParameterContainer +import inspect +import textwrap +from opto.trace.containers import ParameterContainer, trainable_method from opto.trace.nodes import ParameterNode +from opto.trace.projections import Projection, BlackCodeFormatter def model(cls): @@ -11,7 +14,61 @@ def model(cls): """ class ModelWrapper(cls, Module): - pass + def model_dump(self, filename, projection: Projection = BlackCodeFormatter()): + """Dump the model's source code to a file, including all methods and attributes. + Ignores dunder methods unless they were overridden by the user. + """ + trace_model_body = f"class {cls.__name__}:\n" + + # Get all members of the class + all_members = inspect.getmembers(self) + cls_members = inspect.getmembers(cls) + cls_member_names = [m[0] for m in cls_members] + + # Filter out dunder methods unless they were overridden + filtered_members = [] + for name, member in all_members: + # Skip internal trace reserved members + if name.startswith('__TRACE_RESERVED_'): + continue + + if name not in cls_member_names: + continue + + # Include if it's not a dunder method or if it was overridden + if not name.startswith('__'): + filtered_members.append((name, member)) + elif name.startswith('__'): + # For dunder methods, check if they were overridden + try: + if hasattr(member, '__qualname__') and member.__qualname__.split('.')[0] == cls.__name__: + filtered_members.append((name, member)) + except (AttributeError, TypeError): + # Skip if we can't determine if it was overridden + continue + + # Process each member + for i, (name, member) in enumerate(filtered_members): + if 'FunModule' in str(member): + # Handle methods + source = member.parameter.data + source = textwrap.dedent(source) + indented = textwrap.indent(source, " ") + trace_model_body += indented + else: # this is a class method + source = inspect.getsource(member) + source = textwrap.dedent(source) + indented = textwrap.indent(source, " ") + trace_model_body += indented + + if i < len(all_members) - 1: + trace_model_body += "\n" # only one newline between members + + if projection is not None: + trace_model_body = projection.project(trace_model_body) + + with open(filename, "w") as f: + f.write(trace_model_body) return ModelWrapper @@ -25,8 +82,8 @@ def forward(self, *args, **kwargs): def __call__(self, *args, **kwargs): return self.forward(*args, **kwargs) - def save(self, file_name): - """Save the parameters of the model to a file.""" + def save(self, file_name: str): + """Save the parameters of the model to a pickle file.""" # detect if the directory exists directory = os.path.dirname(file_name) if directory != "": @@ -35,7 +92,7 @@ def save(self, file_name): pickle.dump(copy.deepcopy(self.parameters_dict()), f) def load(self, file_name): - """Load the parameters of the model from a file.""" + """Load the parameters of the model from a pickle file.""" with open(file_name, "rb") as f: loaded_data = pickle.load(f) self._set(loaded_data) @@ -62,4 +119,4 @@ def _set(self, new_parameters): parameters_dict[k]._set(v) else: # if the parameter does not exist assert k not in self.__dict__ - setattr(self, k, v) + setattr(self, k, v) \ No newline at end of file diff --git a/opto/trace/nodes.py b/opto/trace/nodes.py index a05e662c..c159624d 100644 --- a/opto/trace/nodes.py +++ b/opto/trace/nodes.py @@ -2007,6 +2007,7 @@ def __init__( trainable=True, description="[ParameterNode] This is a ParameterNode in a computational graph.", constraint=None, + projections=None, # a list of Projection info=None, ) -> None: if description is None or description == "": @@ -2027,6 +2028,15 @@ def __init__( info=info, ) self._dependencies["parameter"].add(self) + if projections is not None: + assert isinstance( + projections, list + ), "Projections must be a list of Projection objects." + from opto.trace.projections import Projection + assert all( + isinstance(p, Projection) for p in projections + ), "All projections must be instances of Projection." + self._projections = projections def __str__(self) -> str: # str(node) allows us to look up in the feedback dictionary easily diff --git a/opto/trace/projections/__init__.py b/opto/trace/projections/__init__.py new file mode 100644 index 00000000..7264d5bd --- /dev/null +++ b/opto/trace/projections/__init__.py @@ -0,0 +1,2 @@ +from opto.trace.projections.projections import Projection +from opto.trace.projections.code_projections import BlackCodeFormatter, DocstringProjection \ No newline at end of file diff --git a/opto/trace/projections/code_projections.py b/opto/trace/projections/code_projections.py new file mode 100644 index 00000000..78a4642c --- /dev/null +++ b/opto/trace/projections/code_projections.py @@ -0,0 +1,31 @@ + +from opto.trace.projections import Projection + +class BlackCodeFormatter(Projection): + # This requires the `black` package to be installed. + + def project(self, x: str) -> str: + # importing here to avoid necessary dependencies on black + # use black formatter for code reformatting + from black import format_str, FileMode + if type(x) == str and 'def' in x: + x = format_str(x, mode=FileMode()) + return x + +class DocstringProjection(Projection): + """ + Projection that formats docstrings. + """ + def __init__(self, docstring: str): + self.docstring = docstring + + def project(self, x: str) -> str: + """ Replace the docstring in the code wit the stored docstring. """ + if type(x) == str and '"""' in x: + # replace the docstring in the code with the stored docstring + x = x.split('"""', 2) + if len(x) > 2: + x = f'{x[0]}"""{self.docstring}"""{x[2]}' + else: + x = f'{x[0]}"""{self.docstring}"""' + return x \ No newline at end of file diff --git a/opto/trace/projections/projections.py b/opto/trace/projections/projections.py new file mode 100644 index 00000000..9be4227c --- /dev/null +++ b/opto/trace/projections/projections.py @@ -0,0 +1,29 @@ +from typing import Any + + +class Projection: + """ + Abstract base class for projection methods. + """ + + def __init__(self, *args, **kwargs): + pass + + def __call__(self, x: Any) -> Any: + """ + Call the projection method on the parameter node `x`. + + Args: + x: The parameter node to project. + + Returns: + The projected parameter node. + """ + return self.project(x) + + def project(self, x: Any) -> Any: + """ + Project the parameter node `x` onto the feasible set. + """ + raise NotImplementedError("Subclasses should implement this method.") + \ No newline at end of file diff --git a/opto/trainer/algorithms/__init__.py b/opto/trainer/algorithms/__init__.py index aac6a494..ea5dde63 100644 --- a/opto/trainer/algorithms/__init__.py +++ b/opto/trainer/algorithms/__init__.py @@ -1 +1 @@ -from opto.trainer.algorithms.basic_algorithm import Minibatch, MinibatchAlgorithm, BasicSearchAlgorithm +from opto.trainer.algorithms.basic_algorithms import Minibatch, MinibatchAlgorithm, BasicSearchAlgorithm diff --git a/opto/trainer/algorithms/aggregator.py b/opto/trainer/algorithms/aggregator.py index 4f94d999..a1d30a67 100644 --- a/opto/trainer/algorithms/aggregator.py +++ b/opto/trainer/algorithms/aggregator.py @@ -9,7 +9,7 @@ from opto.trace.nodes import ParameterNode from opto.optimizers.utils import print_color from opto.trainer.algorithms import Minibatch -from opto.trainer.algorithms.basic_algorithm import standard_optimization_step +from opto.trainer.algorithms.basic_algorithms import standard_optimization_step from opto.utils.llm import LLM, AbstractModel diff --git a/opto/trainer/algorithms/algorithm.py b/opto/trainer/algorithms/algorithm.py index 927d186f..9ec35fcc 100644 --- a/opto/trainer/algorithms/algorithm.py +++ b/opto/trainer/algorithms/algorithm.py @@ -1,7 +1,8 @@ import warnings from opto import trace from opto.trace.modules import Module -from opto.trainer.utils import async_run, DefaultLogger +from opto.trainer.utils import async_run +from opto.trainer.loggers import DefaultLogger import os diff --git a/opto/trainer/algorithms/basic_algorithm.py b/opto/trainer/algorithms/basic_algorithms.py similarity index 99% rename from opto/trainer/algorithms/basic_algorithm.py rename to opto/trainer/algorithms/basic_algorithms.py index d443367e..66596580 100644 --- a/opto/trainer/algorithms/basic_algorithm.py +++ b/opto/trainer/algorithms/basic_algorithms.py @@ -272,9 +272,6 @@ def update(self, outputs, *args, **kwargs): feedback = batchify(*feedbacks).data # str average_score = np.mean(scores) if all([s is not None for s in scores]) else None - fig = target.backward(visualize=True, retain_graph=True) - fig.render("minibatch.pdf") - # Update the agent using the feedback self.optimizer.zero_feedback() self.optimizer.backward(target, feedback) diff --git a/opto/trainer/guide.py b/opto/trainer/guide.py index 5a11dcce..30c428a6 100644 --- a/opto/trainer/guide.py +++ b/opto/trainer/guide.py @@ -53,22 +53,26 @@ class VerbalJudgeGuide(AutoGuide): This is an implementation of LLM-as-a-judge. """ + DEFAULT_CORRECTNESS_TEMPLATE = "Correct [TERMINATE]" + DEFAULT_INCORRECTNESS_TEMPLATE = "Incorrect" + DEFAULT_PROMPT_TEMPLATE = ( - "The query is: {query}. The student answered: {response}. The correct answer is: {reference}. " - "If the student answer is correct, please say 'Correct [TERMINATE]'. " - "Otherwise, if the student answer is incorrect, please provide feedback to the student. " + "The query is: {query}.\n\n\nThe student answered: {response}.\n\n\nThe correct answer is: {reference}.\n\n\n" + "Reason whether the student answer is correct. If the student answer is correct, please say {correctness_template}. " + "Otherwise, if the student answer is incorrect, say {incorrectness_template} and provide feedback to the student. " "The feedback should be specific and actionable." ) DEFAULT_SYSTEM_PROMPT = "You're a helpful teacher who provides clear and constructive feedback." - DEFAULT_CORRECTNESS_TEMPLATE = "Correct [TERMINATE]" def __init__(self, model: Optional[str] = None, llm: Optional[AbstractModel] = None, prompt_template: Optional[str] = None, system_prompt: Optional[str] = None, - correctness_template: Optional[str] = None): + correctness_template: Optional[str] = None, + use_formatted_response: bool = True + ): """ Initialize the VerbalGuide with an LLM and prompt templates. @@ -78,12 +82,14 @@ def __init__(self, prompt_template: Custom prompt template with {response} and {reference} placeholders system_prompt: Custom system prompt for the LLM correctness_template: Template to use when response is deemed correct by metric + use_formatted_response: Whether to format the response with additional context; if False, the raw LLM response is returned """ self.model = model self.llm = llm or LLM(model=model) self.prompt_template = prompt_template or self.DEFAULT_PROMPT_TEMPLATE self.system_prompt = system_prompt or self.DEFAULT_SYSTEM_PROMPT self.correctness_template = correctness_template or self.DEFAULT_CORRECTNESS_TEMPLATE + self.use_formatted_response = use_formatted_response def get_feedback(self, query: str, response: str, reference: Optional[str] = None, **kwargs) -> Tuple[float, str]: """ @@ -103,7 +109,12 @@ def get_feedback(self, query: str, response: str, reference: Optional[str] = Non raise ValueError("ReferenceGuide requires reference information to generate feedback") # Check if metric function indicates perfect match - user_prompt = self.prompt_template.format(query=query, response=response, reference=reference) + user_prompt = self.prompt_template.format( + query=query, + response=response, + reference=reference, + correctness_template=self.DEFAULT_CORRECTNESS_TEMPLATE, + incorrectness_template=self.DEFAULT_INCORRECTNESS_TEMPLATE) messages = [ {"role": "system", "content": self.system_prompt}, @@ -128,7 +139,7 @@ def get_feedback(self, query: str, response: str, reference: Optional[str] = Non score = 1 if 'Correct [TERMINATE]' in llm_response else 0 - return score, formatted_response + return score, formatted_response if self.use_formatted_response else llm_response def forward(self, task: str, response: str, info: Any, **kwargs) -> Tuple[float, str]: score, feedback = self.get_feedback(task, response, info, **kwargs) diff --git a/opto/trainer/loggers.py b/opto/trainer/loggers.py new file mode 100644 index 00000000..5f82a4ac --- /dev/null +++ b/opto/trainer/loggers.py @@ -0,0 +1,84 @@ + + +class BaseLogger: + + def __init__(self, log_dir='./logs', **kwargs): + """Initialize the logger. This method can be overridden by subclasses.""" + self.log_dir = log_dir + pass + + def log(self, name, data, step, **kwargs): + """Log a message with the given name and data at the specified step. + + Args: + name: Name of the metric + data: Value of the metric + step: Current step/iteration + **kwargs: Additional arguments (e.g., color) + """ + raise NotImplementedError("Subclasses should implement this method.") + + +class ConsoleLogger(BaseLogger): + """A simple logger that prints messages to the console.""" + + def log(self, name, data, step, **kwargs): + """Log a message to the console. + + Args: + name: Name of the metric + data: Value of the metric + step: Current step/iteration + **kwargs: Additional arguments (e.g., color) + """ + color = kwargs.get('color', None) + # Simple color formatting for terminal output + color_codes = { + 'green': '\033[92m', + 'red': '\033[91m', + 'blue': '\033[94m', + 'end': '\033[0m' + } + + start_color = color_codes.get(color, '') + end_color = color_codes['end'] if color in color_codes else '' + + print(f"[Step {step}] {start_color}{name}: {data}{end_color}") + + +class TensorboardLogger(ConsoleLogger): + """A logger that writes metrics to TensorBoard.""" + + def __init__(self, log_dir='./logs', verbose=True, **kwargs): + super().__init__(log_dir, **kwargs) + self.verbose = verbose + # Late import to avoid dependency issues + try: + from tensorboardX import SummaryWriter + except ImportError: + # try importing from torch.utils.tensorboard if tensorboardX is not available + from torch.utils.tensorboard import SummaryWriter + + self.writer = SummaryWriter(self.log_dir) + + def log(self, name, data, step, **kwargs): + """Log a message to TensorBoard. + + Args: + name: Name of the metric + data: Value of the metric + step: Current step/iteration + **kwargs: Additional arguments (not used here) + """ + if self.verbose: + super().log(name, data, step, **kwargs) + if isinstance(data, str): + # If data is a string, log it as text + self.writer.add_text(name, data, step) + else: + # Otherwise, log it as a scalar + self.writer.add_scalar(name, data, step) + +# TODO add wandb logger + +DefaultLogger = ConsoleLogger \ No newline at end of file diff --git a/opto/trainer/utils.py b/opto/trainer/utils.py index b8dad65c..717ff23b 100644 --- a/opto/trainer/utils.py +++ b/opto/trainer/utils.py @@ -47,33 +47,6 @@ async def _run(): return asyncio.run(_run()) -class DefaultLogger: - """A simple logger that prints messages to the console.""" - - def log(self, name, data, step, **kwargs): - """Log a message to the console. - - Args: - name: Name of the metric - data: Value of the metric - step: Current step/iteration - **kwargs: Additional arguments (e.g., color) - """ - color = kwargs.get('color', None) - # Simple color formatting for terminal output - color_codes = { - 'green': '\033[92m', - 'red': '\033[91m', - 'blue': '\033[94m', - 'end': '\033[0m' - } - - start_color = color_codes.get(color, '') - end_color = color_codes['end'] if color in color_codes else '' - - print(f"[Step {step}] {start_color}{name}: {data}{end_color}") - - if __name__ == "__main__": def tester(t): # regular time-consuming function diff --git a/setup.py b/setup.py index 5ab3a9a1..4fa7eef5 100644 --- a/setup.py +++ b/setup.py @@ -14,8 +14,10 @@ "litellm", "black", "scikit-learn", + "tensorboardX", + "tensorboard" ] - + setuptools.setup( name="trace-opt", version=__version__, diff --git a/tests/llm_optimizers_tests/test_trainer.py b/tests/llm_optimizers_tests/test_trainer.py deleted file mode 100644 index 3f88cccb..00000000 --- a/tests/llm_optimizers_tests/test_trainer.py +++ /dev/null @@ -1,94 +0,0 @@ -import datasets -import numpy as np -from opto import trace -from opto.utils.llm import AutoGenLLM -from opto.optimizers.utils import print_color -from opto.optimizers import OptoPrime -from opto.trainer import train -from typing import Any - - -@trace.model -class Student: - # A basic LLM agent. - - def __init__(self, system_prompt: str = "You're a helpful agent", - user_prompt_template: str = "Query: {message}", - llm: AutoGenLLM = None): - self.system_prompt = trace.node(system_prompt, trainable=True) - self.user_prompt_template = trace.node(user_prompt_template) - self.llm = llm or AutoGenLLM() - - @trace.bundle() - def model(self, system_prompt: str, user_prompt_template: str, message: str) -> str: - """ Call the LLM model. system_prompt specifies - the behavior of the agent. user prompt is the input to the agent, which - is formatted as user_prompt_template.format(message=message).""" - - if '{message}' not in user_prompt_template: - raise ValueError("user_prompt_template must contain '{message}'") - - response = self.llm( - messages = [{"role": "system", "content": system_prompt}, - {"role": "user", "content": user_prompt_template.format(message=message)}] - ) - return response.choices[0].message.content - - def forward(self, message: Any) -> Any: - """ Forward pass of the agent. """ - return self.model(self.system_prompt, self.user_prompt_template, message) - - -def teacher(student_answer, info, model="gpt-4o-mini_2024-07-18"): - """ Use LLM to evaluate the student answer. """ - llm = AutoGenLLM(filter_dict={"model": [model]}) - system_prompt = "You're a match teacher who helps students to learn. " - user_prompt_template = "The student answered: {}. The correct answer is {}. If the student answer is correct, please say 'Correct [TERMINATE]'. Otherwise, if the student answer is incorrect, please provide feedback to the student. The feedback should be specific and actionable." - true_answer = info - - response = llm( - messages = [{"role": "system", "content": system_prompt}, - {"role": "user", "content": user_prompt_template.format(student_answer, true_answer)}] - ) - - response = response.choices[0].message.content - score = 1 if 'Correct [TERMINATE]' in response else 0 - return score, response - - - -class Logger: - def log(self, message, color=None, **kwargs): - print_color(message, color=color) - - - -def main(): - # set seed - seed = 42 - num_epochs = 1 - batch_size = 1 - eval_frequency = 1 - teacher_model = "gpt-4o-mini_2024-07-18" - student_model = "gpt-35-turbo_1106" - - np.random.seed(seed) - - train_dataset = datasets.load_dataset('openai/gsm8k', 'main')['train'][:10] # NOTE for now, we train on a smaller portion - train_dataset = dict(inputs=train_dataset['question'], infos=train_dataset['answer']) - test_dataset = train_dataset # NOTE for now, we just look at training error - - - train(agent=Student(llm=AutoGenLLM(filter_dict={"model": ["gpt-35-turbo_1106"]})), - teacher=lambda *args, **kwargs : teacher(model=teacher_model, *args, **kwargs), - train_dataset=train_dataset, - num_epochs=num_epochs, - logger=Logger(), - batch_size=batch_size, - test_dataset=test_dataset, - eval_frequency=eval_frequency - ) - - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/tests/unit_tests/test_projection.py b/tests/unit_tests/test_projection.py new file mode 100644 index 00000000..794fffcd --- /dev/null +++ b/tests/unit_tests/test_projection.py @@ -0,0 +1,38 @@ +from opto.trace.projections import BlackCodeFormatter, DocstringProjection + +def test_black_code_formatter(): + code = """ +def example_function(): + print("Hello, World!") + + + print("This is a test function.") + + + + """ + projection = BlackCodeFormatter() + formatted_code = projection.project(code) + assert formatted_code == 'def example_function():\n print("Hello, World!")\n\n print("This is a test function.")\n' + + +def test_docstring_projection(): + code = """ +def example_function(): + \"\"\"This is an example function.\"\"\" + print("Hello, World!") + """ + docstring = "This is a new docstring." + projection = DocstringProjection(docstring) + formatted_code = projection.project(code) + + new_code = """ +def example_function(): + \"\"\"This is a new docstring.\"\"\" + print("Hello, World!") + """ + + assert formatted_code == new_code + + # assert '"""This is a new docstring."""' in formatted_code + # assert 'print("Hello, World!")' in formatted_code \ No newline at end of file