0

Orbax: Distributed Checkpointing with JAX

In a landscape of high-performance distributed ML systems, JAX has emerged as a framework of choice. However, JAX's modular design philosophy leaves it without a standardized checkpointing solution.

Year
2026
Hosting
Full text hostedCC-BY-4.0

Cite

Notes

Only stored in your browser.

Attribution

Abstract & full text
arxiv.org/abs/2605.23066CC-BY-4.0
TL;DR
Semantic Scholar
Attribution policy →

Abstract

In a landscape of high-performance distributed ML systems, JAX has emerged as a framework of choice. However, JAX's modular design philosophy leaves it without a standardized checkpointing solution. In this paper, we introduce Orbax, a modular, JAX-native checkpointing library that abstracts the complexities of distributed accelerator systems while also providing flexibility for user-friendly checkpoint manipulations throughout the ML model lifecycle. We demonstrate performance exceeding comparable PyTorch competitors by up to 3.5$\times$ for saving and 2$\times$ for loading. The library is available at https://github.com/google/orbax.