Coordinate Map (cmap)¶
Key concepts:
cmap()(coordinate map) plays a central role in connecting dimension-aware representation to array functionscmaptransforms a functionfnoperating onArrayinputs to supportFieldargs, withfnapplied over positional axes and vectorized over named axesReturn values have axis order determined by the order of appearance or the
out_axeskwargAlternatively,
cpmap()is a convenient short-cut for coordinate preserving mappingsBuilt-in binary/unary operations on
Fieldare implemented usingcmap
If you’re familiar with Xarray, you can think of cmap() as a super-charged version of xarray.apply_ufunc enabled by jax.vmap.
If we have a function working with Array inputs that we want to apply to Field data, we can use the cmap to transform it to be compatible with Field arguments. Let’s start with standard jnp.add:
import coordax as cx
import jax
import jax.numpy as jnp
import numpy as np
rng = np.random.RandomState(0)
f_rand = cx.field(rng.uniform(size=3), 'x')
f_arange = cx.field(np.arange(3), 'x')
fs_added = cx.cmap(jnp.add)(f_rand, f_arange)
fs_added
<Field dims=('x',) shape=(3,) axes={} >
In the example above, jnp.add was vectorized over named dimension ‘x’ and applied to the remaining (empty) positional axes. This is reflected in the ‘x’ annotation of the output - vectorization over named axes preserves their coordinates.
If we pass inputs with x dimension untagged, the transformed function will applied to 1d arrays corresponding to the positional shape. In the case of jnp.add, the numerical result is the same since jnp.add == jax.vmap(jnp.add).
added = cx.cmap(jnp.add)(f_rand.untag('x'), f_arange.untag('x'))
assert (added.data == fs_added.data).all()
added
<Field dims=(None,) shape=(3,) axes={} >
Note that added no longer has ‘x’ coordinate, since inputs had no coordinates.
Let’s look at a more complex example of computing FFT over a single axis with multi-dimensional inputs
def fft(x):
assert x.ndim == 1 # make sure method is applied to vectors.
return jnp.fft.fft(x)
xc = cx.SizedAxis('x', 16)
rng = np.random.RandomState(0)
f_x = cx.field(rng.uniform(size=(3, 16, 4)), 'batch', xc, 'z')
f_kx = cx.cmap(fft)(f_x.untag(xc))
f_kx
<Field dims=(None, 'batch', 'z') shape=(16, 3, 4) axes={} >
The returned f_kx contains values where each ‘x’ slice was transformed using our function fft.
Note that f_kx has dimensions (None, 'batch', 'z'). By default, all named dimensions are vectorized and placed at the end of the result in the order in which they appear, with positional dimensions placed at the beginning of the result.
The resulting axis order can be explicitly controlled via out_axes argument in cmap, which takes a dictionary mapping axis names to their output axis indices.
Let’s look at a few more example to get a feel for this behavior:
out = cx.cmap(fft)(f_x.order_as('z', ...).untag(xc))
assert out.dims == (None, 'z', 'batch') # 'z', 'batch' is order of appearance.
out = cx.cmap(fft, out_axes={'batch': 0, 'z': 2})(f_x.untag(xc))
assert out.dims == ('batch', None, 'z') # explicitly indicate out_axes order.
In the examples above None corresponds to the positional axis returned by fft. Other choice of function could reduce over the input dimension (e.g. if we replace jnp.fft.fft with jnp.sum) or could create more positional axes (e.g. jnp.cov). This would be reflected in the positional shape of the output (with no positional axes for jnp.sum and two for jnp.cov).
This mechanics forms the primary design pattern of locally positional axes computation:
Desired axes are exposed using
untagComputation is performed using
cx.cmap(fn, ...)(untagged_inputs)New coordinates are added using
tag(if needed)
f_x = f_x.untag(xc)
f_kx = cx.cmap(fft, out_axes=f_x.named_axes)(f_x).tag('kx')
f_kx
<Field dims=('batch', 'kx', 'z') shape=(3, 16, 4) axes={} >
Transforming functions with multiple arguments¶
cmap supports transforming primitives that involve multiple inputs/outputs (in fact we have already looked at the jnp.add example). This provides a convenient mechanism for automatic vectorization based on function inputs, but lack of care can lead to unexpected results.
Lets consider computing log likelihood (and log pdf) of a sample predicted by a model. Our function would need to accept a prediction sample and parameters of a distribution.
import jax.scipy.stats as jsp_stats
def log_likelihood(sample, mean, std_dev):
log_pdf = jsp_stats.norm.logpdf(sample, loc=mean, scale=std_dev)
return jnp.sum(log_pdf), log_pdf
samples_data = jax.random.normal(jax.random.key(0), shape=(10, 3))
samples_data = samples_data * jnp.arange(10)[:, None] + jnp.arange(10)[:, None]
samples = cx.field(samples_data, 'batch', 'x')
means = jnp.ones(3) * 2 # parameters of distribution
std_dev = jnp.array([1.9, 2.0, 2.1])
inputs = samples.untag('x')
ll, log_pdf = cx.cmap(log_likelihood)(inputs, means, std_dev)
ll, log_pdf
(<Field dims=('batch',) shape=(10,) axes={} >,
<Field dims=(None, 'batch') shape=(3, 10) axes={} >)
In the example above we passed in jax.Array inputs for parameters of the distribution. We could obtain the same result by wrapping those arguments into Field:
cx.cmap(log_likelihood)(inputs, cx.field(means), cx.field(std_dev))
(<Field dims=('batch',) shape=(10,) axes={} >,
<Field dims=(None, 'batch') shape=(3, 10) axes={} >)
However, if we accidentally pass argument with named dimension attached, the whole computation will automatically vectorize over it!
cx.cmap(log_likelihood)(inputs, means, cx.field(std_dev, 'x_axis'))
(<Field dims=('batch', 'x_axis') shape=(10, 3) axes={} >,
<Field dims=(None, 'batch', 'x_axis') shape=(3, 10, 3) axes={} >)
Coordinate Preserving Map (cpmap)¶
cpmap() is a shorthand for cmap(..., out_axes='same_as_input'). It applies a function over positional axes while preserving the dimensionality and the coordinate order of the inputs.
This is particularly useful when you want the output field to have the same structure as the input field, without having to manually specify out_axes.
x = cx.SizedAxis('x', 3)
y = cx.SizedAxis('y', 2)
# Create a field with mixed named and positional axes
f = cx.field(jnp.ones((3, 4, 2)), x, None, y)
print(f'Input dims: {f.dims}')
# Standard cmap puts named axes at the end
print(f'cmap dims: {cx.cmap(lambda x: x)(f).dims}')
# cpmap preserves the order
print(f'cpmap dims: {cx.cpmap(lambda x: x)(f).dims}')
Input dims: ('x', None, 'y')
cmap dims: (None, 'x', 'y')
cpmap dims: ('x', None, 'y')
Arithmetic operations on Field¶
Field class implements default methods like addition, multiplication, division. These methods transform standard array methods using cmap and hence automatically align coordinates and broadcast data where necessary:
x_grid = cx.SizedAxis('x', 5)
y_grid = cx.SizedAxis('y', 3)
zero_field = cx.field(np.zeros((5, 3)), x_grid, y_grid)
ones_field = cx.field(np.ones((5, 3)), x_grid, y_grid)
zero_field + ones_field, zero_field * ones_field
(<Field dims=('x', 'y') shape=(5, 3) axes={'x': SizedAxis, 'y': SizedAxis} >,
<Field dims=('x', 'y') shape=(5, 3) axes={'x': SizedAxis, 'y': SizedAxis} >)
zero_field + ones_field.order_as('y', 'x') # still works due to auto alignment
<Field dims=('x', 'y') shape=(5, 3) axes={'x': SizedAxis, 'y': SizedAxis} >
zero_field.order_as('y', 'x') + ones_field # works, but note y, x result order.
<Field dims=('y', 'x') shape=(3, 5) axes={'y': SizedAxis, 'x': SizedAxis} >
Vectorization over named dimension results in effective broadcasting:
zero_field + cx.field(np.arange(zero_field.named_shape['x']), 'x')
<Field dims=('x', 'y') shape=(5, 3) axes={'x': SizedAxis, 'y': SizedAxis} >
The same gotcha about output axis order applies to arithmetic operations. For instance adding Field instances with trailing positional axis will move them the the first dimension as in the case of cmap(jnp.add):
addition = zero_field.untag('y') + ones_field.untag('y')
addition # Note how 'x' dim got moved to the end as the only vectorized axis.
<Field dims=(None, 'x') shape=(3, 5) axes={'x': SizedAxis} >