Coordinate Objects

Key concepts:

  • Coordinate objects store coordinate-specific metadata and methods

  • Coordinate classes are static pytree nodes which enable compile-time checks

While strings are convenient for simple labeling, Coordinate objects allow you to attach data (like tick values) to axes and enforce alignment similar to Xarray.

Using coordinates with fields

The primary way to use Coordinate objects is by passing them to cx.field when creating a Field. This associates the coordinate with the corresponding dimension of the array.

import coordax as cx
import jax.numpy as jnp
import numpy as np

# Define a coordinate
x_axis = cx.LabeledAxis('x', np.arange(5))

# Create a field using the coordinate
# The coordinate 'x_axis' is associated with the first dimension (size 5)
f = cx.field(jnp.ones(5), x_axis)

print(f"Field dims: {f.dims}")
print(f"Field axes: {f.axes}")
Field dims: ('x',)
Field axes: {'x': coordax.LabeledAxis('x', ticks=array([0, 1, 2, 3, 4]))}

Like axis names, coordinates are automatically propagated through Coordax operations.

(f + 1).axes['x']
coordax.LabeledAxis('x', ticks=array([0, 1, 2, 3, 4]))

Coordinates have dims and shape, which are used for dims and shape on associated Field axes.

print(f"{x_axis.dims=}")
print(f"{x_axis.shape=}")
x_axis.dims=('x',)
x_axis.shape=(5,)

Coordinates also have their own fields attribute, a dictionary of any associated fields with the coordinate (e.g., tick labels).

x_axis.fields
{'x': <Field dims=('x',) shape=(5,) axes={'x': LabeledAxis} >}

You can pull out these same coordinate fields, merged across all associated coordinates, from Field.coord_fields.

f.coord_fields
{'x': <Field dims=('x',) shape=(5,) axes={'x': LabeledAxis} >}

Coordinate checks

Most manipulations on Field objects require exact coordinate alignment. When employing more complex coordinate objects that carry information beyond name and shape, this provides a powerful check to catch alignment and coordinate mismatch errors. In this block we will use standard SizeAxis and LabeledAxis coordinates to demonstrate alignment checks and will show how to implement custom coordinates in the following section.

This coordinate equality rule is relaxed for arguments passed to tag, untag, order_as etc. Passing a dimension name (str) is considered sufficient to express the user intent (assuming a coordinate with matching name is present).

import coordax as cx
import jax
import numpy as np

xc, yc = cx.SizedAxis('x', 2), cx.SizedAxis('y', 3)
f_xy = cx.field(np.arange(xc.size * yc.size).reshape((xc.size, yc.size)), xc, yc)

f_yx = f_xy.order_as(yc, xc)  # works - we use same coordinates.
also_f_yx = f_xy.order_as('y', 'x')  # also works

x_grid = cx.LabeledAxis('x', np.linspace(0, np.pi, 2))
y_grid = cx.LabeledAxis('y', np.linspace(0, 1, 3))
try:
  f_xy.order_as(y_grid, x_grid)  # raises, coordinates are different
except Exception as e:
  print(f'{type(e).__name__}: {e}')
ValueError: coordinate not equal to the corresponding coordinate on this field:
coordax.LabeledAxis('y', ticks=array([0. , 0.5, 1. ]))
vs
coordax.SizedAxis('y', size=3)

The LabeledAxis equality includes a check on coordinate ticks, so if coordinates differ in the exact placement of the tick values, an error is raised. This is particularly relevant for numerical methods where fields could have offsets within computational cells.

x_bounds, dx = np.linspace(0, 2 * np.pi, 10, endpoint=False, retstep=True)
x_centers = np.linspace(dx / 2, 2 * np.pi - dx / 2, 10)
x_grid_bounds = cx.LabeledAxis('x', x_bounds)
x_grid_centers = cx.LabeledAxis('x', x_centers)

f = cx.field(np.ones(10), x_grid_centers)

