Skip to content

Releases: ASEM000/pytreeclass

0.11

11 Jan 18:59
2e4a548

Choose a tag to compare

Changelog

v0.11.0

Breaking Changes:

Due to changes in jax tree api, TreeClass no longer treated as named tuple when indexing using AtIndexer/.at only using "names" is valid.

import pytreeclass as tc
class Tree(tc.TreeClass):
    def __init__(self, a, b):
        self.a = a
        self.b = b

tree = Tree(1, 2)
print(tree.at["a"].get())  # 1
print(tree.at[0].get())  # 1 -> no longer valid

v0.9.2

13 Sep 23:45

Choose a tag to compare

v0.9.2

Changes:

  • change threads_count in apply parallel kwargs to max_workers

Full Changelog: v0.9.1...v0.9.2

v0.9.1

13 Sep 22:09

Choose a tag to compare

v0.9.1

Additions:

  • Add parallel mapping option in AtIndexer. This enables myriad of tasks, like reading a pytree of image file names.
# benchmarking serial vs sequential image read
# on mac m1 cpu with image of size 512x512x3
import pytreeclass as tc
from matplotlib.pyplot import imread
paths = ["lenna.png"] * 10
indexer = tc.AtIndexer(paths)
%timeit indexer[...].apply(imread,parallel=True)  # parallel
# 24.9 ms ± 938 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
%timeit indexer[...].apply(imread)  # not parallel
# # 84.8 ms ± 453 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

v0.9

10 Sep 13:26
7aa50d3

Choose a tag to compare

V0.9

Breaking changes:

  • To simplify the API the following will be removed:
    1. tree_repr_with_trace
    2. tree_map_with_trace
    3. tree_flatten_with_trace
    4. tree_leaves_with_trace

What's Changed

Full Changelog: v0.8.0...v0.9

v0.8.0

06 Sep 10:50
e9deb51

Choose a tag to compare

V0.8

Additions:

  • Add on_getattr in field to apply function on __getattr__

Breaking changes:

  • Rename callbacks in field to on_setattr to match attrs and better reflect its functionality.

These changes enable:

  1. stricter data validation on instance values, as in the following example:

    on_setattr ensure the value is of certain type (e.g.integer) during initialization, and on_getattr, ensure the value is of certain type (e.g. integer) whenever its accessed.

    import pytreeclass as pytc
    import jax
    
    def assert_int(x):
        assert isinstance(x, int), "must be an int"
        return x
    
    @pytc.autoinit
    class Tree(pytc.TreeClass):
        a: int = pytc.field(on_getattr=[assert_int], on_setattr=[assert_int])
    
        def __call__(self, x):
            # enusre `a` is an int before using it in computation by calling `assert_int`
            a: int = self.a
            return a + x
    
    tree = Tree(a=1)
    print(tree(1.0))  # 2.0
    tree = jax.tree_map(lambda x: x + 0.0, tree)  # make `a` a float
    tree(1.0)  # AssertionError: must be an int
  2. Frozen field without using tree_mask/tree_unmask

    The following shows a pattern where the value is frozen on __setattr__ and unfrozen whenever accessed, this ensures that jax transformation does not see the value. the following example showcase this functionality

    import pytreeclass as pytc
    import jax
    
    @pytc.autoinit
    class Tree(pytc.TreeClass):
        frozen_a : int = pytc.field(on_getattr=[pytc.unfreeze], on_setattr=[pytc.freeze])
    
        def __call__(self, x):
            return self.frozen_a + x
    
    tree = Tree(frozen_a=1)  # 1 is non-jaxtype
    # can be used in jax transformations
    
    @jax.jit
    def f(tree, x):
        return tree(x)
    
    f(tree, 1.0)  # 2.0
    
    grads = jax.grad(f)(tree, 1.0)  # Tree(frozen_a=#1)

    Compared with other libraies that implements static_field, this pattern has lower overhead and does not alter tree_flatten/tree_unflatten methods of the tree.

  3. Easier way to create a buffer (non-trainable array)

    Just use jax.lax.stop_gradient in on_getattr

    import pytreeclass as pytc
    import jax
    import jax.numpy as jnp
    
    def assert_array(x):
        assert isinstance(x, jax.Array)
        return x
    
    @pytc.autoinit
    class Tree(pytc.TreeClass):
        buffer: jax.Array = pytc.field(on_getattr=[jax.lax.stop_gradient],on_setattr=[assert_array])
        def __call__(self, x):
            return self.buffer**x
        
    tree = Tree(buffer=jnp.array([1.0, 2.0, 3.0]))
    tree(2.0)  # Array([1., 4., 9.], dtype=float32)
    @jax.jit
    def f(tree, x):
        return jnp.sum(tree(x))
    
    f(tree, 1.0)  # Array([1., 2., 3.], dtype=float32)
    print(jax.grad(f)(tree, 1.0))  # Tree(buffer=[0. 0. 0.])

