π¨ Pretty printing API#
- sepes.tree_diagram(tree, *, depth=inf, is_leaf=None, tabwidth=4)[source]#
Pretty print arbitrary pytrees tree with tree structure diagram.
- Parameters:
tree (Any) β arbitrary pytree.
depth (int | float) β depth of the tree to print. default is max depth.
is_leaf (Callable[[Any], None] | None) β function to determine if a node is a leaf. default is None.
tabwidth (int) β tab width of the repr string. default is 4.
Example
>>> import sepes as sp >>> @sp.autoinit ... class A(sp.TreeClass): ... x: int = 10 ... y: int = (20,30) ... z: int = 40
>>> @sp.autoinit ... class B(sp.TreeClass): ... a: int = 10 ... b: tuple = (20,30, A())
>>> print(sp.tree_diagram(B(), depth=0)) B(...)
>>> print(sp.tree_diagram(B(), depth=1)) B βββ .a=10 βββ .b=(...)
>>> print(sp.tree_diagram(B(), depth=2)) B βββ .a=10 βββ .b:tuple βββ [0]=20 βββ [1]=30 βββ [2]=A(...)
- sepes.tree_repr(tree, *, width=80, tabwidth=2, depth=inf)[source]#
- Parameters:
tree (PyTree) β
width (int) β
tabwidth (int) β
depth (int | float) β
- sepes.tree_str(tree, *, width=80, tabwidth=2, depth=inf)[source]#
- Parameters:
tree (PyTree) β
width (int) β
tabwidth (int) β
depth (int | float) β
- sepes.tree_summary(tree, *, depth=inf, is_leaf=None)[source]#
Print a summary of an arbitrary pytree.
- Parameters:
tree (PyTree) β A pytree.
depth (int | float) β max depth to display the tree. Defaults to maximum depth.
is_leaf (Callable[[Any], None] | None) β function to determine if a node is a leaf. Defaults to
None
- Returns:
First column: path to the node.
Second column: type of the node. to control the displayed type use
tree_summary.def_type(type, func)
to define a custom type display function.Third column: number of leaves in the node. for arrays the number of leaves is the number of elements in the array, otherwise its 1. to control the number of leaves of a node use
tree_summary.def_count(type,func)
Fourth column: size of the node in bytes. if the node is array the size is the size of the array in bytes, otherwise the size is not displayed. to control the size of a node use
tree_summary.def_size(type,func)
Last row: type of parent, number of leaves of the parent
- Return type:
String summary of the tree structure
Example
>>> import sepes as sp >>> import jax.numpy as jnp >>> print(sp.tree_summary([1, [2, [3]], jnp.array([1, 2, 3])])) βββββββββββ¬βββββββββββββββββββββββββββββββββββββ¬ββββββ¬βββββββ βName βType βCountβSize β βββββββββββΌβββββββββββββββββββββββββββββββββββββΌββββββΌβββββββ€ β[0] βint β1 β β βββββββββββΌβββββββββββββββββββββββββββββββββββββΌββββββΌβββββββ€ β[1][0] βint β1 β β βββββββββββΌβββββββββββββββββββββββββββββββββββββΌββββββΌβββββββ€ β[1][1][0]βint β1 β β βββββββββββΌβββββββββββββββββββββββββββββββββββββΌββββββΌβββββββ€ β[2] βi32[3] β3 β12.00Bβ βββββββββββΌβββββββββββββββββββββββββββββββββββββΌββββββΌβββββββ€ βΞ£ βlist[int,list[int,list[int]],i32[3]]β6 β12.00Bβ βββββββββββ΄βββββββββββββββββββββββββββββββββββββ΄ββββββ΄βββββββ
Example
Display flops of a function in tree summary
>>> import jax >>> import functools as ft >>> import sepes as sp >>> def count_flops(func, *args, **kwargs) -> int: ... cost_analysis = jax.jit(func).lower(*args, **kwargs).cost_analysis() ... return cost_analysis["flops"] if "flops" in cost_analysis else 0 >>> class Flops: ... def __init__(self, func, *args, **kwargs): ... self.func = ft.partial(func, *args, **kwargs) >>> @sp.tree_summary.def_count(Flops) ... def _(node: Flops) -> int: ... return count_flops(node.func) >>> @sp.tree_summary.def_type(Flops) ... def _(node: Flops) -> str: ... return f"Flops({sp.tree_repr(node.func.func)})" >>> tree = dict(a=1, b=Flops(jax.nn.relu, jax.numpy.ones((10, 1)))) >>> print(sp.tree_summary(tree)) βββββββ¬ββββββββββββββββββββ¬ββββββ¬βββββ βName βType βCountβSizeβ βββββββΌββββββββββββββββββββΌββββββΌβββββ€ β['a']βint β1 β β βββββββΌββββββββββββββββββββΌββββββΌβββββ€ β['b']βFlops(jit(relu(x)))β10.0 β β βββββββΌββββββββββββββββββββΌββββββΌβββββ€ βΞ£ βdict β11.0 β β βββββββ΄ββββββββββββββββββββ΄ββββββ΄βββββ
Example
Register custom type size rule
>>> import jax >>> import sepes as sp >>> def func(x): ... print(sp.tree_summary(x)) ... return x >>> class AbstractZero: ... >>> @sp.tree_summary.def_size(AbstractZero) ... def _(node: AbstractZero) -> int: ... return 0 >>> print(sp.tree_summary(AbstractZero())) ββββββ¬βββββββββββββ¬ββββββ¬βββββ βNameβType βCountβSizeβ ββββββΌβββββββββββββΌββββββΌβββββ€ βΞ£ βAbstractZeroβ1 β β ββββββ΄βββββββββββββ΄ββββββ΄βββββ