assert x_grid_bounds.dims == x_grid_centers.dims  # dims are the same.
assert x_grid_bounds.shape == x_grid_centers.shape  # same, compatible shape.
assert x_grid_bounds != x_grid_centers  # not equal, since tick values differ.

try:
  f.untag(x_grid_bounds)  # raises, coordinates are different
except Exception as e:
  print(f'{type(e).__name__}: {e}')
ValueError: coordinate not equal to the corresponding coordinate on this field:
coordax.LabeledAxis('x', ticks=array([0.        , 0.62831853, 1.25663706, 1.88495559, 2.51327412,
       3.14159265, 3.76991118, 4.39822972, 5.02654825, 5.65486678]))
vs
coordax.LabeledAxis('x', ticks=array([0.31415927, 0.9424778 , 1.57079633, 2.19911486, 2.82743339,
       3.45575192, 4.08407045, 4.71238898, 5.34070751, 5.96902604]))

Standard coordinate types

Coordax provides several built-in coordinate types:

  • SizedAxis: Minimal coordinate, only checks size.

  • LabeledAxis: Stores tick values (e.g. grid points) and checks them for equality.

  • DummyAxis: Placeholder for dimensions without associated coordinate values.

  • Scalar: Zero-dimensional sentinel coordinate for scalars.

SizedAxis

SizedAxis is the simplest coordinate type. It only checks that the dimension size matches. It does not carry any additional data.

import numpy as np

x = cx.SizedAxis('x', 5)
print(x)
f = cx.field(jnp.ones(5), x)
print(f.dims)
coordax.SizedAxis('x', size=5)
('x',)

LabeledAxis

LabeledAxis associates a dimension with a 1D array of tick values (e.g. grid points or labels). It checks for equality of these values when aligning fields.

ticks = np.linspace(0, 1, 5)
y = cx.LabeledAxis('y', ticks)
print(y)
print(y.fields['y'])  # Access the coordinate field
coordax.LabeledAxis('y', ticks=array([0.  , 0.25, 0.5 , 0.75, 1.  ]))
<Field dims=('y',) shape=(5,) axes={'y': LabeledAxis} >

Scalar

Scalar is a special coordinate for 0-dimensional data. It has no dimensions and no shape, and is mostly useful to ensure that every Field can have an associated Coordinate object.

scalar = cx.Scalar()
print(scalar)
print(scalar.shape)
Scalar()
()

DummyAxis

DummyAxis is the placeholder coordinate created when a coordinate is necessary, but does not exist on a Field (e.g., because it was only indicated with a string).

You can construct it explicitly, but generally don’t need to. Note that DummyAxis coordinates are automatically dropped when used on a Field.

dummy = cx.DummyAxis('d', 5)
print(dummy)
f_dummy = cx.field(np.zeros(5), dummy)
print(f_dummy.dims)
coordax.DummyAxis('d', size=5)
('d',)

CartesianProduct

In general, Coordinate can include multiple dimensions. One special instance of such coordinate is a CartesianProduct, that simply bundles multiple Coordinate primitives together. The compose() helper is the most convenent way to create a CartesianProduct.

x_axis, y_axis = cx.SizedAxis('x', 6), cx.SizedAxis('y', 7)
xy_coord = cx.coords.compose(x_axis, y_axis)
print(xy_coord)
print(f'{xy_coord.dims=}')
print(f'{xy_coord.shape=}')
CartesianProduct(coordinates=(coordax.SizedAxis('x', size=6), coordax.SizedAxis('y', size=7)))
xy_coord.dims=('x', 'y')
xy_coord.shape=(6, 7)

Multi-dimensional coordinates can be used the same way to wrap, tag, untag, etc.

x_axis, y_axis = cx.SizedAxis('x', 6), cx.SizedAxis('y', 7)
xy_coord = cx.coords.compose(x_axis, y_axis)
f = cx.field(np.ones((3, 6, 7)), 'batch', xy_coord)
print(f'{f.dims=}')
print(f'{f.untag(xy_coord).dims=}')
f.dims=('batch', 'x', 'y')
f.untag(xy_coord).dims=('batch', None, None)

