diff --git a/plugins/openhands/platoon/openhands/agent.py b/plugins/openhands/platoon/openhands/agent.py index 9b07d60..1f4d117 100644 --- a/plugins/openhands/platoon/openhands/agent.py +++ b/plugins/openhands/platoon/openhands/agent.py @@ -3,7 +3,7 @@ import asyncio from copy import deepcopy from platoon.envs.base import Task -from platoon.openhands.types import OpenHandsObservation, OpenHandsAction +from .types import OpenHandsObservation, OpenHandsAction from platoon.utils.openhands_utils import get_actions_for_last_obs from platoon.utils.openhands_utils import is_finished diff --git a/plugins/openhands/platoon/openhands/env.py b/plugins/openhands/platoon/openhands/env.py index fe757e9..280d08e 100644 --- a/plugins/openhands/platoon/openhands/env.py +++ b/plugins/openhands/platoon/openhands/env.py @@ -3,7 +3,7 @@ from __future__ import annotations from platoon.envs.base import Task -from .types import OpenHandsObservation, OpenHandsTrajectoryStep, OpenHandsAction +from platoon.openhands.types import OpenHandsObservation, OpenHandsTrajectoryStep, OpenHandsAction from openhands.sdk.conversation import get_agent_final_response from openhands.sdk.conversation.base import BaseConversation from openhands.sdk.agent.base import AgentBase @@ -75,6 +75,9 @@ async def step(self, action: OpenHandsAction) -> OpenHandsObservation: if self._state.conversation_state.agent_status == ConversationExecutionStatus.STUCK: error_message.set("Agent got stuck") self._state.misc["error_message"] = error_message.get() + elif self._state.conversation_state.agent_status == ConversationExecutionStatus.ERROR: #TODO: check + error_message.set("Agent encountered an error") + self._state.misc["error_message"] = error_message.get() traj_collection = current_trajectory_collection.get() traj = current_trajectory.get() @@ -88,6 +91,9 @@ async def close(self) -> None: if self._conversation is not None: self._conversation.close() self._conversation = None + # TODO: check if cleaning up workspace manually is required + if isinstance(self._workspace, BaseWorkspace): + await self._workspace.cleanup() # TODO: Consider adding a return_copy option here. async def observe(self) -> OpenHandsObservation: diff --git a/plugins/openhands_rl/README.md b/plugins/openhands_rl/README.md new file mode 100644 index 0000000..bbb1c7c --- /dev/null +++ b/plugins/openhands_rl/README.md @@ -0,0 +1,41 @@ +# platoon-openhands-rl + +Platoon plugin for intermediate rewards with the OpenHands software agent SDK. + +## Installation + +This plugin depends on: +- **platoon** (core library) +- **platoon-openhands** (OpenHands plugin) +- **areal** backend (for RL training) + +### Prerequisites + +- Python 3.12 +- [uv](https://docs.astral.sh/uv/) package manager + +### Step-by-step installation + +We recommend installing into a dedicated virtual environment (not in home directory for Babel space usage constraints). The instructions below use a custom location (`/data/user_data//uv_cache/platoon/`), but you can use any path. In Babel, use a compute node (not a CPU node) so that GPU is detected during torch installation + +Assuming you are in project root directory. +```bash +# Create directory for the environment +export VIRTUAL_ENV=/data/user_data//uv_cache/platoon +export UV_CACHE_DIR=/data/user_data//uv_cache/.cache +uv sync --active --extra areal --extra wandb +mkdir -p /data/user_data/$USER/uv_cache/platoon +source /data/user_data//uv_cache/platoon/bin/activate +uv pip install -e plugins/openhands +uv pip install -e plugins/openhands_rl +``` + +### Verify installation + +```bash +python -c " +from platoon.openhands import * +from platoon.openhands_rl import * +print('All packages imported successfully!') +" +``` diff --git a/plugins/openhands_rl/platoon/openhands_rl/__init__.py b/plugins/openhands_rl/platoon/openhands_rl/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/plugins/openhands_rl/platoon/openhands_rl/env.py b/plugins/openhands_rl/platoon/openhands_rl/env.py new file mode 100644 index 0000000..98623b4 --- /dev/null +++ b/plugins/openhands_rl/platoon/openhands_rl/env.py @@ -0,0 +1,24 @@ +from pathlib import Path +from platoon.utils.openhands_utils import is_finished +from platoon.episode.context import current_trajectory_collection, current_trajectory, finish_message, error_message + +from openhands.sdk import get_logger +from platoon.envs.base import Task +from openhands.sdk.agent import AgentBase +from openhands.sdk.workspace import BaseWorkspace +from openhands.sdk.conversation import Conversation, BaseConversation, get_agent_final_response +#TODO: check below imports +from platoon.openhands.env import OpenHandsEnv +from platoon.openhands.types import OpenHandsObservation, OpenHandsAction, OpenHandsTrajectoryStep +import threading +import asyncio +from platoon.utils.openhands_utils import get_obs_for_last_action + +logger = get_logger(__name__) + +# TODO: double-check if we really need to over-ride any other methods from OpenHandsEnv +# NOTE: The primary job of this class is to implement the step-wise reward functionality. +class OpenHandsRLEnv(OpenHandsEnv): + async def evaluate(self) -> tuple[float, dict]: + raise NotImplementedError("OpenHandsRLEnv does not implement evaluate() --> this is something which Aditya will do soon.") + return 0., {} diff --git a/plugins/openhands_rl/platoon/openhands_rl/prompts/default.j2 b/plugins/openhands_rl/platoon/openhands_rl/prompts/default.j2 new file mode 100644 index 0000000..fa80543 --- /dev/null +++ b/plugins/openhands_rl/platoon/openhands_rl/prompts/default.j2 @@ -0,0 +1,63 @@ +I have access to a python code repository in the directory {{ instance.repo_path }} . You can explore and modify files using the available tools. Consider the following issue description: + + +{{ instance.problem_statement }} + + +Can you help me implement the necessary changes to the repository so that the requirements specified in the are met? +I've already taken care of all changes to any of the test files described in the . This means you DON'T have to modify the testing logic or any of the tests in any way! +Also the development Python environment is already set up for you (i.e., all dependencies already installed), so you don't need to install other packages. +Your task is to make the minimal changes to non-test files in the {{ instance.repo_path }} directory to ensure the is satisfied. + +Follow these phases to resolve the issue: + +Phase 1. READING: read the problem and reword it in clearer terms + 1.1 If there are code or config snippets. Express in words any best practices or conventions in them. + 1.2 Hightlight message errors, method names, variables, file names, stack traces, and technical details. + 1.3 Explain the problem in clear terms. + 1.4 Enumerate the steps to reproduce the problem. + 1.5 Hightlight any best practices to take into account when testing and fixing the issue + +Phase 2. RUNNING: install and run the tests on the repository + 2.1 Activate the environment by running + ./opt/miniconda3/etc/profile.d/conda.sh ; conda activate testbed + 2.2 Follow the readme + 2.3 Install the environment and anything needed + 2.4 Iterate and figure out how to run the tests + +Phase 3. EXPLORATION: find the files that are related to the problem and possible solutions + 3.1 Use `grep` to search for relevant methods, classes, keywords and error messages. + 3.2 Identify all files related to the problem statement. + 3.3 Propose the methods and files to fix the issue and explain why. + 3.4 From the possible file locations, select the most likely location to fix the issue. + +Phase 4. TEST CREATION: before implementing any fix, create a script to reproduce and verify the issue. + 4.1 Look at existing test files in the repository to understand the test format/structure. + 4.2 Create a minimal reproduction script that reproduces the located issue. + 4.3 Run the reproduction script to confirm you are reproducing the issue. + 4.4 Adjust the reproduction script as necessary. + +Phase 5. FIX ANALYSIS: state clearly the problem and how to fix it + 5.1 State clearly what the problem is. + 5.2 State clearly where the problem is located. + 5.3 State clearly how the test reproduces the issue. + 5.4 State clearly the best practices to take into account in the fix. + 5.5 State clearly how to fix the problem. + +Phase 6. FIX IMPLEMENTATION: Edit the source code to implement your chosen solution. + 6.1 Make minimal, focused changes to fix the issue. + +Phase 7. VERIFICATION: Test your implementation thoroughly. + 7.1 Run your reproduction script to verify the fix works. + 7.2 Add edge cases to your test script to ensure comprehensive coverage. + 7.3 Run existing tests related to the modified code to ensure you haven't broken anything. + +8. FINAL REVIEW: Carefully re-read the problem description and compare your changes with the base commit {{ instance.base_commit }}. + 8.1 Ensure you've fully addressed all requirements. + 8.2 Run any tests in the repository related to: + 8.2.1 The issue you are fixing + 8.2.2 The files you modified + 8.2.3 The functions you changed + 8.3 If any tests fail, revise your implementation until all tests pass + +Be thorough in your exploration, testing, and reasoning. It's fine if your thinking process is lengthy - quality and completeness are more important than brevity. \ No newline at end of file diff --git a/plugins/openhands_rl/platoon/openhands_rl/rollout.py b/plugins/openhands_rl/platoon/openhands_rl/rollout.py new file mode 100644 index 0000000..0157f5e --- /dev/null +++ b/plugins/openhands_rl/platoon/openhands_rl/rollout.py @@ -0,0 +1,247 @@ +import os +from jinja2 import Environment, FileSystemLoader +import asyncio +from platoon.envs.base import Task +from .env import OpenHandsRLEnv +from platoon.utils.llm_client import LLMClient +import subprocess +from pathlib import Path +from openhands.sdk import LLM, get_logger, Agent, AgentBase, Tool +from openhands.workspace import DockerWorkspace, APIRemoteWorkspace, ApptainerWorkspace +from platoon.episode.trajectory import TrajectoryCollection +from platoon.config_defs import RolloutConfig +from openhands.sdk.workspace import BaseWorkspace +from openhands.tools.preset import get_default_agent +from platoon.episode.loop import run_episode +from platoon.episode.context import current_trajectory_collection +from pydantic import SecretStr +from platoon.visualization.event_sinks import JsonlFileSink +from .tasks import EVAL_AGENT_SERVER_IMAGE, SDK_SHORT_SHA, ENV_SETUP_COMMANDS, PROMPT_FILENAME +from platoon.openhands.agent import OpenHandsAgent +import platform +logger = get_logger(__name__) + +# TODO: consider pre-building all docker images, and adding their names in instance on Huggingface dataset for simpler code here +def get_official_docker_image( + instance_id: str, + docker_image_prefix="docker.io/xingyaoww/", #NOTE: default changed to match SWE-Gym + # dataset: str = "swe-gym" #TODO: add dataset parameter in future +) -> str: + # Official SWE-Bench image + # swebench/sweb.eval.x86_64.django_1776_django-11333:v1 + # SWE-Gym image: docker.io/xingyaoww/sweb.eval.x86_64.project-monai_s_monai-6969 + image_name = 'sweb.eval.x86_64.' + instance_id + image_name = image_name.replace('__', '_s_') # to comply with docker image naming convention + official_image_name = (docker_image_prefix.rstrip('/') + '/' + image_name).lower() + logger.info(f"Official {docker_image_prefix} image: {official_image_name}") + return official_image_name + +# NOTE: the below function is for SWE-Bench. +# def get_official_docker_image( +# instance_id: str, +# docker_image_prefix="docker.io/swebench/", +# ) -> str: +# # Official SWE-Bench image +# # swebench/sweb.eval.x86_64.django_1776_django-11333:v1 +# repo, name = instance_id.split("__") +# official_image_name = docker_image_prefix.rstrip("/") +# official_image_name += f"/sweb.eval.x86_64.{repo}_1776_{name}:latest".lower() +# logger.debug(f"Official SWE-Bench image: {official_image_name}") +# return official_image_name + +def extract_custom_tag(base_image: str) -> str: + """ + Extract SWE-Bench instance ID from official SWE-Bench image name. + + Example: + docker.io/swebench/sweb.eval.x86_64.django_1776_django-12155:latest + -> sweb.eval.x86_64.django_1776_django-12155 + """ + name_tag = base_image.split("/")[-1] + name = name_tag.split(":")[0] + return name + +def detect_platform(): + """Detects the correct platform string.""" + machine = platform.machine().lower() + if "arm" in machine or "aarch64" in machine: + return "linux/arm64" + return "linux/amd64" + +def get_instruction( + instance: dict, + workspace_path: str, + prompt_path: str +) -> str: + """Generate user instruction for the agent for SWE-Bench-style tasks.""" + # Set up Jinja2 environment + # NOTE: Template will not work for SWE-Smith as its base commit is None + prompts_dir = os.path.dirname(prompt_path) + template_name = os.path.basename(prompt_path) + env = Environment(loader=FileSystemLoader(prompts_dir)) + template = env.get_template(template_name) + + instance["repo_path"] = workspace_path + # Prepare context for rendering + context = { + "instance": instance, + "actual_workspace_path": workspace_path, + } + context["test_instructions"] = "" + + # Render the instruction + instruction = template.render(context) + return instruction + +def prepare_workspace(instance: dict, task: Task) -> BaseWorkspace: + official_docker_image: str = get_official_docker_image(instance["instance_id"]) + build_target: str = "source-minimal" #NOTE: no other targets work, so this is hard-coded for the time being + custom_tag: str = extract_custom_tag(official_docker_image) + suffix: str = f"-{build_target}" if build_target != "binary" else "" + agent_server_image: str = f"{EVAL_AGENT_SERVER_IMAGE}:{SDK_SHORT_SHA}-{custom_tag}{suffix}" + + workspace_type: str = instance.get("workspace_type", "apptainer") #TODO: make sure the instance dict has this key + env_setup_commands = instance.get("env_setup_commands", ENV_SETUP_COMMANDS) #TODO: make sure the instance dict has this key + if workspace_type == "apptainer": + workspace = ApptainerWorkspace( + server_image=agent_server_image, + working_dir="/workspace", + platform=detect_platform(), + ) + elif workspace_type == "remote": + # TODO: check if the environment variables are passed till this point by AReaL + runtime_api_key = os.getenv("RUNTIME_API_KEY") + runtime_api_url = os.getenv("RUNTIME_API_URL", "https://runtime.eval.all-hands.dev") + workspace = APIRemoteWorkspace( + runtime_api_url=runtime_api_url, + runtime_api_key=runtime_api_key, + server_image=agent_server_image, + target_type="source" if "source" in build_target else "binary", + ) + else: #NOTE: Docker workspace not supported yet since Babel doesn't allow docker access + raise NotImplementedError(f"Workspace type {workspace_type} not implemented yet.") + for cmd in env_setup_commands: + res = workspace.execute_command(cmd) + if res.exit_code != 0: + raise RuntimeError( + f"Failed to run env setup command '{cmd}': {res.stderr}" + ) + logger.debug(f"Ran env setup command '{cmd}': {res.stdout}") + # NOTE: Setup repository in workspace (note that we assume the workspace is remote and has the repo pre-configured from SWE-{Bench, Gym, Smith}'s docker containers) + repo_path = f"/workspace/{instance['repo'].split('/')[-1]}/" + logger.info(f"Repo path in Remote workspace: {repo_path}") + instance["repo_path"] = repo_path + + cp_testbed_repo = workspace.execute_command( + (f"mkdir -p {repo_path} ; cp -r /testbed/. {repo_path}") + ) + assert cp_testbed_repo.exit_code == 0, ( + f"cp_testbed_repo failed: {cp_testbed_repo.stderr}" + ) + git_reset = workspace.execute_command(f"cd {repo_path} ; git reset --hard") + assert git_reset.exit_code == 0, f"git reset failed: {git_reset.stderr}" + return workspace + +def prepare_llm(config: RolloutConfig) -> LLM: + is_train: bool = config.train + # TODO: make more adjustments based on training phase + if is_train: + temperature = 1.0 + else: + temperature = 0.6 + + return LLM( + usage_id="agent", + model=config.model_name, + base_url=config.model_endpoint, + api_key=SecretStr(config.model_api_key) if config.model_api_key is not None else None, + temperature=temperature, + litellm_extra_body={ + # "return_token_ids": True, + "include_stop_str_in_output": False, + "add_generation_prompt": True, + "chat_template_kwargs": { + "enable_thinking": False, + } + }, + ) + +def prepare_agent(llm: LLM) -> AgentBase: + # TODO: make tools configurable via instance/env vars or config + # current behaviour: uses default tools without browser + return get_default_agent(llm=llm, cli_mode=True) # browser is added iff cli_mode is False + +async def run_rollout(task: Task, config: RolloutConfig) -> dict | TrajectoryCollection: + agent = env = None + try: + """ + Steps: + 1. Create a new workspace (apptainer/remote/docker), openhands agent, and initialize env + 2. Create trajectory collection and register event handlers + """ + instance: dict = task.misc # SWE-Bench styled instance, with extra keys: "workspace_type", "docker_image_prefix", "dataset_type", etc. + workspace: BaseWorkspace = prepare_workspace(instance) + + # Get task-specific instruction and configure task parameters + prompt_filename = instance.get("prompt_filename", PROMPT_FILENAME) #NOTE: make sure the instance dict has this key if customized prompt is desired + prompt_dir = (Path(__file__).parent / "prompts").resolve() + prompt_path = prompt_dir / prompt_filename + assert prompt_path.exists(), f"Prompt path {prompt_path} not found" + prompt_path = str(prompt_path) + repo_path = f"/workspace/{instance['repo'].split('/')[-1]}/" + instruction = get_instruction(instance, repo_path, prompt_path) + task.goal = instruction + task.max_steps = config.max_steps if config.max_steps is not None else 100 + + llm: LLM = prepare_llm(config) + agent: AgentBase = prepare_agent(llm) + agent_wrapper_platoon: OpenHandsAgent = OpenHandsAgent() + env: OpenHandsRLEnv = OpenHandsRLEnv(task=task, agent=agent, workspace=workspace) + + traj_collection = TrajectoryCollection() + current_trajectory_collection.set(traj_collection) + + events_path = os.path.join( + config.output_dir, + "events", + f"events_{task.id}_{traj_collection.id}.jsonl" + ) + + traj_collection.register_event_handlers( + JsonlFileSink( + events_path, + collection_id=traj_collection.id, + process_id=os.getpid() + ) + ) + + if config.verbose: + logger.info(f"Process {os.getpid()}: Starting rollout for task {task.id}") + + rollout_task = asyncio.create_task(run_episode(agent_wrapper_platoon, env)) #NOTE: run_episode only calls agent_act which will check the event stream for new actions/observations from the agent-sdk's conversation state + + try: + _ = await asyncio.wait_for(rollout_task, timeout=config.timeout) + except asyncio.TimeoutError: + if config.verbose: + logger.error(f"Process {os.getpid()}: Rollout timed out for task {task.id}") + rollout_task.cancel() + # Don't wait indefinitely - tinker's sample_async may not be cancellable + try: + await asyncio.wait_for(rollout_task, timeout=5.0) + except (asyncio.TimeoutError, asyncio.CancelledError): + logger.warning(f"Process {os.getpid()}: Task cancellation did not complete in 5s for {task.id}, abandoning") + raise + if config.return_dict: + return current_trajectory_collection.get().to_dict() + else: + return current_trajectory_collection.get() + except Exception as e: + if config.verbose: + print(f"Error running rollout for task {task.id}: {e}") + raise + finally: + if agent_wrapper_platoon is not None: + await agent_wrapper_platoon.close() + if env is not None: + await env.close() \ No newline at end of file diff --git a/plugins/openhands_rl/platoon/openhands_rl/tasks.py b/plugins/openhands_rl/platoon/openhands_rl/tasks.py new file mode 100644 index 0000000..f2cc0e3 --- /dev/null +++ b/plugins/openhands_rl/platoon/openhands_rl/tasks.py @@ -0,0 +1,61 @@ +from platoon.envs.base import Task +import pandas as pd +import os +from typing import Dict, Literal, Optional, List +import numpy as np + +EVAL_AGENT_SERVER_IMAGE = "ghcr.io/openhands/eval-agent-server" +SDK_SHORT_SHA = "main" +ENV_SETUP_COMMANDS = ["export PIP_CACHE_DIR=~/.cache/pip"] +PROMPT_FILENAME = "default.j2" +data_loaded: bool = False +train_data_map: Optional[Dict[str, Task]] = {} +val_data_map: Optional[Dict[str, Task]] = {} + +def create_task_from_instance(x: dict) -> Task: + task = Task( + id=x['instance_id'], + misc=x, + # NOTE: optionally add new parameters to instance dicts here if needed + # misc={ + # "instance_id": x['instance_id'], + # "repo": x['repo'], + # "base_commit": x['base_commit'], + # "problem_statement": x['problem_statement'], + # "target": x['target'], + # "workspace_type": x.get("workspace_type", "docker"), # default to docker + # "docker_image_prefix": x.get("docker_image_prefix", "docker.io/xingyaoww/"), + # "dataset_type": x.get("dataset_type", "swe-bench"), + # "prompt_filename": x.get("prompt_filename", PROMPT_FILENAME), + # } + ) + return task + +def load_data(): + global data_loaded, train_data_map, val_data_map + if data_loaded: + return train_data_map, val_data_map + data_path = os.path.join(os.path.dirname(__file__), "train.parquet") #NOTE: make it huggingface dataset if possible + dataset = pd.read_parquet(data_path) + np.random.seed(42) + split_indices = np.random.rand(len(dataset)) < 0.8 + train_df = dataset.iloc[split_indices] + val_df = dataset.iloc[~split_indices] + for _, row in train_df.iterrows(): + train_data_map[row['instance_id']] = create_task_from_instance(row.to_dict()) + for _, row in val_df.iterrows(): + val_data_map[row['instance_id']] = create_task_from_instance(row.to_dict()) + data_loaded = True + return train_data_map, val_data_map + + +# NOTE: we should have enough RAM to hold all the training instances in RAM since they are only <200MB in size, so no need to lazy load from disk for now +def get_task(task_id: str) -> Task: + load_data() + global train_data_map, val_data_map + if task_id in train_data_map: + return train_data_map[task_id] + elif task_id in val_data_map: + return val_data_map[task_id] + else: + raise ValueError(f"Task ID {task_id} not found in training or validation data.") \ No newline at end of file diff --git a/plugins/openhands_rl/platoon/openhands_rl/train.py b/plugins/openhands_rl/platoon/openhands_rl/train.py new file mode 100644 index 0000000..7c13b22 --- /dev/null +++ b/plugins/openhands_rl/platoon/openhands_rl/train.py @@ -0,0 +1,38 @@ +import sys +import logging +from datasets import Dataset +from areal.api.cli_args import load_expr_config +# Enable debug logging for platoon workflows +logging.basicConfig(level=logging.WARNING) # Quiet by default +logging.getLogger("platoon.train.areal.workflows").setLevel(logging.DEBUG) +logging.getLogger("httpx").setLevel(logging.WARNING) # Silence httpx spam + +from platoon.openhands_rl.tasks import get_task, load_data +from platoon.openhands_rl.rollout import run_rollout +from platoon.train.areal import PlatoonArealRLTrainer, PlatoonArealRLTrainerConfig +from platoon.train.areal.workflows import StepWiseArealWorkflow + +def main(args): + config, _ = load_expr_config(args, PlatoonArealRLTrainerConfig) + config: PlatoonArealRLTrainerConfig = config + + train_datamap, val_datamap = load_data() + train_dataset = Dataset.from_list([{ "task_id": x } for x in train_datamap.keys()][:1000]) + val_dataset = Dataset.from_list([{ "task_id": x } for x in val_datamap.keys()][:100]) + + with PlatoonArealRLTrainer( + config=config, + train_dataset=train_dataset, + val_dataset=val_dataset, + ) as trainer: + + proxy_server = trainer.proxy_server + # TODO: do we need custom reward processor here? + workflow = StepWiseArealWorkflow(run_rollout, get_task, config.workflow_config, proxy_server, 'train_rollout', trainer.actor.device, filter_errors=True) + eval_workflow = StepWiseArealWorkflow(run_rollout, get_task, config.workflow_config, proxy_server, 'eval_rollout', trainer.actor.device) + + trainer.train( + workflow=workflow, + eval_workflow=eval_workflow, + ) + diff --git a/plugins/openhands_rl/pyproject.toml b/plugins/openhands_rl/pyproject.toml new file mode 100644 index 0000000..8266890 --- /dev/null +++ b/plugins/openhands_rl/pyproject.toml @@ -0,0 +1,64 @@ +[project] +name = "platoon-openhands-rl" +version = "0.1.0" +description = "Platoon plugin for intermediate rewards with the openhands software agent sdk." +requires-python = "~=3.12.0" +authors = [ + {name = "Aditya Soni", email = "adityabs@cs.cmu.edu"} +] +dependencies = [ + "platoon >= 0.1.0", + "openhands-sdk", + "openhands-tools", + "openhands-workspace", + "openhands-agent-server" +] +[project.optional-dependencies] +# Training backends - install one of these for training +tinker = [ + "platoon[tinker]", +] +# NOTE: areal backend requires uv for installation (not available on PyPI) +areal = [ + "platoon[areal]", +] +# Logging integrations +wandb = [ + "platoon[wandb]", +] +# uv-specific configuration +[tool.uv] +no-build-isolation-package = ['flash-attn'] +# tinker and areal backends are mutually exclusive +conflicts = [ + [ + { extra = "tinker" }, + { extra = "areal" }, + ], +] +override-dependencies = [ + "fastapi[standard]>=0.115.0", + "openai==1.99.6", + "xgrammar==0.1.24", + "outlines-core==0.1.26", + "pyarrow==20.0.0", + "huggingface_hub==0.34", + "datasets==4.3.0", + "networkx==3.3.0" # This can be removed if ai-rubric pins 3.3.0 or areal relaxes the pin. +] +[tool.uv.sources] +platoon = { path = "../..", editable = true } + +[tool.ruff] +line-length = 100 + +[tool.ruff.lint] +select = ["E", "F", "I"] + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + + +[tool.hatch.build.targets.wheel] +packages = ["platoon"]