coordax.tag

coordax.tag(tree: PyTree, *dims: str | Coordinate | ellipsis | None) PyTree[source]

Tag dimensions on all fields in a PyTree.

Parameters:
  • tree – The PyTree of fields.

  • *dims – Names or coordinates to tag the positional axes with.

Returns:

A new PyTree with all fields tagged.

Examples

>>> import coordax as cx
>>> import jax.numpy as jnp
>>> tree = {'a': cx.Field(jnp.zeros((2,)))}
>>> tree
{'a': <Field dims=(None,) shape=(2,) axes={} >}
>>> cx.tag(tree, 'x')
{'a': <Field dims=('x',) shape=(2,) axes={} >}