coordax.untag¶
- coordax.untag(tree: Any, *dims: str | Coordinate, allow_missing: bool = False) Any[source]¶
Untag dimensions from all fields in a PyTree.
- Parameters:
tree – The PyTree of fields.
*dims – The axes to untag.
allow_missing – If True, only untags dims that are present on each field.
- Returns:
A new PyTree with all fields untagged.
Examples
>>> import coordax as cx >>> import jax.numpy as jnp >>> tree = {'a': cx.field(jnp.zeros((2,)), 'x')} >>> tree {'a': <Field dims=('x',) shape=(2,) axes={} >} >>> cx.untag(tree, 'x') {'a': <Field dims=(None,) shape=(2,) axes={} >}
See also