Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/run_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -63,5 +63,5 @@ jobs:
- name: Install and Test with pytest
run: |
export PATH="$pythonLocation:$PATH"
python -m pip install -e .[Dev,Orso]
python -m pip install -e .[dev,orso]
pytest tests/ --cov=ratapi --cov-report=term
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,6 @@ dist/*

# Jupyter notebook checkpoints
.ipynb_checkpoints/*

# Lock file for uv env
uv.lock
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@ To install in local directory:

matlabengine is an optional dependency only required for Matlab custom functions. The version of matlabengine should match the version of Matlab installed on the machine. This can be installed as shown below:

pip install -e .[Matlab-2023a]
pip install -e .[matlab-2023a]

Development dependencies can be installed as shown below

pip install -e .[Dev]
pip install -e .[dev]

To build wheel:

Expand Down
2 changes: 1 addition & 1 deletion cpp/RAT
Submodule RAT updated 1 files
+1 −5 groupLayersMod.cpp
50 changes: 49 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,53 @@ requires = [
]
build-backend = 'setuptools.build_meta'

[project]
name = "ratapi"
version = "0.0.0.dev8"
description = "Python extension for the Reflectivity Analysis Toolbox (RAT)"
readme = "README.md"
requires-python = ">=3.10"
dependencies = [
"matplotlib>=3.8.3",
"numpy>=1.20",
"prettytable>=3.9.0",
"pydantic>=2.7.2",
"scipy>=1.13.1",
"strenum>=0.4.15 ; python_full_version < '3.11'",
"tqdm>=4.66.5",
]

[project.optional-dependencies]
dev = [
"pytest>=7.4.0",
"pytest-cov>=4.1.0",
"ruff>=0.4.10"
]
orso = [
"orsopy>=1.2.1",
"pint>=0.24.4"
]
matlab_latest = ["matlabengine"]
matlab_2025b = ["matlabengine == 25.2.*"]
matlab_2025a = ["matlabengine == 25.1.2"]
matlab_2024b = ["matlabengine == 24.2.2"]
matlab_2024a = ["matlabengine == 24.1.4"]
matlab_2023b = ["matlabengine == 23.2.3"]
matlab_2023a = ["matlabengine == 9.14.3"]

[tool.uv]
conflicts = [
[
{ extra = "matlab_latest" },
{ extra = "matlab_2025b" },
{ extra = "matlab_2025a" },
{ extra = "matlab_2024b" },
{ extra = "matlab_2024a" },
{ extra = "matlab_2023b" },
{ extra = "matlab_2023a" },
],
]

[tool.ruff]
line-length = 120
extend-exclude = ["*.ipynb"]
Expand All @@ -24,7 +71,8 @@ ignore = ["SIM103", # needless bool
"D105", # undocumented __init__
"D107", # undocumented magic method
"D203", # blank line before class docstring
"D213"] # multi line summary should start at second line
"D213", # multi line summary should start at second line
"UP038"] # non pep604 isinstance - to be removed

# ignore docstring lints in the tests and install script
[tool.ruff.lint.per-file-ignores]
Expand Down
19 changes: 10 additions & 9 deletions ratapi/classlist.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import importlib
import warnings
from collections.abc import Sequence
from typing import Any, Generic, TypeVar, Union
from typing import Any, Generic, TypeVar

import numpy as np
import prettytable
Expand Down Expand Up @@ -38,7 +38,7 @@ class ClassList(collections.UserList, Generic[T]):

"""

def __init__(self, init_list: Union[Sequence[T], T] = None, name_field: str = "name") -> None:
def __init__(self, init_list: Sequence[T] | T = None, name_field: str = "name") -> None:
self.name_field = name_field

# Set input as list if necessary
Expand Down Expand Up @@ -114,7 +114,7 @@ def __str__(self):
output = str(self.data)
return output

def __getitem__(self, index: Union[int, slice, str, T]) -> T:
def __getitem__(self, index: int | slice | str | T) -> T:
"""Get an item by its index, name, a slice, or the object itself."""
if isinstance(index, (int, slice)):
return self.data[index]
Expand Down Expand Up @@ -262,12 +262,12 @@ def insert(self, index: int, obj: T = None, **kwargs) -> None:
self._validate_name_field(kwargs)
self.data.insert(index, self._class_handle(**kwargs))

def remove(self, item: Union[T, str]) -> None:
def remove(self, item: T | str) -> None:
"""Remove an object from the ClassList using either the object itself or its ``name_field`` value."""
item = self._get_item_from_name_field(item)
self.data.remove(item)

def count(self, item: Union[T, str]) -> int:
def count(self, item: T | str) -> int:
"""Return the number of times an object appears in the ClassList.

This method can use either the object itself or its ``name_field`` value.
Expand All @@ -276,7 +276,7 @@ def count(self, item: Union[T, str]) -> int:
item = self._get_item_from_name_field(item)
return self.data.count(item)

def index(self, item: Union[T, str], offset: bool = False, *args) -> int:
def index(self, item: T | str, offset: bool = False, *args) -> int:
"""Return the index of a particular object in the ClassList.

This method can use either the object itself or its ``name_field`` value.
Expand Down Expand Up @@ -309,7 +309,7 @@ def union(self, other: Sequence[T]) -> None:
]
)

def set_fields(self, index: Union[int, slice, str, T], **kwargs) -> None:
def set_fields(self, index: int | slice | str | T, **kwargs) -> None:
"""Assign the values of an existing object's attributes using keyword arguments."""
self._validate_name_field(kwargs)
pydantic_object = False
Expand Down Expand Up @@ -519,7 +519,7 @@ def _check_classes(self, input_list: Sequence[T]) -> None:
f"In the input list:\n{newline.join(error for error in error_list)}\n"
)

def _get_item_from_name_field(self, value: Union[T, str]) -> Union[T, str]:
def _get_item_from_name_field(self, value: T | str) -> T | str:
"""Return the object with the given value of the ``name_field`` attribute in the ClassList.

Parameters
Expand Down Expand Up @@ -577,11 +577,12 @@ def _determine_class_handle(input_list: Sequence[T]):
@classmethod
def __get_pydantic_core_schema__(cls, source: Any, handler):
# import here so that the ClassList can be instantiated and used without Pydantic installed
from typing import get_args, get_origin

from pydantic import ValidatorFunctionWrapHandler
from pydantic.types import (
core_schema, # import core_schema through here rather than making pydantic_core a dependency
)
from typing_extensions import get_args, get_origin

# if annotated with a class, get the item type of that class
origin = get_origin(source)
Expand Down
5 changes: 2 additions & 3 deletions ratapi/controls.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import tempfile
import warnings
from pathlib import Path
from typing import Union

import prettytable
from pydantic import (
Expand Down Expand Up @@ -233,7 +232,7 @@ def delete_IPC(self):
os.remove(self._IPCFilePath)
return None

def save(self, filepath: Union[str, Path] = "./controls.json"):
def save(self, filepath: str | Path = "./controls.json"):
"""Save a controls object to a JSON file.

Parameters
Expand All @@ -245,7 +244,7 @@ def save(self, filepath: Union[str, Path] = "./controls.json"):
filepath.write_text(self.model_dump_json())

@classmethod
def load(cls, path: Union[str, Path]) -> "Controls":
def load(cls, path: str | Path) -> "Controls":
"""Load a controls object from file.

Parameters
Expand Down
8 changes: 4 additions & 4 deletions ratapi/events.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
"""Hooks for connecting to run callback events."""

import os
from typing import Callable, Union
from collections.abc import Callable

from ratapi.rat_core import EventBridge, EventTypes, PlotEventData, ProgressEventData


def notify(event_type: EventTypes, data: Union[str, PlotEventData, ProgressEventData]) -> None:
def notify(event_type: EventTypes, data: str | PlotEventData | ProgressEventData) -> None:
"""Call registered callbacks with data when event type has been triggered.

Parameters
Expand All @@ -22,7 +22,7 @@ def notify(event_type: EventTypes, data: Union[str, PlotEventData, ProgressEvent
callback(data)


def get_event_callback(event_type: EventTypes) -> list[Callable[[Union[str, PlotEventData, ProgressEventData]], None]]:
def get_event_callback(event_type: EventTypes) -> list[Callable[[str | PlotEventData | ProgressEventData], None]]:
"""Return all callbacks registered for the given event type.

Parameters
Expand All @@ -39,7 +39,7 @@ def get_event_callback(event_type: EventTypes) -> list[Callable[[Union[str, Plot
return list(__event_callbacks[event_type])


def register(event_type: EventTypes, callback: Callable[[Union[str, PlotEventData, ProgressEventData]], None]) -> None:
def register(event_type: EventTypes, callback: Callable[[str | PlotEventData | ProgressEventData], None]) -> None:
"""Register a new callback for the event type.

Parameters
Expand Down
4 changes: 2 additions & 2 deletions ratapi/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import importlib
import os
import pathlib
from typing import Callable, Union
from collections.abc import Callable

import numpy as np

Expand All @@ -23,7 +23,7 @@
}


def get_python_handle(file_name: str, function_name: str, path: Union[str, pathlib.Path] = "") -> Callable:
def get_python_handle(file_name: str, function_name: str, path: str | pathlib.Path = "") -> Callable:
"""Get the function handle from a function defined in a python module located anywhere within the filesystem.

Parameters
Expand Down
14 changes: 7 additions & 7 deletions ratapi/outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import json
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Optional, Union
from typing import Any, Union

import numpy as np

Expand Down Expand Up @@ -244,7 +244,7 @@ def __str__(self):
output += get_field_string(key, value, 100)
return output

def save(self, filepath: Union[str, Path] = "./results.json"):
def save(self, filepath: str | Path = "./results.json"):
"""Save the Results object to a JSON file.

Parameters
Expand All @@ -258,7 +258,7 @@ def save(self, filepath: Union[str, Path] = "./results.json"):
filepath.write_text(json.dumps(json_dict))

@classmethod
def load(cls, path: Union[str, Path]) -> Union["Results", "BayesResults"]:
def load(cls, path: str | Path) -> Union["Results", "BayesResults"]:
"""Load a Results object from file.

Parameters
Expand Down Expand Up @@ -538,7 +538,7 @@ class BayesResults(Results):
nestedSamplerOutput: NestedSamplerOutput
chain: np.ndarray

def save(self, filepath: Union[str, Path] = "./results.json"):
def save(self, filepath: str | Path = "./results.json"):
"""Save the BayesResults object to a JSON file.

Parameters
Expand Down Expand Up @@ -574,7 +574,7 @@ def save(self, filepath: Union[str, Path] = "./results.json"):
filepath.write_text(json.dumps(json_dict))


def write_core_results_fields(results: Union[Results, BayesResults], json_dict: Optional[dict] = None) -> dict:
def write_core_results_fields(results: Results | BayesResults, json_dict: dict | None = None) -> dict:
"""Modify the values of the fields that appear in both Results and BayesResults when saving to a json file.

Parameters
Expand Down Expand Up @@ -684,8 +684,8 @@ def read_bayes_results_fields(results_dict: dict) -> dict:
def make_results(
procedure: Procedures,
output_results: ratapi.rat_core.OutputResult,
bayes_results: Optional[ratapi.rat_core.OutputBayesResult] = None,
) -> Union[Results, BayesResults]:
bayes_results: ratapi.rat_core.OutputBayesResult | None = None,
) -> Results | BayesResults:
"""Initialise a python Results or BayesResults object using the outputs from a RAT calculation.

Parameters
Expand Down
26 changes: 13 additions & 13 deletions ratapi/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@
import functools
import json
import warnings
from collections.abc import Callable
from enum import Enum
from pathlib import Path
from textwrap import indent
from typing import Annotated, Any, Callable, Union
from typing import Annotated, Any, get_args, get_origin

import numpy as np
from pydantic import (
Expand All @@ -21,7 +22,6 @@
field_validator,
model_validator,
)
from typing_extensions import get_args, get_origin

import ratapi.models
from ratapi.classlist import ClassList
Expand Down Expand Up @@ -248,10 +248,10 @@ class Project(BaseModel, validate_assignment=True, extra="forbid", use_attribute
data: ClassList[ratapi.models.Data] = ClassList()
"""Experimental data for a model."""

layers: Union[
Annotated[ClassList[ratapi.models.Layer], Tag("no_abs")],
Annotated[ClassList[ratapi.models.AbsorptionLayer], Tag("abs")],
] = Field(
layers: (
Annotated[ClassList[ratapi.models.Layer], Tag("no_abs")]
| Annotated[ClassList[ratapi.models.AbsorptionLayer], Tag("abs")]
) = Field(
default=ClassList(),
discriminator=Discriminator(
discriminate_layers,
Expand All @@ -265,10 +265,10 @@ class Project(BaseModel, validate_assignment=True, extra="forbid", use_attribute
domain_contrasts: ClassList[ratapi.models.DomainContrast] = ClassList()
"""The groups of layers required by each domain in a domains model."""

contrasts: Union[
Annotated[ClassList[ratapi.models.Contrast], Tag("no_ratio")],
Annotated[ClassList[ratapi.models.ContrastWithRatio], Tag("ratio")],
] = Field(
contrasts: (
Annotated[ClassList[ratapi.models.Contrast], Tag("no_ratio")]
| Annotated[ClassList[ratapi.models.ContrastWithRatio], Tag("ratio")]
) = Field(
default=ClassList(),
discriminator=Discriminator(
discriminate_contrasts,
Expand Down Expand Up @@ -577,7 +577,7 @@ def update_renamed_models(self) -> "Project":
old_names = self._all_names[class_list]
new_names = getattr(self, class_list).get_names()
if len(old_names) == len(new_names):
name_diff = [(old, new) for (old, new) in zip(old_names, new_names) if old != new]
name_diff = [(old, new) for (old, new) in zip(old_names, new_names, strict=False) if old != new]
for old_name, new_name in name_diff:
for field in fields_to_update:
project_field = getattr(self, field.attribute)
Expand Down Expand Up @@ -927,7 +927,7 @@ def classlist_script(name, classlist):
+ "\n)"
)

def save(self, filepath: Union[str, Path] = "./project.json"):
def save(self, filepath: str | Path = "./project.json"):
"""Save a project to a JSON file.

Parameters
Expand Down Expand Up @@ -973,7 +973,7 @@ def make_custom_file_dict(item):
filepath.write_text(json.dumps(json_dict))

@classmethod
def load(cls, path: Union[str, Path]) -> "Project":
def load(cls, path: str | Path) -> "Project":
"""Load a project from file.

Parameters
Expand Down
Loading