Path: blob/master/guides/ipynb/writing_your_own_callbacks.ipynb
5290 views
Writing your own callbacks
Authors: Rick Chao, Francois Chollet
Date created: 2019/03/20
Last modified: 2023/06/25
Description: Complete guide to writing new Keras callbacks.
Introduction
A callback is a powerful tool to customize the behavior of a Keras model during training, evaluation, or inference. Examples include keras.callbacks.TensorBoard to visualize training progress and results with TensorBoard, or keras.callbacks.ModelCheckpoint to periodically save your model during training.
In this guide, you will learn what a Keras callback is, what it can do, and how you can build your own. We provide a few demos of simple callback applications to get you started.
Setup
Keras callbacks overview
All callbacks subclass the keras.callbacks.Callback class, and override a set of methods called at various stages of training, testing, and predicting. Callbacks are useful to get a view on internal states and statistics of the model during training.
You can pass a list of callbacks (as the keyword argument callbacks) to the following model methods:
keras.Model.fit()keras.Model.evaluate()keras.Model.predict()
An overview of callback methods
Global methods
on_(train|test|predict)_begin(self, logs=None)
Called at the beginning of fit/evaluate/predict.
on_(train|test|predict)_end(self, logs=None)
Called at the end of fit/evaluate/predict.
Batch-level methods for training/testing/predicting
on_(train|test|predict)_batch_begin(self, batch, logs=None)
Called right before processing a batch during training/testing/predicting.
on_(train|test|predict)_batch_end(self, batch, logs=None)
Called at the end of training/testing/predicting a batch. Within this method, logs is a dict containing the metrics results.
Epoch-level methods (training only)
on_epoch_begin(self, epoch, logs=None)
Called at the beginning of an epoch during training.
on_epoch_end(self, epoch, logs=None)
Called at the end of an epoch during training.
A basic example
Let's take a look at a concrete example. To get started, let's import tensorflow and define a simple Sequential Keras model:
Then, load the MNIST data for training and testing from Keras datasets API:
Now, define a simple custom callback that logs:
When
fit/evaluate/predictstarts & endsWhen each epoch starts & ends
When each training batch starts & ends
When each evaluation (test) batch starts & ends
When each inference (prediction) batch starts & ends
Let's try it out:
Usage of logs dict
The logs dict contains the loss value, and all the metrics at the end of a batch or epoch. Example includes the loss and mean absolute error.
Usage of self.model attribute
In addition to receiving log information when one of their methods is called, callbacks have access to the model associated with the current round of training/evaluation/inference: self.model.
Here are a few of the things you can do with self.model in a callback:
Set
self.model.stop_training = Trueto immediately interrupt training.Mutate hyperparameters of the optimizer (available as
self.model.optimizer), such asself.model.optimizer.learning_rate.Save the model at period intervals.
Record the output of
model.predict()on a few test samples at the end of each epoch, to use as a sanity check during training.Extract visualizations of intermediate features at the end of each epoch, to monitor what the model is learning over time.
etc.
Let's see this in action in a couple of examples.
Examples of Keras callback applications
Early stopping at minimum loss
This first example shows the creation of a Callback that stops training when the minimum of loss has been reached, by setting the attribute self.model.stop_training (boolean). Optionally, you can provide an argument patience to specify how many epochs we should wait before stopping after having reached a local minimum.
keras.callbacks.EarlyStopping provides a more complete and general implementation.
Learning rate scheduling
In this example, we show how a custom Callback can be used to dynamically change the learning rate of the optimizer during the course of training.
See callbacks.LearningRateScheduler for a more general implementations.
Built-in Keras callbacks
Be sure to check out the existing Keras callbacks by reading the API docs. Applications include logging to CSV, saving the model, visualizing metrics in TensorBoard, and a lot more!