coordax.untag

coordax.untag(tree: Any, *dims: str | Coordinate, allow_missing: bool = False) Any[source]

Untag dimensions from all fields in a PyTree.

Parameters:
  • tree – The PyTree of fields.

  • *dims – The axes to untag.

  • allow_missing – If True, only untags dims that are present on each field.

Returns:

A new PyTree with all fields untagged.

Examples

>>> import coordax as cx
>>> import jax.numpy as jnp
>>> tree = {'a': cx.field(jnp.zeros((2,)), 'x')}
>>> tree
{'a': <Field dims=('x',) shape=(2,) axes={} >}
>>> cx.untag(tree, 'x')
{'a': <Field dims=(None,) shape=(2,) axes={} >}