v0.7.0

02 Sep 18:39

Choose a tag to compare

Changelog

v0.7

  • Remove .at as an alias for __getitem__ when specifying a path entry for where in AtIndexer. This leads to less verbose style.

Example:

>>> tree = {"level1_0": {"level2_0": 100, "level2_1": 200}, "level1_1": 300}
>>> tree = pytc.AtIndexer(tree)

>>> # Before:
>>> # style 1 (with at):
>>> tree.at["level1_0"].at["level2_0", "level2_1"].get()
{'level1_0': {'level2_0': 100, 'level2_1': 200}, 'level1_1': None}
>>> # style 2 (no at):
>>> tree["level1_0"]["level2_0", "level2_1"].get()

>>> # After
>>> # only style 2 is valid
>>> tree["level1_0"]["level2_0", "level2_1"].get()

For TreeClass

at is specified once for each change

@pytc.autoinit
class Tree(pytc.TreeClass):
    a: float = 1.0
    b: tuple[float, float] = (2.0, 3.0)
    c: jax.Array = jnp.array([4.0, 5.0, 6.0])

    def __call__(self, x):
        return self.a + self.b[0] + self.c + x


tree = Tree()
mask = jax.tree_map(lambda x: x > 5, tree)
tree = tree\
       .at["a"].set(100.0)\
-       .at["b"].at[0].set(10.0)\
+      .at["b"][0].set(10.0)\
       .at[mask].set(100.0)

v0.6.0post0

31 Aug 09:24

Choose a tag to compare

Changelog

v0.6.0post0

  • using tree_{repr,str} with an object containing cyclic references will raise RecursionError instead of displaying cyclicref.

v0.6.0

31 Jul 06:53

Choose a tag to compare

v0.6.0

  • Allow nested mutations using .at[method](*args, **kwargs).
    After the change, inner methods can mutate copied new instances at any level not just the top level.
    a motivation for this is to experiment with lazy initialization scheme, where inner layers need to mutate their inner state. see the example below for flax-like lazy initialization as descriped here

    import pytreeclass as pytc
    import jax.random as jr
    from typing import Any
    import jax
    import jax.numpy as jnp
    from typing import Callable, TypeVar
    
    T = TypeVar("T")
    
    @pytc.autoinit
    class LazyLinear(pytc.TreeClass):
        outdim: int
        weight_init: Callable[..., T] = jax.nn.initializers.glorot_normal()
        bias_init: Callable[..., T] = jax.nn.initializers.zeros
    
        def param(self, name: str, init_func: Callable[..., T], *args) -> T:
            if name not in vars(self):
                setattr(self, name, init_func(*args))
            return vars(self)[name]
    
        def __call__(self, x: jax.Array, *, key: jr.KeyArray = jr.PRNGKey(0)):
            w = self.param("weight", self.weight_init, key, (x.shape[-1], self.outdim))
            y = x @ w
            if self.bias_init is not None:
                b = self.param("bias", self.bias_init, key, (self.outdim,))
                return y + b
            return y
    
    
    @pytc.autoinit
    class StackedLinear(pytc.TreeClass):
        l1: LazyLinear = LazyLinear(outdim=10)
        l2: LazyLinear = LazyLinear(outdim=1)
    
        def call(self, x: jax.Array):
            return self.l2(jax.nn.relu(self.l1(x)))
    
    lazy_layer = StackedLinear()
    print(repr(lazy_layer))
    # StackedLinear(
    #   l1=LazyLinear(
    #     outdim=10, 
    #     weight_init=init(key, shape, dtype), 
    #     bias_init=zeros(key, shape, dtype)
    #   ), 
    #   l2=LazyLinear(
    #     outdim=1, 
    #     weight_init=init(key, shape, dtype), 
    #     bias_init=zeros(key, shape, dtype)
    #   )
    # )
    
    _, materialized_layer = lazy_layer.at["call"](jnp.ones((1, 5)))
    materialized_layer
    # StackedLinear(
    #   l1=LazyLinear(
    #     outdim=10, 
    #     weight_init=init(key, shape, dtype), 
    #     bias_init=zeros(key, shape, dtype), 
    #     weight=f32[5,10](μ=-0.04, σ=0.32, ∈[-0.74,0.63]), 
    #     bias=f32[10](μ=0.00, σ=0.00, ∈[0.00,0.00])
    #   ), 
    #   l2=LazyLinear(
    #     outdim=1, 
    #     weight_init=init(key, shape, dtype), 
    #     bias_init=zeros(key, shape, dtype), 
    #     weight=f32[10,1](μ=-0.07, σ=0.23, ∈[-0.34,0.34]), 
    #     bias=f32[1](μ=0.00, σ=0.00, ∈[0.00,0.00])
    #   )
    # )
    
    materialized_layer(jnp.ones((1, 5)))
    # Array([[0.16712935]], dtype=float32)

