diff --git a/alphatrion/experiment/base.py b/alphatrion/experiment/base.py index 83c53c5..db1aa2b 100644 --- a/alphatrion/experiment/base.py +++ b/alphatrion/experiment/base.py @@ -2,8 +2,22 @@ from abc import ABC, abstractmethod from dataclasses import dataclass +from pydantic import BaseModel, Field + from alphatrion.runtime.runtime import global_runtime from alphatrion.trial import trial +from alphatrion.utils import context + + +class ExperimentConfig(BaseModel): + """Configuration for a Experiment.""" + + max_execution_seconds: int = Field( + default=-1, + description="Maximum execution seconds for the experiment. \ + Once exceeded, the experiment and all its trials will be cancelled. \ + Default is -1 (no limit).", + ) @dataclass @@ -12,14 +26,34 @@ class Experiment(ABC): Base Experiment class. One instance one experiment, multiple trials. """ - __slots__ = ("_runtime", "_id", "_trials") + __slots__ = ("_runtime", "_id", "_trials", "_config") + + @classmethod + @abstractmethod + def setup( + cls, + name: str, + description: str | None = None, + meta: dict | None = None, + config: ExperimentConfig | None = None, + ) -> "Experiment": + """Return a new experiment.""" + ... - def __init__(self): + def __init__(self, config: ExperimentConfig | None = None): self._runtime = global_runtime() # All trials in this experiment, key is trial_id, value is Trial instance. - self._trials = dict() + self._trials: dict[int, trial.Trial] = {} + self._config = config or ExperimentConfig() self._runtime.current_exp = self + self._context = context.Context( + cancel_func=self._stop, + timeout=self._config.max_execution_seconds + if self._config.max_execution_seconds > 0 + else None, + ) + @property def id(self): return self._id @@ -40,7 +74,14 @@ async def __aenter__(self): async def __aexit__(self, exc_type, exc_val, exc_tb): self.done() + # done() is safe to call multiple times. def done(self): + self._cancel() + + def _cancel(self): + return self._context.cancel() + + def _stop(self): for t in list(self._trials.values()): t.done() self._trials = dict() @@ -48,14 +89,6 @@ def done(self): # it will be used in trial.done(). self._runtime.current_exp = None - @classmethod - @abstractmethod - def setup( - cls, name: str, description: str | None = None, meta: dict | None = None - ) -> "Experiment": - """Return a new experiment.""" - ... - def register_trial(self, id: uuid.UUID, instance: trial.Trial): self._trials[id] = instance diff --git a/alphatrion/experiment/craft_exp.py b/alphatrion/experiment/craft_exp.py index 61001c8..4bc4fa0 100644 --- a/alphatrion/experiment/craft_exp.py +++ b/alphatrion/experiment/craft_exp.py @@ -1,4 +1,4 @@ -from alphatrion.experiment.base import Experiment +from alphatrion.experiment.base import Experiment, ExperimentConfig from alphatrion.trial.trial import Trial, TrialConfig @@ -10,8 +10,8 @@ class CraftExperiment(Experiment): Opposite to other experiment classes, you need to call all these methods yourself. """ - def __init__(self): - super().__init__() + def __init__(self, config: ExperimentConfig | None = None): + super().__init__(config=config) @classmethod def setup( @@ -19,13 +19,14 @@ def setup( name: str, description: str | None = None, meta: dict | None = None, + config: ExperimentConfig | None = None, ) -> "CraftExperiment": """ Setup the experiment. If the name already exists in the same project, it will refer to the existing experiment instead of creating a new one. """ - exp = CraftExperiment() + exp = CraftExperiment(config) exp_obj = exp._get_by_name(name=name, project_id=exp._runtime._project_id) # If experiment with the same name exists in the project, use it. diff --git a/alphatrion/utils/context.py b/alphatrion/utils/context.py index c08b54f..7a91ecd 100644 --- a/alphatrion/utils/context.py +++ b/alphatrion/utils/context.py @@ -3,6 +3,7 @@ # Inspired by golang context package +# TODO: support inheritable context class Context: def __init__(self, cancel_func: Callable | None = None, timeout: int | None = None): """A context for managing cancellation and timeouts. diff --git a/tests/unit/experiment/test_craft_exp.py b/tests/unit/experiment/test_craft_exp.py index ffc4c75..2ef8f79 100644 --- a/tests/unit/experiment/test_craft_exp.py +++ b/tests/unit/experiment/test_craft_exp.py @@ -5,7 +5,7 @@ import pytest -from alphatrion.experiment.craft_exp import CraftExperiment +from alphatrion.experiment.craft_exp import CraftExperiment, ExperimentConfig from alphatrion.metadata.sql_models import Status from alphatrion.runtime.runtime import global_runtime, init from alphatrion.trial.trial import Trial, TrialConfig, current_trial_id @@ -192,7 +192,7 @@ async def fake_work(timeout: int): @pytest.mark.asyncio -async def test_craft_experiment_with_context(): +async def test_craft_experiment_with_max_execution_seconds(): init(project_id=uuid.uuid4(), artifact_insecure=True, init_tables=True) async with CraftExperiment.setup( @@ -243,3 +243,70 @@ async def fake_work(): fake_work(), ) print("All trials finished.") + + +@pytest.mark.asyncio +async def test_craft_experiment_with_config(): + init(project_id=uuid.uuid4(), artifact_insecure=True, init_tables=True) + + async with CraftExperiment.setup( + name="context_exp", + description="Context manager test", + meta={"key": "value"}, + config=ExperimentConfig(max_execution_seconds=2), + ) as exp: + trial = exp.start_trial(name="first-trial") + await trial.wait() + assert trial.is_done() + + trial = trial._get_obj() + assert trial.status == Status.COMPLETED + + +@pytest.mark.asyncio +async def test_craft_experiment_with_hierarchy_timeout(): + init(project_id=uuid.uuid4(), artifact_insecure=True, init_tables=True) + + async with CraftExperiment.setup( + name="context_exp", + description="Context manager test", + meta={"key": "value"}, + config=ExperimentConfig(max_execution_seconds=2), + ) as exp: + start_time = datetime.now() + trial = exp.start_trial( + name="first-trial", config=TrialConfig(max_execution_seconds=5) + ) + await trial.wait() + assert trial.is_done() + + assert (datetime.now() - start_time).total_seconds() >= 2 + assert (datetime.now() - start_time).total_seconds() < 5 + + trial = trial._get_obj() + assert trial.status == Status.COMPLETED + +@pytest.mark.asyncio +async def test_craft_experiment_with_hierarchy_timeout_2(): + init(project_id=uuid.uuid4(), artifact_insecure=True, init_tables=True) + + start_time = datetime.now() + + async with CraftExperiment.setup( + name="context_exp", + description="Context manager test", + meta={"key": "value"}, + config=ExperimentConfig(max_execution_seconds=5), + ) as exp: + trial = exp.start_trial( + name="first-trial", config=TrialConfig(max_execution_seconds=2) + ) + await trial.wait() + assert trial.is_done() + + assert (datetime.now() - start_time).total_seconds() >= 2 + + trial = trial._get_obj() + assert trial.status == Status.COMPLETED + + assert (datetime.now() - start_time).total_seconds() < 5