🏗️ Constructor utils API#

sepes.field(*, default=NULL, init=True, repr=True, kind='POS_OR_KW', metadata=None, on_setattr=(), on_getattr=(), alias=None, doc='')[source]#

Field placeholder for type hinted attributes.

Parameters:
  • default – The default value of the field.

  • init – Whether the field is included in the object’s __init__ function.

  • repr – Whether the field is included in the object’s __repr__ function.

  • kind

    Argument kind used in the constructor sythesis with autoinit(),

    • POS_ONLY: positional only argument (e.g. x in def f(x, /):)

    • VAR_POS: variable positional argument (e.g. *x in def f(*x):)

    • POS_OR_KW: positional or keyword argument (e.g. x in def f(x):)

    • KW_ONLY: keyword only argument (e.g. x in def f(*, x):)

    • VAR_KW: variable keyword argument (e.g. **x in def f(**x):)

    • CLASS_VAR: Non-constructor class variable (e.g. x in class C: x = 1)

  • metadata – A mapping of user-defined data for the field.

  • on_setattr – A sequence of functions to called on __setattr__.

  • on_getattr – A sequence of functions to called on __getattr__.

  • alias – An a alias for the field name in the constructor. e.g name=x, alias=y will allow obj = Class(y=1) to be equivalent to obj = Class(x=1).

  • doc

    extra documentation for the :func:.`field` .the complete documentation of the field includes the field name, the field doc, and the default value, and function callbacks applied on the field value. Mainly used for documenting the field callbacks.

    >>> import sepes as sp
    >>> @sp.autoinit
    ... class Tree:
    ...     leaf: int = sp.field(
    ...         default=1,
    ...         doc="Leaf node of the tree.",
    ...         on_setattr=[lambda x: x],
    ... )
    >>> print(Tree.leaf.__doc__)  
    Field Information:
            Name:           ``leaf``
            Default:        ``1``
    Description:
            Leaf node of the tree.
    Callbacks:
            - On setting attribute:
    
                    - ``<function Tree.<lambda> at 0x11c53dc60>``
    

Example

Type and range validation using on_setattr:

>>> import sepes as sp
>>> @sp.autoinit
... class IsInstance(sp.TreeClass):
...    klass: type
...    def __call__(self, x):
...        assert isinstance(x, self.klass)
...        return x

>>> @sp.autoinit
... class Range(sp.TreeClass):
...    start: int|float = float("-inf")
...    stop: int|float = float("inf")
...    def __call__(self, x):
...        assert self.start <= x <= self.stop
...        return x

>>> @sp.autoinit
... class Employee(sp.TreeClass):
...    # assert employee ``name`` is str
...    name: str = sp.field(on_setattr=[IsInstance(str)])
...    # use callback compostion to assert employee ``age`` is int and positive
...    age: int = sp.field(on_setattr=[IsInstance(int), Range(1)])
>>> employee = Employee(name="Asem", age=10)
>>> print(employee)
Employee(name=Asem, age=10)

Example

Private attribute using alias:

>>> import sepes as sp
>>> @sp.autoinit
... class Employee(sp.TreeClass):
...     # `alias` is the name used in the constructor
...    _name: str = sp.field(alias="name")
>>> employee = Employee(name="Asem")  # use `name` in the constructor
>>> print(employee)  # `_name` is the private attribute name
Employee(_name=Asem)

Example

Buffer creation using on_getattr:

>>> import sepes as sp
>>> import jax
>>> import jax.numpy as jnp
>>> @sp.autoinit
... class Tree(sp.TreeClass):
...     buffer: jax.Array = sp.field(on_getattr=[jax.lax.stop_gradient])
>>> tree = Tree(buffer=jnp.array((1.0, 2.0)))
>>> def sum_buffer(tree):
...     return tree.buffer.sum()
>>> print(jax.grad(sum_buffer)(tree))  # no gradient on `buffer`
Tree(buffer=[0. 0.])

Example

Parameterization using on_getattr:

>>> import sepes as sp
>>> import jax
>>> import jax.numpy as jnp
>>> def symmetric(array: jax.Array) -> jax.Array:
...    triangle = jnp.triu(array)  # upper triangle
...    return triangle + triangle.transpose(-1, -2)
>>> @sp.autoinit
... class Tree(sp.TreeClass):
...    symmetric_matrix: jax.Array = sp.field(on_getattr=[symmetric])
>>> tree = Tree(symmetric_matrix=jnp.arange(9).reshape(3, 3))
>>> print(tree.symmetric_matrix)
[[ 0  1  2]
 [ 1  8  5]
 [ 2  5 16]]

Note

  • field() is commonly used to annotate the class attributes to be used by the autoinit() decorator to generate the __init__ method similar to dataclasses.dataclass.

  • field() can be used without the autoinit() as a descriptor to apply functions on the field values during initialization using the on_setattr / on_getattr argument.

    >>> import sepes as sp
    >>> def print_and_return(x):
    ...    print(f"Setting {x}")
    ...    return x
    >>> class Tree:
    ...    # `a` must be defined as a class attribute for the descriptor to work
    ...    a: int = sp.field(on_setattr=[print_and_return])
    ...    def __init__(self, a):
    ...        self.a = a
    >>> tree = Tree(1)
    Setting 1
    
