coordax.cmap¶
- coordax.cmap(fun: ~typing.Callable[[...], ~typing.Any], out_axes: dict[str, int] | ~typing.Literal['leading', 'trailing', 'same_as_input'] = 'trailing', *, vmap: ~typing.Callable = <function vmap>) Callable[[...], Any][source]¶
Vectorizes
funover coordinate dimensions ofFieldinputs.cmapis a “coordinate vectorizing map”. It wraps an ordinary positional-axis-based function so that it acceptsFieldobjects as input and producesFieldobjects as output, and vectorizes over all named dimensions usingjax.vmap.Unlike
jax.vmap, the axes to vectorize over are inferred automatically from the named dimensions in theFieldinputs, rather than being specified as part of the mapping transformation. Specifically, each dimension name that appears in any of the arguments is vectorized over jointly across all arguments that include that dimension, and is then included as a named dimension in the output. To make an axis visible tofun, you can calluntagon the argument and pass the axis name(s) of interest;funwill then see those axes as positional axes instead of mapping over them.untagandcmapare together the primary ways to apply individual operations to axes of aField.tagcan then be used on the result to re-bind names to positional axes.Within
fun, any mapped-over axes will be accessible using standard JAX collective operations likepsum, although using this is usually unnecessary.- Parameters:
fun – Function to vectorize by name. This can take arbitrary arguments (even non-JAX-arraylike arguments or “static” axis sizes), but must produce a PyTree of JAX ArrayLike outputs.
out_axes –
Specifies strategy for choosing labeled axis positions in the outputs. Options include:
dict[str, int]: mapping from dimension name to axis position. Keys must include all named dimensions present in the inputs. Axis positions must be unique and either all positive or all negative.
’leading’: dimension names will appear as the leading axes on every output, in order of their appearance on the inputs.
’trailing’: dimension names will appear as the trailing axes on every output, in order of their appearance on the inputs.
’same_as_input’: dimension names will appear in the same order as in the inputs, where the inputs must all have the same named axes and the same number of dimensions as the outputs.
vmap – Vectorizing transformation to use when mapping over named axes. Defaults to
jax.vmap. A different implementation can be used to make coordax compatible with custom objects (e.g. neural net modules).
- Returns:
A vectorized version of
funthat applies originalfunto locally positional dimensions in inputs, while vectorizing over all coordinate dimensions. All dimensions over whichfunis vectorized will be present in every output.
Examples
>>> import coordax as cx >>> import jax.numpy as jnp
Named axes are trailing by default:
>>> field = cx.field(jnp.ones((2, 3, 1)), 'x', None, 'y') >>> cx.cmap(jnp.sin)(field).dims (None, 'x', 'y') >>> cx.cmap(jnp.sin, out_axes='leading')(field).dims ('x', 'y', None) >>> cx.cmap(jnp.sin, out_axes='same_as_input')(field).dims ('x', None, 'y')
Multiple field arguments result in all input axes in the outputs, in order of appearence:
>>> a = cx.field(jnp.ones((2, 3)), 'x', 'y') >>> b = cx.field(jnp.ones((3, 4)), 'y', 'z') >>> cx.cmap(jnp.add)(a, b).dims ('x', 'y', 'z')
cmapleverages JAX’s pytree machinery, so arbitrarily nested inputs and outputs are supported, as well as keyword arguments:>>> z2 = cx.field(jnp.ones((2, 4))) >>> z3 = cx.field(jnp.ones((3, 4))) >>> cx.cmap(jnp.concat)([z2, z3], axis=0) <Field dims=(None, None) shape=(5, 4) axes={} >