From b6f08c97bba8232a6c36cef2f99503b8de53b6c2 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 28 Oct 2025 17:42:19 +0000 Subject: [PATCH 01/24] Initial plan From 5c99ac094b5b826263be54fc1439b99a7593570a Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 28 Oct 2025 17:54:12 +0000 Subject: [PATCH 02/24] Merge RBMDataSet and RBMDataSetElPaso functionality into single RBMDataSet class Co-authored-by: sahiljhawar <55475299+sahiljhawar@users.noreply.github.com> --- swvo/io/RBMDataSet/RBMDataSet.py | 212 +++++++++++++++--- swvo/io/RBMDataSet/__init__.py | 1 - .../bin_and_interpolate_to_model_grid.py | 4 +- tests/io/RBMDataSet/test_RBMDataset.py | 36 +-- tests/io/RBMDataSet/test_RBMDatasetElPaso.py | 14 +- 5 files changed, 207 insertions(+), 60 deletions(-) diff --git a/swvo/io/RBMDataSet/RBMDataSet.py b/swvo/io/RBMDataSet/RBMDataSet.py index c72e5e7..bf30284 100644 --- a/swvo/io/RBMDataSet/RBMDataSet.py +++ b/swvo/io/RBMDataSet/RBMDataSet.py @@ -6,9 +6,10 @@ import datetime as dt import typing +from dataclasses import replace from datetime import timedelta, timezone from pathlib import Path -from typing import Any, Literal +from typing import Any, ClassVar, Literal, TypeVar import distance import numpy as np @@ -27,17 +28,45 @@ Variable, VariableEnum, ) +from swvo.io.RBMDataSet.custom_enums import ElPasoMFMEnum from swvo.io.RBMDataSet.utils import ( get_file_path_any_format, join_var, load_file_any_format, matlab2python, + python2matlab, ) +ElPasoVariable = TypeVar("ElPasoVariable") # Placeholder for ElPaso Variable type + class RBMDataSet: """RBMDataSet class for loading and managing data. + This class can load data either from files or from a dictionary (ElPaso format). + + For file-based loading, provide start_time, end_time, and folder_path. + For dictionary-based loading, initialize without these parameters and use update_from_dict(). + + Parameters + ---------- + satellite : :class:`SatelliteLike` + Satellite identifier as enum or string. + instrument : :class:`InstrumentLike` + Instrument enumeration or string. + mfm : :class:`MfmLike` + Magnetic field model enum or string. + start_time : dt.datetime, optional + Start time for file-based loading. + end_time : dt.datetime, optional + End time for file-based loading. + folder_path : Path, optional + Base folder path for file-based loading. + preferred_extension : Literal["mat", "pickle"], optional + Preferred file extension for file-based loading. Default is "pickle". + verbose : bool, optional + Whether to print verbose output. Default is True. + Attributes ---------- datetime : list[dt.datetime] @@ -65,6 +94,37 @@ class RBMDataSet: """ _preferred_ext: str + _variable_mapping: ClassVar[dict[str, str]] = { + "Epoch_posixtime": "time", + "Energy_FEDU": "energy_channels", + "Energy_FPDU": "energy_channels", + "Energy_FEIU": "energy_channels", + "Energy_FEDO": "energy_channels", + "PA_local": "alpha_local", + "PA_eq_": "alpha_eq_model", + "alpha_eq_real": "alpha_eq_real", + "invMu_": "InvMu", + "InvMu_real": "InvMu_real", + "invK_": "InvK", + # "InvV": "InvV",# computed property + "Lstar_": "Lstar", + "FEDU": "Flux", + "FPDU": "Flux", + "FEIU": "Flux", + "FEDO": "Flux", + "PSD_FEDU": "PSD", + "PSD_FPDU": "PSD", + "PSD_FEIU": "PSD", + "PSD_FEDO": "PSD", + "MLT_": "MLT", + "B_SM": "B_SM", + "B_eq_": "B_total", + "B_local_": "B_sat", + "xGEO": "xGEO", + # "P": "P",# computed property + "R_eq_": "R0", + "density": "density", + } datetime: list[dt.datetime] time: NDArray[np.float64] @@ -90,50 +150,74 @@ class RBMDataSet: def __init__( self, - start_time: dt.datetime, - end_time: dt.datetime, - folder_path: Path, satellite: SatelliteLike, instrument: InstrumentLike, mfm: MfmLike, + start_time: dt.datetime | None = None, + end_time: dt.datetime | None = None, + folder_path: Path | None = None, preferred_extension: Literal["mat", "pickle"] = "pickle", *, verbose: bool = True, ) -> None: - if start_time.tzinfo is None: - start_time = start_time.replace(tzinfo=timezone.utc) - - if end_time.tzinfo is None: - end_time = end_time.replace(tzinfo=timezone.utc) - - self._start_time = start_time - self._end_time = end_time - + # Handle satellite conversion with special cases for GOES if isinstance(satellite, str): - satellite = SatelliteEnum[satellite.upper()] - self._satellite = satellite - + if satellite.lower() == "goesprimary": + satellite = SatelliteEnum["GOESPrimary"] + elif satellite.lower() == "goessecondary": + satellite = SatelliteEnum["GOESSecondary"] + else: + satellite = SatelliteEnum[satellite.upper()] + if isinstance(instrument, str): instrument = InstrumentEnum[instrument.upper()] - self._instrument = instrument - + if isinstance(mfm, str): mfm = MfmEnum[mfm.upper()] + + # Store the original satellite enum for properties and other attributes + self._satellite_enum = satellite + self._instrument = instrument self._mfm = mfm - - self._folder_path = Path(folder_path) - - self._preferred_ext = preferred_extension - self._folder_type = self._satellite.folder_type self._verbose = verbose - - self._file_path_stem = self._create_file_path_stem() - self._file_name_stem = self._create_file_name_stem() - self._file_cadence = self._satellite.file_cadence - self._date_of_files = self._create_date_list() + + # For dict-based loading (ElPaso mode), modify satellite properties + if start_time is None and end_time is None and folder_path is None: + # ElPaso mode: no file loading needed + satellite_obj = replace( + satellite.value, + folder_type=FolderTypeEnum.NoFolder, + file_cadence=FileCadenceEnum.NoCadence, + ) + self._satellite = satellite_obj + self._mfm_prefix = ElPasoMFMEnum[mfm.name].value + self._file_loading_mode = False + else: + # File loading mode: need all parameters + if start_time is None or end_time is None or folder_path is None: + msg = "For file-based loading, start_time, end_time, and folder_path must all be provided" + raise ValueError(msg) + + if start_time.tzinfo is None: + start_time = start_time.replace(tzinfo=timezone.utc) + + if end_time.tzinfo is None: + end_time = end_time.replace(tzinfo=timezone.utc) + + self._start_time = start_time + self._end_time = end_time + self._satellite = satellite + self._folder_path = Path(folder_path) + self._preferred_ext = preferred_extension + self._folder_type = self._satellite.folder_type + self._file_path_stem = self._create_file_path_stem() + self._file_name_stem = self._create_file_name_stem() + self._file_cadence = self._satellite.file_cadence + self._date_of_files = self._create_date_list() + self._file_loading_mode = True def __repr__(self): - return f"{self.__class__.__name__}({self._satellite}, {self._instrument}, {self._mfm})" + return f"{self.__class__.__name__}({self._satellite_enum}, {self._instrument}, {self._mfm})" def __str__(self): return self.__repr__() @@ -142,6 +226,18 @@ def __dir__(self): return list(super().__dir__()) + [var.var_name for var in VariableEnum] def __getattr__(self, name: str): + # Avoid recursion for internal attributes + if name.startswith("_"): + raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'") + + # Handle computed properties for both modes + if name == "P" and hasattr(self, "MLT"): + return ((self.MLT + 12) / 12 * np.pi) % (2 * np.pi) + + if name == "InvV" and hasattr(self, "InvK") and hasattr(self, "InvMu"): + inv_K_repeated = np.repeat(self.InvK[:, np.newaxis, :], self.InvMu.shape[1], axis=1) + return self.InvMu * (inv_K_repeated + 0.5) ** 2 + # check if a sat variable is requested # if we find a similar word, suggest that to the user sat_variable = None @@ -159,11 +255,17 @@ def __getattr__(self, name: str): levenstein_info["min_distance"] = dist levenstein_info["var_name"] = var.var_name - # if yes, load it - if sat_variable is not None: + # if yes, load it (only in file loading mode) + if sat_variable is not None and hasattr(self, "_file_loading_mode") and self._file_loading_mode: self._load_variable(sat_variable) - return getattr(self, name) + + # For dict-loading mode, check if it's a mapped variable + if hasattr(self, "_file_loading_mode") and not self._file_loading_mode and name in self._variable_mapping.values(): + raise AttributeError( + f"Attribute '{name}' is mapped but has not been set. " + "Make sure data is loaded or that this attribute is properly assigned." + ) if levenstein_info["min_distance"] <= 2: msg = f"{self.__class__.__name__} object has no attribute {name}. Maybe you meant {levenstein_info['var_name']}?" @@ -178,6 +280,52 @@ def __getattr__(self, name: str): # def __setitem__(self, key, value): # setattr(self, key, value) + @property + def satellite(self) -> SatelliteEnum: + """Returns the satellite enum.""" + return self._satellite_enum + + @property + def instrument(self) -> InstrumentEnum: + """Returns the instrument enum.""" + return self._instrument + + @property + def mfm(self) -> MfmEnum: + """Returns the MFM enum.""" + return self._mfm + + @property + def variable_mapping(self) -> dict[str, str]: + """Returns the variable mapping dictionary.""" + return self._variable_mapping + + def update_from_dict(self, source_dict: dict[str, ElPasoVariable]) -> None: + """Get data from ElPaso data dictionary and update the object. + + Parameters + ---------- + source_dict : dict[str, Any] + Dictionary containing the data to be loaded into the object. + + """ + for _, value in source_dict.items(): + if value.standard_name in self._variable_mapping: + target_attr = self._variable_mapping[value.standard_name] + + if value.standard_name == "Epoch_posixtime" and target_attr == "time": + datetimes = [dt.datetime.fromtimestamp(ts, tz=timezone.utc) for ts in value.data] + setattr(self, "datetime", datetimes) + setattr(self, "time", [python2matlab(i) for i in datetimes]) + else: + setattr(self, target_attr, value.data) + + elif hasattr(self, "_mfm_prefix") and value.standard_name.endswith(self._mfm_prefix): + base_key = value.standard_name.replace(self._mfm_prefix, "") + if base_key in self._variable_mapping: + target_attr = self._variable_mapping[base_key] + setattr(self, target_attr, value.data) + def get_var(self, var: VariableEnum): return getattr(self, var.var_name) diff --git a/swvo/io/RBMDataSet/__init__.py b/swvo/io/RBMDataSet/__init__.py index f77d064..8312052 100644 --- a/swvo/io/RBMDataSet/__init__.py +++ b/swvo/io/RBMDataSet/__init__.py @@ -23,4 +23,3 @@ from swvo.io.RBMDataSet.scripts.create_RBSP_line_data import create_RBSP_line_data as create_RBSP_line_data from swvo.io.RBMDataSet.RBMDataSet import RBMDataSet as RBMDataSet from swvo.io.RBMDataSet.RBMNcDataSet import RBMNcDataSet as RBMNcDataSet -from swvo.io.RBMDataSet.RBMDataSetElPaso import RBMDataSetElPaso as RBMDataSetElPaso diff --git a/swvo/io/RBMDataSet/bin_and_interpolate_to_model_grid.py b/swvo/io/RBMDataSet/bin_and_interpolate_to_model_grid.py index 7d80409..2504023 100644 --- a/swvo/io/RBMDataSet/bin_and_interpolate_to_model_grid.py +++ b/swvo/io/RBMDataSet/bin_and_interpolate_to_model_grid.py @@ -18,11 +18,11 @@ from tqdm import tqdm if TYPE_CHECKING: - from swvo.io.RBMDataSet import RBMDataSet, RBMDataSetElPaso + from swvo.io.RBMDataSet import RBMDataSet def bin_and_interpolate_to_model_grid( - self: RBMDataSet | RBMDataSetElPaso, + self: RBMDataSet, sim_time: list[datetime], grid_R: NDArray[np.float64], grid_mu_V: NDArray[np.float64], diff --git a/tests/io/RBMDataSet/test_RBMDataset.py b/tests/io/RBMDataSet/test_RBMDataset.py index 118cbef..2e8a70e 100644 --- a/tests/io/RBMDataSet/test_RBMDataset.py +++ b/tests/io/RBMDataSet/test_RBMDataset.py @@ -42,12 +42,12 @@ def mock_dataset(): } dataset = RBMDataSet( - start_time=start_time, - end_time=end_time, - folder_path=Path("/mock/path"), satellite=SatelliteEnum.RBSPA, instrument=InstrumentEnum.MAGEIS, mfm=MfmEnum.T89, + start_time=start_time, + end_time=end_time, + folder_path=Path("/mock/path"), preferred_extension="pickle", verbose=False, ) @@ -67,12 +67,12 @@ def test_init_datetime_timezone(mock_module_string): mock.patch(f"{mock_module_string}._create_file_name_stem"), ): dataset = RBMDataSet( - start_time=start_time, - end_time=end_time, - folder_path=Path("/mock/path"), satellite=SatelliteEnum.RBSPA, instrument=InstrumentEnum.MAGEIS, mfm=MfmEnum.T89, + start_time=start_time, + end_time=end_time, + folder_path=Path("/mock/path"), preferred_extension="pickle", ) @@ -101,12 +101,12 @@ def test_satellite_string_input(mock_module_string): with mock.patch(f"{mock_module_string}._create_file_path_stem"): with mock.patch(f"{mock_module_string}._create_file_name_stem"): dataset = RBMDataSet( - start_time=dt.datetime(2023, 1, 1, tzinfo=timezone.utc), - end_time=dt.datetime(2023, 1, 31, tzinfo=timezone.utc), - folder_path=Path("/mock/path"), satellite="RBSPA", instrument=InstrumentEnum.MAGEIS, mfm=MfmEnum.T89, + start_time=dt.datetime(2023, 1, 1, tzinfo=timezone.utc), + end_time=dt.datetime(2023, 1, 31, tzinfo=timezone.utc), + folder_path=Path("/mock/path"), preferred_extension="pickle", ) @@ -169,12 +169,12 @@ def test_all_satellites_work(satellite, mock_module_string): with mock.patch(f"{mock_module_string}._create_file_path_stem"): with mock.patch(f"{mock_module_string}._create_file_name_stem"): dataset = RBMDataSet( - start_time=dt.datetime(2023, 1, 1, tzinfo=timezone.utc), - end_time=dt.datetime(2023, 1, 31, tzinfo=timezone.utc), - folder_path=Path("/mock/path"), satellite=satellite, instrument=InstrumentEnum.HOPE, mfm=MfmEnum.T89, + start_time=dt.datetime(2023, 1, 1, tzinfo=timezone.utc), + end_time=dt.datetime(2023, 1, 31, tzinfo=timezone.utc), + folder_path=Path("/mock/path"), ) assert dataset._satellite == satellite @@ -186,12 +186,12 @@ def test_all_instruments_work(instrument, mock_module_string): with mock.patch(f"{mock_module_string}._create_file_path_stem"): with mock.patch(f"{mock_module_string}._create_file_name_stem"): dataset = RBMDataSet( - start_time=dt.datetime(2023, 1, 1, tzinfo=timezone.utc), - end_time=dt.datetime(2023, 1, 31, tzinfo=timezone.utc), - folder_path=Path("/mock/path"), satellite=SatelliteEnum.RBSPA, instrument=instrument, mfm=MfmEnum.T89, + start_time=dt.datetime(2023, 1, 1, tzinfo=timezone.utc), + end_time=dt.datetime(2023, 1, 31, tzinfo=timezone.utc), + folder_path=Path("/mock/path"), ) assert dataset._instrument == instrument @@ -250,12 +250,12 @@ def test_load_variable_real_file(): end_time = dt.datetime(2025, 4, 30, tzinfo=dt.timezone.utc) dataset = RBMDataSet( - start_time=start_time, - end_time=end_time, - folder_path=Path("path/to/real/files"), # this does not matter for the test satellite=SatelliteEnum.GOESSecondary, instrument=InstrumentEnum.MAGED, mfm=MfmEnum.T89, + start_time=start_time, + end_time=end_time, + folder_path=Path("path/to/real/files"), # this does not matter for the test preferred_extension="pickle", verbose=True, ) diff --git a/tests/io/RBMDataSet/test_RBMDatasetElPaso.py b/tests/io/RBMDataSet/test_RBMDatasetElPaso.py index efa1244..e859fde 100644 --- a/tests/io/RBMDataSet/test_RBMDatasetElPaso.py +++ b/tests/io/RBMDataSet/test_RBMDatasetElPaso.py @@ -11,7 +11,7 @@ from swvo.io.RBMDataSet import ( InstrumentEnum, MfmEnum, - RBMDataSetElPaso, + RBMDataSet, SatelliteEnum, SatelliteLiteral, VariableEnum, @@ -29,7 +29,7 @@ def __init__(self, standard_name, data=None): @pytest.fixture def dataset(): - return RBMDataSetElPaso( + return RBMDataSet( satellite=SatelliteEnum.GOESSecondary, instrument=InstrumentEnum.MAGED, mfm=MfmEnum.T89, @@ -38,7 +38,7 @@ def dataset(): def test_init_accepts_string_inputs(): """Test that the class can be initialized with string inputs.""" - ds = RBMDataSetElPaso(satellite="GOESSecondary", instrument="MAGED", mfm=MfmEnum.T89) + ds = RBMDataSet(satellite="GOESSecondary", instrument="MAGED", mfm=MfmEnum.T89) assert ds.satellite.sat_name == "secondary" assert ds.instrument == InstrumentEnum.MAGED assert ds.mfm == MfmEnum.T89 @@ -51,12 +51,12 @@ def test_variable_mapping_exposed(dataset): def test_repr_and_str(dataset): - assert "RBMDataSetElPaso" in repr(dataset) + assert "RBMDataSet" in repr(dataset) assert str(dataset.satellite) in repr(dataset) assert str(dataset.instrument) in repr(dataset) assert str(dataset.mfm) in repr(dataset) - assert "RBMDataSetElPaso" in str(dataset) + assert "RBMDataSet" in str(dataset) assert str(dataset.satellite) in str(dataset) assert str(dataset.instrument) in str(dataset) assert str(dataset.mfm) in str(dataset) @@ -221,7 +221,7 @@ def test_all_mfm_specific_mappings(dataset): @pytest.mark.parametrize("satellite, expected", [("goessecondary", "secondary"), ("goesprimary", "primary")]) def test_goes_lowercase(satellite, expected): - goes_dataset = RBMDataSetElPaso( + goes_dataset = RBMDataSet( satellite=satellite, instrument=InstrumentEnum.MAGED, mfm=MfmEnum.T89, @@ -234,7 +234,7 @@ def test_goes_lowercase(satellite, expected): @pytest.mark.parametrize("satellite", [i.lower() for i in uppercase_satellites]) def test_satellite_lowercase(satellite): - dataset = RBMDataSetElPaso( + dataset = RBMDataSet( satellite=satellite, instrument=InstrumentEnum.MAGED, mfm=MfmEnum.T89, From 822fa2e6648cd429ef777599cc5157418ba51bd3 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 28 Oct 2025 17:56:04 +0000 Subject: [PATCH 03/24] Remove RBMDataSetElPaso.py and update all imports Co-authored-by: sahiljhawar <55475299+sahiljhawar@users.noreply.github.com> --- swvo/io/RBMDataSet/RBMDataSetElPaso.py | 239 ------------------------ swvo/io/RBMDataSet/RBMDataSetManager.py | 12 +- swvo/io/RBMDataSet/RBMNcDataSet.py | 10 +- 3 files changed, 11 insertions(+), 250 deletions(-) delete mode 100644 swvo/io/RBMDataSet/RBMDataSetElPaso.py diff --git a/swvo/io/RBMDataSet/RBMDataSetElPaso.py b/swvo/io/RBMDataSet/RBMDataSetElPaso.py deleted file mode 100644 index d5a4fac..0000000 --- a/swvo/io/RBMDataSet/RBMDataSetElPaso.py +++ /dev/null @@ -1,239 +0,0 @@ -# SPDX-FileCopyrightText: 2025 GFZ Helmholtz Centre for Geosciences -# -# SPDX-License-Identifier: Apache-2.0 - -from __future__ import annotations - -import datetime as dt -from dataclasses import replace -from datetime import timezone -from typing import Any, ClassVar, TypeVar - -import numpy as np -from numpy.typing import NDArray - -from swvo.io.RBMDataSet import ( - FileCadenceEnum, - FolderTypeEnum, - InstrumentEnum, - MfmEnum, - SatelliteEnum, - SatelliteLike, - VariableEnum, -) -from swvo.io.RBMDataSet.custom_enums import ElPasoMFMEnum -from swvo.io.RBMDataSet.utils import python2matlab - -Variable = TypeVar( - "Variable" -) # this is a placeholder for the actual Variable class from elpaso and not the one in RBMDataSet - - -class RBMDataSetElPaso: - """RBMDataSetElPaso class for loading ElPaso data to RBMDataSet. - - Parameters - ---------- - satellite : :class:`SatelliteLike` - Satellite identifier as enum or string. - instrument : :class:`InstrumentEnum` - Instrument enumeration. - mfm : :class:`MfmEnum` - Magnetic field model enum. - - - Attributes - ---------- - datetime : list[dt.datetime] - time : NDArray[np.float64] - energy_channels : NDArray[np.float64] - alpha_local : NDArray[np.float64] - alpha_eq_model : NDArray[np.float64] - alpha_eq_real : NDArray[np.float64] - InvMu : NDArray[np.float64] - InvMu_real : NDArray[np.float64] - InvK : NDArray[np.float64] - InvV : NDArray[np.float64] - Lstar : NDArray[np.float64] - Flux : NDArray[np.float64] - PSD : NDArray[np.float64] - MLT : NDArray[np.float64] - B_SM : NDArray[np.float64] - B_total : NDArray[np.float64] - B_sat : NDArray[np.float64] - xGEO : NDArray[np.float64] - P : NDArray[np.float64] - R0 : NDArray[np.float64] - density : NDArray[np.float64] - - """ - - _variable_mapping: ClassVar[dict[str, str]] = { - "Epoch_posixtime": "time", - "Energy_FEDU": "energy_channels", - "Energy_FPDU": "energy_channels", - "Energy_FEIU": "energy_channels", - "Energy_FEDO": "energy_channels", - "PA_local": "alpha_local", - "PA_eq_": "alpha_eq_model", - "alpha_eq_real": "alpha_eq_real", - "invMu_": "InvMu", - "InvMu_real": "InvMu_real", - "invK_": "InvK", - # "InvV": "InvV",# computed property - "Lstar_": "Lstar", - "FEDU": "Flux", - "FPDU": "Flux", - "FEIU": "Flux", - "FEDO": "Flux", - "PSD_FEDU": "PSD", - "PSD_FPDU": "PSD", - "PSD_FEIU": "PSD", - "PSD_FEDO": "PSD", - "MLT_": "MLT", - "B_SM": "B_SM", - "B_eq_": "B_total", - "B_local_": "B_sat", - "xGEO": "xGEO", - # "P": "P",# computed property - "R_eq_": "R0", - "density": "density", - } - - datetime: list[dt.datetime] - time: NDArray[np.float64] - energy_channels: NDArray[np.float64] - alpha_local: NDArray[np.float64] - alpha_eq_model: NDArray[np.float64] - alpha_eq_real: NDArray[np.float64] - InvMu: NDArray[np.float64] - InvMu_real: NDArray[np.float64] - InvK: NDArray[np.float64] - InvV: NDArray[np.float64] - Lstar: NDArray[np.float64] - Flux: NDArray[np.float64] - PSD: NDArray[np.float64] - MLT: NDArray[np.float64] - B_SM: NDArray[np.float64] - B_total: NDArray[np.float64] - B_sat: NDArray[np.float64] - xGEO: NDArray[np.float64] - P: NDArray[np.float64] - R0: NDArray[np.float64] - density: NDArray[np.float64] - - def __init__( - self, - satellite: SatelliteLike, - instrument: InstrumentEnum, - mfm: MfmEnum, - ) -> None: - if isinstance(satellite, str): - if satellite.lower() == "goesprimary": - satellite_enum = SatelliteEnum["GOESPrimary"] - elif satellite.lower() == "goessecondary": - satellite_enum = SatelliteEnum["GOESSecondary"] - else: - satellite_enum = SatelliteEnum[satellite.upper()] - else: - satellite_enum = satellite - - if isinstance(instrument, str): - instrument = InstrumentEnum[instrument.upper()] - satellite_obj = replace( - satellite_enum.value, - folder_type=FolderTypeEnum.NoFolder, - file_cadence=FileCadenceEnum.NoCadence, - ) - - self._satellite = satellite_obj - self._instrument = instrument - self._mfm = mfm - self._mfm_prefix = ElPasoMFMEnum[self._mfm.name].value - - @property - def satellite(self) -> SatelliteEnum: - """Returns the satellite enum.""" - return self._satellite - - @property - def instrument(self) -> InstrumentEnum: - """Returns the instrument enum.""" - return self._instrument - - @property - def mfm(self) -> MfmEnum: - """Returns the MFM enum.""" - return self._mfm - - @property - def variable_mapping(self) -> dict[str, str]: - """Returns the variable mapping dictionary.""" - return self._variable_mapping - - def __dir__(self): - return super().__dir__() + [var.var_name for var in VariableEnum] - - def update_from_dict(self, source_dict: dict[str, Variable]) -> None: - """Get data from ElPaso data dictionary and update the object. - - Parameters - ---------- - source_dict : dict[str, Any] - Dictionary containing the data to be loaded into the object. - - """ - for _, value in source_dict.items(): - if value.standard_name in self._variable_mapping: - target_attr = self._variable_mapping[value.standard_name] - - if value.standard_name == "Epoch_posixtime" and target_attr == "time": - datetimes = [dt.datetime.fromtimestamp(ts, tz=timezone.utc) for ts in value.data] - setattr(self, "datetime", datetimes) - setattr(self, "time", [python2matlab(i) for i in datetimes]) - else: - setattr(self, target_attr, value.data) - - elif value.standard_name.endswith(self._mfm_prefix): - base_key = value.standard_name.replace(self._mfm_prefix, "") - if base_key in self._variable_mapping: - target_attr = self._variable_mapping[base_key] - setattr(self, target_attr, value.data) - - @property - def P(self) -> NDArray[np.float64]: - """Calculate P. - - Returns - ------- - NDArray[np.float64] - The P value calculated from the MLT. - """ - return ((self.MLT + 12) / 12 * np.pi) % (2 * np.pi) - - @property - def InvV(self) -> NDArray[np.float64]: - """Calculate InvV. - - Returns - ------- - NDArray[np.float64] - The InvV value calculated from InvMu and InvK. - """ - inv_K_repeated = np.repeat(self.InvK[:, np.newaxis, :], self.InvMu.shape[1], axis=1) - InvV = self.InvMu * (inv_K_repeated + 0.5) ** 2 - return InvV - - def __getattr__(self, name: str) -> Any: - if name in self._variable_mapping.values(): - raise AttributeError( - f"Attribute '{name}' is mapped but has not been set. " - "Make sure data is loaded or that this attribute is properly assigned." - ) - raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'") - - def __repr__(self): - return f"{self.__class__.__name__}({self.satellite}, {self.instrument}, {self.mfm})" - - def __str__(self): - return self.__repr__() diff --git a/swvo/io/RBMDataSet/RBMDataSetManager.py b/swvo/io/RBMDataSet/RBMDataSetManager.py index 5a08c66..17da461 100644 --- a/swvo/io/RBMDataSet/RBMDataSetManager.py +++ b/swvo/io/RBMDataSet/RBMDataSetManager.py @@ -152,12 +152,12 @@ def load( return_list.append(cls._instance.data_set_dict[key_tuple]) else: cls._instance.data_set_dict[key_tuple] = RBMDataSet( - start_time, - end_time, - folder_path, - sat, - instrument, - mfm, + satellite=sat, + instrument=instrument, + mfm=mfm, + start_time=start_time, + end_time=end_time, + folder_path=folder_path, verbose=verbose, preferred_extension=preferred_extension, ) diff --git a/swvo/io/RBMDataSet/RBMNcDataSet.py b/swvo/io/RBMDataSet/RBMNcDataSet.py index ee5e211..7eb273b 100644 --- a/swvo/io/RBMDataSet/RBMNcDataSet.py +++ b/swvo/io/RBMDataSet/RBMNcDataSet.py @@ -98,12 +98,12 @@ def __init__( verbose: bool = True, ) -> None: super().__init__( - start_time, - end_time, - folder_path, - satellite, - instrument, + satellite=satellite, + instrument=instrument, mfm=mfm, + start_time=start_time, + end_time=end_time, + folder_path=folder_path, verbose=verbose, ) From fea06192865c4fc87d75f083360d9fe5c05359a7 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 28 Oct 2025 17:57:20 +0000 Subject: [PATCH 04/24] Run linters and format code Co-authored-by: sahiljhawar <55475299+sahiljhawar@users.noreply.github.com> --- swvo/io/RBMDataSet/RBMDataSet.py | 30 +++++++++++++++++------------- 1 file changed, 17 insertions(+), 13 deletions(-) diff --git a/swvo/io/RBMDataSet/RBMDataSet.py b/swvo/io/RBMDataSet/RBMDataSet.py index bf30284..97a844c 100644 --- a/swvo/io/RBMDataSet/RBMDataSet.py +++ b/swvo/io/RBMDataSet/RBMDataSet.py @@ -44,7 +44,7 @@ class RBMDataSet: """RBMDataSet class for loading and managing data. This class can load data either from files or from a dictionary (ElPaso format). - + For file-based loading, provide start_time, end_time, and folder_path. For dictionary-based loading, initialize without these parameters and use update_from_dict(). @@ -168,19 +168,19 @@ def __init__( satellite = SatelliteEnum["GOESSecondary"] else: satellite = SatelliteEnum[satellite.upper()] - + if isinstance(instrument, str): instrument = InstrumentEnum[instrument.upper()] - + if isinstance(mfm, str): mfm = MfmEnum[mfm.upper()] - + # Store the original satellite enum for properties and other attributes self._satellite_enum = satellite self._instrument = instrument self._mfm = mfm self._verbose = verbose - + # For dict-based loading (ElPaso mode), modify satellite properties if start_time is None and end_time is None and folder_path is None: # ElPaso mode: no file loading needed @@ -197,13 +197,13 @@ def __init__( if start_time is None or end_time is None or folder_path is None: msg = "For file-based loading, start_time, end_time, and folder_path must all be provided" raise ValueError(msg) - + if start_time.tzinfo is None: start_time = start_time.replace(tzinfo=timezone.utc) - + if end_time.tzinfo is None: end_time = end_time.replace(tzinfo=timezone.utc) - + self._start_time = start_time self._end_time = end_time self._satellite = satellite @@ -229,15 +229,15 @@ def __getattr__(self, name: str): # Avoid recursion for internal attributes if name.startswith("_"): raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'") - + # Handle computed properties for both modes if name == "P" and hasattr(self, "MLT"): return ((self.MLT + 12) / 12 * np.pi) % (2 * np.pi) - + if name == "InvV" and hasattr(self, "InvK") and hasattr(self, "InvMu"): inv_K_repeated = np.repeat(self.InvK[:, np.newaxis, :], self.InvMu.shape[1], axis=1) return self.InvMu * (inv_K_repeated + 0.5) ** 2 - + # check if a sat variable is requested # if we find a similar word, suggest that to the user sat_variable = None @@ -259,9 +259,13 @@ def __getattr__(self, name: str): if sat_variable is not None and hasattr(self, "_file_loading_mode") and self._file_loading_mode: self._load_variable(sat_variable) return getattr(self, name) - + # For dict-loading mode, check if it's a mapped variable - if hasattr(self, "_file_loading_mode") and not self._file_loading_mode and name in self._variable_mapping.values(): + if ( + hasattr(self, "_file_loading_mode") + and not self._file_loading_mode + and name in self._variable_mapping.values() + ): raise AttributeError( f"Attribute '{name}' is mapped but has not been set. " "Make sure data is loaded or that this attribute is properly assigned." From 4a2b5d3cd8e48376e99f98232ee99d1b830a5f5c Mon Sep 17 00:00:00 2001 From: Sahil Jhawar Date: Wed, 29 Oct 2025 15:04:07 +0100 Subject: [PATCH 05/24] new exception for RBMDS --- swvo/io/exceptions.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/swvo/io/exceptions.py b/swvo/io/exceptions.py index c44563f..8e4f485 100644 --- a/swvo/io/exceptions.py +++ b/swvo/io/exceptions.py @@ -5,3 +5,6 @@ class ModelError(Exception): pass + +class VariableNotFoundError(Exception): + pass \ No newline at end of file From 0e3d46cf272ff7fc24ded93b78fd80d85f2e31e2 Mon Sep 17 00:00:00 2001 From: Sahil Jhawar Date: Wed, 29 Oct 2025 15:04:44 +0100 Subject: [PATCH 06/24] remove elpaso enum and add dummy enum --- swvo/io/RBMDataSet/custom_enums.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/swvo/io/RBMDataSet/custom_enums.py b/swvo/io/RBMDataSet/custom_enums.py index edb24cd..95522a8 100644 --- a/swvo/io/RBMDataSet/custom_enums.py +++ b/swvo/io/RBMDataSet/custom_enums.py @@ -180,13 +180,22 @@ class MfmEnum(Enum): T96 = "T96" TS04 = "T04s" OP77 = "OP77" + T04 = "T04" -MfmEnumLiteral = Literal["T89", "T04s", "TS04", "T96", "OP77"] +MfmEnumLiteral = Literal["T89", "T04s", "TS04", "T96", "OP77", "T04"] MfmLike: TypeAlias = MfmEnumLiteral | MfmEnum -class ElPasoMFMEnum(Enum): - T89 = "T89" - T04s = "T04s" - T04 = "T04" +class DummyEnum(Enum): + SATELLITE = Satellite( + sat_name="dummy", + mission="dummy", + folder_type=FolderTypeEnum.NoFolder, + file_cadence=FileCadenceEnum.NoCadence, + ) + INSTRUMENT = "dummy instrument" + MFM = "dummy mfm" + + +DummyLike: TypeAlias = Literal["dummy"] | DummyEnum From d021c95b6859dcfec3ac15eb4fcbb891f7c47da0 Mon Sep 17 00:00:00 2001 From: Sahil Jhawar Date: Wed, 29 Oct 2025 15:04:57 +0100 Subject: [PATCH 07/24] add new imports --- swvo/io/RBMDataSet/__init__.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/swvo/io/RBMDataSet/__init__.py b/swvo/io/RBMDataSet/__init__.py index 8312052..d307c3b 100644 --- a/swvo/io/RBMDataSet/__init__.py +++ b/swvo/io/RBMDataSet/__init__.py @@ -15,8 +15,10 @@ InstrumentLike as InstrumentLike, MfmEnum as MfmEnum, MfmLike as MfmLike, - ElPasoMFMEnum as ElPasoMFMEnum, SatelliteLiteral as SatelliteLiteral, + VariableLiteral as VariableLiteral, + DummyLike as DummyLike, + DummyEnum as DummyEnum, ) from swvo.io.RBMDataSet.RBMDataSetManager import RBMDataSetManager as RBMDataSetManager from swvo.io.RBMDataSet.interp_functions import TargetType as TargetType From 6a1617fa268809ab4ce93c25dee57b2d62e26e6d Mon Sep 17 00:00:00 2001 From: Sahil Jhawar Date: Wed, 29 Oct 2025 15:05:25 +0100 Subject: [PATCH 08/24] new exception for RBMDS --- swvo/io/exceptions.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/swvo/io/exceptions.py b/swvo/io/exceptions.py index 8e4f485..8e275e0 100644 --- a/swvo/io/exceptions.py +++ b/swvo/io/exceptions.py @@ -6,5 +6,6 @@ class ModelError(Exception): pass + class VariableNotFoundError(Exception): - pass \ No newline at end of file + pass From 220e2958c1dfe5f577c52c668a5ad9a3b31d6629 Mon Sep 17 00:00:00 2001 From: Sahil Jhawar Date: Wed, 29 Oct 2025 15:07:13 +0100 Subject: [PATCH 09/24] refactor: RBMDataSet to use VariableLiteral and improve error handling - Replace variable mapping dictionary with direct VariableLiteral usage - Remove ElPaso-specific type annotations and use generic DummyLike types - Improve error handling for computed properties (P, InvV) with proper validation - Simplify update_from_dict() method to use direct key matching instead of mapping - Add find_similar_variable() helper method for better error suggestions - Remove unused imports and replace ElPasoMFMEnum with DummyEnum --- swvo/io/RBMDataSet/RBMDataSet.py | 145 +++++++++++-------------------- 1 file changed, 52 insertions(+), 93 deletions(-) diff --git a/swvo/io/RBMDataSet/RBMDataSet.py b/swvo/io/RBMDataSet/RBMDataSet.py index 97a844c..01fad02 100644 --- a/swvo/io/RBMDataSet/RBMDataSet.py +++ b/swvo/io/RBMDataSet/RBMDataSet.py @@ -9,13 +9,14 @@ from dataclasses import replace from datetime import timedelta, timezone from pathlib import Path -from typing import Any, ClassVar, Literal, TypeVar +from typing import Any, Literal import distance import numpy as np from dateutil.relativedelta import relativedelta from numpy.typing import NDArray +from swvo.io.exceptions import VariableNotFoundError from swvo.io.RBMDataSet import ( FileCadenceEnum, FolderTypeEnum, @@ -27,18 +28,16 @@ SatelliteLike, Variable, VariableEnum, + VariableLiteral, ) -from swvo.io.RBMDataSet.custom_enums import ElPasoMFMEnum +from swvo.io.RBMDataSet.custom_enums import DummyEnum, DummyLike from swvo.io.RBMDataSet.utils import ( get_file_path_any_format, join_var, load_file_any_format, matlab2python, - python2matlab, ) -ElPasoVariable = TypeVar("ElPasoVariable") # Placeholder for ElPaso Variable type - class RBMDataSet: """RBMDataSet class for loading and managing data. @@ -94,37 +93,6 @@ class RBMDataSet: """ _preferred_ext: str - _variable_mapping: ClassVar[dict[str, str]] = { - "Epoch_posixtime": "time", - "Energy_FEDU": "energy_channels", - "Energy_FPDU": "energy_channels", - "Energy_FEIU": "energy_channels", - "Energy_FEDO": "energy_channels", - "PA_local": "alpha_local", - "PA_eq_": "alpha_eq_model", - "alpha_eq_real": "alpha_eq_real", - "invMu_": "InvMu", - "InvMu_real": "InvMu_real", - "invK_": "InvK", - # "InvV": "InvV",# computed property - "Lstar_": "Lstar", - "FEDU": "Flux", - "FPDU": "Flux", - "FEIU": "Flux", - "FEDO": "Flux", - "PSD_FEDU": "PSD", - "PSD_FPDU": "PSD", - "PSD_FEIU": "PSD", - "PSD_FEDO": "PSD", - "MLT_": "MLT", - "B_SM": "B_SM", - "B_eq_": "B_total", - "B_local_": "B_sat", - "xGEO": "xGEO", - # "P": "P",# computed property - "R_eq_": "R0", - "density": "density", - } datetime: list[dt.datetime] time: NDArray[np.float64] @@ -150,9 +118,9 @@ class RBMDataSet: def __init__( self, - satellite: SatelliteLike, - instrument: InstrumentLike, - mfm: MfmLike, + satellite: SatelliteLike | DummyLike, + instrument: InstrumentLike | DummyLike, + mfm: MfmLike | DummyLike, start_time: dt.datetime | None = None, end_time: dt.datetime | None = None, folder_path: Path | None = None, @@ -160,6 +128,7 @@ def __init__( *, verbose: bool = True, ) -> None: + self.ep_variables = list(VariableLiteral.__args__) # Handle satellite conversion with special cases for GOES if isinstance(satellite, str): if satellite.lower() == "goesprimary": @@ -183,14 +152,14 @@ def __init__( # For dict-based loading (ElPaso mode), modify satellite properties if start_time is None and end_time is None and folder_path is None: - # ElPaso mode: no file loading needed + # no file loading needed satellite_obj = replace( satellite.value, folder_type=FolderTypeEnum.NoFolder, file_cadence=FileCadenceEnum.NoCadence, ) self._satellite = satellite_obj - self._mfm_prefix = ElPasoMFMEnum[mfm.name].value + self._mfm_prefix = DummyEnum.MFM.value if isinstance(mfm, DummyEnum) else MfmEnum[mfm.name].value self._file_loading_mode = False else: # File loading mode: need all parameters @@ -231,44 +200,34 @@ def __getattr__(self, name: str): raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'") # Handle computed properties for both modes - if name == "P" and hasattr(self, "MLT"): + if name == "P": + if not hasattr(self, "MLT") or getattr(self, "MLT") is None or not isinstance(self.MLT, np.ndarray): + raise AttributeError("Cannot compute `P` because `MLT` is missing, not loaded or is not valid array.") return ((self.MLT + 12) / 12 * np.pi) % (2 * np.pi) - if name == "InvV" and hasattr(self, "InvK") and hasattr(self, "InvMu"): + if name == "InvV": + if not all(hasattr(self, attr) for attr in ("InvK", "InvMu")): + raise AttributeError("Cannot compute `InvV` because `InvK` or `InvMu` is missing.") + if not isinstance(self.InvK, np.ndarray) or not isinstance(self.InvMu, np.ndarray): + raise AttributeError("Cannot compute `InvV` because required arrays are invalid or not loaded.") + if self.InvK.ndim < 1 or self.InvMu.ndim < 2: + raise AttributeError("Cannot compute `InvV` because array dimensions are insufficient.") inv_K_repeated = np.repeat(self.InvK[:, np.newaxis, :], self.InvMu.shape[1], axis=1) return self.InvMu * (inv_K_repeated + 0.5) ** 2 # check if a sat variable is requested # if we find a similar word, suggest that to the user sat_variable = None - levenstein_info: dict[str, Any] = {"min_distance": 10, "var_name": ""} - for var in VariableEnum: - if name == var.var_name: - sat_variable = var - break - else: - dist = distance.levenshtein(name, var.var_name) - if name.lower() in var.name.lower(): - dist = 1 - - if dist < levenstein_info["min_distance"]: - levenstein_info["min_distance"] = dist - levenstein_info["var_name"] = var.var_name + sat_variable, levenstein_info = self.find_similar_variable(name) - # if yes, load it (only in file loading mode) - if sat_variable is not None and hasattr(self, "_file_loading_mode") and self._file_loading_mode: + if sat_variable is not None and self._file_loading_mode: self._load_variable(sat_variable) return getattr(self, name) - # For dict-loading mode, check if it's a mapped variable - if ( - hasattr(self, "_file_loading_mode") - and not self._file_loading_mode - and name in self._variable_mapping.values() - ): + if not self._file_loading_mode and name in self.ep_variables: raise AttributeError( - f"Attribute '{name}' is mapped but has not been set. " - "Make sure data is loaded or that this attribute is properly assigned." + f"Attribute '{name}' exists in `VariableLiteral` but has not been set. " + "Call `update_from_dict()` before accessing it." ) if levenstein_info["min_distance"] <= 2: @@ -278,11 +237,22 @@ def __getattr__(self, name: str): raise AttributeError(msg) - # def __getitem__(self, key:str): - # return getattr(self, key:str) + def find_similar_variable(self, name): + levenstein_info: dict[str, Any] = {"min_distance": 10, "var_name": ""} + sat_variable = None + for var in self.ep_variables: + if name == var: + sat_variable = var + break + else: + dist = distance.levenshtein(name, var) + if name.lower() in var.lower(): + dist = 1 - # def __setitem__(self, key, value): - # setattr(self, key, value) + if dist < levenstein_info["min_distance"]: + levenstein_info["min_distance"] = dist + levenstein_info["var_name"] = var + return sat_variable, levenstein_info @property def satellite(self) -> SatelliteEnum: @@ -299,36 +269,25 @@ def mfm(self) -> MfmEnum: """Returns the MFM enum.""" return self._mfm - @property - def variable_mapping(self) -> dict[str, str]: - """Returns the variable mapping dictionary.""" - return self._variable_mapping - - def update_from_dict(self, source_dict: dict[str, ElPasoVariable]) -> None: + def update_from_dict(self, source_dict: dict[str, VariableLiteral]) -> None: """Get data from ElPaso data dictionary and update the object. Parameters ---------- - source_dict : dict[str, Any] + source_dict : dict[str, VariableLiteral] Dictionary containing the data to be loaded into the object. """ - for _, value in source_dict.items(): - if value.standard_name in self._variable_mapping: - target_attr = self._variable_mapping[value.standard_name] - - if value.standard_name == "Epoch_posixtime" and target_attr == "time": - datetimes = [dt.datetime.fromtimestamp(ts, tz=timezone.utc) for ts in value.data] - setattr(self, "datetime", datetimes) - setattr(self, "time", [python2matlab(i) for i in datetimes]) - else: - setattr(self, target_attr, value.data) - - elif hasattr(self, "_mfm_prefix") and value.standard_name.endswith(self._mfm_prefix): - base_key = value.standard_name.replace(self._mfm_prefix, "") - if base_key in self._variable_mapping: - target_attr = self._variable_mapping[base_key] - setattr(self, target_attr, value.data) + for key, value in source_dict.items(): + _, levenstein_info = self.find_similar_variable(key) + if key in self.ep_variables: + setattr(self, key, value) + elif levenstein_info["min_distance"] <= 2: + msg = f"Key '{key}' is not a valid `VariableLiteral`. Maybe you meant '{levenstein_info['var_name']}'?" + raise VariableNotFoundError(msg) + else: + msg = f"Key '{key}' is not a valid `VariableLiteral`." + raise VariableNotFoundError(msg) def get_var(self, var: VariableEnum): return getattr(self, var.var_name) From c672e0a9976f2ea818234dab0967dff9319b359f Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Wed, 29 Oct 2025 14:18:42 +0000 Subject: [PATCH 10/24] Update tests to match refactored RBMDataSet implementation - Remove MockVariable wrapper class as it's no longer needed - Update tests to use direct VariableLiteral keys instead of mapped names - Simplify update_from_dict tests to pass data arrays directly - Add tests for invalid key error handling - Remove obsolete variable mapping tests - All 136 tests passing Co-authored-by: sahiljhawar <55475299+sahiljhawar@users.noreply.github.com> --- tests/io/RBMDataSet/test_RBMDatasetElPaso.py | 178 +++++++------------ 1 file changed, 66 insertions(+), 112 deletions(-) diff --git a/tests/io/RBMDataSet/test_RBMDatasetElPaso.py b/tests/io/RBMDataSet/test_RBMDatasetElPaso.py index e859fde..48ad13f 100644 --- a/tests/io/RBMDataSet/test_RBMDatasetElPaso.py +++ b/tests/io/RBMDataSet/test_RBMDatasetElPaso.py @@ -2,7 +2,6 @@ # # SPDX-License-Identifier: Apache-2.0 -from datetime import datetime, timezone from typing import get_args import numpy as np @@ -16,15 +15,6 @@ SatelliteLiteral, VariableEnum, ) -from swvo.io.RBMDataSet.utils import python2matlab - - -class MockVariable: - """Create a mock Elpaso Variable class for testing""" - - def __init__(self, standard_name, data=None): - self.standard_name = standard_name - self.data = data @pytest.fixture @@ -45,9 +35,10 @@ def test_init_accepts_string_inputs(): def test_variable_mapping_exposed(dataset): - assert isinstance(dataset.variable_mapping, dict) - assert "FEDU" in dataset.variable_mapping - assert dataset.variable_mapping["FEDU"] == "Flux" + """Test that ep_variables contains the expected variable names""" + assert isinstance(dataset.ep_variables, list) + assert "Flux" in dataset.ep_variables + assert "energy_channels" in dataset.ep_variables def test_repr_and_str(dataset): @@ -63,36 +54,36 @@ def test_repr_and_str(dataset): def test_update_from_dict_sets_variables(dataset): - """Test that the correct variable is set with the standard name""" - fedu_data = np.array([[1.0, 2.0]]) + """Test that the correct variable is set with direct key""" + flux_data = np.array([[1.0, 2.0]]) - source_dict = {"FEDU": MockVariable(standard_name="FEDU", data=fedu_data)} + source_dict = {"Flux": flux_data} dataset.update_from_dict(source_dict) - np.testing.assert_array_equal(dataset.Flux, fedu_data) + np.testing.assert_array_equal(dataset.Flux, flux_data) def test_update_from_dict_sets_time(dataset): - """Test that the correct variable is set with the time standard name""" - ts = [datetime(2025, 4, 1, tzinfo=timezone.utc).timestamp()] + """Test that the correct variable is set with direct time key""" + time_data = np.array([738000.0]) # MATLAB datenum format - source_dict = {"Epoch": MockVariable(standard_name="Epoch_posixtime", data=ts)} + source_dict = {"time": time_data} dataset.update_from_dict(source_dict) assert hasattr(dataset, "time") - assert hasattr(dataset, "datetime") - assert dataset.time[0] == python2matlab(datetime(2025, 4, 1, tzinfo=timezone.utc)) + np.testing.assert_array_equal(dataset.time, time_data) -def test_update_from_dict_with_mfm_suffix(dataset): - """Test that the correct variable is set with the MFM suffix""" - mfm_suffix = "_" + dataset._mfm_prefix +def test_update_from_dict_with_multiple_variables(dataset): + """Test that multiple variables can be set at once""" lstar_data = np.array([4.5, 5.0, 5.5]) + energy_data = np.array([100.0, 200.0, 300.0]) - source_dict = {"Lstar": MockVariable(standard_name=f"Lstar{mfm_suffix}", data=lstar_data)} + source_dict = {"Lstar": lstar_data, "energy_channels": energy_data} dataset.update_from_dict(source_dict) np.testing.assert_array_equal(dataset.Lstar, lstar_data) + np.testing.assert_array_equal(dataset.energy_channels, energy_data) def test_computed_p_property(dataset): @@ -115,7 +106,7 @@ def test_computed_invv_property(dataset): def test_getattr_errors(dataset): - with pytest.raises(AttributeError, match="mapped but has not been set"): + with pytest.raises(AttributeError, match="exists in `VariableLiteral` but has not been set"): _ = dataset.Flux with pytest.raises(AttributeError, match="no attribute"): @@ -128,94 +119,57 @@ def test_dir_contains_variable_names(dataset): assert name in dir(dataset) -def test_all_variable_mappings(dataset): - """Test that all variable mappings work correctly.""" - - expected_mappings = { - "Epoch_posixtime": "time", - "Energy_FEDU": "energy_channels", - "PA_local": "alpha_local", - "PA_eq_": "alpha_eq_model", - "alpha_eq_real": "alpha_eq_real", - "invMu_": "InvMu", - "InvMu_real": "InvMu_real", - "invK_": "InvK", - "Lstar_": "Lstar", - "FEDU": "Flux", - "PSD_FEDU": "PSD", - "MLT_": "MLT", - "B_SM": "B_SM", - "B_eq_": "B_total", - "B_local_": "B_sat", - "xGEO": "xGEO", - "R_eq_": "R0", - "density": "density", +def test_update_from_dict_invalid_key(dataset): + """Test that invalid keys raise VariableNotFoundError""" + from swvo.io.exceptions import VariableNotFoundError + + source_dict = {"InvalidKey": np.array([1.0, 2.0])} + + with pytest.raises(VariableNotFoundError, match="not a valid `VariableLiteral`"): + dataset.update_from_dict(source_dict) + + +def test_update_from_dict_similar_key(dataset): + """Test that similar keys suggest the correct variable""" + from swvo.io.exceptions import VariableNotFoundError + + source_dict = {"Flx": np.array([1.0, 2.0])} # Typo: should be "Flux" + + with pytest.raises(VariableNotFoundError, match="Maybe you meant 'Flux'"): + dataset.update_from_dict(source_dict) + + +def test_all_variable_literals(dataset): + """Test that all VariableLiteral values can be set and retrieved.""" + + # Test with a subset of common variables + test_variables = { + "time": np.array([738000.0, 738001.0]), + "energy_channels": np.array([100.0, 200.0, 300.0]), + "alpha_local": np.array([0.1, 0.2, 0.3]), + "alpha_eq_model": np.array([45.0, 60.0, 90.0]), + "InvMu": np.array([[0.1, 0.2]]), + "InvK": np.array([[1.0, 2.0]]), + "Lstar": np.array([4.5, 5.0, 5.5]), + "Flux": np.array([[1.0, 2.0, 3.0]]), + "PSD": np.array([[0.1, 0.2, 0.3]]), + "MLT": np.array([0.0, 6.0, 12.0]), + "B_SM": np.array([100.0, 200.0, 300.0]), + "B_total": np.array([50.0, 60.0, 70.0]), + "B_sat": np.array([45.0, 55.0, 65.0]), + "xGEO": np.array([6.6, 6.7, 6.8]), + "R0": np.array([5.0, 5.5, 6.0]), + "density": np.array([100.0, 200.0, 300.0]), } - for source, target in expected_mappings.items(): - assert source in dataset.variable_mapping - assert dataset.variable_mapping[source] == target - - test_data = {} - for source, target in expected_mappings.items(): - if source == "Epoch_posixtime": - data = [datetime(2025, 4, 1, tzinfo=timezone.utc).timestamp()] - else: - data = np.array([float(hash(source) % 1000) / 10.0]) - - test_data[source] = MockVariable(standard_name=source, data=data) - - dataset.update_from_dict(test_data) - - for source, target in expected_mappings.items(): - if source == "Epoch_posixtime": - assert hasattr(dataset, "time") - assert hasattr(dataset, "datetime") - assert isinstance(dataset.datetime[0], datetime) - assert dataset.time[0] == python2matlab(dataset.datetime[0]) - elif source in ["P", "InvV"]: - pass - else: - assert hasattr(dataset, target), f"Attribute {target} not set from {source}" - - np.testing.assert_array_equal( - dataset.__getattribute__(target), - test_data[source].data, - err_msg=f"Data mismatch for {target} from {source}", - ) - - -def test_all_mfm_specific_mappings(dataset): - """Test that all MFM-specific variable mappings work correctly.""" - - mfm_variables = [ - ("PA_eq", "alpha_eq_model"), - ("invMu", "InvMu"), - ("invK", "InvK"), - ("Lstar", "Lstar"), - ("B_eq", "B_total"), - ("B_local", "B_sat"), - ("R_eq", "R0"), - ] - - mfm_suffix = "_" + dataset._mfm_prefix - - test_data = {} - for source_base, target in mfm_variables: - source = f"{source_base}{mfm_suffix}" - - data = np.array([float(hash(source) % 1000) / 10.0]) - test_data[source] = MockVariable(standard_name=source, data=data) - - dataset.update_from_dict(test_data) - - for source_base, target in mfm_variables: - source = f"{source_base}{mfm_suffix}" - assert hasattr(dataset, target), f"Attribute {target} not set from {source}" + dataset.update_from_dict(test_variables) + + for var_name, expected_data in test_variables.items(): + assert hasattr(dataset, var_name), f"Attribute {var_name} not set" np.testing.assert_array_equal( - dataset.__getattribute__(target), - test_data[source].data, - err_msg=f"Data mismatch for {target} from {source}", + getattr(dataset, var_name), + expected_data, + err_msg=f"Data mismatch for {var_name}", ) From 4044cfdee1cc160ad667ac478d03049352096428 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Wed, 29 Oct 2025 14:39:17 +0000 Subject: [PATCH 11/24] Merge test_RBMDatasetElPaso.py into test_RBMDataset.py Since RBMDataSetElPaso no longer exists as a separate class, consolidate all tests into a single test file. The merged file now contains: - File-based loading tests (original test_RBMDataset.py) - Dictionary-based loading tests (from test_RBMDatasetElPaso.py) All 73 tests for RBMDataSet now in one file with clear naming: - Original tests remain unchanged - Dict mode tests prefixed with "test_dict_mode_" for clarity - Total: 136 tests passing across all RBMDataSet test files Co-authored-by: sahiljhawar <55475299+sahiljhawar@users.noreply.github.com> --- tests/io/RBMDataSet/test_RBMDataset.py | 192 ++++++++++++++++++ tests/io/RBMDataSet/test_RBMDatasetElPaso.py | 196 ------------------- 2 files changed, 192 insertions(+), 196 deletions(-) delete mode 100644 tests/io/RBMDataSet/test_RBMDatasetElPaso.py diff --git a/tests/io/RBMDataSet/test_RBMDataset.py b/tests/io/RBMDataSet/test_RBMDataset.py index 2e8a70e..62ca462 100644 --- a/tests/io/RBMDataSet/test_RBMDataset.py +++ b/tests/io/RBMDataSet/test_RBMDataset.py @@ -5,6 +5,7 @@ import datetime as dt from datetime import timezone from pathlib import Path +from typing import get_args from unittest import mock import numpy as np @@ -16,6 +17,7 @@ MfmEnum, RBMDataSet, SatelliteEnum, + SatelliteLiteral, VariableEnum, ) @@ -293,3 +295,193 @@ def test_all_variables_in_dir(mock_dataset: RBMDataSet): for var in vars: assert var in mock_dataset.__dir__() + + +# ======================================== +# Dictionary-based loading tests (formerly in test_RBMDatasetElPaso.py) +# ======================================== + + +@pytest.fixture +def dict_dataset(): + """Fixture for dictionary-based loading mode (no file parameters)""" + return RBMDataSet( + satellite=SatelliteEnum.GOESSecondary, + instrument=InstrumentEnum.MAGED, + mfm=MfmEnum.T89, + ) + + +def test_dict_mode_init_accepts_string_inputs(): + """Test that the class can be initialized with string inputs in dict mode.""" + ds = RBMDataSet(satellite="GOESSecondary", instrument="MAGED", mfm=MfmEnum.T89) + assert ds.satellite.sat_name == "secondary" + assert ds.instrument == InstrumentEnum.MAGED + assert ds.mfm == MfmEnum.T89 + + +def test_dict_mode_variable_mapping_exposed(dict_dataset): + """Test that ep_variables contains the expected variable names""" + assert isinstance(dict_dataset.ep_variables, list) + assert "Flux" in dict_dataset.ep_variables + assert "energy_channels" in dict_dataset.ep_variables + + +def test_dict_mode_repr_and_str(dict_dataset): + """Test repr and str for dict mode""" + assert "RBMDataSet" in repr(dict_dataset) + assert str(dict_dataset.satellite) in repr(dict_dataset) + assert str(dict_dataset.instrument) in repr(dict_dataset) + assert str(dict_dataset.mfm) in repr(dict_dataset) + + assert "RBMDataSet" in str(dict_dataset) + assert str(dict_dataset.satellite) in str(dict_dataset) + assert str(dict_dataset.instrument) in str(dict_dataset) + assert str(dict_dataset.mfm) in str(dict_dataset) + + +def test_update_from_dict_sets_variables(dict_dataset): + """Test that the correct variable is set with direct key""" + flux_data = np.array([[1.0, 2.0]]) + + source_dict = {"Flux": flux_data} + + dict_dataset.update_from_dict(source_dict) + np.testing.assert_array_equal(dict_dataset.Flux, flux_data) + + +def test_update_from_dict_sets_time(dict_dataset): + """Test that the correct variable is set with direct time key""" + time_data = np.array([738000.0]) # MATLAB datenum format + + source_dict = {"time": time_data} + + dict_dataset.update_from_dict(source_dict) + assert hasattr(dict_dataset, "time") + np.testing.assert_array_equal(dict_dataset.time, time_data) + + +def test_update_from_dict_with_multiple_variables(dict_dataset): + """Test that multiple variables can be set at once""" + lstar_data = np.array([4.5, 5.0, 5.5]) + energy_data = np.array([100.0, 200.0, 300.0]) + + source_dict = {"Lstar": lstar_data, "energy_channels": energy_data} + + dict_dataset.update_from_dict(source_dict) + np.testing.assert_array_equal(dict_dataset.Lstar, lstar_data) + np.testing.assert_array_equal(dict_dataset.energy_channels, energy_data) + + +def test_dict_mode_computed_p_property(dict_dataset): + """Test P property with correct dimensions in dict mode""" + dict_dataset.MLT = np.array([0.0, 6.0, 12.0]) + + expected_p = ((dict_dataset.MLT + 12) / 12 * np.pi) % (2 * np.pi) + np.testing.assert_allclose(dict_dataset.P, expected_p) + + +def test_dict_mode_computed_invv_property(dict_dataset): + """Test InvV property with correct dimensions in dict mode""" + dict_dataset.InvMu = np.array([[0.1, 0.2]]) + dict_dataset.InvK = np.array([[1.0]]) + + inv_K_repeated = np.repeat(dict_dataset.InvK[:, np.newaxis, :], dict_dataset.InvMu.shape[1], axis=1) + expected_invv = dict_dataset.InvMu * (inv_K_repeated + 0.5) ** 2 + + np.testing.assert_allclose(dict_dataset.InvV, expected_invv) + + +def test_dict_mode_getattr_errors(dict_dataset): + """Test error handling for unset attributes in dict mode""" + with pytest.raises(AttributeError, match="exists in `VariableLiteral` but has not been set"): + _ = dict_dataset.Flux + + with pytest.raises(AttributeError, match="no attribute"): + _ = dict_dataset.NonExistent + + +def test_dict_mode_dir_contains_variable_names(dict_dataset): + """Test that dir() includes variable names in dict mode""" + variable_names = [var.var_name for var in VariableEnum] + for name in variable_names: + assert name in dir(dict_dataset) + + +def test_update_from_dict_invalid_key(dict_dataset): + """Test that invalid keys raise VariableNotFoundError""" + from swvo.io.exceptions import VariableNotFoundError + + source_dict = {"InvalidKey": np.array([1.0, 2.0])} + + with pytest.raises(VariableNotFoundError, match="not a valid `VariableLiteral`"): + dict_dataset.update_from_dict(source_dict) + + +def test_update_from_dict_similar_key(dict_dataset): + """Test that similar keys suggest the correct variable""" + from swvo.io.exceptions import VariableNotFoundError + + source_dict = {"Flx": np.array([1.0, 2.0])} # Typo: should be "Flux" + + with pytest.raises(VariableNotFoundError, match="Maybe you meant 'Flux'"): + dict_dataset.update_from_dict(source_dict) + + +def test_all_variable_literals(dict_dataset): + """Test that all VariableLiteral values can be set and retrieved.""" + + # Test with a subset of common variables + test_variables = { + "time": np.array([738000.0, 738001.0]), + "energy_channels": np.array([100.0, 200.0, 300.0]), + "alpha_local": np.array([0.1, 0.2, 0.3]), + "alpha_eq_model": np.array([45.0, 60.0, 90.0]), + "InvMu": np.array([[0.1, 0.2]]), + "InvK": np.array([[1.0, 2.0]]), + "Lstar": np.array([4.5, 5.0, 5.5]), + "Flux": np.array([[1.0, 2.0, 3.0]]), + "PSD": np.array([[0.1, 0.2, 0.3]]), + "MLT": np.array([0.0, 6.0, 12.0]), + "B_SM": np.array([100.0, 200.0, 300.0]), + "B_total": np.array([50.0, 60.0, 70.0]), + "B_sat": np.array([45.0, 55.0, 65.0]), + "xGEO": np.array([6.6, 6.7, 6.8]), + "R0": np.array([5.0, 5.5, 6.0]), + "density": np.array([100.0, 200.0, 300.0]), + } + + dict_dataset.update_from_dict(test_variables) + + for var_name, expected_data in test_variables.items(): + assert hasattr(dict_dataset, var_name), f"Attribute {var_name} not set" + np.testing.assert_array_equal( + getattr(dict_dataset, var_name), + expected_data, + err_msg=f"Data mismatch for {var_name}", + ) + + +@pytest.mark.parametrize("satellite, expected", [("goessecondary", "secondary"), ("goesprimary", "primary")]) +def test_dict_mode_goes_lowercase(satellite, expected): + """Test GOES satellite lowercase string handling in dict mode""" + goes_dataset = RBMDataSet( + satellite=satellite, + instrument=InstrumentEnum.MAGED, + mfm=MfmEnum.T89, + ) + assert goes_dataset.satellite.sat_name == expected + + +uppercase_satellites = set(get_args(SatelliteLiteral)) - set(["GOESPrimary", "GOESSecondary"]) + + +@pytest.mark.parametrize("satellite", [i.lower() for i in uppercase_satellites]) +def test_dict_mode_satellite_lowercase(satellite): + """Test satellite lowercase string handling in dict mode""" + dataset = RBMDataSet( + satellite=satellite, + instrument=InstrumentEnum.MAGED, + mfm=MfmEnum.T89, + ) + assert dataset.satellite.sat_name == satellite.lower() diff --git a/tests/io/RBMDataSet/test_RBMDatasetElPaso.py b/tests/io/RBMDataSet/test_RBMDatasetElPaso.py deleted file mode 100644 index 48ad13f..0000000 --- a/tests/io/RBMDataSet/test_RBMDatasetElPaso.py +++ /dev/null @@ -1,196 +0,0 @@ -# SPDX-FileCopyrightText: 2025 GFZ Helmholtz Centre for Geosciences -# -# SPDX-License-Identifier: Apache-2.0 - -from typing import get_args - -import numpy as np -import pytest - -from swvo.io.RBMDataSet import ( - InstrumentEnum, - MfmEnum, - RBMDataSet, - SatelliteEnum, - SatelliteLiteral, - VariableEnum, -) - - -@pytest.fixture -def dataset(): - return RBMDataSet( - satellite=SatelliteEnum.GOESSecondary, - instrument=InstrumentEnum.MAGED, - mfm=MfmEnum.T89, - ) - - -def test_init_accepts_string_inputs(): - """Test that the class can be initialized with string inputs.""" - ds = RBMDataSet(satellite="GOESSecondary", instrument="MAGED", mfm=MfmEnum.T89) - assert ds.satellite.sat_name == "secondary" - assert ds.instrument == InstrumentEnum.MAGED - assert ds.mfm == MfmEnum.T89 - - -def test_variable_mapping_exposed(dataset): - """Test that ep_variables contains the expected variable names""" - assert isinstance(dataset.ep_variables, list) - assert "Flux" in dataset.ep_variables - assert "energy_channels" in dataset.ep_variables - - -def test_repr_and_str(dataset): - assert "RBMDataSet" in repr(dataset) - assert str(dataset.satellite) in repr(dataset) - assert str(dataset.instrument) in repr(dataset) - assert str(dataset.mfm) in repr(dataset) - - assert "RBMDataSet" in str(dataset) - assert str(dataset.satellite) in str(dataset) - assert str(dataset.instrument) in str(dataset) - assert str(dataset.mfm) in str(dataset) - - -def test_update_from_dict_sets_variables(dataset): - """Test that the correct variable is set with direct key""" - flux_data = np.array([[1.0, 2.0]]) - - source_dict = {"Flux": flux_data} - - dataset.update_from_dict(source_dict) - np.testing.assert_array_equal(dataset.Flux, flux_data) - - -def test_update_from_dict_sets_time(dataset): - """Test that the correct variable is set with direct time key""" - time_data = np.array([738000.0]) # MATLAB datenum format - - source_dict = {"time": time_data} - - dataset.update_from_dict(source_dict) - assert hasattr(dataset, "time") - np.testing.assert_array_equal(dataset.time, time_data) - - -def test_update_from_dict_with_multiple_variables(dataset): - """Test that multiple variables can be set at once""" - lstar_data = np.array([4.5, 5.0, 5.5]) - energy_data = np.array([100.0, 200.0, 300.0]) - - source_dict = {"Lstar": lstar_data, "energy_channels": energy_data} - - dataset.update_from_dict(source_dict) - np.testing.assert_array_equal(dataset.Lstar, lstar_data) - np.testing.assert_array_equal(dataset.energy_channels, energy_data) - - -def test_computed_p_property(dataset): - """Test P property with correct dimensions""" - dataset.MLT = np.array([0.0, 6.0, 12.0]) - - expected_p = ((dataset.MLT + 12) / 12 * np.pi) % (2 * np.pi) - np.testing.assert_allclose(dataset.P, expected_p) - - -def test_computed_invv_property(dataset): - """Test InvV property with correct dimensions""" - dataset.InvMu = np.array([[0.1, 0.2]]) - dataset.InvK = np.array([[1.0]]) - - inv_K_repeated = np.repeat(dataset.InvK[:, np.newaxis, :], dataset.InvMu.shape[1], axis=1) - expected_invv = dataset.InvMu * (inv_K_repeated + 0.5) ** 2 - - np.testing.assert_allclose(dataset.InvV, expected_invv) - - -def test_getattr_errors(dataset): - with pytest.raises(AttributeError, match="exists in `VariableLiteral` but has not been set"): - _ = dataset.Flux - - with pytest.raises(AttributeError, match="no attribute"): - _ = dataset.NonExistent - - -def test_dir_contains_variable_names(dataset): - variable_names = [var.var_name for var in VariableEnum] - for name in variable_names: - assert name in dir(dataset) - - -def test_update_from_dict_invalid_key(dataset): - """Test that invalid keys raise VariableNotFoundError""" - from swvo.io.exceptions import VariableNotFoundError - - source_dict = {"InvalidKey": np.array([1.0, 2.0])} - - with pytest.raises(VariableNotFoundError, match="not a valid `VariableLiteral`"): - dataset.update_from_dict(source_dict) - - -def test_update_from_dict_similar_key(dataset): - """Test that similar keys suggest the correct variable""" - from swvo.io.exceptions import VariableNotFoundError - - source_dict = {"Flx": np.array([1.0, 2.0])} # Typo: should be "Flux" - - with pytest.raises(VariableNotFoundError, match="Maybe you meant 'Flux'"): - dataset.update_from_dict(source_dict) - - -def test_all_variable_literals(dataset): - """Test that all VariableLiteral values can be set and retrieved.""" - - # Test with a subset of common variables - test_variables = { - "time": np.array([738000.0, 738001.0]), - "energy_channels": np.array([100.0, 200.0, 300.0]), - "alpha_local": np.array([0.1, 0.2, 0.3]), - "alpha_eq_model": np.array([45.0, 60.0, 90.0]), - "InvMu": np.array([[0.1, 0.2]]), - "InvK": np.array([[1.0, 2.0]]), - "Lstar": np.array([4.5, 5.0, 5.5]), - "Flux": np.array([[1.0, 2.0, 3.0]]), - "PSD": np.array([[0.1, 0.2, 0.3]]), - "MLT": np.array([0.0, 6.0, 12.0]), - "B_SM": np.array([100.0, 200.0, 300.0]), - "B_total": np.array([50.0, 60.0, 70.0]), - "B_sat": np.array([45.0, 55.0, 65.0]), - "xGEO": np.array([6.6, 6.7, 6.8]), - "R0": np.array([5.0, 5.5, 6.0]), - "density": np.array([100.0, 200.0, 300.0]), - } - - dataset.update_from_dict(test_variables) - - for var_name, expected_data in test_variables.items(): - assert hasattr(dataset, var_name), f"Attribute {var_name} not set" - np.testing.assert_array_equal( - getattr(dataset, var_name), - expected_data, - err_msg=f"Data mismatch for {var_name}", - ) - - -@pytest.mark.parametrize("satellite, expected", [("goessecondary", "secondary"), ("goesprimary", "primary")]) -def test_goes_lowercase(satellite, expected): - goes_dataset = RBMDataSet( - satellite=satellite, - instrument=InstrumentEnum.MAGED, - mfm=MfmEnum.T89, - ) - assert goes_dataset.satellite.sat_name == expected - - -uppercase_satellites = set(get_args(SatelliteLiteral)) - set(["GOESPrimary", "GOESSecondary"]) - - -@pytest.mark.parametrize("satellite", [i.lower() for i in uppercase_satellites]) -def test_satellite_lowercase(satellite): - dataset = RBMDataSet( - satellite=satellite, - instrument=InstrumentEnum.MAGED, - mfm=MfmEnum.T89, - ) - assert dataset.satellite.sat_name == satellite.lower() From edc06a9c2ad0654cb0cfbfa87386420cd23bc352 Mon Sep 17 00:00:00 2001 From: Sahil Jhawar Date: Wed, 29 Oct 2025 15:50:23 +0100 Subject: [PATCH 12/24] test: Replace hardcoded enums with DummyEnum in RBMDataSet tests - Remove obsolete test_dict_mode_variable_mapping_exposed function --- tests/io/RBMDataSet/test_RBMDataset.py | 21 +++++---------------- 1 file changed, 5 insertions(+), 16 deletions(-) diff --git a/tests/io/RBMDataSet/test_RBMDataset.py b/tests/io/RBMDataSet/test_RBMDataset.py index 62ca462..a70fe00 100644 --- a/tests/io/RBMDataSet/test_RBMDataset.py +++ b/tests/io/RBMDataSet/test_RBMDataset.py @@ -20,6 +20,7 @@ SatelliteLiteral, VariableEnum, ) +from swvo.io.RBMDataSet.custom_enums import DummyEnum @pytest.fixture @@ -297,18 +298,13 @@ def test_all_variables_in_dir(mock_dataset: RBMDataSet): assert var in mock_dataset.__dir__() -# ======================================== -# Dictionary-based loading tests (formerly in test_RBMDatasetElPaso.py) -# ======================================== - - @pytest.fixture def dict_dataset(): """Fixture for dictionary-based loading mode (no file parameters)""" return RBMDataSet( - satellite=SatelliteEnum.GOESSecondary, - instrument=InstrumentEnum.MAGED, - mfm=MfmEnum.T89, + satellite=DummyEnum.SATELLITE, + instrument=DummyEnum.INSTRUMENT, + mfm=DummyEnum.MFM, ) @@ -320,13 +316,6 @@ def test_dict_mode_init_accepts_string_inputs(): assert ds.mfm == MfmEnum.T89 -def test_dict_mode_variable_mapping_exposed(dict_dataset): - """Test that ep_variables contains the expected variable names""" - assert isinstance(dict_dataset.ep_variables, list) - assert "Flux" in dict_dataset.ep_variables - assert "energy_channels" in dict_dataset.ep_variables - - def test_dict_mode_repr_and_str(dict_dataset): """Test repr and str for dict mode""" assert "RBMDataSet" in repr(dict_dataset) @@ -352,7 +341,7 @@ def test_update_from_dict_sets_variables(dict_dataset): def test_update_from_dict_sets_time(dict_dataset): """Test that the correct variable is set with direct time key""" - time_data = np.array([738000.0]) # MATLAB datenum format + time_data = np.array([738000.0]) source_dict = {"time": time_data} From 0d7257c4e9bba2c16982f3ed48306c780df800e7 Mon Sep 17 00:00:00 2001 From: Sahil Jhawar Date: Wed, 29 Oct 2025 16:19:20 +0100 Subject: [PATCH 13/24] refactor: Add get_loaded_variables method and improve __eq__ comparison - Add get_loaded_variables() method to return list of currently loaded variable names - Refactor __eq__ method to handle both file loading and dict variable --- swvo/io/RBMDataSet/RBMDataSet.py | 48 +++++++++++++++++++++----------- 1 file changed, 31 insertions(+), 17 deletions(-) diff --git a/swvo/io/RBMDataSet/RBMDataSet.py b/swvo/io/RBMDataSet/RBMDataSet.py index 01fad02..366a20a 100644 --- a/swvo/io/RBMDataSet/RBMDataSet.py +++ b/swvo/io/RBMDataSet/RBMDataSet.py @@ -438,28 +438,42 @@ def _load_variable(self, var: Variable | VariableEnum) -> None: setattr(self, var_name, loaded_var_arrs[var_name]) + def get_loaded_variables(self) -> list[str]: + """Get a list of currently loaded variable names.""" + loaded_vars = [] + for var in VariableEnum: + if hasattr(self, var.var_name): + loaded_vars.append(var.var_name) + return loaded_vars + def __eq__(self, other: RBMDataSet) -> bool: - if self._satellite != other._satellite: - return False - if self._instrument != other._instrument: - return False - if self._mfm != other._mfm: + if self._file_loading_mode != other._file_loading_mode: return False - for var in VariableEnum: - self_var = getattr(self, var.var_name) - other_var = getattr(other, var.var_name) + if self._file_loading_mode: + if self._satellite != other._satellite or self._instrument != other._instrument or self._mfm != other._mfm: + return False + variables = [v.var_name for v in VariableEnum] - if isinstance(self_var, list) and isinstance(other_var, list): - if len(self_var) != len(other_var): - return False - for dt1, dt2 in zip(self_var, other_var): - if dt1 != dt2: - return False - elif isinstance(self_var, np.ndarray) and isinstance(other_var, np.ndarray): - if self_var.shape != other_var.shape: + else: + self_vars = {v for v in self.ep_variables if hasattr(self, v)} + other_vars = {v for v in other.ep_variables if hasattr(other, v)} + if self_vars != other_vars: + return False + variables = self_vars + + for var in variables: + self_var = getattr(self, var) + other_var = getattr(other, var) + + if not isinstance(other_var, type(self_var)): + return False + + if isinstance(self_var, list): + if len(self_var) != len(other_var) or any(a != b for a, b in zip(self_var, other_var)): return False - if not np.allclose(self_var, other_var, equal_nan=True): + elif isinstance(self_var, np.ndarray): + if self_var.shape != other_var.shape or not np.allclose(self_var, other_var, equal_nan=True): return False elif self_var != other_var: return False From 4ebd597d75c844755b7e0c2e38d012269b3f5578 Mon Sep 17 00:00:00 2001 From: Sahil Jhawar Date: Wed, 29 Oct 2025 16:36:26 +0100 Subject: [PATCH 14/24] refactor: only use get_loaded_variables to compare objects --- swvo/io/RBMDataSet/RBMDataSet.py | 23 +++++++++++------------ 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/swvo/io/RBMDataSet/RBMDataSet.py b/swvo/io/RBMDataSet/RBMDataSet.py index 366a20a..41ecd02 100644 --- a/swvo/io/RBMDataSet/RBMDataSet.py +++ b/swvo/io/RBMDataSet/RBMDataSet.py @@ -447,20 +447,19 @@ def get_loaded_variables(self) -> list[str]: return loaded_vars def __eq__(self, other: RBMDataSet) -> bool: - if self._file_loading_mode != other._file_loading_mode: + if ( + self._file_loading_mode != other._file_loading_mode + or self._satellite != other._satellite + or self._instrument != other._instrument + or self._mfm != other._mfm + ): return False - if self._file_loading_mode: - if self._satellite != other._satellite or self._instrument != other._instrument or self._mfm != other._mfm: - return False - variables = [v.var_name for v in VariableEnum] - - else: - self_vars = {v for v in self.ep_variables if hasattr(self, v)} - other_vars = {v for v in other.ep_variables if hasattr(other, v)} - if self_vars != other_vars: - return False - variables = self_vars + self_vars = self.get_loaded_variables() + other_vars = other.get_loaded_variables() + if self_vars != other_vars: + return False + variables = self_vars for var in variables: self_var = getattr(self, var) From 7f3519ad97888401a61ac0c1fec6ab7bee90487b Mon Sep 17 00:00:00 2001 From: Sahil Jhawar Date: Wed, 29 Oct 2025 16:36:40 +0100 Subject: [PATCH 15/24] tests: add tests for __eq__ method --- tests/io/RBMDataSet/test_RBMDataset.py | 340 +++++++++++++++++++++++++ 1 file changed, 340 insertions(+) diff --git a/tests/io/RBMDataSet/test_RBMDataset.py b/tests/io/RBMDataSet/test_RBMDataset.py index a70fe00..d6c0bed 100644 --- a/tests/io/RBMDataSet/test_RBMDataset.py +++ b/tests/io/RBMDataSet/test_RBMDataset.py @@ -474,3 +474,343 @@ def test_dict_mode_satellite_lowercase(satellite): mfm=MfmEnum.T89, ) assert dataset.satellite.sat_name == satellite.lower() + + +def test_eq_file_loading_mode_identical(mock_module_string): + """Test equality for identical file loading mode datasets.""" + start_time = dt.datetime(2023, 1, 1, tzinfo=timezone.utc) + end_time = dt.datetime(2023, 1, 31, tzinfo=timezone.utc) + folder_path = Path("/mock/path") + + with mock.patch(f"{mock_module_string}._create_date_list"): + with mock.patch(f"{mock_module_string}._create_file_path_stem"): + with mock.patch(f"{mock_module_string}._create_file_name_stem"): + dataset1 = RBMDataSet( + satellite=SatelliteEnum.RBSPA, + instrument=InstrumentEnum.MAGEIS, + mfm=MfmEnum.T89, + start_time=start_time, + end_time=end_time, + folder_path=folder_path, + verbose=False, + ) + + dataset2 = RBMDataSet( + satellite=SatelliteEnum.RBSPA, + instrument=InstrumentEnum.MAGEIS, + mfm=MfmEnum.T89, + start_time=start_time, + end_time=end_time, + folder_path=folder_path, + verbose=False, + ) + + dataset1.Flux = np.array([[1.0, 2.0, 3.0]]) + dataset1.time = np.array([738000.0]) + dataset1.datetime = [dt.datetime(2023, 1, 15, tzinfo=timezone.utc)] + + dataset2.Flux = np.array([[1.0, 2.0, 3.0]]) + dataset2.time = np.array([738000.0]) + dataset2.datetime = [dt.datetime(2023, 1, 15, tzinfo=timezone.utc)] + + assert dataset1 == dataset2 + + +def test_eq_file_loading_mode_different_satellite(mock_module_string): + """Test inequality for file loading mode datasets with different satellites.""" + start_time = dt.datetime(2023, 1, 1, tzinfo=timezone.utc) + end_time = dt.datetime(2023, 1, 31, tzinfo=timezone.utc) + folder_path = Path("/mock/path") + + with mock.patch(f"{mock_module_string}._create_date_list"): + with mock.patch(f"{mock_module_string}._create_file_path_stem"): + with mock.patch(f"{mock_module_string}._create_file_name_stem"): + dataset1 = RBMDataSet( + satellite=SatelliteEnum.RBSPA, + instrument=InstrumentEnum.MAGEIS, + mfm=MfmEnum.T89, + start_time=start_time, + end_time=end_time, + folder_path=folder_path, + verbose=False, + ) + + dataset2 = RBMDataSet( + satellite=SatelliteEnum.RBSPB, # Different satellite + instrument=InstrumentEnum.MAGEIS, + mfm=MfmEnum.T89, + start_time=start_time, + end_time=end_time, + folder_path=folder_path, + verbose=False, + ) + + print("sahil") + + assert dataset1 != dataset2 + + +def test_eq_file_loading_mode_different_instrument(mock_module_string): + """Test inequality for file loading mode datasets with different instruments.""" + start_time = dt.datetime(2023, 1, 1, tzinfo=timezone.utc) + end_time = dt.datetime(2023, 1, 31, tzinfo=timezone.utc) + folder_path = Path("/mock/path") + + with mock.patch(f"{mock_module_string}._create_date_list"): + with mock.patch(f"{mock_module_string}._create_file_path_stem"): + with mock.patch(f"{mock_module_string}._create_file_name_stem"): + dataset1 = RBMDataSet( + satellite=SatelliteEnum.RBSPA, + instrument=InstrumentEnum.MAGEIS, + mfm=MfmEnum.T89, + start_time=start_time, + end_time=end_time, + folder_path=folder_path, + verbose=False, + ) + + dataset2 = RBMDataSet( + satellite=SatelliteEnum.RBSPA, + instrument=InstrumentEnum.HOPE, # Different instrument + mfm=MfmEnum.T89, + start_time=start_time, + end_time=end_time, + folder_path=folder_path, + verbose=False, + ) + + assert dataset1 != dataset2 + + +def test_eq_file_loading_mode_different_mfm(mock_module_string): + """Test inequality for file loading mode datasets with different MFM.""" + start_time = dt.datetime(2023, 1, 1, tzinfo=timezone.utc) + end_time = dt.datetime(2023, 1, 31, tzinfo=timezone.utc) + folder_path = Path("/mock/path") + + with mock.patch(f"{mock_module_string}._create_date_list"): + with mock.patch(f"{mock_module_string}._create_file_path_stem"): + with mock.patch(f"{mock_module_string}._create_file_name_stem"): + dataset1 = RBMDataSet( + satellite=SatelliteEnum.RBSPA, + instrument=InstrumentEnum.MAGEIS, + mfm=MfmEnum.T89, + start_time=start_time, + end_time=end_time, + folder_path=folder_path, + verbose=False, + ) + + dataset2 = RBMDataSet( + satellite=SatelliteEnum.RBSPA, + instrument=InstrumentEnum.MAGEIS, + mfm=MfmEnum.T96, # Different MFM + start_time=start_time, + end_time=end_time, + folder_path=folder_path, + verbose=False, + ) + + assert dataset1 != dataset2 + + +def test_eq_dict_mode_identical(): + """Test equality for identical dict mode datasets.""" + dataset1 = RBMDataSet( + satellite=DummyEnum.SATELLITE, + instrument=DummyEnum.INSTRUMENT, + mfm=DummyEnum.MFM, + ) + + dataset2 = RBMDataSet( + satellite=DummyEnum.SATELLITE, + instrument=DummyEnum.INSTRUMENT, + mfm=DummyEnum.MFM, + ) + + test_data = { + "Flux": np.array([[1.0, 2.0, 3.0]]), + "time": np.array([738000.0]), + "energy_channels": np.array([100.0, 200.0, 300.0]), + "Lstar": np.array([4.5, 5.0, 5.5]), + } + + dataset1.update_from_dict(test_data) + dataset2.update_from_dict(test_data.copy()) + + assert dataset1 == dataset2 + + +def test_eq_dict_mode_different_variables(): + """Test inequality for dict mode datasets with different variables.""" + dataset1 = RBMDataSet( + satellite=DummyEnum.SATELLITE, + instrument=DummyEnum.INSTRUMENT, + mfm=DummyEnum.MFM, + ) + + dataset2 = RBMDataSet( + satellite=DummyEnum.SATELLITE, + instrument=DummyEnum.INSTRUMENT, + mfm=DummyEnum.MFM, + ) + + dataset1.update_from_dict({"Flux": np.array([[1.0, 2.0, 3.0]])}) + dataset2.update_from_dict({"time": np.array([738000.0])}) + + assert dataset1 != dataset2 + + +def test_eq_dict_mode_same_variables_different_values(): + """Test inequality for dict mode datasets with same variables but different values.""" + dataset1 = RBMDataSet( + satellite=DummyEnum.SATELLITE, + instrument=DummyEnum.INSTRUMENT, + mfm=DummyEnum.MFM, + ) + + dataset2 = RBMDataSet( + satellite=DummyEnum.SATELLITE, + instrument=DummyEnum.INSTRUMENT, + mfm=DummyEnum.MFM, + ) + + dataset1.update_from_dict({"Flux": np.array([[1.0, 2.0, 3.0]])}) + dataset2.update_from_dict({"Flux": np.array([[4.0, 5.0, 6.0]])}) + + assert dataset1 != dataset2 + + +def test_eq_different_modes(mock_module_string): + """Test inequality between file loading and dict mode datasets.""" + # File loading mode dataset + start_time = dt.datetime(2023, 1, 1, tzinfo=timezone.utc) + end_time = dt.datetime(2023, 1, 31, tzinfo=timezone.utc) + folder_path = Path("/mock/path") + + with mock.patch(f"{mock_module_string}._create_date_list"): + with mock.patch(f"{mock_module_string}._create_file_path_stem"): + with mock.patch(f"{mock_module_string}._create_file_name_stem"): + file_dataset = RBMDataSet( + satellite=SatelliteEnum.RBSPA, + instrument=InstrumentEnum.MAGEIS, + mfm=MfmEnum.T89, + start_time=start_time, + end_time=end_time, + folder_path=folder_path, + verbose=False, + ) + + # Dict mode dataset + dict_dataset = RBMDataSet( + satellite=DummyEnum.SATELLITE, + instrument=DummyEnum.INSTRUMENT, + mfm=DummyEnum.MFM, + ) + + assert file_dataset != dict_dataset + + +def test_eq_array_comparison_with_nan(): + """Test equality with NaN values in arrays.""" + dataset1 = RBMDataSet( + satellite=DummyEnum.SATELLITE, + instrument=DummyEnum.INSTRUMENT, + mfm=DummyEnum.MFM, + ) + + dataset2 = RBMDataSet( + satellite=DummyEnum.SATELLITE, + instrument=DummyEnum.INSTRUMENT, + mfm=DummyEnum.MFM, + ) + + test_array = np.array([[1.0, np.nan, 3.0]]) + dataset1.update_from_dict({"Flux": test_array}) + dataset2.update_from_dict({"Flux": test_array.copy()}) + + assert dataset1 == dataset2 + + +def test_eq_list_comparison(): + """Test equality with list variables.""" + dataset1 = RBMDataSet( + satellite=DummyEnum.SATELLITE, + instrument=DummyEnum.INSTRUMENT, + mfm=DummyEnum.MFM, + ) + + dataset2 = RBMDataSet( + satellite=DummyEnum.SATELLITE, + instrument=DummyEnum.INSTRUMENT, + mfm=DummyEnum.MFM, + ) + + test_datetime_list = [dt.datetime(2023, 1, 15, tzinfo=timezone.utc)] + dataset1.datetime = test_datetime_list + dataset2.datetime = test_datetime_list.copy() + + assert dataset1 == dataset2 + + +def test_eq_list_different_lengths(): + """Test inequality with lists of different lengths.""" + dataset1 = RBMDataSet( + satellite=DummyEnum.SATELLITE, + instrument=DummyEnum.INSTRUMENT, + mfm=DummyEnum.MFM, + ) + + dataset2 = RBMDataSet( + satellite=DummyEnum.SATELLITE, + instrument=DummyEnum.INSTRUMENT, + mfm=DummyEnum.MFM, + ) + + dataset1.datetime = [dt.datetime(2023, 1, 15, tzinfo=timezone.utc)] + dataset2.datetime = [ + dt.datetime(2023, 1, 15, tzinfo=timezone.utc), + dt.datetime(2023, 1, 16, tzinfo=timezone.utc), + ] + + assert dataset1 != dataset2 + + +def test_eq_array_different_shapes(): + """Test inequality with arrays of different shapes.""" + dataset1 = RBMDataSet( + satellite=DummyEnum.SATELLITE, + instrument=DummyEnum.INSTRUMENT, + mfm=DummyEnum.MFM, + ) + + dataset2 = RBMDataSet( + satellite=DummyEnum.SATELLITE, + instrument=DummyEnum.INSTRUMENT, + mfm=DummyEnum.MFM, + ) + + dataset1.update_from_dict({"Flux": np.array([[1.0, 2.0]])}) + dataset2.update_from_dict({"Flux": np.array([1.0, 2.0])}) + + assert dataset1 != dataset2 + + +def test_eq_different_types(): + """Test inequality when same variable has different types.""" + dataset1 = RBMDataSet( + satellite=DummyEnum.SATELLITE, + instrument=DummyEnum.INSTRUMENT, + mfm=DummyEnum.MFM, + ) + + dataset2 = RBMDataSet( + satellite=DummyEnum.SATELLITE, + instrument=DummyEnum.INSTRUMENT, + mfm=DummyEnum.MFM, + ) + + dataset1.time = np.array([738000.0]) + dataset2.time = [738000.0] + + assert dataset1 != dataset2 From 06a34b15e86db8b45e3d05b950d5e0c53a137ecb Mon Sep 17 00:00:00 2001 From: Sahil Jhawar Date: Wed, 29 Oct 2025 16:48:17 +0100 Subject: [PATCH 16/24] remove mentions of elpaso --- swvo/io/RBMDataSet/RBMDataSet.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/swvo/io/RBMDataSet/RBMDataSet.py b/swvo/io/RBMDataSet/RBMDataSet.py index 41ecd02..fd73808 100644 --- a/swvo/io/RBMDataSet/RBMDataSet.py +++ b/swvo/io/RBMDataSet/RBMDataSet.py @@ -42,7 +42,7 @@ class RBMDataSet: """RBMDataSet class for loading and managing data. - This class can load data either from files or from a dictionary (ElPaso format). + This class can load data either from files or from a dictionary. For file-based loading, provide start_time, end_time, and folder_path. For dictionary-based loading, initialize without these parameters and use update_from_dict(). @@ -150,7 +150,7 @@ def __init__( self._mfm = mfm self._verbose = verbose - # For dict-based loading (ElPaso mode), modify satellite properties + # For dict-based loading, modify satellite properties if start_time is None and end_time is None and folder_path is None: # no file loading needed satellite_obj = replace( @@ -270,7 +270,7 @@ def mfm(self) -> MfmEnum: return self._mfm def update_from_dict(self, source_dict: dict[str, VariableLiteral]) -> None: - """Get data from ElPaso data dictionary and update the object. + """Get data from data dictionary and update the object. Parameters ---------- From ee5295c9cff3acc706115ba3ada222746e220df1 Mon Sep 17 00:00:00 2001 From: Sahil Jhawar Date: Wed, 29 Oct 2025 17:36:15 +0100 Subject: [PATCH 17/24] add enable_dict_loading i.e., optionally allow dict loading when loading from files --- swvo/io/RBMDataSet/RBMDataSet.py | 7 +++++++ tests/io/RBMDataSet/test_RBMDataset.py | 14 ++++++++++++++ 2 files changed, 21 insertions(+) diff --git a/swvo/io/RBMDataSet/RBMDataSet.py b/swvo/io/RBMDataSet/RBMDataSet.py index fd73808..8dbc8bc 100644 --- a/swvo/io/RBMDataSet/RBMDataSet.py +++ b/swvo/io/RBMDataSet/RBMDataSet.py @@ -65,6 +65,8 @@ class RBMDataSet: Preferred file extension for file-based loading. Default is "pickle". verbose : bool, optional Whether to print verbose output. Default is True. + enable_dict_loading : bool, optional + Enable dictionary-based loading even in file mode. Default is False. Attributes ---------- @@ -127,6 +129,7 @@ def __init__( preferred_extension: Literal["mat", "pickle"] = "pickle", *, verbose: bool = True, + enable_dict_loading: bool = False, ) -> None: self.ep_variables = list(VariableLiteral.__args__) # Handle satellite conversion with special cases for GOES @@ -184,6 +187,7 @@ def __init__( self._file_cadence = self._satellite.file_cadence self._date_of_files = self._create_date_list() self._file_loading_mode = True + self._enable_dict_loading = enable_dict_loading def __repr__(self): return f"{self.__class__.__name__}({self._satellite_enum}, {self._instrument}, {self._mfm})" @@ -278,6 +282,9 @@ def update_from_dict(self, source_dict: dict[str, VariableLiteral]) -> None: Dictionary containing the data to be loaded into the object. """ + if self._file_loading_mode and not self._enable_dict_loading: + msg = "RBMDataSet is in file loading mode. Cannot update from dictionary. To use dictionary-based loading, set `enable_dict_loading=True` during initialization." + raise RuntimeError(msg) for key, value in source_dict.items(): _, levenstein_info = self.find_similar_variable(key) if key in self.ep_variables: diff --git a/tests/io/RBMDataSet/test_RBMDataset.py b/tests/io/RBMDataSet/test_RBMDataset.py index d6c0bed..02e7374 100644 --- a/tests/io/RBMDataSet/test_RBMDataset.py +++ b/tests/io/RBMDataSet/test_RBMDataset.py @@ -298,6 +298,20 @@ def test_all_variables_in_dir(mock_dataset: RBMDataSet): assert var in mock_dataset.__dir__() +def test_disable_dict_loading_mode(mock_dataset: RBMDataSet): + """Test that default file loading mode raises RuntimeError on update_from_dict.""" + with pytest.raises(RuntimeError): + mock_dataset.update_from_dict({"Flux": np.array([[1.0, 2.0, 3.0]])}) + + +def test_enable_dict_loading_mode(mock_dataset: RBMDataSet): + """Test that default file loading mode raises RuntimeError on update_from_dict.""" + mock_dataset._enable_dict_loading = True + mock_dataset.update_from_dict({"Flux": np.array([[1.0, 2.0, 3.0]])}) + + assert hasattr(mock_dataset, "Flux") + + @pytest.fixture def dict_dataset(): """Fixture for dictionary-based loading mode (no file parameters)""" From 3901c03bf02cdc88404b0bb4af19b6fef5d0c813 Mon Sep 17 00:00:00 2001 From: Sahil Jhawar Date: Thu, 30 Oct 2025 11:29:39 +0100 Subject: [PATCH 18/24] fix type annotation --- swvo/io/RBMDataSet/RBMDataSet.py | 10 +++++----- swvo/io/RBMDataSet/RBMDataSetManager.py | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/swvo/io/RBMDataSet/RBMDataSet.py b/swvo/io/RBMDataSet/RBMDataSet.py index 8dbc8bc..7afa0e1 100644 --- a/swvo/io/RBMDataSet/RBMDataSet.py +++ b/swvo/io/RBMDataSet/RBMDataSet.py @@ -44,16 +44,16 @@ class RBMDataSet: This class can load data either from files or from a dictionary. - For file-based loading, provide start_time, end_time, and folder_path. - For dictionary-based loading, initialize without these parameters and use update_from_dict(). + For file-based loading, provide `start_time`, `end_time`, and `folder_path`. + For dictionary-based loading, initialize without these parameters and use `update_from_dict()`. Parameters ---------- - satellite : :class:`SatelliteLike` + satellite : Union[:class:`SatelliteLike`, :class:`DummyLike`] Satellite identifier as enum or string. - instrument : :class:`InstrumentLike` + instrument : Union[:class:`InstrumentLike`, :class:`DummyLike`] Instrument enumeration or string. - mfm : :class:`MfmLike` + mfm : Union[:class:`MfmLike`, :class:`DummyLike`] Magnetic field model enum or string. start_time : dt.datetime, optional Start time for file-based loading. diff --git a/swvo/io/RBMDataSet/RBMDataSetManager.py b/swvo/io/RBMDataSet/RBMDataSetManager.py index 17da461..7533a48 100644 --- a/swvo/io/RBMDataSet/RBMDataSetManager.py +++ b/swvo/io/RBMDataSet/RBMDataSetManager.py @@ -110,7 +110,7 @@ def load( Satellite identifier(s) as enum or string. If a single satellite is provided, it can be a string or an enum. instrument : :class:`InstrumentEnum` Instrument enumeration, e.g., :class:`InstrumentEnum.HOPE`. - mfm : MfmEnum + mfm : :class:`MfmEnum` Magnetic field model enum, e.g., :class:`MfmEnum.T89`. folder_type : :class:`FolderTypeEnum`, optional Type of folder where the data is stored, by default :class:`FolderTypeEnum.DataServer`. From 2b31188d273784596a2af0b459c7fffa3fe92b17 Mon Sep 17 00:00:00 2001 From: Sahil Jhawar Date: Thu, 30 Oct 2025 11:30:29 +0100 Subject: [PATCH 19/24] add .env to gitingore --- .gitignore | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index e6070e0..a88fd58 100644 --- a/.gitignore +++ b/.gitignore @@ -22,4 +22,5 @@ cron_check.sh swace_commented_code docs/_build/* docs/api/* -.nfs* \ No newline at end of file +.nfs* +.env From cbb89fb5c224fda01bc9f1eb2a37b833dca15a70 Mon Sep 17 00:00:00 2001 From: Sahil Jhawar Date: Thu, 30 Oct 2025 13:36:18 +0100 Subject: [PATCH 20/24] return object in update_from_dict --- swvo/io/RBMDataSet/RBMDataSet.py | 15 ++++++++++++++- tests/io/RBMDataSet/test_RBMDataset.py | 8 ++++++++ 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/swvo/io/RBMDataSet/RBMDataSet.py b/swvo/io/RBMDataSet/RBMDataSet.py index 7afa0e1..ae94bcf 100644 --- a/swvo/io/RBMDataSet/RBMDataSet.py +++ b/swvo/io/RBMDataSet/RBMDataSet.py @@ -273,7 +273,7 @@ def mfm(self) -> MfmEnum: """Returns the MFM enum.""" return self._mfm - def update_from_dict(self, source_dict: dict[str, VariableLiteral]) -> None: + def update_from_dict(self, source_dict: dict[str, VariableLiteral]) -> RBMDataSet: """Get data from data dictionary and update the object. Parameters @@ -281,6 +281,18 @@ def update_from_dict(self, source_dict: dict[str, VariableLiteral]) -> None: source_dict : dict[str, VariableLiteral] Dictionary containing the data to be loaded into the object. + Returns + ------- + RBMDataSet + The updated RBMDataSet object. + + Raises + ------ + VariableNotFoundError + If a key in the `source_dict` is not a valid `VariableLiteral`. + RuntimeError + If the `RBMDataSet` is in file loading mode and dictionary loading is not enabled. + """ if self._file_loading_mode and not self._enable_dict_loading: msg = "RBMDataSet is in file loading mode. Cannot update from dictionary. To use dictionary-based loading, set `enable_dict_loading=True` during initialization." @@ -295,6 +307,7 @@ def update_from_dict(self, source_dict: dict[str, VariableLiteral]) -> None: else: msg = f"Key '{key}' is not a valid `VariableLiteral`." raise VariableNotFoundError(msg) + return self def get_var(self, var: VariableEnum): return getattr(self, var.var_name) diff --git a/tests/io/RBMDataSet/test_RBMDataset.py b/tests/io/RBMDataSet/test_RBMDataset.py index 02e7374..21ec896 100644 --- a/tests/io/RBMDataSet/test_RBMDataset.py +++ b/tests/io/RBMDataSet/test_RBMDataset.py @@ -343,6 +343,14 @@ def test_dict_mode_repr_and_str(dict_dataset): assert str(dict_dataset.mfm) in str(dict_dataset) +def test_update_from_dict_returns_self(dict_dataset): + """Test that update_from_dict returns self for method chaining""" + source_dict = {"Flux": np.array([[1.0, 2.0]])} + + result = dict_dataset.update_from_dict(source_dict) + assert isinstance(result, RBMDataSet) + + def test_update_from_dict_sets_variables(dict_dataset): """Test that the correct variable is set with direct key""" flux_data = np.array([[1.0, 2.0]]) From 17b533eedd819dafed3d988cd400a3e8daea958f Mon Sep 17 00:00:00 2001 From: Bernhard Haas Date: Thu, 30 Oct 2025 14:09:35 +0100 Subject: [PATCH 21/24] Removed Dummy enum. Small type changes. --- swvo/io/RBMDataSet/RBMDataSet.py | 62 ++++--- swvo/io/RBMDataSet/RBMNcDataSet.py | 4 +- swvo/io/RBMDataSet/__init__.py | 4 +- .../bin_and_interpolate_to_model_grid.py | 152 ++++++++++++++---- swvo/io/RBMDataSet/custom_enums.py | 28 ++-- tests/io/RBMDataSet/test_RBMDataset.py | 109 +++++++------ 6 files changed, 221 insertions(+), 138 deletions(-) diff --git a/swvo/io/RBMDataSet/RBMDataSet.py b/swvo/io/RBMDataSet/RBMDataSet.py index ae94bcf..2d248cc 100644 --- a/swvo/io/RBMDataSet/RBMDataSet.py +++ b/swvo/io/RBMDataSet/RBMDataSet.py @@ -20,17 +20,19 @@ from swvo.io.RBMDataSet import ( FileCadenceEnum, FolderTypeEnum, + Instrument, InstrumentEnum, InstrumentLike, + Mfm, MfmEnum, MfmLike, + Satellite, SatelliteEnum, SatelliteLike, Variable, VariableEnum, VariableLiteral, ) -from swvo.io.RBMDataSet.custom_enums import DummyEnum, DummyLike from swvo.io.RBMDataSet.utils import ( get_file_path_any_format, join_var, @@ -120,9 +122,9 @@ class RBMDataSet: def __init__( self, - satellite: SatelliteLike | DummyLike, - instrument: InstrumentLike | DummyLike, - mfm: MfmLike | DummyLike, + satellite: SatelliteLike, + instrument: InstrumentLike, + mfm: MfmLike, start_time: dt.datetime | None = None, end_time: dt.datetime | None = None, folder_path: Path | None = None, @@ -131,7 +133,7 @@ def __init__( verbose: bool = True, enable_dict_loading: bool = False, ) -> None: - self.ep_variables = list(VariableLiteral.__args__) + self.possible_variables: list[str] = list(VariableLiteral.__args__) # Handle satellite conversion with special cases for GOES if isinstance(satellite, str): if satellite.lower() == "goesprimary": @@ -148,21 +150,13 @@ def __init__( mfm = MfmEnum[mfm.upper()] # Store the original satellite enum for properties and other attributes - self._satellite_enum = satellite + self._satellite = satellite self._instrument = instrument self._mfm = mfm self._verbose = verbose # For dict-based loading, modify satellite properties if start_time is None and end_time is None and folder_path is None: - # no file loading needed - satellite_obj = replace( - satellite.value, - folder_type=FolderTypeEnum.NoFolder, - file_cadence=FileCadenceEnum.NoCadence, - ) - self._satellite = satellite_obj - self._mfm_prefix = DummyEnum.MFM.value if isinstance(mfm, DummyEnum) else MfmEnum[mfm.name].value self._file_loading_mode = False else: # File loading mode: need all parameters @@ -190,7 +184,7 @@ def __init__( self._enable_dict_loading = enable_dict_loading def __repr__(self): - return f"{self.__class__.__name__}({self._satellite_enum}, {self._instrument}, {self._mfm})" + return f"{self.__class__.__name__}({self._satellite}, {self._instrument}, {self._mfm})" def __str__(self): return self.__repr__() @@ -205,14 +199,14 @@ def __getattr__(self, name: str): # Handle computed properties for both modes if name == "P": - if not hasattr(self, "MLT") or getattr(self, "MLT") is None or not isinstance(self.MLT, np.ndarray): + if not hasattr(self, "MLT") or getattr(self, "MLT") is None or not isinstance(self.MLT, np.ndarray): # type: ignore[reportUnnecessaryIsInstance] raise AttributeError("Cannot compute `P` because `MLT` is missing, not loaded or is not valid array.") return ((self.MLT + 12) / 12 * np.pi) % (2 * np.pi) if name == "InvV": if not all(hasattr(self, attr) for attr in ("InvK", "InvMu")): raise AttributeError("Cannot compute `InvV` because `InvK` or `InvMu` is missing.") - if not isinstance(self.InvK, np.ndarray) or not isinstance(self.InvMu, np.ndarray): + if not isinstance(self.InvK, np.ndarray) or not isinstance(self.InvMu, np.ndarray): # type: ignore[reportUnnecessaryIsInstance] raise AttributeError("Cannot compute `InvV` because required arrays are invalid or not loaded.") if self.InvK.ndim < 1 or self.InvMu.ndim < 2: raise AttributeError("Cannot compute `InvV` because array dimensions are insufficient.") @@ -228,7 +222,7 @@ def __getattr__(self, name: str): self._load_variable(sat_variable) return getattr(self, name) - if not self._file_loading_mode and name in self.ep_variables: + if not self._file_loading_mode and name in self.possible_variables: raise AttributeError( f"Attribute '{name}' exists in `VariableLiteral` but has not been set. " "Call `update_from_dict()` before accessing it." @@ -241,35 +235,37 @@ def __getattr__(self, name: str): raise AttributeError(msg) - def find_similar_variable(self, name): + def find_similar_variable(self, name: str) -> tuple[None | VariableEnum, dict[str, Any]]: levenstein_info: dict[str, Any] = {"min_distance": 10, "var_name": ""} sat_variable = None - for var in self.ep_variables: - if name == var: + for var in VariableEnum: + + if name == var.var_name: sat_variable = var break else: - dist = distance.levenshtein(name, var) - if name.lower() in var.lower(): + dist = distance.levenshtein(name, var.var_name) + if name.lower() in var.var_name.lower(): dist = 1 if dist < levenstein_info["min_distance"]: levenstein_info["min_distance"] = dist - levenstein_info["var_name"] = var + levenstein_info["var_name"] = var.var_name + return sat_variable, levenstein_info @property - def satellite(self) -> SatelliteEnum: + def satellite(self) -> SatelliteEnum | Satellite: """Returns the satellite enum.""" - return self._satellite_enum + return self._satellite @property - def instrument(self) -> InstrumentEnum: + def instrument(self) -> InstrumentEnum | Instrument: """Returns the instrument enum.""" return self._instrument @property - def mfm(self) -> MfmEnum: + def mfm(self) -> MfmEnum | Mfm: """Returns the MFM enum.""" return self._mfm @@ -299,7 +295,7 @@ def update_from_dict(self, source_dict: dict[str, VariableLiteral]) -> RBMDataSe raise RuntimeError(msg) for key, value in source_dict.items(): _, levenstein_info = self.find_similar_variable(key) - if key in self.ep_variables: + if key in self.possible_variables: setattr(self, key, value) elif levenstein_info["min_distance"] <= 2: msg = f"Key '{key}' is not a valid `VariableLiteral`. Maybe you meant '{levenstein_info['var_name']}'?" @@ -347,13 +343,13 @@ def _create_file_name_stem(self) -> str: # if self._satellite == SatelliteEnum.THEMIS: # pass - return self._satellite.sat_name + "_" + self._instrument.value + "_" + return self._satellite.sat_name + "_" + self._instrument.instrument_name + "_" def get_satellite_name(self) -> str: return self._satellite.sat_name def get_satellite_and_instrument_name(self) -> str: - return self._satellite.sat_name + "_" + self._instrument.value + return self._satellite.sat_name + "_" + self._instrument.instrument_name def set_file_path_stem(self, file_path_stem: Path): self._file_path_stem = file_path_stem @@ -369,7 +365,7 @@ def set_file_cadence(self, file_cadence: FileCadenceEnum): return self def get_print_name(self) -> str: - return self._satellite.sat_name + " " + self._instrument.value + return self._satellite.sat_name + " " + self._instrument.instrument_name def _load_variable(self, var: Variable | VariableEnum) -> None: loaded_var_arrs: dict[str, NDArray[np.number]] = {} @@ -395,7 +391,7 @@ def _load_variable(self, var: Variable | VariableEnum) -> None: file_name_no_format = self._file_name_stem + date_str + "_" + var.mat_file_prefix if var.mat_has_B: - file_name_no_format += "_n4_4_" + self._mfm.value + file_name_no_format += "_n4_4_" + self._mfm.mfm_name file_name_no_format += "_ver4" else: diff --git a/swvo/io/RBMDataSet/RBMNcDataSet.py b/swvo/io/RBMDataSet/RBMNcDataSet.py index 7eb273b..c6b28db 100644 --- a/swvo/io/RBMDataSet/RBMNcDataSet.py +++ b/swvo/io/RBMDataSet/RBMNcDataSet.py @@ -141,7 +141,7 @@ def _load_variable(self, var: Variable | VariableEnum) -> None: next_month = start_month + relativedelta(months=1, days=-1) date_str = start_month.strftime("%Y%m%d") + "to" + next_month.strftime("%Y%m%d") - file_name = self._file_name_stem + date_str + "_" + self._mfm.value + ".nc" + file_name = self._file_name_stem + date_str + "_" + self._mfm.mfm_name + ".nc" else: raise NotImplementedError @@ -193,7 +193,7 @@ def _load_variable(self, var: Variable | VariableEnum) -> None: if var_name == "datetime": loaded_var_arrs[var_name] = list(loaded_var_arrs[var_name]) # type: ignore - rbm_var_name = RBMNcDataSet._get_rbm_name(var_name, self._mfm.value) + rbm_var_name = RBMNcDataSet._get_rbm_name(var_name, self._mfm.mfm_name) if rbm_var_name is not None: setattr(self, rbm_var_name, loaded_var_arrs[var_name]) diff --git a/swvo/io/RBMDataSet/__init__.py b/swvo/io/RBMDataSet/__init__.py index d307c3b..d51d0f8 100644 --- a/swvo/io/RBMDataSet/__init__.py +++ b/swvo/io/RBMDataSet/__init__.py @@ -11,14 +11,14 @@ Satellite as Satellite, SatelliteLike as SatelliteLike, SatelliteEnum as SatelliteEnum, + Instrument as Instrument, InstrumentEnum as InstrumentEnum, InstrumentLike as InstrumentLike, + Mfm as Mfm, MfmEnum as MfmEnum, MfmLike as MfmLike, SatelliteLiteral as SatelliteLiteral, VariableLiteral as VariableLiteral, - DummyLike as DummyLike, - DummyEnum as DummyEnum, ) from swvo.io.RBMDataSet.RBMDataSetManager import RBMDataSetManager as RBMDataSetManager from swvo.io.RBMDataSet.interp_functions import TargetType as TargetType diff --git a/swvo/io/RBMDataSet/bin_and_interpolate_to_model_grid.py b/swvo/io/RBMDataSet/bin_and_interpolate_to_model_grid.py index 2504023..78fc199 100644 --- a/swvo/io/RBMDataSet/bin_and_interpolate_to_model_grid.py +++ b/swvo/io/RBMDataSet/bin_and_interpolate_to_model_grid.py @@ -12,17 +12,18 @@ from pathlib import Path from typing import TYPE_CHECKING, Literal +from icecream import ic import numpy as np from matplotlib import pyplot as plt from numpy.typing import NDArray from tqdm import tqdm if TYPE_CHECKING: - from swvo.io.RBMDataSet import RBMDataSet + from swvo.io.RBMDataSet import RBMDataSet, RBMDataSetElPaso def bin_and_interpolate_to_model_grid( - self: RBMDataSet, + self: RBMDataSet | RBMDataSetElPaso, sim_time: list[datetime], grid_R: NDArray[np.float64], grid_mu_V: NDArray[np.float64], @@ -41,27 +42,35 @@ def bin_and_interpolate_to_model_grid( if grid_K.ndim == 3: grid_K = grid_K[np.newaxis, ...] + target_var_init = getattr(self, target_var_name) + # 1. interpolate to V-K - target_var_init = getattr(self, target_var_name) + if grid_R.shape[2] > 1 and grid_R.shape[3] > 1: - if target_var_init.ndim == 1: - target_var_init = target_var_init[:, np.newaxis, np.newaxis] + if target_var_init.ndim == 1: + target_var_init = target_var_init[:, np.newaxis, np.newaxis] - mu_or_V_arr = self.InvMu if mu_or_V == "Mu" else self.InvV - if grid_mu_V.shape[2] > 1: - psd_interp = _interpolate_in_V_K(target_var_init, mu_or_V_arr, self.InvK, grid_mu_V, grid_K) + mu_or_V_arr = self.InvMu if mu_or_V == "Mu" else self.InvV + if grid_mu_V.shape[2] > 1: + psd_interp = _interpolate_in_V_K(target_var_init, mu_or_V_arr, self.InvK, grid_mu_V, grid_K) + else: + psd_interp = target_var_init + + # sanity check + if np.min(target_var_init) > np.min(psd_interp) or np.max(target_var_init) < np.max(psd_interp): + msg = "Found inconsitency in V-K interpolation. Aborting..." + raise (ValueError(msg)) else: - psd_interp = target_var_init - # sanity check - if np.min(target_var_init) > np.min(psd_interp) or np.max(target_var_init) < np.max(psd_interp): - msg = "Found inconsitency in V-K interpolation. Aborting..." - raise (ValueError(msg)) + if target_var_init.ndim == 1: # plasmasphere + target_var_init = target_var_init[:, np.newaxis, np.newaxis] + + psd_interp = target_var_init # 2. Bin in space - R_or_Lstar_arr = self.R0 if grid_P else self.Lstar[:, -1] + R_or_Lstar_arr = self.R0 if grid_P is not None else self.Lstar[:, -1] psd_binned_in_space = _bin_in_space(psd_interp, self.P, R_or_Lstar_arr, grid_R, grid_P) # sanity check @@ -77,15 +86,24 @@ def bin_and_interpolate_to_model_grid( raise (ValueError(msg)) if debug_plot_settings: - plot_debug_figures( + if debug_plot_settings.target_K is not None: + plot_debug_figures( + self, + psd_binned_in_time, + sim_time, + grid_P, + grid_R, + grid_mu_V, + grid_K, + mu_or_V, + debug_plot_settings, + ) + else: plot_debug_figures_plasmasphere( self, psd_binned_in_time, sim_time, grid_P, grid_R, - grid_mu_V, - grid_K, - mu_or_V, debug_plot_settings, ) @@ -141,7 +159,7 @@ def _bin_in_time( time_indices = _get_time_indices(data_timestamps, _get_time_bins(sim_timestamps)) for i, _ in tqdm(enumerate(sim_time)): - psd_binned[i, ...] = np.nanmean(data_psd[time_indices == i, ...], axis=0) + psd_binned[i, ...] = np.power(10, np.nanmean(np.log10(data_psd[time_indices == i, ...]), axis=0)) return psd_binned @@ -155,7 +173,7 @@ def _bin_in_space( ) -> NDArray[np.float64]: print("\tBin in space...") - if grid_P: + if grid_P is not None: grid_P_1d = grid_P[:, 0, 0, 0] grid_R_1d = grid_R[0, :, 0, 0] @@ -199,7 +217,7 @@ def _bin_in_space( r_idx = np.argmin(np.abs(R_data[it] - grid_R_1d)) - if grid_P_1d: + if grid_P_1d is not None: raw_difference_p = np.abs(P_data[it] - grid_P_1d) min_difference_p = np.where( raw_difference_p <= np.pi, @@ -210,10 +228,15 @@ def _bin_in_space( number_of_observations[it, p_idx, r_idx, :, :] += np.where(np.isnan(psd_in[it, :, :]), 0, 1) psd_binned[it, p_idx, r_idx, :, :] += np.where(np.isnan(psd_in[it, :, :]), 0, np.log10(psd_in[it, :, :])) + else: number_of_observations[it, 0, r_idx, :, :] += np.where(np.isnan(psd_in[it, :, :]), 0, 1) psd_binned[it, 0, r_idx, :, :] += np.where(np.isnan(psd_in[it, :, :]), 0, np.log10(psd_in[it, :, :])) + # # ic(number_of_observations[it, :, :, 0, 0]) + # ic(np.power(10, np.nanmax(psd_binned[it, :, :, 0, 0]))) + # ic(np.power(10, np.nanmax(psd_binned[it, :, :, 0, 0] / number_of_observations[it, :, :, 0, 0]))) + psd_binned = np.where(psd_binned == 0, np.nan, psd_binned) return np.power(10, psd_binned / number_of_observations) @@ -350,15 +373,91 @@ def _parallel_func_VK( class DebugPlotSettings: folder_path: Path satellite_name: str - target_V: float - target_K: float + target_V: float | None = None + target_K: float | None = None + + +def plot_debug_figures_plasmasphere( + data_set: RBMDataSet, + psd_binned: NDArray[np.float64], + sim_time: NDArray[np.object_], + grid_P: NDArray[np.float64] | None, + grid_R: NDArray[np.float64], + debug_plot_settings: DebugPlotSettings, +): + + print("\tPlot debug features...") + + from icecream import ic + + dt = sim_time[1] - sim_time[0] + + fig = plt.figure(figsize=(19.20, 8)) + plt.rcParams["axes.axisbelow"] = False + + R_or_Lstar_arr = data_set.R0 + + for it, sim_time_curr in enumerate(tqdm(sim_time)): + sat_time_idx = np.argwhere(np.abs(np.asarray(data_set.datetime) - sim_time_curr) <= dt / 2) + + R_idx = np.argwhere(np.abs(grid_R[0, :, 0, 0] - R_or_Lstar_arr[sat_time_idx])) + + + ax0 = fig.add_subplot(121, projection="polar") + ax1 = fig.add_subplot(122) + + # plot satellite trajectory on PxR grid + # [x_sat, y_sat] = pol2cart(self.P, self.R) + + # ic(data_set.P[sat_time_idx]) + # ic(R_or_Lstar_arr[sat_time_idx]) + + ax0.scatter(data_set.P[sat_time_idx], R_or_Lstar_arr[sat_time_idx], c=np.log10(data_set.density[sat_time_idx]), marker="D", vmin=0, + vmax=4, + cmap="jet",) + ax0.set_ylim(1, 6.6) + ax0.set_title("Orbit") + ax0.set_rlim([0, 6.6]) + ax0.set_theta_offset(np.pi) + + grid_X = grid_R[:, :, 0, 0] * np.cos(grid_P[:, :, 0, 0]) + grid_Y = grid_R[:, :, 0, 0] * np.sin(grid_P[:, :, 0, 0]) + + pc = ax1.pcolormesh( + grid_X, + grid_Y, + np.squeeze(np.log10(psd_binned[it, :, :, :, :])), + vmin=0, + vmax=4, + cmap="jet", + edgecolors="k", + linewidth=0.1, + ) + ax1.set_title("Assimilation input") + ax1.set_xlim(np.max(grid_R), -np.max(grid_R)) + ax1.set_ylim(np.max(grid_R), -np.max(grid_R)) + ax1.set_xlabel("X") + ax1.set_ylabel("Y") + + fig.colorbar(pc, ax=ax1) + + fig.savefig(Path(debug_plot_settings.folder_path) / f"{debug_plot_settings.satellite_name}_{sim_time_curr}.png") + + # ic(np.log10(psd_binned[it,:,:,V_idx,K_idx])) + + fig.clf() + + if np.any(data_set.P[sat_time_idx] < 0.1): + ic(psd_binned[it, 0, :, :, :]) + asdf + def plot_debug_figures( data_set: RBMDataSet, psd_binned: NDArray[np.float64], sim_time: NDArray[np.object_], - grid_P: NDArray[np.float64], + grid_P: NDArray[np.float64] | None, grid_R: NDArray[np.float64], grid_V: NDArray[np.float64], grid_K: NDArray[np.float64], @@ -374,7 +473,7 @@ def plot_debug_figures( data_set_V_or_Mu = data_set.InvMu if mu_or_V == "Mu" else data_set.InvV - R_or_Lstar_arr = data_set.R0 if grid_P else data_set.Lstar[:, -1] + R_or_Lstar_arr = data_set.R0 if grid_P is not None else data_set.Lstar[:, -1] for it, sim_time_curr in enumerate(tqdm(sim_time)): sat_time_idx = np.argwhere(np.abs(np.asarray(data_set.datetime) - sim_time_curr) <= dt / 2) @@ -384,9 +483,6 @@ def plot_debug_figures( K_idx = np.argmin(np.abs(grid_K[0, R_idx, 0, :] - debug_plot_settings.target_K)) V_idx = np.argmin(np.abs(grid_V[0, R_idx, :, K_idx] - debug_plot_settings.target_V)) - K_idx = 45 - V_idx = 51 - V_lim_min = np.log10(0.9 * np.min([np.nanmin(data_set_V_or_Mu), np.min(grid_V)])) V_lim_max = np.log10(1.1 * np.max([np.nanmax(data_set_V_or_Mu), np.max(grid_V)])) diff --git a/swvo/io/RBMDataSet/custom_enums.py b/swvo/io/RBMDataSet/custom_enums.py index 95522a8..dd30b2d 100644 --- a/swvo/io/RBMDataSet/custom_enums.py +++ b/swvo/io/RBMDataSet/custom_enums.py @@ -134,8 +134,11 @@ class SatelliteEnum(Satellite, Enum): ] SatelliteLike: TypeAlias = SatelliteLiteral | SatelliteEnum | Satellite +@dataclass(frozen=True) +class Instrument: + instrument_name: str -class InstrumentEnum(Enum): +class InstrumentEnum(Instrument, Enum): # RBSP HOPE = "hope" MAGEIS = "mageis" @@ -171,10 +174,13 @@ class InstrumentEnum(Enum): "orbit", "TED-electron", ] -InstrumentLike: TypeAlias = InstrumentLiteral | InstrumentEnum +InstrumentLike: TypeAlias = InstrumentLiteral | InstrumentEnum | Instrument +@dataclass(frozen=True) +class Mfm: + mfm_name: str -class MfmEnum(Enum): +class MfmEnum(Mfm, Enum): T89 = "T89" T04s = "T04s" T96 = "T96" @@ -184,18 +190,4 @@ class MfmEnum(Enum): MfmEnumLiteral = Literal["T89", "T04s", "TS04", "T96", "OP77", "T04"] -MfmLike: TypeAlias = MfmEnumLiteral | MfmEnum - - -class DummyEnum(Enum): - SATELLITE = Satellite( - sat_name="dummy", - mission="dummy", - folder_type=FolderTypeEnum.NoFolder, - file_cadence=FileCadenceEnum.NoCadence, - ) - INSTRUMENT = "dummy instrument" - MFM = "dummy mfm" - - -DummyLike: TypeAlias = Literal["dummy"] | DummyEnum +MfmLike: TypeAlias = MfmEnumLiteral | MfmEnum | Mfm \ No newline at end of file diff --git a/tests/io/RBMDataSet/test_RBMDataset.py b/tests/io/RBMDataSet/test_RBMDataset.py index 21ec896..23253a6 100644 --- a/tests/io/RBMDataSet/test_RBMDataset.py +++ b/tests/io/RBMDataSet/test_RBMDataset.py @@ -20,7 +20,6 @@ SatelliteLiteral, VariableEnum, ) -from swvo.io.RBMDataSet.custom_enums import DummyEnum @pytest.fixture @@ -316,9 +315,9 @@ def test_enable_dict_loading_mode(mock_dataset: RBMDataSet): def dict_dataset(): """Fixture for dictionary-based loading mode (no file parameters)""" return RBMDataSet( - satellite=DummyEnum.SATELLITE, - instrument=DummyEnum.INSTRUMENT, - mfm=DummyEnum.MFM, + satellite="RBSPA", + instrument="hope", + mfm="T89", ) @@ -639,15 +638,15 @@ def test_eq_file_loading_mode_different_mfm(mock_module_string): def test_eq_dict_mode_identical(): """Test equality for identical dict mode datasets.""" dataset1 = RBMDataSet( - satellite=DummyEnum.SATELLITE, - instrument=DummyEnum.INSTRUMENT, - mfm=DummyEnum.MFM, + satellite="RBSPA", + instrument="hope", + mfm="T89", ) dataset2 = RBMDataSet( - satellite=DummyEnum.SATELLITE, - instrument=DummyEnum.INSTRUMENT, - mfm=DummyEnum.MFM, + satellite="RBSPA", + instrument="hope", + mfm="T89", ) test_data = { @@ -666,15 +665,15 @@ def test_eq_dict_mode_identical(): def test_eq_dict_mode_different_variables(): """Test inequality for dict mode datasets with different variables.""" dataset1 = RBMDataSet( - satellite=DummyEnum.SATELLITE, - instrument=DummyEnum.INSTRUMENT, - mfm=DummyEnum.MFM, + satellite="RBSPA", + instrument="hope", + mfm="T89", ) dataset2 = RBMDataSet( - satellite=DummyEnum.SATELLITE, - instrument=DummyEnum.INSTRUMENT, - mfm=DummyEnum.MFM, + satellite="RBSPA", + instrument="hope", + mfm="T89", ) dataset1.update_from_dict({"Flux": np.array([[1.0, 2.0, 3.0]])}) @@ -686,15 +685,15 @@ def test_eq_dict_mode_different_variables(): def test_eq_dict_mode_same_variables_different_values(): """Test inequality for dict mode datasets with same variables but different values.""" dataset1 = RBMDataSet( - satellite=DummyEnum.SATELLITE, - instrument=DummyEnum.INSTRUMENT, - mfm=DummyEnum.MFM, + satellite="RBSPA", + instrument="hope", + mfm="T89", ) dataset2 = RBMDataSet( - satellite=DummyEnum.SATELLITE, - instrument=DummyEnum.INSTRUMENT, - mfm=DummyEnum.MFM, + satellite="RBSPA", + instrument="hope", + mfm="T89", ) dataset1.update_from_dict({"Flux": np.array([[1.0, 2.0, 3.0]])}) @@ -725,9 +724,9 @@ def test_eq_different_modes(mock_module_string): # Dict mode dataset dict_dataset = RBMDataSet( - satellite=DummyEnum.SATELLITE, - instrument=DummyEnum.INSTRUMENT, - mfm=DummyEnum.MFM, + satellite="RBSPA", + instrument="hope", + mfm="T89", ) assert file_dataset != dict_dataset @@ -736,15 +735,15 @@ def test_eq_different_modes(mock_module_string): def test_eq_array_comparison_with_nan(): """Test equality with NaN values in arrays.""" dataset1 = RBMDataSet( - satellite=DummyEnum.SATELLITE, - instrument=DummyEnum.INSTRUMENT, - mfm=DummyEnum.MFM, + satellite="RBSPA", + instrument="hope", + mfm="T89", ) dataset2 = RBMDataSet( - satellite=DummyEnum.SATELLITE, - instrument=DummyEnum.INSTRUMENT, - mfm=DummyEnum.MFM, + satellite="RBSPA", + instrument="hope", + mfm="T89", ) test_array = np.array([[1.0, np.nan, 3.0]]) @@ -757,15 +756,15 @@ def test_eq_array_comparison_with_nan(): def test_eq_list_comparison(): """Test equality with list variables.""" dataset1 = RBMDataSet( - satellite=DummyEnum.SATELLITE, - instrument=DummyEnum.INSTRUMENT, - mfm=DummyEnum.MFM, + satellite="RBSPA", + instrument="hope", + mfm="T89", ) dataset2 = RBMDataSet( - satellite=DummyEnum.SATELLITE, - instrument=DummyEnum.INSTRUMENT, - mfm=DummyEnum.MFM, + satellite="RBSPA", + instrument="hope", + mfm="T89", ) test_datetime_list = [dt.datetime(2023, 1, 15, tzinfo=timezone.utc)] @@ -778,15 +777,15 @@ def test_eq_list_comparison(): def test_eq_list_different_lengths(): """Test inequality with lists of different lengths.""" dataset1 = RBMDataSet( - satellite=DummyEnum.SATELLITE, - instrument=DummyEnum.INSTRUMENT, - mfm=DummyEnum.MFM, + satellite="RBSPA", + instrument="hope", + mfm="T89", ) dataset2 = RBMDataSet( - satellite=DummyEnum.SATELLITE, - instrument=DummyEnum.INSTRUMENT, - mfm=DummyEnum.MFM, + satellite="RBSPA", + instrument="hope", + mfm="T89", ) dataset1.datetime = [dt.datetime(2023, 1, 15, tzinfo=timezone.utc)] @@ -801,15 +800,15 @@ def test_eq_list_different_lengths(): def test_eq_array_different_shapes(): """Test inequality with arrays of different shapes.""" dataset1 = RBMDataSet( - satellite=DummyEnum.SATELLITE, - instrument=DummyEnum.INSTRUMENT, - mfm=DummyEnum.MFM, + satellite="RBSPA", + instrument="hope", + mfm="T89", ) dataset2 = RBMDataSet( - satellite=DummyEnum.SATELLITE, - instrument=DummyEnum.INSTRUMENT, - mfm=DummyEnum.MFM, + satellite="RBSPA", + instrument="hope", + mfm="T89", ) dataset1.update_from_dict({"Flux": np.array([[1.0, 2.0]])}) @@ -821,15 +820,15 @@ def test_eq_array_different_shapes(): def test_eq_different_types(): """Test inequality when same variable has different types.""" dataset1 = RBMDataSet( - satellite=DummyEnum.SATELLITE, - instrument=DummyEnum.INSTRUMENT, - mfm=DummyEnum.MFM, + satellite="RBSPA", + instrument="hope", + mfm="T89", ) dataset2 = RBMDataSet( - satellite=DummyEnum.SATELLITE, - instrument=DummyEnum.INSTRUMENT, - mfm=DummyEnum.MFM, + satellite="RBSPA", + instrument="hope", + mfm="T89", ) dataset1.time = np.array([738000.0]) From 68a7e703589223bd2bfecd3b4662635dd4b52398 Mon Sep 17 00:00:00 2001 From: Sahil Jhawar Date: Thu, 30 Oct 2025 14:13:31 +0100 Subject: [PATCH 22/24] ruff formatting and fixes --- swvo/io/RBMDataSet/RBMDataSet.py | 2 - .../bin_and_interpolate_to_model_grid.py | 41 +++++++++---------- swvo/io/RBMDataSet/custom_enums.py | 6 ++- 3 files changed, 25 insertions(+), 24 deletions(-) diff --git a/swvo/io/RBMDataSet/RBMDataSet.py b/swvo/io/RBMDataSet/RBMDataSet.py index 2d248cc..acb3856 100644 --- a/swvo/io/RBMDataSet/RBMDataSet.py +++ b/swvo/io/RBMDataSet/RBMDataSet.py @@ -6,7 +6,6 @@ import datetime as dt import typing -from dataclasses import replace from datetime import timedelta, timezone from pathlib import Path from typing import Any, Literal @@ -239,7 +238,6 @@ def find_similar_variable(self, name: str) -> tuple[None | VariableEnum, dict[st levenstein_info: dict[str, Any] = {"min_distance": 10, "var_name": ""} sat_variable = None for var in VariableEnum: - if name == var.var_name: sat_variable = var break diff --git a/swvo/io/RBMDataSet/bin_and_interpolate_to_model_grid.py b/swvo/io/RBMDataSet/bin_and_interpolate_to_model_grid.py index 78fc199..2a1f041 100644 --- a/swvo/io/RBMDataSet/bin_and_interpolate_to_model_grid.py +++ b/swvo/io/RBMDataSet/bin_and_interpolate_to_model_grid.py @@ -12,8 +12,8 @@ from pathlib import Path from typing import TYPE_CHECKING, Literal -from icecream import ic import numpy as np +from icecream import ic from matplotlib import pyplot as plt from numpy.typing import NDArray from tqdm import tqdm @@ -47,7 +47,6 @@ def bin_and_interpolate_to_model_grid( # 1. interpolate to V-K if grid_R.shape[2] > 1 and grid_R.shape[3] > 1: - if target_var_init.ndim == 1: target_var_init = target_var_init[:, np.newaxis, np.newaxis] @@ -62,8 +61,7 @@ def bin_and_interpolate_to_model_grid( msg = "Found inconsitency in V-K interpolation. Aborting..." raise (ValueError(msg)) else: - - if target_var_init.ndim == 1: # plasmasphere + if target_var_init.ndim == 1: # plasmasphere target_var_init = target_var_init[:, np.newaxis, np.newaxis] psd_interp = target_var_init @@ -98,14 +96,15 @@ def bin_and_interpolate_to_model_grid( mu_or_V, debug_plot_settings, ) - else: plot_debug_figures_plasmasphere( - self, - psd_binned_in_time, - sim_time, - grid_P, - grid_R, - debug_plot_settings, - ) + else: + plot_debug_figures_plasmasphere( + self, + psd_binned_in_time, + sim_time, + grid_P, + grid_R, + debug_plot_settings, + ) return psd_binned_in_time @@ -385,11 +384,8 @@ def plot_debug_figures_plasmasphere( grid_R: NDArray[np.float64], debug_plot_settings: DebugPlotSettings, ): - print("\tPlot debug features...") - from icecream import ic - dt = sim_time[1] - sim_time[0] fig = plt.figure(figsize=(19.20, 8)) @@ -400,8 +396,7 @@ def plot_debug_figures_plasmasphere( for it, sim_time_curr in enumerate(tqdm(sim_time)): sat_time_idx = np.argwhere(np.abs(np.asarray(data_set.datetime) - sim_time_curr) <= dt / 2) - R_idx = np.argwhere(np.abs(grid_R[0, :, 0, 0] - R_or_Lstar_arr[sat_time_idx])) - + # R_idx = np.argwhere(np.abs(grid_R[0, :, 0, 0] - R_or_Lstar_arr[sat_time_idx])) ax0 = fig.add_subplot(121, projection="polar") ax1 = fig.add_subplot(122) @@ -412,9 +407,15 @@ def plot_debug_figures_plasmasphere( # ic(data_set.P[sat_time_idx]) # ic(R_or_Lstar_arr[sat_time_idx]) - ax0.scatter(data_set.P[sat_time_idx], R_or_Lstar_arr[sat_time_idx], c=np.log10(data_set.density[sat_time_idx]), marker="D", vmin=0, + ax0.scatter( + data_set.P[sat_time_idx], + R_or_Lstar_arr[sat_time_idx], + c=np.log10(data_set.density[sat_time_idx]), + marker="D", + vmin=0, vmax=4, - cmap="jet",) + cmap="jet", + ) ax0.set_ylim(1, 6.6) ax0.set_title("Orbit") ax0.set_rlim([0, 6.6]) @@ -449,8 +450,6 @@ def plot_debug_figures_plasmasphere( if np.any(data_set.P[sat_time_idx] < 0.1): ic(psd_binned[it, 0, :, :, :]) - asdf - def plot_debug_figures( diff --git a/swvo/io/RBMDataSet/custom_enums.py b/swvo/io/RBMDataSet/custom_enums.py index dd30b2d..a701106 100644 --- a/swvo/io/RBMDataSet/custom_enums.py +++ b/swvo/io/RBMDataSet/custom_enums.py @@ -134,10 +134,12 @@ class SatelliteEnum(Satellite, Enum): ] SatelliteLike: TypeAlias = SatelliteLiteral | SatelliteEnum | Satellite + @dataclass(frozen=True) class Instrument: instrument_name: str + class InstrumentEnum(Instrument, Enum): # RBSP HOPE = "hope" @@ -176,10 +178,12 @@ class InstrumentEnum(Instrument, Enum): ] InstrumentLike: TypeAlias = InstrumentLiteral | InstrumentEnum | Instrument + @dataclass(frozen=True) class Mfm: mfm_name: str + class MfmEnum(Mfm, Enum): T89 = "T89" T04s = "T04s" @@ -190,4 +194,4 @@ class MfmEnum(Mfm, Enum): MfmEnumLiteral = Literal["T89", "T04s", "TS04", "T96", "OP77", "T04"] -MfmLike: TypeAlias = MfmEnumLiteral | MfmEnum | Mfm \ No newline at end of file +MfmLike: TypeAlias = MfmEnumLiteral | MfmEnum | Mfm From e253b77c26a49e4dc9018bb607aeda2bd37fbeb7 Mon Sep 17 00:00:00 2001 From: Sahil Jhawar Date: Thu, 30 Oct 2025 14:20:31 +0100 Subject: [PATCH 23/24] add icecream to reqs --- requirements.txt | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index aa87c46..568140a 100755 --- a/requirements.txt +++ b/requirements.txt @@ -14,4 +14,6 @@ wget distance tqdm pytest-mock -netcdf4 \ No newline at end of file +netcdf4 +icecream + From 2c8878cf30785b32d5d406d4e188af765903fa9e Mon Sep 17 00:00:00 2001 From: Sahil Jhawar Date: Thu, 30 Oct 2025 15:19:32 +0100 Subject: [PATCH 24/24] source .env when building docs locally --- docs/makefile | 3 +++ 1 file changed, 3 insertions(+) diff --git a/docs/makefile b/docs/makefile index c1518e4..2ac3148 100644 --- a/docs/makefile +++ b/docs/makefile @@ -3,6 +3,9 @@ # # You can set these variables from the command line. +ifneq (,$(wildcard $(PWD)/../.env)) +include $(PWD)/../.env +endif SPHINXOPTS = SPHINXBUILD = sphinx-build PAPER =