🏗️ Constructor utils API#
- sepes.field(*, default=NULL, init=True, repr=True, kind='POS_OR_KW', metadata=None, on_setattr=(), on_getattr=(), alias=None, doc='')[source]#
Field placeholder for type hinted attributes.
- Parameters:
default – The default value of the field.
init – Whether the field is included in the object’s
__init__
function.repr – Whether the field is included in the object’s
__repr__
function.kind –
Argument kind used in the constructor sythesis with
autoinit()
,POS_ONLY
: positional only argument (e.g.x
indef f(x, /):
)VAR_POS
: variable positional argument (e.g.*x
indef f(*x):
)POS_OR_KW
: positional or keyword argument (e.g.x
indef f(x):
)KW_ONLY
: keyword only argument (e.g.x
indef f(*, x):
)VAR_KW
: variable keyword argument (e.g.**x
indef f(**x):
)CLASS_VAR
: Non-constructor class variable (e.g.x
inclass C: x = 1
)
metadata – A mapping of user-defined data for the field.
on_setattr – A sequence of functions to called on
__setattr__
.on_getattr – A sequence of functions to called on
__getattr__
.alias – An a alias for the field name in the constructor. e.g
name=x
,alias=y
will allowobj = Class(y=1)
to be equivalent toobj = Class(x=1)
.doc –
extra documentation for the :func:.`field` .the complete documentation of the field includes the field name, the field doc, and the default value, and function callbacks applied on the field value. Mainly used for documenting the field callbacks.
>>> import sepes as sp >>> @sp.autoinit ... class Tree: ... leaf: int = sp.field( ... default=1, ... doc="Leaf node of the tree.", ... on_setattr=[lambda x: x], ... ) >>> print(Tree.leaf.__doc__) Field Information: Name: ``leaf`` Default: ``1`` Description: Leaf node of the tree. Callbacks: - On setting attribute: - ``<function Tree.<lambda> at 0x11c53dc60>``
Example
Type and range validation using
on_setattr
:>>> import sepes as sp >>> @sp.autoinit ... class IsInstance(sp.TreeClass): ... klass: type ... def __call__(self, x): ... assert isinstance(x, self.klass) ... return x >>> @sp.autoinit ... class Range(sp.TreeClass): ... start: int|float = float("-inf") ... stop: int|float = float("inf") ... def __call__(self, x): ... assert self.start <= x <= self.stop ... return x >>> @sp.autoinit ... class Employee(sp.TreeClass): ... # assert employee ``name`` is str ... name: str = sp.field(on_setattr=[IsInstance(str)]) ... # use callback compostion to assert employee ``age`` is int and positive ... age: int = sp.field(on_setattr=[IsInstance(int), Range(1)]) >>> employee = Employee(name="Asem", age=10) >>> print(employee) Employee(name=Asem, age=10)
Example
Private attribute using
alias
:>>> import sepes as sp >>> @sp.autoinit ... class Employee(sp.TreeClass): ... # `alias` is the name used in the constructor ... _name: str = sp.field(alias="name") >>> employee = Employee(name="Asem") # use `name` in the constructor >>> print(employee) # `_name` is the private attribute name Employee(_name=Asem)
Example
Buffer creation using
on_getattr
:>>> import sepes as sp >>> import jax >>> import jax.numpy as jnp >>> @sp.autoinit ... class Tree(sp.TreeClass): ... buffer: jax.Array = sp.field(on_getattr=[jax.lax.stop_gradient]) >>> tree = Tree(buffer=jnp.array((1.0, 2.0))) >>> def sum_buffer(tree): ... return tree.buffer.sum() >>> print(jax.grad(sum_buffer)(tree)) # no gradient on `buffer` Tree(buffer=[0. 0.])
Example
Parameterization using
on_getattr
:>>> import sepes as sp >>> import jax >>> import jax.numpy as jnp >>> def symmetric(array: jax.Array) -> jax.Array: ... triangle = jnp.triu(array) # upper triangle ... return triangle + triangle.transpose(-1, -2) >>> @sp.autoinit ... class Tree(sp.TreeClass): ... symmetric_matrix: jax.Array = sp.field(on_getattr=[symmetric]) >>> tree = Tree(symmetric_matrix=jnp.arange(9).reshape(3, 3)) >>> print(tree.symmetric_matrix) [[ 0 1 2] [ 1 8 5] [ 2 5 16]]
Note
field()
is commonly used to annotate the class attributes to be used by theautoinit()
decorator to generate the__init__
method similar todataclasses.dataclass
.field()
can be used without theautoinit()
as a descriptor to apply functions on the field values during initialization using theon_setattr
/on_getattr
argument.>>> import sepes as sp >>> def print_and_return(x): ... print(f"Setting {x}") ... return x >>> class Tree: ... # `a` must be defined as a class attribute for the descriptor to work ... a: int = sp.field(on_setattr=[print_and_return]) ... def __init__(self, a): ... self.a = a >>> tree = Tree(1) Setting 1
- sepes.fields(x)[source]#
Returns a tuple of
Field
objects for the given instance or class.Field
objects are generated from the class type hints and contains the information about the field information.if the user uses thesepes.field
to annotate.Note
If the class is not annotated, an empty tuple is returned.
The
Field
generation is cached for class and its bases.
- sepes.autoinit(klass)[source]#
A class decorator that generates the
__init__
method from type hints.Using the
autoinit
decorator, the user can define the class attributes using type hints and the__init__
method will be generated automatically>>> import sepes as sp >>> @sp.autoinit ... class Tree: ... x: int ... y: int
Is equivalent to:
>>> class Tree: ... def __init__(self, x: int, y: int): ... self.x = x ... self.y = y
Example
>>> import sepes as sp >>> import inspect >>> @sp.autoinit ... class Tree: ... x: int ... y: int >>> inspect.signature(Tree.__init__) <Signature (self, x: int, y: int) -> None> >>> tree = Tree(1, 2) >>> tree.x, tree.y (1, 2)
Example
Define fields with different argument kinds
>>> import sepes as sp >>> import inspect >>> @sp.autoinit ... class Tree: ... kw_only_field: int = sp.field(default=1, kind="KW_ONLY") ... pos_only_field: int = sp.field(default=2, kind="POS_ONLY") >>> inspect.signature(Tree.__init__) <Signature (self, pos_only_field: int = 2, /, *, kw_only_field: int = 1) -> None>
Example
Define a converter to apply
abs
on the field value>>> @sp.autoinit ... class Tree: ... a:int = sp.field(on_setattr=[abs]) >>> Tree(a=-1).a 1
Warning
The
autoinit
decorator will raiseTypeError
if the user defines__init__
method in the decorated class.Note
In case of inheritance, the
__init__
method is generated from the the type hints of the current class and any base classes that are decorated withautoinit
.
>>> import sepes as sp >>> import inspect >>> @sp.autoinit ... class Base: ... x: int >>> @sp.autoinit ... class Derived(Base): ... y: int >>> obj = Derived(x=1, y=2) >>> inspect.signature(obj.__init__) <Signature (x: int, y: int) -> None>
Base classes that are not decorated with
autoinit
are ignored during synthesis of the__init__
method.
>>> import sepes as sp >>> import inspect >>> class Base: ... x: int >>> @sp.autoinit ... class Derived(Base): ... y: int >>> obj = Derived(y=2) >>> inspect.signature(obj.__init__) <Signature (y: int) -> None>
Note
Use
autoinit
instead ofdataclasses.dataclass
if you want to usejax.Array
as a field default value. Asdataclasses.dataclass
will incorrectly raise an error starting from python 3.11 complaining thatjax.Array
is not immutable.Note
By default
autoinit
will raise an error if the user uses mutable defaults. To register an additional type to be excluded fromautoinit
, useautoinit.register_excluded_type()
, with an optionalreason
for excluding the type.>>> import sepes as sp >>> class T: ... pass >>> sp.autoinit.register_excluded_type(T, reason="not allowed") >>> @sp.autoinit ... class Tree: ... x: T = sp.field(default=T()) Traceback (most recent call last): ...
- sepes.leafwise(klass)[source]#
A class decorator that adds leafwise operators to a class.
Leafwise operators are operators that are applied to the leaves of a pytree. For example leafwise
__add__
is equivalent to:tree_map(lambda x: x + rhs, tree)
ifrhs
is a scalar.tree_map(lambda x, y: x + y, tree, rhs)
ifrhs
is a pytree with the same structure astree
.
- Parameters:
klass – The class to be decorated.
- Returns:
The decorated class.
Example
Use
numpy
functions onTreeClass`
classes decorated withleafwise()
>>> import sepes as sp >>> import jax.numpy as jnp >>> @sp.leafwise ... @sp.autoinit ... class Point(sp.TreeClass): ... x: float = 0.5 ... y: float = 1.0 ... description: str = "point coordinates" >>> # use :func:`tree_mask` to mask the non-inexact part of the tree >>> # i.e. mask the string leaf ``description`` to ``Point`` work >>> # with ``jax.numpy`` functions >>> co = sp.tree_mask(Point()) >>> print(sp.bcmap(jnp.where)(co > 0.5, co, 1000)) Point(x=1000.0, y=1.0, description=#point coordinates)
Note
If a mathematically equivalent operator is already defined on the class, then it is not overridden.
Method
Operator
__add__
+
__and__
&
__ceil__
math.ceil
__divmod__
divmod
__eq__
==
__floor__
math.floor
__floordiv__
//
__ge__
>=
__gt__
>
__invert__
~
__le__
<=
__lshift__
<<
__lt__
<
__matmul__
@
__mod__
%
__mul__
*
__ne__
!=
__neg__
-
__or__
|
__pos__
+
__pow__
**
__round__
round
__sub__
-
__truediv__
/
__trunc__
math.trunc
__xor__
^