sepes.fields(x)[source]#

Returns a tuple of Field objects for the given instance or class.

Field objects are generated from the class type hints and contains the information about the field information.if the user uses the sepes.field to annotate.

Note

  • If the class is not annotated, an empty tuple is returned.

  • The Field generation is cached for class and its bases.

sepes.autoinit(klass)[source]#

A class decorator that generates the __init__ method from type hints.

Using the autoinit decorator, the user can define the class attributes using type hints and the __init__ method will be generated automatically

>>> import sepes as sp
>>> @sp.autoinit
... class Tree:
...     x: int
...     y: int

Is equivalent to:

>>> class Tree:
...     def __init__(self, x: int, y: int):
...         self.x = x
...         self.y = y

Example

>>> import sepes as sp
>>> import inspect
>>> @sp.autoinit
... class Tree:
...     x: int
...     y: int
>>> inspect.signature(Tree.__init__)
<Signature (self, x: int, y: int) -> None>
>>> tree = Tree(1, 2)
>>> tree.x, tree.y
(1, 2)

Example

Define fields with different argument kinds

>>> import sepes as sp
>>> import inspect
>>> @sp.autoinit
... class Tree:
...     kw_only_field: int = sp.field(default=1, kind="KW_ONLY")
...     pos_only_field: int = sp.field(default=2, kind="POS_ONLY")
>>> inspect.signature(Tree.__init__)
<Signature (self, pos_only_field: int = 2, /, *, kw_only_field: int = 1) -> None>

Example

Define a converter to apply abs on the field value

>>> @sp.autoinit
... class Tree:
...     a:int = sp.field(on_setattr=[abs])
>>> Tree(a=-1).a
1

Warning

The autoinit decorator will raise TypeError if the user defines __init__ method in the decorated class.

Note

  • In case of inheritance, the __init__ method is generated from the the type hints of the current class and any base classes that are decorated with autoinit.

>>> import sepes as sp
>>> import inspect
>>> @sp.autoinit
... class Base:
...     x: int
>>> @sp.autoinit
... class Derived(Base):
...     y: int
>>> obj = Derived(x=1, y=2)
>>> inspect.signature(obj.__init__)
<Signature (x: int, y: int) -> None>
  • Base classes that are not decorated with autoinit are ignored during synthesis of the __init__ method.

>>> import sepes as sp
>>> import inspect
>>> class Base:
...     x: int
>>> @sp.autoinit
... class Derived(Base):
...     y: int
>>> obj = Derived(y=2)
>>> inspect.signature(obj.__init__)
<Signature (y: int) -> None>

Note

Use autoinit instead of dataclasses.dataclass if you want to use jax.Array as a field default value. As dataclasses.dataclass will incorrectly raise an error starting from python 3.11 complaining that jax.Array is not immutable.

Note

By default autoinit will raise an error if the user uses mutable defaults. To register an additional type to be excluded from autoinit, use autoinit.register_excluded_type(), with an optional reason for excluding the type.

>>> import sepes as sp
>>> class T:
...     pass
>>> sp.autoinit.register_excluded_type(T, reason="not allowed")
>>> @sp.autoinit
... class Tree:
...     x: T = sp.field(default=T())  
Traceback (most recent call last):
    ...
sepes.leafwise(klass)[source]#

A class decorator that adds leafwise operators to a class.

Leafwise operators are operators that are applied to the leaves of a pytree. For example leafwise __add__ is equivalent to:

  • tree_map(lambda x: x + rhs, tree) if rhs is a scalar.

  • tree_map(lambda x, y: x + y, tree, rhs) if rhs is a pytree with the same structure as tree.

Parameters:

klass – The class to be decorated.

Returns:

The decorated class.

Example

Use numpy functions on TreeClass` classes decorated with leafwise()

>>> import sepes as sp
>>> import jax.numpy as jnp
>>> @sp.leafwise
... @sp.autoinit
... class Point(sp.TreeClass):
...    x: float = 0.5
...    y: float = 1.0
...    description: str = "point coordinates"
>>> # use :func:`tree_mask` to mask the non-inexact part of the tree
>>> # i.e. mask the string leaf ``description`` to ``Point`` work
>>> # with ``jax.numpy`` functions
>>> co = sp.tree_mask(Point())
>>> print(sp.bcmap(jnp.where)(co > 0.5, co, 1000))
Point(x=1000.0, y=1.0, description=#point coordinates)

Note

If a mathematically equivalent operator is already defined on the class, then it is not overridden.

Method

Operator

__add__

+

__and__

&

__ceil__

math.ceil

__divmod__

divmod

__eq__

==

__floor__

math.floor

__floordiv__

//

__ge__

>=

__gt__

>

__invert__

~

__le__

<=

__lshift__

<<

__lt__

<

__matmul__

@

__mod__

%

__mul__

*

__ne__

!=

__neg__

-

__or__

|

__pos__

+

__pow__

**

__round__

round

__sub__

-

__truediv__

/

__trunc__

math.trunc

__xor__

^