Releases: ASEM000/pytreeclass
0.11
Changelog
v0.11.0
Breaking Changes:
Due to changes in jax tree api, TreeClass no longer treated as named tuple when indexing using AtIndexer/.at only using "names" is valid.
import pytreeclass as tc
class Tree(tc.TreeClass):
def __init__(self, a, b):
self.a = a
self.b = b
tree = Tree(1, 2)
print(tree.at["a"].get()) # 1
print(tree.at[0].get()) # 1 -> no longer validv0.9.2
v0.9.2
Changes:
- change
threads_countinapplyparallel kwargs tomax_workers
Full Changelog: v0.9.1...v0.9.2
v0.9.1
v0.9.1
Additions:
- Add parallel mapping option in
AtIndexer. This enables myriad of tasks, like reading a pytree of image file names.
# benchmarking serial vs sequential image read
# on mac m1 cpu with image of size 512x512x3
import pytreeclass as tc
from matplotlib.pyplot import imread
paths = ["lenna.png"] * 10
indexer = tc.AtIndexer(paths)
%timeit indexer[...].apply(imread,parallel=True) # parallel
# 24.9 ms ± 938 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
%timeit indexer[...].apply(imread) # not parallel
# # 84.8 ms ± 453 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)v0.9
v0.8.0
V0.8
Additions:
- Add
on_getattrinfieldto apply function on__getattr__
Breaking changes:
- Rename
callbacksinfieldtoon_setattrto matchattrsand better reflect its functionality.
These changes enable:
-
stricter data validation on instance values, as in the following example:
on_setattrensure the value is of certain type (e.g.integer) during initialization, andon_getattr, ensure the value is of certain type (e.g. integer) whenever its accessed.import pytreeclass as pytc import jax def assert_int(x): assert isinstance(x, int), "must be an int" return x @pytc.autoinit class Tree(pytc.TreeClass): a: int = pytc.field(on_getattr=[assert_int], on_setattr=[assert_int]) def __call__(self, x): # enusre `a` is an int before using it in computation by calling `assert_int` a: int = self.a return a + x tree = Tree(a=1) print(tree(1.0)) # 2.0 tree = jax.tree_map(lambda x: x + 0.0, tree) # make `a` a float tree(1.0) # AssertionError: must be an int
-
Frozen field without using
tree_mask/tree_unmaskThe following shows a pattern where the value is frozen on
__setattr__and unfrozen whenever accessed, this ensures thatjaxtransformation does not see the value. the following example showcase this functionalityimport pytreeclass as pytc import jax @pytc.autoinit class Tree(pytc.TreeClass): frozen_a : int = pytc.field(on_getattr=[pytc.unfreeze], on_setattr=[pytc.freeze]) def __call__(self, x): return self.frozen_a + x tree = Tree(frozen_a=1) # 1 is non-jaxtype # can be used in jax transformations @jax.jit def f(tree, x): return tree(x) f(tree, 1.0) # 2.0 grads = jax.grad(f)(tree, 1.0) # Tree(frozen_a=#1)
Compared with other libraies that implements
static_field, this pattern has lower overhead and does not altertree_flatten/tree_unflattenmethods of the tree. -
Easier way to create a buffer (non-trainable array)
Just use
jax.lax.stop_gradientinon_getattrimport pytreeclass as pytc import jax import jax.numpy as jnp def assert_array(x): assert isinstance(x, jax.Array) return x @pytc.autoinit class Tree(pytc.TreeClass): buffer: jax.Array = pytc.field(on_getattr=[jax.lax.stop_gradient],on_setattr=[assert_array]) def __call__(self, x): return self.buffer**x tree = Tree(buffer=jnp.array([1.0, 2.0, 3.0])) tree(2.0) # Array([1., 4., 9.], dtype=float32) @jax.jit def f(tree, x): return jnp.sum(tree(x)) f(tree, 1.0) # Array([1., 2., 3.], dtype=float32) print(jax.grad(f)(tree, 1.0)) # Tree(buffer=[0. 0. 0.])
v0.7.0
Changelog
v0.7
- Remove
.atas an alias for__getitem__when specifying a path entry for where inAtIndexer. This leads to less verbose style.
Example:
>>> tree = {"level1_0": {"level2_0": 100, "level2_1": 200}, "level1_1": 300}
>>> tree = pytc.AtIndexer(tree)
>>> # Before:
>>> # style 1 (with at):
>>> tree.at["level1_0"].at["level2_0", "level2_1"].get()
{'level1_0': {'level2_0': 100, 'level2_1': 200}, 'level1_1': None}
>>> # style 2 (no at):
>>> tree["level1_0"]["level2_0", "level2_1"].get()
>>> # After
>>> # only style 2 is valid
>>> tree["level1_0"]["level2_0", "level2_1"].get()For TreeClass
at is specified once for each change
@pytc.autoinit
class Tree(pytc.TreeClass):
a: float = 1.0
b: tuple[float, float] = (2.0, 3.0)
c: jax.Array = jnp.array([4.0, 5.0, 6.0])
def __call__(self, x):
return self.a + self.b[0] + self.c + x
tree = Tree()
mask = jax.tree_map(lambda x: x > 5, tree)
tree = tree\
.at["a"].set(100.0)\
- .at["b"].at[0].set(10.0)\
+ .at["b"][0].set(10.0)\
.at[mask].set(100.0)v0.6.0post0
Changelog
v0.6.0post0
- using
tree_{repr,str}with an object containing cyclic references will raiseRecursionErrorinstead of displaying cyclicref.
v0.6.0
v0.6.0
-
Allow nested mutations using
.at[method](*args, **kwargs).
After the change, inner methods can mutate copied new instances at any level not just the top level.
a motivation for this is to experiment with lazy initialization scheme, where inner layers need to mutate their inner state. see the example below forflax-like lazy initialization as descriped hereimport pytreeclass as pytc import jax.random as jr from typing import Any import jax import jax.numpy as jnp from typing import Callable, TypeVar T = TypeVar("T") @pytc.autoinit class LazyLinear(pytc.TreeClass): outdim: int weight_init: Callable[..., T] = jax.nn.initializers.glorot_normal() bias_init: Callable[..., T] = jax.nn.initializers.zeros def param(self, name: str, init_func: Callable[..., T], *args) -> T: if name not in vars(self): setattr(self, name, init_func(*args)) return vars(self)[name] def __call__(self, x: jax.Array, *, key: jr.KeyArray = jr.PRNGKey(0)): w = self.param("weight", self.weight_init, key, (x.shape[-1], self.outdim)) y = x @ w if self.bias_init is not None: b = self.param("bias", self.bias_init, key, (self.outdim,)) return y + b return y @pytc.autoinit class StackedLinear(pytc.TreeClass): l1: LazyLinear = LazyLinear(outdim=10) l2: LazyLinear = LazyLinear(outdim=1) def call(self, x: jax.Array): return self.l2(jax.nn.relu(self.l1(x))) lazy_layer = StackedLinear() print(repr(lazy_layer)) # StackedLinear( # l1=LazyLinear( # outdim=10, # weight_init=init(key, shape, dtype), # bias_init=zeros(key, shape, dtype) # ), # l2=LazyLinear( # outdim=1, # weight_init=init(key, shape, dtype), # bias_init=zeros(key, shape, dtype) # ) # ) _, materialized_layer = lazy_layer.at["call"](jnp.ones((1, 5))) materialized_layer # StackedLinear( # l1=LazyLinear( # outdim=10, # weight_init=init(key, shape, dtype), # bias_init=zeros(key, shape, dtype), # weight=f32[5,10](μ=-0.04, σ=0.32, ∈[-0.74,0.63]), # bias=f32[10](μ=0.00, σ=0.00, ∈[0.00,0.00]) # ), # l2=LazyLinear( # outdim=1, # weight_init=init(key, shape, dtype), # bias_init=zeros(key, shape, dtype), # weight=f32[10,1](μ=-0.07, σ=0.23, ∈[-0.34,0.34]), # bias=f32[1](μ=0.00, σ=0.00, ∈[0.00,0.00]) # ) # ) materialized_layer(jnp.ones((1, 5))) # Array([[0.16712935]], dtype=float32)
Full Changelog: v0.5...v0.6.0
v0.5.0post0
Fix __init_subclass__ not accepting arguments. Bug introduced since v0.5
Full Changelog: v0.5...v0.0.5post0
v0.5.0
Changelog
PyTreeClass v0.5
Breaking changes
Auto generation of __init__ method from type hints is decoupled from TreeClass
Alternatives
Use:
- Preferably decorate with
pytreeclass.autoinitwithpytreeclass.fieldas field specifier. aspytreeclass.fieldhas more features (e.g.callbacks, multiple argument kind selection) and the init generation is cached compared todataclasses. - decorate with
dataclasses.dataclasswithdataclasses.fieldas field specifier. however :- Must set
fronzen=Falsebecause the__setattr__,__delattr__is handled byTreeClass - Optionally
repr=Falseto be handled byTreeClass - Optionally
eq=hash=Falseas it is handled byTreeClass
- Must set
Beforeimport jax.tree_util as jtu
import pytreeclass as pytc
import dataclasses as dc
class Tree(pytc.TreeClass):
a: int = 1
jtu.tree_leaves(Tree())
# [1] |
AfterEquivalent behavior when decorating with either:
import jax.tree_util as jtu
import pytreeclass as pytc
@pytc.autoinit
class Tree(pytc.TreeClass):
a: int = 1
jtu.tree_leaves(Tree())
# [1] |
This change aims to fix the ambiguity of using the dataclass mental model in the following siutations:
-
subclassing. previously, using
TreeClassas a base class is equivalent to decorating the class withdataclasses.dataclass, however this is a bit challenging to understand as demonstrated in the next example:import pytreeclass as pytc import dataclasses as dc class A(pytc.TreeClass): def ___init__(self, a:int): self.a = a class B(A): ...
When instantiating
B(a=...), an error will be raised, because usingTreeClassis equivalent of decorating all classes with@dataclass, which synthesize the__init__method based on the fields.
Since no fields (e.g. type hinted values) then the synthesized__init__method .The previous code is equivalent to this code.
@dc.dataclass class A: def __init__(self, a:int): self.a = a @dc.dataclass class B: ...
-
dataclass_transformdoes not play nicely with user created__init__see 1, 2
leafwise_transform is decoupled from TreeClass.
instead decorate the class with pytreeclass.leafwise.