Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
436 changes: 436 additions & 0 deletions examples/interpolation_demo.ipynb

Large diffs are not rendered by default.

62 changes: 50 additions & 12 deletions src/pyetm/clients/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import asyncio
import threading
from typing import Optional, Dict, Any
from pydantic import BaseModel, field_validator, ValidationInfo
from pydantic import BaseModel

import aiohttp
from pyetm.config.settings import get_settings
Expand Down Expand Up @@ -43,22 +43,57 @@ def content(self) -> bytes:
"""Get response content as bytes."""
return self._content or self.text.encode("utf-8")

@field_validator("status_code", mode="before")
@classmethod
def raise_for_status(cls, value, info: ValidationInfo) -> None:
def _format_error_message(self) -> str:
"""Format error message from API response in a readable way."""
import json

try:
# Try to parse as JSON
data = json.loads(self.text)

# Handle {"errors": {...}} format
if isinstance(data, dict) and "errors" in data:
errors = data["errors"]

# Handle {"errors": {"field": ["message1", "message2"]}}
if isinstance(errors, dict):
messages = []
for field, field_errors in errors.items():
if isinstance(field_errors, list):
for error in field_errors:
messages.append(f"{field}: {error}")
else:
messages.append(f"{field}: {field_errors}")

if messages:
return f"{self.status_code}: {', '.join(messages)}"

# Handle {"errors": ["message1", "message2"]}
elif isinstance(errors, list):
return f"{self.status_code}: {', '.join(str(e) for e in errors)}"

# Handle {"error": "message"} format
if isinstance(data, dict) and "error" in data:
return f"{self.status_code}: {data['error']}"

except (json.JSONDecodeError, KeyError, TypeError):
pass

# Fall back to raw text (strip to avoid unnecessary whitespace)
text = self.text.strip()
return f"{self.status_code}: {text}" if text else f"{self.status_code}"

def raise_for_status(self) -> None:
"""Raise appropriate exception for HTTP errors."""
if value == 401:
if self.status_code == 401:
raise PermissionError("Invalid or missing ETM_API_TOKEN")

text = info.data.get("text", "")

if 400 <= value < 500:
raise ValueError(f"HTTP {value}: {text}")
if 400 <= self.status_code < 500:
raise ValueError(self._format_error_message())

if 500 <= value < 600:
raise ConnectionError(f"HTTP {value}: {text}")
if 500 <= self.status_code < 600:
raise ConnectionError(self._format_error_message())

return value


# TODO: Extract utils and organise private methods
Expand Down Expand Up @@ -173,6 +208,9 @@ async def async_request(self, method: str, url: str, **kwargs) -> ETMResponse:
etm_response.text = await response.text()
etm_response._content = await response.read()

# Check for HTTP errors after response is fully constructed
etm_response.raise_for_status()

return etm_response

