JAX Transformations¶
Key concepts:
Fieldis a pytree and is compatible with shape-preserving JAX transformationFieldwith leading positional axes are compatible with shape-altering transformations
Since Field is implemented as a pytree, operations like automatic differentiation, JIT are compatible out of the box.
import coordax as cx
import jax
import jax.numpy as jnp
import numpy as np
def loss(x, y):
return cx.cmap(jnp.sum)((x - y)**2).data
grad = jax.grad(loss)(cx.field(np.ones(5)), cx.field(np.zeros(5)))
also_grad = jax.jit(jax.grad(loss))(cx.field(np.ones(5)), cx.field(np.zeros(5)))
also_grad
<Field dims=(None,) shape=(5,) axes={} >
Operations that modify the shape of the pytree leaves (e.g. jax.vmap, jax.lax.scan) are also allowed, as long as they do not violate the coordinates metadata. This requirement limits their application to Field instances over leading positional axes, which covers most common use-cases:
jax.vmap(fn_on_field): explicit batchingjax.lax.scan(fn, fields_to_scan_over): scanning overFieldentries
Explicit vectorization with jax.vmap¶
cmap() is the most convenient way to vectorize functions in Coordax,
but Field is also compatible with jax.vmap, as long as the leading dimension is unlabeled:
data = np.arange(10).reshape((2, 5))
f = cx.field(data, None, 'y')
def identity_with_checks(x: cx.Field) -> cx.Field:
assert x.dims == ('y',) # under vmap `x` will be slices of `f` above.
return x
same_as_f = jax.vmap(identity_with_checks)(f)
same_as_f
<Field dims=(None, 'y') shape=(2, 5) axes={} >
vmap over a labeled dimension raises an error:
try:
jax.vmap(identity_with_checks)(f.tag('x'))
except Exception as e:
print(f'{type(e).__name__}: {e}')
ValueError: cannot trim named dimensions when unflattening to a NamedArray: ('x',). JAX pytree operations on NamedArray objects are only valid when they insert new leading dimensions, or trim unnamed leading dimensions. The sizes and positions (from the end) of all named dimensions must be preserved. If you are using vmap or scan, the first dimension must be unnamed.
Scanning over Field entries¶
Similarly, Field supports jax.lax.scan, as long all dimension scanned over are unlabeled:
data = np.arange(10).reshape((2, 5))
f = cx.field(data, None, 'y')
def identity_body_with_checks(unused_c, x: cx.Field) -> cx.Field:
assert x.dims == ('y',) # under vmap `x` will be slices of `f` above.
return None, x
_, same_as_f = jax.lax.scan(identity_body_with_checks, init=None, xs=f)
same_as_f
<Field dims=(None, 'y') shape=(2, 5) axes={} >
And likewise, scanning over labeled dimensions raises an error:
try:
jax.lax.scan(identity_body_with_checks, init=None, xs=f.tag('x'))
except Exception as e:
print(f'{type(e).__name__}: {e}')
ValueError: cannot trim named dimensions when unflattening to a NamedArray: ('x',). JAX pytree operations on NamedArray objects are only valid when they insert new leading dimensions, or trim unnamed leading dimensions. The sizes and positions (from the end) of all named dimensions must be preserved. If you are using vmap or scan, the first dimension must be unnamed.
jax.tree.map and Field¶
There are two main options how jax.tree.map can interact with Field instances:
Mapping over
Fieldleaves by settingis_leaf=cx.is_fieldin themapfunctionMapping over underlying data using the default
mapbehavior
The former is a safe approach as metadata for each Field is explicitly taken care of by the calling function.
The later, especially when performing shape modifications, should be exercised with great care.
Coordax implements shape checks to catch issues when the underlying data was modified to be no longer compatible with the coordinate labels, but erroneous transformations that result in Arrays of compatible shape cannot be identified in general.
Similar to the jax.vmap and jax.lax.scan, functions that trim or insert leading positional dimensions are supported, e.g.,
Adding new leading positional shape using jax.tree.map:
data = np.arange(10).reshape((2, 5))
f = cx.field(data, 'x', 'y')
# note that here tree.map operates on the underlying Array values.
with_leading_axis = jax.tree.map(lambda x: x[np.newaxis, ...], f)
with_leading_axis
<Field dims=(None, 'x', 'y') shape=(1, 2, 5) axes={} >
Trimming leading positional shape using jax.tree.map:
data = np.arange(10).reshape((1, 2, 5))
f = cx.field(data, None, 'x', 'y')
without_leading_axis = jax.tree.map(lambda x: x[0, ...], f)
without_leading_axis
<Field dims=('x', 'y') shape=(2, 5) axes={} >
Modifying data without changing shape using jax.tree.map:
data = np.arange(4).reshape((1, 2, 2))
f = cx.field(data, 'b', 'x', 'y')
double_f = jax.tree.map(lambda x: x * 2, f)
double_f
<Field dims=('b', 'x', 'y') shape=(1, 2, 2) axes={} >
The function above, however, cannot be distinguished from an accidental transpose of the last 2 axes, which could result in out of sync coordinates (unless motivated by the desired computation):
data = np.arange(4).reshape((2, 2))
f = cx.field(data, 'x', 'y')
yx_f_labeled_as_xy = jax.tree.map(lambda x: x.transpose(), f)
yx_f_labeled_as_xy
<Field dims=('x', 'y') shape=(2, 2) axes={} >