🌲 Tree utils API#

class sepes.at(tree, where=None)#

Operate on a pytree at a given path using a path or mask in out-of-place manner.

Parameters:
  • tree – pytree to operate on.

  • where

    one of the following:

    • str for mapping keys or class attributes.

    • int for positional indexing for sequences.

    • ... to select all leaves.

    • a boolean mask of the same structure as the tree

    • re.Pattern to match a leaf level path with a regex pattern.

    • a tuple of the above to match multiple keys at the same level.

Example

>>> import jax
>>> import sepes as sp
>>> tree = {"a": 1, "b": [1, 2, 3]}
>>> sp.at(tree)["a"].set(100)
{'a': 100, 'b': [1, 2, 3]}
>>> sp.at(tree)["b"][0].set(100)
{'a': 1, 'b': [100, 2, 3]}
>>> mask = jax.tree_map(lambda x: x > 1, tree)
>>> sp.at(tree)[mask].set(100)
{'a': 1, 'b': [1, 100, 100]}
get(*, is_leaf=None, is_parallel=False, fill_value=<object object>)[source]#

Get the leaf values at the specified location.

Parameters:
  • is_leaf (Callable[[Any], bool] | None) – a predicate function to determine if a value is a leaf.

  • is_parallel (bool | ParallelConfig) –

    accepts the following:

    • bool: apply func in parallel if True otherwise in serial.

    • dict: a dict of of:

      • max_workers: maximum number of workers to use.

      • kind: kind of pool to use, either thread or process.

  • fill_value (Any) – the value to fill the non-selected leaves with. Useful to use with jax.jit to avoid variable size arrays leaves related errors.

Returns:

A _new_ pytree of leaf values at the specified location, with the non-selected leaf values set to None if the leaf is not an array.

Example

>>> import sepes as sp
>>> tree = {"a": 1, "b": [1, 2, 3]}
>>> sp.at(tree)["b"][0].get()
{'a': None, 'b': [1, None, None]}
set(set_value, *, is_leaf=None, is_parallel=False)[source]#

Set the leaf values at the specified location.

Parameters:
  • set_value (Any) – the value to set at the specified location.

  • is_leaf (Callable[[Any], bool] | None) – a predicate function to determine if a value is a leaf.

  • is_parallel (bool | ParallelConfig) –

    accepts the following:

    • bool: apply func in parallel if True otherwise in serial.

    • dict: a dict of of:

      • max_workers: maximum number of workers to use.

      • kind: kind of pool to use, either thread or process.

Returns:

A pytree with the leaf values at the specified location set to set_value.

Example

>>> import sepes as sp
>>> tree = {"a": 1, "b": [1, 2, 3]}
>>> sp.at(tree)["b"][0].set(100)
{'a': 1, 'b': [100, 2, 3]}
apply(func, *, is_leaf=None, is_parallel=False)[source]#

Apply a function to the leaf values at the specified location.

Parameters:
  • func (Callable[[Any], Any]) – the function to apply to the leaf values.

  • is_leaf (Callable[[Any], bool] | None) – a predicate function to determine if a value is a leaf.

  • is_parallel (bool | ParallelConfig) –

    accepts the following:

    • bool: apply func in parallel if True otherwise in serial.

    • dict: a dict of of:

      • max_workers: maximum number of workers to use.

      • kind: kind of pool to use, either thread or process.

Returns:

A pytree with the leaf values at the specified location set to the result of applying func to the leaf values.

Example

>>> import sepes as sp
>>> tree = {"a": 1, "b": [1, 2, 3]}
>>> sp.at(tree)["b"][0].apply(lambda x: x + 100)
{'a': 1, 'b': [101, 2, 3]}

Example

Read images in parallel

>>> import sepes as sp
>>> from matplotlib.pyplot import imread
>>> path = {"img1": "path1.png", "img2": "path2.png"}
>>> is_parallel = dict(max_workers=2)
>>> images = sp.at(path)[...].apply(imread, is_parallel=is_parallel)  
scan(func, state, *, is_leaf=None)[source]#

Apply a function while carrying a state.

