Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,12 @@ You can also create a movie of a single field quantity (can be custom):
(data.fields.Ex * data.fields.Bx).sel(x=slice(None, 0.2)).movie.plot(name="ExBx", vmin=-0.01, vmax=0.01, cmap="BrBG")
```

For particles, one can also make 2D phase-space plots:

```python
data.particles[1].sel(t=1.0, method="nearest").particles.phaseplot(x="x", y="uy", xnbins=100, ynbins=200, xlims=(0, 100), cmap="inferno")
```

You may also combine different quantities and plots (e.g., fields & particles) to produce a more customized movie:

```python
Expand Down
Binary file added dist/nt2py-1.1.0-py3-none-any.whl
Binary file not shown.
Binary file added dist/nt2py-1.1.0.tar.gz
Binary file not shown.
2 changes: 1 addition & 1 deletion nt2/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "1.0.1"
__version__ = "1.1.0"

import nt2.containers.data as nt2_data

Expand Down
11 changes: 10 additions & 1 deletion nt2/containers/data.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
from typing import Callable, Any

import sys

if sys.version_info >= (3, 12):
from typing import override
else:

def override(method):
return method


from collections.abc import KeysView
from nt2.utils import ToHumanReadable

Expand All @@ -25,7 +28,7 @@ def override(method):
from nt2.containers.particles import Particles

import nt2.plotters.polar as acc_polar

import nt2.plotters.particles as acc_particles
import nt2.plotters.inspect as acc_inspect
import nt2.plotters.movie as acc_movie
from nt2.plotters.export import makeFramesAndMovie
Expand All @@ -37,6 +40,12 @@ class DatasetPolarPlotAccessor(acc_polar.ds_accessor):
pass


@xr.register_dataset_accessor("particles")
@InheritClassDocstring
class DatasetParticlesPlotAccessor(acc_particles.ds_accessor):
pass


@xr.register_dataarray_accessor("polar")
@InheritClassDocstring
class PolarPlotAccessor(acc_polar.accessor):
Expand Down
4 changes: 2 additions & 2 deletions nt2/plotters/inspect.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def plot(
)

def plot_func(ti: int, _):
if len(self._obj.dims) == 1:
if len(self._obj.isel(t=ti).dims) == 1:
_ = self.plot_frame_1d(
self._obj.isel(t=ti),
None,
Expand All @@ -189,7 +189,7 @@ def plot_func(ti: int, _):
fig_kwargs,
plot_kwargs,
)
elif len(self._obj.dims) == 2:
elif len(self._obj.isel(t=ti).dims) == 2:
_ = self.plot_frame_2d(
self._obj.isel(t=ti),
None,
Expand Down
102 changes: 102 additions & 0 deletions nt2/plotters/particles.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
import xarray as xr
import numpy as np


class ds_accessor:
def __init__(self, xarray_obj: xr.Dataset):
self._obj: xr.Dataset = xarray_obj

def phaseplot(
self,
x: str = "x",
y: str = "ux",
xbins: None | np.ndarray = None,
ybins: None | np.ndarray = None,
xlims: None | tuple[float] = None,
ylims: None | tuple[float] = None,
xnbins: int = 100,
ynbins: int = 100,
**kwargs,
):
"""
Create a 2D histogram (phase plot) of two variables in the dataset.

Parameters
----------
x : str
The variable name for the x-axis (default: "x").
y : str
The variable name for the y-axis (default: "ux").
xbins : np.ndarray, optional
The bin edges for the x-axis. If None, 100 bins between min and max of x are used.
ybins : np.ndarray, optional
The bin edges for the y-axis. If None, 100 bins between min and max of y are used.
xlims : tuple[float], optional
The limits for the x-axis. If None, the limits are determined from the data.
ylims : tuple[float], optional
The limits for the y-axis. If None, the limits are determined from the data.
xnbins : int, optional
The number of bins for the x-axis if xbins is None (default: 100).
ynbins : int, optional
The number of bins for the y-axis if ybins is None (default: 100).
**kwargs
Additional keyword arguments passed to matplotlib's pcolormesh.

Raises
------
AssertionError
If x or y are not valid variable names in the dataset, or if the dataset has a time dimension.

Returns
-------
None

Examples
--------
>>> ds.phaseplot(x='x', y='ux', xbins=np.linspace(0, 1000, 100), ybins=np.linspace(-5, 5, 50))
"""
assert x in list(self._obj.keys()) and y in list(
self._obj.keys()
), "x and y must be valid variable names in the dataset"
assert (
len(self._obj[x].dims) == 1 and len(self._obj[y].dims) == 1
), "x and y must be 1D variables"
assert "t" not in self._obj.dims, "Dataset must not have time dimension"

import matplotlib.pyplot as plt

if xbins is None:
if xlims is not None:
xbins_ = np.linspace(xlims[0], xlims[1], xnbins)
else:
xbins_ = np.linspace(
self._obj[x].values.min(), self._obj[x].values.max(), xnbins
)
else:
xbins_ = xbins
if ybins is None:
if ylims is not None:
ybins_ = np.linspace(ylims[0], ylims[1], ynbins)
else:
ybins_ = np.linspace(
self._obj[y].values.min(), self._obj[y].values.max(), ynbins
)
else:
ybins_ = ybins

cnt, _, _ = np.histogram2d(
self._obj[x].values, self._obj[y].values, bins=[xbins_, ybins_]
)
xbins_ = 0.5 * (xbins_[1:] + xbins_[:-1])
ybins_ = 0.5 * (ybins_[1:] + ybins_[:-1])

ax = kwargs.pop("ax", plt.gca())
ax.pcolormesh(
xbins_,
ybins_,
cnt.T,
rasterized=True,
**kwargs,
)
ax.set_xlabel(x)
ax.set_ylabel(y)
4 changes: 1 addition & 3 deletions nt2/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,7 @@ def DataIs2DPolar(ds: xr.Dataset) -> bool:
bool
True if the dataset is 2D polar, False otherwise.
"""
return ("r" in ds.dims and ("θ" in ds.dims or "th" in ds.dims)) and len(
ds.dims
) == 2
return ("r" in ds.dims and "th" in ds.dims) and len(ds.dims) == 2


def InheritClassDocstring(cls: type) -> type:
Expand Down
2 changes: 1 addition & 1 deletion shell.nix
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ pkgs.mkShell {
pkgs."python${py}"
pkgs."python${py}Packages".pip
black
basedpyright
pyright
taplo
vscode-langservers-extracted
zlib
Expand Down