Coordax: Coordinate Axes for JAX

Coordax is a Python library for labeled axes with JAX. Our approach is reminiscent of Xarray, but tailored to meet the needs of modern physics- and AI-based simulation codes written in JAX, such as NeuralGCM.

Compared to other libraries for labeled arrays, Coordax provides a handful of key features:

  1. First class integration with JAX, including support for arbitrary JAX transformations

  2. Easy wrapping of code not written for labeled arrays with cmap, inspired by Penzai

  3. Optional Coordinate objects, for advanced use-cases

  4. Lossless conversion to and from Xarray, for serialization and data analysis

Read on for more details!

Contents

Questions?

The best place to ask for help or report bugs is in the issue tracker on GitHub.