Coordinates for field structure

The dimensions, shape and axes of Field objects always match a single coordinate object, via the computed coordinate attribute. This makes coordinates convenient as a way to represent the structure of fields without actual data values.

For fields with a single coordinate object, this will be the single Coordinate object:

field = cx.field(np.zeros((6,)), x_axis)
assert field.coordinate == field.axes['x'] == x_axis

For fields with multiple coordinates, this will be a dynamically constructed CartesianProduct:

field = cx.field(np.zeros((6, 7)), x_axis, y_axis)
field.coordinate
CartesianProduct(coordinates=(coordax.SizedAxis('x', size=6), coordax.SizedAxis('y', size=7)))

To facilitate using Coordinate objects to represent data shapes, Coordax comes with a handful of utilties for manipulating coordinates on their own, including canonicalize(), insert_axes(), and replace_axes().

Custom coordinates

Users can also define custom coordinates by subclassing Coordinate. This is useful for propagating metadata and associating custom methods that propagate automatically.

Coordinate objects represent one or more Array axes, specifying their names, shape and potentially providing additional values and methods associated with the coordinate. Subclasses must implement at least dims and shape:

  • dims: tuple of dimension names

  • shape: shape of the coordinate object

Other methods in the Coordinate interface are optional:

  • fields: holds named supporting values

  • from_xarray: construct this coordinate from an Xarray coordinate

Here’s a simple example of a custom coordinate subclass:

import dataclasses

@jax.tree_util.register_static
@dataclasses.dataclass(frozen=True)
class UniformAxis(cx.Coordinate):
  """Cell centered coordinate with uniform discretization of (0, `length`)."""

  name: str
  size: int
  length: float

  @property
  def dims(self):
    return (self.name,)

  @property
  def shape(self) -> tuple[int, ...]:
    return (self.size,)

  @property
  def fields(self):
    delta = self.length / self.size
    cell_centers = np.linspace(delta / 2, self.length - delta / 2, self.size)
    return {self.name: cx.field(cell_centers, self)}
z_centers = UniformAxis('z', 10, np.pi * 2)
print(z_centers)
z_centers.fields['z']  # access to supporting value from `@fields`.
UniformAxis(name='z', size=10, length=6.283185307179586)
<Field dims=('z',) shape=(10,) axes={'z': UniformAxis} >

Coordinates must also be registered “static” JAX pytrees, i.e., without any array leaves, which is required for making coordinates work inside arbitrary JAX transformations.

This means that coordinates can’t use JAX arrays internally for storing their state, but they can make use of NumPy arrays. Static JAX pytrees must be hashable, this requires a bit of care to implementing __eq__ and __hash__ methods according to Python’s expectations (__eq__ must return a boolean, and __hash__ must be the same for arrays that compare equal). The easiest way to do this is by using cx.coords.ArrayKey inside __eq__ and __hash__, as in this minimal example:

@jax.tree_util.register_static
@dataclasses.dataclass(frozen=True)
class XAxis(cx.Coordinate):
  ticks: np.ndarray

  def __post_init__(self):
    assert self.ticks.ndim == 1

  @property
  def dims(self) -> tuple[str, ...]:
    return ('x',)

  @property
  def shape(self) -> tuple[int, ...]:
    return self.ticks.shape

  def _array_key(self):
    return cx.coords.ArrayKey(self.ticks)

  def __eq__(self, other):
    return (
        isinstance(other, XAxis) and self._array_key() == other._array_key()
    )

  def __hash__(self) -> int:
    return hash(self._array_key())
cx.field(np.zeros(10), XAxis(np.arange(10)))
<Field dims=('x',) shape=(10,) axes={'x': XAxis} >

There is also the optional from_xarray method for unserializing a coordinate from Xarray, as described in Serialization of custom coordinates.