coordax.tag¶
- coordax.tag(tree: PyTree, *dims: str | Coordinate | ellipsis | None) PyTree[source]¶
Tag dimensions on all fields in a PyTree.
- Parameters:
tree – The PyTree of fields.
*dims – Names or coordinates to tag the positional axes with.
- Returns:
A new PyTree with all fields tagged.
Examples
>>> import coordax as cx >>> import jax.numpy as jnp >>> tree = {'a': cx.Field(jnp.zeros((2,)))} >>> tree {'a': <Field dims=(None,) shape=(2,) axes={} >} >>> cx.tag(tree, 'x') {'a': <Field dims=('x',) shape=(2,) axes={} >}
See also