From c6158848ba85dd9d813cc53c55d54e10bc40ec83 Mon Sep 17 00:00:00 2001 From: louispt1 Date: Mon, 22 Dec 2025 08:46:40 +0100 Subject: [PATCH] Implement mixed scenarios model to hold both sessions and scenarios --- examples/saved_scenarios_everywhere.ipynb | 36 +++-- src/pyetm/models/scenarios.py | 69 +++++---- tests/models/test_scenarios_collections.py | 155 +++++++++++++-------- 3 files changed, 165 insertions(+), 95 deletions(-) diff --git a/examples/saved_scenarios_everywhere.ipynb b/examples/saved_scenarios_everywhere.ipynb index 9d52ff9..c837fc4 100644 --- a/examples/saved_scenarios_everywhere.ipynb +++ b/examples/saved_scenarios_everywhere.ipynb @@ -5,7 +5,9 @@ "id": "cba27484", "metadata": {}, "source": [ - "Testing notebook for interpreting everything as a saved scenario and delegating everywhere to scenario methods.\n" + "Testing notebook for loading mixed Scenario and Session types from Excel.\n", + "\n", + "This notebook demonstrates the simplified API where `Scenarios.from_excel()` loads both SavedScenario (from MyETM) and Session (from ETEngine) instances." ] }, { @@ -37,10 +39,9 @@ "outputs": [], "source": [ "from pyetm.models.scenarios import Scenarios\n", - "from pyetm.models.sessions import Sessions\n", "\n", - "scenarios = Scenarios.from_excel(\"../examples/inputs/example_input_excel.xlsx\")\n", - "sessions = Sessions.from_excel(\"../examples/inputs/example_input_excel.xlsx\")" + "# Load all scenarios (both SavedScenarios and Sessions) from the Excel file\n", + "mixed_scenarios = Scenarios.from_excel(\"../examples/inputs/example_input_excel.xlsx\")" ] }, { @@ -50,8 +51,9 @@ "metadata": {}, "outputs": [], "source": [ - "# Metadata\n", - "for scenario in scenarios:\n", + "# Iterate over all scenarios (both Scenario and Session instances)\n", + "# The delegation pattern means both types have the same interface\n", + "for scenario in mixed_scenarios:\n", " print(f\"Title: {scenario.title}\")\n", " print(f\"ID: {scenario.id}\")\n", " print(f\"Area: {scenario.area_code}\")\n", @@ -68,8 +70,9 @@ "metadata": {}, "outputs": [], "source": [ - "# Metadata\n", - "sessions.extend(scenarios.sessions)\n", + "# Access the underlying Session instances from all items\n", + "# For Scenario instances, this unwraps the session; for Session instances, returns them directly\n", + "sessions = mixed_scenarios.sessions\n", "\n", "for session in sessions:\n", " print(f\"Title: {session.title}\")\n", @@ -88,11 +91,13 @@ "metadata": {}, "outputs": [], "source": [ - "# Accessing the session data from the saved scenarios\n", - "for scenario in scenarios:\n", - " print(f\"Keep combatible: {scenario.keep_compatible}\")\n", - " print(f\"ID: {scenario.scenario_id}\")\n", - " print(f\"Inputs: {scenario.inputs.to_dataframe()}\")\n", + "# Accessing delegated properties works on all types\n", + "for scenario in mixed_scenarios:\n", + " print(f\"Keep compatible: {scenario.keep_compatible}\")\n", + " print(f\"ID: {scenario.id}\")\n", + " print(f\"Type: {'SavedScenario' if hasattr(scenario, 'scenario_id') and hasattr(scenario, 'session') else 'Session'}\")\n", + " # Uncomment to see inputs:\n", + " # print(f\"Inputs: {scenario.inputs.to_dataframe()}\")\n", " print(\"\")" ] }, @@ -111,8 +116,9 @@ "metadata": {}, "outputs": [], "source": [ - "# Export the scenarios to excel\n", - "# scenarios.to_excel(\"../examples/outputs/scenarios.xlsx\") # This will create scenarios.xlsx and scenarios_exports.xlsx (if you've set exports to true in the output config)." + "# Export all scenarios (both types) to Excel\n", + "# mixed_scenarios.to_excel(\"../examples/outputs/mixed_scenarios.xlsx\")\n", + "# This will create mixed_scenarios.xlsx (and mixed_scenarios_exports.xlsx if you've configured exports)" ] } ], diff --git a/src/pyetm/models/scenarios.py b/src/pyetm/models/scenarios.py index 5f13aed..04dab26 100644 --- a/src/pyetm/models/scenarios.py +++ b/src/pyetm/models/scenarios.py @@ -1,7 +1,7 @@ from __future__ import annotations from os import PathLike from pathlib import Path -from typing import Iterable, Iterator, List +from typing import Iterable, Iterator, List, Union from pydantic import Field from pyetm.models.session import Session from pyetm.models.base import Base @@ -10,37 +10,46 @@ class Scenarios(Base): """ - A collection of SavedScenario objects. + A collection of Scenario and Session objects. + + This class can hold both SavedScenario (Scenario) instances from MyETM + and Session instances from ETEngine. Since Scenario delegates all operations + to its underlying Session, both types can be used interchangeably. """ - items: List[Scenario] = Field(default_factory=list) + items: List[Union[Scenario, Session]] = Field(default_factory=list) - def __iter__(self) -> Iterator[Scenario]: + def __iter__(self) -> Iterator[Union[Scenario, Session]]: return iter(self.items) def __len__(self) -> int: return len(self.items) - def __getitem__(self, index: int) -> Scenario: + def __getitem__(self, index: int) -> Union[Scenario, Session]: return self.items[index] - def add(self, *saved_scenarios: Scenario) -> None: - self.items.extend(saved_scenarios) + def add(self, *scenarios: Union[Scenario, Session]) -> None: + self.items.extend(scenarios) - def extend(self, saved_scenarios: Iterable[Scenario]) -> None: - self.items.extend(list(saved_scenarios)) + def extend(self, scenarios: Iterable[Union[Scenario, Session]]) -> None: + self.items.extend(list(scenarios)) @property def sessions(self) -> List["Session"]: """ - Get the underlying ETEngine Scenario objects for all SavedScenarios. + Get the underlying ETEngine Session objects for all items. + + For Scenario (SavedScenario) instances, returns the underlying session. + For Session instances, returns them directly. Returns: - List of Scenario instances (the underlying sessions) + List of Session instances """ from pyetm.models.session import Session - return [saved.session for saved in self.items] + return [ + item.session if isinstance(item, Scenario) else item for item in self.items + ] @classmethod def load_many(cls, saved_scenario_ids: Iterable[int]) -> "Scenarios": @@ -63,15 +72,21 @@ def load_many(cls, saved_scenario_ids: Iterable[int]) -> "Scenarios": def to_excel(self, path: PathLike | str, **export_options) -> None: """ - Export all saved scenarios to Excel. + Export all scenarios to Excel. + + Exports both Scenario (SavedScenario) and Session instances. + For SavedScenario instances, the scenario_id column will contain + the MyETM SavedScenario ID. For Session instances, it will contain + the ETEngine session ID. - Note: This exports the underlying session data from each SavedScenario. - The scenario_id column will contain SavedScenario IDs (MyETM IDs). + Args: + path: Output path for the Excel file + **export_options: Additional export options to pass to ScenarioExcelService """ from pyetm.utils.scenario_excel_service import ScenarioExcelService if not self.items: - raise ValueError("No saved scenarios to export") + raise ValueError("No scenarios to export") resolved_path = Path(path).expanduser().resolve() ScenarioExcelService.export_to_excel( @@ -81,10 +96,17 @@ def to_excel(self, path: PathLike | str, **export_options) -> None: @classmethod def from_excel(cls, xlsx_path: PathLike | str) -> "Scenarios": """ - Import SavedScenarios from Excel file. + Import all scenarios from Excel file. + + Loads both SavedScenarios (Scenario instances from MyETM) and Sessions + (Session instances from ETEngine) based on the 'session' column value. + All scenarios are included regardless of type. - Only loads scenarios where the 'session' column is False or missing. - Scenarios with session=True are ignored. + Args: + xlsx_path: Path to Excel file + + Returns: + Scenarios collection containing all loaded scenarios (both types) """ from pyetm.models.scenario_packer import ScenarioPacker @@ -92,10 +114,9 @@ def from_excel(cls, xlsx_path: PathLike | str) -> "Scenarios": packer = ScenarioPacker.from_excel(str(resolved_path)) all_scenarios = list(packer._scenarios()) - saved_scenarios = [s for s in all_scenarios if isinstance(s, Scenario)] - if not saved_scenarios: - print(f"No SavedScenarios found in Excel file: {resolved_path}") + if not all_scenarios: + print(f"No scenarios found in Excel file: {resolved_path}") - saved_scenarios.sort(key=lambda s: s.id if hasattr(s, "id") else 0) - return cls(items=saved_scenarios) + all_scenarios.sort(key=lambda s: s.id) + return cls(items=all_scenarios) diff --git a/tests/models/test_scenarios_collections.py b/tests/models/test_scenarios_collections.py index 4cd36e0..34379c3 100644 --- a/tests/models/test_scenarios_collections.py +++ b/tests/models/test_scenarios_collections.py @@ -1,5 +1,3 @@ -"""Tests for Scenarios and SavedScenarios collection classes with from_excel filtering.""" - import pytest import pandas as pd import tempfile @@ -96,11 +94,11 @@ def test_from_excel_sorts_by_id(self, monkeypatch): assert result.items[2].id == 300 -class TestSavedScenariosFromExcel: - """Test SavedScenarios.from_excel() filtering for SavedScenario instances only.""" +class TestScenariosFromExcelMixed: + """Test Scenarios.from_excel() with mixed Scenario and Session instances.""" - def test_from_excel_filters_saved_scenarios_only(self, monkeypatch): - """Test that from_excel only returns SavedScenario instances.""" + def test_from_excel_includes_all_types(self, monkeypatch): + """Test that from_excel includes both SavedScenario and Session instances.""" session1 = Mock(spec=Session) session1.id = 100 saved1 = Mock(spec=Scenario) @@ -114,14 +112,14 @@ def test_from_excel_filters_saved_scenarios_only(self, monkeypatch): with patch.object(ScenarioPacker, "from_excel", return_value=mock_packer): result = Scenarios.from_excel("test.xlsx") - # Should only include SavedScenario instances - assert len(result.items) == 2 + # Should include all types + assert len(result.items) == 3 assert saved1 in result.items assert saved2 in result.items - assert session1 not in result.items + assert session1 in result.items - def test_from_excel_empty_when_no_saved_scenarios(self, monkeypatch): - """Test that from_excel returns empty collection when only Sessions exist.""" + def test_from_excel_all_sessions(self, monkeypatch): + """Test that from_excel returns all Sessions when only Sessions exist.""" session1 = Mock(spec=Session) session1.id = 100 session2 = Mock(spec=Session) @@ -133,8 +131,10 @@ def test_from_excel_empty_when_no_saved_scenarios(self, monkeypatch): with patch.object(ScenarioPacker, "from_excel", return_value=mock_packer): result = Scenarios.from_excel("test.xlsx") - # Should be empty - assert len(result.items) == 0 + # Should include all sessions + assert len(result.items) == 2 + assert session1 in result.items + assert session2 in result.items def test_from_excel_all_saved_scenarios(self, monkeypatch): """Test that from_excel returns all SavedScenarios when no Sessions exist.""" @@ -158,16 +158,16 @@ def test_from_excel_all_saved_scenarios(self, monkeypatch): assert saved3 in result.items def test_from_excel_sorts_by_id(self, monkeypatch): - """Test that SavedScenarios are sorted by ID.""" + """Test that mixed scenarios are sorted by ID.""" saved1 = Mock(spec=Scenario) saved1.id = 300 + session1 = Mock(spec=Session) + session1.id = 100 saved2 = Mock(spec=Scenario) - saved2.id = 100 - saved3 = Mock(spec=Scenario) - saved3.id = 200 + saved2.id = 200 mock_packer = Mock(spec=ScenarioPacker) - mock_packer._scenarios.return_value = {saved1, saved2, saved3} + mock_packer._scenarios.return_value = {saved1, session1, saved2} with patch.object(ScenarioPacker, "from_excel", return_value=mock_packer): result = Scenarios.from_excel("test.xlsx") @@ -178,35 +178,63 @@ def test_from_excel_sorts_by_id(self, monkeypatch): assert result.items[2].id == 300 -class TestSavedScenariosSessionsProperty: - """Test SavedScenarios.sessions property for accessing underlying Scenario objects.""" +class TestScenariosSessionsProperty: + """Test Scenarios.sessions property for accessing underlying Session objects.""" - def test_sessions_property_returns_list_of_scenarios(self): - """Test that sessions property returns list of underlying Scenario objects.""" + def test_sessions_property_returns_list_of_sessions(self): + """Test that sessions property returns list of underlying Session objects.""" # Create mock SavedScenarios with mock sessions - scenario1 = Mock(spec=Session) - scenario1.id = 100 - scenario2 = Mock(spec=Session) - scenario2.id = 200 + session1 = Mock(spec=Session) + session1.id = 100 + session2 = Mock(spec=Session) + session2.id = 200 saved1 = Mock(spec=Scenario) saved1.id = 1 - saved1.session = scenario1 + saved1.session = session1 saved2 = Mock(spec=Scenario) saved2.id = 2 - saved2.session = scenario2 + saved2.session = session2 collection = Scenarios(items=[saved1, saved2]) # Access sessions property sessions = collection.sessions - # Should return list of Scenario objects + # Should return list of Session objects assert isinstance(sessions, list) assert len(sessions) == 2 - assert sessions[0] is scenario1 - assert sessions[1] is scenario2 + assert sessions[0] is session1 + assert sessions[1] is session2 + + def test_sessions_property_mixed_types(self): + """Test that sessions property handles mixed Scenario and Session instances.""" + # Create mock instances + session1 = Mock(spec=Session) + session1.id = 100 + + session2 = Mock(spec=Session) + session2.id = 200 + + saved_session = Mock(spec=Session) + saved_session.id = 300 + + saved = Mock(spec=Scenario) + saved.id = 3 + saved.session = saved_session + + collection = Scenarios(items=[session1, saved, session2]) + + # Access sessions property + sessions = collection.sessions + + # Should return all Sessions, unwrapping Scenario instances + assert isinstance(sessions, list) + assert len(sessions) == 3 + assert sessions[0] is session1 + assert sessions[1] is saved_session # Unwrapped from Scenario + assert sessions[2] is session2 def test_sessions_property_empty_collection(self): """Test that sessions property returns empty list for empty collection.""" @@ -218,25 +246,39 @@ def test_sessions_property_empty_collection(self): def test_sessions_property_single_saved_scenario(self): """Test sessions property with single SavedScenario.""" - scenario = Mock(spec=Session) - scenario.id = 100 + session = Mock(spec=Session) + session.id = 100 saved = Mock(spec=Scenario) saved.id = 1 - saved.session = scenario + saved.session = session collection = Scenarios(items=[saved]) sessions = collection.sessions assert len(sessions) == 1 - assert sessions[0] is scenario + assert sessions[0] is session + + def test_sessions_property_only_sessions(self): + """Test sessions property with only Session instances.""" + session1 = Mock(spec=Session) + session1.id = 100 + session2 = Mock(spec=Session) + session2.id = 200 + + collection = Scenarios(items=[session1, session2]) + sessions = collection.sessions + + assert len(sessions) == 2 + assert sessions[0] is session1 + assert sessions[1] is session2 -class TestMixedScenariosSeparation: - """Test that mixed Excel files correctly separate Sessions and SavedScenarios.""" +class TestMixedScenariosHandling: + """Test that Scenarios handles mixed types while Sessions filters.""" - def test_mixed_excel_separates_correctly(self, monkeypatch): - """Test that same Excel file returns different results for Scenarios vs SavedScenarios.""" + def test_scenarios_includes_all_sessions_filters(self, monkeypatch): + """Test that Scenarios includes all while Sessions filters to Sessions only.""" # Create mixed scenarios session1 = Mock(spec=Session) session1.id = 100 @@ -251,28 +293,28 @@ def test_mixed_excel_separates_correctly(self, monkeypatch): mock_packer._scenarios.return_value = {session1, session2, saved1, saved2} with patch.object(ScenarioPacker, "from_excel", return_value=mock_packer): - # Load as Scenarios (Sessions only) + # Load as Sessions (Sessions only) sessions_result = Sessions.from_excel("test.xlsx") - # Load as SavedScenarios (SavedScenarios only) - saved_result = Scenarios.from_excel("test.xlsx") + # Load as Scenarios (all types) + scenarios_result = Scenarios.from_excel("test.xlsx") - # Verify Sessions collection + # Verify Sessions collection - only Session instances assert len(sessions_result.items) == 2 assert session1 in sessions_result.items assert session2 in sessions_result.items assert saved1 not in sessions_result.items assert saved2 not in sessions_result.items - # Verify SavedScenarios collection - assert len(saved_result.items) == 2 - assert saved1 in saved_result.items - assert saved2 in saved_result.items - assert session1 not in saved_result.items - assert session2 not in saved_result.items + # Verify Scenarios collection - includes all types + assert len(scenarios_result.items) == 4 + assert saved1 in scenarios_result.items + assert saved2 in scenarios_result.items + assert session1 in scenarios_result.items + assert session2 in scenarios_result.items - def test_no_overlap_between_collections(self, monkeypatch): - """Test that there's no overlap between the two collections.""" + def test_scenarios_is_superset_of_sessions(self, monkeypatch): + """Test that Scenarios contains everything Sessions contains (and more).""" session1 = Mock(spec=Session) session1.id = 100 saved1 = Mock(spec=Scenario) @@ -283,12 +325,13 @@ def test_no_overlap_between_collections(self, monkeypatch): with patch.object(ScenarioPacker, "from_excel", return_value=mock_packer): sessions_result = Sessions.from_excel("test.xlsx") - saved_result = Scenarios.from_excel("test.xlsx") + scenarios_result = Scenarios.from_excel("test.xlsx") # Get all items from both collections all_sessions = set(sessions_result.items) - all_saved = set(saved_result.items) + all_scenarios = set(scenarios_result.items) - # Verify no overlap - assert len(all_sessions & all_saved) == 0 - assert len(all_sessions) + len(all_saved) == 2 + # Verify Sessions is a subset of Scenarios + assert all_sessions.issubset(all_scenarios) + assert len(all_scenarios) == 2 + assert len(all_sessions) == 1