🥽 Masking API#
- sepes.is_masked(value)[source]#
Returns True if the value is a frozen wrapper.
- Parameters:
value (
Any
) –- Return type:
bool
- sepes.tree_mask(tree, cond=<function is_nondiff>, *, is_leaf=None)[source]#
Mask leaves of a pytree based on
mask
boolean pytree or callable.Masked leaves are wrapped with a wrapper that yields no leaves when
tree_flatten
is called on it.- Parameters:
tree (T) – A pytree of values.
cond (Callable[[Any], bool]) – A callable that accepts a leaf and returns a boolean to mark the leaf for masking. Defaults to masking non-differentiable leaf nodes that are not instances of of python float, python complex, or inexact array types.
is_leaf (Callable[[Any], None] | None) – A callable that accepts a leaf and returns a boolean. If provided, it is used to determine if a value is a leaf. for example,
is_leaf=lambda x: isinstance(x, list)
will treat lists as leaves and will not recurse into them.
Example
>>> import sepes as sp >>> import jax >>> tree = [1, 2, {"a": 3, "b": 4.}] >>> # mask all non-differentiable nodes by default >>> masked_tree = sp.tree_mask(tree) >>> masked_tree [#1, #2, {'a': #3, 'b': 4.0}] >>> jax.tree_util.tree_leaves(masked_tree) [4.0] >>> sp.tree_unmask(masked_tree) [1, 2, {'a': 3, 'b': 4.0}]
Example
Pass non-differentiable values to
jax.grad
>>> import sepes as sp >>> import jax >>> @jax.grad ... def square(tree): ... tree = sp.tree_unmask(tree) ... return tree[0] ** 2 >>> tree = (1., 2) # contains a non-differentiable node >>> square(sp.tree_mask(tree)) (Array(2., dtype=float32, weak_type=True), #2)
- sepes.tree_unmask(tree, cond=<function <lambda>>)[source]#
Undo the masking of tree leaves according to
cond
. defaults to unmasking all leaves.- Parameters:
tree (
TypeVar
(T
)) – A pytree of values.cond (
Callable
[[Any
],bool
]) – A callable that accepts a leaf and returns a boolean to mark the leaf to be unmasked. Defaults to always unmask.
Example
>>> import sepes as sp >>> import jax >>> tree = [1, 2, {"a": 3, "b": 4.}] >>> # mask all non-differentiable nodes by default >>> masked_tree = sp.tree_mask(tree) >>> masked_tree [#1, #2, {'a': #3, 'b': 4.0}] >>> jax.tree_util.tree_leaves(masked_tree) [4.0] >>> sp.tree_unmask(masked_tree) [1, 2, {'a': 3, 'b': 4.0}]
Example
Pass non-differentiable values to
jax.grad
>>> import sepes as sp >>> import jax >>> @jax.grad ... def square(tree): ... tree = sp.tree_unmask(tree) ... return tree[0] ** 2 >>> tree = (1., 2) # contains a non-differentiable node >>> square(sp.tree_mask(tree)) (Array(2., dtype=float32, weak_type=True), #2)