diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000..433420c --- /dev/null +++ b/.gitattributes @@ -0,0 +1 @@ +nt2/tests/testdata.tar.gz filter=lfs diff=lfs merge=lfs -text diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index 1eaf532..a6fbaa9 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -8,8 +8,22 @@ jobs: steps: - uses: actions/checkout@v3 + with: + lfs: true + - name: Set up Python 3.12 + uses: actions/setup-python@v4 + with: + python-version: "3.12" + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install pytest + if [ -f requirements.txt ]; then pip install -r requirements.txt; fi + - name: Test with `pytest` + run: | + pytest - name: Publish package if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags') uses: pypa/gh-action-pypi-publish@release/v1 with: - password: ${{ secrets.PYPI_API_TOKEN }} \ No newline at end of file + password: ${{ secrets.PYPI_API_TOKEN }} diff --git a/.gitignore b/.gitignore index ae1febe..17eb961 100644 --- a/.gitignore +++ b/.gitignore @@ -151,6 +151,7 @@ dmypy.json # Cython debug symbols cython_debug/ +nt2/tests/testdata test/ temp/ *.bak diff --git a/.vscode/extensions.json b/.vscode/extensions.json new file mode 100644 index 0000000..00563da --- /dev/null +++ b/.vscode/extensions.json @@ -0,0 +1,6 @@ +{ + "recommendations": [ + "ms-python.python", + "ms-python.black-formatter" + ] +} \ No newline at end of file diff --git a/README.md b/README.md index 0e1a092..9c0e049 100644 --- a/README.md +++ b/README.md @@ -8,15 +8,14 @@ pip install nt2py ### Usage -The Library works both with single-file output as well as with separate files. In either case, the location of the data is passed via `path` keyword argument. +Simply pass the location to the data when initializing the main `Data` object: ```python import nt2 -data = nt2.Data(path="path/to/data") +data = nt2.Data("path/to/data") # example: -# data = nt2.Data(path="path/to/shock.h5") : for single-file -# data = nt2.Data(path="path/to/shock") : for multi-file +# data = nt2.Data("path/to/shock") ``` The data is stored in specialized containers which can be accessed via corresponding attributes: @@ -146,16 +145,51 @@ nt2.Dashboard() This will output the port where the dashboard server is running, e.g., `Dashboard: http://127.0.0.1:8787/status`. Click on it (or enter in your browser) to open the dashboard. +### CLI + +Since version 1.0.0, `nt2py` also offers a command-line interface, accessed via `nt2` command. To view all the options, simply run: + +```sh +nt2 --help +``` + +The plotting routine is pretty customizable. For instance, if the data is located in `myrun/mysimulation`, you can inspect the content of the data structure using: + +```sh +nt2 show myrun/mysimulation +``` + +Or if you want to make a quick plot (a-la `inspect` discussed above) of the specific quantities, you may simply run: + +```sh +nt2 plot myrun/mysimulation --fields "E.*;B.*" --isel "t=5" --sel "x=slice(-5, None); z=0.5" +``` + +This plots the 6-th snapshot (`t=5`) of all the `E` and `B` field components, sliced for `x > -5`, and at `z = 0.5` (notice, that you can use both `--isel` and `--sel`). If instead, you prefer to make a movie, simply do not specify the time: + +```sh +nt2 plot myrun/mysimulation --fields "E.*;B.*" --sel "x=slice(-5, None); z=0.5" +``` + +> If you want to only install the CLI, without the library itself, you may do that via `pipx`: `pipx install nt2py`. + ### Features 1. Lazy loading and parallel processing of the simulation data with [`dask`](https://dask.org/). 2. Context-aware data manipulation with [`xarray`](http://xarray.pydata.org/en/stable/). -3. Parellel plotting and movie generation with [`multiprocessing`](https://docs.python.org/3/library/multiprocessing.html) and [`ffmpeg`](https://ffmpeg.org/). +3. Parallel plotting and movie generation with [`multiprocessing`](https://docs.python.org/3/library/multiprocessing.html) and [`ffmpeg`](https://ffmpeg.org/). +4. Command-line interface, the `nt2` command, for quick plotting (both movies and snapshots). + +### Testing + +There are unit tests included with the code which also require downloading test data with [`git lfs`](https://git-lfs.com/) (installed separately from `git`). You may download the data simply by running `git lfs pull`. ### TODO -- [ ] Unit tests -- [ ] Plugins for other simulation data formats +- [x] Unit tests +- [x] Plugins for other simulation data formats +- [ ] Support for sparse arrays for particles via `Sparse` library +- [x] Command-line interface - [ ] Support for multiple runs - [ ] Interactive regime (`hvplot`, `bokeh`, `panel`) - [x] Ghost cells support diff --git a/nt2/__init__.py b/nt2/__init__.py index 587afdd..9808bca 100644 --- a/nt2/__init__.py +++ b/nt2/__init__.py @@ -1,4 +1,7 @@ -__version__ = "0.6.0" +__version__ = "1.0.0" -from nt2.data import Data as Data -from nt2.dashboard import Dashboard as Dashboard +import nt2.containers.data as nt2_data + + +class Data(nt2_data.Data): + pass diff --git a/nt2/cli/__init__.py b/nt2/cli/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/nt2/cli/__main__.py b/nt2/cli/__main__.py new file mode 100644 index 0000000..ae32d87 --- /dev/null +++ b/nt2/cli/__main__.py @@ -0,0 +1,3 @@ +from .main import app + +app(prog_name="nt2") diff --git a/nt2/cli/main.py b/nt2/cli/main.py new file mode 100644 index 0000000..d84c141 --- /dev/null +++ b/nt2/cli/main.py @@ -0,0 +1,159 @@ +import typer, nt2, os +from typing_extensions import Annotated +import matplotlib.pyplot as plt + +app = typer.Typer() + + +@app.command(help="Print the data info") +def version(): + print(nt2.__version__) + + +def check_path(path: str) -> str: + if not os.path.exists(path) or not ( + os.path.exists(os.path.join(path, "fields")) + or os.path.exists(os.path.join(path, "particles")) + or os.path.exists(os.path.join(path, "spectra")) + ): + raise typer.BadParameter( + f"Path {path} does not exist or is not a valid nt2 data directory." + ) + return path + + +def check_sel(sel: str) -> dict[str, int | float | slice]: + if sel == "": + return {} + sel_list = sel.strip().split(";") + sel_dict: dict[str, int | float | slice] = {} + for _, s in enumerate(sel_list): + coord, arg = s.strip().split("=", 1) + coord = coord.strip() + arg_exec = eval(arg.strip()) + assert isinstance( + arg_exec, (int, float, slice) + ), f"Invalid selection argument for '{coord}': {arg_exec}. Must be int, float, or slice." + sel_dict[coord] = arg_exec + return sel_dict + + +def check_species(species: int) -> int: + if species < 0: + raise typer.BadParameter( + f"Species index must be a non-negative integer, got {species}." + ) + return species + + +def check_what(what: str) -> str: + valid_options = ["fields", "particles", "spectra"] + if what not in valid_options: + raise typer.BadParameter( + f"Invalid option '{what}'. Valid options are: {', '.join(valid_options)}." + ) + return what + + +@app.command(help="Print the data info") +def show( + path: Annotated[ + str, + typer.Argument( + callback=check_path, + help="Path to the data", + ), + ] = "", +): + data = nt2.Data(path) + print(data.to_str()) + + +@app.command(help="Plot the data") +def plot( + path: Annotated[ + str, + typer.Argument( + callback=check_path, + help="Path to the data", + ), + ] = "", + what: Annotated[ + Annotated[ + str, + typer.Option( + callback=check_what, + help="Which data to plot [fields, particles, spectra]", + ), + ], + str, + ] = "fields", + fields: Annotated[ + str, + typer.Option( + help="Which fields to plot (only when `what` is `fields`). Separate multiple fields with ';'. May contain regex. Empty = all fields. Example: `--fields \"E.*;B.*\"`", + ), + ] = "", + # species: Annotated[ + # Annotated[ + # int, + # typer.Option( + # callback=check_species, + # help="Which species to take (only when `what` is `particles`). 0 = all species", + # ), + # ], + # str, + # ] = 0, + sel: Annotated[ + str, + typer.Option( + callback=check_sel, + help="Select a subset of the data with xarray.sel. Separate multiple selections with ';'. Example: `--sel \"t=23;z=slice(0, None)\"`", + ), + ] = "", + isel: Annotated[ + str, + typer.Option( + callback=check_sel, + help="Select a subset of the data with xarray.isel. Separate multiple selections with ';'. Example: `--isel \"t=slice(None, 5);z=5\"`", + ), + ] = "", +): + fname = os.path.basename(path.strip("/")) + data = nt2.Data(path) + assert isinstance( + sel, dict + ), f"Invalid selection format: {sel}. Must be a dictionary." + assert isinstance(isel, dict), f"Invalid isel format: {isel}. Must be a dictionary." + if what == "fields": + d = data.fields + if sel != {}: + slices = {} + sels = {} + slices: dict[str, slice | float | int] = { + k: v for k, v in sel.items() if isinstance(v, slice) + } + sels: dict[str, slice | float | int] = { + k: v for k, v in sel.items() if not isinstance(v, slice) + } + d = d.sel(**sels, method="nearest") + d = d.sel(**slices) + if isel != {}: + d = d.isel(**isel) + if fields != "": + ret = d.inspect.plot( + name=fname, only_fields=fields.split(";"), fig_kwargs={"dpi": 200} + ) + else: + ret = d.inspect.plot(name=fname, fig_kwargs={"dpi": 200}) + if not isinstance(ret, bool): + plt.savefig(fname=f"{fname}.png") + + elif what == "particles": + raise NotImplementedError("Particles plotting is not implemented yet.") + elif what == "spectra": + raise NotImplementedError("Spectra plotting is not implemented yet.") + else: + raise typer.BadParameter( + f"Invalid option '{what}'. Valid options are: fields, particles, spectra." + ) diff --git a/nt2/containers/container.py b/nt2/containers/container.py index 616f4b6..da0da3b 100644 --- a/nt2/containers/container.py +++ b/nt2/containers/container.py @@ -1,178 +1,45 @@ -import os -import h5py -import numpy as np -from typing import Any - - -def _read_attribs_SingleFile(file: h5py.File): - attribs = {} - for k in file.attrs.keys(): - attr = file.attrs[k] - if type(attr) is bytes or type(attr) is np.bytes_: - attribs[k] = attr.decode("UTF-8") - else: - attribs[k] = attr - return attribs - - -class Container: - """ - * * * * Container * * * * - - Parent class for all data containers. - - Args - ---- - path : str - The path to the data. - - Kwargs - ------ - pickle : bool, optional - Whether to use pickle for reading the data. Default is True. - - greek : bool, optional - Whether to use Greek letters for the spherical coordinates. Default is False. - - dask_props : dict, optional - Additional properties for Dask [NOT IMPLEMENTED]. Default is {}. - - Attributes - ---------- - path : str - The path to the data. - - configs : dict - The configuration settings for the data. - - metadata : dict - The metadata for the data. - - mesh : dict - Coordinate grid of the domain (cell-centered & edges). - - master_file : h5py.File - The master file for the data (from which the main attributes are read). - - attrs : dict - The attributes of the master file. - - Methods - ------- - plotGrid(ax, **kwargs) - Plots the gridlines of the domain. - - """ - - def __init__(self, path, pickle=True, greek=False, dask_props={}): - super(Container, self).__init__() - - self.path = path - self.configs: dict[str, Any] = { - "single_file": self.path.endswith(".h5"), - "use_pickle": pickle, - "use_greek": greek, - } - self.metadata = {} - self.mesh = None - if self.configs["single_file"]: - try: - self.master_file: h5py.File | None = h5py.File(self.path, "r") - except OSError: - raise OSError(f"Could not open file {self.path}") - else: - field_path = os.path.join(self.path, "fields") - file = os.path.join(field_path, os.listdir(field_path)[0]) - try: - self.master_file: h5py.File | None = h5py.File(file, "r") - except OSError: - raise OSError(f"Could not open file {file}") - - self.attrs = _read_attribs_SingleFile(self.master_file) - - assert self.master_file is not None, "Master file not found" - - self.configs["ngh"] = int(self.master_file.attrs.get("NGhosts", 0)) - self.configs["layout"] = ( - "right" if self.master_file.attrs.get("LayoutRight", 1) == 1 else "left" - ) - self.configs["dimension"] = int(self.master_file.attrs.get("Dimension", 1)) - self.configs["coordinates"] = self.master_file.attrs.get( - "Coordinates", b"cart" - ).decode("UTF-8") - if self.configs["coordinates"] == "qsph": - self.configs["coordinates"] = "sph" - - if self.isDebug(): - self.configs["coordinates"] = "cart" - - if not self.configs["single_file"]: - self.master_file.close() - self.master_file = None - - def isDebug(self): - return self.configs["ngh"] > 0 - - def __del__(self): - if self.master_file is not None: - self.master_file.close() - - def plotGrid(self, ax, **kwargs): - try: - assert self.mesh is not None, "Mesh not found" - except AttributeError: - raise AttributeError("Mesh not found") - except AssertionError: - raise AssertionError("Mesh not found") - - assert len(self.mesh["xc"]) == 2, "Data must be 2D for plotGrid to work" - - from matplotlib import patches as mpatches - - xlim, ylim = ax.get_xlim(), ax.get_ylim() - options = { - "lw": 1, - "color": "k", - "ls": "-", - } - options.update(kwargs) - - x1_emin, x2_emin = list(self.mesh["xe_min"].keys()) - x1_emax, x2_emax = list(self.mesh["xe_max"].keys()) - x1_e = list(self.mesh["xe_min"][x1_emin][1]) + [ - self.mesh["xe_max"][x1_emax][1][-1] - ] - x2_e = list(self.mesh["xe_min"][x2_emin][1]) + [ - self.mesh["xe_max"][x2_emax][1][-1] - ] - if self.configs["coordinates"] == "cart": - for x1 in x1_e: - ax.plot([x1, x1], [x2_e[0], x2_e[-1]], **options) - for x2 in x2_e: - ax.plot([x1_e[0], x1_e[-1]], [x2, x2], **options) - else: - for r in x1_e: - ax.add_patch( - mpatches.Arc( - (0, 0), - 2 * r, - 2 * r, - theta1=-90, - theta2=90, - fill=False, - **options, - ) - ) - for th in x2_e: - ax.plot( - [ - x1_e[0] * np.sin(th), - x1_e[-1] * np.sin(th), - ], - [ - x1_e[0] * np.cos(th), - x1_e[-1] * np.cos(th), - ], - **options, - ) - ax.set(xlim=xlim, ylim=ylim) +from typing import Callable + +from nt2.readers.base import BaseReader + + +class BaseContainer: + """Parent container class for holding any category data.""" + + def __init__( + self, + path: str, + reader: BaseReader, + remap: dict[str, Callable[[str], str]] | None = None, + ): + """Initializer for the BaseContainer class. + + Parameters + ---------- + path : str + The path to the data. + reader : BaseReader + The reader to be used for reading the data. + remap : dict[str, Callable[[str], str]] | None + Remap dictionary to use to remap the data names (coords, fields, etc.). + + """ + super(BaseContainer, self).__init__() + self.__path = path + self.__reader = reader + self.__remap = remap + + @property + def path(self) -> str: + """str: The main path of the data.""" + return self.__path + + @property + def reader(self) -> BaseReader: + """BaseReader: The reader used to read the data.""" + return self.__reader + + @property + def remap(self) -> dict[str, Callable[[str], str]] | None: + """dict[str, Callable[[str], str]]: The coordinate/field remap dictionary.""" + return self.__remap diff --git a/nt2/containers/data.py b/nt2/containers/data.py new file mode 100644 index 0000000..cd0bec9 --- /dev/null +++ b/nt2/containers/data.py @@ -0,0 +1,324 @@ +from typing import Callable, Any + +import sys +if sys.version_info >= (3, 12): + from typing import override +else: + def override(method): + return method + +from collections.abc import KeysView +from nt2.utils import ToHumanReadable + +import xarray as xr + +from nt2.utils import ( + DetermineDataFormat, + InheritClassDocstring, + Format, + CoordinateSystem, +) +from nt2.readers.base import BaseReader +from nt2.readers.hdf5 import Reader as HDF5Reader +from nt2.readers.adios2 import Reader as BP5Reader +from nt2.containers.fields import Fields +from nt2.containers.particles import Particles + +import nt2.plotters.polar as acc_polar + +import nt2.plotters.inspect as acc_inspect +import nt2.plotters.movie as acc_movie +from nt2.plotters.export import makeFramesAndMovie + + +@xr.register_dataset_accessor("polar") +@InheritClassDocstring +class DatasetPolarPlotAccessor(acc_polar.ds_accessor): + pass + + +@xr.register_dataarray_accessor("polar") +@InheritClassDocstring +class PolarPlotAccessor(acc_polar.accessor): + pass + + +@xr.register_dataset_accessor("inspect") +@InheritClassDocstring +class DatasetInspectPlotAccessor(acc_inspect.ds_accessor): + pass + + +@xr.register_dataarray_accessor("movie") +@InheritClassDocstring +class MoviePlotAccessor(acc_movie.accessor): + pass + + +class Data(Fields, Particles): # pyright: ignore[reportUnsafeMultipleInheritance] + """Main class to manage all the data containers. + + Inherits from all category-specific containers. + + """ + + def __init__( + self, + path: str, + reader: BaseReader | None = None, + remap: dict[str, Callable[[str], str]] | None = None, + coord_system: CoordinateSystem | None = None, + ): + """Initializer for the Data class. + + Parameters + ---------- + path : str + Main path to the data + reader : BaseReader | None + Reader to use to read the data. If None, it will be determined + based on the file format. + remap : dict[str, Callable[[str], str]] | None + Remap dictionary to use to remap the data names (coords, fields, etc.). + coord_system : CoordinateSystem | None + Coordinate system of the data. If None, it will be determined + based on the data attrs (if remap is also None). + + Raises + ------ + NotImplementedError + If the data format or coordinate system support is not implemented. + ValueError + If the reader format does not match the data format or if coordinate system cannot be inferred. + """ + # determine the reader from the format + fmt = DetermineDataFormat(path) + if reader is None: + if fmt == Format.HDF5: + self.__reader = HDF5Reader() + elif fmt == Format.BP5: + self.__reader = BP5Reader() + else: + raise NotImplementedError( + "Only HDF5 & BP5 formats are supported at the moment." + ) + else: + if fmt != reader.format: + raise ValueError( + f"Reader format {reader.format} does not match data format {fmt}." + ) + self.__reader = reader + + # determine the coordinate system and remapping + self.__attrs: dict[str, Any] = {} + for category in ["fields", "particles", "spectra"]: + if self.__reader.DefinesCategory(path, category): + valid_steps = self.__reader.GetValidSteps(path, category) + if len(valid_steps) == 0: + raise ValueError(f"No valid steps found for category {category}.") + first_step = valid_steps[0] + attrs = self.__reader.ReadAttrsAtTimestep(path, category, first_step) + self.__attrs.update(**attrs) + if "Coordinates" not in attrs: + raise ValueError( + f"Coordinates not found in attributes for category {category}." + ) + else: + if attrs["Coordinates"] in [b"cart", "cart"]: + + def remap_fields(name: str) -> str: + name = name[1:] + fieldname = name.split("_")[0] + fieldname = fieldname.replace("0", "t") + fieldname = fieldname.replace("1", "x") + fieldname = fieldname.replace("2", "y") + fieldname = fieldname.replace("3", "z") + suffix = "_".join(name.split("_")[1:]) + return f"{fieldname}{'_' + suffix if suffix != '' else ''}" + + def remap_coords(name: str) -> str: + return { + "X1": "x", + "X2": "y", + "X3": "z", + }.get(name, name) + + def remap_prtl_quantities(name: str) -> str: + shortname = name[1:] + return { + "X1": "x", + "X2": "y", + "X3": "z", + "U1": "ux", + "U2": "uy", + "U3": "uz", + "W": "w", + }.get(shortname, shortname) + + coord_system = CoordinateSystem.XYZ + + elif attrs["Coordinates"] in [b"sph", "sph", b"qsph", "qsph"]: + + def remap_fields(name: str) -> str: + name = name[1:] + fieldname = name.split("_")[0] + fieldname = fieldname.replace("0", "t") + fieldname = fieldname.replace("1", "r") + fieldname = fieldname.replace("2", "th") + fieldname = fieldname.replace("3", "ph") + suffix = "_".join(name.split("_")[1:]) + return f"{fieldname}{'_' + suffix if suffix != '' else ''}" + + def remap_coords(name: str) -> str: + return { + "X1": "r", + "X2": "th", + "X3": "ph", + }.get(name, name) + + def remap_prtl_quantities(name: str) -> str: + shortname = name[1:] + return { + "X1": "r", + "X2": "th", + "X3": "ph", + "U1": "ur", + "U2": "uth", + "U3": "uph", + "W": "w", + }.get(shortname, shortname) + + coord_system = CoordinateSystem.SPH + + else: + raise NotImplementedError( + f"Coordinate system {attrs['Coordinates']} not supported." + ) + if remap is None: + remap = { + "coords": remap_coords, + "fields": remap_fields, + "particles": remap_prtl_quantities, + } + break + + if coord_system is None: + raise ValueError("No coordinate system found in the data.") + + self.__coordinate_system = coord_system + + super(Data, self).__init__(path=path, reader=self.__reader, remap=remap) + + def makeMovie( + self, + plot: Callable, # pyright: ignore[reportMissingTypeArgument, reportUnknownParameterType] + time: list[float] | None = None, + num_cpus: int | None = None, + **movie_kwargs: Any, + ) -> bool: + f"""Create animation with provided plot function. + + Parameters + ---------- + plot : callable + A function that takes a single argument (time in physical units) and produces a plot. + time : array_like, optional + An array of time values to use for the animation. If not provided, the entire time range will be used. + + Returns + ------- + bool + True if the movie was created successfully, False otherwise. + """ + if time is None: + if self.fields_defined: + time = self.fields.t.values + elif self.particles_defined: + species = sorted(list(self.particles.keys())) + time = self.particles[species[0]].t.values + else: + raise ValueError("No time values found.") + assert time is not None, "Time values must be provided." + name: str = "" + if self.attrs.get("simulation.name", None) == None: + name = movie_kwargs.pop("name", "movie") + else: + name_b = self.attrs.get("simulation.name") + if isinstance(name_b, bytes): + name = name_b.decode("utf-8") + else: + name = str(name_b) + return makeFramesAndMovie( + name=name, + data=self, + plot=plot, + times=time, + num_cpus=num_cpus, + **movie_kwargs, + ) + + @property + def coordinate_system(self) -> CoordinateSystem: + """CoordinateSystem: The coordinate system of the data.""" + return self.__coordinate_system + + @property + def attrs(self) -> dict[str, Any]: + """dict[str, Any]: The attributes of the data.""" + return self.__attrs + + def to_str(self) -> str: + """str: String representation of the all the enclosed dataframes.""" + + def compactify(lst: list[Any] | KeysView[Any]) -> str: + c = "" + cntr = 0 + for l_ in lst: + if cntr > 5: + c += "\n " + cntr = 0 + c += f"{l_}, " + cntr += 1 + return c[:-2] + + string = "" + if self.fields_defined: + string += "Fields:\n" + string += f" - coordinates: {self.coordinate_system.value}\n" + string += f" - data axes: {compactify(self.fields.indexes.keys())}\n" + delta_t = ( + self.fields.coords["t"].values[1] - self.fields.coords["t"].values[0] + ) / (self.fields.coords["s"].values[1] - self.fields.coords["s"].values[0]) + string += f" - dt: {delta_t:.2e}\n" + for key in self.fields.coords.keys(): + crd = self.fields.coords[key].values + fmt = "" + if key != "s": + fmt = ".2f" + string += f" - {key}: {crd.min():{fmt}} -> {crd.max():{fmt}} [{len(crd)}]\n" + string += ( + f" - quantities: {compactify(sorted(self.fields.data_vars.keys()))}\n" + ) + string += f" - total size: {ToHumanReadable(self.fields.nbytes)}\n\n" + else: + string += "Fields: empty\n\n" + if self.particles_defined: + species = sorted(list(self.particles.keys())) + string += "Particle species:\n" + string += f" - species: {compactify(species)}\n" + string += f" - timesteps: {len(self.particles[species[0]].t)}\n" + string += f" - quantities: {compactify(self.particles[species[0]].data_vars.keys())}\n" + string += f" - max # per species: {[self.particles[sp].idx.shape[0] for sp in species]}\n" + string += f" - total size: {ToHumanReadable(sum([self.particles[sp].nbytes for sp in species]))}\n\n" + else: + string += "Particles: empty\n\n" + + return string + + @override + def __str__(self) -> str: + return self.to_str() + + @override + def __repr__(self) -> str: + return self.to_str() diff --git a/nt2/containers/fields.py b/nt2/containers/fields.py index 513e0fe..ebc9914 100644 --- a/nt2/containers/fields.py +++ b/nt2/containers/fields.py @@ -1,293 +1,159 @@ -import os -import h5py -import xarray as xr - -from nt2.containers.container import Container -from nt2.containers.utils import ( - _read_category_metadata, - _read_coordinates, - _preload_domain_shapes, - _preload_field, - _preload_field_with_ghosts, -) - -from nt2.plotters.polar import ( - _datasetPolarPlotAccessor, - _polarPlotAccessor, -) - -from nt2.plotters.inspect import _datasetInspectPlotAccessor -from nt2.plotters.movie import _moviePlotAccessor - -from nt2.containers.utils import InheritClassDocstring - - -@xr.register_dataset_accessor("polar") -@InheritClassDocstring -class DatasetPolarPlotAccessor(_datasetPolarPlotAccessor): - pass - - -@xr.register_dataarray_accessor("polar") -@InheritClassDocstring -class PolarPlotAccessor(_polarPlotAccessor): - pass - - -@xr.register_dataset_accessor("inspect") -@InheritClassDocstring -class DatasetInspectPlotAccessor(_datasetInspectPlotAccessor): - pass - - -@xr.register_dataarray_accessor("movie") -@InheritClassDocstring -class MoviePlotAccessor(_moviePlotAccessor): - pass - - -class FieldsContainer(Container): - """ - * * * * FieldsContainer : Container * * * * - - Class for hodling the field (grid-based) data. - - Attributes - ---------- - fields : xarray.Dataset - The xarray dataset for all the field quantities. - - fields_files : list - The list of opened fields files. - - Methods - ------- - print_fields() - Prints the basic information about the field data. - - """ - - def __init__(self, **kwargs): - super(FieldsContainer, self).__init__(**kwargs) - QuantityDict = { - "Ttt": "E", - "Ttx": "Px", - "Tty": "Py", - "Ttz": "Pz", - } - CoordinateDict = { - "cart": {"x": "x", "y": "y", "z": "z", "1": "x", "2": "y", "3": "z"}, - "sph": { - "r": "r", - "theta": "θ" if self.configs["use_greek"] else "th", - "phi": "φ" if self.configs["use_greek"] else "ph", - "1": "r", - "2": "θ" if self.configs["use_greek"] else "th", - "3": "φ" if self.configs["use_greek"] else "ph", - }, - } - if self.configs["single_file"]: - assert self.master_file is not None, "Master file not found" - self.metadata["fields"] = _read_category_metadata( - True, "f", self.master_file - ) - else: - field_path = os.path.join(self.path, "fields") - if os.path.isdir(field_path): - files = sorted(os.listdir(field_path)) - try: - self.fields_files = [ - h5py.File(os.path.join(field_path, f), "r") for f in files - ] - except OSError: - raise OSError(f"Could not open file in {field_path}") - self.metadata["fields"] = _read_category_metadata( - False, "f", self.fields_files - ) +from typing import Any - if not self.isDebug(): - coords = list(CoordinateDict[self.configs["coordinates"]].values())[::-1][ - -self.configs["dimension"] : - ] - else: - coords = ["i3", "i2", "i1"][-self.configs["dimension"] :] +import dask +import dask.array as da +import xarray as xr - if self.configs["single_file"]: - assert self.master_file is not None, "Master file not found" - self.mesh = _read_coordinates(coords, self.master_file) +from nt2.containers.container import BaseContainer +from nt2.readers.base import BaseReader +from nt2.utils import Layout + + +class Fields(BaseContainer): + """Parent class to manage the fields dataframe.""" + + @staticmethod + @dask.delayed + def __read_field(path: str, reader: BaseReader, field: str, step: int) -> Any: + """Reads a field from the data. + + This is a dask-delayed function used further to build the dataset. + + Parameters + ---------- + path : str + Main path to the data. + reader : BaseReader + Reader to use to read the data. + field : str + Field to read. + step : int + Step to read. + + Returns + ------- + Any + Field data. + + """ + return reader.ReadArrayAtTimestep(path, "fields", field, step) + + def __init__( + self, + **kwargs: Any, + ) -> None: + """Initializer for the Fields class. + + Parameters + ---------- + **kwargs : dict + Keyword arguments to be passed to the parent BaseContainer class. + + """ + super(Fields, self).__init__(**kwargs) + if self.reader.DefinesCategory(self.path, "fields"): + self.__fields_defined = True + self.__fields = self.__read_fields() else: - self.mesh = _read_coordinates(coords, self.fields_files[0]) - - self._fields = xr.Dataset() - - if "fields" in self.metadata and len(self.metadata["fields"]["outsteps"]) > 0: - self.domains = xr.Dataset() - for i in range(self.configs["dimension"]): - self.domains[f"x{i+1}"], self.domains[f"sx{i+1}"] = ( - _preload_domain_shapes( - single_file=self.configs["single_file"], - k=f"N{i+1}l", - outsteps=self.metadata["fields"]["outsteps"], - times=self.metadata["fields"]["times"], - steps=self.metadata["fields"]["steps"], - file=( - self.master_file - if self.configs["single_file"] - and self.master_file is not None - else self.fields_files - ), - ) - ) - - for k in self.metadata["fields"]["quantities"]: - if not self.isDebug(): - name, dset = _preload_field( - single_file=self.configs["single_file"], - k=k, - outsteps=self.metadata["fields"]["outsteps"], - times=self.metadata["fields"]["times"], - steps=self.metadata["fields"]["steps"], - coords=coords, - xc_coords=self.mesh["xc"], - xe_min_coords=self.mesh["xe_min"], - xe_max_coords=self.mesh["xe_max"], - coord_replacements=list( - CoordinateDict[self.configs["coordinates"]].items() - ), - field_replacements=list(QuantityDict.items()), - layout=self.configs["layout"], - file=( - self.master_file - if self.configs["single_file"] - and self.master_file is not None - else self.fields_files - ), - ) - else: - ( - name, - dset, - self.mesh["xc"], - self.mesh["xe_min"], - self.mesh["xe_max"], - ) = _preload_field_with_ghosts( - single_file=self.configs["single_file"], - k=k, - outsteps=self.metadata["fields"]["outsteps"], - times=self.metadata["fields"]["times"], - steps=self.metadata["fields"]["steps"], - coords=coords, - coord_replacements=list( - CoordinateDict[self.configs["coordinates"]].items() - ), - field_replacements=list(QuantityDict.items()), - layout=self.configs["layout"], - file=( - self.master_file - if self.configs["single_file"] - and self.master_file is not None - else self.fields_files - ), - ) - self.fields[name] = dset + self.__fields_defined = False + self.__fields = xr.Dataset() @property - def fields(self): - return self._fields - - def __del__(self): - if not self.configs["single_file"]: - for f in self.fields_files: - f.close() - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_value, traceback): - self.__del__() - - def print_fields(self) -> str: - def sizeof_fmt(num, suffix="B"): - for unit in ("", "K", "M", "G", "T", "P", "E", "Z"): - if abs(num) < 1e3: - return f"{num:3.1f} {unit}{suffix}" - num /= 1e3 - return f"{num:.1f} Y{suffix}" - - def compactify(lst): - c = "" - cntr = 0 - for l_ in lst: - if cntr > 5: - c += "\n " - cntr = 0 - c += l_ + ", " - cntr += 1 - return c[:-2] - - string = "" - field_keys = list(self.fields.data_vars.keys()) + def fields_defined(self) -> bool: + """bool: Whether the fields category is defined.""" + return self.__fields_defined - if len(field_keys) > 0: - string += "Fields:\n" - string += f" - data axes: {compactify(self.fields.indexes.keys())}\n" - string += f" - timesteps: {self.fields[field_keys[0]].shape[0]}\n" - string += f" - shape: {self.fields[field_keys[0]].shape[1:]}\n" - string += f" - quantities: {compactify(self.fields.data_vars.keys())}\n" - string += f" - total size: {sizeof_fmt(self.fields.nbytes)}\n" + @property + def fields(self) -> xr.Dataset: + """xr.Dataset: The fields dataframe.""" + return self.__fields + + def __read_fields(self) -> xr.Dataset: + """Helper function to read the fields dataframe.""" + self.reader.VerifySameCategoryNames(self.path, "fields", "f") + self.reader.VerifySameFieldShapes(self.path) + self.reader.VerifySameFieldLayouts(self.path) + + valid_steps = self.reader.GetValidSteps(self.path, "fields") + field_names = self.reader.ReadCategoryNamesAtTimestep( + self.path, "fields", "f", valid_steps[0] + ) + + first_step = valid_steps[0] + first_name = next(iter(field_names)) + layout = self.reader.ReadFieldLayoutAtTimestep(self.path, first_step) + shape = self.reader.ReadArrayShapeAtTimestep( + self.path, "fields", first_name, first_step + ) + coords = self.reader.ReadFieldCoordsAtTimestep(self.path, first_step) + coords = {k: coords[k] for k in sorted(coords.keys())[::-1]} + # rename coordinates if remap is provided + if self.remap is not None and "coords" in self.remap: + new_coords = {} + for coord in coords.keys(): + new_coords[self.remap["coords"](coord)] = coords[coord] + coords = new_coords + + times = self.reader.ReadPerTimestepVariable(self.path, "fields", "Time", "t") + steps = self.reader.ReadPerTimestepVariable(self.path, "fields", "Step", "s") + + edge_coords = self.reader.ReadEdgeCoordsAtTimestep(self.path, first_step) + if self.remap is None or "coords" not in self.remap: + + def remap(x: str) -> str: + return x + + coord_remap = remap else: - string += "Fields: empty\n" - - return string - - def plotDomains(self, ax, ti=None, t=None, **kwargs): - if self.domains is None: - raise AttributeError("Domains not found") - - assert len(self.domains.data_vars) == 4, "Data must be 2D for plotGrid to work" - - import matplotlib.patches as mpatches - - ngh = self.configs["ngh"] - - xlim, ylim = ax.get_xlim(), ax.get_ylim() - options = { - "lw": 2, - "color": "r", - "ls": "-", - } - options.update(kwargs) - - for dom in self.domains.dom: - selection = self.domains.sel(dom=dom) - if ti is not None: - selection = selection.sel(t=ti) - elif t is not None: - selection = selection.sel(t=t, method="nearest") + coord_remap = self.remap["coords"] + new_edge_coords = {} + for coord in edge_coords.keys(): + assoc_x = coord_remap(coord[:-1]) + new_edge_coords[assoc_x + "_min"] = (assoc_x, edge_coords[coord][:-1]) + new_edge_coords[assoc_x + "_max"] = (assoc_x, edge_coords[coord][1:]) + edge_coords = new_edge_coords + + all_dims = {**times, **coords}.keys() + all_coords = {**times, **coords, "s": ("t", steps["s"]), **edge_coords} + + def remap_name(name: str) -> str: + """ + Remaps the field name if remap is provided + """ + if self.remap is not None and "fields" in self.remap: + return self.remap["fields"](name) + return name + + def get_field(name: str, step: int) -> Any: + """ + Reads a field from the data + """ + if layout == Layout.L: + return Fields.__read_field(self.path, self.reader, name, step) else: - selection = selection.isel(t=0) - - x1c, sx1 = selection.x1.values[()], selection.sx1.values[()] - x2c, sx2 = selection.x2.values[()], selection.sx2.values[()] - - # add rectangle - ax.add_patch( - mpatches.Rectangle( - (x1c + ngh, x2c + ngh), - sx1 - 2 * ngh, - sx2 - 2 * ngh, - fill=None, - **options, + return Fields.__read_field(self.path, self.reader, name, step).T + + return xr.Dataset( + { + remap_name(name): xr.DataArray( + da.stack( + [ + da.from_delayed( + get_field(name, step), + shape=shape[:: -1 if layout == Layout.R else 1], + dtype="float", + ) + for step in valid_steps + ], + axis=0, + ), + name=remap_name(name), + dims=all_dims, + coords=all_coords, ) - ) - - # ax.plot( - # self.domains[x1][j], - # self.domains[x2][j], - # **options, - # ) - # ax.plot( - # self.domains[x1_e][j], - # self.domains[x2_e][j], - # **options, - # ) + for name in field_names + }, + attrs=self.reader.ReadAttrsAtTimestep( + path=self.path, category="fields", step=first_step + ), + ) diff --git a/nt2/containers/particles.py b/nt2/containers/particles.py index 61543e6..bb17a33 100644 --- a/nt2/containers/particles.py +++ b/nt2/containers/particles.py @@ -1,147 +1,189 @@ -import os -import h5py -from nt2.containers.container import Container -from nt2.containers.utils import ( - _read_category_metadata, - _read_particle_species, - _preload_particle_species, -) - - -class ParticleContainer(Container): - """ - * * * * ParticleContainer : Container * * * * - - Class for holding the particle data. - - Attributes - ---------- - particles : dict - The dictionary of particle species. - - particle_files : list - The list of opened particle files. - - Methods - ------- - print_particles() - Prints the basic information about the particle data. - - """ - - def __init__(self, **kwargs): - super(ParticleContainer, self).__init__(**kwargs) - PrtlDict = { - "cart": { - "X1": "x", - "X2": "y", - "X3": "z", - "U1": "ux", - "U2": "uy", - "U3": "uz", - }, - "sph": { - "X1": "r", - "X2": "θ" if self.configs["use_greek"] else "th", - "X3": "φ" if self.configs["use_greek"] else "ph", - "U1": "ur", - "U2": "uΘ" if self.configs["use_greek"] else "uth", - "U3": "uφ" if self.configs["use_greek"] else "uph", - }, - } +from typing import Any - if self.configs["single_file"]: - assert self.master_file is not None, "Master file not found" - self.metadata["particles"] = _read_category_metadata( - True, "p", self.master_file - ) - else: - particle_path = os.path.join(self.path, "particles") - if os.path.isdir(particle_path): - files = sorted(os.listdir(particle_path)) - try: - self.particle_files = [ - h5py.File(os.path.join(particle_path, f), "r") for f in files - ] - except OSError: - raise OSError(f"Could not open file in {particle_path}") - self.metadata["particles"] = _read_category_metadata( - False, "p", self.particle_files - ) - self._particles = {} +import dask +import dask.array as da +import xarray as xr +import numpy as np + +from nt2.containers.container import BaseContainer +from nt2.readers.base import BaseReader + + +class Particles(BaseContainer): + """Parent class to manage the particles dataframe.""" + + @staticmethod + @dask.delayed + def __read_species_quantity( + path: str, reader: BaseReader, species_quantity: str, step: int, pad: int + ) -> Any: + """Reads a species from the data. + + This is a dask-delayed function used further to build the dataset. + + Parameters + ---------- + path : str + Main path to the data. + reader : BaseReader + Reader to use to read the data. + species_quantity : str + Quantity specific to a species to read. + step : int + Step to read. + pad : int + Length to pad the array to. + + Returns + ------- + Any + Species data. + """ + arr = reader.ReadArrayAtTimestep(path, "particles", species_quantity, step) + shape = reader.ReadArrayShapeAtTimestep( + path, "particles", species_quantity, step + )[0] + return da.pad(arr, ((0, pad - shape),), mode="constant", constant_values=np.nan) + + def __init__(self, **kwargs: Any) -> None: + """Initializer for the Particles class. + + Parameters + ---------- + **kwargs : dict + Keyword arguments to be passed to the parent BaseContainer class. + + """ + super(Particles, self).__init__(**kwargs) if ( - "particles" in self.metadata - and len(self.metadata["particles"]["outsteps"]) > 0 + self.reader.DefinesCategory(self.path, "particles") + and self.particles_present ): - if self.configs["single_file"]: - assert self.master_file is not None, "Master file not found" - species = _read_particle_species( - self.metadata["particles"]["outsteps"][0], self.master_file - ) - else: - species = _read_particle_species("Step0", self.particle_files[0]) - self.metadata["particles"]["species"] = species - for s in species: - self._particles[s] = _preload_particle_species( - self.configs["single_file"], - s=s, - quantities=self.metadata["particles"]["quantities"], - coord_type=self.configs["coordinates"], - outsteps=self.metadata["particles"]["outsteps"], - times=self.metadata["particles"]["times"], - steps=self.metadata["particles"]["steps"], - coord_replacements=PrtlDict[self.configs["coordinates"]], - file=( - self.master_file - if self.configs["single_file"] and self.master_file is not None - else self.particle_files - ), + self.__particles_defined = True + self.__particles = self.__read_particles() + else: + self.__particles_defined = False + self.__particles = {} + + @property + def particles_present(self) -> bool: + """bool: Whether the particles are present in any of the timesteps.""" + return len(self.nonempty_steps) > 0 + + @property + def nonempty_steps(self) -> list[int]: + """list[int]: List of timesteps that contain particles data.""" + valid_steps = self.reader.GetValidSteps(self.path, "particles") + return [ + step + for step in valid_steps + if len( + set( + q.split("_")[0] + for q in self.reader.ReadCategoryNamesAtTimestep( + self.path, "particles", "p", step + ) + if q.startswith("p") ) + ) + > 0 + ] @property - def particles(self): - return self._particles - - def __del__(self): - if not self.configs["single_file"]: - for f in self.particle_files: - f.close() - - def print_particles(self) -> str: - def sizeof_fmt(num, suffix="B"): - for unit in ("", "K", "M", "G", "T", "P", "E", "Z"): - if abs(num) < 1e3: - return f"{num:3.1f} {unit}{suffix}" - num /= 1e3 - return f"{num:.1f} Y{suffix}" - - def compactify(lst): - c = "" - cntr = 0 - for l_ in lst: - if cntr > 5: - c += "\n " - cntr = 0 - c += l_ + ", " - cntr += 1 - return c[:-2] - - string = "" - if self.particles != {}: - species = [int(i) for i in self.particles.keys()] - string += "Particles:\n" - string += f" - species: {species}\n" - string += f" - data axes: {compactify(self.particles[species[0]].indexes.keys())}\n" - string += f" - timesteps: {self.particles[species[0]][list(self.particles[species[0]].data_vars.keys())[0]].shape[0]}\n" - string += f" - quantities: {compactify(self.particles[species[0]].data_vars.keys())}\n" - size = 0 - for s in species: - keys = list(self.particles[s].data_vars.keys()) - string += f" - species [{s}]:\n" - string += f" - number: {self.particles[s][keys[0]].shape[1]}\n" - size += self.particles[s].nbytes - string += f" - total size: {sizeof_fmt(size)}\n" - else: - string += "Particles: empty\n" - return string + def particles_defined(self) -> bool: + """bool: Whether the particles category is defined.""" + return self.__particles_defined + + @property + def particles(self) -> dict[int, xr.Dataset]: + """Returns the particles data. + + Returns + ------- + dict[int, xr.Dataset] + Dictionary of datasets for each step. + + """ + return self.__particles + + def __read_particles(self) -> dict[int, xr.Dataset]: + """Helper function to read all particles data.""" + self.reader.VerifySameCategoryNames(self.path, "particles", "p") + self.reader.VerifySameParticleShapes(self.path) + + valid_steps = self.nonempty_steps + prtl_species = self.reader.ReadParticleSpeciesAtTimestep( + self.path, valid_steps[0] + ) + prtl_quantities = set( + q.split("_")[0] + for q in self.reader.ReadCategoryNamesAtTimestep( + self.path, "particles", "p", valid_steps[0] + ) + if q.startswith("p") + ) + prtl_quantities = sorted(prtl_quantities) + + first_quantity = next(iter(prtl_quantities)) + maxlens = { + sp: np.max( + [ + self.reader.ReadArrayShapeAtTimestep( + self.path, "particles", f"{first_quantity}_{sp}", st + ) + for st in valid_steps + ] + ) + for sp in prtl_species + } + + times = self.reader.ReadPerTimestepVariable(self.path, "particles", "Time", "t") + steps = self.reader.ReadPerTimestepVariable(self.path, "particles", "Step", "s") + + idxs: dict[int, dict[str, np.ndarray]] = { + sp: {"idx": np.arange(maxlens[sp])} for sp in prtl_species + } + + all_dims = {sp: {**times, **(idxs[sp])}.keys() for sp in prtl_species} + all_coords = { + sp: {**times, **(idxs[sp]), "s": ("t", steps["s"])} for sp in prtl_species + } + + def remap_quantity(name: str) -> str: + """ + Remaps the particle quantity name if remap is provided + """ + if self.remap is not None and "particles" in self.remap: + return self.remap["particles"](name) + return name + + def get_quantity(species: int, quantity: str, step: int, maxlen: int) -> Any: + return Particles.__read_species_quantity( + self.path, self.reader, f"{quantity}_{species}", step, maxlen + ) + + return { + sp: xr.Dataset( + { + remap_quantity(quantity): xr.DataArray( + da.stack( + [ + da.from_delayed( + get_quantity(sp, quantity, step, maxlens[sp]), + shape=(maxlens[sp],), + dtype="float", + ) + for step in valid_steps + ], + axis=0, + ), + name=remap_quantity(quantity), + dims=all_dims[sp], + coords=all_coords[sp], + ) + for quantity in prtl_quantities + } + ) + for sp in prtl_species + } diff --git a/nt2/containers/spectra.py b/nt2/containers/spectra.py deleted file mode 100644 index 27d43ad..0000000 --- a/nt2/containers/spectra.py +++ /dev/null @@ -1,139 +0,0 @@ -import os -import h5py -import xarray as xr - -from nt2.containers.container import Container -from nt2.containers.utils import ( - _read_category_metadata, - _read_spectra_species, - _read_spectra_bins, - _preload_spectra, -) - - -class SpectraContainer(Container): - """ - * * * * SpectraContainer : Container * * * * - - Class for holding the spectra (energy distribution) data. - - Attributes - ---------- - spectra : xarray.Dataset - The xarray dataset of particle distributions. - - spectra_files : list - The list of opened spectra files. - - Methods - ------- - print_spectra() - Prints the basic information about the spectra data. - - """ - - def __init__(self, **kwargs): - super(SpectraContainer, self).__init__(**kwargs) - assert "single_file" in self.configs - assert "use_pickle" in self.configs - assert "use_greek" in self.configs - assert "path" in self.__dict__ - assert "metadata" in self.__dict__ - assert "mesh" in self.__dict__ - assert "attrs" in self.__dict__ - - if self.configs["single_file"]: - assert self.master_file is not None, "Master file not found" - self.metadata["spectra"] = _read_category_metadata( - True, "s", self.master_file - ) - else: - spectra_path = os.path.join(self.path, "spectra") - if os.path.isdir(spectra_path): - files = sorted(os.listdir(spectra_path)) - try: - self.spectra_files = [ - h5py.File(os.path.join(spectra_path, f), "r") for f in files - ] - except OSError: - raise OSError(f"Could not open file {spectra_path}") - self.metadata["spectra"] = _read_category_metadata( - False, "s", self.spectra_files - ) - self._spectra = xr.Dataset() - log_bins = self.attrs["output.spectra.log_bins"] - - if "spectra" in self.metadata and len(self.metadata["spectra"]["outsteps"]) > 0: - if self.configs["single_file"]: - assert self.master_file is not None, "Master file not found" - species = _read_spectra_species( - f'Step{self.metadata["spectra"]["outsteps"][0]}', self.master_file - ) - e_bins = _read_spectra_bins( - f'Step{self.metadata["spectra"]["outsteps"][0]}', - log_bins, - self.master_file, - ) - else: - species = _read_spectra_species("Step0", self.spectra_files[0]) - e_bins = _read_spectra_bins("Step0", log_bins, self.spectra_files[0]) - - self.metadata["spectra"]["species"] = species - - for sp in species: - self._spectra[f"n_{sp}"] = _preload_spectra( - self.configs["single_file"], - sp, - e_bins=e_bins, - outsteps=self.metadata["spectra"]["outsteps"], - times=self.metadata["spectra"]["times"], - steps=self.metadata["spectra"]["steps"], - file=( - self.master_file - if self.configs["single_file"] and self.master_file is not None - else self.spectra_files - ), - ) - - def __del__(self): - if not self.configs["single_file"]: - for f in self.spectra_files: - f.close() - - @property - def spectra(self): - return self._spectra - - def print_spectra(self) -> str: - def sizeof_fmt(num, suffix="B"): - for unit in ("", "K", "M", "G", "T", "P", "E", "Z"): - if abs(num) < 1e3: - return f"{num:3.1f} {unit}{suffix}" - num /= 1e3 - return f"{num:.1f} Y{suffix}" - - def compactify(lst): - c = "" - cntr = 0 - for l_ in lst: - if cntr > 5: - c += "\n " - cntr = 0 - c += l_ + ", " - cntr += 1 - return c[:-2] - - string = "" - spec_keys = list(self.spectra.data_vars.keys()) - - if len(spec_keys) > 0: - string += "Spectra:\n" - string += f" - data axes: {compactify(self.spectra.indexes.keys())}\n" - string += f" - timesteps: {self.spectra[spec_keys[0]].shape[0]}\n" - string += f" - # of bins: {self.spectra[spec_keys[0]].shape[1]}\n" - string += f" - quantities: {compactify(self.spectra.data_vars.keys())}\n" - string += f" - total size: {sizeof_fmt(self.spectra.nbytes)}\n" - else: - string += "Spectra: empty\n" - - return string diff --git a/nt2/containers/utils.py b/nt2/containers/utils.py deleted file mode 100644 index 524f166..0000000 --- a/nt2/containers/utils.py +++ /dev/null @@ -1,445 +0,0 @@ -from typing import List, Union, Tuple, Dict -import h5py -import numpy as np -import xarray as xr -from dask.array.core import from_array -from dask.array.core import stack -from dask.array.core import Array as dArray -import inspect - - -def InheritClassDocstring(cls): - if cls.__doc__ is None: - cls.__doc__ = "" - for base in inspect.getmro(cls): - if base.__doc__ is not None: - cls.__doc__ += base.__doc__ - return cls - - -def _dataIs2DPolar(ds): - return ("r" in ds.dims and ("θ" in ds.dims or "th" in ds.dims)) and len( - ds.dims - ) == 2 - - -def _read_category_metadata( - single_file: bool, prefix: str, file: Union[h5py.File, List[h5py.File]] -): - outsteps = [] - steps = [] - times = [] - quantities = None - for i, st in enumerate(file): - if single_file: - assert isinstance(file, h5py.File) - group = file[st] - else: - assert isinstance(file[i], h5py.File) - group = st["Step0"] - assert isinstance(group, h5py.Group), f"Unexpected type {type(group)}" - time_ds = group["Time"] - step_ds = group["Step"] - assert isinstance(time_ds, h5py.Dataset), f"Unexpected type {type(time_ds)}" - assert isinstance(step_ds, h5py.Dataset), f"Unexpected type {type(step_ds)}" - if single_file and not any( - [k.startswith(prefix) for k in group if k is not None] - ): - continue - if quantities is None or len(quantities) == 0: - quantities = [k for k in group.keys() if k.startswith(prefix)] - outsteps.append(st if single_file else f"Step{i}") - times.append(time_ds[()]) - steps.append(int(step_ds[()])) - outsteps = sorted(outsteps, key=lambda x: int(x.replace("Step", ""))) - steps = sorted(steps) - times = np.array(sorted(times), dtype=np.float64) - return { - "quantities": quantities, - "outsteps": outsteps, - "steps": steps, - "times": times, - } - - -# fields -def _read_coordinates(coords: List[str], file: h5py.File): - for st in file: - group = file[st] - if isinstance(group, h5py.Group): - if any([k.startswith("X") for k in group if k is not None]): - # cell-centered coords - xc = { - c: ( - np.asarray(xi[:]) - if isinstance(xi := group[f"X{i+1}"], h5py.Dataset) and xi - else None - ) - for i, c in enumerate(coords[::-1]) - } - # cell edges - xe_min = { - f"{c}_1": ( - c, - ( - np.asarray(xi[:-1]) - if isinstance((xi := group[f"X{i+1}e"]), h5py.Dataset) - else None - ), - ) - for i, c in enumerate(coords[::-1]) - } - xe_max = { - f"{c}_2": ( - c, - ( - np.asarray(xi[1:]) - if isinstance((xi := group[f"X{i+1}e"]), h5py.Dataset) - else None - ), - ) - for i, c in enumerate(coords[::-1]) - } - return {"xc": xc, "xe_min": xe_min, "xe_max": xe_max} - else: - raise ValueError(f"Unexpected type {type(file[st])}") - raise ValueError("Could not find coordinates in file") - - -def _preload_dask_arrays( - single_file: bool, - k: str, - outsteps: List[int], - coord_replacements: List[Tuple[str, str]], - field_replacements: List[Tuple[str, str]], - layout: str, - file: Union[h5py.File, List[h5py.File]], -) -> Tuple[str, List[dArray]]: - dask_arrays = [] - if single_file: - for s in outsteps: - assert isinstance(file, h5py.File) - dset = file[f"{s}/{k}"] - if isinstance(dset, h5py.Dataset): - array = from_array(np.transpose(dset) if layout == "right" else dset) - dask_arrays.append(array) - else: - raise ValueError(f"Unexpected type {type(dset)}") - else: - for f in file: - assert isinstance(f, h5py.File) - dset = f[f"Step0/{k}"] - if isinstance(dset, h5py.Dataset): - array = from_array(np.transpose(dset) if layout == "right" else dset) - dask_arrays.append(array) - else: - raise ValueError(f"Unexpected type {type(dset)}") - - k_ = k[1:] - for c in coord_replacements: - if "_" not in k_: - k_ = k_.replace(c[0], c[1]) - else: - k_ = "_".join([k_.split("_")[0].replace(c[0], c[1])] + k_.split("_")[1:]) - for f in field_replacements: - k_ = k_.replace(*f) - - return k_, dask_arrays - - -def _preload_domain_shapes( - single_file: bool, - k: str, - outsteps: List[int], - times: List[float], - steps: List[int], - file: Union[h5py.File, List[h5py.File]], -) -> Tuple[xr.DataArray, xr.DataArray]: - dask_corners = [] - dask_sizes = [] - ndomains = None - if single_file: - for s in outsteps: - assert isinstance(file, h5py.File) - dset = file[f"{s}/{k}"] - if isinstance(dset, h5py.Dataset): - dask_corners.append(from_array(dset[::2])) - dask_sizes.append(from_array(dset[1::2])) - ndomains = ndomains or (dset.shape[0] // 2) - else: - raise ValueError(f"Unexpected type {type(dset)}") - else: - for f in file: - assert isinstance(f, h5py.File) - dset = f[f"Step0/{k}"] - if isinstance(dset, h5py.Dataset): - dask_corners.append(from_array(dset[::2])) - dask_sizes.append(from_array(dset[1::2])) - ndomains = ndomains or (dset.shape[0] // 2) - else: - raise ValueError(f"Unexpected type {type(dset)}") - assert ndomains is not None, "Could not find domain shapes" - domains = np.arange(ndomains) - return xr.DataArray( - stack(dask_corners, axis=0), - dims=["t", "dom"], - name=k, - coords={ - "t": times, - "s": ("t", steps), - "dom": domains, - }, - ), xr.DataArray( - stack(dask_sizes, axis=0), - dims=["t", "dom"], - name=k, - coords={ - "t": times, - "s": ("t", steps), - "dom": domains, - }, - ) - - -def _preload_field_with_ghosts( - single_file: bool, - k: str, - outsteps: List[int], - times: List[float], - steps: List[int], - coords: List[str], - coord_replacements: List[Tuple[str, str]], - field_replacements: List[Tuple[str, str]], - layout: str, - file: Union[h5py.File, List[h5py.File]], -) -> Tuple[str, xr.DataArray, Dict, Dict, Dict]: - k_, dask_arrays = _preload_dask_arrays( - single_file=single_file, - k=k, - outsteps=outsteps, - coord_replacements=coord_replacements, - field_replacements=field_replacements, - layout=layout, - file=file, - ) - - sizes = dask_arrays[0].shape[::-1] - assert len(sizes) == len(coords), "Mismatch in number of dimensions" - xc = {} - xe_min = {} - xe_max = {} - for i, ci in enumerate(coords[::-1]): - xc[ci] = np.arange(sizes[i]) + 0.5 - xe_min[ci + "_1"] = (ci, np.arange(sizes[i])) - xe_max[ci + "_2"] = (ci, np.arange(sizes[i]) + 1) - - return ( - k_, - xr.DataArray( - stack(dask_arrays, axis=0), - dims=["t", *coords], - name=k_, - coords={ - "t": times, - "s": ("t", steps), - **xc, - **xe_min, - **xe_max, - }, - ), - xc, - xe_min, - xe_max, - ) - - -def _preload_field( - single_file: bool, - k: str, - outsteps: List[int], - times: List[float], - steps: List[int], - coords: List[str], - xc_coords: Dict[str, str], - xe_min_coords: Dict[str, str], - xe_max_coords: Dict[str, str], - coord_replacements: List[Tuple[str, str]], - field_replacements: List[Tuple[str, str]], - layout: str, - file: Union[h5py.File, List[h5py.File]], -) -> Tuple[str, xr.DataArray]: - k_, dask_arrays = _preload_dask_arrays( - single_file=single_file, - k=k, - outsteps=outsteps, - coord_replacements=coord_replacements, - field_replacements=field_replacements, - layout=layout, - file=file, - ) - - return k_, xr.DataArray( - stack(dask_arrays, axis=0), - dims=["t", *coords], - name=k_, - coords={ - "t": times, - "s": ("t", steps), - **xc_coords, - **xe_min_coords, - **xe_max_coords, - }, - ) - - -# particles -def _list_to_ragged(arr): - max_len = np.max([len(a) for a in arr]) - return map( - lambda a: np.concatenate([a, np.full(max_len - len(a), np.nan)]), - arr, - ) - - -def _read_particle_species(first_step: str, file: h5py.File): - group = file[first_step] - if not isinstance(group, h5py.Group): - raise ValueError(f"Unexpected type {type(group)}") - species = np.unique( - [int(pq.split("_")[1]) for pq in group.keys() if pq.startswith("p")] - ) - return species - - -def _preload_particle_species( - single_file: bool, - s: int, - quantities: List[str], - coord_type: str, - outsteps: List[int], - times: List[float], - steps: List[int], - coord_replacements: Dict[str, str], - file: Union[h5py.File, List[h5py.File]], -): - prtl_data = {} - for q in [ - f"X1_{s}", - f"X2_{s}", - f"X3_{s}", - f"U1_{s}", - f"U2_{s}", - f"U3_{s}", - f"W_{s}", - ]: - if q[0] in ["X", "U"]: - q_ = coord_replacements[q.split("_")[0]] - else: - q_ = q.split("_")[0] - if "p" + q not in quantities: - continue - if q not in prtl_data.keys(): - prtl_data[q_] = [] - if single_file: - assert isinstance(file, h5py.File) - for step_k in outsteps: - group = file[step_k] - if isinstance(group, h5py.Group): - if "p" + q in group.keys(): - prtl_data[q_].append(group["p" + q]) - else: - prtl_data[q_].append(np.full_like(prtl_data[q_][-1], np.nan)) - else: - raise ValueError(f"Unexpected type {type(file[step_k])}") - else: - for f in file: - assert isinstance(f, h5py.File) - group = f["Step0"] - if isinstance(group, h5py.Group): - if "p" + q in group.keys(): - prtl_data[q_].append(group["p" + q]) - else: - prtl_data[q_].append(np.full_like(prtl_data[q_][-1], np.nan)) - else: - raise ValueError(f"Unexpected type {type(group)}") - prtl_data[q_] = _list_to_ragged(prtl_data[q_]) - prtl_data[q_] = from_array(list(prtl_data[q_])) - prtl_data[q_] = xr.DataArray( - prtl_data[q_], - dims=["t", "id"], - name=q_, - coords={"t": times, "s": ("t", steps)}, - ) - if coord_type == "sph": - prtl_data["x"] = ( - prtl_data[coord_replacements["X1"]] - * np.sin(prtl_data[coord_replacements["X2"]]) - * np.cos(prtl_data[coord_replacements["X3"]]) - ) - prtl_data["y"] = ( - prtl_data[coord_replacements["X1"]] - * np.sin(prtl_data[coord_replacements["X2"]]) - * np.sin(prtl_data[coord_replacements["X3"]]) - ) - prtl_data["z"] = prtl_data[coord_replacements["X1"]] * np.cos( - prtl_data[coord_replacements["X2"]] - ) - return xr.Dataset(prtl_data) - - -# spectra -def _read_spectra_species(first_step: str, file: h5py.File): - group = file[first_step] - if not isinstance(group, h5py.Group): - raise ValueError(f"Unexpected type {type(group)}") - species = np.unique( - [int(pq.split("_")[1]) for pq in group.keys() if pq.startswith("sN")] - ) - return species - - -def _read_spectra_bins(first_step: str, log_bins: bool, file: h5py.File): - group = file[first_step] - if not isinstance(group, h5py.Group): - raise ValueError(f"Unexpected type {type(group)}") - e_bins = group["sEbn"] - if not isinstance(e_bins, h5py.Dataset): - raise ValueError(f"Unexpected type {type(e_bins)}") - if log_bins: - e_bins = np.sqrt(e_bins[1:] * e_bins[:-1]) - else: - e_bins = (e_bins[1:] + e_bins[:-1]) / 2 - return e_bins - - -def _preload_spectra( - single_file: bool, - sp: int, - e_bins: np.ndarray, - outsteps: List[int], - times: List[float], - steps: List[int], - file: Union[h5py.File, List[h5py.File]], -): - dask_arrays = [] - if single_file: - assert isinstance(file, h5py.File) - for st in outsteps: - array = from_array(file[f"{st}/sN_{sp}"]) - dask_arrays.append(array) - else: - for f in file: - assert isinstance(f, h5py.File) - array = from_array(f[f"Step0/sN_{sp}"]) - dask_arrays.append(array) - - return xr.DataArray( - stack(dask_arrays, axis=0), - dims=["t", "e"], - name=f"n_{sp}", - coords={ - "t": times, - "s": ("t", steps), - "e": e_bins, - }, - ) diff --git a/nt2/dashboard.py b/nt2/dashboard.py deleted file mode 100644 index d5390d7..0000000 --- a/nt2/dashboard.py +++ /dev/null @@ -1,18 +0,0 @@ -class Dashboard: - def __init__(self, **kwargs): - from dask.distributed import Client - - self._client = Client(**kwargs) - - def restart(self): - self._client.restart() - - def close(self): - self._client.close() - - @property - def client(self): - return self._client - - def _repr_html_(self): - return self.client._repr_html_() diff --git a/nt2/data.py b/nt2/data.py deleted file mode 100644 index 5e78dc1..0000000 --- a/nt2/data.py +++ /dev/null @@ -1,101 +0,0 @@ -from nt2.containers.fields import FieldsContainer -from nt2.containers.particles import ParticleContainer -from nt2.containers.spectra import SpectraContainer - -from nt2.containers.utils import InheritClassDocstring -from nt2.export import makeFramesAndMovie -from nt2.utils import FutureDeprecationWarning - -import warnings - - -@InheritClassDocstring -class Data(FieldsContainer, ParticleContainer, SpectraContainer): - """ - * * * * Data : FieldsContainer, ParticleContainer, SpectraContainer * * * * - - Master class for holding the whole simulation data. - Inherits attributes & methods from more specialized classes. - - """ - - def __init__(self, **kwargs): - """ - Kwargs - ------ - pickle : bool, optional - Whether to use pickle for reading the data. Default is True. - - greek : bool, optional - Whether to use Greek letters for the spherical coordinates. Default is False. - - dask_props : dict, optional - Additional properties for Dask [NOT IMPLEMENTED]. Default is {}. - - """ - if "single_file" in kwargs: - kwargs.pop("single_file") - DeprecationWarning() - warnings.warn( - """ --= Deprecation Warning =- -The argument `single_file` is no longer needed to be specified, -as the code derives this from the extension passed in the `path`. -The argument is ignored now, but will throw an error in future releases.""", - FutureDeprecationWarning, - ) - super(Data, self).__init__(**kwargs) - if "path" not in kwargs: - raise ValueError('Usage example: data = nt2.Data(path="...", ...)') - - def __repr__(self) -> str: - help = "Usage: \n" - help += ' data = Data(path="...", ...)\n' - help += " data.fields\n" - help += " data.particles\n" - help += " data.spectra\n" - return ( - help - + "\n" - + self.print_fields() - + "\n" - + self.print_particles() - + "\n" - + self.print_spectra() - ) - - def __str__(self) -> str: - return self.__repr__() - - def __del__(self): - super().__del__() - - def makeMovie(self, plot, times=None, **kwargs): - """ - Makes a movie from a plot function - - Parameters - ---------- - plot : function - The plot function to use; accepts output timestep indices or timestamps and, optionally, - the dataset as arguments. - - times : array_like, optional - Either time indices or timestamps to use for generating the movie. Default is None. - If None, will use timestamps (not steps) from the fields, - which might not coincide with values from other quantities. - - **kwargs : - Additional keyword arguments passed to `ffmpeg`. - - """ - - if times is None: - times = self.fields.t.values - return makeFramesAndMovie( - name=self.attrs["simulation.name"], - data=self, - plot=plot, - times=times, - **kwargs, - ) diff --git a/nt2/plotters/annotations.py b/nt2/plotters/annotations.py index 7843fd1..4a768b4 100644 --- a/nt2/plotters/annotations.py +++ b/nt2/plotters/annotations.py @@ -1,3 +1,6 @@ +import logging + + def annotatePulsar( ax, data, rmax, rstar=1.1, ti=None, time=None, attrs={}, ax_props={}, star_props={} ): @@ -18,8 +21,8 @@ def annotatePulsar( ) ) ) is None: - print( - "WARNING: No spinup time or spin period found, please specify explicitly as `attrs = {'psr_omega': ..., 'psr_spinup_time': ...}`" + logging.warning( + "No spinup time or spin period found, please specify explicitly as `attrs = {'psr_omega': ..., 'psr_spinup_time': ...}`" ) demo_rotation = False phase = 0 diff --git a/nt2/export.py b/nt2/plotters/export.py similarity index 83% rename from nt2/export.py rename to nt2/plotters/export.py index d5cc202..db0b19f 100644 --- a/nt2/export.py +++ b/nt2/plotters/export.py @@ -1,4 +1,13 @@ -def makeFramesAndMovie(name, plot, times, data=None, **kwargs): +from typing import Any, Callable + + +def makeFramesAndMovie( + name: str, + plot: Callable, # pyright: ignore[reportUnknownParameterType,reportMissingTypeArgument] + times: list[float], + data: Any = None, + **kwargs: Any, +) -> bool: num_cpus = kwargs.pop("num_cpus", None) if all( makeFrames( @@ -10,7 +19,7 @@ def makeFramesAndMovie(name, plot, times, data=None, **kwargs): ) ): print(f"Frames saved in {name}/frames") - output = kwargs.pop("output", f"{name}.mp4") + output: str = kwargs.pop("output", f"{name}.mp4") if makeMovie( input=f"{name}/frames/", overwrite=True, @@ -26,7 +35,7 @@ def makeFramesAndMovie(name, plot, times, data=None, **kwargs): raise ValueError("Failed to make frames") -def makeMovie(**ffmpeg_kwargs): +def makeMovie(**ffmpeg_kwargs: str | int | float) -> bool: """ Create a movie from frames using the `ffmpeg` command-line tool. @@ -52,6 +61,10 @@ def makeMovie(**ffmpeg_kwargs): """ import subprocess + input_pattern: str = ( + f"{ffmpeg_kwargs.get('input', 'step_')}%0{ffmpeg_kwargs.get('number', 3)}d.{ffmpeg_kwargs.get('extension', 'png')}" + ) + command = [ ffmpeg_kwargs.get("ffmpeg", "ffmpeg"), "-nostdin", @@ -60,8 +73,7 @@ def makeMovie(**ffmpeg_kwargs): "-start_number", str(ffmpeg_kwargs.get("start", 0)), "-i", - ffmpeg_kwargs.get("input", "step_") - + f"%0{ffmpeg_kwargs.get('number', 3)}d.{ffmpeg_kwargs.get('extension', 'png')}", + input_pattern, "-c:v", "libx264", "-crf", @@ -83,7 +95,13 @@ def makeMovie(**ffmpeg_kwargs): return False -def makeFrames(plot, times, fpath, data=None, num_cpus=None): +def makeFrames( + plot: Callable, # pyright: ignore[reportUnknownParameterType,reportMissingTypeArgument] + times: list[float], + fpath: str, + data: Any = None, + num_cpus: int | None = None, +) -> list[bool]: """ Create plot frames from a set of timesteps of the same dataset. @@ -135,7 +153,7 @@ def makeFrames(plot, times, fpath, data=None, num_cpus=None): global plotAndSave - def plotAndSave(ti, t, fpath): + def plotAndSave(ti: int, t: float, fpath: str) -> bool: try: if data is None: plot(t) diff --git a/nt2/plotters/inspect.py b/nt2/plotters/inspect.py index a986352..ea5d91b 100644 --- a/nt2/plotters/inspect.py +++ b/nt2/plotters/inspect.py @@ -1,30 +1,139 @@ -from nt2.containers.utils import _dataIs2DPolar -from nt2.export import makeFramesAndMovie +# pyright: reportMissingTypeStubs=false +from typing import Any, Callable +import matplotlib.pyplot as plt +import matplotlib.figure as mfigure +import xarray as xr +from nt2.utils import DataIs2DPolar +from nt2.plotters.export import makeFramesAndMovie -class _datasetInspectPlotAccessor: - def __init__(self, xarray_obj): - self._obj = xarray_obj + +class ds_accessor: + def __init__(self, xarray_obj: xr.Dataset): + self._obj: xr.Dataset = xarray_obj + + def __axes_grid( + self, + grouped_fields: dict[str, list[str]], + makeplot: Callable, # pyright: ignore[reportUnknownParameterType,reportMissingTypeArgument] + nrows: int, + ncols: int, + nfields: int, + size: float, + aspect: float, + pad: float, + **fig_kwargs: Any, + ) -> tuple[mfigure.Figure, list[plt.Axes]]: + if aspect > 1: + axw = size / aspect + axh = size + else: + axw = size + axh = size * aspect + + fig_w = ncols * (axw + pad) + pad + fig_h = nrows * axh + (nrows + 1) * pad + fig = plt.figure(figsize=(fig_w, fig_h), **fig_kwargs) + + gs = fig.add_gridspec(nrows, ncols, wspace=pad / axw, hspace=pad / axh) + axes = [ + fig.add_subplot(gs[i, j]) + for i in range(nrows) + for j in range(ncols) + if (i * ncols + j) < nfields + ] + for ax, (g, fields) in zip(axes, grouped_fields.items()): + for field in fields: + makeplot(ax, field) + _ = ax.set_ylabel(g) + _ = ax.set_title(None) + + return fig, axes + + @staticmethod + def _fixed_axes_grid_with_cbars( + fields: list[str], + makeplot: Callable, # pyright: ignore[reportUnknownParameterType,reportMissingTypeArgument] + makecbar: Callable, # pyright: ignore[reportUnknownParameterType,reportMissingTypeArgument] + nrows: int, + ncols: int, + nfields: int, + size: float, + aspect: float, + pad: float, + cbar_w: float, + **fig_kwargs: Any, + ) -> tuple[mfigure.Figure, list[plt.Axes]]: + from mpl_toolkits.axes_grid1 import Divider, Size + + if aspect > 1: + axw = size / aspect + axh = size + else: + axw = size + axh = size * aspect + + fig_w = ncols * (axw + cbar_w + pad) + pad + fig_h = nrows * axh + (nrows + 1) * pad + fig = plt.figure(figsize=(fig_w, fig_h), **fig_kwargs) + + h = [] + for _ in range(ncols): + h += [Size.Fixed(pad), Size.Fixed(axw), Size.Fixed(cbar_w)] + h += [Size.Fixed(pad)] + + v = [] + for _ in range(nrows): + v += [Size.Fixed(pad), Size.Fixed(axh)] + v += [Size.Fixed(pad)] + + divider = Divider(fig, (0, 0, 1, 1), h, v, aspect=False) + axes: list[plt.Axes] = [] + + cntr = 0 + for i in range(nrows): + for j in range(ncols): + cntr += 1 + if cntr > nfields: + break + nx = 3 * j + 1 + ny = 2 * (nrows - 1 - i) + 1 + + ax = fig.add_axes( + divider.get_position(), + axes_locator=divider.new_locator(nx=nx, ny=ny), + ) + field = fields[cntr - 1] + im = makeplot(ax, field) + cax = fig.add_axes( + divider.get_position(), + axes_locator=divider.new_locator(nx=nx + 1, ny=ny), + ) + _ = fig.colorbar(im, cax=cax) + makecbar(ax, cax, field) + axes.append(ax) + return fig, axes def plot( self, - fig=None, - name=None, - skip_fields=[], - only_fields=[], - fig_kwargs={}, - plot_kwargs={}, - movie_kwargs={}, - ): + fig: mfigure.Figure | None = None, + name: str | None = None, + skip_fields: list[str] | None = None, + only_fields: list[str] | None = None, + fig_kwargs: dict[str, Any] | None = None, + plot_kwargs: dict[str, Any] | None = None, + movie_kwargs: dict[str, Any] | None = None, + set_aspect: str | None = "equal", + ) -> mfigure.Figure | bool: """ Plots the overview plot for fields at a given time or step (or as a movie). Kwargs ------ - fig : matplotlib.figure.Figure, optional + fig : matplotlib.figure.Figure | None, optional The figure to plot the data (if None, a new figure is created). Default is None. - name : string, optional + name : string | None, optional Used when saving the frames and the movie. Default is None. skip_fields : list, optional @@ -45,27 +154,55 @@ def plot( movie_kwargs : dict, optional Additional keyword arguments for makeMovie. Default is {}. + set_aspect : str | None, optional + If None, the aspect ratio will not be enforced. Otherwise, this value is passed to `set_aspect` method of the axes. Default is 'equal'. + Returns ------- figure : matplotlib.figure.Figure | boolean The figure with the plotted data (if single timestep) or True/False. """ + if skip_fields is None: + skip_fields = [] + if only_fields is None: + only_fields = [] + if fig_kwargs is None: + fig_kwargs = {} + if plot_kwargs is None: + plot_kwargs = {} + if movie_kwargs is None: + movie_kwargs = {} if "t" in self._obj.dims: if name is None: raise ValueError( "Please provide a name for saving the frames and movie" ) - def plot_func(ti, _): - self.plot_frame( - self._obj.isel(t=ti), - None, - skip_fields, - only_fields, - fig_kwargs, - plot_kwargs, - ) + def plot_func(ti: int, _): + if len(self._obj.dims) == 1: + _ = self.plot_frame_1d( + self._obj.isel(t=ti), + None, + skip_fields, + only_fields, + fig_kwargs, + plot_kwargs, + ) + elif len(self._obj.dims) == 2: + _ = self.plot_frame_2d( + self._obj.isel(t=ti), + None, + skip_fields, + only_fields, + fig_kwargs, + plot_kwargs, + set_aspect, + ) + else: + raise ValueError( + "Data has more than 2 dimensions; use .sel or .isel to reduce dimension." + ) return makeFramesAndMovie( name=name, @@ -75,37 +212,46 @@ def plot_func(ti, _): **movie_kwargs, ) else: - return self.plot_frame( - self._obj, fig, skip_fields, only_fields, fig_kwargs, plot_kwargs - ) - - def plot_frame(self, data, fig, skip_fields, only_fields, fig_kwargs, plot_kwargs): - if len(data.dims) != 2: - raise ValueError("Pass 2D data; use .sel or .isel to reduce dimension.") - - x1, x2 = data.dims + if len(self._obj.dims) == 1: + return self.plot_frame_1d( + self._obj, + fig, + skip_fields, + only_fields, + fig_kwargs, + plot_kwargs, + ) + elif len(self._obj.dims) == 2: + return self.plot_frame_2d( + self._obj, + fig, + skip_fields, + only_fields, + fig_kwargs, + plot_kwargs, + set_aspect, + ) + else: + raise ValueError( + "Data has more than 2 dimensions; use .sel or .isel to reduce dimension." + ) - import matplotlib.pyplot as plt - from matplotlib.gridspec import GridSpec, GridSpecFromSubplotSpec - import matplotlib.colors as mcolors - import numpy as np + @staticmethod + def _get_fields_to_plot( + data: xr.Dataset, skip_fields: list[str], only_fields: list[str] + ) -> list[str]: import re - import math - # count the number of subplots nfields = len(data.data_vars) if nfields > 0: + keys: list[str] = [str(k) for k in data.keys()] if len(only_fields) == 0: fields_to_plot = [ - f - for f in list(data.keys()) - if not any([re.match(sf, f) for sf in skip_fields]) + f for f in keys if not any([re.match(sf, f) for sf in skip_fields]) ] else: fields_to_plot = [ - f - for f in list(data.keys()) - if any([re.match(sf, f) for sf in only_fields]) + f for f in keys if any([re.match(sf, f) for sf in only_fields]) ] else: fields_to_plot = [] @@ -113,77 +259,21 @@ def plot_frame(self, data, fig, skip_fields, only_fields, fig_kwargs, plot_kwarg if fields_to_plot == []: raise ValueError("No fields to plot.") - nfields = len(fields_to_plot) + fields_to_plot = sorted(fields_to_plot) + return fields_to_plot - aspect = 1 - if _dataIs2DPolar(data): - aspect = 0.5 - else: - aspect = len(data[x1]) / len(data[x2]) - - ncols = 3 if aspect <= 1.15 else int(math.ceil(nfields / 3)) - nrows = 3 if aspect > 1.15 else int(math.ceil(nfields / 3)) - - figsize0 = 3 - - if fig is None: - dpi = fig_kwargs.pop("dpi", 200) - fig = plt.figure( - figsize=( - figsize0 * ncols * aspect * (1 + 0.2 / aspect), - figsize0 * nrows, - ), - dpi=dpi, - **fig_kwargs, - ) - - gs = GridSpec(nrows, ncols, wspace=0.2 / aspect) - gs_for_axes = [ - GridSpecFromSubplotSpec( - 1, - 2, - subplot_spec=gs[i], - width_ratios=[1, max(0.025 / aspect, 0.025)], - wspace=0.01, - ) - for i in range(nfields) - ] - if aspect <= 1.15: - axes = [ - fig.add_subplot(gs_for_axes[i * ncols + j][0]) - for i in range(nrows) - for j in range(ncols) - if i * ncols + j < nfields - ] - cbars = [ - fig.add_subplot(gs_for_axes[i * ncols + j][1]) - for i in range(nrows) - for j in range(ncols) - if i * ncols + j < nfields - ] - else: - axes = [ - fig.add_subplot(gs_for_axes[i * ncols + j][0]) - for j in range(ncols) - for i in range(nrows) - if i * ncols + j < nfields - ] - cbars = [ - fig.add_subplot(gs_for_axes[i * ncols + j][1]) - for j in range(ncols) - for i in range(nrows) - if i * ncols + j < nfields - ] - - # find minmax for all components - minmax: dict[str, None | tuple] = { + @staticmethod + def _get_fields_minmax( + data: xr.Dataset, fields: list[str] + ) -> dict[str, None | tuple[float, float]]: + minmax: dict[str, None | tuple[float, float]] = { "E": None, "B": None, "J": None, "N": None, "T": None, } - for fld in fields_to_plot: + for fld in fields: vmin, vmax = ( data[fld].min().values[()], data[fld].max().values[()], @@ -207,6 +297,122 @@ def plot_frame(self, data, fig, skip_fields, only_fields, fig_kwargs, plot_kwarg vmin = -vmax minmax[f] = (vmin, vmax) + return minmax + + def plot_frame_1d( + self, + data: xr.Dataset, + fig: mfigure.Figure | None, + skip_fields: list[str], + only_fields: list[str], + fig_kwargs: dict[str, Any], + plot_kwargs: dict[str, Any], + ) -> mfigure.Figure: + if len(data.dims) != 1: + raise ValueError("Pass 1D data; use .sel or .isel to reduce dimension.") + + import math, re + + # count the number of subplots + fields_to_plot = self._get_fields_to_plot(data, skip_fields, only_fields) + + # group fields by their first letter + grouped_fields: dict[str, list[str]] = {} + for f in fields_to_plot: + key = f[0] + if key not in grouped_fields: + grouped_fields[key] = [] + grouped_fields[key].append(f) + + nplots = len(grouped_fields) + + aspect = 0.5 + ncols = max(1, int(math.floor(nplots * 1.5 * aspect / (1 + 1.5 * aspect)))) + nrows = max(1, int(math.ceil(nplots / ncols))) + + figsize0 = 3.0 + + minmax = self._get_fields_minmax(data, fields_to_plot) + kwargs = {} + for fld in fields_to_plot: + kwargs[fld] = {} + for fld_kwargs in plot_kwargs: + if re.match(fld_kwargs, fld): + kwargs[fld] = {**plot_kwargs[fld_kwargs]} + break + + def make_plot(ax: plt.Axes, fld: str): + data[fld].plot(ax=ax, label=fld, **kwargs[fld]) + _ = ax.set(ylim=minmax[fld[0]]) + + fig, axes = self.__axes_grid( + grouped_fields=grouped_fields, + makeplot=make_plot, + nrows=nrows, + ncols=ncols, + nfields=nplots, + size=figsize0, + aspect=aspect, + pad=0.5, + **fig_kwargs, + ) + for n, ax in enumerate(axes): + i = n // ncols + j = n % ncols + + if j != 0: + _ = ax.set( + ylabel=None, + yticklabels=[], + ) + if (nplots - i * ncols - j) > ncols: + _ = ax.set( + xlabel=None, + xticklabels=[], + ) + _ = ax.legend(loc="best", fontsize="small") + _ = fig.suptitle(f"t = {data.t.values[()]:.2f}", y=0.95) + return fig + + def plot_frame_2d( + self, + data: xr.Dataset, + fig: mfigure.Figure | None, + skip_fields: list[str], + only_fields: list[str], + fig_kwargs: dict[str, Any], + plot_kwargs: dict[str, Any], + set_aspect: str | None, + ) -> mfigure.Figure: + if len(data.dims) != 2: + raise ValueError("Pass 2D data; use .sel or .isel to reduce dimension.") + + x1, x2 = data.dims + + import matplotlib.colors as mcolors + import numpy as np + import math, re + + # count the number of subplots + fields_to_plot = self._get_fields_to_plot(data, skip_fields, only_fields) + nfields = len(fields_to_plot) + + aspect = 1 + if not DataIs2DPolar(data): + aspect = (data[x1].values.max() - data[x1].values.min()) / ( + data[x2].values.max() - data[x2].values.min() + ) + aspect = aspect[()] + else: + aspect = 1.5 + + ncols = max(1, int(math.floor(nfields * 1.5 * aspect / (1 + 1.5 * aspect)))) + nrows = max(1, int(math.ceil(nfields / ncols))) + + figsize0 = 3.0 + + minmax = self._get_fields_minmax(data, fields_to_plot) + kwargs = {} for fld in fields_to_plot: cmap = "viridis" @@ -246,17 +452,30 @@ def plot_frame(self, data, fig, skip_fields, only_fields, fig_kwargs, plot_kwarg kwargs[fld] = {**default_kwargs, **plot_kwargs[fld_kwargs]} break if "norm" in kwargs[fld]: - kwargs[fld].pop("vmin") - kwargs[fld].pop("vmax") + vmin = kwargs[fld].pop("vmin") + vmax = kwargs[fld].pop("vmax") + norm_str: str = kwargs[fld].pop("norm") + if norm_str == "linear": + kwargs[fld]["vmin"] = vmin + kwargs[fld]["vmax"] = vmax + elif norm_str == "log": + if vmin <= 0: + vmin = 1e-3 * vmax + kwargs[fld]["norm"] = mcolors.LogNorm(vmin=vmin, vmax=vmax) + elif norm_str == "symlog": + linthresh = kwargs[fld].pop("linthresh", 1e-3 * vmax) + kwargs[fld]["norm"] = mcolors.SymLogNorm( + linthresh=linthresh, vmin=vmin, vmax=vmax, linscale=1 + ) - if _dataIs2DPolar(data): - raise NotImplementedError("Polar plots for inspect not implemented yet.") - else: - for fld, ax in zip(fields_to_plot, axes): + def make_plot(ax: plt.Axes, fld: str): + if DataIs2DPolar(data): + data[fld].polar.pcolor(ax=ax, cbar_position=None, **kwargs[fld]) + else: data[fld].plot(ax=ax, add_colorbar=False, **kwargs[fld]) - for i, (ax, cbar, fld) in enumerate(zip(axes, cbars, fields_to_plot)): - cbar.set(xticks=[], xlabel=None, ylabel=None) + def make_cbar(ax: plt.Axes, cbar: plt.Axes, fld: str): + _ = cbar.set(xticks=[], xlabel=None, ylabel=None) cbar.yaxis.tick_right() vmin, vmax = ax.collections[0].get_clim() if vmin == vmax: @@ -270,10 +489,10 @@ def plot_frame(self, data, fig, skip_fields, only_fields, fig_kwargs, plot_kwarg vmin /= coeff vmax /= coeff if isinstance(ax.collections[0].norm, mcolors.LogNorm): - cbar.set(ylim=(vmin, vmax), yscale="log") data_norm = mcolors.LogNorm(vmin=vmin, vmax=vmax) + _ = cbar.set(ylim=(vmin, vmax), yscale="log") ys = np.logspace(np.log10(vmin), np.log10(vmax)) - cbar.pcolor( + _ = cbar.pcolor( [0, 1], ys, np.transpose([ys] * 2), @@ -282,11 +501,46 @@ def plot_frame(self, data, fig, skip_fields, only_fields, fig_kwargs, plot_kwarg norm=data_norm, ) elif isinstance(ax.collections[0].norm, mcolors.SymLogNorm): - raise NotImplementedError("SymLogNorm not implemented yet.") + data_norm = ax.collections[0].norm + _ = cbar.set_ylim(vmin, vmax) + _ = cbar.set_yscale( + "symlog", + linthresh=data_norm.linthresh, + linscale=1, + ) + ys = np.concatenate( + ( + -np.logspace( + np.log10(-vmin), + np.log10(data_norm.linthresh), + num=100, + endpoint=False, + ), + np.linspace( + -data_norm.linthresh, + data_norm.linthresh, + num=10, + endpoint=False, + ), + np.logspace( + np.log10(data_norm.linthresh), + np.log10(vmax), + num=100, + ), + ) + ) + _ = cbar.pcolor( + [0, 1], + ys, + np.transpose([ys] * 2), + cmap=kwargs[fld]["cmap"], + rasterized=True, + norm=data_norm, + ) else: - cbar.set(ylim=(vmin, vmax)) + _ = cbar.set(ylim=(vmin, vmax)) ys = np.linspace(vmin, vmax) - cbar.pcolor( + _ = cbar.pcolor( [0, 1], ys, np.transpose([ys] * 2), @@ -294,30 +548,41 @@ def plot_frame(self, data, fig, skip_fields, only_fields, fig_kwargs, plot_kwarg rasterized=True, norm=mcolors.Normalize(vmin=vmin, vmax=vmax), ) - ax.set( + _ = ax.set( title=f"{fld}" + ("" if coeff_pow == 0 else f" [$\\cdot 10^{-coeff_pow}$]") ) + fig, axes = self._fixed_axes_grid_with_cbars( + fields=fields_to_plot, + makeplot=make_plot, + makecbar=make_cbar, + nrows=nrows, + ncols=ncols, + nfields=nfields, + size=figsize0, + aspect=aspect, + pad=0.5, + cbar_w=0.1, + **fig_kwargs, + ) + for n, ax in enumerate(axes): - if aspect > 1.15: - i = n % nrows - j = n // nrows - else: - i = n // ncols - j = n % ncols + i = n // ncols + j = n % ncols if j != 0: - ax.set( + _ = ax.set( ylabel=None, yticklabels=[], ) if (nfields - i * ncols - j) > ncols: - ax.set( + _ = ax.set( xlabel=None, xticklabels=[], ) - ax.set(aspect=1) + if set_aspect is not None: + _ = ax.set(aspect=set_aspect) - fig.suptitle(f"t = {data.t.values[()]:.2f}", y=0.95) + _ = fig.suptitle(f"t = {data.t.values[()]:.2f}", y=1.0) return fig diff --git a/nt2/plotters/movie.py b/nt2/plotters/movie.py index 1704e84..3eec8ad 100644 --- a/nt2/plotters/movie.py +++ b/nt2/plotters/movie.py @@ -1,28 +1,40 @@ -from nt2.export import makeFramesAndMovie +from typing import Any +from nt2.plotters.export import ( + makeFramesAndMovie, +) +import xarray as xr -class _moviePlotAccessor: - def __init__(self, xarray_obj): - self._obj = xarray_obj +class accessor: + def __init__(self, xarray_obj: xr.DataArray) -> None: + self._obj: xr.DataArray = xarray_obj - def plot(self, name, movie_kwargs={}, *args, **kwargs): + def plot( + self, + name: str, + movie_kwargs: dict[str, Any] | None = None, + *args: Any, + **kwargs: Any, + ) -> bool: + if movie_kwargs is None: + movie_kwargs = {} if "t" not in self._obj.dims: raise ValueError("The dataset does not have a time dimension.") import matplotlib.pyplot as plt - def plot_func(ti, _): + def plot_func(ti: int, _: Any) -> None: if len(self._obj.isel(t=ti).dims) == 2: x1, x2 = self._obj.isel(t=ti).dims nx1, nx2 = len(self._obj.isel(t=ti)[x1]), len(self._obj.isel(t=ti)[x2]) aspect = nx1 / nx2 - plt.figure(figsize=(6, 4 * aspect)) + _ = plt.figure(figsize=(6, 4 * aspect)) self._obj.isel(t=ti).plot(*args, **kwargs) if len(self._obj.isel(t=ti).dims) == 2: plt.gca().set_aspect("equal") plt.tight_layout() - num_cpus = movie_kwargs.pop("num_cpus", None) + num_cpus: int | None = movie_kwargs.pop("num_cpus", None) return makeFramesAndMovie( name=name, data=self._obj, diff --git a/nt2/plotters/polar.py b/nt2/plotters/polar.py index c5a9b5d..f25ca95 100644 --- a/nt2/plotters/polar.py +++ b/nt2/plotters/polar.py @@ -1,7 +1,7 @@ import numpy as np from typing import Any -from nt2.containers.utils import _dataIs2DPolar +from nt2.utils import DataIs2DPolar def DipoleSampling(**kwargs): @@ -53,13 +53,13 @@ def MonopoleSampling(**kwargs): return np.linspace(0, np.pi, nth + 2)[1:-1] -class _datasetPolarPlotAccessor: +class ds_accessor: def __init__(self, xarray_obj): self._obj = xarray_obj def pcolor(self, value, **kwargs): assert "t" not in self._obj[value].dims, "Time must be specified" - assert _dataIs2DPolar(self._obj), "Data must be 2D polar" + assert DataIs2DPolar(self._obj), "Data must be 2D polar" self._obj[value].polar.pcolor(**kwargs) def fieldplot( @@ -166,13 +166,11 @@ def fieldlines(self, fr, fth, start_points, **kwargs): assert "t" not in self._obj[fr].dims, "Time must be specified" assert "t" not in self._obj[fth].dims, "Time must be specified" - assert _dataIs2DPolar(self._obj), "Data must be 2D polar" - - useGreek = "θ" in self._obj.coords.keys() + assert DataIs2DPolar(self._obj), "Data must be 2D polar" r, th = ( self._obj.coords["r"].values, - self._obj.coords["θ" if useGreek else "th"].values, + self._obj.coords["th"].values, ) _, ths = np.meshgrid(r, th) fxs = self._obj[fr] * np.sin(ths) + self._obj[fth] * np.cos(ths) @@ -212,23 +210,23 @@ def stop(xy, rth): def integrate(delta, counter): r0, th0 = copy(r_th_start) - XY = np.array([r0 * np.sin(th0), r0 * np.cos(th0)]) - RTH = [r0, th0] - fieldline = np.array([XY]) + xy = np.array([r0 * np.sin(th0), r0 * np.cos(th0)]) + rth = [r0, th0] + fieldline = np.array([xy]) with np.errstate(divide="ignore", invalid="ignore"): while range(counter, maxsteps): - x, y = XY + x, y = xy r = np.sqrt(x**2 + y**2) th = np.arctan2(-y, x) + np.pi / 2 - RTH = [r, th] + rth = [r, th] vx = interp_fx((th, r))[()] vy = interp_fy((th, r))[()] vmag = np.sqrt(vx**2 + vy**2) - XY = XY + delta * np.array([vx, vy]) / vmag - if stop(XY, RTH) or np.isnan(XY).any() or np.isinf(XY).any(): + xy = xy + delta * np.array([vx, vy]) / vmag + if stop(xy, rth) or np.isnan(xy).any() or np.isinf(xy).any(): break else: - fieldline = np.append(fieldline, [XY], axis=0) + fieldline = np.append(fieldline, [xy], axis=0) return fieldline if direction == "forward": @@ -242,11 +240,11 @@ def integrate(delta, counter): return np.append(f2[::-1], f1, axis=0) -class _polarPlotAccessor: - def __init__(self, xarray_obj): +class accessor: + def __init__(self, xarray_obj) -> None: self._obj = xarray_obj - def pcolor(self, **kwargs): + def pcolor(self, **kwargs) -> Any: """ Plots a pseudocolor plot of 2D polar data on a rectilinear projection. @@ -300,8 +298,6 @@ def pcolor(self, **kwargs): import matplotlib as mpl from mpl_toolkits.axes_grid1 import make_axes_locatable - useGreek = "θ" in self._obj.coords.keys() - ax = kwargs.pop("ax", plt.gca()) cbar_size = kwargs.pop("cbar_size", "5%") cbar_pad = kwargs.pop("cbar_pad", 0.05) @@ -321,7 +317,7 @@ def pcolor(self, **kwargs): assert ax.name != "polar", "`ax` must be a rectilinear projection" assert "t" not in self._obj.dims, "Time must be specified" - assert _dataIs2DPolar(self._obj), "Data must be 2D polar" + assert DataIs2DPolar(self._obj), "Data must be 2D polar" ax.grid(False) if type(kwargs.get("norm", None)) is colors.LogNorm: cm = kwargs.get("cmap", "viridis") @@ -332,19 +328,12 @@ def pcolor(self, **kwargs): vals = self._obj.values.flatten() vals = np.concatenate((vals, vals)) if not cell_centered: - drs = self._obj.coords["r_2"] - self._obj.coords["r_1"] - dths = ( - self._obj.coords["θ_2" if useGreek else "th_2"] - - self._obj.coords["θ_1" if useGreek else "th_1"] - ) - r1s = self._obj.coords["r_1"] - drs * cell_size / 2 - r2s = self._obj.coords["r_1"] + drs * cell_size / 2 - th1s = ( - self._obj.coords["θ_1" if useGreek else "th_1"] - dths * cell_size / 2 - ) - th2s = ( - self._obj.coords["θ_1" if useGreek else "th_1"] + dths * cell_size / 2 - ) + drs = self._obj.coords["r_max"] - self._obj.coords["r_min"] + dths = self._obj.coords["th_max"] - self._obj.coords["th_min"] + r1s = self._obj.coords["r_min"] - drs * cell_size / 2 + r2s = self._obj.coords["r_min"] + drs * cell_size / 2 + th1s = self._obj.coords["th_min"] - dths * cell_size / 2 + th2s = self._obj.coords["th_min"] + dths * cell_size / 2 rs = np.ravel(np.column_stack((r1s, r2s))) ths = np.ravel(np.column_stack((th1s, th2s))) nr = len(rs) @@ -358,10 +347,10 @@ def pcolor(self, **kwargs): points_4 = np.arange(nth * nr).reshape(nth, -1)[1::2, :-1:2].flatten() else: - rs = np.append(self._obj.coords["r_1"], self._obj.coords["r_2"][-1]) + rs = np.append(self._obj.coords["r_min"], self._obj.coords["r_max"][-1]) ths = np.append( - self._obj.coords["θ_1" if useGreek else "th_1"], - self._obj.coords["θ_2" if useGreek else "th_2"][-1], + self._obj.coords["th_min"], + self._obj.coords["th_max"][-1], ) nr = len(rs) nth = len(ths) @@ -447,8 +436,6 @@ def contour(self, **kwargs): import warnings import matplotlib.pyplot as plt - useGreek = "θ" in self._obj.coords.keys() - ax = kwargs.pop("ax", plt.gca()) title = kwargs.pop("title", None) invert_x = kwargs.pop("invert_x", False) @@ -456,11 +443,9 @@ def contour(self, **kwargs): assert ax.name != "polar", "`ax` must be a rectilinear projection" assert "t" not in self._obj.dims, "Time must be specified" - assert _dataIs2DPolar(self._obj), "Data must be 2D polar" + assert DataIs2DPolar(self._obj), "Data must be 2D polar" ax.grid(False) - r, th = np.meshgrid( - self._obj.coords["r"], self._obj.coords["θ" if useGreek else "th"] - ) + r, th = np.meshgrid(self._obj.coords["r"], self._obj.coords["th"]) x, y = r * np.sin(th), r * np.cos(th) if invert_x: x = -x diff --git a/nt2/readers/__init__.py b/nt2/readers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/nt2/readers/adios2.py b/nt2/readers/adios2.py new file mode 100644 index 0000000..14f9aaf --- /dev/null +++ b/nt2/readers/adios2.py @@ -0,0 +1,167 @@ +# pyright: reportMissingTypeStubs=false + +from typing import Any + +import sys +if sys.version_info >= (3, 12): + from typing import override +else: + def override(method): + return method + +import re +import os +import numpy as np +import numpy.typing as npt + +import adios2 as bp + +from nt2.utils import Format, Layout +from nt2.readers.base import BaseReader + + +class Reader(BaseReader): + @property + @override + def format(self) -> Format: + return Format.BP5 + + @staticmethod + @override + def EnterFile( + filename: str, + ) -> bp.FileReader: + return bp.FileReader(filename) + + @override + def ReadPerTimestepVariable( + self, + path: str, + category: str, + varname: str, + newname: str, + ) -> dict[str, npt.NDArray[Any]]: + variables: list[float] = [] + for filename in self.GetValidFiles( + path=path, + category=category, + ): + with bp.FileReader(os.path.join(path, category, filename)) as f: + avail: dict[str, Any] = f.available_variables() + vars: list[str] = list(avail.keys()) + if varname in vars: + var = f.inquire_variable(varname) + if var is not None: + variables.append(f.read(var)) + else: + raise ValueError( + f"{varname} is not a variable in the BP file {filename}" + ) + else: + raise ValueError(f"{varname} not found in the BP file {filename}") + return {newname: np.array(variables)} + + @override + def ReadEdgeCoordsAtTimestep( + self, + path: str, + step: int, + ) -> dict[str, Any]: + dct: dict[str, npt.NDArray[Any]] = {} + with bp.FileReader(self.FullPath(path, "fields", step)) as f: + avail: dict[str, Any] = f.available_variables() + vars: list[str] = list(avail.keys()) + for var in vars: + if var.startswith("X") and var.endswith("e"): + var_obj = f.inquire_variable(var) + if var_obj is not None: + dct[var] = f.read(var_obj) + return dct + + @override + def ReadAttrsAtTimestep( + self, + path: str, + category: str, + step: int, + ) -> dict[str, Any]: + with bp.FileReader(self.FullPath(path, category, step)) as f: + return {k: f.read_attribute(k) for k in f.available_attributes()} + + @override + def ReadArrayAtTimestep( + self, + path: str, + category: str, + quantity: str, + step: int, + ) -> npt.NDArray[Any]: + with bp.FileReader(filename := self.FullPath(path, category, step)) as f: + if quantity in f.available_variables(): + var = f.inquire_variable(quantity) + if var is not None: + return np.array(f.read(var)) + else: + raise ValueError(f"{quantity} not found in the {filename}") + else: + raise ValueError(f"{quantity} not found in the {filename}") + + @override + def ReadCategoryNamesAtTimestep( + self, + path: str, + category: str, + prefix: str, + step: int, + ) -> set[str]: + with bp.FileReader(self.FullPath(path, category, step)) as f: + keys: list[str] = f.available_variables() + return set( + filter( + lambda c: c.startswith(prefix), + keys, + ) + ) + + @override + def ReadArrayShapeAtTimestep( + self, path: str, category: str, quantity: str, step: int + ) -> tuple[int]: + with bp.FileReader(filename := self.FullPath(path, category, step)) as f: + if quantity in f.available_variables(): + var = f.inquire_variable(quantity) + if var is not None: + return var.shape() + else: + raise ValueError( + f"{category.capitalize()} {quantity} is not a group in the {filename}" + ) + else: + raise ValueError( + f"{category.capitalize()} {quantity} not found in the {filename}" + ) + + @override + def ReadFieldCoordsAtTimestep( + self, path: str, step: int + ) -> dict[str, npt.NDArray[Any]]: + with bp.FileReader(filename := self.FullPath(path, "fields", step)) as f: + + def get_coord(c: str) -> npt.NDArray[Any]: + f_c = f.inquire_variable(c) + if f_c is not None: + return np.array(f.read(f_c)) + else: + raise ValueError(f"Field {c} is not a group in the {filename}") + + keys: list[str] = list(f.available_variables()) + return {c: get_coord(c) for c in keys if re.match(r"^X[1|2|3]$", c)} + + @override + def ReadFieldLayoutAtTimestep(self, path: str, step: int) -> Layout: + with bp.FileReader(filename := self.FullPath(path, "fields", step)) as f: + attrs: dict[str, Any] = f.available_attributes() + keys = list(attrs.keys()) + if "LayoutRight" not in keys: + raise ValueError(f"LayoutRight attribute not found in the {filename}") + return Layout.R if f.read_attribute("LayoutRight") else Layout.L diff --git a/nt2/readers/base.py b/nt2/readers/base.py new file mode 100644 index 0000000..d547c81 --- /dev/null +++ b/nt2/readers/base.py @@ -0,0 +1,599 @@ +# pyright: reportUnusedParameter=false + +from typing import Any +import numpy.typing as npt +import os, re, logging + +from nt2.utils import Format, Layout + + +class BaseReader: + """Base virtual class for arbitrary format readers. + + Implements common methods for reading files in different formats and declares virtual methods to be implemented in subclasses. + + """ + + skipped_files: list[str] + + def __init__(self) -> None: + """Initializer for the BaseReader class.""" + self.skipped_files = [] + + # # # # # # # # # # # # # # # # # # # # # # # # + # Virtual methods (to be implemented in subclasses) + # # # # # # # # # # # # # # # # # # # # # # # # + + @property + def format(self) -> Format: + """Format: the format of the reader.""" + raise NotImplementedError("format is not implemented") + + @staticmethod + def EnterFile( + filename: str, + ) -> Any: + """Open a file and return the file object. + + Parameters + ---------- + filename: str + The full path to the file to be opened. + + Returns + ------- + Any + A file object. + + """ + raise NotImplementedError("EnterFile is not implemented") + + def ReadPerTimestepVariable( + self, + path: str, + category: str, + varname: str, + newname: str, + ) -> dict[str, npt.NDArray[Any]]: + """Read a variable at each timestep and return a dictionary with the new name. + + Parameters + ---------- + path : str + The path to the files. + category : str + The category of the files. + varname : str + The name of the variable to be read. + newname : str + The new name of the variable to be returned. + + Returns + ------- + dict[str, NDArray[Any]] + A dictionary with the new name and the variable at each timestep. + + """ + raise NotImplementedError("ReadPerTimestepVariable is not implemented") + + def ReadAttrsAtTimestep( + self, + path: str, + category: str, + step: int, + ) -> dict[str, Any]: + """Read the attributes of a given timestep. + + Parameters + ---------- + path : str + The path to the files. + category : str + The category of the files. + step : int + The timestep to be read. + + Returns + ------- + dict[str, Any] + A dictionary with the attributes of the timestep. + + """ + raise NotImplementedError("ReadAttrsAtTimestep is not implemented") + + def ReadEdgeCoordsAtTimestep( + self, + path: str, + step: int, + ) -> dict[str, npt.NDArray[Any]]: + """Read the coordinates of cell edges at a given timestep. + + Parameters + ---------- + path : str + The path to the files. + step : int + The timestep to be read. + + Returns + ------- + dict[str, NDArray[Any]] + A dictionary with the coordinates of the cell edges. + + """ + raise NotImplementedError("ReadEdgeCoordsAtTimestep is not implemented") + + def ReadArrayAtTimestep( + self, + path: str, + category: str, + quantity: str, + step: int, + ) -> npt.NDArray[Any]: + """Read an array at a given timestep. + + Parameters + ---------- + path : str + The path to the files. + category : str + The category of the files. + quantity : str + The name of the array to be read. + step : int + The timestep to be read. + + Returns + ------- + NDArray[Any] + The array at a given timestep. + + """ + raise NotImplementedError("ReadArrayAtTimestep is not implemented") + + def ReadCategoryNamesAtTimestep( + self, + path: str, + category: str, + prefix: str, + step: int, + ) -> set[str]: + """Read the names of the variables in a given category and timestep. + + Parameters + ---------- + path : str + The path to the files. + category : str + The category of the files. + prefix : str + The prefix of the variables to be read. + step : int + The timestep to be read. + + Returns + ------- + set[str] + The names of the variables in the category. + + """ + raise NotImplementedError("ReadCategoryNamesAtTimestep is not implemented") + + def ReadParticleSpeciesAtTimestep(self, path: str, step: int) -> set[int]: + """Read the particle species indices at a given timestep. + + Parameters + ---------- + path : str + The path to the files. + step : int + The timestep to be read. + + Returns + ------- + set[int] + A set of particle species indices at a given timestep. + + """ + return set( + int(f.split("_")[1]) + for f in self.ReadCategoryNamesAtTimestep(path, "particles", "p", step) + ) + + def ReadArrayShapeAtTimestep( + self, + path: str, + category: str, + quantity: str, + step: int, + ) -> tuple[int]: + """Read the shape of an array at a given timestep. + + Parameters + ---------- + path : str + The path to the files. + category : str + The category of the files. + quantity : str + The name of the quantity to be read. + step : int + The timestep to be read. + + Returns + ------- + tuple[int] + The shape of the array at a given timestep. + + """ + raise NotImplementedError("ReadArrayShapeAtTimestep is not implemented") + + def ReadFieldCoordsAtTimestep( + self, + path: str, + step: int, + ) -> dict[str, npt.NDArray[Any]]: + """Read the coordinates of the fields at a given timestep. + + Parameters + ---------- + path : str + The path to the files. + step : int + The timestep to be read. + + Returns + ------- + dict[str, NDArray[Any]] + A dictionary with the coordinates of the fields where the keys are the names of the coordinates and the values are. + + """ + raise NotImplementedError("ReadFieldCoordsAtTimestep is not implemented") + + def ReadFieldLayoutAtTimestep(self, path: str, step: int) -> Layout: + """Read the layout of the fields at a given timestep. + + Parameters + ---------- + path : str + The path to the files. + step : int + The timestep to be read. + + Returns + ------- + Layout + The layout of the fields at a given timestep (R or L). + + """ + raise NotImplementedError("ReadFieldLayoutAtTimestep is not implemented") + + # # # # # # # # # # # # # # # # # # # # # # # # + # Common methods + # # # # # # # # # # # # # # # # # # # # # # # # + + @staticmethod + def CategoryFiles(path: str, category: str, format: str) -> list[str]: + """Get the list of files in a given category and format. + + Parameters + ---------- + path : str + The path to the files. + category : str + The category of the files. + format : str + The format of the files. + + Returns + ------- + list[str] + A list of files in the given category and format. + + Raises + ------ + ValueError + If no files are found. + + """ + files = [ + f + for f in os.listdir(os.path.join(path, category)) + if re.match(rf"^{category}\.\d{{{8}}}\.{format}", f) + ] + files.sort(key=lambda x: int(x.split(".")[1])) + if len(files) == 0: + raise ValueError(f"No {category} files found in the specified path") + return files + + def FullPath(self, path: str, category: str, step: int) -> str: + """Get the full path to a file. + + Parameters + ---------- + path : str + The path to the files. + category : str + The category of the files. + step : int + The timestep to be read. + + Returns + ------- + str + The full path to the file. + + """ + return os.path.join( + path, category, f"{category}.{step:08d}.{self.format.value}" + ) + + def GetValidSteps( + self, + path: str, + category: str, + ) -> list[int]: + """Get valid timesteps (sorted) in a given path and category. + + Parameters + ---------- + path : str + The path to the files. + category : str + The category of the files. + + Returns + ------- + list[int] + A list of valid timesteps in the given path and category. + + """ + steps: list[int] = [] + for filename in BaseReader.CategoryFiles( + path=path, + category=category, + format=self.format.value, + ): + try: + with self.EnterFile(os.path.join(path, category, filename)): + step = int(filename.split(".")[1]) + steps.append(step) + except OSError: + if filename not in self.skipped_files: + self.skipped_files.append(filename) + logging.warning(f"Could not read {filename}, skipping it") + except Exception as e: + raise e + steps.sort() + return steps + + def GetValidFiles( + self, + path: str, + category: str, + ) -> list[str]: + """Get valid files (sorted by timestep) in a given path and category. + + Parameters + ---------- + path : str + The path to the files. + category : str + The category of the files. + + Returns + ------- + list[str] + A list of valid files in the given path and category. + + """ + files: list[str] = [] + for filename in BaseReader.CategoryFiles( + path=path, + category=category, + format=self.format.value, + ): + try: + with self.EnterFile(os.path.join(path, category, filename)): + files.append(filename) + except OSError: + if filename not in self.skipped_files: + self.skipped_files.append(filename) + logging.warning(f"Could not read {filename}, skipping it") + except Exception as e: + raise e + files.sort(key=lambda x: int(x.split(".")[1])) + return files + + def VerifySameCategoryNames( + self, + path: str, + category: str, + prefix: str, + ): + """Verify that all files in a given category have the same names. + + Parameters + ---------- + path : str + The path to the files. + category : str + The category of the files. + prefix : str + The prefix of the variables to be read. + + Raises + ------ + ValueError + If different names are found. + + """ + names = None + for step in self.GetValidSteps( + path=path, + category=category, + ): + if names is None: + names = self.ReadCategoryNamesAtTimestep( + path=path, + category=category, + prefix=prefix, + step=step, + ) + else: + if names != self.ReadCategoryNamesAtTimestep( + path=path, + category=category, + prefix=prefix, + step=step, + ): + raise ValueError( + f"Different field names found in the {self.format.value} files for step {step}" + ) + + def VerifySameFieldShapes( + self, + path: str, + ): + """Verify that all fields in a given path have the same shape. + + Parameters + ---------- + path : str + The path to the files. + + Raises + ------ + ValueError + If different shapes are found. + + """ + shape = None + for step in self.GetValidSteps( + path=path, + category="fields", + ): + names = self.ReadCategoryNamesAtTimestep( + path=path, + category="fields", + prefix="f", + step=step, + ) + if shape is None: + name = names.pop() + shape = self.ReadArrayShapeAtTimestep( + path=path, + category="fields", + quantity=name, + step=step, + ) + for name in names: + if shape != self.ReadArrayShapeAtTimestep( + path=path, + category="fields", + quantity=name, + step=step, + ): + raise ValueError( + f"Different field shapes found in the {self.format.value} files for field {name} in step {step}" + ) + + def VerifySameFieldLayouts(self, path: str): + """Verify that all timesteps in a given path have the same layout. + + Parameters + ---------- + path : str + The path to the files. + + Raises + ------ + ValueError + If different layouts are found. + + """ + layout = None + for step in self.GetValidSteps( + path=path, + category="fields", + ): + if layout is None: + layout = self.ReadFieldLayoutAtTimestep( + path=path, + step=step, + ) + else: + if layout != self.ReadFieldLayoutAtTimestep( + path=path, + step=step, + ): + raise ValueError( + f"Different field layouts found in the {self.format.value} files for step {step}" + ) + + def VerifySameParticleShapes(self, path: str): + """Verify that all particle quantities in a given path have the same shape at specific timesteps. + + Parameters + ---------- + path : str + The path to the files. + + Raises + ------ + ValueError + If different shapes are found. + + """ + for step in self.GetValidSteps( + path=path, + category="particles", + ): + prtl_species = self.ReadParticleSpeciesAtTimestep(path=path, step=step) + quantities = self.ReadCategoryNamesAtTimestep( + path=path, + category="particles", + prefix="p", + step=step, + ) + quantities = set(q.split("_")[0] for q in quantities if q.startswith("p")) + for sp in prtl_species: + shape = None + for q in quantities: + if shape is None: + shape = self.ReadArrayShapeAtTimestep( + path=path, + category="particles", + quantity=f"{q}_{sp}", + step=step, + ) + else: + if shape != self.ReadArrayShapeAtTimestep( + path=path, + category="particles", + quantity=f"{q}_{sp}", + step=step, + ): + raise ValueError( + f"Different particle shapes found in the {self.format.value} files for species {sp} and quantity {q} in step {step}" + ) + + def DefinesCategory(self, path: str, category: str) -> bool: + """Check whether a given category is defined in the path. + + Parameters + ---------- + path : str + The path to the files. + category : str + The category to be checked. + + Returns + ------- + bool + True if the category is defined, False otherwise. + + """ + return os.path.exists(os.path.join(path, category)) and ( + len(self.GetValidFiles(path=path, category=category)) > 0 + ) diff --git a/nt2/readers/hdf5.py b/nt2/readers/hdf5.py new file mode 100644 index 0000000..8ec7708 --- /dev/null +++ b/nt2/readers/hdf5.py @@ -0,0 +1,168 @@ +# pyright: reportMissingTypeStubs=false + +from typing import Any + +import sys +if sys.version_info >= (3, 12): + from typing import override +else: + def override(method): + return method + +import re +import os +import numpy as np +import numpy.typing as npt + +import h5py + +from nt2.utils import Format, Layout +from nt2.readers.base import BaseReader + + +class Reader(BaseReader): + @staticmethod + def __extract_step0(f: h5py.File) -> h5py.Group: + if "Step0" in f.keys(): + f0 = f["Step0"] + if isinstance(f0, h5py.Group): + return f0 + else: + raise ValueError(f"Step0 is not a group in the HDF5 file") + else: + raise ValueError(f"Wrong structure of the hdf5 file") + + @property + @override + def format(self) -> Format: + return Format.HDF5 + + @staticmethod + @override + def EnterFile( + filename: str, + ) -> h5py.File: + return h5py.File(filename, "r") + + @override + def ReadPerTimestepVariable( + self, + path: str, + category: str, + varname: str, + newname: str, + ) -> dict[str, npt.NDArray[Any]]: + variables: list[Any] = [] + for filename in self.GetValidFiles( + path=path, + category=category, + ): + with h5py.File(os.path.join(path, category, filename), "r") as f: + f0 = Reader.__extract_step0(f) + if varname in f0.keys(): + var = f0[varname] + if isinstance(var, h5py.Dataset): + variables.append(var[()]) + else: + raise ValueError( + f"{varname} is not a group in the HDF5 file {filename}" + ) + else: + raise ValueError(f"{varname} not found in the HDF5 file {filename}") + + return {newname: np.array(variables)} + + @override + def ReadAttrsAtTimestep( + self, + path: str, + category: str, + step: int, + ) -> dict[str, Any]: + with h5py.File(self.FullPath(path, category, step), "r") as f: + return {k: v for k, v in f.attrs.items()} + + @override + def ReadEdgeCoordsAtTimestep( + self, + path: str, + step: int, + ) -> dict[str, npt.NDArray[Any]]: + with h5py.File(self.FullPath(path, "fields", step), "r") as f: + f0 = Reader.__extract_step0(f) + return {k: v[:] for k, v in f0.items() if k[0] == "X" and k[-1] == "e"} + + @override + def ReadArrayAtTimestep( + self, + path: str, + category: str, + quantity: str, + step: int, + ) -> npt.NDArray[Any]: + with h5py.File(filename := self.FullPath(path, category, step), "r") as f: + f0 = Reader.__extract_step0(f) + if quantity in f0.keys(): + var = f0[quantity] + if isinstance(var, h5py.Dataset): + return np.array(var[:]) + else: + raise ValueError(f"{quantity} is not a group in the {filename}") + else: + raise ValueError(f"{quantity} not found in the {filename}") + + @override + def ReadCategoryNamesAtTimestep( + self, + path: str, + category: str, + prefix: str, + step: int, + ) -> set[str]: + with h5py.File(self.FullPath(path, category, step), "r") as f: + f0 = Reader.__extract_step0(f) + keys: list[str] = list(f0.keys()) + return set(c for c in keys if c.startswith(prefix)) + + @override + def ReadArrayShapeAtTimestep( + self, path: str, category: str, quantity: str, step: int + ) -> tuple[int]: + with h5py.File(filename := self.FullPath(path, category, step), "r") as f: + f0 = Reader.__extract_step0(f) + if quantity in f0.keys(): + var = f0[quantity] + if isinstance(var, h5py.Dataset): + return var.shape + else: + raise ValueError( + f"{category.capitalize()} {quantity} is not a group in the {filename}" + ) + else: + raise ValueError( + f"{category.capitalize()} {quantity} not found in the {filename}" + ) + + @override + def ReadFieldCoordsAtTimestep( + self, path: str, step: int + ) -> dict[str, npt.NDArray[Any]]: + with h5py.File(filename := self.FullPath(path, "fields", step), "r") as f: + f0 = Reader.__extract_step0(f) + + def get_coord(c: str) -> Any: + f0_c = f0[c] + if isinstance(f0_c, h5py.Dataset): + return f0_c[:] + else: + raise ValueError(f"Field {c} is not a group in the {filename}") + + keys: list[str] = list(f0.keys()) + return {c: get_coord(c) for c in keys if re.match(r"^X[1|2|3]$", c)} + + @override + def ReadFieldLayoutAtTimestep(self, path: str, step: int) -> Layout: + with h5py.File(filename := self.FullPath(path, "fields", step), "r") as f: + if "LayoutRight" not in f.attrs: + raise ValueError(f"LayoutRight attribute not found in the {filename}") + return Layout.R if f.attrs["LayoutRight"] else Layout.L diff --git a/nt2/tests/.gitattributes b/nt2/tests/.gitattributes new file mode 100644 index 0000000..3b62c44 --- /dev/null +++ b/nt2/tests/.gitattributes @@ -0,0 +1 @@ +testdata.tar.gz filter=lfs diff=lfs merge=lfs -text diff --git a/nt2/tests/__init__.py b/nt2/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/nt2/tests/cases.py b/nt2/tests/cases.py new file mode 100644 index 0000000..9bde822 --- /dev/null +++ b/nt2/tests/cases.py @@ -0,0 +1,155 @@ +import nt2.readers.hdf5 as hdf5 +import nt2.readers.adios2 as adios2 + +cwd = __file__.rsplit("/", 1)[0] + +TESTS = [ + { + "dim": "2D", + "reader": hdf5.Reader, + "path": f"{cwd}/testdata/h5_2d_cart_cpu/", + "invalid_tstep": 71, + "fields": { + "sx1": 10, + "sx2": 20, + "nx1": 64, + "nx2": 128, + }, + "particles": { + "num": [325, 325, 324, 325], + }, + }, + { + "dim": "2D", + "reader": hdf5.Reader, + "path": f"{cwd}/testdata/h5_2d_cart_gpu/", + "invalid_tstep": 81, + "fields": { + "sx1": 10, + "sx2": 20, + "nx1": 64, + "nx2": 128, + }, + "particles": { + "num": 327, + }, + }, + { + "dim": "3D", + "reader": hdf5.Reader, + "path": f"{cwd}/testdata/h5_3d_cart_cpu/", + "invalid_tstep": None, + "fields": { + "sx1": 10, + "sx2": 16, + "sx3": 12.5, + "nx1": 20, + "nx2": 32, + "nx3": 25, + }, + "particles": { + "num": 639, + }, + }, + { + "dim": "3D", + "reader": hdf5.Reader, + "path": f"{cwd}/testdata/h5_3d_cart_gpu/", + "invalid_tstep": None, + "dt": 0.1443375647, + "fields": {}, + "particles": { + "num": 640, + }, + }, + { + "dim": "2D", + "reader": adios2.Reader, + "path": f"{cwd}/testdata/adios2_2d_cart_cpu/", + "invalid_tstep": None, + "fields": { + "sx1": 10, + "sx2": 20, + "nx1": 64, + "nx2": 128, + }, + "particles": { + "num": [323, 324, 323, 324], + }, + }, + { + "dim": "3D", + "reader": adios2.Reader, + "path": f"{cwd}/testdata/adios2_3d_cart_cpu/", + "invalid_tstep": None, + "fields": { + "sx1": 10, + "sx2": 16, + "sx3": 12.5, + "nx1": 20, + "nx2": 32, + "nx3": 25, + }, + "particles": { + "num": 637, + }, + }, + { + "dim": "2D", + "coords": "qsph", + "reader": hdf5.Reader, + "path": f"{cwd}/testdata/h5_2d_qsph_cpu/", + "invalid_tstep": None, + "dt": 2.93069752e-03, + "fields": { + "sx1": 1, + "sx2": 30, + "nx1": 512, + "nx2": 256, + "quantities": [ + "E1", + "E2", + "E3", + "B1", + "B2", + "B3", + "N_1", + "N_2", + "T00", + ], + }, + "particles": { + "nspec": 2, + "num": [2939, 1222], + }, + }, + { + "dim": "2D", + "coords": "sph", + "reader": adios2.Reader, + "path": f"{cwd}/testdata/adios2_2d_sph_cpu/", + "invalid_tstep": None, + "dt": 5.99678652e-03, + "fields": { + "sx1": 1, + "sx2": 30, + "nx1": 512, + "nx2": 256, + "quantities": [ + "E1", + "E2", + "E3", + "B1", + "B2", + "B3", + "N_1", + "N_2", + "T00", + ], + }, + "particles": { + "nspec": 2, + "num": [253, 157], + }, + }, +] diff --git a/nt2/tests/conftest.py b/nt2/tests/conftest.py new file mode 100644 index 0000000..bb13a39 --- /dev/null +++ b/nt2/tests/conftest.py @@ -0,0 +1,15 @@ +import os +import tarfile +import pytest + + +@pytest.fixture(scope="session", autouse=True) +def unpack_testdata(): + base_dir = os.path.dirname(__file__) + tar_path = os.path.join(base_dir, "testdata.tar.gz") + extract_dir = os.path.join(base_dir, "testdata") + + if not os.path.exists(os.path.join(base_dir, "testdata")): + os.makedirs(extract_dir, exist_ok=True) + with tarfile.open(tar_path, "r:gz") as tar: + tar.extractall(path=extract_dir, filter="data") diff --git a/nt2/tests/test_cli.py b/nt2/tests/test_cli.py new file mode 100644 index 0000000..272262f --- /dev/null +++ b/nt2/tests/test_cli.py @@ -0,0 +1,87 @@ +import pytest +from typer.testing import CliRunner +import matplotlib.pyplot as plt + +import os +import nt2 +from nt2.cli.main import app +from nt2.tests.cases import TESTS + +runner = CliRunner() + + +def test_version(): + result = runner.invoke(app, ["version"]) + assert result.exit_code == 0, f"Expected exit code 0, got {result.exit_code}" + assert ( + nt2.__version__ in result.output + ), f"Expected version {nt2.__version__} in output, got {result.output}" + + +@pytest.mark.parametrize( + "test", + [test for test in TESTS], +) +def test_show(test): + PATH = test["path"] + data = nt2.Data(PATH) + result = runner.invoke(app, ["show", PATH]) + assert result.exit_code == 0, f"Expected exit code 0, got {result.exit_code}" + assert ( + data.to_str() in result.output + ), f"Expected data info in output, got {result.output}" + + +@pytest.mark.parametrize( + "test", + [test for test in TESTS], +) +def test_plot_png(test): + PATH = test["path"] + if test["fields"] == {}: + return + if test.get("coords", "cart") == "cart": + result = runner.invoke( + app, + [ + "plot", + PATH, + "--sel", + "x=slice(None, 5);y=slice(-5.0, 5.0)", + "--isel", + f"t=0{';z=0' if test['dim'] == '3D' else ''}", + ], + ) + assert result.exit_code == 0, f"Expected exit code 0, got {result.exit_code}" + + data = nt2.Data(PATH) + fname = os.path.basename(PATH.strip("/")) + + d = data.fields.sel(x=slice(None, 5), y=slice(-5, 5)).isel(t=0) + if test["dim"] == "3D": + d = d.isel(z=0) + d.inspect.plot(fig_kwargs={"dpi": 200}) + plt.savefig(fname=f"{fname}-2.png") + + def files_are_identical(path1, path2): + with open(path1, "rb") as f1, open(path2, "rb") as f2: + return f1.read() == f2.read() + + assert files_are_identical( + f"{fname}-2.png", f"{fname}.png" + ), f"Files {fname}-2.png and {fname}.png are not identical." + + os.remove(f"{fname}-2.png") + os.remove(f"{fname}.png") + # else: + # result = runner.invoke( + # app, + # [ + # "plot", + # PATH, + # "--sel", + # "r=slice(None, 5);th=slice(1.5, 2.5)", + # "--isel", + # "t=0", + # ], + # ) diff --git a/nt2/tests/test_containers.py b/nt2/tests/test_containers.py new file mode 100644 index 0000000..77302c4 --- /dev/null +++ b/nt2/tests/test_containers.py @@ -0,0 +1,162 @@ +import pytest + +from nt2.readers.base import BaseReader +from nt2.containers.fields import Fields +from nt2.containers.particles import Particles +from nt2.containers.data import Data +from nt2.tests.cases import TESTS + + +def check_shape(shape1, shape2): + """ + Check if two shapes are equal + """ + assert shape1 == shape2, f"Shape {shape1} is not equal to {shape2}" + + +@pytest.mark.parametrize( + "test,field_container", [[test, fc] for test in TESTS for fc in [Data, Fields]] +) +def test_fields(test, field_container: type[Data] | type[Fields]): + reader: BaseReader = test["reader"]() + PATH = test["path"] + if test["fields"] == {}: + return + + coords: list[str] = ["x", "y", "z"] + flds: list[str] = ["Ex", "Ey", "Ez", "Bx", "By", "Bz"] + + def coord_remap(Xold: str) -> str: + return { + "X1": "x", + "X2": "y", + "X3": "z", + }.get(Xold, Xold) + + if test.get("coords", "cart") != "cart": + coords = ["r", "th", "ph"] + flds = ["Er", "Eth", "Eph", "Br", "Bth", "Bph"] + coord_remap = lambda Xold: { + "X1": "r", + "X2": "th", + "X3": "ph", + }.get(Xold, Xold) + + def field_remap(Fold: str): + return { + f"f{F}{i+1}": f"{F}{x}" for i, x in enumerate(coords) for F in "EB" + }.get(Fold, Fold) + + fields = field_container( + path=PATH, + reader=reader, + remap={"coords": coord_remap, "fields": field_remap}, + ) + + steps = reader.GetValidSteps(path=PATH, category="fields") + nx1 = test["fields"]["nx1"] + nx2 = test["fields"]["nx2"] + assert fields.fields is not None, "Fields are None" + for f in flds: + assert f in fields.fields, f"{f} is not in fields" + if test["dim"] == "2D": + xyzshape = (nx2, nx1) + yzshape = (nx2,) + xzshape = (nx1,) + xyshape = () + else: + nx3 = test["fields"]["nx3"] + xyzshape = (nx3, nx2, nx1) + yzshape = (nx3, nx2) + xzshape = (nx3, nx1) + xyshape = (nx2, nx1) + + check_shape( + fields.fields[f].shape, + tuple([len(steps), *xyzshape]), + ) + check_shape( + fields.fields[f].isel(t=0).shape, + tuple([*xyzshape]), + ) + if test.get("coords", "cart") == "cart": + check_shape( + fields.fields[f].isel(x=0).shape, + tuple([len(steps), *yzshape]), + ) + check_shape( + fields.fields[f].isel(y=0).shape, + tuple([len(steps), *xzshape]), + ) + + if test["dim"] == "3D": + check_shape( + fields.fields[f].isel(z=0).shape, + tuple([len(steps), *xyshape]), + ) + else: + check_shape( + fields.fields[f].isel(r=0).shape, + tuple([len(steps), *yzshape]), + ) + check_shape( + fields.fields[f].isel(th=0).shape, + tuple([len(steps), *xzshape]), + ) + + +@pytest.mark.parametrize( + "test,particle_container", + [[test, fc] for test in TESTS for fc in [Data, Particles]], +) +def test_particles(test, particle_container: type[Data] | type[Particles]): + reader: BaseReader = test["reader"]() + PATH = test["path"] + if test["particles"] == {}: + return + + prtl_coords: list[str] = ["x", "y", "z", "ux", "uy", "uz", "w"] + + def prtl_remap(Xold: str) -> str: + return { + "pX1": "x", + "pX2": "y", + "pX3": "z", + "pU1": "ux", + "pU2": "uy", + "pU3": "uz", + "pW": "w", + }.get(Xold, Xold) + + if test.get("coords", "cart") != "cart": + prtl_coords = ["r", "th", "ph", "ur", "uth", "uph", "w"] + prtl_remap = lambda Xold: { + "pX1": "r", + "pX2": "th", + "pX3": "ph", + "pU1": "ur", + "pU2": "uth", + "pU3": "uph", + "pW": "w", + }.get(Xold, Xold) + particles = particle_container( + path=PATH, + reader=reader, + remap={"particles": prtl_remap}, + ) + steps = reader.GetValidSteps(path=PATH, category="particles") + assert particles.particles is not None, "Particles are None" + for p in prtl_coords: + if p == "z" and test["dim"] == "2D" and test.get("coords", "cart") == "cart": + continue + for i, (_, parts) in enumerate(particles.particles.items()): + if isinstance(test["particles"]["num"], list): + num = test["particles"]["num"][i] + else: + num = test["particles"]["num"] + check_shape(parts[p].shape, (len(steps), num)) + for i, st in enumerate(steps): + assert ( + parts[p].isel(t=i).s.values[()] == st + ), f"Step {st} does not match in particle {p}" + check_shape(parts[p].isel(t=i).shape, (num,)) diff --git a/nt2/tests/test_reader.py b/nt2/tests/test_reader.py new file mode 100644 index 0000000..28bae7e --- /dev/null +++ b/nt2/tests/test_reader.py @@ -0,0 +1,227 @@ +import numpy as np +from nt2.tests.cases import TESTS + +from nt2.readers.base import BaseReader +from nt2.utils import Layout + + +def pytest_generate_tests(metafunc): + if "test" in metafunc.fixturenames: + metafunc.parametrize("test", TESTS) + + +def check_equal_arrays(arr1, arr2): + if isinstance(arr1, set): + assert len(arr1) == len( + arr2 + ), f"Set lengths do not match: {len(arr1)} != {len(arr2)}" + assert arr1 == arr2, f"Sets do not match: {arr1} != {arr2}" + else: + arr1 = np.array(arr1) + arr2 = np.array(arr2) + assert ( + arr1.shape == arr2.shape + ), f"Shapes do not match: {arr1.shape} != {arr2.shape}" + assert np.all(np.isclose(arr1, arr2)), f"Arrays do not match: {arr1} != {arr2}" + + +def check_raises(func, exception): + try: + func() + except exception: + return + except Exception as e: + raise e + raise AssertionError(f"{func} should raise {exception}") + + +def test_reader(test): + PATH = test["path"] + invalid_tstep = test["invalid_tstep"] + reader: BaseReader = test["reader"]() + + # PARAMS["2D"]["dx"] = PARAMS["2D"]["sx1"] / PARAMS["2D"]["nx1"] + # PARAMS["2D"]["x1min"] = -PARAMS["2D"]["sx1"] / 2 + PARAMS["2D"]["dx"] / 2 + # PARAMS["2D"]["x2min"] = -PARAMS["2D"]["sx2"] / 2 + PARAMS["2D"]["dx"] / 2 + # PARAMS["2D"]["dt"] = PARAMS["2D"]["dx"] / np.sqrt(2) / 2 + + # PARAMS["3D"]["dx"] = PARAMS["3D"]["sx1"] / PARAMS["3D"]["nx1"] + # PARAMS["3D"]["x1min"] = -PARAMS["3D"]["sx1"] / 2 + PARAMS["3D"]["dx"] / 2 + # PARAMS["3D"]["x2min"] = -PARAMS["3D"]["sx2"] / 2 + PARAMS["3D"]["dx"] / 2 + # PARAMS["3D"]["x3min"] = PARAMS["3D"]["dx"] / 2 + # PARAMS["3D"]["dt"] = PARAMS["3D"]["dx"] / np.sqrt(3) / 2 + + # dt = PARAMS[test["dim"]]["dt"] + # dx = PARAMS[test["dim"]]["dx"] + # x1min = PARAMS[test["dim"]]["x1min"] + # x2min = PARAMS[test["dim"]]["x2min"] + # nx1 = PARAMS[test["dim"]]["nx1"] + # nx2 = PARAMS[test["dim"]]["nx2"] + + if test["fields"] != {}: + # parameters + nx1 = test["fields"]["nx1"] + nx2 = test["fields"]["nx2"] + sx1 = test["fields"]["sx1"] + sx2 = test["fields"]["sx2"] + dx = sx1 / nx1 + x1min = -sx1 / 2 + dx / 2 + x2min = -sx2 / 2 + dx / 2 + if test["dim"] == "3D": + dt = dx / np.sqrt(3) / 2 + else: + dt = dx / np.sqrt(2) / 2 + + timestep = test.get("dt", None) + if timestep is not None: + dt = timestep + + field_names = test["fields"].get( + "quantities", + [f"{f}{i+1}" for i in range(3) for f in "BE"] + + [f"N_{i}" for i in ["1_2", "3_4"]] + + [f"T0{c+1}_{i+1}" for i in range(4) for c in range(3)], + ) + field_names = set(f"f{f}" for f in field_names) + # Check that invalid_tstep raises OSError in fields + if invalid_tstep is not None: + check_raises( + lambda: reader.ReadArrayAtTimestep( + path=PATH, category="fields", quantity="Foo", step=invalid_tstep + ), + OSError, + ) + + # Check that timesteps are read correctly from fields + times = reader.ReadPerTimestepVariable( + path=PATH, category="fields", varname="Time", newname="t" + )["t"] + steps = reader.ReadPerTimestepVariable( + path=PATH, category="fields", varname="Step", newname="s" + )["s"] + check_equal_arrays( + times, + np.array([s * dt for s in steps]), + ) + + # Check that the names of the fields are read correctly + names = reader.ReadCategoryNamesAtTimestep( + path=PATH, category="fields", prefix="f", step=1 + ) + check_equal_arrays(names, field_names) + + # Check coords + if test.get("coords", "cart") == "cart": + coords = reader.ReadFieldCoordsAtTimestep(path=PATH, step=1) + x1 = np.array([x1min + i * dx for i in range(int(nx1))]) + x2 = np.array([x2min + i * dx for i in range(int(nx2))]) + check_equal_arrays(coords["X1"], x1) + check_equal_arrays(coords["X2"], x2) + + if test["dim"] == "3D": + sx3 = test["fields"]["sx3"] + x3min = dx / 2 + nx3 = test["fields"]["nx3"] + x3 = np.array([x3min + i * dx for i in range(int(nx3))]) + check_equal_arrays(coords["X3"], x3) + + # Check field shapes + + field = next(iter(field_names)) + layout = reader.ReadFieldLayoutAtTimestep(path=PATH, step=1) + shape = reader.ReadArrayShapeAtTimestep( + path=PATH, category="fields", quantity=field, step=1 + ) + if test["dim"] == "2D": + check_equal_arrays(shape, (nx1, nx2) if layout == Layout.R else (nx2, nx1)) + else: + nx3 = test["fields"]["nx3"] + check_equal_arrays( + shape, (nx1, nx2, nx3) if layout == Layout.R else (nx3, nx2, nx1) + ) + + for step in reader.GetValidSteps(path=PATH, category="fields"): + for f in field_names: + field = reader.ReadArrayAtTimestep( + path=PATH, category="fields", quantity=f, step=step + ) + check_equal_arrays(field.shape, shape) + + reader.VerifySameCategoryNames(path=PATH, category="fields", prefix="f") + reader.VerifySameFieldLayouts(path=PATH) + + # Check that the shapes of the fields are read correctly + reader.VerifySameFieldShapes(path=PATH) + + if test["particles"] != {}: + dt = 0 + if test["fields"] != {}: + nx1 = test["fields"]["nx1"] + sx1 = test["fields"]["sx1"] + dx = sx1 / nx1 + if test["dim"] == "3D": + dt = dx / np.sqrt(3) / 2 + else: + dt = dx / np.sqrt(2) / 2 + + timestep = test.get("dt", None) + if timestep is not None: + dt = timestep + + nspec: int = test["particles"].get("nspec", 4) + + prtl_names = ( + [f"U{i+1}_{j+1}" for i in range(3) for j in range(nspec)] + + [ + f"X{i+1}_{j+1}" + for i in range( + 2 + if test["dim"] == "2D" and test.get("coords", "cart") == "cart" + else 3 + ) + for j in range(nspec) + ] + + [f"W_{i+1}" for i in range(nspec)] + ) + prtl_names = set(f"p{p}" for p in prtl_names) + + # Check that timesteps are read correctly from particles + times = reader.ReadPerTimestepVariable( + path=PATH, category="particles", varname="Time", newname="t" + )["t"] + steps = reader.ReadPerTimestepVariable( + path=PATH, category="particles", varname="Step", newname="s" + )["s"] + + if dt is not None: + check_equal_arrays( + times, + np.array([s * dt for s in steps]), + ) + + # Check that the names of the particle quantities are read correctly + names = reader.ReadCategoryNamesAtTimestep( + path=PATH, category="particles", prefix="p", step=1 + ) + print(names, prtl_names) + check_equal_arrays(names, prtl_names) + + # Check prtl shapes + for step in reader.GetValidSteps(path=PATH, category="particles"): + for sp in range(nspec): + shape = reader.ReadArrayShapeAtTimestep( + path=PATH, + category="particles", + quantity=f"pW_{sp+1}", + step=step, + ) + for p in prtl_names: + if not p.endswith(f"_{sp+1}"): + continue + prtl_shape = reader.ReadArrayShapeAtTimestep( + path=PATH, category="particles", quantity=p, step=step + ) + check_equal_arrays(prtl_shape, shape) + + # Check that all timesteps have the same names + reader.VerifySameCategoryNames(path=PATH, category="particles", prefix="p") diff --git a/nt2/tests/testdata.tar.gz b/nt2/tests/testdata.tar.gz new file mode 100644 index 0000000..6c7e538 --- /dev/null +++ b/nt2/tests/testdata.tar.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e8991b217a7fa943096e06ff0606561114bfd0a429ad32642b9ff65afffc8e04 +size 101862478 diff --git a/nt2/utils.py b/nt2/utils.py index 1ce4ea9..5e5d233 100644 --- a/nt2/utils.py +++ b/nt2/utils.py @@ -1,5 +1,110 @@ -import warnings +from enum import Enum +import os +import re +import inspect + +import xarray as xr class FutureDeprecationWarning(Warning): pass + + +class Format(Enum): + HDF5 = "h5" + BP5 = "bp" + + +class Layout(Enum): + L = "Left" + R = "Right" + + +class CoordinateSystem(Enum): + XYZ = "Cartesian" + SPH = "Spherical" + + +def DetermineDataFormat(path: str) -> Format: + """Determine the data format for the files in the given path. + + Parameters + ---------- + path : str + + Returns + ------- + Format + The data format of the file. + + Raises + ------- + ValueError + If the file format is unknown. + """ + categories = ["fields", "particles", "spectra"] + for category in categories: + category_path = os.path.join(path, category) + if os.path.exists(category_path): + files = [ + f + for f in os.listdir(category_path) + if re.match(rf"^{category}\.\d{{{8}}}\.", f) + ] + if len(files) > 0: + filename = files[0] + if filename.endswith(".h5"): + return Format.HDF5 + elif filename.endswith(".bp"): + return Format.BP5 + else: + raise ValueError(f"Unknown file format: {filename}.") + raise ValueError("Could not determine file format.") + + +def ToHumanReadable(num: float | int, suffix: str = "B") -> str: + """Convert a number to a human-readable format with SI prefixes. + + Parameters + ---------- + num : float | int + The number to convert. + suffix : str + The suffix to append to the number (default: "B"). + + Returns + ------- + str + The number in human-readable format with SI prefixes. + """ + for unit in ("", "K", "M", "G", "T", "P", "E", "Z"): + if abs(num) < 1e3: + return f"{num:3.1f} {unit}{suffix}" + num /= 1e3 + return f"{num:.1f} Y{suffix}" + + +def DataIs2DPolar(ds: xr.Dataset) -> bool: + """Check if the dataset is 2D polar. + A dataset is considered 2D polar if it has two dimensions: "r" and either "θ" or "th". + + Returns + ------- + bool + True if the dataset is 2D polar, False otherwise. + """ + return ("r" in ds.dims and ("θ" in ds.dims or "th" in ds.dims)) and len( + ds.dims + ) == 2 + + +def InheritClassDocstring(cls: type) -> type: + """Decorator to inherit docstring from parent classes. + This decorator appends the docstrings of all parent classes to the docstring of the class. + """ + if cls.__doc__ is None: + cls.__doc__ = "" + for base in inspect.getmro(cls): + if base.__doc__ is not None: + cls.__doc__ += base.__doc__ + return cls diff --git a/pyproject.toml b/pyproject.toml index 47551ed..f5e1685 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,7 +7,8 @@ name = "nt2py" dynamic = ["version"] dependencies = [ "types-setuptools", - "dask[distributed]", + "dask[complete]", + "adios2", "bokeh", "xarray", "numpy", @@ -16,6 +17,7 @@ dependencies = [ "matplotlib", "tqdm", "contourpy", + "typer", ] requires-python = ">=3.8" authors = [{ name = "Hayk", email = "haykh.astro@gmail.com" }] @@ -42,6 +44,9 @@ classifiers = [ [project.urls] Repository = "https://github.com/entity-toolkit/nt2py" +[project.scripts] +nt2 = "nt2.cli.main:app" + [tool.hatch.version] path = "nt2/__init__.py" diff --git a/pyrightconfig.json b/pyrightconfig.json index bdfd610..cb36191 100644 --- a/pyrightconfig.json +++ b/pyrightconfig.json @@ -1,3 +1,12 @@ { - "extraPath": ["./"], + "extraPath": [ + "./" + ], + "reportAny": false, + "reportExplicitAny": false, + "reportUnknownVariableType": false, + "reportUnknownMemberType": false, + "reportUnknownArgumentType": false, + "reportArgumentType": false, + "reportPrivateImportUsage": false } diff --git a/requirements.txt b/requirements.txt index 0ece1c7..d8c67e5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,61 +1,141 @@ -bleach>=6.2.0 -bokeh>=3.6.3 -build>=1.2.2.post1 -cachetools>=5.5.2 -certifi>=2025.1.31 -charset-normalizer>=3.4.1 -click>=8.1.8 -cloudpickle>=3.1.1 -colorcet>=3.1.0 -contourpy>=1.3.1 -cycler>=0.12.1 -dask>=2025.2.0 -fonttools>=4.56.0 -fsspec>=2025.2.0 -h5pickle>=0.4.2 -h5py>=3.13.0 -hatchling>=1.27.0 -holoviews>=1.20.1 -hvplot>=0.11.2 -idna>=3.10 -importlib_metadata>=8.6.1 -Jinja2>=3.1.5 -kiwisolver>=1.4.8 -linkify-it-py>=2.0.3 -locket>=1.0.0 -Markdown>=3.7 -markdown-it-py>=3.0.0 -MarkupSafe>=3.0.2 -matplotlib>=3.10.1 -mdit-py-plugins>=0.4.2 -mdurl>=0.1.2 -numpy>=2.2.3 -packaging>=24.2 -pandas>=2.2.3 -panel>=1.6.1 -param>=2.2.0 -partd>=1.4.2 -pathspec>=0.12.1 -pillow>=11.1.0 -pluggy>=1.5.0 -pyct>=0.5.0 -pyparsing>=3.2.1 -pyproject_hooks>=1.2.0 -python-dateutil>=2.9.0.post0 -pytz>=2025.1 -pyviz_comms>=3.0.4 -PyYAML>=6.0.2 -requests>=2.32.3 -six>=1.17.0 -toolz>=1.0.0 -tornado>=6.4.2 -tqdm>=4.67.1 -trove-classifiers>=2025.3.3.18 -typing_extensions>=4.12.2 -tzdata>=2025.1 -uc-micro-py>=1.0.3 -urllib3>=2.3.0 -webencodings>=0.5.1 -xarray>=2025.1.2 -xyzservices>=2025.1.0 -zipp>=3.21.0 +adios2==2.10.1.100042 +anyio==4.9.0 +argon2-cffi==25.1.0 +argon2-cffi-bindings==21.2.0 +arrow==1.3.0 +asttokens==3.0.0 +async-lru==2.0.5 +attrs==25.3.0 +babel==2.17.0 +beautifulsoup4==4.13.4 +black==25.1.0 +bleach==6.2.0 +bokeh==3.7.3 +build==1.2.2.post1 +cachetools==6.1.0 +certifi==2025.6.15 +cffi==1.17.1 +charset-normalizer==3.4.2 +click==8.1.8 +cloudpickle==3.1.1 +colorcet==3.1.0 +comm==0.2.2 +contourpy==1.3.2 +cycler==0.12.1 +dask==2025.5.1 +debugpy==1.8.14 +decorator==5.2.1 +defusedxml==0.7.1 +distributed==2025.5.1 +executing==2.2.0 +fastjsonschema==2.21.1 +flit_core==3.12.0 +fonttools==4.58.4 +fqdn==1.5.1 +fsspec==2025.5.1 +h11==0.16.0 +h5pickle==0.4.2 +h5py==3.14.0 +hatchling==1.27.0 +holoviews==1.21.0 +httpcore==1.0.9 +httpx==0.28.1 +hvplot==0.11.3 +idna==3.10 +importlib_metadata==8.7.0 +iniconfig==2.1.0 +isoduration==20.11.0 +jedi==0.19.2 +Jinja2==3.1.6 +json5==0.12.0 +jsonpointer==3.0.0 +jsonschema==4.24.0 +jsonschema-specifications==2025.4.1 +kiwisolver==1.4.8 +linkify-it-py==2.0.3 +locket==1.0.0 +lz4==4.4.4 +Markdown==3.8.2 +markdown-it-py==3.0.0 +MarkupSafe==3.0.2 +matplotlib==3.10.3 +matplotlib-inline==0.1.7 +mdit-py-plugins==0.4.2 +mdurl==0.1.2 +mistune==3.1.3 +msgpack==1.1.1 +mypy_extensions==1.1.0 +narwhals==1.44.0 +nbclient==0.10.2 +nbconvert==7.16.6 +nbformat==5.10.4 +nest-asyncio==1.6.0 +numpy==2.3.1 +overrides==7.7.0 +packaging==24.2 +pandas==2.3.0 +pandocfilters==1.5.1 +panel==1.7.2 +param==2.2.1 +parso==0.8.4 +partd==1.4.2 +pathspec==0.12.1 +pexpect==4.9.0 +pillow==11.2.1 +platformdirs==4.3.7 +pluggy==1.6.0 +prometheus_client==0.22.1 +prompt_toolkit==3.0.51 +psutil==7.0.0 +ptyprocess==0.7.0 +pure_eval==0.2.3 +pyarrow==20.0.0 +pycparser==2.22 +pyct==0.5.0 +Pygments==2.19.2 +pyparsing==3.2.3 +pyproject_hooks==1.2.0 +pytest==8.4.1 +python-dateutil==2.9.0.post0 +python-json-logger==3.3.0 +pytz==2025.2 +pyviz_comms==3.0.6 +PyYAML==6.0.2 +pyzmq==27.0.0 +referencing==0.36.2 +requests==2.32.4 +rfc3339-validator==0.1.4 +rfc3986-validator==0.1.1 +rpds-py==0.25.1 +scipy==1.16.0 +Send2Trash==1.8.3 +setuptools==80.9.0 +six==1.17.0 +sniffio==1.3.1 +sortedcontainers==2.4.0 +soupsieve==2.7 +stack-data==0.6.3 +tblib==3.1.0 +terminado==0.18.1 +tinycss2==1.4.0 +toolz==1.0.0 +tornado==6.5.1 +tqdm==4.67.1 +traitlets==5.14.3 +trove-classifiers==2025.5.9.12 +types-python-dateutil==2.9.0.20250516 +types-setuptools==80.9.0.20250529 +typing_extensions==4.14.0 +typer==0.16.0 +tzdata==2025.2 +uc-micro-py==1.0.3 +uri-template==1.3.0 +urllib3==2.5.0 +wcwidth==0.2.13 +webcolors==24.11.1 +webencodings==0.5.1 +websocket-client==1.8.0 +xarray==2025.6.1 +xyzservices==2025.4.0 +zict==3.0.0 +zipp==3.23.0 diff --git a/shell.nix b/shell.nix index 696216d..f8148ca 100644 --- a/shell.nix +++ b/shell.nix @@ -9,7 +9,7 @@ pkgs.mkShell { pkgs."python${py}" pkgs."python${py}Packages".pip black - pyright + basedpyright taplo vscode-langservers-extracted zlib @@ -20,126 +20,15 @@ pkgs.mkShell { ]; shellHook = '' + if [ ! -d ".venv" ]; then + python3 -m venv .venv + source .venv/bin/activate + pip3 install -r requirements.txt + pip3 install pytest + pip3 install -e . + else + source .venv/bin/activate + fi echo "nt2dev nix-shell activated: $(which python)" ''; } -# { pkgs, env }: -# -# let -# webPkgs = with pkgs; [ -# nodejs_23 -# vscode-langservers-extracted -# emmet-ls -# typescript-language-server -# taplo -# yaml-language-server -# markdown-oxide -# prettierd -# eslint_d -# mdformat -# ]; -# -# goPkgs = with pkgs; [ -# go -# gopls -# gotools -# hugo -# prettier-plugin-go-template -# ]; -# -# cppPkgs = with pkgs; [ -# zlib -# llvmPackages_19.libcxxClang -# clang-tools -# cmake -# neocmakelsp -# cmake-format -# ]; -# -# glPkgs = with pkgs; [ -# zlib -# glslls -# clang-tools -# ]; -# -# pythonPkgs = with pkgs; [ -# python312 -# black -# pyright -# taplo -# vscode-langservers-extracted -# ]; -# -# rocmPkgs = with pkgs; [ -# rocmPackages.hipcc -# rocmPackages.rocminfo -# rocmPackages.rocm-smi -# ]; -# -# asm = with pkgs; [ -# nasm -# (import ../derivations/asm-lsp.nix { inherit pkgs; }) -# ]; -# -# in -# let -# nativeBuildInputs = -# [ ] -# ++ (if builtins.elem "web" env then webPkgs else [ ]) -# ++ (if builtins.elem "go" env then goPkgs else [ ]) -# ++ (if builtins.elem "cpp" env then cppPkgs else [ ]) -# ++ (if builtins.elem "gl" env then glPkgs else [ ]) -# ++ (if builtins.elem "python" env then pythonPkgs else [ ]) -# ++ (if builtins.elem "rocm" env then rocmPkgs else [ ]) -# ++ (if builtins.elem "asm" env then asm else [ ]); -# _vars = { -# LD_LIBRARY_PATH = pkgs.lib.makeLibraryPath ( -# if ((builtins.elem "cpp" env) || (builtins.elem "gl" env)) then -# [ -# pkgs.stdenv.cc.cc -# pkgs.zlib -# ] -# else -# [ ] -# ); -# }; -# envVars = builtins.listToAttrs ( -# builtins.map (varName: { -# name = varName; -# value = _vars.${varName}; -# }) (builtins.attrNames _vars) -# ); -# preShellHook = -# '' -# RED='\033[0;31m' -# GREEN='\033[0;32m' -# BLUE='\033[0;34m' -# NC='\033[0m' -# export SHELL=$(which zsh) -# '' -# + ( -# if builtins.elem "web" env then -# '' -# npm set prefix $HOME/.npm -# export PATH=$HOME/.npm/bin:$PATH -# '' -# else -# "" -# ); -# postShellHook = -# { name, cmd }: -# '' -# echo "" -# echo -e "${name} nix-shell activated: ''\${BLUE}$(which ${cmd})''\${NC}" -# exec $SHELL -# ''; -# -# in -# { -# inherit -# nativeBuildInputs -# envVars -# preShellHook -# postShellHook -# ; -# }