coordax.cpmap

coordax.cpmap(fun: ~typing.Callable[[...], ~typing.Any], *, vmap: ~typing.Callable = <function vmap>) Callable[[...], Any][source]

Coordinate preserving cmap.

cpmap(fun) is an alias for cmap(fun, out_axes='same_as_input').

Primary use case is applying a function over positional axes while preserving the dimensionality and the coordinate order.

Parameters:
  • fun – Function to apply over positional axes of the inputs.

  • vmap – Vectorizing transformation to use when mapping over named axes. Defaults to jax.vmap. A different implementation can be used to make coordax compatible with custom objects (e.g. neural net modules).

Returns:

A function that applies fun to positional axes of the inputs while vectorizing over all coordinate dimensions. The coordinate order of the inputs is preserved in the outputs.

Examples

>>> import coordax as cx
>>> import jax.numpy as jnp
>>> field = cx.field(jnp.ones((2, 3, 4)), 'x', None, 'y')
>>> cx.cpmap(lambda x: x**2)(field).dims
('x', None, 'y')

cpmap requires all inputs to have the same named axes ordering:

>>> a = cx.field(jnp.ones((2, 3)), 'x', 'y')
>>> b = cx.field(jnp.ones((3, 2)), 'y', 'x')
>>> cx.cpmap(jnp.add)(a, b)
Traceback (most recent call last):
...
ValueError: 'same_as_input' for out_axes requires all NamedArray inputs with
named axes to have the same named_axes. Found multiple distinct
named_axes on inputs:
[{'x': 0, 'y': 1}, {'y': 0, 'x': 1}]

See also

coordax.cmap()