🎨 Pretty printing API#

sepes.tree_diagram(tree, *, depth=inf, is_leaf=None, tabwidth=4)[source]#

Pretty print arbitrary pytrees tree with tree structure diagram.

Parameters:
  • tree (Any) – arbitrary pytree.

  • depth (int | float) – depth of the tree to print. default is max depth.

  • is_leaf (Callable[[Any], None] | None) – function to determine if a node is a leaf. default is None.

  • tabwidth (int) – tab width of the repr string. default is 4.

Example

>>> import sepes as sp
>>> @sp.autoinit
... class A(sp.TreeClass):
...     x: int = 10
...     y: int = (20,30)
...     z: int = 40
>>> @sp.autoinit
... class B(sp.TreeClass):
...     a: int = 10
...     b: tuple = (20,30, A())
>>> print(sp.tree_diagram(B(), depth=0))
B(...)
>>> print(sp.tree_diagram(B(), depth=1))
B
β”œβ”€β”€ .a=10
└── .b=(...)
>>> print(sp.tree_diagram(B(), depth=2))
B
β”œβ”€β”€ .a=10
└── .b:tuple
    β”œβ”€β”€ [0]=20
    β”œβ”€β”€ [1]=30
    └── [2]=A(...)
sepes.tree_repr(tree, *, width=80, tabwidth=2, depth=inf)[source]#
Parameters:
  • tree (PyTree) –

  • width (int) –

  • tabwidth (int) –

  • depth (int | float) –

sepes.tree_str(tree, *, width=80, tabwidth=2, depth=inf)[source]#
Parameters:
  • tree (PyTree) –

  • width (int) –

  • tabwidth (int) –

  • depth (int | float) –

sepes.tree_summary(tree, *, depth=inf, is_leaf=None)[source]#

Print a summary of an arbitrary pytree.

Parameters:
  • tree (PyTree) – A pytree.

  • depth (int | float) – max depth to display the tree. Defaults to maximum depth.

  • is_leaf (Callable[[Any], None] | None) – function to determine if a node is a leaf. Defaults to None

Returns:

  • First column: path to the node.

  • Second column: type of the node. to control the displayed type use tree_summary.def_type(type, func) to define a custom type display function.

  • Third column: number of leaves in the node. for arrays the number of leaves is the number of elements in the array, otherwise its 1. to control the number of leaves of a node use tree_summary.def_count(type,func)

  • Fourth column: size of the node in bytes. if the node is array the size is the size of the array in bytes, otherwise the size is not displayed. to control the size of a node use tree_summary.def_size(type,func)

  • Last row: type of parent, number of leaves of the parent

Return type:

String summary of the tree structure

Example

>>> import sepes as sp
>>> import jax.numpy as jnp
>>> print(sp.tree_summary([1, [2, [3]], jnp.array([1, 2, 3])]))
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”
β”‚Name     β”‚Type                                β”‚Countβ”‚Size  β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€
β”‚[0]      β”‚int                                 β”‚1    β”‚      β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€
β”‚[1][0]   β”‚int                                 β”‚1    β”‚      β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€
β”‚[1][1][0]β”‚int                                 β”‚1    β”‚      β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€
β”‚[2]      β”‚i32[3]                              β”‚3    β”‚12.00Bβ”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€
β”‚Ξ£        β”‚list[int,list[int,list[int]],i32[3]]β”‚6    β”‚12.00Bβ”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”˜

Example

Display flops of a function in tree summary

>>> import jax
>>> import functools as ft
>>> import sepes as sp
>>> def count_flops(func, *args, **kwargs) -> int:
...     cost_analysis = jax.jit(func).lower(*args, **kwargs).cost_analysis()
...     return cost_analysis["flops"] if "flops" in cost_analysis else 0
>>> class Flops:
...     def __init__(self, func, *args, **kwargs):
...         self.func = ft.partial(func, *args, **kwargs)
>>> @sp.tree_summary.def_count(Flops)
... def _(node: Flops) -> int:
...     return count_flops(node.func)
>>> @sp.tree_summary.def_type(Flops)
... def _(node: Flops) -> str:
...     return f"Flops({sp.tree_repr(node.func.func)})"
>>> tree = dict(a=1, b=Flops(jax.nn.relu, jax.numpy.ones((10, 1))))
>>> print(sp.tree_summary(tree))
β”Œβ”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”
β”‚Name β”‚Type               β”‚Countβ”‚Sizeβ”‚
β”œβ”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€
β”‚['a']β”‚int                β”‚1    β”‚    β”‚
β”œβ”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€
β”‚['b']β”‚Flops(jit(relu(x)))β”‚10.0 β”‚    β”‚
β”œβ”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€
β”‚Ξ£    β”‚dict               β”‚11.0 β”‚    β”‚
β””β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”˜

Example

Register custom type size rule

>>> import jax
>>> import sepes as sp
>>> def func(x):
...     print(sp.tree_summary(x))
...     return x
>>> class AbstractZero: ...
>>> @sp.tree_summary.def_size(AbstractZero)
... def _(node: AbstractZero) -> int:
...     return 0
>>> print(sp.tree_summary(AbstractZero()))
β”Œβ”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”
β”‚Nameβ”‚Type        β”‚Countβ”‚Sizeβ”‚
β”œβ”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€
β”‚Ξ£   β”‚AbstractZeroβ”‚1    β”‚    β”‚
β””β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”˜