Parameters:
  • func – the function to apply to the leaf values. the function accepts a running state and leaf value and returns a tuple of the new leaf value and the new state.

  • state – the initial state to carry.

  • is_leaf – a predicate function to determine if a value is a leaf. for example, lambda x: isinstance(x, list) will treat all lists as leaves and will not recurse into list items.

Returns:

A tuple of the final state and pytree with the leaf values at the specified location set to the result of applying func to the leaf values.

Example

>>> import sepes as sp
>>> tree = {"a": 1, "b": [1, 2, 3]}
>>> def scan_func(leaf, running_max):
...     cur_max = max(leaf, running_max)
...     return leaf, cur_max
>>> running_max = float("-inf")
>>> _, running_max = sp.at(tree)["b"][0, 1].scan(scan_func, state=running_max)
>>> running_max  # max of b[0] and b[1]
2

Note

scan applies a binary func to the leaf values while carrying a state and returning a tree leaves with the the func applied to them with final state. While reduce applies a binary func to the leaf values while carrying a state and returning a single value.

reduce(func, *, initializer=<object object>, is_leaf=None)[source]#

Reduce the leaf values at the specified location.

Parameters:
  • func (Callable[[Any, Any], Any]) – the function to reduce the leaf values.

  • initializer (Any) – the initializer value for the reduction.

  • is_leaf (Callable[[Any], bool] | None) – a predicate function to determine if a value is a leaf.

Return type:

Any

Returns:

The result of reducing the leaf values at the specified location.

Note

  • If initializer is not specified, the first leaf value is used as the initializer.

  • reduce applies a binary func to each leaf values while accumulating a state a returns the final result. while scan applies func to each leaf value while carrying a state and returns the final state and the leaves of the tree with the result of applying func to each leaf.

Example

>>> import sepes as sp
>>> tree = {"a": 1, "b": [1, 2, 3]}
>>> sp.at(tree)["b"].reduce(lambda x, y: x + y)
6
pluck(count=None, *, is_leaf=None, is_parallel=False)[source]#

Extract subtrees at the specified location.

Note

pluck first applies get to the specified location and then extracts the immediate subtrees of the selected leaves. is_leaf and is_parallel are passed to get.

Parameters:
  • count (int | None) – number of subtrees to extract, Default to None to extract all subtrees.

  • is_leaf (Callable[[Any], bool] | None) – a predicate function to determine if a value is a leaf.

  • is_parallel (bool | ParallelConfig) –

    accepts the following:

    • bool: apply func in parallel if True otherwise in serial.

    • dict: a dict of of:

      • max_workers: maximum number of workers to use.

      • kind: kind of pool to use, either thread or process.

Return type:

list[Any]

Returns:

A list of subtrees at the specified location.

Note

Compared to get, pluck extracts subtrees at the specified location and returns a list of subtrees. While get returns a pytree with the leaf values at the specified location and set the non-selected leaf values to None.

Example

>>> import sepes as sp
>>> tree = {"a": 1, "b": [1, 2, 3]}

>>> # `pluck` returns a list of selected subtrees
>>> sp.at(tree)["b"].pluck()
[[1, 2, 3]]

>>> # `get` returns same pytree
>>> sp.at(tree)["b"].get()
{'a': None, 'b': [1, 2, 3]}

Example

pluck with mask

>>> import sepes as sp
>>> tree = {"a": 1, "b": [2, 3, 4]}
>>> mask = {"a": True, "b": [False, True, False]}
>>> sp.at(tree)[mask].pluck()
[1, 3]

This is equivalent to the following:

>>> [tree["a"], tree["b"][1]]  
sepes.value_and_tree(func, argnums=0)[source]#

Call a function on copied input argument and return the value and the tree.

Input arguments are copied before calling the function, and the argument specified by argnums are returned as a tree.

Parameters:
  • func (Callable[..., T]) – A function.

  • argnums (int | Sequence[int]) – The argument number of the tree that will be returned. If multiple arguments are specified, the tree will be returned as a tuple.

Returns:

A function that returns the value and the tree.

Example

Usage with mutable types:

>>> import sepes as sp
>>> mutable_tree = [1, 2, 3]
>>> def mutating_func(tree):
...     tree[0] += 100
...     return tree
>>> new_tree = mutating_func(mutable_tree)
>>> assert new_tree is mutable_tree
>>> # now with `value_and_tree` the function does not mutate the tree
>>> new_tree, _ = sp.value_and_tree(mutating_func)(mutable_tree)
>>> assert new_tree is not mutable_tree

