🥽 Masking API#

sepes.is_masked(value)[source]#

Returns True if the value is a frozen wrapper.

Parameters:

value (Any) –

Return type:

bool

sepes.tree_mask(tree, cond=<function is_nondiff>, *, is_leaf=None)[source]#

Mask leaves of a pytree based on mask boolean pytree or callable.

Masked leaves are wrapped with a wrapper that yields no leaves when tree_flatten is called on it.

Parameters:
  • tree (T) – A pytree of values.

  • cond (Callable[[Any], bool]) – A callable that accepts a leaf and returns a boolean to mark the leaf for masking. Defaults to masking non-differentiable leaf nodes that are not instances of of python float, python complex, or inexact array types.

  • is_leaf (Callable[[Any], None] | None) – A callable that accepts a leaf and returns a boolean. If provided, it is used to determine if a value is a leaf. for example, is_leaf=lambda x: isinstance(x, list) will treat lists as leaves and will not recurse into them.

Example

>>> import sepes as sp
>>> import jax
>>> tree = [1, 2, {"a": 3, "b": 4.}]
>>> # mask all non-differentiable nodes by default
>>> masked_tree = sp.tree_mask(tree)
>>> masked_tree
[#1, #2, {'a': #3, 'b': 4.0}]
>>> jax.tree_util.tree_leaves(masked_tree)
[4.0]
>>> sp.tree_unmask(masked_tree)
[1, 2, {'a': 3, 'b': 4.0}]

Example

Pass non-differentiable values to jax.grad

>>> import sepes as sp
>>> import jax
>>> @jax.grad
... def square(tree):
...     tree = sp.tree_unmask(tree)
...     return tree[0] ** 2
>>> tree = (1., 2)  # contains a non-differentiable node
>>> square(sp.tree_mask(tree))
(Array(2., dtype=float32, weak_type=True), #2)
sepes.tree_unmask(tree, cond=<function <lambda>>)[source]#

Undo the masking of tree leaves according to cond. defaults to unmasking all leaves.

Parameters:
  • tree (TypeVar(T)) – A pytree of values.

  • cond (Callable[[Any], bool]) – A callable that accepts a leaf and returns a boolean to mark the leaf to be unmasked. Defaults to always unmask.

Example

>>> import sepes as sp
>>> import jax
>>> tree = [1, 2, {"a": 3, "b": 4.}]
>>> # mask all non-differentiable nodes by default
>>> masked_tree = sp.tree_mask(tree)
>>> masked_tree
[#1, #2, {'a': #3, 'b': 4.0}]
>>> jax.tree_util.tree_leaves(masked_tree)
[4.0]
>>> sp.tree_unmask(masked_tree)
[1, 2, {'a': 3, 'b': 4.0}]

Example

Pass non-differentiable values to jax.grad

>>> import sepes as sp
>>> import jax
>>> @jax.grad
... def square(tree):
...     tree = sp.tree_unmask(tree)
...     return tree[0] ** 2
>>> tree = (1., 2)  # contains a non-differentiable node
>>> square(sp.tree_mask(tree))
(Array(2., dtype=float32, weak_type=True), #2)