🌲 Tree utils API#
- class sepes.at(tree, where=None)#
Operate on a pytree at a given path using a path or mask in out-of-place manner.
- Parameters:
tree – pytree to operate on.
where –
one of the following:
str
for mapping keys or class attributes.int
for positional indexing for sequences....
to select all leaves.a boolean mask of the same structure as the tree
re.Pattern
to match a leaf level path with a regex pattern.a tuple of the above to match multiple keys at the same level.
Example
>>> import jax >>> import sepes as sp >>> tree = {"a": 1, "b": [1, 2, 3]} >>> sp.at(tree)["a"].set(100) {'a': 100, 'b': [1, 2, 3]} >>> sp.at(tree)["b"][0].set(100) {'a': 1, 'b': [100, 2, 3]} >>> mask = jax.tree_map(lambda x: x > 1, tree) >>> sp.at(tree)[mask].set(100) {'a': 1, 'b': [1, 100, 100]}
- get(*, is_leaf=None, is_parallel=False, fill_value=<object object>)[source]#
Get the leaf values at the specified location.
- Parameters:
is_leaf (Callable[[Any], bool] | None) – a predicate function to determine if a value is a leaf.
is_parallel (bool | ParallelConfig) –
accepts the following:
bool
: applyfunc
in parallel ifTrue
otherwise in serial.dict
: a dict of of:max_workers
: maximum number of workers to use.kind
: kind of pool to use, eitherthread
orprocess
.
fill_value (Any) – the value to fill the non-selected leaves with. Useful to use with
jax.jit
to avoid variable size arrays leaves related errors.
- Returns:
A _new_ pytree of leaf values at the specified location, with the non-selected leaf values set to None if the leaf is not an array.
Example
>>> import sepes as sp >>> tree = {"a": 1, "b": [1, 2, 3]} >>> sp.at(tree)["b"][0].get() {'a': None, 'b': [1, None, None]}
- set(set_value, *, is_leaf=None, is_parallel=False)[source]#
Set the leaf values at the specified location.
- Parameters:
set_value (Any) – the value to set at the specified location.
is_leaf (Callable[[Any], bool] | None) – a predicate function to determine if a value is a leaf.
is_parallel (bool | ParallelConfig) –
accepts the following:
bool
: applyfunc
in parallel ifTrue
otherwise in serial.dict
: a dict of of:max_workers
: maximum number of workers to use.kind
: kind of pool to use, eitherthread
orprocess
.
- Returns:
A pytree with the leaf values at the specified location set to
set_value
.
Example
>>> import sepes as sp >>> tree = {"a": 1, "b": [1, 2, 3]} >>> sp.at(tree)["b"][0].set(100) {'a': 1, 'b': [100, 2, 3]}
- apply(func, *, is_leaf=None, is_parallel=False)[source]#
Apply a function to the leaf values at the specified location.
- Parameters:
func (Callable[[Any], Any]) – the function to apply to the leaf values.
is_leaf (Callable[[Any], bool] | None) – a predicate function to determine if a value is a leaf.
is_parallel (bool | ParallelConfig) –
accepts the following:
bool
: applyfunc
in parallel ifTrue
otherwise in serial.dict
: a dict of of:max_workers
: maximum number of workers to use.kind
: kind of pool to use, eitherthread
orprocess
.
- Returns:
A pytree with the leaf values at the specified location set to the result of applying
func
to the leaf values.
Example
>>> import sepes as sp >>> tree = {"a": 1, "b": [1, 2, 3]} >>> sp.at(tree)["b"][0].apply(lambda x: x + 100) {'a': 1, 'b': [101, 2, 3]}
Example
Read images in parallel
>>> import sepes as sp >>> from matplotlib.pyplot import imread >>> path = {"img1": "path1.png", "img2": "path2.png"} >>> is_parallel = dict(max_workers=2) >>> images = sp.at(path)[...].apply(imread, is_parallel=is_parallel)
- scan(func, state, *, is_leaf=None)[source]#
Apply a function while carrying a state.
- Parameters:
func – the function to apply to the leaf values. the function accepts a running state and leaf value and returns a tuple of the new leaf value and the new state.
state – the initial state to carry.
is_leaf – a predicate function to determine if a value is a leaf. for example,
lambda x: isinstance(x, list)
will treat all lists as leaves and will not recurse into list items.
- Returns:
A tuple of the final state and pytree with the leaf values at the specified location set to the result of applying
func
to the leaf values.
Example
>>> import sepes as sp >>> tree = {"a": 1, "b": [1, 2, 3]} >>> def scan_func(leaf, running_max): ... cur_max = max(leaf, running_max) ... return leaf, cur_max >>> running_max = float("-inf") >>> _, running_max = sp.at(tree)["b"][0, 1].scan(scan_func, state=running_max) >>> running_max # max of b[0] and b[1] 2
Note
scan
applies a binaryfunc
to the leaf values while carrying a state and returning a tree leaves with the thefunc
applied to them with final state. Whilereduce
applies a binaryfunc
to the leaf values while carrying a state and returning a single value.
- reduce(func, *, initializer=<object object>, is_leaf=None)[source]#
Reduce the leaf values at the specified location.
- Parameters:
func (Callable[[Any, Any], Any]) – the function to reduce the leaf values.
initializer (Any) – the initializer value for the reduction.
is_leaf (Callable[[Any], bool] | None) – a predicate function to determine if a value is a leaf.
- Return type:
Any
- Returns:
The result of reducing the leaf values at the specified location.
Note
If
initializer
is not specified, the first leaf value is used as the initializer.reduce
applies a binaryfunc
to each leaf values while accumulating a state a returns the final result. whilescan
appliesfunc
to each leaf value while carrying a state and returns the final state and the leaves of the tree with the result of applyingfunc
to each leaf.
Example
>>> import sepes as sp >>> tree = {"a": 1, "b": [1, 2, 3]} >>> sp.at(tree)["b"].reduce(lambda x, y: x + y) 6
- pluck(count=None, *, is_leaf=None, is_parallel=False)[source]#
Extract subtrees at the specified location.
Note
pluck
first appliesget
to the specified location and then extracts the immediate subtrees of the selected leaves.is_leaf
andis_parallel
are passed toget
.- Parameters:
count (int | None) – number of subtrees to extract, Default to
None
to extract all subtrees.is_leaf (Callable[[Any], bool] | None) – a predicate function to determine if a value is a leaf.
is_parallel (bool | ParallelConfig) –
accepts the following:
bool
: applyfunc
in parallel ifTrue
otherwise in serial.dict
: a dict of of:max_workers
: maximum number of workers to use.kind
: kind of pool to use, eitherthread
orprocess
.
- Return type:
list[Any]
- Returns:
A list of subtrees at the specified location.
Note
Compared to
get
,pluck
extracts subtrees at the specified location and returns a list of subtrees. Whileget
returns a pytree with the leaf values at the specified location and set the non-selected leaf values toNone
.Example
>>> import sepes as sp >>> tree = {"a": 1, "b": [1, 2, 3]} >>> # `pluck` returns a list of selected subtrees >>> sp.at(tree)["b"].pluck() [[1, 2, 3]] >>> # `get` returns same pytree >>> sp.at(tree)["b"].get() {'a': None, 'b': [1, 2, 3]}
Example
pluck
with mask>>> import sepes as sp >>> tree = {"a": 1, "b": [2, 3, 4]} >>> mask = {"a": True, "b": [False, True, False]} >>> sp.at(tree)[mask].pluck() [1, 3]
This is equivalent to the following:
>>> [tree["a"], tree["b"][1]]
- sepes.value_and_tree(func, argnums=0)[source]#
Call a function on copied input argument and return the value and the tree.
Input arguments are copied before calling the function, and the argument specified by
argnums
are returned as a tree.- Parameters:
func (Callable[..., T]) – A function.
argnums (int | Sequence[int]) – The argument number of the tree that will be returned. If multiple arguments are specified, the tree will be returned as a tuple.
- Returns:
A function that returns the value and the tree.
Example
Usage with mutable types:
>>> import sepes as sp >>> mutable_tree = [1, 2, 3] >>> def mutating_func(tree): ... tree[0] += 100 ... return tree >>> new_tree = mutating_func(mutable_tree) >>> assert new_tree is mutable_tree >>> # now with `value_and_tree` the function does not mutate the tree >>> new_tree, _ = sp.value_and_tree(mutating_func)(mutable_tree) >>> assert new_tree is not mutable_tree
Example
Usage with immutable types (
TreeClass
) with support for in-place mutation via custom behavior registration usingvalue_and_tree.def_mutator()
andvalue_and_tree.def_immutator()
:>>> import sepes as sp >>> class Counter(sp.TreeClass): ... def __init__(self, count: int): ... self.count = count ... def increment(self, value): ... self.count += value ... return self.count >>> counter = Counter(0) >>> counter.increment(1) AttributeError: Cannot set attribute value=1 to `key='count'` on an immutable instance of `Counter`. >>> sp.value_and_tree(lambda counter: counter.increment(1))(counter) (1, Counter(count=1))
Note
Use this function on function that:
Mutates the input arguments of mutable types (e.g. lists, dicts, etc.).
Mutates the input arguments of immutable types that do not support in-place mutation and needs special handling that can be registered (e.g.
TreeClass
) usingvalue_and_tree.def_mutator()
andvalue_and_tree.def_immutator()
.
Note
The default behavior of
value_and_tree()
is to copy the input arguments and then call the function on the copy. However if the function mutates some of the input arguments that does not support in-place mutation, then the function will fail. In this case,value_and_tree()
enables registering custom behavior that modifies the copied input argument to allow in-place mutation. and custom function that restores the copied argument to its original state after the method call. The following example shows how to register custom functions for a simple class that allows in-place mutation ifimmutable
Flag is set toFalse
.>>> import jax >>> from jax.util import unzip2 >>> import sepes as sp >>> @jax.tree_util.register_pytree_node_class ... class MyNode: ... def __init__(self): ... self.counter = 0 ... self.immutable = True ... def tree_flatten(self): ... keys, values = unzip2(vars(self).items()) ... return tuple(values), tuple(keys) ... @classmethod ... def tree_unflatten(cls, keys, values): ... self = object.__new__(cls) ... vars(self).update(dict(zip(keys, values))) ... return self ... def __setattr__(self, name, value): ... if getattr(self, "immutable", False) is True: ... raise AttributeError("MyNode is immutable") ... object.__setattr__(self, name, value) ... def __repr__(self): ... params = ", ".join(f"{k}={v}" for k, v in vars(self).items()) ... return f"MyNode({params})" ... def increment(self) -> None: ... self.counter += 1 >>> @sp.value_and_tree.def_mutator(MyNode) ... def mutable(node) -> None: ... vars(node)["immutable"] = False >>> @sp.value_and_tree.def_immutator(MyNode) ... def immutable(node) -> None: ... vars(node)["immutable"] = True >>> node = MyNode() >>> sp.value_and_tree(lambda node: node.increment())(node) (None, MyNode(counter=1, immutable=True))
- sepes.bcmap(func, broadcast_to=None, *, is_leaf=None)[source]#
Map a function over pytree leaves with automatic broadcasting for scalar arguments.
- Parameters:
func (Callable[P, T]) – the function to be mapped over the pytree.
broadcast_to (int | str | None) – Accepts integer for broadcasting to a specific argument or string for broadcasting to a specific keyword argument. If
None
, then the function is broadcasted to the first argument or the first keyword argument if no positional arguments are provided. Defaults toNone
.is_leaf (Callable[[Any], bool] | None) – a predicate function that returns True if the node is a leaf.
- Return type:
Callable[P, T]
Example
Transform numpy functions to work with pytrees:
>>> import sepes as sp >>> import jax.numpy as jnp >>> tree_of_arrays = {"a": jnp.array([1, 2, 3]), "b": jnp.array([4, 5, 6])} >>> tree_add = sp.bcmap(jnp.add) >>> # both lhs and rhs are pytrees >>> print(sp.tree_str(tree_add(tree_of_arrays, tree_of_arrays))) dict(a=[2 4 6], b=[ 8 10 12]) >>> # rhs is a scalar >>> print(sp.tree_str(tree_add(tree_of_arrays, 1))) dict(a=[2 3 4], b=[5 6 7])