diff --git a/.gitignore b/.gitignore index e6070e0..a88fd58 100644 --- a/.gitignore +++ b/.gitignore @@ -22,4 +22,5 @@ cron_check.sh swace_commented_code docs/_build/* docs/api/* -.nfs* \ No newline at end of file +.nfs* +.env diff --git a/docs/makefile b/docs/makefile index c1518e4..2ac3148 100644 --- a/docs/makefile +++ b/docs/makefile @@ -3,6 +3,9 @@ # # You can set these variables from the command line. +ifneq (,$(wildcard $(PWD)/../.env)) +include $(PWD)/../.env +endif SPHINXOPTS = SPHINXBUILD = sphinx-build PAPER = diff --git a/requirements.txt b/requirements.txt index aa87c46..568140a 100755 --- a/requirements.txt +++ b/requirements.txt @@ -14,4 +14,6 @@ wget distance tqdm pytest-mock -netcdf4 \ No newline at end of file +netcdf4 +icecream + diff --git a/swvo/io/RBMDataSet/RBMDataSet.py b/swvo/io/RBMDataSet/RBMDataSet.py index c72e5e7..acb3856 100644 --- a/swvo/io/RBMDataSet/RBMDataSet.py +++ b/swvo/io/RBMDataSet/RBMDataSet.py @@ -15,17 +15,22 @@ from dateutil.relativedelta import relativedelta from numpy.typing import NDArray +from swvo.io.exceptions import VariableNotFoundError from swvo.io.RBMDataSet import ( FileCadenceEnum, FolderTypeEnum, + Instrument, InstrumentEnum, InstrumentLike, + Mfm, MfmEnum, MfmLike, + Satellite, SatelliteEnum, SatelliteLike, Variable, VariableEnum, + VariableLiteral, ) from swvo.io.RBMDataSet.utils import ( get_file_path_any_format, @@ -38,6 +43,32 @@ class RBMDataSet: """RBMDataSet class for loading and managing data. + This class can load data either from files or from a dictionary. + + For file-based loading, provide `start_time`, `end_time`, and `folder_path`. + For dictionary-based loading, initialize without these parameters and use `update_from_dict()`. + + Parameters + ---------- + satellite : Union[:class:`SatelliteLike`, :class:`DummyLike`] + Satellite identifier as enum or string. + instrument : Union[:class:`InstrumentLike`, :class:`DummyLike`] + Instrument enumeration or string. + mfm : Union[:class:`MfmLike`, :class:`DummyLike`] + Magnetic field model enum or string. + start_time : dt.datetime, optional + Start time for file-based loading. + end_time : dt.datetime, optional + End time for file-based loading. + folder_path : Path, optional + Base folder path for file-based loading. + preferred_extension : Literal["mat", "pickle"], optional + Preferred file extension for file-based loading. Default is "pickle". + verbose : bool, optional + Whether to print verbose output. Default is True. + enable_dict_loading : bool, optional + Enable dictionary-based loading even in file mode. Default is False. + Attributes ---------- datetime : list[dt.datetime] @@ -90,47 +121,66 @@ class RBMDataSet: def __init__( self, - start_time: dt.datetime, - end_time: dt.datetime, - folder_path: Path, satellite: SatelliteLike, instrument: InstrumentLike, mfm: MfmLike, + start_time: dt.datetime | None = None, + end_time: dt.datetime | None = None, + folder_path: Path | None = None, preferred_extension: Literal["mat", "pickle"] = "pickle", *, verbose: bool = True, + enable_dict_loading: bool = False, ) -> None: - if start_time.tzinfo is None: - start_time = start_time.replace(tzinfo=timezone.utc) - - if end_time.tzinfo is None: - end_time = end_time.replace(tzinfo=timezone.utc) - - self._start_time = start_time - self._end_time = end_time - + self.possible_variables: list[str] = list(VariableLiteral.__args__) + # Handle satellite conversion with special cases for GOES if isinstance(satellite, str): - satellite = SatelliteEnum[satellite.upper()] - self._satellite = satellite + if satellite.lower() == "goesprimary": + satellite = SatelliteEnum["GOESPrimary"] + elif satellite.lower() == "goessecondary": + satellite = SatelliteEnum["GOESSecondary"] + else: + satellite = SatelliteEnum[satellite.upper()] if isinstance(instrument, str): instrument = InstrumentEnum[instrument.upper()] - self._instrument = instrument if isinstance(mfm, str): mfm = MfmEnum[mfm.upper()] - self._mfm = mfm - self._folder_path = Path(folder_path) - - self._preferred_ext = preferred_extension - self._folder_type = self._satellite.folder_type + # Store the original satellite enum for properties and other attributes + self._satellite = satellite + self._instrument = instrument + self._mfm = mfm self._verbose = verbose - self._file_path_stem = self._create_file_path_stem() - self._file_name_stem = self._create_file_name_stem() - self._file_cadence = self._satellite.file_cadence - self._date_of_files = self._create_date_list() + # For dict-based loading, modify satellite properties + if start_time is None and end_time is None and folder_path is None: + self._file_loading_mode = False + else: + # File loading mode: need all parameters + if start_time is None or end_time is None or folder_path is None: + msg = "For file-based loading, start_time, end_time, and folder_path must all be provided" + raise ValueError(msg) + + if start_time.tzinfo is None: + start_time = start_time.replace(tzinfo=timezone.utc) + + if end_time.tzinfo is None: + end_time = end_time.replace(tzinfo=timezone.utc) + + self._start_time = start_time + self._end_time = end_time + self._satellite = satellite + self._folder_path = Path(folder_path) + self._preferred_ext = preferred_extension + self._folder_type = self._satellite.folder_type + self._file_path_stem = self._create_file_path_stem() + self._file_name_stem = self._create_file_name_stem() + self._file_cadence = self._satellite.file_cadence + self._date_of_files = self._create_date_list() + self._file_loading_mode = True + self._enable_dict_loading = enable_dict_loading def __repr__(self): return f"{self.__class__.__name__}({self._satellite}, {self._instrument}, {self._mfm})" @@ -142,41 +192,116 @@ def __dir__(self): return list(super().__dir__()) + [var.var_name for var in VariableEnum] def __getattr__(self, name: str): + # Avoid recursion for internal attributes + if name.startswith("_"): + raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'") + + # Handle computed properties for both modes + if name == "P": + if not hasattr(self, "MLT") or getattr(self, "MLT") is None or not isinstance(self.MLT, np.ndarray): # type: ignore[reportUnnecessaryIsInstance] + raise AttributeError("Cannot compute `P` because `MLT` is missing, not loaded or is not valid array.") + return ((self.MLT + 12) / 12 * np.pi) % (2 * np.pi) + + if name == "InvV": + if not all(hasattr(self, attr) for attr in ("InvK", "InvMu")): + raise AttributeError("Cannot compute `InvV` because `InvK` or `InvMu` is missing.") + if not isinstance(self.InvK, np.ndarray) or not isinstance(self.InvMu, np.ndarray): # type: ignore[reportUnnecessaryIsInstance] + raise AttributeError("Cannot compute `InvV` because required arrays are invalid or not loaded.") + if self.InvK.ndim < 1 or self.InvMu.ndim < 2: + raise AttributeError("Cannot compute `InvV` because array dimensions are insufficient.") + inv_K_repeated = np.repeat(self.InvK[:, np.newaxis, :], self.InvMu.shape[1], axis=1) + return self.InvMu * (inv_K_repeated + 0.5) ** 2 + # check if a sat variable is requested # if we find a similar word, suggest that to the user sat_variable = None + sat_variable, levenstein_info = self.find_similar_variable(name) + + if sat_variable is not None and self._file_loading_mode: + self._load_variable(sat_variable) + return getattr(self, name) + + if not self._file_loading_mode and name in self.possible_variables: + raise AttributeError( + f"Attribute '{name}' exists in `VariableLiteral` but has not been set. " + "Call `update_from_dict()` before accessing it." + ) + + if levenstein_info["min_distance"] <= 2: + msg = f"{self.__class__.__name__} object has no attribute {name}. Maybe you meant {levenstein_info['var_name']}?" + else: + msg = f"{self.__class__.__name__} object has no attribute {name}" + + raise AttributeError(msg) + + def find_similar_variable(self, name: str) -> tuple[None | VariableEnum, dict[str, Any]]: levenstein_info: dict[str, Any] = {"min_distance": 10, "var_name": ""} + sat_variable = None for var in VariableEnum: if name == var.var_name: sat_variable = var break else: dist = distance.levenshtein(name, var.var_name) - if name.lower() in var.name.lower(): + if name.lower() in var.var_name.lower(): dist = 1 if dist < levenstein_info["min_distance"]: levenstein_info["min_distance"] = dist levenstein_info["var_name"] = var.var_name - # if yes, load it - if sat_variable is not None: - self._load_variable(sat_variable) - - return getattr(self, name) - - if levenstein_info["min_distance"] <= 2: - msg = f"{self.__class__.__name__} object has no attribute {name}. Maybe you meant {levenstein_info['var_name']}?" - else: - msg = f"{self.__class__.__name__} object has no attribute {name}" - - raise AttributeError(msg) - - # def __getitem__(self, key:str): - # return getattr(self, key:str) - - # def __setitem__(self, key, value): - # setattr(self, key, value) + return sat_variable, levenstein_info + + @property + def satellite(self) -> SatelliteEnum | Satellite: + """Returns the satellite enum.""" + return self._satellite + + @property + def instrument(self) -> InstrumentEnum | Instrument: + """Returns the instrument enum.""" + return self._instrument + + @property + def mfm(self) -> MfmEnum | Mfm: + """Returns the MFM enum.""" + return self._mfm + + def update_from_dict(self, source_dict: dict[str, VariableLiteral]) -> RBMDataSet: + """Get data from data dictionary and update the object. + + Parameters + ---------- + source_dict : dict[str, VariableLiteral] + Dictionary containing the data to be loaded into the object. + + Returns + ------- + RBMDataSet + The updated RBMDataSet object. + + Raises + ------ + VariableNotFoundError + If a key in the `source_dict` is not a valid `VariableLiteral`. + RuntimeError + If the `RBMDataSet` is in file loading mode and dictionary loading is not enabled. + + """ + if self._file_loading_mode and not self._enable_dict_loading: + msg = "RBMDataSet is in file loading mode. Cannot update from dictionary. To use dictionary-based loading, set `enable_dict_loading=True` during initialization." + raise RuntimeError(msg) + for key, value in source_dict.items(): + _, levenstein_info = self.find_similar_variable(key) + if key in self.possible_variables: + setattr(self, key, value) + elif levenstein_info["min_distance"] <= 2: + msg = f"Key '{key}' is not a valid `VariableLiteral`. Maybe you meant '{levenstein_info['var_name']}'?" + raise VariableNotFoundError(msg) + else: + msg = f"Key '{key}' is not a valid `VariableLiteral`." + raise VariableNotFoundError(msg) + return self def get_var(self, var: VariableEnum): return getattr(self, var.var_name) @@ -216,13 +341,13 @@ def _create_file_name_stem(self) -> str: # if self._satellite == SatelliteEnum.THEMIS: # pass - return self._satellite.sat_name + "_" + self._instrument.value + "_" + return self._satellite.sat_name + "_" + self._instrument.instrument_name + "_" def get_satellite_name(self) -> str: return self._satellite.sat_name def get_satellite_and_instrument_name(self) -> str: - return self._satellite.sat_name + "_" + self._instrument.value + return self._satellite.sat_name + "_" + self._instrument.instrument_name def set_file_path_stem(self, file_path_stem: Path): self._file_path_stem = file_path_stem @@ -238,7 +363,7 @@ def set_file_cadence(self, file_cadence: FileCadenceEnum): return self def get_print_name(self) -> str: - return self._satellite.sat_name + " " + self._instrument.value + return self._satellite.sat_name + " " + self._instrument.instrument_name def _load_variable(self, var: Variable | VariableEnum) -> None: loaded_var_arrs: dict[str, NDArray[np.number]] = {} @@ -264,7 +389,7 @@ def _load_variable(self, var: Variable | VariableEnum) -> None: file_name_no_format = self._file_name_stem + date_str + "_" + var.mat_file_prefix if var.mat_has_B: - file_name_no_format += "_n4_4_" + self._mfm.value + file_name_no_format += "_n4_4_" + self._mfm.mfm_name file_name_no_format += "_ver4" else: @@ -327,28 +452,41 @@ def _load_variable(self, var: Variable | VariableEnum) -> None: setattr(self, var_name, loaded_var_arrs[var_name]) + def get_loaded_variables(self) -> list[str]: + """Get a list of currently loaded variable names.""" + loaded_vars = [] + for var in VariableEnum: + if hasattr(self, var.var_name): + loaded_vars.append(var.var_name) + return loaded_vars + def __eq__(self, other: RBMDataSet) -> bool: - if self._satellite != other._satellite: - return False - if self._instrument != other._instrument: + if ( + self._file_loading_mode != other._file_loading_mode + or self._satellite != other._satellite + or self._instrument != other._instrument + or self._mfm != other._mfm + ): return False - if self._mfm != other._mfm: + + self_vars = self.get_loaded_variables() + other_vars = other.get_loaded_variables() + if self_vars != other_vars: return False + variables = self_vars - for var in VariableEnum: - self_var = getattr(self, var.var_name) - other_var = getattr(other, var.var_name) + for var in variables: + self_var = getattr(self, var) + other_var = getattr(other, var) - if isinstance(self_var, list) and isinstance(other_var, list): - if len(self_var) != len(other_var): - return False - for dt1, dt2 in zip(self_var, other_var): - if dt1 != dt2: - return False - elif isinstance(self_var, np.ndarray) and isinstance(other_var, np.ndarray): - if self_var.shape != other_var.shape: + if not isinstance(other_var, type(self_var)): + return False + + if isinstance(self_var, list): + if len(self_var) != len(other_var) or any(a != b for a, b in zip(self_var, other_var)): return False - if not np.allclose(self_var, other_var, equal_nan=True): + elif isinstance(self_var, np.ndarray): + if self_var.shape != other_var.shape or not np.allclose(self_var, other_var, equal_nan=True): return False elif self_var != other_var: return False diff --git a/swvo/io/RBMDataSet/RBMDataSetElPaso.py b/swvo/io/RBMDataSet/RBMDataSetElPaso.py deleted file mode 100644 index d5a4fac..0000000 --- a/swvo/io/RBMDataSet/RBMDataSetElPaso.py +++ /dev/null @@ -1,239 +0,0 @@ -# SPDX-FileCopyrightText: 2025 GFZ Helmholtz Centre for Geosciences -# -# SPDX-License-Identifier: Apache-2.0 - -from __future__ import annotations - -import datetime as dt -from dataclasses import replace -from datetime import timezone -from typing import Any, ClassVar, TypeVar - -import numpy as np -from numpy.typing import NDArray - -from swvo.io.RBMDataSet import ( - FileCadenceEnum, - FolderTypeEnum, - InstrumentEnum, - MfmEnum, - SatelliteEnum, - SatelliteLike, - VariableEnum, -) -from swvo.io.RBMDataSet.custom_enums import ElPasoMFMEnum -from swvo.io.RBMDataSet.utils import python2matlab - -Variable = TypeVar( - "Variable" -) # this is a placeholder for the actual Variable class from elpaso and not the one in RBMDataSet - - -class RBMDataSetElPaso: - """RBMDataSetElPaso class for loading ElPaso data to RBMDataSet. - - Parameters - ---------- - satellite : :class:`SatelliteLike` - Satellite identifier as enum or string. - instrument : :class:`InstrumentEnum` - Instrument enumeration. - mfm : :class:`MfmEnum` - Magnetic field model enum. - - - Attributes - ---------- - datetime : list[dt.datetime] - time : NDArray[np.float64] - energy_channels : NDArray[np.float64] - alpha_local : NDArray[np.float64] - alpha_eq_model : NDArray[np.float64] - alpha_eq_real : NDArray[np.float64] - InvMu : NDArray[np.float64] - InvMu_real : NDArray[np.float64] - InvK : NDArray[np.float64] - InvV : NDArray[np.float64] - Lstar : NDArray[np.float64] - Flux : NDArray[np.float64] - PSD : NDArray[np.float64] - MLT : NDArray[np.float64] - B_SM : NDArray[np.float64] - B_total : NDArray[np.float64] - B_sat : NDArray[np.float64] - xGEO : NDArray[np.float64] - P : NDArray[np.float64] - R0 : NDArray[np.float64] - density : NDArray[np.float64] - - """ - - _variable_mapping: ClassVar[dict[str, str]] = { - "Epoch_posixtime": "time", - "Energy_FEDU": "energy_channels", - "Energy_FPDU": "energy_channels", - "Energy_FEIU": "energy_channels", - "Energy_FEDO": "energy_channels", - "PA_local": "alpha_local", - "PA_eq_": "alpha_eq_model", - "alpha_eq_real": "alpha_eq_real", - "invMu_": "InvMu", - "InvMu_real": "InvMu_real", - "invK_": "InvK", - # "InvV": "InvV",# computed property - "Lstar_": "Lstar", - "FEDU": "Flux", - "FPDU": "Flux", - "FEIU": "Flux", - "FEDO": "Flux", - "PSD_FEDU": "PSD", - "PSD_FPDU": "PSD", - "PSD_FEIU": "PSD", - "PSD_FEDO": "PSD", - "MLT_": "MLT", - "B_SM": "B_SM", - "B_eq_": "B_total", - "B_local_": "B_sat", - "xGEO": "xGEO", - # "P": "P",# computed property - "R_eq_": "R0", - "density": "density", - } - - datetime: list[dt.datetime] - time: NDArray[np.float64] - energy_channels: NDArray[np.float64] - alpha_local: NDArray[np.float64] - alpha_eq_model: NDArray[np.float64] - alpha_eq_real: NDArray[np.float64] - InvMu: NDArray[np.float64] - InvMu_real: NDArray[np.float64] - InvK: NDArray[np.float64] - InvV: NDArray[np.float64] - Lstar: NDArray[np.float64] - Flux: NDArray[np.float64] - PSD: NDArray[np.float64] - MLT: NDArray[np.float64] - B_SM: NDArray[np.float64] - B_total: NDArray[np.float64] - B_sat: NDArray[np.float64] - xGEO: NDArray[np.float64] - P: NDArray[np.float64] - R0: NDArray[np.float64] - density: NDArray[np.float64] - - def __init__( - self, - satellite: SatelliteLike, - instrument: InstrumentEnum, - mfm: MfmEnum, - ) -> None: - if isinstance(satellite, str): - if satellite.lower() == "goesprimary": - satellite_enum = SatelliteEnum["GOESPrimary"] - elif satellite.lower() == "goessecondary": - satellite_enum = SatelliteEnum["GOESSecondary"] - else: - satellite_enum = SatelliteEnum[satellite.upper()] - else: - satellite_enum = satellite - - if isinstance(instrument, str): - instrument = InstrumentEnum[instrument.upper()] - satellite_obj = replace( - satellite_enum.value, - folder_type=FolderTypeEnum.NoFolder, - file_cadence=FileCadenceEnum.NoCadence, - ) - - self._satellite = satellite_obj - self._instrument = instrument - self._mfm = mfm - self._mfm_prefix = ElPasoMFMEnum[self._mfm.name].value - - @property - def satellite(self) -> SatelliteEnum: - """Returns the satellite enum.""" - return self._satellite - - @property - def instrument(self) -> InstrumentEnum: - """Returns the instrument enum.""" - return self._instrument - - @property - def mfm(self) -> MfmEnum: - """Returns the MFM enum.""" - return self._mfm - - @property - def variable_mapping(self) -> dict[str, str]: - """Returns the variable mapping dictionary.""" - return self._variable_mapping - - def __dir__(self): - return super().__dir__() + [var.var_name for var in VariableEnum] - - def update_from_dict(self, source_dict: dict[str, Variable]) -> None: - """Get data from ElPaso data dictionary and update the object. - - Parameters - ---------- - source_dict : dict[str, Any] - Dictionary containing the data to be loaded into the object. - - """ - for _, value in source_dict.items(): - if value.standard_name in self._variable_mapping: - target_attr = self._variable_mapping[value.standard_name] - - if value.standard_name == "Epoch_posixtime" and target_attr == "time": - datetimes = [dt.datetime.fromtimestamp(ts, tz=timezone.utc) for ts in value.data] - setattr(self, "datetime", datetimes) - setattr(self, "time", [python2matlab(i) for i in datetimes]) - else: - setattr(self, target_attr, value.data) - - elif value.standard_name.endswith(self._mfm_prefix): - base_key = value.standard_name.replace(self._mfm_prefix, "") - if base_key in self._variable_mapping: - target_attr = self._variable_mapping[base_key] - setattr(self, target_attr, value.data) - - @property - def P(self) -> NDArray[np.float64]: - """Calculate P. - - Returns - ------- - NDArray[np.float64] - The P value calculated from the MLT. - """ - return ((self.MLT + 12) / 12 * np.pi) % (2 * np.pi) - - @property - def InvV(self) -> NDArray[np.float64]: - """Calculate InvV. - - Returns - ------- - NDArray[np.float64] - The InvV value calculated from InvMu and InvK. - """ - inv_K_repeated = np.repeat(self.InvK[:, np.newaxis, :], self.InvMu.shape[1], axis=1) - InvV = self.InvMu * (inv_K_repeated + 0.5) ** 2 - return InvV - - def __getattr__(self, name: str) -> Any: - if name in self._variable_mapping.values(): - raise AttributeError( - f"Attribute '{name}' is mapped but has not been set. " - "Make sure data is loaded or that this attribute is properly assigned." - ) - raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'") - - def __repr__(self): - return f"{self.__class__.__name__}({self.satellite}, {self.instrument}, {self.mfm})" - - def __str__(self): - return self.__repr__() diff --git a/swvo/io/RBMDataSet/RBMDataSetManager.py b/swvo/io/RBMDataSet/RBMDataSetManager.py index 5a08c66..7533a48 100644 --- a/swvo/io/RBMDataSet/RBMDataSetManager.py +++ b/swvo/io/RBMDataSet/RBMDataSetManager.py @@ -110,7 +110,7 @@ def load( Satellite identifier(s) as enum or string. If a single satellite is provided, it can be a string or an enum. instrument : :class:`InstrumentEnum` Instrument enumeration, e.g., :class:`InstrumentEnum.HOPE`. - mfm : MfmEnum + mfm : :class:`MfmEnum` Magnetic field model enum, e.g., :class:`MfmEnum.T89`. folder_type : :class:`FolderTypeEnum`, optional Type of folder where the data is stored, by default :class:`FolderTypeEnum.DataServer`. @@ -152,12 +152,12 @@ def load( return_list.append(cls._instance.data_set_dict[key_tuple]) else: cls._instance.data_set_dict[key_tuple] = RBMDataSet( - start_time, - end_time, - folder_path, - sat, - instrument, - mfm, + satellite=sat, + instrument=instrument, + mfm=mfm, + start_time=start_time, + end_time=end_time, + folder_path=folder_path, verbose=verbose, preferred_extension=preferred_extension, ) diff --git a/swvo/io/RBMDataSet/RBMNcDataSet.py b/swvo/io/RBMDataSet/RBMNcDataSet.py index ee5e211..c6b28db 100644 --- a/swvo/io/RBMDataSet/RBMNcDataSet.py +++ b/swvo/io/RBMDataSet/RBMNcDataSet.py @@ -98,12 +98,12 @@ def __init__( verbose: bool = True, ) -> None: super().__init__( - start_time, - end_time, - folder_path, - satellite, - instrument, + satellite=satellite, + instrument=instrument, mfm=mfm, + start_time=start_time, + end_time=end_time, + folder_path=folder_path, verbose=verbose, ) @@ -141,7 +141,7 @@ def _load_variable(self, var: Variable | VariableEnum) -> None: next_month = start_month + relativedelta(months=1, days=-1) date_str = start_month.strftime("%Y%m%d") + "to" + next_month.strftime("%Y%m%d") - file_name = self._file_name_stem + date_str + "_" + self._mfm.value + ".nc" + file_name = self._file_name_stem + date_str + "_" + self._mfm.mfm_name + ".nc" else: raise NotImplementedError @@ -193,7 +193,7 @@ def _load_variable(self, var: Variable | VariableEnum) -> None: if var_name == "datetime": loaded_var_arrs[var_name] = list(loaded_var_arrs[var_name]) # type: ignore - rbm_var_name = RBMNcDataSet._get_rbm_name(var_name, self._mfm.value) + rbm_var_name = RBMNcDataSet._get_rbm_name(var_name, self._mfm.mfm_name) if rbm_var_name is not None: setattr(self, rbm_var_name, loaded_var_arrs[var_name]) diff --git a/swvo/io/RBMDataSet/__init__.py b/swvo/io/RBMDataSet/__init__.py index f77d064..d51d0f8 100644 --- a/swvo/io/RBMDataSet/__init__.py +++ b/swvo/io/RBMDataSet/__init__.py @@ -11,16 +11,17 @@ Satellite as Satellite, SatelliteLike as SatelliteLike, SatelliteEnum as SatelliteEnum, + Instrument as Instrument, InstrumentEnum as InstrumentEnum, InstrumentLike as InstrumentLike, + Mfm as Mfm, MfmEnum as MfmEnum, MfmLike as MfmLike, - ElPasoMFMEnum as ElPasoMFMEnum, SatelliteLiteral as SatelliteLiteral, + VariableLiteral as VariableLiteral, ) from swvo.io.RBMDataSet.RBMDataSetManager import RBMDataSetManager as RBMDataSetManager from swvo.io.RBMDataSet.interp_functions import TargetType as TargetType from swvo.io.RBMDataSet.scripts.create_RBSP_line_data import create_RBSP_line_data as create_RBSP_line_data from swvo.io.RBMDataSet.RBMDataSet import RBMDataSet as RBMDataSet from swvo.io.RBMDataSet.RBMNcDataSet import RBMNcDataSet as RBMNcDataSet -from swvo.io.RBMDataSet.RBMDataSetElPaso import RBMDataSetElPaso as RBMDataSetElPaso diff --git a/swvo/io/RBMDataSet/bin_and_interpolate_to_model_grid.py b/swvo/io/RBMDataSet/bin_and_interpolate_to_model_grid.py index 7d80409..2a1f041 100644 --- a/swvo/io/RBMDataSet/bin_and_interpolate_to_model_grid.py +++ b/swvo/io/RBMDataSet/bin_and_interpolate_to_model_grid.py @@ -13,6 +13,7 @@ from typing import TYPE_CHECKING, Literal import numpy as np +from icecream import ic from matplotlib import pyplot as plt from numpy.typing import NDArray from tqdm import tqdm @@ -41,27 +42,33 @@ def bin_and_interpolate_to_model_grid( if grid_K.ndim == 3: grid_K = grid_K[np.newaxis, ...] + target_var_init = getattr(self, target_var_name) + # 1. interpolate to V-K - target_var_init = getattr(self, target_var_name) + if grid_R.shape[2] > 1 and grid_R.shape[3] > 1: + if target_var_init.ndim == 1: + target_var_init = target_var_init[:, np.newaxis, np.newaxis] - if target_var_init.ndim == 1: - target_var_init = target_var_init[:, np.newaxis, np.newaxis] + mu_or_V_arr = self.InvMu if mu_or_V == "Mu" else self.InvV + if grid_mu_V.shape[2] > 1: + psd_interp = _interpolate_in_V_K(target_var_init, mu_or_V_arr, self.InvK, grid_mu_V, grid_K) + else: + psd_interp = target_var_init - mu_or_V_arr = self.InvMu if mu_or_V == "Mu" else self.InvV - if grid_mu_V.shape[2] > 1: - psd_interp = _interpolate_in_V_K(target_var_init, mu_or_V_arr, self.InvK, grid_mu_V, grid_K) + # sanity check + if np.min(target_var_init) > np.min(psd_interp) or np.max(target_var_init) < np.max(psd_interp): + msg = "Found inconsitency in V-K interpolation. Aborting..." + raise (ValueError(msg)) else: - psd_interp = target_var_init + if target_var_init.ndim == 1: # plasmasphere + target_var_init = target_var_init[:, np.newaxis, np.newaxis] - # sanity check - if np.min(target_var_init) > np.min(psd_interp) or np.max(target_var_init) < np.max(psd_interp): - msg = "Found inconsitency in V-K interpolation. Aborting..." - raise (ValueError(msg)) + psd_interp = target_var_init # 2. Bin in space - R_or_Lstar_arr = self.R0 if grid_P else self.Lstar[:, -1] + R_or_Lstar_arr = self.R0 if grid_P is not None else self.Lstar[:, -1] psd_binned_in_space = _bin_in_space(psd_interp, self.P, R_or_Lstar_arr, grid_R, grid_P) # sanity check @@ -77,17 +84,27 @@ def bin_and_interpolate_to_model_grid( raise (ValueError(msg)) if debug_plot_settings: - plot_debug_figures( - self, - psd_binned_in_time, - sim_time, - grid_P, - grid_R, - grid_mu_V, - grid_K, - mu_or_V, - debug_plot_settings, - ) + if debug_plot_settings.target_K is not None: + plot_debug_figures( + self, + psd_binned_in_time, + sim_time, + grid_P, + grid_R, + grid_mu_V, + grid_K, + mu_or_V, + debug_plot_settings, + ) + else: + plot_debug_figures_plasmasphere( + self, + psd_binned_in_time, + sim_time, + grid_P, + grid_R, + debug_plot_settings, + ) return psd_binned_in_time @@ -141,7 +158,7 @@ def _bin_in_time( time_indices = _get_time_indices(data_timestamps, _get_time_bins(sim_timestamps)) for i, _ in tqdm(enumerate(sim_time)): - psd_binned[i, ...] = np.nanmean(data_psd[time_indices == i, ...], axis=0) + psd_binned[i, ...] = np.power(10, np.nanmean(np.log10(data_psd[time_indices == i, ...]), axis=0)) return psd_binned @@ -155,7 +172,7 @@ def _bin_in_space( ) -> NDArray[np.float64]: print("\tBin in space...") - if grid_P: + if grid_P is not None: grid_P_1d = grid_P[:, 0, 0, 0] grid_R_1d = grid_R[0, :, 0, 0] @@ -199,7 +216,7 @@ def _bin_in_space( r_idx = np.argmin(np.abs(R_data[it] - grid_R_1d)) - if grid_P_1d: + if grid_P_1d is not None: raw_difference_p = np.abs(P_data[it] - grid_P_1d) min_difference_p = np.where( raw_difference_p <= np.pi, @@ -210,10 +227,15 @@ def _bin_in_space( number_of_observations[it, p_idx, r_idx, :, :] += np.where(np.isnan(psd_in[it, :, :]), 0, 1) psd_binned[it, p_idx, r_idx, :, :] += np.where(np.isnan(psd_in[it, :, :]), 0, np.log10(psd_in[it, :, :])) + else: number_of_observations[it, 0, r_idx, :, :] += np.where(np.isnan(psd_in[it, :, :]), 0, 1) psd_binned[it, 0, r_idx, :, :] += np.where(np.isnan(psd_in[it, :, :]), 0, np.log10(psd_in[it, :, :])) + # # ic(number_of_observations[it, :, :, 0, 0]) + # ic(np.power(10, np.nanmax(psd_binned[it, :, :, 0, 0]))) + # ic(np.power(10, np.nanmax(psd_binned[it, :, :, 0, 0] / number_of_observations[it, :, :, 0, 0]))) + psd_binned = np.where(psd_binned == 0, np.nan, psd_binned) return np.power(10, psd_binned / number_of_observations) @@ -350,15 +372,91 @@ def _parallel_func_VK( class DebugPlotSettings: folder_path: Path satellite_name: str - target_V: float - target_K: float + target_V: float | None = None + target_K: float | None = None + + +def plot_debug_figures_plasmasphere( + data_set: RBMDataSet, + psd_binned: NDArray[np.float64], + sim_time: NDArray[np.object_], + grid_P: NDArray[np.float64] | None, + grid_R: NDArray[np.float64], + debug_plot_settings: DebugPlotSettings, +): + print("\tPlot debug features...") + + dt = sim_time[1] - sim_time[0] + + fig = plt.figure(figsize=(19.20, 8)) + plt.rcParams["axes.axisbelow"] = False + + R_or_Lstar_arr = data_set.R0 + + for it, sim_time_curr in enumerate(tqdm(sim_time)): + sat_time_idx = np.argwhere(np.abs(np.asarray(data_set.datetime) - sim_time_curr) <= dt / 2) + + # R_idx = np.argwhere(np.abs(grid_R[0, :, 0, 0] - R_or_Lstar_arr[sat_time_idx])) + + ax0 = fig.add_subplot(121, projection="polar") + ax1 = fig.add_subplot(122) + + # plot satellite trajectory on PxR grid + # [x_sat, y_sat] = pol2cart(self.P, self.R) + + # ic(data_set.P[sat_time_idx]) + # ic(R_or_Lstar_arr[sat_time_idx]) + + ax0.scatter( + data_set.P[sat_time_idx], + R_or_Lstar_arr[sat_time_idx], + c=np.log10(data_set.density[sat_time_idx]), + marker="D", + vmin=0, + vmax=4, + cmap="jet", + ) + ax0.set_ylim(1, 6.6) + ax0.set_title("Orbit") + ax0.set_rlim([0, 6.6]) + ax0.set_theta_offset(np.pi) + + grid_X = grid_R[:, :, 0, 0] * np.cos(grid_P[:, :, 0, 0]) + grid_Y = grid_R[:, :, 0, 0] * np.sin(grid_P[:, :, 0, 0]) + + pc = ax1.pcolormesh( + grid_X, + grid_Y, + np.squeeze(np.log10(psd_binned[it, :, :, :, :])), + vmin=0, + vmax=4, + cmap="jet", + edgecolors="k", + linewidth=0.1, + ) + ax1.set_title("Assimilation input") + ax1.set_xlim(np.max(grid_R), -np.max(grid_R)) + ax1.set_ylim(np.max(grid_R), -np.max(grid_R)) + ax1.set_xlabel("X") + ax1.set_ylabel("Y") + + fig.colorbar(pc, ax=ax1) + + fig.savefig(Path(debug_plot_settings.folder_path) / f"{debug_plot_settings.satellite_name}_{sim_time_curr}.png") + + # ic(np.log10(psd_binned[it,:,:,V_idx,K_idx])) + + fig.clf() + + if np.any(data_set.P[sat_time_idx] < 0.1): + ic(psd_binned[it, 0, :, :, :]) def plot_debug_figures( data_set: RBMDataSet, psd_binned: NDArray[np.float64], sim_time: NDArray[np.object_], - grid_P: NDArray[np.float64], + grid_P: NDArray[np.float64] | None, grid_R: NDArray[np.float64], grid_V: NDArray[np.float64], grid_K: NDArray[np.float64], @@ -374,7 +472,7 @@ def plot_debug_figures( data_set_V_or_Mu = data_set.InvMu if mu_or_V == "Mu" else data_set.InvV - R_or_Lstar_arr = data_set.R0 if grid_P else data_set.Lstar[:, -1] + R_or_Lstar_arr = data_set.R0 if grid_P is not None else data_set.Lstar[:, -1] for it, sim_time_curr in enumerate(tqdm(sim_time)): sat_time_idx = np.argwhere(np.abs(np.asarray(data_set.datetime) - sim_time_curr) <= dt / 2) @@ -384,9 +482,6 @@ def plot_debug_figures( K_idx = np.argmin(np.abs(grid_K[0, R_idx, 0, :] - debug_plot_settings.target_K)) V_idx = np.argmin(np.abs(grid_V[0, R_idx, :, K_idx] - debug_plot_settings.target_V)) - K_idx = 45 - V_idx = 51 - V_lim_min = np.log10(0.9 * np.min([np.nanmin(data_set_V_or_Mu), np.min(grid_V)])) V_lim_max = np.log10(1.1 * np.max([np.nanmax(data_set_V_or_Mu), np.max(grid_V)])) diff --git a/swvo/io/RBMDataSet/custom_enums.py b/swvo/io/RBMDataSet/custom_enums.py index edb24cd..a701106 100644 --- a/swvo/io/RBMDataSet/custom_enums.py +++ b/swvo/io/RBMDataSet/custom_enums.py @@ -135,7 +135,12 @@ class SatelliteEnum(Satellite, Enum): SatelliteLike: TypeAlias = SatelliteLiteral | SatelliteEnum | Satellite -class InstrumentEnum(Enum): +@dataclass(frozen=True) +class Instrument: + instrument_name: str + + +class InstrumentEnum(Instrument, Enum): # RBSP HOPE = "hope" MAGEIS = "mageis" @@ -171,22 +176,22 @@ class InstrumentEnum(Enum): "orbit", "TED-electron", ] -InstrumentLike: TypeAlias = InstrumentLiteral | InstrumentEnum +InstrumentLike: TypeAlias = InstrumentLiteral | InstrumentEnum | Instrument + + +@dataclass(frozen=True) +class Mfm: + mfm_name: str -class MfmEnum(Enum): +class MfmEnum(Mfm, Enum): T89 = "T89" T04s = "T04s" T96 = "T96" TS04 = "T04s" OP77 = "OP77" + T04 = "T04" -MfmEnumLiteral = Literal["T89", "T04s", "TS04", "T96", "OP77"] -MfmLike: TypeAlias = MfmEnumLiteral | MfmEnum - - -class ElPasoMFMEnum(Enum): - T89 = "T89" - T04s = "T04s" - T04 = "T04" +MfmEnumLiteral = Literal["T89", "T04s", "TS04", "T96", "OP77", "T04"] +MfmLike: TypeAlias = MfmEnumLiteral | MfmEnum | Mfm diff --git a/swvo/io/exceptions.py b/swvo/io/exceptions.py index c44563f..8e275e0 100644 --- a/swvo/io/exceptions.py +++ b/swvo/io/exceptions.py @@ -5,3 +5,7 @@ class ModelError(Exception): pass + + +class VariableNotFoundError(Exception): + pass diff --git a/tests/io/RBMDataSet/test_RBMDataset.py b/tests/io/RBMDataSet/test_RBMDataset.py index 118cbef..23253a6 100644 --- a/tests/io/RBMDataSet/test_RBMDataset.py +++ b/tests/io/RBMDataSet/test_RBMDataset.py @@ -5,6 +5,7 @@ import datetime as dt from datetime import timezone from pathlib import Path +from typing import get_args from unittest import mock import numpy as np @@ -16,6 +17,7 @@ MfmEnum, RBMDataSet, SatelliteEnum, + SatelliteLiteral, VariableEnum, ) @@ -42,12 +44,12 @@ def mock_dataset(): } dataset = RBMDataSet( - start_time=start_time, - end_time=end_time, - folder_path=Path("/mock/path"), satellite=SatelliteEnum.RBSPA, instrument=InstrumentEnum.MAGEIS, mfm=MfmEnum.T89, + start_time=start_time, + end_time=end_time, + folder_path=Path("/mock/path"), preferred_extension="pickle", verbose=False, ) @@ -67,12 +69,12 @@ def test_init_datetime_timezone(mock_module_string): mock.patch(f"{mock_module_string}._create_file_name_stem"), ): dataset = RBMDataSet( - start_time=start_time, - end_time=end_time, - folder_path=Path("/mock/path"), satellite=SatelliteEnum.RBSPA, instrument=InstrumentEnum.MAGEIS, mfm=MfmEnum.T89, + start_time=start_time, + end_time=end_time, + folder_path=Path("/mock/path"), preferred_extension="pickle", ) @@ -101,12 +103,12 @@ def test_satellite_string_input(mock_module_string): with mock.patch(f"{mock_module_string}._create_file_path_stem"): with mock.patch(f"{mock_module_string}._create_file_name_stem"): dataset = RBMDataSet( - start_time=dt.datetime(2023, 1, 1, tzinfo=timezone.utc), - end_time=dt.datetime(2023, 1, 31, tzinfo=timezone.utc), - folder_path=Path("/mock/path"), satellite="RBSPA", instrument=InstrumentEnum.MAGEIS, mfm=MfmEnum.T89, + start_time=dt.datetime(2023, 1, 1, tzinfo=timezone.utc), + end_time=dt.datetime(2023, 1, 31, tzinfo=timezone.utc), + folder_path=Path("/mock/path"), preferred_extension="pickle", ) @@ -169,12 +171,12 @@ def test_all_satellites_work(satellite, mock_module_string): with mock.patch(f"{mock_module_string}._create_file_path_stem"): with mock.patch(f"{mock_module_string}._create_file_name_stem"): dataset = RBMDataSet( - start_time=dt.datetime(2023, 1, 1, tzinfo=timezone.utc), - end_time=dt.datetime(2023, 1, 31, tzinfo=timezone.utc), - folder_path=Path("/mock/path"), satellite=satellite, instrument=InstrumentEnum.HOPE, mfm=MfmEnum.T89, + start_time=dt.datetime(2023, 1, 1, tzinfo=timezone.utc), + end_time=dt.datetime(2023, 1, 31, tzinfo=timezone.utc), + folder_path=Path("/mock/path"), ) assert dataset._satellite == satellite @@ -186,12 +188,12 @@ def test_all_instruments_work(instrument, mock_module_string): with mock.patch(f"{mock_module_string}._create_file_path_stem"): with mock.patch(f"{mock_module_string}._create_file_name_stem"): dataset = RBMDataSet( - start_time=dt.datetime(2023, 1, 1, tzinfo=timezone.utc), - end_time=dt.datetime(2023, 1, 31, tzinfo=timezone.utc), - folder_path=Path("/mock/path"), satellite=SatelliteEnum.RBSPA, instrument=instrument, mfm=MfmEnum.T89, + start_time=dt.datetime(2023, 1, 1, tzinfo=timezone.utc), + end_time=dt.datetime(2023, 1, 31, tzinfo=timezone.utc), + folder_path=Path("/mock/path"), ) assert dataset._instrument == instrument @@ -250,12 +252,12 @@ def test_load_variable_real_file(): end_time = dt.datetime(2025, 4, 30, tzinfo=dt.timezone.utc) dataset = RBMDataSet( - start_time=start_time, - end_time=end_time, - folder_path=Path("path/to/real/files"), # this does not matter for the test satellite=SatelliteEnum.GOESSecondary, instrument=InstrumentEnum.MAGED, mfm=MfmEnum.T89, + start_time=start_time, + end_time=end_time, + folder_path=Path("path/to/real/files"), # this does not matter for the test preferred_extension="pickle", verbose=True, ) @@ -293,3 +295,543 @@ def test_all_variables_in_dir(mock_dataset: RBMDataSet): for var in vars: assert var in mock_dataset.__dir__() + + +def test_disable_dict_loading_mode(mock_dataset: RBMDataSet): + """Test that default file loading mode raises RuntimeError on update_from_dict.""" + with pytest.raises(RuntimeError): + mock_dataset.update_from_dict({"Flux": np.array([[1.0, 2.0, 3.0]])}) + + +def test_enable_dict_loading_mode(mock_dataset: RBMDataSet): + """Test that default file loading mode raises RuntimeError on update_from_dict.""" + mock_dataset._enable_dict_loading = True + mock_dataset.update_from_dict({"Flux": np.array([[1.0, 2.0, 3.0]])}) + + assert hasattr(mock_dataset, "Flux") + + +@pytest.fixture +def dict_dataset(): + """Fixture for dictionary-based loading mode (no file parameters)""" + return RBMDataSet( + satellite="RBSPA", + instrument="hope", + mfm="T89", + ) + + +def test_dict_mode_init_accepts_string_inputs(): + """Test that the class can be initialized with string inputs in dict mode.""" + ds = RBMDataSet(satellite="GOESSecondary", instrument="MAGED", mfm=MfmEnum.T89) + assert ds.satellite.sat_name == "secondary" + assert ds.instrument == InstrumentEnum.MAGED + assert ds.mfm == MfmEnum.T89 + + +def test_dict_mode_repr_and_str(dict_dataset): + """Test repr and str for dict mode""" + assert "RBMDataSet" in repr(dict_dataset) + assert str(dict_dataset.satellite) in repr(dict_dataset) + assert str(dict_dataset.instrument) in repr(dict_dataset) + assert str(dict_dataset.mfm) in repr(dict_dataset) + + assert "RBMDataSet" in str(dict_dataset) + assert str(dict_dataset.satellite) in str(dict_dataset) + assert str(dict_dataset.instrument) in str(dict_dataset) + assert str(dict_dataset.mfm) in str(dict_dataset) + + +def test_update_from_dict_returns_self(dict_dataset): + """Test that update_from_dict returns self for method chaining""" + source_dict = {"Flux": np.array([[1.0, 2.0]])} + + result = dict_dataset.update_from_dict(source_dict) + assert isinstance(result, RBMDataSet) + + +def test_update_from_dict_sets_variables(dict_dataset): + """Test that the correct variable is set with direct key""" + flux_data = np.array([[1.0, 2.0]]) + + source_dict = {"Flux": flux_data} + + dict_dataset.update_from_dict(source_dict) + np.testing.assert_array_equal(dict_dataset.Flux, flux_data) + + +def test_update_from_dict_sets_time(dict_dataset): + """Test that the correct variable is set with direct time key""" + time_data = np.array([738000.0]) + + source_dict = {"time": time_data} + + dict_dataset.update_from_dict(source_dict) + assert hasattr(dict_dataset, "time") + np.testing.assert_array_equal(dict_dataset.time, time_data) + + +def test_update_from_dict_with_multiple_variables(dict_dataset): + """Test that multiple variables can be set at once""" + lstar_data = np.array([4.5, 5.0, 5.5]) + energy_data = np.array([100.0, 200.0, 300.0]) + + source_dict = {"Lstar": lstar_data, "energy_channels": energy_data} + + dict_dataset.update_from_dict(source_dict) + np.testing.assert_array_equal(dict_dataset.Lstar, lstar_data) + np.testing.assert_array_equal(dict_dataset.energy_channels, energy_data) + + +def test_dict_mode_computed_p_property(dict_dataset): + """Test P property with correct dimensions in dict mode""" + dict_dataset.MLT = np.array([0.0, 6.0, 12.0]) + + expected_p = ((dict_dataset.MLT + 12) / 12 * np.pi) % (2 * np.pi) + np.testing.assert_allclose(dict_dataset.P, expected_p) + + +def test_dict_mode_computed_invv_property(dict_dataset): + """Test InvV property with correct dimensions in dict mode""" + dict_dataset.InvMu = np.array([[0.1, 0.2]]) + dict_dataset.InvK = np.array([[1.0]]) + + inv_K_repeated = np.repeat(dict_dataset.InvK[:, np.newaxis, :], dict_dataset.InvMu.shape[1], axis=1) + expected_invv = dict_dataset.InvMu * (inv_K_repeated + 0.5) ** 2 + + np.testing.assert_allclose(dict_dataset.InvV, expected_invv) + + +def test_dict_mode_getattr_errors(dict_dataset): + """Test error handling for unset attributes in dict mode""" + with pytest.raises(AttributeError, match="exists in `VariableLiteral` but has not been set"): + _ = dict_dataset.Flux + + with pytest.raises(AttributeError, match="no attribute"): + _ = dict_dataset.NonExistent + + +def test_dict_mode_dir_contains_variable_names(dict_dataset): + """Test that dir() includes variable names in dict mode""" + variable_names = [var.var_name for var in VariableEnum] + for name in variable_names: + assert name in dir(dict_dataset) + + +def test_update_from_dict_invalid_key(dict_dataset): + """Test that invalid keys raise VariableNotFoundError""" + from swvo.io.exceptions import VariableNotFoundError + + source_dict = {"InvalidKey": np.array([1.0, 2.0])} + + with pytest.raises(VariableNotFoundError, match="not a valid `VariableLiteral`"): + dict_dataset.update_from_dict(source_dict) + + +def test_update_from_dict_similar_key(dict_dataset): + """Test that similar keys suggest the correct variable""" + from swvo.io.exceptions import VariableNotFoundError + + source_dict = {"Flx": np.array([1.0, 2.0])} # Typo: should be "Flux" + + with pytest.raises(VariableNotFoundError, match="Maybe you meant 'Flux'"): + dict_dataset.update_from_dict(source_dict) + + +def test_all_variable_literals(dict_dataset): + """Test that all VariableLiteral values can be set and retrieved.""" + + # Test with a subset of common variables + test_variables = { + "time": np.array([738000.0, 738001.0]), + "energy_channels": np.array([100.0, 200.0, 300.0]), + "alpha_local": np.array([0.1, 0.2, 0.3]), + "alpha_eq_model": np.array([45.0, 60.0, 90.0]), + "InvMu": np.array([[0.1, 0.2]]), + "InvK": np.array([[1.0, 2.0]]), + "Lstar": np.array([4.5, 5.0, 5.5]), + "Flux": np.array([[1.0, 2.0, 3.0]]), + "PSD": np.array([[0.1, 0.2, 0.3]]), + "MLT": np.array([0.0, 6.0, 12.0]), + "B_SM": np.array([100.0, 200.0, 300.0]), + "B_total": np.array([50.0, 60.0, 70.0]), + "B_sat": np.array([45.0, 55.0, 65.0]), + "xGEO": np.array([6.6, 6.7, 6.8]), + "R0": np.array([5.0, 5.5, 6.0]), + "density": np.array([100.0, 200.0, 300.0]), + } + + dict_dataset.update_from_dict(test_variables) + + for var_name, expected_data in test_variables.items(): + assert hasattr(dict_dataset, var_name), f"Attribute {var_name} not set" + np.testing.assert_array_equal( + getattr(dict_dataset, var_name), + expected_data, + err_msg=f"Data mismatch for {var_name}", + ) + + +@pytest.mark.parametrize("satellite, expected", [("goessecondary", "secondary"), ("goesprimary", "primary")]) +def test_dict_mode_goes_lowercase(satellite, expected): + """Test GOES satellite lowercase string handling in dict mode""" + goes_dataset = RBMDataSet( + satellite=satellite, + instrument=InstrumentEnum.MAGED, + mfm=MfmEnum.T89, + ) + assert goes_dataset.satellite.sat_name == expected + + +uppercase_satellites = set(get_args(SatelliteLiteral)) - set(["GOESPrimary", "GOESSecondary"]) + + +@pytest.mark.parametrize("satellite", [i.lower() for i in uppercase_satellites]) +def test_dict_mode_satellite_lowercase(satellite): + """Test satellite lowercase string handling in dict mode""" + dataset = RBMDataSet( + satellite=satellite, + instrument=InstrumentEnum.MAGED, + mfm=MfmEnum.T89, + ) + assert dataset.satellite.sat_name == satellite.lower() + + +def test_eq_file_loading_mode_identical(mock_module_string): + """Test equality for identical file loading mode datasets.""" + start_time = dt.datetime(2023, 1, 1, tzinfo=timezone.utc) + end_time = dt.datetime(2023, 1, 31, tzinfo=timezone.utc) + folder_path = Path("/mock/path") + + with mock.patch(f"{mock_module_string}._create_date_list"): + with mock.patch(f"{mock_module_string}._create_file_path_stem"): + with mock.patch(f"{mock_module_string}._create_file_name_stem"): + dataset1 = RBMDataSet( + satellite=SatelliteEnum.RBSPA, + instrument=InstrumentEnum.MAGEIS, + mfm=MfmEnum.T89, + start_time=start_time, + end_time=end_time, + folder_path=folder_path, + verbose=False, + ) + + dataset2 = RBMDataSet( + satellite=SatelliteEnum.RBSPA, + instrument=InstrumentEnum.MAGEIS, + mfm=MfmEnum.T89, + start_time=start_time, + end_time=end_time, + folder_path=folder_path, + verbose=False, + ) + + dataset1.Flux = np.array([[1.0, 2.0, 3.0]]) + dataset1.time = np.array([738000.0]) + dataset1.datetime = [dt.datetime(2023, 1, 15, tzinfo=timezone.utc)] + + dataset2.Flux = np.array([[1.0, 2.0, 3.0]]) + dataset2.time = np.array([738000.0]) + dataset2.datetime = [dt.datetime(2023, 1, 15, tzinfo=timezone.utc)] + + assert dataset1 == dataset2 + + +def test_eq_file_loading_mode_different_satellite(mock_module_string): + """Test inequality for file loading mode datasets with different satellites.""" + start_time = dt.datetime(2023, 1, 1, tzinfo=timezone.utc) + end_time = dt.datetime(2023, 1, 31, tzinfo=timezone.utc) + folder_path = Path("/mock/path") + + with mock.patch(f"{mock_module_string}._create_date_list"): + with mock.patch(f"{mock_module_string}._create_file_path_stem"): + with mock.patch(f"{mock_module_string}._create_file_name_stem"): + dataset1 = RBMDataSet( + satellite=SatelliteEnum.RBSPA, + instrument=InstrumentEnum.MAGEIS, + mfm=MfmEnum.T89, + start_time=start_time, + end_time=end_time, + folder_path=folder_path, + verbose=False, + ) + + dataset2 = RBMDataSet( + satellite=SatelliteEnum.RBSPB, # Different satellite + instrument=InstrumentEnum.MAGEIS, + mfm=MfmEnum.T89, + start_time=start_time, + end_time=end_time, + folder_path=folder_path, + verbose=False, + ) + + print("sahil") + + assert dataset1 != dataset2 + + +def test_eq_file_loading_mode_different_instrument(mock_module_string): + """Test inequality for file loading mode datasets with different instruments.""" + start_time = dt.datetime(2023, 1, 1, tzinfo=timezone.utc) + end_time = dt.datetime(2023, 1, 31, tzinfo=timezone.utc) + folder_path = Path("/mock/path") + + with mock.patch(f"{mock_module_string}._create_date_list"): + with mock.patch(f"{mock_module_string}._create_file_path_stem"): + with mock.patch(f"{mock_module_string}._create_file_name_stem"): + dataset1 = RBMDataSet( + satellite=SatelliteEnum.RBSPA, + instrument=InstrumentEnum.MAGEIS, + mfm=MfmEnum.T89, + start_time=start_time, + end_time=end_time, + folder_path=folder_path, + verbose=False, + ) + + dataset2 = RBMDataSet( + satellite=SatelliteEnum.RBSPA, + instrument=InstrumentEnum.HOPE, # Different instrument + mfm=MfmEnum.T89, + start_time=start_time, + end_time=end_time, + folder_path=folder_path, + verbose=False, + ) + + assert dataset1 != dataset2 + + +def test_eq_file_loading_mode_different_mfm(mock_module_string): + """Test inequality for file loading mode datasets with different MFM.""" + start_time = dt.datetime(2023, 1, 1, tzinfo=timezone.utc) + end_time = dt.datetime(2023, 1, 31, tzinfo=timezone.utc) + folder_path = Path("/mock/path") + + with mock.patch(f"{mock_module_string}._create_date_list"): + with mock.patch(f"{mock_module_string}._create_file_path_stem"): + with mock.patch(f"{mock_module_string}._create_file_name_stem"): + dataset1 = RBMDataSet( + satellite=SatelliteEnum.RBSPA, + instrument=InstrumentEnum.MAGEIS, + mfm=MfmEnum.T89, + start_time=start_time, + end_time=end_time, + folder_path=folder_path, + verbose=False, + ) + + dataset2 = RBMDataSet( + satellite=SatelliteEnum.RBSPA, + instrument=InstrumentEnum.MAGEIS, + mfm=MfmEnum.T96, # Different MFM + start_time=start_time, + end_time=end_time, + folder_path=folder_path, + verbose=False, + ) + + assert dataset1 != dataset2 + + +def test_eq_dict_mode_identical(): + """Test equality for identical dict mode datasets.""" + dataset1 = RBMDataSet( + satellite="RBSPA", + instrument="hope", + mfm="T89", + ) + + dataset2 = RBMDataSet( + satellite="RBSPA", + instrument="hope", + mfm="T89", + ) + + test_data = { + "Flux": np.array([[1.0, 2.0, 3.0]]), + "time": np.array([738000.0]), + "energy_channels": np.array([100.0, 200.0, 300.0]), + "Lstar": np.array([4.5, 5.0, 5.5]), + } + + dataset1.update_from_dict(test_data) + dataset2.update_from_dict(test_data.copy()) + + assert dataset1 == dataset2 + + +def test_eq_dict_mode_different_variables(): + """Test inequality for dict mode datasets with different variables.""" + dataset1 = RBMDataSet( + satellite="RBSPA", + instrument="hope", + mfm="T89", + ) + + dataset2 = RBMDataSet( + satellite="RBSPA", + instrument="hope", + mfm="T89", + ) + + dataset1.update_from_dict({"Flux": np.array([[1.0, 2.0, 3.0]])}) + dataset2.update_from_dict({"time": np.array([738000.0])}) + + assert dataset1 != dataset2 + + +def test_eq_dict_mode_same_variables_different_values(): + """Test inequality for dict mode datasets with same variables but different values.""" + dataset1 = RBMDataSet( + satellite="RBSPA", + instrument="hope", + mfm="T89", + ) + + dataset2 = RBMDataSet( + satellite="RBSPA", + instrument="hope", + mfm="T89", + ) + + dataset1.update_from_dict({"Flux": np.array([[1.0, 2.0, 3.0]])}) + dataset2.update_from_dict({"Flux": np.array([[4.0, 5.0, 6.0]])}) + + assert dataset1 != dataset2 + + +def test_eq_different_modes(mock_module_string): + """Test inequality between file loading and dict mode datasets.""" + # File loading mode dataset + start_time = dt.datetime(2023, 1, 1, tzinfo=timezone.utc) + end_time = dt.datetime(2023, 1, 31, tzinfo=timezone.utc) + folder_path = Path("/mock/path") + + with mock.patch(f"{mock_module_string}._create_date_list"): + with mock.patch(f"{mock_module_string}._create_file_path_stem"): + with mock.patch(f"{mock_module_string}._create_file_name_stem"): + file_dataset = RBMDataSet( + satellite=SatelliteEnum.RBSPA, + instrument=InstrumentEnum.MAGEIS, + mfm=MfmEnum.T89, + start_time=start_time, + end_time=end_time, + folder_path=folder_path, + verbose=False, + ) + + # Dict mode dataset + dict_dataset = RBMDataSet( + satellite="RBSPA", + instrument="hope", + mfm="T89", + ) + + assert file_dataset != dict_dataset + + +def test_eq_array_comparison_with_nan(): + """Test equality with NaN values in arrays.""" + dataset1 = RBMDataSet( + satellite="RBSPA", + instrument="hope", + mfm="T89", + ) + + dataset2 = RBMDataSet( + satellite="RBSPA", + instrument="hope", + mfm="T89", + ) + + test_array = np.array([[1.0, np.nan, 3.0]]) + dataset1.update_from_dict({"Flux": test_array}) + dataset2.update_from_dict({"Flux": test_array.copy()}) + + assert dataset1 == dataset2 + + +def test_eq_list_comparison(): + """Test equality with list variables.""" + dataset1 = RBMDataSet( + satellite="RBSPA", + instrument="hope", + mfm="T89", + ) + + dataset2 = RBMDataSet( + satellite="RBSPA", + instrument="hope", + mfm="T89", + ) + + test_datetime_list = [dt.datetime(2023, 1, 15, tzinfo=timezone.utc)] + dataset1.datetime = test_datetime_list + dataset2.datetime = test_datetime_list.copy() + + assert dataset1 == dataset2 + + +def test_eq_list_different_lengths(): + """Test inequality with lists of different lengths.""" + dataset1 = RBMDataSet( + satellite="RBSPA", + instrument="hope", + mfm="T89", + ) + + dataset2 = RBMDataSet( + satellite="RBSPA", + instrument="hope", + mfm="T89", + ) + + dataset1.datetime = [dt.datetime(2023, 1, 15, tzinfo=timezone.utc)] + dataset2.datetime = [ + dt.datetime(2023, 1, 15, tzinfo=timezone.utc), + dt.datetime(2023, 1, 16, tzinfo=timezone.utc), + ] + + assert dataset1 != dataset2 + + +def test_eq_array_different_shapes(): + """Test inequality with arrays of different shapes.""" + dataset1 = RBMDataSet( + satellite="RBSPA", + instrument="hope", + mfm="T89", + ) + + dataset2 = RBMDataSet( + satellite="RBSPA", + instrument="hope", + mfm="T89", + ) + + dataset1.update_from_dict({"Flux": np.array([[1.0, 2.0]])}) + dataset2.update_from_dict({"Flux": np.array([1.0, 2.0])}) + + assert dataset1 != dataset2 + + +def test_eq_different_types(): + """Test inequality when same variable has different types.""" + dataset1 = RBMDataSet( + satellite="RBSPA", + instrument="hope", + mfm="T89", + ) + + dataset2 = RBMDataSet( + satellite="RBSPA", + instrument="hope", + mfm="T89", + ) + + dataset1.time = np.array([738000.0]) + dataset2.time = [738000.0] + + assert dataset1 != dataset2 diff --git a/tests/io/RBMDataSet/test_RBMDatasetElPaso.py b/tests/io/RBMDataSet/test_RBMDatasetElPaso.py deleted file mode 100644 index efa1244..0000000 --- a/tests/io/RBMDataSet/test_RBMDatasetElPaso.py +++ /dev/null @@ -1,242 +0,0 @@ -# SPDX-FileCopyrightText: 2025 GFZ Helmholtz Centre for Geosciences -# -# SPDX-License-Identifier: Apache-2.0 - -from datetime import datetime, timezone -from typing import get_args - -import numpy as np -import pytest - -from swvo.io.RBMDataSet import ( - InstrumentEnum, - MfmEnum, - RBMDataSetElPaso, - SatelliteEnum, - SatelliteLiteral, - VariableEnum, -) -from swvo.io.RBMDataSet.utils import python2matlab - - -class MockVariable: - """Create a mock Elpaso Variable class for testing""" - - def __init__(self, standard_name, data=None): - self.standard_name = standard_name - self.data = data - - -@pytest.fixture -def dataset(): - return RBMDataSetElPaso( - satellite=SatelliteEnum.GOESSecondary, - instrument=InstrumentEnum.MAGED, - mfm=MfmEnum.T89, - ) - - -def test_init_accepts_string_inputs(): - """Test that the class can be initialized with string inputs.""" - ds = RBMDataSetElPaso(satellite="GOESSecondary", instrument="MAGED", mfm=MfmEnum.T89) - assert ds.satellite.sat_name == "secondary" - assert ds.instrument == InstrumentEnum.MAGED - assert ds.mfm == MfmEnum.T89 - - -def test_variable_mapping_exposed(dataset): - assert isinstance(dataset.variable_mapping, dict) - assert "FEDU" in dataset.variable_mapping - assert dataset.variable_mapping["FEDU"] == "Flux" - - -def test_repr_and_str(dataset): - assert "RBMDataSetElPaso" in repr(dataset) - assert str(dataset.satellite) in repr(dataset) - assert str(dataset.instrument) in repr(dataset) - assert str(dataset.mfm) in repr(dataset) - - assert "RBMDataSetElPaso" in str(dataset) - assert str(dataset.satellite) in str(dataset) - assert str(dataset.instrument) in str(dataset) - assert str(dataset.mfm) in str(dataset) - - -def test_update_from_dict_sets_variables(dataset): - """Test that the correct variable is set with the standard name""" - fedu_data = np.array([[1.0, 2.0]]) - - source_dict = {"FEDU": MockVariable(standard_name="FEDU", data=fedu_data)} - - dataset.update_from_dict(source_dict) - np.testing.assert_array_equal(dataset.Flux, fedu_data) - - -def test_update_from_dict_sets_time(dataset): - """Test that the correct variable is set with the time standard name""" - ts = [datetime(2025, 4, 1, tzinfo=timezone.utc).timestamp()] - - source_dict = {"Epoch": MockVariable(standard_name="Epoch_posixtime", data=ts)} - - dataset.update_from_dict(source_dict) - assert hasattr(dataset, "time") - assert hasattr(dataset, "datetime") - assert dataset.time[0] == python2matlab(datetime(2025, 4, 1, tzinfo=timezone.utc)) - - -def test_update_from_dict_with_mfm_suffix(dataset): - """Test that the correct variable is set with the MFM suffix""" - mfm_suffix = "_" + dataset._mfm_prefix - lstar_data = np.array([4.5, 5.0, 5.5]) - - source_dict = {"Lstar": MockVariable(standard_name=f"Lstar{mfm_suffix}", data=lstar_data)} - - dataset.update_from_dict(source_dict) - np.testing.assert_array_equal(dataset.Lstar, lstar_data) - - -def test_computed_p_property(dataset): - """Test P property with correct dimensions""" - dataset.MLT = np.array([0.0, 6.0, 12.0]) - - expected_p = ((dataset.MLT + 12) / 12 * np.pi) % (2 * np.pi) - np.testing.assert_allclose(dataset.P, expected_p) - - -def test_computed_invv_property(dataset): - """Test InvV property with correct dimensions""" - dataset.InvMu = np.array([[0.1, 0.2]]) - dataset.InvK = np.array([[1.0]]) - - inv_K_repeated = np.repeat(dataset.InvK[:, np.newaxis, :], dataset.InvMu.shape[1], axis=1) - expected_invv = dataset.InvMu * (inv_K_repeated + 0.5) ** 2 - - np.testing.assert_allclose(dataset.InvV, expected_invv) - - -def test_getattr_errors(dataset): - with pytest.raises(AttributeError, match="mapped but has not been set"): - _ = dataset.Flux - - with pytest.raises(AttributeError, match="no attribute"): - _ = dataset.NonExistent - - -def test_dir_contains_variable_names(dataset): - variable_names = [var.var_name for var in VariableEnum] - for name in variable_names: - assert name in dir(dataset) - - -def test_all_variable_mappings(dataset): - """Test that all variable mappings work correctly.""" - - expected_mappings = { - "Epoch_posixtime": "time", - "Energy_FEDU": "energy_channels", - "PA_local": "alpha_local", - "PA_eq_": "alpha_eq_model", - "alpha_eq_real": "alpha_eq_real", - "invMu_": "InvMu", - "InvMu_real": "InvMu_real", - "invK_": "InvK", - "Lstar_": "Lstar", - "FEDU": "Flux", - "PSD_FEDU": "PSD", - "MLT_": "MLT", - "B_SM": "B_SM", - "B_eq_": "B_total", - "B_local_": "B_sat", - "xGEO": "xGEO", - "R_eq_": "R0", - "density": "density", - } - - for source, target in expected_mappings.items(): - assert source in dataset.variable_mapping - assert dataset.variable_mapping[source] == target - - test_data = {} - for source, target in expected_mappings.items(): - if source == "Epoch_posixtime": - data = [datetime(2025, 4, 1, tzinfo=timezone.utc).timestamp()] - else: - data = np.array([float(hash(source) % 1000) / 10.0]) - - test_data[source] = MockVariable(standard_name=source, data=data) - - dataset.update_from_dict(test_data) - - for source, target in expected_mappings.items(): - if source == "Epoch_posixtime": - assert hasattr(dataset, "time") - assert hasattr(dataset, "datetime") - assert isinstance(dataset.datetime[0], datetime) - assert dataset.time[0] == python2matlab(dataset.datetime[0]) - elif source in ["P", "InvV"]: - pass - else: - assert hasattr(dataset, target), f"Attribute {target} not set from {source}" - - np.testing.assert_array_equal( - dataset.__getattribute__(target), - test_data[source].data, - err_msg=f"Data mismatch for {target} from {source}", - ) - - -def test_all_mfm_specific_mappings(dataset): - """Test that all MFM-specific variable mappings work correctly.""" - - mfm_variables = [ - ("PA_eq", "alpha_eq_model"), - ("invMu", "InvMu"), - ("invK", "InvK"), - ("Lstar", "Lstar"), - ("B_eq", "B_total"), - ("B_local", "B_sat"), - ("R_eq", "R0"), - ] - - mfm_suffix = "_" + dataset._mfm_prefix - - test_data = {} - for source_base, target in mfm_variables: - source = f"{source_base}{mfm_suffix}" - - data = np.array([float(hash(source) % 1000) / 10.0]) - test_data[source] = MockVariable(standard_name=source, data=data) - - dataset.update_from_dict(test_data) - - for source_base, target in mfm_variables: - source = f"{source_base}{mfm_suffix}" - assert hasattr(dataset, target), f"Attribute {target} not set from {source}" - np.testing.assert_array_equal( - dataset.__getattribute__(target), - test_data[source].data, - err_msg=f"Data mismatch for {target} from {source}", - ) - - -@pytest.mark.parametrize("satellite, expected", [("goessecondary", "secondary"), ("goesprimary", "primary")]) -def test_goes_lowercase(satellite, expected): - goes_dataset = RBMDataSetElPaso( - satellite=satellite, - instrument=InstrumentEnum.MAGED, - mfm=MfmEnum.T89, - ) - assert goes_dataset.satellite.sat_name == expected - - -uppercase_satellites = set(get_args(SatelliteLiteral)) - set(["GOESPrimary", "GOESSecondary"]) - - -@pytest.mark.parametrize("satellite", [i.lower() for i in uppercase_satellites]) -def test_satellite_lowercase(satellite): - dataset = RBMDataSetElPaso( - satellite=satellite, - instrument=InstrumentEnum.MAGED, - mfm=MfmEnum.T89, - ) - assert dataset.satellite.sat_name == satellite.lower()