Path: blob/master/guides/custom_train_step_in_jax.py
5282 views
"""1Title: Customizing what happens in `fit()` with JAX2Author: [fchollet](https://twitter.com/fchollet)3Date created: 2023/06/274Last modified: 2023/06/275Description: Overriding the training step of the Model class with JAX.6Accelerator: GPU7"""89"""10## Introduction1112When you're doing supervised learning, you can use `fit()` and everything works13smoothly.1415When you need to take control of every little detail, you can write your own training16loop entirely from scratch.1718But what if you need a custom training algorithm, but you still want to benefit from19the convenient features of `fit()`, such as callbacks, built-in distribution support,20or step fusing?2122A core principle of Keras is **progressive disclosure of complexity**. You should23always be able to get into lower-level workflows in a gradual way. You shouldn't fall24off a cliff if the high-level functionality doesn't exactly match your use case. You25should be able to gain more control over the small details while retaining a26commensurate amount of high-level convenience.2728When you need to customize what `fit()` does, you should **override the training step29function of the `Model` class**. This is the function that is called by `fit()` for30every batch of data. You will then be able to call `fit()` as usual -- and it will be31running your own learning algorithm.3233Note that this pattern does not prevent you from building models with the Functional34API. You can do this whether you're building `Sequential` models, Functional API35models, or subclassed models.3637Let's see how that works.38"""3940"""41## Setup42"""4344import os4546# This guide can only be run with the JAX backend.47os.environ["KERAS_BACKEND"] = "jax"4849import jax50import keras51import numpy as np5253"""54## A first simple example5556Let's start from a simple example:5758- We create a new class that subclasses `keras.Model`.59- We implement a fully-stateless `compute_loss_and_updates()` method60to compute the loss as well as the updated values for the non-trainable61variables of the model. Internally, it calls `stateless_call()` and62the built-in `stateless_compute_loss()`.63- We implement a fully-stateless `train_step()` method to compute current64metric values (including the loss) as well as updated values for the65trainable variables, the optimizer variables, and the metric variables.66"""676869class CustomModel(keras.Model):70def compute_loss_and_updates(71self,72trainable_variables,73non_trainable_variables,74metrics_variables,75x,76y,77sample_weight,78training=False,79):80y_pred, non_trainable_variables = self.stateless_call(81trainable_variables,82non_trainable_variables,83x,84training=training,85)86loss, (87trainable_variables,88non_trainable_variables,89metrics_variables,90) = self.stateless_compute_loss(91trainable_variables,92non_trainable_variables,93metrics_variables,94x=x,95y=y,96y_pred=y_pred,97sample_weight=sample_weight,98training=training,99)100return loss, (y_pred, non_trainable_variables, metrics_variables)101102def train_step(self, state, data):103(104trainable_variables,105non_trainable_variables,106optimizer_variables,107metrics_variables,108) = state109x, y, sample_weight = keras.utils.unpack_x_y_sample_weight(data)110111# Get the gradient function.112grad_fn = jax.value_and_grad(self.compute_loss_and_updates, has_aux=True)113114# Compute the gradients.115(loss, (y_pred, non_trainable_variables, metrics_variables)), grads = grad_fn(116trainable_variables,117non_trainable_variables,118metrics_variables,119x,120y,121sample_weight,122training=True,123)124125# Update trainable variables and optimizer variables.126trainable_variables, optimizer_variables = self.optimizer.stateless_apply(127optimizer_variables, grads, trainable_variables128)129130# Update metrics.131new_metrics_vars = []132logs = {}133for metric in self.metrics:134this_metric_vars = metrics_variables[135len(new_metrics_vars) : len(new_metrics_vars) + len(metric.variables)136]137if metric.name == "loss":138this_metric_vars = metric.stateless_update_state(139this_metric_vars, loss, sample_weight=sample_weight140)141else:142this_metric_vars = metric.stateless_update_state(143this_metric_vars, y, y_pred, sample_weight=sample_weight144)145logs[metric.name] = metric.stateless_result(this_metric_vars)146new_metrics_vars += this_metric_vars147148# Return metric logs and updated state variables.149state = (150trainable_variables,151non_trainable_variables,152optimizer_variables,153new_metrics_vars,154)155return logs, state156157158"""159Let's try this out:160"""161162# Construct and compile an instance of CustomModel163inputs = keras.Input(shape=(32,))164outputs = keras.layers.Dense(1)(inputs)165model = CustomModel(inputs, outputs)166model.compile(optimizer="adam", loss="mse", metrics=["mae"])167168# Just use `fit` as usual169x = np.random.random((1000, 32))170y = np.random.random((1000, 1))171model.fit(x, y, epochs=3)172173174"""175## Going lower-level176177Naturally, you could just skip passing a loss function in `compile()`, and instead do178everything *manually* in `train_step`. Likewise for metrics.179180Here's a lower-level example, that only uses `compile()` to configure the optimizer:181"""182183184class CustomModel(keras.Model):185def __init__(self, *args, **kwargs):186super().__init__(*args, **kwargs)187self.loss_tracker = keras.metrics.Mean(name="loss")188self.mae_metric = keras.metrics.MeanAbsoluteError(name="mae")189self.loss_fn = keras.losses.MeanSquaredError()190191def compute_loss_and_updates(192self,193trainable_variables,194non_trainable_variables,195x,196y,197sample_weight,198training=False,199):200y_pred, non_trainable_variables = self.stateless_call(201trainable_variables,202non_trainable_variables,203x,204training=training,205)206loss = self.loss_fn(y, y_pred, sample_weight=sample_weight)207return loss, (y_pred, non_trainable_variables)208209def train_step(self, state, data):210(211trainable_variables,212non_trainable_variables,213optimizer_variables,214metrics_variables,215) = state216x, y, sample_weight = keras.utils.unpack_x_y_sample_weight(data)217218# Get the gradient function.219grad_fn = jax.value_and_grad(self.compute_loss_and_updates, has_aux=True)220221# Compute the gradients.222(loss, (y_pred, non_trainable_variables)), grads = grad_fn(223trainable_variables,224non_trainable_variables,225x,226y,227sample_weight,228training=True,229)230231# Update trainable variables and optimizer variables.232trainable_variables, optimizer_variables = self.optimizer.stateless_apply(233optimizer_variables, grads, trainable_variables234)235236# Update metrics.237loss_tracker_vars = metrics_variables[: len(self.loss_tracker.variables)]238mae_metric_vars = metrics_variables[len(self.loss_tracker.variables) :]239240loss_tracker_vars = self.loss_tracker.stateless_update_state(241loss_tracker_vars, loss, sample_weight=sample_weight242)243mae_metric_vars = self.mae_metric.stateless_update_state(244mae_metric_vars, y, y_pred, sample_weight=sample_weight245)246247logs = {}248logs[self.loss_tracker.name] = self.loss_tracker.stateless_result(249loss_tracker_vars250)251logs[self.mae_metric.name] = self.mae_metric.stateless_result(mae_metric_vars)252253new_metrics_vars = loss_tracker_vars + mae_metric_vars254255# Return metric logs and updated state variables.256state = (257trainable_variables,258non_trainable_variables,259optimizer_variables,260new_metrics_vars,261)262return logs, state263264@property265def metrics(self):266# We list our `Metric` objects here so that `reset_states()` can be267# called automatically at the start of each epoch268# or at the start of `evaluate()`.269return [self.loss_tracker, self.mae_metric]270271272# Construct an instance of CustomModel273inputs = keras.Input(shape=(32,))274outputs = keras.layers.Dense(1)(inputs)275model = CustomModel(inputs, outputs)276277# We don't pass a loss or metrics here.278model.compile(optimizer="adam")279280# Just use `fit` as usual -- you can use callbacks, etc.281x = np.random.random((1000, 32))282y = np.random.random((1000, 1))283model.fit(x, y, epochs=5)284285286"""287## Providing your own evaluation step288289What if you want to do the same for calls to `model.evaluate()`? Then you would290override `test_step` in exactly the same way. Here's what it looks like:291"""292293294class CustomModel(keras.Model):295def test_step(self, state, data):296# Unpack the data.297x, y, sample_weight = keras.utils.unpack_x_y_sample_weight(data)298(299trainable_variables,300non_trainable_variables,301metrics_variables,302) = state303304# Compute predictions and loss.305y_pred, non_trainable_variables = self.stateless_call(306trainable_variables,307non_trainable_variables,308x,309training=False,310)311loss, (312trainable_variables,313non_trainable_variables,314metrics_variables,315) = self.stateless_compute_loss(316trainable_variables,317non_trainable_variables,318metrics_variables,319x=x,320y=y,321y_pred=y_pred,322sample_weight=sample_weight,323training=False,324)325326# Update metrics.327new_metrics_vars = []328logs = {}329for metric in self.metrics:330this_metric_vars = metrics_variables[331len(new_metrics_vars) : len(new_metrics_vars) + len(metric.variables)332]333if metric.name == "loss":334this_metric_vars = metric.stateless_update_state(335this_metric_vars, loss, sample_weight=sample_weight336)337else:338this_metric_vars = metric.stateless_update_state(339this_metric_vars, y, y_pred, sample_weight=sample_weight340)341logs[metric.name] = metric.stateless_result(this_metric_vars)342new_metrics_vars += this_metric_vars343344# Return metric logs and updated state variables.345state = (346trainable_variables,347non_trainable_variables,348new_metrics_vars,349)350return logs, state351352353# Construct an instance of CustomModel354inputs = keras.Input(shape=(32,))355outputs = keras.layers.Dense(1)(inputs)356model = CustomModel(inputs, outputs)357model.compile(loss="mse", metrics=["mae"])358359# Evaluate with our custom test_step360x = np.random.random((1000, 32))361y = np.random.random((1000, 1))362model.evaluate(x, y, return_dict=True)363364365"""366That's it!367"""368369370