Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
keras-team
GitHub Repository: keras-team/keras-io
Path: blob/master/examples/vision/image_classification_from_scratch.py
5290 views
1
"""
2
Title: Image classification from scratch
3
Author: [fchollet](https://twitter.com/fchollet)
4
Date created: 2020/04/27
5
Last modified: 2023/11/09
6
Description: Training an image classifier from scratch on the Kaggle Cats vs Dogs dataset.
7
Accelerator: GPU
8
"""
9
10
"""
11
## Introduction
12
13
This example shows how to do image classification from scratch, starting from JPEG
14
image files on disk, without leveraging pre-trained weights or a pre-made Keras
15
Application model. We demonstrate the workflow on the Kaggle Cats vs Dogs binary
16
classification dataset.
17
18
We use the `image_dataset_from_directory` utility to generate the datasets, and
19
we use Keras image preprocessing layers for image standardization and data augmentation.
20
"""
21
22
"""
23
## Setup
24
"""
25
26
import os
27
import numpy as np
28
import keras
29
from keras import layers
30
from tensorflow import data as tf_data
31
import matplotlib.pyplot as plt
32
33
"""
34
## Load the data: the Cats vs Dogs dataset
35
36
### Raw data download
37
38
First, let's download the 786M ZIP archive of the raw data:
39
"""
40
41
"""shell
42
curl -O https://download.microsoft.com/download/3/E/1/3E1C3F21-ECDB-4869-8368-6DEBA77B919F/kagglecatsanddogs_5340.zip
43
"""
44
45
"""shell
46
unzip -q kagglecatsanddogs_5340.zip
47
ls
48
"""
49
50
"""
51
Now we have a `PetImages` folder which contain two subfolders, `Cat` and `Dog`. Each
52
subfolder contains image files for each category.
53
"""
54
55
"""shell
56
ls PetImages
57
"""
58
59
"""
60
### Filter out corrupted images
61
62
When working with lots of real-world image data, corrupted images are a common
63
occurence. Let's filter out badly-encoded images that do not feature the string "JFIF"
64
in their header.
65
"""
66
67
num_skipped = 0
68
for folder_name in ("Cat", "Dog"):
69
folder_path = os.path.join("PetImages", folder_name)
70
for fname in os.listdir(folder_path):
71
fpath = os.path.join(folder_path, fname)
72
try:
73
fobj = open(fpath, "rb")
74
is_jfif = b"JFIF" in fobj.peek(10)
75
finally:
76
fobj.close()
77
78
if not is_jfif:
79
num_skipped += 1
80
# Delete corrupted image
81
os.remove(fpath)
82
83
print(f"Deleted {num_skipped} images.")
84
85
"""
86
## Generate a `Dataset`
87
"""
88
89
image_size = (180, 180)
90
batch_size = 128
91
92
train_ds, val_ds = keras.utils.image_dataset_from_directory(
93
"PetImages",
94
validation_split=0.2,
95
subset="both",
96
seed=1337,
97
image_size=image_size,
98
batch_size=batch_size,
99
)
100
101
"""
102
## Visualize the data
103
104
Here are the first 9 images in the training dataset.
105
"""
106
107
108
plt.figure(figsize=(10, 10))
109
for images, labels in train_ds.take(1):
110
for i in range(9):
111
ax = plt.subplot(3, 3, i + 1)
112
plt.imshow(np.array(images[i]).astype("uint8"))
113
plt.title(int(labels[i]))
114
plt.axis("off")
115
116
"""
117
## Using image data augmentation
118
119
When you don't have a large image dataset, it's a good practice to artificially
120
introduce sample diversity by applying random yet realistic transformations to the
121
training images, such as random horizontal flipping or small random rotations. This
122
helps expose the model to different aspects of the training data while slowing down
123
overfitting.
124
"""
125
126
data_augmentation_layers = [
127
layers.RandomFlip("horizontal"),
128
layers.RandomRotation(0.1),
129
]
130
131
132
def data_augmentation(images):
133
for layer in data_augmentation_layers:
134
images = layer(images)
135
return images
136
137
138
"""
139
Let's visualize what the augmented samples look like, by applying `data_augmentation`
140
repeatedly to the first few images in the dataset:
141
"""
142
143
plt.figure(figsize=(10, 10))
144
for images, _ in train_ds.take(1):
145
for i in range(9):
146
augmented_images = data_augmentation(images)
147
ax = plt.subplot(3, 3, i + 1)
148
plt.imshow(np.array(augmented_images[0]).astype("uint8"))
149
plt.axis("off")
150
151
152
"""
153
## Standardizing the data
154
155
Our image are already in a standard size (180x180), as they are being yielded as
156
contiguous `float32` batches by our dataset. However, their RGB channel values are in
157
the `[0, 255]` range. This is not ideal for a neural network;
158
in general you should seek to make your input values small. Here, we will
159
standardize values to be in the `[0, 1]` by using a `Rescaling` layer at the start of
160
our model.
161
"""
162
163
"""
164
## Two options to preprocess the data
165
166
There are two ways you could be using the `data_augmentation` preprocessor:
167
168
**Option 1: Make it part of the model**, like this:
169
170
```python
171
inputs = keras.Input(shape=input_shape)
172
x = data_augmentation(inputs)
173
x = layers.Rescaling(1./255)(x)
174
... # Rest of the model
175
```
176
177
With this option, your data augmentation will happen *on device*, synchronously
178
with the rest of the model execution, meaning that it will benefit from GPU
179
acceleration.
180
181
Note that data augmentation is inactive at test time, so the input samples will only be
182
augmented during `fit()`, not when calling `evaluate()` or `predict()`.
183
184
If you're training on GPU, this may be a good option.
185
186
**Option 2: apply it to the dataset**, so as to obtain a dataset that yields batches of
187
augmented images, like this:
188
189
```python
190
augmented_train_ds = train_ds.map(lambda x, y: (data_augmentation(x), y))
191
```
192
193
With this option, your data augmentation will happen **on CPU**, asynchronously, and will
194
be buffered before going into the model.
195
196
If you're training on CPU, this is the better option, since it makes data augmentation
197
asynchronous and non-blocking.
198
199
In our case, we'll go with the second option. If you're not sure
200
which one to pick, this second option (asynchronous preprocessing) is always a solid choice.
201
"""
202
203
"""
204
## Configure the dataset for performance
205
206
Let's apply data augmentation to our training dataset,
207
and let's make sure to use buffered prefetching so we can yield data from disk without
208
having I/O becoming blocking:
209
"""
210
211
# Apply `data_augmentation` to the training images.
212
train_ds = train_ds.map(
213
lambda img, label: (data_augmentation(img), label),
214
num_parallel_calls=tf_data.AUTOTUNE,
215
)
216
# Prefetching samples in GPU memory helps maximize GPU utilization.
217
train_ds = train_ds.prefetch(tf_data.AUTOTUNE)
218
val_ds = val_ds.prefetch(tf_data.AUTOTUNE)
219
220
"""
221
## Build a model
222
223
We'll build a small version of the Xception network. We haven't particularly tried to
224
optimize the architecture; if you want to do a systematic search for the best model
225
configuration, consider using
226
[KerasTuner](https://github.com/keras-team/keras-tuner).
227
228
Note that:
229
230
- We start the model with the `data_augmentation` preprocessor, followed by a
231
`Rescaling` layer.
232
- We include a `Dropout` layer before the final classification layer.
233
"""
234
235
236
def make_model(input_shape, num_classes):
237
inputs = keras.Input(shape=input_shape)
238
239
# Entry block
240
x = layers.Rescaling(1.0 / 255)(inputs)
241
x = layers.Conv2D(128, 3, strides=2, padding="same")(x)
242
x = layers.BatchNormalization()(x)
243
x = layers.Activation("relu")(x)
244
245
previous_block_activation = x # Set aside residual
246
247
for size in [256, 512, 728]:
248
x = layers.Activation("relu")(x)
249
x = layers.SeparableConv2D(size, 3, padding="same")(x)
250
x = layers.BatchNormalization()(x)
251
252
x = layers.Activation("relu")(x)
253
x = layers.SeparableConv2D(size, 3, padding="same")(x)
254
x = layers.BatchNormalization()(x)
255
256
x = layers.MaxPooling2D(3, strides=2, padding="same")(x)
257
258
# Project residual
259
residual = layers.Conv2D(size, 1, strides=2, padding="same")(
260
previous_block_activation
261
)
262
x = layers.add([x, residual]) # Add back residual
263
previous_block_activation = x # Set aside next residual
264
265
x = layers.SeparableConv2D(1024, 3, padding="same")(x)
266
x = layers.BatchNormalization()(x)
267
x = layers.Activation("relu")(x)
268
269
x = layers.GlobalAveragePooling2D()(x)
270
if num_classes == 2:
271
units = 1
272
else:
273
units = num_classes
274
275
x = layers.Dropout(0.25)(x)
276
# We specify activation=None so as to return logits
277
outputs = layers.Dense(units, activation=None)(x)
278
return keras.Model(inputs, outputs)
279
280
281
model = make_model(input_shape=image_size + (3,), num_classes=2)
282
keras.utils.plot_model(model, show_shapes=True)
283
284
"""
285
## Train the model
286
"""
287
288
epochs = 25
289
290
callbacks = [
291
keras.callbacks.ModelCheckpoint("save_at_{epoch}.keras"),
292
]
293
model.compile(
294
optimizer=keras.optimizers.Adam(3e-4),
295
loss=keras.losses.BinaryCrossentropy(from_logits=True),
296
metrics=[keras.metrics.BinaryAccuracy(name="acc")],
297
)
298
model.fit(
299
train_ds,
300
epochs=epochs,
301
callbacks=callbacks,
302
validation_data=val_ds,
303
)
304
305
"""
306
We get to >90% validation accuracy after training for 25 epochs on the full dataset
307
(in practice, you can train for 50+ epochs before validation performance starts degrading).
308
"""
309
310
"""
311
## Run inference on new data
312
313
Note that data augmentation and dropout are inactive at inference time.
314
"""
315
316
img = keras.utils.load_img("PetImages/Cat/6779.jpg", target_size=image_size)
317
plt.imshow(img)
318
319
img_array = keras.utils.img_to_array(img)
320
img_array = keras.ops.expand_dims(img_array, 0) # Create batch axis
321
322
predictions = model.predict(img_array)
323
score = float(keras.ops.sigmoid(predictions[0][0]))
324
print(f"This image is {100 * (1 - score):.2f}% cat and {100 * score:.2f}% dog.")
325
326