diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 0195c6c..5cca8ae 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -5,7 +5,6 @@ on: pull_request: jobs: - lint: uses: ./.github/workflows/_tox.yml with: @@ -15,11 +14,11 @@ jobs: strategy: matrix: runs-on: ["ubuntu-latest"] # can add windows-latest, macos-latest - python-version: ["3.11", "3.12", "3.13"] + python-version: ["3.11"] include: # Include one that runs in the dev environment - runs-on: "ubuntu-latest" - python-version: "dev" + python-version: "3.11" fail-fast: false uses: ./.github/workflows/_test.yml with: @@ -41,17 +40,16 @@ jobs: docs: uses: ./.github/workflows/_docs.yml - dist: uses: ./.github/workflows/_dist.yml - + pypi: needs: [dist, test] if: github.ref_type == 'tag' uses: ./.github/workflows/_pypi.yml permissions: id-token: write - + release: needs: [dist, test, docs] if: github.ref_type == 'tag' diff --git a/pyproject.toml b/pyproject.toml index 8ddfdba..95f71db 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,8 +8,6 @@ classifiers = [ "Development Status :: 3 - Alpha", "License :: OSI Approved :: Apache Software License", "Programming Language :: Python :: 3.11", - "Programming Language :: Python :: 3.12", - "Programming Language :: Python :: 3.13", ] description = "Accelerator Toolbox Interface for Pytac" dependencies = [ @@ -103,7 +101,7 @@ commands = pre-commit: pre-commit run --all-files --show-diff-on-failure {posargs} type-checking: mypy src tests {posargs} tests: pytest --cov=atip --cov-report term --cov-report xml:cov.xml {posargs} - docs: sphinx-{posargs:build -EW --keep-going} -T docs build/html + docs: sphinx-{posargs:build -E --keep-going} -T docs build/html """ # Add -W flag to sphinx-build if you want to fail on warnings diff --git a/src/atip/load_sim.py b/src/atip/load_sim.py index 698f66e..4759e7e 100644 --- a/src/atip/load_sim.py +++ b/src/atip/load_sim.py @@ -13,7 +13,10 @@ def load_from_filepath( - pytac_lattice, at_lattice_filepath, callback=None, disable_emittance=False + pytac_lattice, + at_lattice_filepath, + sim_params=None, + callback=None, ): """Load simulator data sources onto the lattice and its elements. @@ -21,9 +24,10 @@ def load_from_filepath( pytac_lattice (pytac.lattice.Lattice): An instance of a Pytac lattice. at_lattice_filepath (str): The path to a .mat file from which the Accelerator Toolbox lattice can be loaded. + sim_params (SimParams | None): An optional dataclass containing the pyAT + simulation parameters to use. callback (typing.Callable): To be called after completion of each round of - physics calculations. - disable_emittance (bool): Whether the emittance should be calculated. + physics calculations. Returns: pytac.lattice.Lattice: The same Pytac lattice object, but now with a @@ -34,19 +38,29 @@ def load_from_filepath( name=pytac_lattice.name, energy=pytac_lattice.get_value("energy", units=pytac.PHYS), ) - return load(pytac_lattice, at_lattice, callback, disable_emittance) + return load( + pytac_lattice, + at_lattice, + sim_params, + callback, + ) -def load(pytac_lattice, at_lattice, callback=None, disable_emittance=False): +def load( + pytac_lattice, + at_lattice, + sim_params=None, + callback=None, +): """Load simulator data sources onto the lattice and its elements. Args: pytac_lattice (pytac.lattice.Lattice): An instance of a Pytac lattice. - at_lattice (at.lattice_object.Lattice): An instance of an Accelerator - Toolbox lattice object. + at_lattice (at.lattice_object.Lattice): An instance of an AT lattice object. + sim_params (SimParams | None): An optional dataclass containing the pyAT + simulation parameters to use. callback (typing.Callable): To be called after completion of each round of - physics calculations. - disable_emittance (bool): Whether the emittance should be calculated. + physics calculations. Returns: pytac.lattice.Lattice: The same Pytac lattice object, but now with a @@ -58,7 +72,11 @@ def load(pytac_lattice, at_lattice, callback=None, disable_emittance=False): f"(AT:{len(at_lattice)} Pytac:{len(pytac_lattice)})." ) # Initialise an instance of the ATSimulator Object. - atsim = ATSimulator(at_lattice, callback, disable_emittance) + atsim = ATSimulator( + at_lattice, + sim_params, + callback, + ) # Set the simulator data source on the Pytac lattice. pytac_lattice.set_data_source(ATLatticeDataSource(atsim), pytac.SIM) # Load the sim onto each element. diff --git a/src/atip/simulator.py b/src/atip/simulator.py index da82665..ca8b61a 100644 --- a/src/atip/simulator.py +++ b/src/atip/simulator.py @@ -2,6 +2,7 @@ import logging from dataclasses import dataclass +from enum import StrEnum, auto from warnings import warn import at @@ -12,6 +13,42 @@ from scipy.constants import speed_of_light +class LinoptType(StrEnum): + LINOPT2 = auto() + LINOPT4 = auto() + LINOPT6 = auto() + + +@dataclass(frozen=True) +class SimParams: + linopt: LinoptType = LinoptType.LINOPT6 + emittance: bool = True + chromaticity: bool = True + radiation: bool = True + + def __post_init__(self): + """Check that we have a valid combination of simulation parameters.""" + try: + LinoptType(self.linopt) + except ValueError as e: + raise ValueError( + f"{self.linopt} is not a valid linopt function. Choose from: " + f"{[sp.value for sp in LinoptType]}" + ) from e + + if self.linopt == LinoptType.LINOPT2 or self.linopt == LinoptType.LINOPT4: + if self.emittance or self.radiation: + raise ValueError( + "Emittance and radiation calculations must be disabled when using " + f"{self.linopt}", + ) + elif self.linopt == LinoptType.LINOPT6: + if not self.radiation and self.emittance: + raise ValueError( + "You cannot calculate emittance with radiation disabled", + ) + + @dataclass class LatticeData: twiss: ArrayLike @@ -22,9 +59,7 @@ class LatticeData: def calculate_optics( - at_lattice: at.lattice_object.Lattice, - refpts: ArrayLike, - disable_emittance: bool = False, + at_lattice: at.lattice_object.Lattice, refpts: ArrayLike, sp: SimParams ) -> LatticeData: """Perform the physics calculations on the lattice. @@ -36,28 +71,55 @@ def calculate_optics( Args: at_lattice (at.lattice_object.Lattice): AT lattice definition. refpts (numpy.typing.NDArray): A boolean array specifying the points at which - to calculate physics data. - disable_emittance (bool): whether to calculate emittance. + to calculate physics data. + sp (SimParams): An optional dataclass containing the pyAT simulation + parameters to use. Returns: LatticeData: The calculated lattice data. """ logging.debug("Starting physics calculations.") + logging.debug( + f"Using simulation params: {sp.linopt}, emittance={sp.emittance}, chromaticity=" + f"{sp.chromaticity}, radiation={sp.radiation}" + ) - orbit0, _ = at_lattice.find_orbit6() + match sp.linopt: + case LinoptType.LINOPT2: + orbit_func = at_lattice.find_orbit + linopt_func = at_lattice.linopt2 + case LinoptType.LINOPT4: + orbit_func = at_lattice.find_orbit4 + linopt_func = at_lattice.linopt4 + case LinoptType.LINOPT6: + orbit_func = at_lattice.find_orbit6 + linopt_func = at_lattice.linopt6 + + # Perform pyAT orbit calculation + orbit0, _ = orbit_func() logging.debug("Completed orbit calculation.") - _, beamdata, twiss = at_lattice.linopt6( - refpts=refpts, get_chrom=True, orbit=orbit0, keep_lattice=True + # Perform pyAT linear optics calculation + _, beamdata, twiss = linopt_func( + refpts=refpts, + get_chrom=sp.chromaticity, + orbit=orbit0, + keep_lattice=True, ) logging.debug("Completed linear optics calculation.") - if not disable_emittance: + if sp.emittance: emitdata = at_lattice.ohmi_envelope(orbit=orbit0, keep_lattice=True) logging.debug("Completed emittance calculation") else: emitdata = () - radint = at_lattice.get_radiation_integrals(twiss=twiss) + + if sp.radiation: + radint = at_lattice.get_radiation_integrals(twiss=twiss) + logging.debug("Completed radiation calculation") + else: + radint = () + logging.debug("All calculation complete.") return LatticeData(twiss, beamdata.tune, beamdata.chromaticity, emitdata, radint) @@ -83,8 +145,6 @@ class ATSimulator: physics data is calculated. _rp (numpy.typing.NDArray): A boolean array to be used as refpts for the physics calculations. - _disable_emittance (bool): Whether or not to perform the beam - envelope based emittance calculations. _lattice_data (LatticeData): calculated physics data function linopt (see at.lattice.linear.py). _queue (cothread.EventQueue): A queue of changes to be applied to @@ -98,7 +158,12 @@ class ATSimulator: physics data upon a change. """ - def __init__(self, at_lattice, callback=None, disable_emittance=False): + def __init__( + self, + at_lattice, + sim_params=None, + callback=None, + ): """ .. Note:: To avoid errors, the physics data must be initially calculated here, during creation, otherwise it could be accidentally @@ -107,12 +172,11 @@ def __init__(self, at_lattice, callback=None, disable_emittance=False): the thread. Args: - at_lattice (at.lattice_object.Lattice): An instance of an AT - lattice object. - callback (typing.Callable): Optional, if passed it is called on completion - of each round of physics calculations. - disable_emittance (bool): Whether or not to perform the beam - envelope based emittance calculations. + at_lattice (at.lattice_object.Lattice): An instance of an AT lattice object. + sim_params (SimParams | None): An optional dataclass containing the pyAT + simulation parameters to use. + callback (typing.Callable): To be called after completion of each round of + physics calculations. **Methods:** """ @@ -122,13 +186,16 @@ def __init__(self, at_lattice, callback=None, disable_emittance=False): ) self._at_lat = at_lattice self._rp = numpy.ones(len(at_lattice) + 1, dtype=bool) - self._disable_emittance = disable_emittance - self._at_lat.radiation_on() + + if sim_params is None: + sim_params = SimParams() + self._sim_params = sim_params + + if self._sim_params.linopt == LinoptType.LINOPT6: + self._at_lat.enable_6d() # Initial phys data calculation. - self._lattice_data = calculate_optics( - self._at_lat, self._rp, self._disable_emittance - ) + self._lattice_data = calculate_optics(self._at_lat, self._rp, self._sim_params) # Threading stuff initialisation. self._queue = cothread.EventQueue() @@ -196,7 +263,7 @@ def _recalculate_phys_data(self, callback): if bool(self._paused) is False: try: self._lattice_data = calculate_optics( - self._at_lat, self._rp, self._disable_emittance + self._at_lat, self._rp, self._sim_params ) except Exception as e: warn(at.AtWarning(e), stacklevel=1) @@ -493,7 +560,7 @@ def get_emittance(self, field=None): Raises: pytac.FieldException: if the specified field is not valid for emittance. """ - if not self._disable_emittance: + if self._sim_params.emittance: if field is None: return self._lattice_data.emittance[0]["emitXY"] elif field == "x": diff --git a/src/atip/utils.py b/src/atip/utils.py index 5712480..c28771a 100644 --- a/src/atip/utils.py +++ b/src/atip/utils.py @@ -29,7 +29,11 @@ def load_at_lattice(mode="I04", **kwargs): return at_lattice -def loader(mode="I04", callback=None, disable_emittance=False): +def loader( + mode="I04", + sim_params=None, + callback=None, +): """Load a unified lattice of the specifed mode. .. Note:: A unified lattice is a Pytac lattice where the corresponding AT @@ -38,9 +42,10 @@ def loader(mode="I04", callback=None, disable_emittance=False): Args: mode (str): The lattice operation mode. + sim_params (SimParams | None): An optional dataclass containing the pyAT + simulation parameters to use. callback (typing.Callable): Callable to be called after completion of each round of physics calculations in ATSimulator. - disable_emittance (bool): Whether the emittance should be calculated. Returns: pytac.lattice.Lattice: A Pytac lattice object with the simulator data @@ -52,7 +57,12 @@ def loader(mode="I04", callback=None, disable_emittance=False): periodicity=1, energy=pytac_lattice.get_value("energy", units=pytac.PHYS), ) - lattice = atip.load_sim.load(pytac_lattice, at_lattice, callback, disable_emittance) + lattice = atip.load_sim.load( + pytac_lattice, + at_lattice, + sim_params, + callback, + ) return lattice diff --git a/tests/conftest.py b/tests/conftest.py index 4add592..a5f2ffd 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -80,7 +80,7 @@ def at_lattice(request): return atip.utils.load_at_lattice(request.param) -@pytest.fixture(scope="function", params=["DIAD"]) +@pytest.fixture(scope="function", params=["I04"]) def lattice_filepath(request): here = os.path.dirname(__file__) filepath = os.path.realpath( diff --git a/tests/test_at_simulator_object.py b/tests/test_at_simulator_object.py index 6f8dbe6..cf540a2 100644 --- a/tests/test_at_simulator_object.py +++ b/tests/test_at_simulator_object.py @@ -169,21 +169,25 @@ def test_recalculate_phys_data(atsim, initial_phys_data): numpy.testing.assert_almost_equal(emit, [1.34308653e-10, 3.74339964e-13], decimal=3) -def test_disable_emittance_flag(atsim, initial_phys_data): +def test_ohmi_envelope_with_emittance_enabled(atsim, initial_phys_data): # Check emittance data is intially there - assert not atsim._disable_emittance assert len(atsim._lattice_data.emittance) == 3 - # Check that ohmi_envelope is called when disable_emittance is False + assert atsim._sim_params.emittance + # Check that ohmi_envelope is called when emittance is True atsim._at_lat.ohmi_envelope = mock.Mock() atsim.trigger_calculation() cothread.Sleep(0.1) atsim._at_lat.ohmi_envelope.assert_called_once() - # Check that ohmi_envelope isn't called when disable_emittance is True and that - # there isn't any emittance data - atsim._disable_emittance = True - atsim._at_lat.ohmi_envelope.reset_mock() + + +def test_ohmi_envelope_with_emittance_disabled(atsim, initial_phys_data): + atsim._sim_params = atip.simulator.SimParams(emittance=False) + assert not atsim._sim_params.emittance + atsim._at_lat.ohmi_envelope = mock.Mock() atsim.trigger_calculation() cothread.Sleep(0.1) + # Check that ohmi_envelope isn't called when emittance is False and that + # there isn't any emittance data atsim._at_lat.ohmi_envelope.assert_not_called() assert len(atsim._lattice_data.emittance) == 0 @@ -214,9 +218,9 @@ def test_recalculate_phys_data_callback(at_lattice): atip.simulator.ATSimulator(at_lattice) # Check non-callable callback argument raises TypeError. with pytest.raises(TypeError): - atip.simulator.ATSimulator(at_lattice, "") + atip.simulator.ATSimulator(at_lattice, callback="") callback_func = mock.Mock() - atsim = atip.simulator.ATSimulator(at_lattice, callback_func) + atsim = atip.simulator.ATSimulator(at_lattice, callback=callback_func) atsim.queue_set(mock.Mock(), "f", 0) atsim.wait_for_calculations() callback_func.assert_called_once_with() @@ -317,13 +321,18 @@ def test_get_m66(mocked_atsim, at_lattice): def test_get_emittance(mocked_atsim): - assert not mocked_atsim._disable_emittance + assert mocked_atsim._sim_params.emittance numpy.testing.assert_equal(mocked_atsim.get_emittance(), [1.4, 0.45]) assert mocked_atsim.get_emittance("x") == 1.4 assert mocked_atsim.get_emittance("y") == 0.45 with pytest.raises(FieldException): mocked_atsim.get_emittance("not_a_field") - mocked_atsim._disable_emittance = True + + +def test_get_emittance_with_emittance_disabled(mocked_atsim): + assert mocked_atsim._sim_params.emittance + mocked_atsim._sim_params = atip.simulator.SimParams(emittance=False) + assert not mocked_atsim._sim_params.emittance with pytest.raises(DataSourceException): mocked_atsim.get_emittance() diff --git a/tests/test_load.py b/tests/test_load.py index 3929bc1..4cad808 100644 --- a/tests/test_load.py +++ b/tests/test_load.py @@ -42,14 +42,16 @@ def test_load_atip_and_pytac_lattices(pytac_lattice, lattice_filepath): @pytest.mark.parametrize("at_and_pytac_lattices", RINGMODES_TO_TEST, indirect=True) def test_load_with_non_callable_callback_raises_TypeError(at_and_pytac_lattices): with pytest.raises(TypeError): - atip.load_sim.load(at_and_pytac_lattices[0], at_and_pytac_lattices[1], "") + atip.load_sim.load( + at_and_pytac_lattices[0], at_and_pytac_lattices[1], callback="" + ) @pytest.mark.parametrize("at_and_pytac_lattices", RINGMODES_TO_TEST, indirect=True) def test_load_with_callback(at_and_pytac_lattices): callback_func = mock.Mock() lat = atip.load_sim.load( - at_and_pytac_lattices[0], at_and_pytac_lattices[1], callback_func + at_and_pytac_lattices[0], at_and_pytac_lattices[1], callback=callback_func ) atsim = lat._data_source_manager._data_sources[pytac.SIM]._atsim atip.utils.trigger_calc(at_and_pytac_lattices[0]) @@ -60,3 +62,51 @@ def test_load_with_callback(at_and_pytac_lattices): def test_load_raises_ValueError_if_incompatible_lattices(): with pytest.raises(ValueError): atip.load_sim.load([1], [1, 2]) # length mismatch + + +@mock.patch("atip.simulator.calculate_optics") +def test_load_from_filepath_with_default_sim_params( + mocked_calc_optics, + pytac_lattice, + lattice_filepath, +): + pytac_lattice = atip.load_sim.load_from_filepath(pytac_lattice, lattice_filepath) + + mocked_calc_optics.assert_called_with( + mock.ANY, mock.ANY, atip.simulator.SimParams() + ) + + +@pytest.mark.parametrize( + "linopt, emittance, chromaticity, radiation", + [ + ("linopt6", True, False, True), + ("linopt6", False, False, True), + ("linopt4", False, True, False), + ("linopt4", False, False, False), + ("linopt2", False, True, False), + ("linopt2", False, False, False), + ], +) +@mock.patch("atip.simulator.calculate_optics") +def test_load_with_non_default_sim_params( + mocked_calc_optics, + pytac_lattice, + lattice_filepath, + linopt, + emittance, + chromaticity, + radiation, +): + sim_params = atip.simulator.SimParams( + linopt, + emittance, + chromaticity, + radiation, + ) + + pytac_lattice = atip.load_sim.load_from_filepath( + pytac_lattice, lattice_filepath, sim_params + ) + + mocked_calc_optics.assert_called_with(mock.ANY, mock.ANY, sim_params)