coordax.cpmap¶
- coordax.cpmap(fun: ~typing.Callable[[...], ~typing.Any], *, vmap: ~typing.Callable = <function vmap>) Callable[[...], Any][source]¶
Coordinate preserving cmap.
cpmap(fun)is an alias forcmap(fun, out_axes='same_as_input').Primary use case is applying a function over positional axes while preserving the dimensionality and the coordinate order.
- Parameters:
fun – Function to apply over positional axes of the inputs.
vmap – Vectorizing transformation to use when mapping over named axes. Defaults to jax.vmap. A different implementation can be used to make coordax compatible with custom objects (e.g. neural net modules).
- Returns:
A function that applies
funto positional axes of the inputs while vectorizing over all coordinate dimensions. The coordinate order of the inputs is preserved in the outputs.
Examples
>>> import coordax as cx >>> import jax.numpy as jnp
>>> field = cx.field(jnp.ones((2, 3, 4)), 'x', None, 'y') >>> cx.cpmap(lambda x: x**2)(field).dims ('x', None, 'y')
cpmaprequires all inputs to have the same named axes ordering:>>> a = cx.field(jnp.ones((2, 3)), 'x', 'y') >>> b = cx.field(jnp.ones((3, 2)), 'y', 'x') >>> cx.cpmap(jnp.add)(a, b) Traceback (most recent call last): ... ValueError: 'same_as_input' for out_axes requires all NamedArray inputs with named axes to have the same named_axes. Found multiple distinct named_axes on inputs: [{'x': 0, 'y': 1}, {'y': 0, 'x': 1}]
See also