Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 44 additions & 11 deletions alphatrion/experiment/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,22 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass

from pydantic import BaseModel, Field

from alphatrion.runtime.runtime import global_runtime
from alphatrion.trial import trial
from alphatrion.utils import context


class ExperimentConfig(BaseModel):
"""Configuration for a Experiment."""

max_execution_seconds: int = Field(
default=-1,
description="Maximum execution seconds for the experiment. \
Once exceeded, the experiment and all its trials will be cancelled. \
Default is -1 (no limit).",
)


@dataclass
Expand All @@ -12,14 +26,34 @@ class Experiment(ABC):
Base Experiment class. One instance one experiment, multiple trials.
"""

__slots__ = ("_runtime", "_id", "_trials")
__slots__ = ("_runtime", "_id", "_trials", "_config")

@classmethod
@abstractmethod
def setup(
cls,
name: str,
description: str | None = None,
meta: dict | None = None,
config: ExperimentConfig | None = None,
) -> "Experiment":
"""Return a new experiment."""
...

def __init__(self):
def __init__(self, config: ExperimentConfig | None = None):
self._runtime = global_runtime()
# All trials in this experiment, key is trial_id, value is Trial instance.
self._trials = dict()
self._trials: dict[int, trial.Trial] = {}
self._config = config or ExperimentConfig()
self._runtime.current_exp = self

self._context = context.Context(
cancel_func=self._stop,
timeout=self._config.max_execution_seconds
if self._config.max_execution_seconds > 0
else None,
)

@property
def id(self):
return self._id
Expand All @@ -40,22 +74,21 @@ async def __aenter__(self):
async def __aexit__(self, exc_type, exc_val, exc_tb):
self.done()

# done() is safe to call multiple times.
def done(self):
self._cancel()

def _cancel(self):
return self._context.cancel()

def _stop(self):
for t in list(self._trials.values()):
t.done()
self._trials = dict()
# Set to None at the end of the experiment because
# it will be used in trial.done().
self._runtime.current_exp = None

@classmethod
@abstractmethod
def setup(
cls, name: str, description: str | None = None, meta: dict | None = None
) -> "Experiment":
"""Return a new experiment."""
...

def register_trial(self, id: uuid.UUID, instance: trial.Trial):
self._trials[id] = instance

Expand Down
9 changes: 5 additions & 4 deletions alphatrion/experiment/craft_exp.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from alphatrion.experiment.base import Experiment
from alphatrion.experiment.base import Experiment, ExperimentConfig
from alphatrion.trial.trial import Trial, TrialConfig


Expand All @@ -10,22 +10,23 @@ class CraftExperiment(Experiment):
Opposite to other experiment classes, you need to call all these methods yourself.
"""

def __init__(self):
super().__init__()
def __init__(self, config: ExperimentConfig | None = None):
super().__init__(config=config)

@classmethod
def setup(
cls,
name: str,
description: str | None = None,
meta: dict | None = None,
config: ExperimentConfig | None = None,
) -> "CraftExperiment":
"""
Setup the experiment. If the name already exists in the same project,
it will refer to the existing experiment instead of creating a new one.
"""

exp = CraftExperiment()
exp = CraftExperiment(config)
exp_obj = exp._get_by_name(name=name, project_id=exp._runtime._project_id)

# If experiment with the same name exists in the project, use it.
Expand Down
1 change: 1 addition & 0 deletions alphatrion/utils/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@


# Inspired by golang context package
# TODO: support inheritable context
class Context:
def __init__(self, cancel_func: Callable | None = None, timeout: int | None = None):
"""A context for managing cancellation and timeouts.
Expand Down
71 changes: 69 additions & 2 deletions tests/unit/experiment/test_craft_exp.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import pytest

from alphatrion.experiment.craft_exp import CraftExperiment
from alphatrion.experiment.craft_exp import CraftExperiment, ExperimentConfig
from alphatrion.metadata.sql_models import Status
from alphatrion.runtime.runtime import global_runtime, init
from alphatrion.trial.trial import Trial, TrialConfig, current_trial_id
Expand Down Expand Up @@ -192,7 +192,7 @@ async def fake_work(timeout: int):


@pytest.mark.asyncio
async def test_craft_experiment_with_context():
async def test_craft_experiment_with_max_execution_seconds():
init(project_id=uuid.uuid4(), artifact_insecure=True, init_tables=True)

async with CraftExperiment.setup(
Expand Down Expand Up @@ -243,3 +243,70 @@ async def fake_work():
fake_work(),
)
print("All trials finished.")


@pytest.mark.asyncio
async def test_craft_experiment_with_config():
init(project_id=uuid.uuid4(), artifact_insecure=True, init_tables=True)

async with CraftExperiment.setup(
name="context_exp",
description="Context manager test",
meta={"key": "value"},
config=ExperimentConfig(max_execution_seconds=2),
) as exp:
trial = exp.start_trial(name="first-trial")
await trial.wait()
assert trial.is_done()

trial = trial._get_obj()
assert trial.status == Status.COMPLETED


@pytest.mark.asyncio
async def test_craft_experiment_with_hierarchy_timeout():
init(project_id=uuid.uuid4(), artifact_insecure=True, init_tables=True)

async with CraftExperiment.setup(
name="context_exp",
description="Context manager test",
meta={"key": "value"},
config=ExperimentConfig(max_execution_seconds=2),
) as exp:
start_time = datetime.now()
trial = exp.start_trial(
name="first-trial", config=TrialConfig(max_execution_seconds=5)
)
await trial.wait()
assert trial.is_done()

assert (datetime.now() - start_time).total_seconds() >= 2
assert (datetime.now() - start_time).total_seconds() < 5

trial = trial._get_obj()
assert trial.status == Status.COMPLETED

@pytest.mark.asyncio
async def test_craft_experiment_with_hierarchy_timeout_2():
init(project_id=uuid.uuid4(), artifact_insecure=True, init_tables=True)

start_time = datetime.now()

async with CraftExperiment.setup(
name="context_exp",
description="Context manager test",
meta={"key": "value"},
config=ExperimentConfig(max_execution_seconds=5),
) as exp:
trial = exp.start_trial(
name="first-trial", config=TrialConfig(max_execution_seconds=2)
)
await trial.wait()
assert trial.is_done()

assert (datetime.now() - start_time).total_seconds() >= 2

trial = trial._get_obj()
assert trial.status == Status.COMPLETED

assert (datetime.now() - start_time).total_seconds() < 5
Loading