Example

Usage with immutable types (TreeClass) with support for in-place mutation via custom behavior registration using value_and_tree.def_mutator() and value_and_tree.def_immutator():

>>> import sepes as sp
>>> class Counter(sp.TreeClass):
...     def __init__(self, count: int):
...         self.count = count
...     def increment(self, value):
...         self.count += value
...         return self.count
>>> counter = Counter(0)
>>> counter.increment(1)  
AttributeError: Cannot set attribute value=1 to `key='count'`  on an immutable instance of `Counter`.
>>> sp.value_and_tree(lambda counter: counter.increment(1))(counter)
(1, Counter(count=1))

Note

Use this function on function that:

  • Mutates the input arguments of mutable types (e.g. lists, dicts, etc.).

  • Mutates the input arguments of immutable types that do not support in-place mutation and needs special handling that can be registered (e.g. TreeClass) using value_and_tree.def_mutator() and value_and_tree.def_immutator().

Note

The default behavior of value_and_tree() is to copy the input arguments and then call the function on the copy. However if the function mutates some of the input arguments that does not support in-place mutation, then the function will fail. In this case, value_and_tree() enables registering custom behavior that modifies the copied input argument to allow in-place mutation. and custom function that restores the copied argument to its original state after the method call. The following example shows how to register custom functions for a simple class that allows in-place mutation if immutable Flag is set to False.

>>> import jax
>>> from jax.util import unzip2
>>> import sepes as sp
>>> @jax.tree_util.register_pytree_node_class
... class MyNode:
...     def __init__(self):
...         self.counter = 0
...         self.immutable = True
...     def tree_flatten(self):
...         keys, values = unzip2(vars(self).items())
...         return tuple(values), tuple(keys)
...     @classmethod
...     def tree_unflatten(cls, keys, values):
...         self = object.__new__(cls)
...         vars(self).update(dict(zip(keys, values)))
...         return self
...     def __setattr__(self, name, value):
...         if getattr(self, "immutable", False) is True:
...             raise AttributeError("MyNode is immutable")
...         object.__setattr__(self, name, value)
...     def __repr__(self):
...         params = ", ".join(f"{k}={v}" for k, v in vars(self).items())
...         return f"MyNode({params})"
...     def increment(self) -> None:
...         self.counter += 1
>>> @sp.value_and_tree.def_mutator(MyNode)
... def mutable(node) -> None:
...     vars(node)["immutable"] = False
>>> @sp.value_and_tree.def_immutator(MyNode)
... def immutable(node) -> None:
...     vars(node)["immutable"] = True
>>> node = MyNode()
>>> sp.value_and_tree(lambda node: node.increment())(node)
(None, MyNode(counter=1, immutable=True))
sepes.bcmap(func, broadcast_to=None, *, is_leaf=None)[source]#

Map a function over pytree leaves with automatic broadcasting for scalar arguments.

Parameters:
  • func (Callable[P, T]) – the function to be mapped over the pytree.

  • broadcast_to (int | str | None) – Accepts integer for broadcasting to a specific argument or string for broadcasting to a specific keyword argument. If None, then the function is broadcasted to the first argument or the first keyword argument if no positional arguments are provided. Defaults to None.

  • is_leaf (Callable[[Any], bool] | None) – a predicate function that returns True if the node is a leaf.

Return type:

Callable[P, T]

Example

Transform numpy functions to work with pytrees:

>>> import sepes as sp
>>> import jax.numpy as jnp
>>> tree_of_arrays = {"a": jnp.array([1, 2, 3]), "b": jnp.array([4, 5, 6])}
>>> tree_add = sp.bcmap(jnp.add)
>>> # both lhs and rhs are pytrees
>>> print(sp.tree_str(tree_add(tree_of_arrays, tree_of_arrays)))
dict(a=[2 4 6], b=[ 8 10 12])
>>> # rhs is a scalar
>>> print(sp.tree_str(tree_add(tree_of_arrays, 1)))
dict(a=[2 3 4], b=[5 6 7])