diff --git a/.gitignore b/.gitignore index 8680bb6b3..531776fe8 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,5 @@ +trial*.ipynb + **/.ipynb_checkpoints **/__pycache__ deeptrack-app/* @@ -25,4 +27,6 @@ examples/**/*/models/ *.jpg *.jpeg *.npy -*.db \ No newline at end of file +*.db + +.DS_Store \ No newline at end of file diff --git a/deeptrack/backend/citations.py b/deeptrack/backend/citations.py deleted file mode 100644 index baee70f35..000000000 --- a/deeptrack/backend/citations.py +++ /dev/null @@ -1,35 +0,0 @@ -deeptrack_bibtex = """ -@article{Midtvet2021DeepTrack, - author = {Midtvedt,Benjamin and - Helgadottir,Saga and - Argun,Aykut and - Pineda,Jesús and - Midtvedt,Daniel and - Volpe,Giovanni}, - title = {Quantitative digital microscopy with deep learning}, - journal = {Applied Physics Reviews}, - volume = {8}, - number = {1}, - pages = {011310}, - year = {2021}, - doi = {10.1063/5.0034891} -} -""" - -unet_bibtex = """ -@article{DBLP:journals/corr/RonnebergerFB15, - author = {Olaf Ronneberger and - Philipp Fischer and - Thomas Brox}, - title = {U-Net: Convolutional Networks for Biomedical Image Segmentation}, - journal = {CoRR}, - volume = {abs/1505.04597}, - year = {2015}, - url = {http://arxiv.org/abs/1505.04597}, - archivePrefix = {arXiv}, - eprint = {1505.04597}, - timestamp = {Mon, 13 Aug 2018 16:46:52 +0200}, - biburl = {https://dblp.org/rec/journals/corr/RonnebergerFB15.bib}, - bibsource = {dblp computer science bibliography, https://dblp.org} -} -""" \ No newline at end of file diff --git a/deeptrack/backend/core.py b/deeptrack/backend/core.py index 44b542c29..276b1dc8d 100644 --- a/deeptrack/backend/core.py +++ b/deeptrack/backend/core.py @@ -1,410 +1,1638 @@ -from copy import copy, deepcopy -import re -from weakref import WeakSet -import numpy as np -import operator +"""Core data strDeepTrack2 package. -from .. import utils, image -from .citations import deeptrack_bibtex +This package provides the core DeepTrack2 classes to manage and process data. +In particular, it enables users to: +- Construct flexible and efficient computational pipelines. +- Manage data and dependencies in a hierarchical structure. +- Perform lazy evaluations for performance optimization. + +Main Features +------------- +Data Management: `DeepTrackDataObject` and `DeepTrackDataDict` provide tools +to store, validate, and manage data with dependency tracking. They enable +nested data structures and flexible indexing for complex data hierarchies. + +Computational Graphs: `DeepTrackNode` forms the backbone of DeepTrack2 +computation pipelines, representing computation nodes in a computation graph. +Nodes support lazy evaluation, dependency tracking, and caching for improved +computational performance. They implement mathematical operators for easy +composition of computational graphs. + +Citations: Supports citing the relevant publication to ensure proper +attribution (e.g., `Midtvedt et al., 2021`). + +Package Structure +----------------- +Data Containers: +- `DeepTrackDataObject`: A basic container for data with validation status. +- `DeepTrackDataDict`: A data contained to store multiple data objects + (DeepTrackDataObject) indexed by unique access IDs + (consisting of tuples of integers), enabling nested data + storage. + +Computation Nodes: +- `DeepTrackNode`: Represents a node in a computation graph, capable of lazy + evaluation, caching, and dependency management. + +Example +------- +Create a DeepTrackNode: + +>>> node = DeepTrackNode(lambda x: x**2) +>>> node.store(5) + +Retrieve the stored value: + +>>> print(node.current_value()) # Output: 25 + +""" + +import operator # Operator overloading for computation nodes. +from weakref import WeakSet # Manages relationships between nodes without + # creating circular dependencies. + +from typing import ( + Any, Callable, Dict, Iterator, List, Optional, Set, Tuple, Union +) + +from .. import utils + + +citation_midtvet2021quantitative = """ +@article{Midtvet2021Quantitative, + author = {Midtvedt, Benjamin and Helgadottir, Saga and Argun, Aykut and + Pineda, Jesús and Midtvedt, Daniel and Volpe, Giovanni}, + title = {Quantitative digital microscopy with deep learning}, + journal = {Applied Physics Reviews}, + volume = {8}, + number = {1}, + pages = {011310}, + year = {2021}, + doi = {10.1063/5.0034891} +} +""" class DeepTrackDataObject: + """Basic data container for DeepTrack2. + + `DeepTrackDataObject` is a simple data container to store data and track + its validity. - """Atomic data container for deeptrack. + Attributes + ---------- + data : Any + The stored data. Default is `None`. + valid : bool + A flag indicating whether the stored data is valid. Default is `False`. + + Methods + ------- + store(data : Any) + Stores data in the container and marks it as valid. + current_value() -> Any + Returns the currently stored data. + is_valid() -> bool + Returns whether the stored data is valid. + invalidate() + Marks the data as invalid. + validate() + Marks the data as valid. + + Example + ------- + Create a `DeepTrackDataObject`: + + >>> data_obj = core.DeepTrackDataObject() + + Store a value in this container: + + >>> data_obj.store(42) + >>> print(data_obj.current_value()) + 42 + + Check if the stored data is valid: + + >>> print(data_obj.is_valid()) + True + + Invalidate the stored data: + + >>> data_obj.invalidate() + >>> print(data_obj.is_valid()) + False + + Validate the data again to restore its status: + + >>> data_obj.validate() + >>> print(data_obj.is_valid()) + True - The purpose of this is to store some data, and if that data is valid. - Data is not valid, if some dependency of the data has been changed or otherwise made invalid - since the last time the data was validated. """ + # Attributes. + data: Any + valid: bool + def __init__(self): + """Initialize the container without data. + + The `data` and `valid` attributes are set to their default values + `None` and `False`. + + """ + self.data = None self.valid = False - def store(self, data): - self.valid = True + def store(self, data: Any) -> None: + """Store data and mark it as valid. + + Parameters + ---------- + data : Any + The data to be stored in the container. + + """ + self.data = data + self.valid = True + + def current_value(self) -> Any: + """Retrieve the stored data. - def current_value(self): + Returns + ------- + Any + The data stored in the container. + + """ + return self.data - def is_valid(self): + def is_valid(self) -> bool: + """Return whether the stored data is valid. + + Returns + ------- + bool + `True` if the data is valid, `False` otherwise. + + """ + return self.valid - def invalidate(self): + def invalidate(self) -> None: + """Mark the stored data as invalid.""" + self.valid = False - def validate(self): + def validate(self) -> None: + """Mark the stored data as valid.""" + self.valid = True class DeepTrackDataDict: + """Stores multiple data objects indexed by a tuple of integers (ID). + + `DeepTrackDataDict` can store multiple `DeepTrackDataObject` instances, + each associated with a unique tuple of integers (its ID). This is + particularly useful to handle sequences of data or nested structures. + + The default ID is an empty tuple, `()`. Once the first entry is created, + all IDs must match the established key length: + - If an ID longer than the set length is requested, it is trimmed. + - If an ID shorter than the set length is requested, a dictionary slice + containing all matching entries is returned. + + Attributes + ---------- + keylength : int or None + The length of the IDs currently stored. Set when the first entry is + created. If `None`, no entries have been created yet, and any ID length + is valid. + dict : Dict[Tuple[int, ...], DeepTrackDataObject] + A dictionary mapping tuples of integers (IDs) to `DeepTrackDataObject` + instances. + + Methods + ------- + invalidate() + Marks all stored data objects as invalid. + validate() + Marks all stored data objects as valid. + valid_index(_ID : Tuple[int, ...]) -> bool + Checks if the given ID is valid for the current configuration. + create_index(_ID : Tuple[int, ...] = ()) + Creates an entry for the given ID if it does not exist. + __getitem__(_ID : tuple) -> Union[ + DeepTrackDataObject, + Dict[Tuple[int, ...], DeepTrackDataObject] + ] + Retrieves data associated with the ID. Can return a + `DeepTrackDataObject` or a dict of matching entries if `_ID` is shorter + than `keylength`. + __contains__(_ID : Tuple[int, ...]) -> bool + Checks if the given ID exists in the dictionary. + + Example + ------- + Imagine to have a structure that generates multiple instances of data: + + >>> data_dict = DeepTrackDataDict() + + # Create two top-level entries + >>> data_dict.create_index((0,)) + >>> data_dict.create_index((1,)) + + # Add nested entries + >>> data_dict.create_index((0, 0)) + >>> data_dict.create_index((0, 1)) + >>> data_dict.create_index((1, 0)) + >>> data_dict.create_index((1, 1)) + + Now, store and access values associated with each ID: + + >>> data_dict[(0, 0)].store("Data at (0, 0)") + >>> data_dict[(0, 1)].store("Data at (0, 1)") + >>> data_dict[(1, 0)].store("Data at (1, 0)") + >>> data_dict[(1, 1)].store("Data at (1, 1)") + + Retrieve values based on their IDs: + + >>> print(data_dict[(0, 0)].current_value()) + Data at (0, 0) + + >>> print(data_dict[(1, 1)].current_value()) + Data at (1, 1) + + If requesting a shorter ID, it returns all matching nested entries: + + >>> print(data_dict[(0,)]) + {(0, 0): , (0, 1): } + + """ - """Stores multiple data objects indexed by an access id. - - The purpose of this class is to allow a single object to store multiple - data objects at once. This is necessary for sequences and the feature `Repeat`. - - The access id is a tuple of integers. Consider the following example:: + # Attributes. + keylength: Optional[int] + dict: Dict[Tuple[int, ...], DeepTrackDataObject] - F = Repeat( - Repeat(DummyFeature(prop = np.random.rand), 2), - 2 - ) + def __init__(self): + """Initialize the data dictionary. - `F` contains 2*2=4 instances of the feature prop. They would be accessed using the IDs - (0, 0), (0, 1), (1, 0), and (1, 1). In this way nested structures are resolved. + Initializes `keylength` to `None` and `dict` to an empty dictionary, + indicating no data objects are currently stored. + + """ - The default is an empty tuple. + self.keylength = None + self.dict = {} - All IDs of a DataDict need to be of the same length. If a DataDict has value stored to the None ID, it can not store any other IDs. - If a longer ID is requested than what is stored, the request is trimmed to the length of the stored IDs. This is important to - correctly handle dependencies to lower depths of nested structures. + def invalidate(self) -> None: + """Mark all stored data objects as invalid. - If a shorter ID is requested than what is stored, the a slice of the DataDict is returned with all IDs matching the request. + Calls `invalidate()` on every `DeepTrackDataObject` in the dictionary. + + """ + for dataobject in self.dict.values(): + dataobject.invalidate() + def validate(self) -> None: + """Mark all stored data objects as valid. - """ + Calls `validate()` on every `DeepTrackDataObject` in the dictionary. + + """ - def __init__(self): - self.keylength = None - self.dict = {} + for dataobject in self.dict.values(): + dataobject.validate() - def invalidate(self): - # self.dict = {} - for d in self.dict.values(): - d.invalidate() + def valid_index(self, _ID: Tuple[int, ...]) -> bool: + """Check if a given ID is valid for this data dictionary. - def validate(self): - for d in self.dict.values(): - d.validate() + If `keylength` is `None`, any tuple ID is considered valid since no + entries have been created yet. If `_ID` already exists in `dict`, it is + automatically valid. Otherwise, `_ID` must have the same length as + `keylength` to be considered valid. + + Parameters + ---------- + _ID : Tuple[int, ...] + The index to check, consisting of a tuple of integers. + + Returns + ------- + bool + `True` if the ID is valid given the current configuration, `False` + otherwise. + + Raises + ------ + AssertionError + If `_ID` is not a tuple of integers. + + """ - def valid_index(self, _ID): - assert isinstance(_ID, tuple), f"Data index {_ID} is not a tuple" + # Ensure `_ID` is a tuple of integers. + assert isinstance(_ID, tuple), ( + f"Data index {_ID} is not a tuple. Got: {type(_ID).__name__}." + ) + assert all(isinstance(i, int) for i in _ID), ( + f"Data index {_ID} is not a tuple of integers. " + f"Got a tuple of types: {[type(i).__name__ for i in _ID]}." + ) + # If keylength has not yet been set, all indexes are valid. if self.keylength is None: - # If keylength has not yet been set, all indexes are valid return True + # If index is already stored, always valid. if _ID in self.dict: - # If index is a key, always valid return True - # Otherwise, check key is correct length + # Otherwise, the ID length must match the established keylength. return len(_ID) == self.keylength - def create_index(self, _ID=()): + def create_index(self, _ID: Tuple[int, ...] = ()) -> None: + """Create a new data entry for the given ID if not already existing. - assert isinstance(_ID, tuple), f"Data index {_ID} is not a tuple" + Each newly created index is associated with a new + `DeepTrackDataObject`. If `_ID` is already in `dict`, no new entry is + created. + + If `keylength` is `None`, it is set to the length of `_ID`. Once + established, all subsequently created IDs must have this same length. + + Parameters + ---------- + _ID : Tuple[int, ...], optional + A tuple of integers representing the ID for the data entry. + Default is `()`, which represents a root-level data entry with no + nesting. + + Raises + ------ + AssertionError + - If `_ID` is not a tuple of integers. + - If `_ID` is not valid for the current configuration. + + """ + + # Check if the given `_ID` is valid. + # (Also: Ensure `_ID` is a tuple of integers.) + assert self.valid_index(_ID), ( + f"{_ID} is not a valid index for current dictionary configuration." + ) + # If `_ID` already exists, do nothing. if _ID in self.dict: return - assert self.valid_index(_ID), f"{_ID} is not a valid index for dict {self.dict}" + # Create a new DeepTrackDataObject for this ID. + self.dict[_ID] = DeepTrackDataObject() + # If `keylength` is not set, initialize it with current ID's length. if self.keylength is None: self.keylength = len(_ID) - self.dict[_ID] = DeepTrackDataObject() + def __getitem__( + self, + _ID: Tuple[int, ...], + ) -> Union[ + DeepTrackDataObject, + Dict[Tuple[int, ...], DeepTrackDataObject] + ]: + """Retrieve data associated with a given ID. + + Parameters + ---------- + _ID : Tuple[int, ...] + The ID for the requested data. + + Returns + ------- + DeepTrackDataObject or Dict[Tuple[int, ...], DeepTrackDataObject] + If `_ID` matches `keylength`, returns the corresponding + `DeepTrackDataObject`. + If `_ID` is longer than `keylength`, the request is trimmed to + match `keylength`. + If `_ID` is shorter than `keylength`, returns a dict of all entries + whose IDs match the given `_ID` prefix. + + Raises + ------ + AssertionError + If `_ID` is not a tuple of integers. + KeyError + If the dictionary is empty (`keylength` is `None`). + + """ - def __getitem__(self, _ID): - assert isinstance(_ID, tuple), f"Data index {_ID} is not a tuple" + # Ensure `_ID` is a tuple of integers. + assert isinstance(_ID, tuple), ( + f"Data index {_ID} is not a tuple. Got: {type(_ID).__name__}." + ) + assert all(isinstance(i, int) for i in _ID), ( + f"Data index {_ID} is not a tuple of integers. " + f"Got a tuple of types: {[type(i).__name__ for i in _ID]}." + ) if self.keylength is None: - raise KeyError("Indexing an empty dict") - + raise KeyError("Attempting to index an empty dict.") + + # If ID matches keylength, returns corresponding DeepTrackDataObject. if len(_ID) == self.keylength: return self.dict[_ID] - elif len(_ID) > self.keylength: + # If ID longer than keylength, trim the requested ID. + if len(_ID) > self.keylength: return self[_ID[: self.keylength]] - else: - return {k: v for k, v in self.dict.items() if k[: len(_ID)] == _ID} + # If ID longer than keylength, return a slice of all matching items. + return {k: v for k, v in self.dict.items() if k[: len(_ID)] == _ID} + + def __contains__(self, _ID: Tuple[int, ...]) -> bool: + """Check if a given ID exists in the dictionary. + + Parameters + ---------- + _ID : Tuple[int, ...] + The ID to check. + + Returns + ------- + bool + `True` if the ID exists, `False` otherwise. + + Raises + ------ + AssertionError + If `_ID` is not a tuple of integers. + + """ + + # Ensure `_ID` is a tuple of integers. + assert isinstance(_ID, tuple), ( + f"Data index {_ID} is not a tuple. Got: {type(_ID).__name__}." + ) + assert all(isinstance(i, int) for i in _ID), ( + f"Data index {_ID} is not a tuple of integers. " + f"Got a tuple of types: {[type(i).__name__ for i in _ID]}." + ) - def __contains__(self, _ID): return _ID in self.dict class DeepTrackNode: """Object corresponding to a node in a computation graph. - This is a base class for all nodes in a computation graph. It is used to store and compute data. - When evaluated, the node will call the `action` method. The action method defines a way to calculate the next data. - If the data is already present, it will not be recalculated. + `DeepTrackNode` represents a node within a DeepTrack2 computation graph. + In the DeepTrack2 computation graph, each node can store data and compute + new values based on its dependencies. The value of a node is computed by + calling its `action` method. + + Attributes + ---------- + data : DeepTrackDataDict + Dictionary-like object for storing data, indexed by tuples of integers. + children : WeakSet[DeepTrackNode] + Nodes that depend on this node (its parents, grandparents, etc.). + dependencies : WeakSet[DeepTrackNode] + Nodes on which this node depends (its children, grandchildren, etc.). + _action : Callable + The function or lambda-function to compute the node value. + _accepts_ID : bool + Whether `action` accepts an input ID. + _all_subchildren : Set[DeepTrackNode] + All nodes in the subtree rooted at this node, including itself. + citations : List[str] + Citations associated with this node. + + Methods + ------- + action : property + Gets or sets the computation function for the node. + add_child(child: DeepTrackNode) -> DeepTrackNode + Adds a child node that depends on this node. + Also adds the dependency between the two nodes. + add_dependency(other: DeepTrackNode) -> DeepTrackNode + Adds a dependency, making this node depend on the given node. + store(data: Any, _ID: Tuple[int, ...] = ()) -> DeepTrackNode + Stores computed data for the given `_ID`. + is_valid(_ID: Tuple[int, ...] = ()) -> bool + Checks if the data for the given `_ID` is valid. + valid_index(_ID: Tuple[int, ...]) -> bool + Checks if the given `_ID` is valid for this node. + invalidate(_ID: Tuple[int, ...] = ()) -> DeepTrackNode + Invalidates the data for the given `_ID` and all child nodes. + validate(_ID: Tuple[int, ...] = ()) -> DeepTrackNode + Validates the data for the given `_ID`, marking it as up-to-date, but + not its children. + _update() -> DeepTrackNode + Internal method to reset data. + set_value(value: Any, _ID: Tuple[int, ...] = ()) -> DeepTrackNode + Sets a value for the given `_ID`. If the new value differs from the + current value, the node is invalidated to ensure dependencies are + recomputed. + previous(_ID: Tuple[int, ...] = ()) -> Any + Returns the previously stored value for the given `_ID` without + recomputing it. + recurse_children( + memory: Optional[Set[DeepTrackNode]] = None + ) -> Set[DeepTrackNode] + Returns all child nodes in the dependency tree rooted at this node. + recurse_dependencies( + memory: Optional[List[DeepTrackNode]] = None + ) -> Iterator[DeepTrackNode] + Yields all nodes that this node depends on, traversing dependencies. + get_citations() -> Set[str] + Returns a set of citations for this node and its dependencies. + __call__(_ID: Tuple[int, ...] = ()) -> Any + Evaluates the node's computation for the given `_ID`, recomputing if + necessary. + current_value(_ID: Tuple[int, ...] = ()) -> Any + Returns the currently stored value for the given `_ID` without + recomputation. + __hash__() -> int + Returns a unique hash for this node. + __getitem__(idx: Any) -> DeepTrackNode + Creates a new node that indexes into this node’s computed data. + + Example + ------- + Create two `DeepTrackNode` objects: + + >>> parent = DeepTrackNode(action=lambda: 10) + >>> child = DeepTrackNode(action=lambda _ID=None: parent(_ID) * 2) + + First, establish the dependency between `parent` and `child`: + + >>> parent.add_child(child) + + Store values in the parent node for specific IDs: + + >>> parent.store(15, _ID=(0,)) + >>> parent.store(20, _ID=(1,)) + + Compute the values for the child node based on these parent values: + + >>> child_value_0 = child(_ID=(0,)) + >>> child_value_1 = child(_ID=(1,)) + >>> print(child_value_0, child_value_1) + 30 40 + + Invalidate the parent data for a specific ID: + + >>> parent.invalidate((0,)) + >>> print(parent.is_valid((0,))) + False + >>> print(child.is_valid((0,))) + False + + Update the parent value and recompute the child value: + + >>> parent.store(25, _ID=(0,)) + >>> child_value_recomputed = child(_ID=(0,)) + >>> print(child_value_recomputed) + 50 + """ - __nonelike_default = object() + # Attributes. + data: DeepTrackDataDict + children: WeakSet #TODO + # From Python 3.9, change to WeakSet['DeepTrackNode'] + dependencies: WeakSet #TODO + # From Python 3.9, change to WeakSet['DeepTrackNode'] + _action: Callable[..., Any] + _accepts_ID: bool + _all_subchildren: Set['DeepTrackNode'] - citation = deeptrack_bibtex + # Citations associated with this DeepTrack2. + citations: List[str] = [citation_midtvet2021quantitative] @property - def action(self): + def action(self) -> Callable[..., Any]: + """Callable: The function that computes this node’s value. + + When accessed, returns the current action. This is often a function or + lambda-function that takes `_ID` as an optional parameter if + `_accepts_ID` is True. + + """ + return self._action - + @action.setter - def action(self, value): + def action(self, value: Callable[..., Any]) -> None: + """Set the action used to compute this node’s value. + + Parameters + ---------- + value : Callable[..., Any] + A function or lambda to be used for computing the node’s value. If + the function’s signature includes `_ID`, this node will pass `_ID` + when calling `action`. + """ self._action = value - self._accepts_ID = utils.get_kwarg_names(value).__contains__("_ID") + self._accepts_ID = "_ID" in utils.get_kwarg_names(value) + + def __init__( + self, + action: Optional[Callable[..., Any]] = None, + **kwargs: Any, + ): + """Initialize a new DeepTrackNode. + + Parameters + ---------- + action : Callable or Any, optional + Action to compute this node’s value. If not provided, uses a no-op + action (lambda: None). + + **kwargs : dict + Additional arguments for subclasses or extended functionality. + + """ - def __init__(self, action=__nonelike_default, **kwargs): self.data = DeepTrackDataDict() self.children = WeakSet() self.dependencies = WeakSet() - self._action = lambda: None + self._action = lambda: None # Default no-op action. - if action is not self.__nonelike_default: + # If action is provided, set it. + # If it's callable, use it directly; + # otherwise, wrap it in a lambda. + if action is not None: if callable(action): self.action = action else: self.action = lambda: action - self._accepts_ID = utils.get_kwarg_names(self.action).__contains__("_ID") + # Check if action accepts `_ID`. + self._accepts_ID = "_ID" in utils.get_kwarg_names(self.action) + + # Call super init in case of multiple inheritance. super().__init__(**kwargs) + # Keep track of all subchildren, including this node. self._all_subchildren = set() self._all_subchildren.add(self) - def add_child(self, other): - self.children.add(other) - if not self in other.dependencies: - other.add_dependency(self) + def add_child(self, child: 'DeepTrackNode') -> 'DeepTrackNode': + """Add a child node to the current node. + + Adding a child also updates `_all_subchildren` for this node and all + its dependencies. It also ensures that dependency and child + relationships remain consistent. + + Parameters + ---------- + child : DeepTrackNode + The child node that depends on this node. - subchildren = other._all_subchildren.copy() - subchildren.add(other) + Returns + ------- + self : DeepTrackNode + Returns the current node for chaining. + + """ + + self.children.add(child) + if self not in child.dependencies: + child.add_dependency(self) # Ensure bidirectional relationship. + + # Get all subchildren of `child` and add `child` itself. + subchildren = child._all_subchildren.copy() + subchildren.add(child) + # Merge all these subchildren into this node’s subtree. self._all_subchildren = self._all_subchildren.union(subchildren) for parent in self.recurse_dependencies(): - parent._all_subchildren = parent._all_subchildren.union(subchildren) + parent._all_subchildren = \ + parent._all_subchildren.union(subchildren) return self - def add_dependency(self, other): - self.dependencies.add(other) - other.add_child(self) + def add_dependency(self, parent: 'DeepTrackNode') -> 'DeepTrackNode': + """Adds a dependency, making this node depend on a parent node. + + Parameters + ---------- + parent : DeepTrackNode + The parent node that this node depends on. If `parent` changes, + this node’s data may become invalid. + + Returns + ------- + self : DeepTrackNode + Returns the current node for chaining. + + """ + + self.dependencies.add(parent) + + parent.add_child(self) # Ensure the child relationship is also set. return self - def store(self, data, _ID=()): + def store(self, data: Any, _ID: Tuple[int, ...] = ()) -> 'DeepTrackNode': + """Store computed data in this node. + + Parameters + ---------- + data : Any + The data to be store. + _ID : Tuple[int, ...], optional + The index for this data. Default is the empty tuple (), indicating + a root-level entry. + + Returns + ------- + self : DeepTrackNode + Returns the current node for chaining. + + """ + # Create the index if necessary, then store data in it. self.data.create_index(_ID) self.data[_ID].store(data) return self - def is_valid(self, _ID=()): + def is_valid(self, _ID: Tuple[int, ...] = ()) -> bool: + """Check if data for the given ID is valid. + + Parameters + ---------- + _ID : Tuple[int, ...], optional + The ID to check validity for. + + Returns + ------- + bool + `True` if data at `_ID` is valid, otherwise `False`. + + """ + try: return self.data[_ID].is_valid() except (KeyError, AttributeError): return False - def valid_index(self, _ID): + def valid_index(self, _ID: Tuple[int, ...]) -> bool: + """Check if ID is a valid index for this node’s data. + + Parameters + ---------- + _ID : Tuple[int, ...] + The ID to validate. + + Returns + ------- + bool + `True` if `_ID` is valid, otherwise `False`. + + """ return self.data.valid_index(_ID) - def invalidate(self, _ID=()): + def invalidate(self, _ID: Tuple[int, ...] = ()) -> 'DeepTrackNode': + """Mark this node’s data and all its children’s data as invalid. + + Parameters + ---------- + _ID : Tuple[int, ...], optional + The ID to invalidate. Default is empty tuple, indicating + potentially the full dataset. + + Returns + ------- + self : DeepTrackNode + Returns the current node for chaining. + + Note + ---- + At the moment, the code to invalidate specific IDs is not implemented, + so the _ID parameter is not effectively used. + + """ + + # Invalidate data for all children of this node. + for child in self.recurse_children(): child.data.invalidate() return self - def validate(self, _ID=()): - for child in self.recurse_children(): - try: - child.data[_ID].validate() - except KeyError: - pass + def validate(self, _ID: Tuple[int, ...] = ()) -> 'DeepTrackNode': + """Mark this node’s data as valid. + + Parameters + ---------- + _ID : Tuple[int, ...], optional + The ID to validate. Default is empty tuple. + + Returns + ------- + self : DeepTrackNode + + """ + + self.data[_ID].validate() return self - def _update(self): - # Pre-instantiate memory for optimization + def _update(self) -> 'DeepTrackNode': + """Internal method to reset data in all dependent children. + + This method resets `data` for all children of each dependency, + effectively clearing cached values to force a recomputation on the next + evaluation. + + Returns + ------- + self : DeepTrackNode + Returns the current node for chaining. + + """ + + # Pre-instantiate memory for optimization used to avoid repeated + # processing of the same nodes. child_memory = [] + # For each dependency, reset data in all of its children. for dependency in self.recurse_dependencies(): for dep_child in dependency.recurse_children(memory=child_memory): dep_child.data = DeepTrackDataDict() return self - def set_value(self, value, _ID=()): + def set_value(self, value, _ID: Tuple[int, ...] = ()) -> 'DeepTrackNode': + """Set a value for this node’s data at ID. - # If set to same value, no need to invalidate + If the value is different from the currently stored one (or if it is + invalid), it will invalidate the old data before storing the new one. + Parameters + ---------- + value : Any + The value to store. + _ID : Tuple[int, ...], optional + The ID at which to store the value. + + Returns + ------- + self : DeepTrackNode + Returns the current node for chaining. + + """ + + # Check if current value is equivalent. If not, invalidate and store + # the new value. If set to same value, no need to invalidate. if not ( - self.is_valid(_ID=_ID) and equivalent(value, self.data[_ID].current_value()) + self.is_valid(_ID=_ID) + and _equivalent(value, self.data[_ID].current_value()) ): self.invalidate(_ID=_ID) self.store(value, _ID=_ID) return self - def previous(self, _ID=()): + def previous(self, _ID: Tuple[int, ...] = ()) -> Any: + """Retrieve the previously stored value at ID without recomputing. + + Parameters + ---------- + _ID : Tuple[int, ...], optional + The ID for which to retrieve the previous value. + + Returns + ------- + Any + The previously stored value if `_ID` is valid. + Returns `[]` if `_ID` is not a valid index. + + """ + if self.data.valid_index(_ID): return self.data[_ID].current_value() else: return [] - def recurse_children(self, memory=set()): + def recurse_children( + self, + memory: Optional[Set['DeepTrackNode']] = None, + ) -> Set['DeepTrackNode']: + """Return all subchildren of this node. + + Parameters + ---------- + memory : set, optional + Memory set to track visited nodes (not used directly here). + + Returns + ------- + set + All nodes in the subtree rooted at this node, including itself. + """ + + # Simply return `_all_subchildren` since it's maintained incrementally. return self._all_subchildren - def old_recurse_children(self, memory=None): - # On first call, instantiate memory + def old_recurse_children( + self, + memory: Optional[List['DeepTrackNode']] = None, + ) -> Iterator['DeepTrackNode']: + """Legacy recursive method for traversing children. + + Parameters + ---------- + memory : list, optional + A list to remember visited nodes, ensuring that each node is + yielded only once. + + Yields + ------ + DeepTrackNode + Yields each node in a depth-first traversal. + + Notes + ----- + This method is kept for backward compatibility or debugging purposes. + + """ + + # On first call, instantiate memory. if memory is None: memory = [] - # Make sure each DeepTrackNode is only yielded once + # Make sure each DeepTrackNode is only yielded once. if self in memory: return - # Remember self + # Remember self. memory.append(self) - # Yield self and recurse children + # Yield self and recurse children. yield self + # Recursively traverse children. for child in self.children: yield from child.recurse_children(memory=memory) - def recurse_dependencies(self, memory=None): - # On first call, instantiate memory + def recurse_dependencies( + self, + memory: Optional[List['DeepTrackNode']] = None, + ) -> Iterator['DeepTrackNode']: + """Yield all dependencies of this node, ensuring each is visited once. + + Parameters + ---------- + memory : list, optional + A list of visited nodes to avoid repeated visits or infinite loops. + + Yields + ------ + DeepTrackNode + Yields this node and all nodes it depends on. + + """ + + # On first call, instantiate memory. if memory is None: memory = [] - # Make sure each DeepTrackNode is only yielded once + # Make sure each DeepTrackNode is only yielded once. if self in memory: return - # Remember self + # Remember self. memory.append(self) - # Yield self and recurse dependencies + # Yield self and recurse dependencies. yield self + # Recursively yield dependencies. for dependency in self.dependencies: yield from dependency.recurse_dependencies(memory=memory) - def get_citations(self): - """Gets a set of citations for all objects in a pipeline.""" - cites = {self.citation} - for dep in self.recurse_dependencies(): - for obj in type(dep).mro(): - if hasattr(obj, "citation"): - cites.add(obj.citation) + def get_citations(self) -> Set[str]: + """Get citations from this node and all its dependencies. - return cites + It gathers citations from this node and all nodes that it depends on. + Citations are stored as a class attribute `citations`. - def __call__(self, _ID=()): + Returns + ------- + Set[str] + Set of all citations relevant to this node and its dependency tree. + + """ + + # Initialize citations as a set of elements from self.citations. + citations = set(self.citations) if self.citations else set() + + # Recurse through dependencies to collect all citations. + for dep in self.recurse_dependencies(): + for obj in type(dep).mro(): + if hasattr(obj, "citations"): + # Add the citations of the current object. + citations.update( + obj.citations if isinstance(obj.citations, list) + else [obj.citations] + ) + + return citations + + def __call__(self, _ID: Tuple[int, ...] = ()) -> Any: + """Evaluate this node at ID. + + If the data at `_ID` is valid, it returns the stored value. Otherwise, + it calls `action` to compute a new value, stores it, and returns it. + + Parameters + ---------- + _ID : Tuple[int, ...], optional + The ID at which to evaluate the node’s action. + + Returns + ------- + Any + The computed or retrieved data for the given `_ID`. + + """ if self.is_valid(_ID): try: return self.current_value(_ID) except KeyError: - pass - + pass # Data might have been invalidated or removed. + + # Call action with or without `_ID` depending on `_accepts_ID`. if self._accepts_ID: new_value = self.action(_ID=_ID) else: new_value = self.action() + # Store the newly computed value. self.store(new_value, _ID=_ID) + return self.current_value(_ID) - def current_value(self, _ID=()): + def current_value(self, _ID: Tuple[int, ...] = ()) -> Any: + """Retrieve the currently stored value at ID. + + Parameters + ---------- + _ID : Tuple[int, ...], optional + The ID at which to retrieve the current value. + + Returns + ------- + Any + The currently stored value for `_ID`. + + """ + return self.data[_ID].current_value() - def __hash__(self): + def __hash__(self) -> int: + """Return a unique hash for this node. + + Uses the node’s `id` to ensure uniqueness. + + """ + return id(self) - def __getitem__(self, idx): + def __getitem__(self, idx: Any) -> 'DeepTrackNode': + """Allow indexing into the node’s computed data. + + Parameters + ---------- + idx : Any + The index applied to the result of evaluating this node. + + Returns + ------- + DeepTrackNode + A new node that, when evaluated, applies `idx` to the result of + `self`. + + Notes + ----- + This effectively creates a node that corresponds to `self(...)[idx]`, + allowing you to select parts of the computed data dynamically. + """ + + # Create a new node whose action indexes into this node’s result. node = DeepTrackNode(lambda _ID=None: self(_ID=_ID)[idx]) - node.add_dependency(self) + self.add_child(node) + # node.add_dependency(self) # Already executed by add_child. + return node - # node-node operators - def __add__(self, other): - return create_node_with_operator(operator.__add__, self, other) + # Node-node operators. + # These methods define arithmetic and comparison operations for + # DeepTrackNode objects. Each operation creates a new DeepTrackNode that + # represents the result of applying the corresponding operator to `self` + # and `other`. The operators are applied lazily and will be computed only + # when the resulting node is evaluated. + + def __add__(self, other: Union['DeepTrackNode', Any]) -> 'DeepTrackNode': + """Add node to another node or value. - def __radd__(self, other): - return create_node_with_operator(operator.__add__, other, self) + Creates a new `DeepTrackNode` representing the addition of the values + produced by this node (`self`) and another node or value (`other`). - def __sub__(self, other): - return create_node_with_operator(operator.__sub__, self, other) + Parameters + ---------- + other : DeepTrackNode or Any + The node or value to add. - def __rsub__(self, other): - return create_node_with_operator(operator.__sub__, other, self) + Returns + ------- + DeepTrackNode + A new node that represents the addition operation (`self + other`). + + """ - def __mul__(self, other): - return create_node_with_operator(operator.__mul__, self, other) + return _create_node_with_operator(operator.__add__, self, other) - def __rmul__(self, other): - return create_node_with_operator(operator.__mul__, other, self) + def __radd__(self, other: Union['DeepTrackNode', Any]) -> 'DeepTrackNode': + """Add other value to node (right-hand). - def __truediv__(self, other): - return create_node_with_operator(operator.__truediv__, self, other) + Creates a new `DeepTrackNode` representing the addition of another + node or value (`other`) to the value produced by this node (`self`). - def __rtruediv__(self, other): - return create_node_with_operator(operator.__truediv__, other, self) + Parameters + ---------- + other : DeepTrackNode or Any + The value or node to add. - def __floordiv__(self, other): - return create_node_with_operator(operator.__floordiv__, self, other) + Returns + ------- + DeepTrackNode + A new node that represents the addition operation (`other + self`). + + """ - def __rfloordiv__(self, other): - return create_node_with_operator(operator.__floordiv__, other, self) + return _create_node_with_operator(operator.__add__, other, self) - def __lt__(self, other): - return create_node_with_operator(operator.__lt__, self, other) + def __sub__(self, other: Union['DeepTrackNode', Any]) -> 'DeepTrackNode': + """Subtract another node or value from node. - def __rlt__(self, other): - return create_node_with_operator(operator.__lt__, other, self) + Creates a new `DeepTrackNode` representing the subtraction of the + values produced by another node or value (`other`) from this node + (`self`). - def __gt__(self, other): - return create_node_with_operator(operator.__gt__, self, other) + Parameters + ---------- + other : DeepTrackNode or Any + The node or value to subtract. - def __rgt__(self, other): - return create_node_with_operator(operator.__gt__, other, self) + Returns + ------- + DeepTrackNode + A new node that represents the subtraction operation + (`self - other`). + + """ - def __le__(self, other): - return create_node_with_operator(operator.__le__, self, other) + return _create_node_with_operator(operator.__sub__, self, other) - def __rle__(self, other): - return create_node_with_operator(operator.__le__, other, self) + def __rsub__(self, other: Union['DeepTrackNode', Any]) -> 'DeepTrackNode': + """Subtract node from other value (right-hand). - def __ge__(self, other): - return create_node_with_operator(operator.__ge__, self, other) + Creates a new `DeepTrackNode` representing the subtraction of the value + produced by this node (`self`) from another node or value (`other`). - def __rge__(self, other): - return create_node_with_operator(operator.__ge__, other, self) + Parameters + ---------- + other : DeepTrackNode or Any + The value or node to subtract from. + Returns + ------- + DeepTrackNode + A new node that represents the subtraction operation + `other - self`). + + """ + + return _create_node_with_operator(operator.__sub__, other, self) + + def __mul__(self, other: Union['DeepTrackNode', Any]) -> 'DeepTrackNode': + """Multiply node by another node or value. + + Creates a new `DeepTrackNode` representing the multiplication of the + values produced by this node (`self`) and another node or value + (`other`). + + Parameters + ---------- + other : DeepTrackNode or Any + The node or value to multiply by. + + Returns + ------- + DeepTrackNode + A new node that represents the multiplication operation + (`self * other`). + + """ + + return _create_node_with_operator(operator.__mul__, self, other) + + def __rmul__(self, other: Union['DeepTrackNode', Any]) -> 'DeepTrackNode': + """Multiply other value by node (right-hand). + + Creates a new `DeepTrackNode` representing the multiplication of + another node or value (`other`) by the value produced by this node + (`self`). + + Parameters + ---------- + other : DeepTrackNode or Any + The value or node to multiply. + + Returns + ------- + DeepTrackNode + A new node that represents the multiplication operation + (`other * self`). + """ + return _create_node_with_operator(operator.__mul__, other, self) + + def __truediv__( + self, + other: Union['DeepTrackNode', Any], + ) -> 'DeepTrackNode': + """Divide node by another node or value. + + Creates a new `DeepTrackNode` representing the division of the value + produced by this node (`self`) by another node or value (`other`). + + Parameters + ---------- + other : DeepTrackNode or Any + The node or value to divide by. + + Returns + ------- + DeepTrackNode + A new node that represents the division operation (`self / other`). + + """ + + return _create_node_with_operator(operator.__truediv__, self, other) + + def __rtruediv__( + self, + other: Union['DeepTrackNode', Any], + ) -> 'DeepTrackNode': + """Divide other value by node (right-hand). + + Creates a new `DeepTrackNode` representing the division of another + node or value (`other`) by the value produced by this node (`self`). + + Parameters + ---------- + other : DeepTrackNode or Any + The value or node to divide. + + Returns + ------- + DeepTrackNode + A new node that represents the division operation (`other / self`). + + """ + + return _create_node_with_operator(operator.__truediv__, other, self) + + def __floordiv__( + self, + other: Union['DeepTrackNode', Any], + ) -> 'DeepTrackNode': + """Perform floor division of node by another node or value. + + Creates a new `DeepTrackNode` representing the floor division of the + value produced by this node (`self`) by another node or value + (`other`). + + Parameters + ---------- + other : DeepTrackNode or Any + The node or value to divide by. + + Returns + ------- + DeepTrackNode + A new node that represents the floor division operation + (`self // other`). + + """ + + return _create_node_with_operator(operator.__floordiv__, self, other) + + def __rfloordiv__( + self, + other: Union['DeepTrackNode', Any], + ) -> 'DeepTrackNode': + """Perform floor division of other value by node (right-hand). + + Creates a new `DeepTrackNode` representing the floor division of + another node or value (`other`) by the value produced by this node + (`self`). + + Parameters + ---------- + other : DeepTrackNode or Any + The value or node to divide. + + Returns + ------- + DeepTrackNode + A new node that represents the floor division operation + (`other // self`). + + """ -def equivalent(a, b): - # This is a bare-bones implementation to check if two objects are equivalent. - # We can implement more cases to reduce updates. + return _create_node_with_operator(operator.__floordiv__, other, self) + def __lt__(self, other: Union['DeepTrackNode', Any]) -> 'DeepTrackNode': + """Check if node is less than another node or value. + + Creates a new `DeepTrackNode` representing the comparison of this node + (`self`) being less than another node or value (`other`). + + Parameters + ---------- + other : DeepTrackNode or Any + The node or value to compare with. + + Returns + ------- + DeepTrackNode + A new node that represents the comparison operation (`self < other`). + + """ + + return _create_node_with_operator(operator.__lt__, self, other) + + def __rlt__(self, other: Union['DeepTrackNode', Any]) -> 'DeepTrackNode': + """Check if other value is less than node (right-hand). + + Creates a new `DeepTrackNode` representing the comparison of another + node or value (`other`) being less than this node (`self`). + + Parameters + ---------- + other : DeepTrackNode or Any + The value or node to compare. + + Returns + ------- + DeepTrackNode + A new node that represents the comparison operation + (`other < self`). + + """ + + return _create_node_with_operator(operator.__lt__, other, self) + + def __gt__(self, other: Union['DeepTrackNode', Any]) -> 'DeepTrackNode': + """Check if node is greater than another node or value. + + Creates a new `DeepTrackNode` representing the comparison of this node + (`self`) being greater than another node or value (`other`). + + Parameters + ---------- + other : DeepTrackNode or Any + The node or value to compare with. + + Returns + ------- + DeepTrackNode + A new node that represents the comparison operation + (`self > other`). + + """ + + return _create_node_with_operator(operator.__gt__, self, other) + + def __rgt__(self, other: Union['DeepTrackNode', Any]) -> 'DeepTrackNode': + """Check if other value is greater than node (right-hand). + + Creates a new `DeepTrackNode` representing the comparison of another + node or value (`other`) being greater than this node (`self`). + + Parameters + ---------- + other : DeepTrackNode or Any + The value or node to compare. + + Returns + ------- + DeepTrackNode + A new node that represents the comparison operation + (`other > self`). + + """ + + return _create_node_with_operator(operator.__gt__, other, self) + + def __le__(self, other: Union['DeepTrackNode', Any]) -> 'DeepTrackNode': + """Check if node is less than or equal to another node or value. + + Creates a new `DeepTrackNode` representing the comparison of this node + (`self`) being less than or equal to another node or value (`other`). + + Parameters + ---------- + other : DeepTrackNode or Any + The node or value to compare with. + + Returns + ------- + DeepTrackNode + A new node that represents the comparison operation + (`self <= other`). + + """ + + return _create_node_with_operator(operator.__le__, self, other) + + def __rle__(self, other: Union['DeepTrackNode', Any]) -> 'DeepTrackNode': + """Check if other value is less than or equal to node (right-hand). + + Creates a new `DeepTrackNode` representing the comparison of another + node or value (`other`) being less than or equal to this node (`self`). + + Parameters + ---------- + other : DeepTrackNode or Any + The value or node to compare. + + Returns + ------- + DeepTrackNode + A new node that represents the comparison operation + (`other <= self`). + + """ + + return _create_node_with_operator(operator.__le__, other, self) + + def __ge__(self, other: Union['DeepTrackNode', Any]) -> 'DeepTrackNode': + """Check if node is greater than or equal to another node or value. + + Creates a new `DeepTrackNode` representing the comparison of this node + (`self`) being greater than or equal to another node or value + (`other`). + + Parameters + ---------- + other : DeepTrackNode or Any + The node or value to compare with. + + Returns + ------- + DeepTrackNode + A new node that represents the comparison operation + (`self >= other`). + + """ + + return _create_node_with_operator(operator.__ge__, self, other) + + def __rge__(self, other: Union['DeepTrackNode', Any]) -> 'DeepTrackNode': + """Check if other value is greater than or equal to node (right-hand). + + Creates a new `DeepTrackNode` representing the comparison of another + node or value (`other`) being greater than or equal to this node + (`self`). + + Parameters + ---------- + other : DeepTrackNode or Any + The value or node to compare. + + Returns + ------- + DeepTrackNode + A new node that represents the comparison operation + (`other >= self`). + + """ + + return _create_node_with_operator(operator.__ge__, other, self) + + +def _equivalent(a, b): + """Check if two objects are equivalent. + + This internal helper function provides a basic implementation to determine + equivalence between two objects: + - If `a` and `b` are the same object (identity check), they are considered + equivalent. + - If both `a` and `b` are empty lists, they are considered equivalent. + Additional cases can be implemented as needed to refine this behavior. + + Parameters + ---------- + a : Any + The first object to compare. + b : Any + The second object to compare. + + Returns + ------- + bool + `True` if the objects are equivalent, `False` otherwise. + + """ + + # If a and b are the same object, return True. if a is b: return True + # If a and b are empty lists, consider them identical. if isinstance(a, list) and isinstance(b, list): return len(a) == 0 and len(b) == 0 + # Otherwise, return False. return False -def create_node_with_operator(op, a, b): - """Creates a new node with the given operator and operands.""" +def _create_node_with_operator(op, a, b): + """Create a new computation node using a given operator and operands. + + This internal helper function constructs a `DeepTrackNode` obtained from + the application of the specified operator to two operands. If the operands + are not already `DeepTrackNode` instances, they are converted to nodes. + + This function also establishes bidirectional relationships between the new + node and its operands: + - The new node is added as a child of the operands `a` and `b`. + - The operands `a` and `b` are added as dependencies of the new node. + - The operator `op` is applied lazily, meaning it will be evaluated when + the new node is called, for computational efficiency. + + Parameters + ---------- + op : Callable + The operator function. + a : Any + First operand. If not a `DeepTrackNode`, it will be wrapped in one. + b : Any + Second operand. If not a `DeepTrackNode`, it will be wrapped in one. + + Returns + ------- + DeepTrackNode + A new `DeepTrackNode` containing the result of applying the operator + `op` to the values of nodes `a` and `b`. + + Raises + ------ + TypeError + If any of the operand is not a `DeepTrackNode` or a callable. + + """ + # Ensure `a` is a `DeepTrackNode`. Wrap it if necessary. if not isinstance(a, DeepTrackNode): - a = DeepTrackNode(a) + if callable(a): + a = DeepTrackNode(a) + else: + raise TypeError("Operand 'a' must be callable or a DeepTrackNode, " + f"got {type(a).__name__}.") + + # Ensure `b` is a `DeepTrackNode`. Wrap it if necessary. + if not isinstance(b, DeepTrackNode): + if callable(b): + b = DeepTrackNode(b) + else: + raise TypeError("Operand 'b' must be callable or a DeepTrackNode, " + f"got {type(b).__name__}.") + + # New node that applies the operator `op` to the values of `a` and `b`. + new_node = DeepTrackNode(lambda _ID=(): op(a(_ID=_ID), b(_ID=_ID))) - if not isinstance(b, DeepTrackNode): - b = DeepTrackNode(b) + # Set the new node as a child of both `a` and `b`. + # (Also: Establish dependency relationships between the nodes.) + a.add_child(new_node) + b.add_child(new_node) - new = DeepTrackNode(lambda _ID=(): op(a(_ID=_ID), b(_ID=_ID))) - new.add_dependency(a) - new.add_dependency(b) - a.add_child(new) - b.add_child(new) + # Establish dependency relationships between the nodes. + # (Not needed because already done implicitly above.) + # new_node.add_dependency(a) + # new_node.add_dependency(b) - return new + return new_node diff --git a/deeptrack/benchmarks/test_arithmetics.py b/deeptrack/benchmarks/test_arithmetics.py deleted file mode 100644 index 53185dcca..000000000 --- a/deeptrack/benchmarks/test_arithmetics.py +++ /dev/null @@ -1,41 +0,0 @@ -import sys -import numpy as np -import itertools -import deeptrack as dt -import pytest -import itertools - - -u = dt.units - - -def create_pipeline(elements=1024): - - value = dt.Value(np.zeros((elements,))) - - value = value + 14 - - value = value * (np.ones((elements,)) * 2) - - value = value / 1.5 - - value = value ** 2 - - return value - - -@pytest.mark.parametrize( - "elements,gpu", - [*itertools.product((1000, 5000, 10000, 50000, 100000, 500000), [True, False])], -) -def test_arithmetic(elements, gpu, benchmark): - benchmark.group = "arithm_{}_elements".format(elements) - benchmark.name = "test_arithmetic_{}".format("gpu" if gpu else "cpu") - if gpu: - dt.config.enable_gpu() - else: - dt.config.disable_gpu() - pipeline = create_pipeline(elements=elements) - benchmark( - lambda: pipeline.update()(), - ) diff --git a/deeptrack/benchmarks/test_fluorescence.py b/deeptrack/benchmarks/test_fluorescence.py deleted file mode 100644 index 0b8766629..000000000 --- a/deeptrack/benchmarks/test_fluorescence.py +++ /dev/null @@ -1,47 +0,0 @@ -import sys -import numpy as np -import itertools -import deeptrack as dt -import pytest - - -u = dt.units - - -def create_pipeline(output_region=(0, 0, 128, 128), num_particles=1): - - optics = dt.Fluorescence(output_region=output_region) - - mie = dt.Sphere( - radius=2e-6, - refractive_index=1.45, - z=10, - position=lambda: output_region[2:] * np.random.randn(2), - ) - - field = optics(mie ^ num_particles) - return field - - -@pytest.mark.parametrize( - "size,gpu", - [ - *itertools.product( - (64, 256, 512), - [True, False], - ) - ], -) -def test_simulate_mie(size, gpu, benchmark): - benchmark.group = f"fluorescence_{size}_px_image" - benchmark.name = f"test_fluorescence_{'gpu' if gpu else 'cpu'}" - if gpu: - dt.config.enable_gpu() - else: - dt.config.disable_gpu() - pipeline = create_pipeline(output_region=(0, 0, size, size), num_particles=1) - # One cold run for performance - pipeline.update()() - benchmark( - lambda: pipeline.update()(), - ) diff --git a/deeptrack/benchmarks/test_image.py b/deeptrack/benchmarks/test_image.py deleted file mode 100644 index 34c96a2a3..000000000 --- a/deeptrack/benchmarks/test_image.py +++ /dev/null @@ -1,43 +0,0 @@ -import sys -import numpy as np -import itertools -import deeptrack as dt -import pytest -import itertools -from deeptrack.backend._config import cupy as cp - -u = dt.units - - -def create_pipeline(elements=1024): - value = dt.Value(np.zeros((elements,))) - value = value + 14 - value = value * (np.ones((elements,)) * 2) - value = value / 1.5 - value = value ** 2 - return value - - -@pytest.mark.parametrize( - "elements,gpu,image", - [*itertools.product((1000, 10000, 100000, 1000000), [True, False], [True, False])], -) -def test_arithmetic(elements, gpu, image, benchmark): - benchmark.group = "add_{}_elements".format(elements) - benchmark.name = "test_{}_{}".format( - "Image" if image else "array", "gpu" if gpu else "cpu" - ) - - a = np.random.randn(elements) - b = np.random.randn(elements) - - if gpu: - a = cp.array(a) - b = cp.array(b) - if image: - a = dt.image.Image(a) - b = dt.image.Image(b) - - benchmark( - lambda: a + b, - ) diff --git a/deeptrack/benchmarks/test_simulate_mie.py b/deeptrack/benchmarks/test_simulate_mie.py deleted file mode 100644 index 51f60ee89..000000000 --- a/deeptrack/benchmarks/test_simulate_mie.py +++ /dev/null @@ -1,41 +0,0 @@ -import sys -import numpy as np -import itertools -import deeptrack as dt -import pytest - - -u = dt.units - - -def create_pipeline(output_region=(0, 0, 128, 128), num_particles=1): - - optics = dt.Brightfield(output_region=output_region) - - mie = dt.MieSphere( - radius=0.5e-6, - refractive_index=1.45, - z=lambda: np.random.randn() * 10, - position=lambda: output_region[2:] * np.random.randn(2), - L=20, - ) - - field = optics(mie ^ num_particles) - return field - - -@pytest.mark.parametrize( - "size,gpu", - [*itertools.product((64, 128, 256, 512, 728), [True, False])], -) -def test_simulate_mie(size, gpu, benchmark): - benchmark.group = "mie_{}_px_image".format(size) - benchmark.name = "test_simulate_mie_{}".format("gpu" if gpu else "cpu") - if gpu: - dt.config.enable_gpu() - else: - dt.config.disable_gpu() - pipeline = create_pipeline(output_region=(0, 0, size, size), num_particles=2) - benchmark( - lambda: pipeline.update()(), - ) diff --git a/deeptrack/test/__init__.py b/deeptrack/test/__init__.py index 07578d9d9..a4bdf38b8 100644 --- a/deeptrack/test/__init__.py +++ b/deeptrack/test/__init__.py @@ -1,3 +1,5 @@ +from .backend import * + from .test_aberrations import * from .test_augmentations import * from .test_elementwise import * diff --git a/deeptrack/test/backend/__init__.py b/deeptrack/test/backend/__init__.py new file mode 100644 index 000000000..d25beb5a6 --- /dev/null +++ b/deeptrack/test/backend/__init__.py @@ -0,0 +1 @@ +from .test_core import * \ No newline at end of file diff --git a/deeptrack/test/backend/test_core.py b/deeptrack/test/backend/test_core.py new file mode 100644 index 000000000..4521d097d --- /dev/null +++ b/deeptrack/test/backend/test_core.py @@ -0,0 +1,330 @@ +# pylint: disable=C0115:missing-class-docstring +# pylint: disable=C0116:missing-function-docstring +# pylint: disable=C0103:invalid-name + +# Use this only when running the test locally. +# import sys +# sys.path.append(".") # Adds the module to path. + +import unittest + +from deeptrack.backend import core + + +class TestCore(unittest.TestCase): + + def test_DeepTrackDataObject(self): + dataobj = core.DeepTrackDataObject() + + # Test storing and validating data. + dataobj.store(1) + self.assertEqual(dataobj.current_value(), 1) + self.assertEqual(dataobj.is_valid(), True) + + # Test invalidating data. + dataobj.invalidate() + self.assertEqual(dataobj.current_value(), 1) + self.assertEqual(dataobj.is_valid(), False) + + # Test validating data. + dataobj.validate() + self.assertEqual(dataobj.current_value(), 1) + self.assertEqual(dataobj.is_valid(), True) + + + def test_DeepTrackDataDict(self): + dataset = core.DeepTrackDataDict() + + # Test initial state. + self.assertEqual(dataset.keylength, None) + self.assertFalse(dataset.dict) + + # Create indices and store data. + dataset.create_index((0,)) + dataset[(0,)].store({"image": [1, 2, 3], "label": 0}) + + dataset.create_index((1,)) + dataset[(1,)].store({"image": [4, 5, 6], "label": 1}) + + self.assertEqual(dataset.keylength, 1) + self.assertEqual(len(dataset.dict), 2) + self.assertIn((0,), dataset.dict) + self.assertIn((1,), dataset.dict) + + # Test retrieving stored data. + self.assertEqual(dataset[(0,)].current_value(), + {"image": [1, 2, 3], "label": 0}) + self.assertEqual(dataset[(1,)].current_value(), + {"image": [4, 5, 6], "label": 1}) + + # Test validation and invalidation - all. + self.assertTrue(dataset[(0,)].is_valid()) + self.assertTrue(dataset[(1,)].is_valid()) + + dataset.invalidate() + self.assertFalse(dataset[(0,)].is_valid()) + self.assertFalse(dataset[(1,)].is_valid()) + + dataset.validate() + self.assertTrue(dataset[(0,)].is_valid()) + self.assertTrue(dataset[(1,)].is_valid()) + + # Test validation and invalidation - single node. + self.assertTrue(dataset[(0,)].is_valid()) + + dataset[(0,)].invalidate() + self.assertFalse(dataset[(0,)].is_valid()) + self.assertTrue(dataset[(1,)].is_valid()) + + dataset[(1,)].invalidate() + self.assertFalse(dataset[(0,)].is_valid()) + self.assertFalse(dataset[(1,)].is_valid()) + + dataset[(0,)].validate() + self.assertTrue(dataset[(0,)].is_valid()) + self.assertFalse(dataset[(1,)].is_valid()) + + dataset[(1,)].validate() + self.assertTrue(dataset[(0,)].is_valid()) + self.assertTrue(dataset[(1,)].is_valid()) + + # Test iteration over entries. + for key, value in dataset.dict.items(): + self.assertIn(key, {(0,), (1,)}) + self.assertIsInstance(value, core.DeepTrackDataObject) + + + def test_DeepTrackNode_basics(self): + node = core.DeepTrackNode(action=lambda: 42) + + # Evaluate the node. + result = node() # Value is calculated and stored. + self.assertEqual(result, 42) + + # Store a value. + node.store(100) # Value is stored. + self.assertEqual(node.current_value(), 100) + self.assertTrue(node.is_valid()) + + # Invalidate the node and check the value. + node.invalidate() + self.assertFalse(node.is_valid()) + + self.assertEqual(node.current_value(), 100) # Value is retrieved. + self.assertFalse(node.is_valid()) + + self.assertEqual(node(), 42) # Value is calculated and stored. + self.assertTrue(node.is_valid()) + + + def test_DeepTrackNode_dependencies(self): + parent = core.DeepTrackNode(action=lambda: 10) + child = core.DeepTrackNode(action=lambda _ID=None: parent() * 2) + parent.add_child(child) # Establish dependency. + + # Check that the just create nodes are invalid as not calculated. + self.assertFalse(parent.is_valid()) + self.assertFalse(child.is_valid()) + + # Calculate child, and therefore parent. + result = child() + self.assertEqual(result, 20) + self.assertTrue(parent.is_valid()) + self.assertTrue(child.is_valid()) + + # Invalidate parent and check child validity. + parent.invalidate() + self.assertFalse(parent.is_valid()) + self.assertFalse(child.is_valid()) + + # Validate parent and ensure child is invalid until recomputation. + parent.validate() + self.assertTrue(parent.is_valid()) + self.assertFalse(child.is_valid()) + + # Recompute child and check its validity + child() + self.assertTrue(parent.is_valid()) + self.assertTrue(child.is_valid()) + + def test_DeepTrackNode_nested_dependencies(self): + parent = core.DeepTrackNode(action=lambda: 5) + middle = core.DeepTrackNode(action=lambda: parent() + 5) + child = core.DeepTrackNode(action=lambda: middle() * 2) + + parent.add_child(middle) + middle.add_child(child) + + result = child() + self.assertEqual(result, 20, "Nested computation failed.") + + # Invalidate the middle and check propagation. + middle.invalidate() + self.assertTrue(parent.is_valid()) + self.assertFalse(middle.is_valid()) + self.assertFalse(child.is_valid()) + + + def test_DeepTrackNode_overloading(self): + node1 = core.DeepTrackNode(action=lambda: 5) + node2 = core.DeepTrackNode(action=lambda: 10) + + sum_node = node1 + node2 + self.assertEqual(sum_node(), 15) + + diff_node = node2 - node1 + self.assertEqual(diff_node(), 5) + + prod_node = node1 * node2 + self.assertEqual(prod_node(), 50) + + div_node = node2 / node1 + self.assertEqual(div_node(), 2) + + + def test_DeepTrackNode_citations(self): + node = core.DeepTrackNode(action=lambda: 42) + citations = node.get_citations() + self.assertIn(core.citation_midtvet2021quantitative, citations) + + + def test_DeepTrackNode_single_id(self): + """Test a single _ID on a simple parent-child relationship.""" + + parent = core.DeepTrackNode(action=lambda: 10) + child = core.DeepTrackNode(action=lambda _ID=None: parent(_ID) * 2) + parent.add_child(child) + + # Store value for a specific _ID's. + for id, value in enumerate(range(10)): + parent.store(id, _ID=(id,)) + + # Retrieves the values stored in children and parents. + for id, value in enumerate(range(10)): + self.assertEqual(child(_ID=(id,)), value * 2) + self.assertEqual(parent.previous((id,)), value) + + def test_DeepTrackNode_nested_ids(self): + """Test nested IDs for parent-child relationships.""" + + parent = core.DeepTrackNode(action=lambda: 10) + child = core.DeepTrackNode( + action=lambda _ID=None: parent(_ID[:1]) * _ID[1] + ) + parent.add_child(child) + + # Store values for parent at different IDs. + parent.store(5, _ID=(0,)) + parent.store(10, _ID=(1,)) + + # Compute child values for nested IDs + child_value_0_0 = child(_ID=(0, 0)) # Uses parent(_ID=(0,)). + self.assertEqual(child_value_0_0, 0) + + child_value_0_1 = child(_ID=(0, 1)) # Uses parent(_ID=(0,)). + self.assertEqual(child_value_0_1, 5) + + child_value_1_0 = child(_ID=(1, 0)) # Uses parent(_ID=(1,)). + self.assertEqual(child_value_1_0, 0) + + child_value_1_1 = child(_ID=(1, 1)) # Uses parent(_ID=(1,)). + self.assertEqual(child_value_1_1, 10) + + + def test_DeepTrackNode_replicated_behavior(self): + """Test replicated behavior where IDs expand.""" + + particle = core.DeepTrackNode(action=lambda _ID=None: _ID[0] + 1) + + # Replicate node logic. + cluster = core.DeepTrackNode( + action=lambda _ID=None: particle(_ID=(0,)) + particle(_ID=(1,)) + ) + + cluster_value = cluster() + self.assertEqual(cluster_value, 3) + + def test_DeepTrackNode_parent_id_inheritance(self): + + # Children with IDs matching than parents. + parent_matching = core.DeepTrackNode(action=lambda: 10) + child_matching = core.DeepTrackNode( + action=lambda _ID=None: parent_matching(_ID[:1]) * 2 + ) + parent_matching.add_child(child_matching) + + parent_matching.store(7, _ID=(0,)) + parent_matching.store(5, _ID=(1,)) + + self.assertEqual(child_matching(_ID=(0,)), 14) + self.assertEqual(child_matching(_ID=(1,)), 10) + + # Children with IDs deeper than parents. + parent_deeper = core.DeepTrackNode(action=lambda: 10) + child_deeper = core.DeepTrackNode( + action=lambda _ID=None: parent_deeper(_ID[:1]) * 2 + ) + parent_deeper.add_child(child_deeper) + + parent_deeper.store(7, _ID=(0,)) + parent_deeper.store(5, _ID=(1,)) + + self.assertEqual(child_deeper(_ID=(0, 0)), 14) + self.assertEqual(child_deeper(_ID=(0, 1)), 14) + self.assertEqual(child_deeper(_ID=(0, 2)), 14) + + self.assertEqual(child_deeper(_ID=(1, 0)), 10) + self.assertEqual(child_deeper(_ID=(1, 1)), 10) + self.assertEqual(child_deeper(_ID=(1, 2)), 10) + + def test_DeepTrackNode_invalidation_and_ids(self): + """Test that invalidating a parent affects specific IDs of children.""" + + parent = core.DeepTrackNode(action=lambda: 10) + child = core.DeepTrackNode(action=lambda _ID=None: parent(_ID[:1]) * 2) + parent.add_child(child) + + # Store and compute values. + parent.store(0, _ID=(0,)) + parent.store(1, _ID=(1,)) + child(_ID=(0, 0)) + child(_ID=(0, 1)) + child(_ID=(1, 0)) + child(_ID=(1, 1)) + + # Invalidate the parent at _ID=(0,). + parent.invalidate((0,)) + + self.assertFalse(parent.is_valid((0,))) + self.assertFalse(parent.is_valid((1,))) + self.assertFalse(child.is_valid((0, 0))) + self.assertFalse(child.is_valid((0, 1))) + self.assertFalse(child.is_valid((1, 0))) + self.assertFalse(child.is_valid((1, 1))) + + + def test_DeepTrackNode_dependency_graph_with_ids(self): + """Test a multi-level dependency graph with nested IDs.""" + + A = core.DeepTrackNode(action=lambda: 10) + B = core.DeepTrackNode(action=lambda _ID=None: A(_ID[:-1]) + 5) + C = core.DeepTrackNode( + action=lambda _ID=None: B(_ID[:-1]) * (_ID[-1] + 1) + ) + A.add_child(B) + B.add_child(C) + + # Store values for A at different IDs. + A.store(3, _ID=(0,)) + A.store(4, _ID=(1,)) + + # Compute values for C at nested IDs. + C_0_1_2 = C(_ID=(0, 1, 2)) # B((0, 1)) * (2 + 1) + # (A((0,)) + 5) * (2 + 1) + # (3 + 5) * (2 + 1) + # 24 + self.assertEqual(C_0_1_2, 24) + + +if __name__ == "__main__": + unittest.main() diff --git a/deeptrack/test/test_utils.py b/deeptrack/test/test_utils.py index d45c7508f..f16285bd2 100644 --- a/deeptrack/test/test_utils.py +++ b/deeptrack/test/test_utils.py @@ -1,16 +1,25 @@ +# pylint: disable=C0115:missing-class-docstring +# pylint: disable=C0116:missing-function-docstring +# pylint: disable=C0103:invalid-name + +# Use this only when running the test locally. +# import sys +# sys.path.append(".") # Adds the module to path. + import unittest from .. import utils class TestUtils(unittest.TestCase): - + + def test_hasmethod(self): self.assertTrue(utils.hasmethod(utils, "hasmethod")) self.assertFalse( utils.hasmethod(utils, "this_is_definetely_not_a_method_of_utils") ) - + def test_as_list(self): obj = 1 @@ -98,4 +107,4 @@ def func6(key1, key2=1, key3=3, **kwargs): if __name__ == "__main__": - unittest.main() \ No newline at end of file + unittest.main() diff --git a/deeptrack/types.py b/deeptrack/types.py index 389f410d0..fd9ad1314 100644 --- a/deeptrack/types.py +++ b/deeptrack/types.py @@ -5,11 +5,14 @@ maintainability, and reduces redundancy in type annotations. These types are particularly useful for properties and array-like structures used within the library. + """ -import numpy as np import typing +import numpy as np + + # T is a generic type variable defining generic types for reusability. T = typing.TypeVar("T") diff --git a/deeptrack/utils.py b/deeptrack/utils.py index cfd9d1567..c6504b494 100644 --- a/deeptrack/utils.py +++ b/deeptrack/utils.py @@ -6,17 +6,18 @@ Functions --------- hasmethod(obj: any, method_name: str) -> bool - Return True if the object has a field named `function_name` that is - callable. Otherwise, return False. + Returns True if the object has a field named `function_name` that is + callable. Otherwise, returns False. as_list(obj: any) -> list - If the input is iterable, convert it to list. - Otherwise, wrap the input in a list. + If the input is iterable, converts it to list. + Otherwise, wraps the input in a list. get_kwarg_names(function: Callable) -> List[str] Retrieves the names of the keyword arguments accepted by a function. kwarg_has_default(function: Callable, argument: str) -> bool Checks if a specific argument of a function has a default value. safe_call(function, positional_args=[], **kwargs) - Calls a function, passing only valid arguments from the provided kwargs. + Calls a function, passing only valid arguments from the provided keyword + arguments (kwargs). """ @@ -42,7 +43,7 @@ def hasmethod(obj: any, method_name: str) -> bool: """ - return (hasattr(obj, method_name) + return (hasattr(obj, method_name) and callable(getattr(obj, method_name, None))) @@ -73,7 +74,7 @@ def as_list(obj: any) -> list: def get_kwarg_names(function: Callable) -> List[str]: """Retrieve the names of the keyword arguments accepted by a function. - Retrieves the names of the keyword arguments accepted by `function` as a + It retrieves the names of the keyword arguments accepted by `function` as a list of strings. Parameters @@ -129,8 +130,8 @@ def kwarg_has_default(function: Callable, argument: str) -> bool: def safe_call(function, positional_args=[], **kwargs) -> Any: """Calls a function with valid arguments from a dictionary of arguments. - Filters `kwargs` to include only arguments accepted by the function, - ensuring that no invalid arguments are passed. This function also supports + It filters `kwargs` to include only arguments accepted by the function, + ensuring that no invalid arguments are passed. This function also supports positional arguments. Parameters diff --git a/setup.py b/setup.py index 7e7302b6b..f998beecb 100644 --- a/setup.py +++ b/setup.py @@ -12,7 +12,7 @@ setup( name="deeptrack", - version="2.0.0rc0", + version="2.0.0", license="MIT", packages=find_packages(), author=(