def _build_request_kwargs(self, **kwargs) -> dict:
Expand Down
49 changes: 45 additions & 4 deletions src/pyetm/models/scenario.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def create(

Args:
params: Dictionary with required keys (scenario_id, title) and optional keys
(description, private)
(private)
client: Optional BaseClient instance

Returns:
Expand Down Expand Up @@ -111,7 +111,7 @@ def from_scenario(
scenario: Scenario instance to save
title: Title for the saved scenario
client: Optional BaseClient instance
**kwargs: Optional params (description, private)
**kwargs: Optional params (private)

Returns:
SavedScenario instance
Expand Down Expand Up @@ -169,7 +169,7 @@ def new(
scenario_id: The ETEngine scenario ID to save
title: Title for the saved scenario
client: Optional BaseClient instance
**kwargs: Optional params (description, private)
**kwargs: Optional params (private)

Returns:
SavedScenario instance
Expand Down Expand Up @@ -209,7 +209,7 @@ def update(self, client: Optional[BaseClient] = None, **kwargs) -> None:

Args:
client: Optional BaseClient instance
**kwargs: Fields to update (title, description, private, discarded)
**kwargs: Fields to update (title, private, discarded)
"""
if client is None:
client = BaseClient()
Expand Down Expand Up @@ -397,6 +397,47 @@ def deep_copy(self, **overrides) -> "Session":
"""Create a deep copy of the underlying session."""
return self.session.deep_copy(**overrides)

@classmethod
def interpolate(
cls,
scenarios: Union["Scenario", List["Scenario"]],
*end_years: int,
titles: Optional[List[str]] = None,
client: Optional[BaseClient] = None,
**kwargs,
) -> List["Scenario"]:
"""
Interpolate one or more saved scenarios to target years and save to MyETM.
"""
end_years_list = list(end_years)

if titles is not None and len(titles) != len(end_years_list):
raise ValueError(
f"Length of titles ({len(titles)}) must match length of "
f"end_years ({len(end_years_list)})"
)

# Get underlying sessions and perform interpolation
from pyetm.models.session import Session

scenario_list = scenarios if isinstance(scenarios, list) else [scenarios]
sessions = [sc.session for sc in scenario_list]
interpolated_sessions = Session.interpolate(sessions, *end_years, client=client)

# Save each interpolated session as a SavedScenario
saved_scenarios_list = []
for i, session in enumerate(interpolated_sessions):
# Generate title if not provided
if titles:
title = titles[i]
else:
title = f"Interpolated to {session.end_year}"

saved = session.save(client=client, title=title, **kwargs)
saved_scenarios_list.append(saved)

return saved_scenarios_list

def to_excel(self, path: PathLike | str, **export_options) -> None:
"""Export this saved scenario to Excel."""
self.session.to_excel(path, **export_options)
Expand Down
34 changes: 34 additions & 0 deletions src/pyetm/models/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@
from pyetm.services.scenario_runners.update_couplings import UpdateCouplingsRunner
from pyetm.services.scenario_runners.copy_scenario import CopyScenarioRunner
from pyetm.services.scenario_runners.break_preset_link import BreakPresetLinkRunner
from pyetm.services.scenario_runners.interpolate_scenarios import (
InterpolateScenariosRunner,
)


class ScenarioError(Exception):
Expand Down Expand Up @@ -159,6 +162,37 @@ def copy(self, **overrides) -> "Session":

return new_scenario

@classmethod
def interpolate(
cls,
sessions: Union["Session", List["Session"]],
*end_years: int,
client: Optional[BaseClient] = None,
) -> List["Session"]:
"""
Interpolate one or more sessions to create scenarios at target years.
"""
client = client or BaseClient()

session_list = sessions if isinstance(sessions, list) else [sessions]
scenario_ids = [s.id for s in session_list]
end_years_list = list(end_years)

result = InterpolateScenariosRunner.run(client, scenario_ids, end_years_list)

if not result.success:
raise ScenarioError(f"Interpolation failed: {result.errors}")

# Create Session instances from the response data
interpolated_sessions = []
for scenario_data in result.data:
scenario = cls.model_validate(scenario_data)
for warning in result.errors:
scenario.add_warning("base", warning)
interpolated_sessions.append(scenario)

return interpolated_sessions

@classmethod
def from_excel(cls, xlsx_path: PathLike | str) -> "Session":
"""
Expand Down
5 changes: 2 additions & 3 deletions src/pyetm/services/scenario_runners/create_saved_scenario.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@ class CreateSavedScenarioRunner(BaseRunner[Dict[str, Any]]):

Args:
client: The HTTP client to use
saved_scenario_data: Dictionary with scenario_id, title, description, private
saved_scenario_data: Dictionary with scenario_id, title, private
**kwargs: Additional arguments passed to the request
"""

REQUIRED_KEYS = ["scenario_id", "title"]
OPTIONAL_KEYS = ["description", "private"]
OPTIONAL_KEYS = ["private"]

@staticmethod
def run(
Expand All @@ -32,7 +32,6 @@ def run(
saved_scenario_data={
"scenario_id": 123,
"title": "My Saved Scenario",
"description": "Optional description",
"private": False
}
)
Expand Down
35 changes: 35 additions & 0 deletions src/pyetm/services/scenario_runners/interpolate_scenarios.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from typing import Any, Dict, List
from pyetm.services.scenario_runners.base_runner import BaseRunner
from ..service_result import ServiceResult
from pyetm.clients.base_client import BaseClient


class InterpolateScenariosRunner(BaseRunner[List[Dict[str, Any]]]):
"""
Runner for interpolating one or more scenarios to create intermediate year scenarios.

POST /api/v3/scenarios/interpolate

Supports both single scenario and batch interpolation use cases.
"""

@staticmethod
def run(
client: BaseClient,
scenario_ids: List[int],
end_years: List[int],
**kwargs,
) -> ServiceResult[List[Dict[str, Any]]]:
payload = {
"scenario_ids": scenario_ids,
"end_years": end_years,
}

result = InterpolateScenariosRunner._make_request(
client=client,
method="post",
path="/scenarios/interpolate",
payload=payload,
)

return result
4 changes: 2 additions & 2 deletions src/pyetm/services/scenario_runners/update_saved_scenario.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class UpdateSavedScenarioRunner(BaseRunner[Dict[str, Any]]):
Args:
client: The HTTP client to use
saved_scenario_id: ID of the SavedScenario to update
update_data: Dictionary with fields to update (title, description, private, discarded)
update_data: Dictionary with fields to update (title, private, discarded)
**kwargs: Additional arguments passed to the request
"""

Expand All @@ -35,7 +35,7 @@ def run(
saved_scenario_id=123,
update_data={
"title": "Updated Title",
"description": "New description"
"private": False
}
)
"""
Expand Down
Loading