Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/guides/custom_train_step_in_jax.py
5282 views
1
"""
2
Title: Customizing what happens in `fit()` with JAX
3
Author: [fchollet](https://twitter.com/fchollet)
4
Date created: 2023/06/27
5
Last modified: 2023/06/27
6
Description: Overriding the training step of the Model class with JAX.
7
Accelerator: GPU
8
"""
9
10
"""
11
## Introduction
12
13
When you're doing supervised learning, you can use `fit()` and everything works
14
smoothly.
15
16
When you need to take control of every little detail, you can write your own training
17
loop entirely from scratch.
18
19
But what if you need a custom training algorithm, but you still want to benefit from
20
the convenient features of `fit()`, such as callbacks, built-in distribution support,
21
or step fusing?
22
23
A core principle of Keras is **progressive disclosure of complexity**. You should
24
always be able to get into lower-level workflows in a gradual way. You shouldn't fall
25
off a cliff if the high-level functionality doesn't exactly match your use case. You
26
should be able to gain more control over the small details while retaining a
27
commensurate amount of high-level convenience.
28
29
When you need to customize what `fit()` does, you should **override the training step
30
function of the `Model` class**. This is the function that is called by `fit()` for
31
every batch of data. You will then be able to call `fit()` as usual -- and it will be
32
running your own learning algorithm.
33
34
Note that this pattern does not prevent you from building models with the Functional
35
API. You can do this whether you're building `Sequential` models, Functional API
36
models, or subclassed models.
37
38
Let's see how that works.
39
"""
40
41
"""
42
## Setup
43
"""
44
45
import os
46
47
# This guide can only be run with the JAX backend.
48
os.environ["KERAS_BACKEND"] = "jax"
49
50
import jax
51
import keras
52
import numpy as np
53
54
"""
55
## A first simple example
56
57
Let's start from a simple example:
58
59
- We create a new class that subclasses `keras.Model`.
60
- We implement a fully-stateless `compute_loss_and_updates()` method
61
to compute the loss as well as the updated values for the non-trainable
62
variables of the model. Internally, it calls `stateless_call()` and
63
the built-in `stateless_compute_loss()`.
64
- We implement a fully-stateless `train_step()` method to compute current
65
metric values (including the loss) as well as updated values for the
66
trainable variables, the optimizer variables, and the metric variables.
67
"""
68
69
70
class CustomModel(keras.Model):
71
def compute_loss_and_updates(
72
self,
73
trainable_variables,
74
non_trainable_variables,
75
metrics_variables,
76
x,
77
y,
78
sample_weight,
79
training=False,
80
):
81
y_pred, non_trainable_variables = self.stateless_call(
82
trainable_variables,
83
non_trainable_variables,
84
x,
85
training=training,
86
)
87
loss, (
88
trainable_variables,
89
non_trainable_variables,
90
metrics_variables,
91
) = self.stateless_compute_loss(
92
trainable_variables,
93
non_trainable_variables,
94
metrics_variables,
95
x=x,
96
y=y,
97
y_pred=y_pred,
98
sample_weight=sample_weight,
99
training=training,
100
)
101
return loss, (y_pred, non_trainable_variables, metrics_variables)
102
103
def train_step(self, state, data):
104
(
105
trainable_variables,
106
non_trainable_variables,
107
optimizer_variables,
108
metrics_variables,
109
) = state
110
x, y, sample_weight = keras.utils.unpack_x_y_sample_weight(data)
111
112
# Get the gradient function.
113
grad_fn = jax.value_and_grad(self.compute_loss_and_updates, has_aux=True)
114
115
# Compute the gradients.
116
(loss, (y_pred, non_trainable_variables, metrics_variables)), grads = grad_fn(
117
trainable_variables,
118
non_trainable_variables,
119
metrics_variables,
120
x,
121
y,
122
sample_weight,
123
training=True,
124
)
125
126
# Update trainable variables and optimizer variables.
127
trainable_variables, optimizer_variables = self.optimizer.stateless_apply(
128
optimizer_variables, grads, trainable_variables
129
)
130
131
# Update metrics.
132
new_metrics_vars = []
133
logs = {}
134
for metric in self.metrics:
135
this_metric_vars = metrics_variables[
136
len(new_metrics_vars) : len(new_metrics_vars) + len(metric.variables)
137
]
138
if metric.name == "loss":
139
this_metric_vars = metric.stateless_update_state(
140
this_metric_vars, loss, sample_weight=sample_weight
141
)
142
else:
143
this_metric_vars = metric.stateless_update_state(
144
this_metric_vars, y, y_pred, sample_weight=sample_weight
145
)
146
logs[metric.name] = metric.stateless_result(this_metric_vars)
147
new_metrics_vars += this_metric_vars
148
149
# Return metric logs and updated state variables.
150
state = (
151
trainable_variables,
152
non_trainable_variables,
153
optimizer_variables,
154
new_metrics_vars,
155
)
156
return logs, state
157
158
159
"""
160
Let's try this out:
161
"""
162
163
# Construct and compile an instance of CustomModel
164
inputs = keras.Input(shape=(32,))
165
outputs = keras.layers.Dense(1)(inputs)
166
model = CustomModel(inputs, outputs)
167
model.compile(optimizer="adam", loss="mse", metrics=["mae"])
168
169
# Just use `fit` as usual
170
x = np.random.random((1000, 32))
171
y = np.random.random((1000, 1))
172
model.fit(x, y, epochs=3)
173
174
175
"""
176
## Going lower-level
177
178
Naturally, you could just skip passing a loss function in `compile()`, and instead do
179
everything *manually* in `train_step`. Likewise for metrics.
180
181
Here's a lower-level example, that only uses `compile()` to configure the optimizer:
182
"""
183
184
185
class CustomModel(keras.Model):
186
def __init__(self, *args, **kwargs):
187
super().__init__(*args, **kwargs)
188
self.loss_tracker = keras.metrics.Mean(name="loss")
189
self.mae_metric = keras.metrics.MeanAbsoluteError(name="mae")
190
self.loss_fn = keras.losses.MeanSquaredError()
191
192
def compute_loss_and_updates(
193
self,
194
trainable_variables,
195
non_trainable_variables,
196
x,
197
y,
198
sample_weight,
199
training=False,
200
):
201
y_pred, non_trainable_variables = self.stateless_call(
202
trainable_variables,
203
non_trainable_variables,
204
x,
205
training=training,
206
)
207
loss = self.loss_fn(y, y_pred, sample_weight=sample_weight)
208
return loss, (y_pred, non_trainable_variables)
209
210
def train_step(self, state, data):
211
(
212
trainable_variables,
213
non_trainable_variables,
214
optimizer_variables,
215
metrics_variables,
216
) = state
217
x, y, sample_weight = keras.utils.unpack_x_y_sample_weight(data)
218
219
# Get the gradient function.
220
grad_fn = jax.value_and_grad(self.compute_loss_and_updates, has_aux=True)
221
222
# Compute the gradients.
223
(loss, (y_pred, non_trainable_variables)), grads = grad_fn(
224
trainable_variables,
225
non_trainable_variables,
226
x,
227
y,
228
sample_weight,
229
training=True,
230
)
231
232
# Update trainable variables and optimizer variables.
233
trainable_variables, optimizer_variables = self.optimizer.stateless_apply(
234
optimizer_variables, grads, trainable_variables
235
)
236
237
# Update metrics.
238
loss_tracker_vars = metrics_variables[: len(self.loss_tracker.variables)]
239
mae_metric_vars = metrics_variables[len(self.loss_tracker.variables) :]
240
241
loss_tracker_vars = self.loss_tracker.stateless_update_state(
242
loss_tracker_vars, loss, sample_weight=sample_weight
243
)
244
mae_metric_vars = self.mae_metric.stateless_update_state(
245
mae_metric_vars, y, y_pred, sample_weight=sample_weight
246
)
247
248
logs = {}
249
logs[self.loss_tracker.name] = self.loss_tracker.stateless_result(
250
loss_tracker_vars
251
)
252
logs[self.mae_metric.name] = self.mae_metric.stateless_result(mae_metric_vars)
253
254
new_metrics_vars = loss_tracker_vars + mae_metric_vars
255
256
# Return metric logs and updated state variables.
257
state = (
258
trainable_variables,
259
non_trainable_variables,
260
optimizer_variables,
261
new_metrics_vars,
262
)
263
return logs, state
264
265
@property
266
def metrics(self):
267
# We list our `Metric` objects here so that `reset_states()` can be
268
# called automatically at the start of each epoch
269
# or at the start of `evaluate()`.
270
return [self.loss_tracker, self.mae_metric]
271
272
273
# Construct an instance of CustomModel
274
inputs = keras.Input(shape=(32,))
275
outputs = keras.layers.Dense(1)(inputs)
276
model = CustomModel(inputs, outputs)
277
278
# We don't pass a loss or metrics here.
279
model.compile(optimizer="adam")
280
281
# Just use `fit` as usual -- you can use callbacks, etc.
282
x = np.random.random((1000, 32))
283
y = np.random.random((1000, 1))
284
model.fit(x, y, epochs=5)
285
286
287
"""
288
## Providing your own evaluation step
289
290
What if you want to do the same for calls to `model.evaluate()`? Then you would
291
override `test_step` in exactly the same way. Here's what it looks like:
292
"""
293
294
295
class CustomModel(keras.Model):
296
def test_step(self, state, data):
297
# Unpack the data.
298
x, y, sample_weight = keras.utils.unpack_x_y_sample_weight(data)
299
(
300
trainable_variables,
301
non_trainable_variables,
302
metrics_variables,
303
) = state
304
305
# Compute predictions and loss.
306
y_pred, non_trainable_variables = self.stateless_call(
307
trainable_variables,
308
non_trainable_variables,
309
x,
310
training=False,
311
)
312
loss, (
313
trainable_variables,
314
non_trainable_variables,
315
metrics_variables,
316
) = self.stateless_compute_loss(
317
trainable_variables,
318
non_trainable_variables,
319
metrics_variables,
320
x=x,
321
y=y,
322
y_pred=y_pred,
323
sample_weight=sample_weight,
324
training=False,
325
)
326
327
# Update metrics.
328
new_metrics_vars = []
329
logs = {}
330
for metric in self.metrics:
331
this_metric_vars = metrics_variables[
332
len(new_metrics_vars) : len(new_metrics_vars) + len(metric.variables)
333
]
334
if metric.name == "loss":
335
this_metric_vars = metric.stateless_update_state(
336
this_metric_vars, loss, sample_weight=sample_weight
337
)
338
else:
339
this_metric_vars = metric.stateless_update_state(
340
this_metric_vars, y, y_pred, sample_weight=sample_weight
341
)
342
logs[metric.name] = metric.stateless_result(this_metric_vars)
343
new_metrics_vars += this_metric_vars
344
345
# Return metric logs and updated state variables.
346
state = (
347
trainable_variables,
348
non_trainable_variables,
349
new_metrics_vars,
350
)
351
return logs, state
352
353
354
# Construct an instance of CustomModel
355
inputs = keras.Input(shape=(32,))
356
outputs = keras.layers.Dense(1)(inputs)
357
model = CustomModel(inputs, outputs)
358
model.compile(loss="mse", metrics=["mae"])
359
360
# Evaluate with our custom test_step
361
x = np.random.random((1000, 32))
362
y = np.random.random((1000, 1))
363
model.evaluate(x, y, return_dict=True)
364
365
366
"""
367
That's it!
368
"""
369
370