Path: blob/master/guides/ipynb/orbax_checkpoint.ipynb
5290 views
Orbax Checkpointing in Keras
Author: Samaneh Saadat
Date created: 2025/08/20
Last modified: 2025/08/20
Description: A guide on how to save Orbax checkpoints during model training with the JAX backend.
Introduction
Orbax is the default checkpointing library recommended for JAX ecosystem users. It is a high-level checkpointing library which provides functionality for both checkpoint management and composable and extensible serialization. This guide explains how to do Orbax checkpointing when training a model in the JAX backend.
Note that you should use Orbax checkpointing for multi-host training using Keras distribution API as the default Keras checkpointing currently does not support multi-host.
Setup
Let's start by installing Orbax checkpointing library:
We need to set the Keras backend to JAX as this guide is intended for the JAX backend. Then we import Keras and other libraries needed including the Orbax checkpointing library.
Orbax Callback
We need to create two main utilities to manage Orbax checkpointing in Keras:
KerasOrbaxCheckpointManager: A wrapper aroundorbax.checkpoint.CheckpointManagerfor Keras models.KerasOrbaxCheckpointManagerusesModel'sget_state_treeandset_state_treeAPIs to save and restore the model variables.OrbaxCheckpointCallback: A Keras callback that usesKerasOrbaxCheckpointManagerto automatically save and restore model states during training.
Orbax checkpointing in Keras is as simple as copying these utilities to your own codebase and passing OrbaxCheckpointCallback to the fit method.
An Orbax checkpointing example
Let's look at how we can use OrbaxCheckpointCallback to save Orbax checkpoints during the training. To get started, let's define a simple model and a toy training dataset.
Then, we create an Orbax checkpointing callback and pass it to the callbacks argument in the fit method.
Now if you look at the Orbax checkpoint directory, you can see all the files saved as part of Orbax checkpointing.