coordax.cmap

coordax.cmap(fun: ~typing.Callable[[...], ~typing.Any], out_axes: dict[str, int] | ~typing.Literal['leading', 'trailing', 'same_as_input'] = 'trailing', *, vmap: ~typing.Callable = <function vmap>) Callable[[...], Any][source]

Vectorizes fun over coordinate dimensions of Field inputs.

cmap is a “coordinate vectorizing map”. It wraps an ordinary positional-axis-based function so that it accepts Field objects as input and produces Field objects as output, and vectorizes over all named dimensions using jax.vmap.

Unlike jax.vmap, the axes to vectorize over are inferred automatically from the named dimensions in the Field inputs, rather than being specified as part of the mapping transformation. Specifically, each dimension name that appears in any of the arguments is vectorized over jointly across all arguments that include that dimension, and is then included as a named dimension in the output. To make an axis visible to fun, you can call untag on the argument and pass the axis name(s) of interest; fun will then see those axes as positional axes instead of mapping over them.

untag and cmap are together the primary ways to apply individual operations to axes of a Field. tag can then be used on the result to re-bind names to positional axes.

Within fun, any mapped-over axes will be accessible using standard JAX collective operations like psum, although using this is usually unnecessary.

Parameters:
  • fun – Function to vectorize by name. This can take arbitrary arguments (even non-JAX-arraylike arguments or “static” axis sizes), but must produce a PyTree of JAX ArrayLike outputs.

  • out_axes

    Specifies strategy for choosing labeled axis positions in the outputs. Options include:

    • dict[str, int]: mapping from dimension name to axis position. Keys must include all named dimensions present in the inputs. Axis positions must be unique and either all positive or all negative.

    • ’leading’: dimension names will appear as the leading axes on every output, in order of their appearance on the inputs.

    • ’trailing’: dimension names will appear as the trailing axes on every output, in order of their appearance on the inputs.

    • ’same_as_input’: dimension names will appear in the same order as in the inputs, where the inputs must all have the same named axes and the same number of dimensions as the outputs.

  • vmap – Vectorizing transformation to use when mapping over named axes. Defaults to jax.vmap. A different implementation can be used to make coordax compatible with custom objects (e.g. neural net modules).

Returns:

A vectorized version of fun that applies original fun to locally positional dimensions in inputs, while vectorizing over all coordinate dimensions. All dimensions over which fun is vectorized will be present in every output.

Examples

>>> import coordax as cx
>>> import jax.numpy as jnp

Named axes are trailing by default:

>>> field = cx.field(jnp.ones((2, 3, 1)), 'x', None, 'y')
>>> cx.cmap(jnp.sin)(field).dims
(None, 'x', 'y')
>>> cx.cmap(jnp.sin, out_axes='leading')(field).dims
('x', 'y', None)
>>> cx.cmap(jnp.sin, out_axes='same_as_input')(field).dims
('x', None, 'y')

Multiple field arguments result in all input axes in the outputs, in order of appearence:

>>> a = cx.field(jnp.ones((2, 3)), 'x', 'y')
>>> b = cx.field(jnp.ones((3, 4)), 'y', 'z')
>>> cx.cmap(jnp.add)(a, b).dims
('x', 'y', 'z')

cmap leverages JAX’s pytree machinery, so arbitrarily nested inputs and outputs are supported, as well as keyword arguments:

>>> z2 = cx.field(jnp.ones((2, 4)))
>>> z3 = cx.field(jnp.ones((3, 4)))
>>> cx.cmap(jnp.concat)([z2, z3], axis=0)
<Field dims=(None, None) shape=(5, 4) axes={} >