Full Changelog: v0.5...v0.6.0

v0.5.0post0

27 Jul 21:43

Choose a tag to compare

Fix __init_subclass__ not accepting arguments. Bug introduced since v0.5

Full Changelog: v0.5...v0.0.5post0

v0.5.0

24 Jul 19:32
88978ab

Choose a tag to compare

Changelog

PyTreeClass v0.5

Breaking changes

Auto generation of __init__ method from type hints is decoupled from TreeClass

Alternatives

Use:

  1. Preferably decorate with pytreeclass.autoinit with pytreeclass.field as field specifier. as pytreeclass.field has more features (e.g. callbacks, multiple argument kind selection) and the init generation is cached compared to dataclasses.
  2. decorate with dataclasses.dataclass with dataclasses.field as field specifier. however :
    1. Must set fronzen=False because the __setattr__, __delattr__ is handled by TreeClass
    2. Optionally repr=False to be handled by TreeClass
    3. Optionally eq=hash=False as it is handled by TreeClass

Before

import jax.tree_util as jtu
import pytreeclass as pytc
import dataclasses as dc

class Tree(pytc.TreeClass):
    a: int = 1

jtu.tree_leaves(Tree())
# [1]

After

Equivalent behavior when decorating with either:

  1. @pytreeclass.autoinit
  2. @dataclasses.dataclass
import jax.tree_util as jtu
import pytreeclass as pytc

@pytc.autoinit
class Tree(pytc.TreeClass):
    a: int = 1

jtu.tree_leaves(Tree())
# [1]

This change aims to fix the ambiguity of using the dataclass mental model in the following siutations:

  1. subclassing. previously, using TreeClass as a base class is equivalent to decorating the class with dataclasses.dataclass, however this is a bit challenging to understand as demonstrated in the next example:

    import pytreeclass as pytc
    import dataclasses as dc
    
    class A(pytc.TreeClass):
        def ___init__(self, a:int):
            self.a = a
    
    class B(A):
        ...

    When instantiating B(a=...), an error will be raised, because using TreeClass is equivalent of decorating all classes with @dataclass, which synthesize the __init__ method based on the fields.
    Since no fields (e.g. type hinted values) then the synthesized __init__ method .

    The previous code is equivalent to this code.

    @dc.dataclass
    class A:
        def __init__(self, a:int):
            self.a = a
    @dc.dataclass
    class B:
        ...
  2. dataclass_transform does not play nicely with user created __init__ see 1, 2

leafwise_transform is decoupled from TreeClass.

instead decorate the class with pytreeclass.leafwise.