diff --git a/README.md b/README.md index 9c0e049..f994ee0 100644 --- a/README.md +++ b/README.md @@ -79,6 +79,12 @@ You can also create a movie of a single field quantity (can be custom): (data.fields.Ex * data.fields.Bx).sel(x=slice(None, 0.2)).movie.plot(name="ExBx", vmin=-0.01, vmax=0.01, cmap="BrBG") ``` +For particles, one can also make 2D phase-space plots: + +```python +data.particles[1].sel(t=1.0, method="nearest").particles.phaseplot(x="x", y="uy", xnbins=100, ynbins=200, xlims=(0, 100), cmap="inferno") +``` + You may also combine different quantities and plots (e.g., fields & particles) to produce a more customized movie: ```python diff --git a/dist/nt2py-1.1.0-py3-none-any.whl b/dist/nt2py-1.1.0-py3-none-any.whl new file mode 100644 index 0000000..580236c Binary files /dev/null and b/dist/nt2py-1.1.0-py3-none-any.whl differ diff --git a/dist/nt2py-1.1.0.tar.gz b/dist/nt2py-1.1.0.tar.gz new file mode 100644 index 0000000..f41e835 Binary files /dev/null and b/dist/nt2py-1.1.0.tar.gz differ diff --git a/nt2/__init__.py b/nt2/__init__.py index 4cce7aa..5a81e39 100644 --- a/nt2/__init__.py +++ b/nt2/__init__.py @@ -1,4 +1,4 @@ -__version__ = "1.0.1" +__version__ = "1.1.0" import nt2.containers.data as nt2_data diff --git a/nt2/containers/data.py b/nt2/containers/data.py index cd0bec9..d706cee 100644 --- a/nt2/containers/data.py +++ b/nt2/containers/data.py @@ -1,12 +1,15 @@ from typing import Callable, Any import sys + if sys.version_info >= (3, 12): from typing import override else: + def override(method): return method + from collections.abc import KeysView from nt2.utils import ToHumanReadable @@ -25,7 +28,7 @@ def override(method): from nt2.containers.particles import Particles import nt2.plotters.polar as acc_polar - +import nt2.plotters.particles as acc_particles import nt2.plotters.inspect as acc_inspect import nt2.plotters.movie as acc_movie from nt2.plotters.export import makeFramesAndMovie @@ -37,6 +40,12 @@ class DatasetPolarPlotAccessor(acc_polar.ds_accessor): pass +@xr.register_dataset_accessor("particles") +@InheritClassDocstring +class DatasetParticlesPlotAccessor(acc_particles.ds_accessor): + pass + + @xr.register_dataarray_accessor("polar") @InheritClassDocstring class PolarPlotAccessor(acc_polar.accessor): diff --git a/nt2/plotters/inspect.py b/nt2/plotters/inspect.py index ea5d91b..d9d9ab7 100644 --- a/nt2/plotters/inspect.py +++ b/nt2/plotters/inspect.py @@ -180,7 +180,7 @@ def plot( ) def plot_func(ti: int, _): - if len(self._obj.dims) == 1: + if len(self._obj.isel(t=ti).dims) == 1: _ = self.plot_frame_1d( self._obj.isel(t=ti), None, @@ -189,7 +189,7 @@ def plot_func(ti: int, _): fig_kwargs, plot_kwargs, ) - elif len(self._obj.dims) == 2: + elif len(self._obj.isel(t=ti).dims) == 2: _ = self.plot_frame_2d( self._obj.isel(t=ti), None, diff --git a/nt2/plotters/particles.py b/nt2/plotters/particles.py new file mode 100644 index 0000000..b4bbe1b --- /dev/null +++ b/nt2/plotters/particles.py @@ -0,0 +1,102 @@ +import xarray as xr +import numpy as np + + +class ds_accessor: + def __init__(self, xarray_obj: xr.Dataset): + self._obj: xr.Dataset = xarray_obj + + def phaseplot( + self, + x: str = "x", + y: str = "ux", + xbins: None | np.ndarray = None, + ybins: None | np.ndarray = None, + xlims: None | tuple[float] = None, + ylims: None | tuple[float] = None, + xnbins: int = 100, + ynbins: int = 100, + **kwargs, + ): + """ + Create a 2D histogram (phase plot) of two variables in the dataset. + + Parameters + ---------- + x : str + The variable name for the x-axis (default: "x"). + y : str + The variable name for the y-axis (default: "ux"). + xbins : np.ndarray, optional + The bin edges for the x-axis. If None, 100 bins between min and max of x are used. + ybins : np.ndarray, optional + The bin edges for the y-axis. If None, 100 bins between min and max of y are used. + xlims : tuple[float], optional + The limits for the x-axis. If None, the limits are determined from the data. + ylims : tuple[float], optional + The limits for the y-axis. If None, the limits are determined from the data. + xnbins : int, optional + The number of bins for the x-axis if xbins is None (default: 100). + ynbins : int, optional + The number of bins for the y-axis if ybins is None (default: 100). + **kwargs + Additional keyword arguments passed to matplotlib's pcolormesh. + + Raises + ------ + AssertionError + If x or y are not valid variable names in the dataset, or if the dataset has a time dimension. + + Returns + ------- + None + + Examples + -------- + >>> ds.phaseplot(x='x', y='ux', xbins=np.linspace(0, 1000, 100), ybins=np.linspace(-5, 5, 50)) + """ + assert x in list(self._obj.keys()) and y in list( + self._obj.keys() + ), "x and y must be valid variable names in the dataset" + assert ( + len(self._obj[x].dims) == 1 and len(self._obj[y].dims) == 1 + ), "x and y must be 1D variables" + assert "t" not in self._obj.dims, "Dataset must not have time dimension" + + import matplotlib.pyplot as plt + + if xbins is None: + if xlims is not None: + xbins_ = np.linspace(xlims[0], xlims[1], xnbins) + else: + xbins_ = np.linspace( + self._obj[x].values.min(), self._obj[x].values.max(), xnbins + ) + else: + xbins_ = xbins + if ybins is None: + if ylims is not None: + ybins_ = np.linspace(ylims[0], ylims[1], ynbins) + else: + ybins_ = np.linspace( + self._obj[y].values.min(), self._obj[y].values.max(), ynbins + ) + else: + ybins_ = ybins + + cnt, _, _ = np.histogram2d( + self._obj[x].values, self._obj[y].values, bins=[xbins_, ybins_] + ) + xbins_ = 0.5 * (xbins_[1:] + xbins_[:-1]) + ybins_ = 0.5 * (ybins_[1:] + ybins_[:-1]) + + ax = kwargs.pop("ax", plt.gca()) + ax.pcolormesh( + xbins_, + ybins_, + cnt.T, + rasterized=True, + **kwargs, + ) + ax.set_xlabel(x) + ax.set_ylabel(y) diff --git a/nt2/utils.py b/nt2/utils.py index 5e5d233..d85cece 100644 --- a/nt2/utils.py +++ b/nt2/utils.py @@ -93,9 +93,7 @@ def DataIs2DPolar(ds: xr.Dataset) -> bool: bool True if the dataset is 2D polar, False otherwise. """ - return ("r" in ds.dims and ("θ" in ds.dims or "th" in ds.dims)) and len( - ds.dims - ) == 2 + return ("r" in ds.dims and "th" in ds.dims) and len(ds.dims) == 2 def InheritClassDocstring(cls: type) -> type: diff --git a/shell.nix b/shell.nix index f8148ca..d7ee1a0 100644 --- a/shell.nix +++ b/shell.nix @@ -9,7 +9,7 @@ pkgs.mkShell { pkgs."python${py}" pkgs."python${py}Packages".pip black - basedpyright + pyright taplo vscode-langservers-extracted zlib