Path: blob/master/examples/vision/image_classification_from_scratch.py
5290 views
"""1Title: Image classification from scratch2Author: [fchollet](https://twitter.com/fchollet)3Date created: 2020/04/274Last modified: 2023/11/095Description: Training an image classifier from scratch on the Kaggle Cats vs Dogs dataset.6Accelerator: GPU7"""89"""10## Introduction1112This example shows how to do image classification from scratch, starting from JPEG13image files on disk, without leveraging pre-trained weights or a pre-made Keras14Application model. We demonstrate the workflow on the Kaggle Cats vs Dogs binary15classification dataset.1617We use the `image_dataset_from_directory` utility to generate the datasets, and18we use Keras image preprocessing layers for image standardization and data augmentation.19"""2021"""22## Setup23"""2425import os26import numpy as np27import keras28from keras import layers29from tensorflow import data as tf_data30import matplotlib.pyplot as plt3132"""33## Load the data: the Cats vs Dogs dataset3435### Raw data download3637First, let's download the 786M ZIP archive of the raw data:38"""3940"""shell41curl -O https://download.microsoft.com/download/3/E/1/3E1C3F21-ECDB-4869-8368-6DEBA77B919F/kagglecatsanddogs_5340.zip42"""4344"""shell45unzip -q kagglecatsanddogs_5340.zip46ls47"""4849"""50Now we have a `PetImages` folder which contain two subfolders, `Cat` and `Dog`. Each51subfolder contains image files for each category.52"""5354"""shell55ls PetImages56"""5758"""59### Filter out corrupted images6061When working with lots of real-world image data, corrupted images are a common62occurence. Let's filter out badly-encoded images that do not feature the string "JFIF"63in their header.64"""6566num_skipped = 067for folder_name in ("Cat", "Dog"):68folder_path = os.path.join("PetImages", folder_name)69for fname in os.listdir(folder_path):70fpath = os.path.join(folder_path, fname)71try:72fobj = open(fpath, "rb")73is_jfif = b"JFIF" in fobj.peek(10)74finally:75fobj.close()7677if not is_jfif:78num_skipped += 179# Delete corrupted image80os.remove(fpath)8182print(f"Deleted {num_skipped} images.")8384"""85## Generate a `Dataset`86"""8788image_size = (180, 180)89batch_size = 1289091train_ds, val_ds = keras.utils.image_dataset_from_directory(92"PetImages",93validation_split=0.2,94subset="both",95seed=1337,96image_size=image_size,97batch_size=batch_size,98)99100"""101## Visualize the data102103Here are the first 9 images in the training dataset.104"""105106107plt.figure(figsize=(10, 10))108for images, labels in train_ds.take(1):109for i in range(9):110ax = plt.subplot(3, 3, i + 1)111plt.imshow(np.array(images[i]).astype("uint8"))112plt.title(int(labels[i]))113plt.axis("off")114115"""116## Using image data augmentation117118When you don't have a large image dataset, it's a good practice to artificially119introduce sample diversity by applying random yet realistic transformations to the120training images, such as random horizontal flipping or small random rotations. This121helps expose the model to different aspects of the training data while slowing down122overfitting.123"""124125data_augmentation_layers = [126layers.RandomFlip("horizontal"),127layers.RandomRotation(0.1),128]129130131def data_augmentation(images):132for layer in data_augmentation_layers:133images = layer(images)134return images135136137"""138Let's visualize what the augmented samples look like, by applying `data_augmentation`139repeatedly to the first few images in the dataset:140"""141142plt.figure(figsize=(10, 10))143for images, _ in train_ds.take(1):144for i in range(9):145augmented_images = data_augmentation(images)146ax = plt.subplot(3, 3, i + 1)147plt.imshow(np.array(augmented_images[0]).astype("uint8"))148plt.axis("off")149150151"""152## Standardizing the data153154Our image are already in a standard size (180x180), as they are being yielded as155contiguous `float32` batches by our dataset. However, their RGB channel values are in156the `[0, 255]` range. This is not ideal for a neural network;157in general you should seek to make your input values small. Here, we will158standardize values to be in the `[0, 1]` by using a `Rescaling` layer at the start of159our model.160"""161162"""163## Two options to preprocess the data164165There are two ways you could be using the `data_augmentation` preprocessor:166167**Option 1: Make it part of the model**, like this:168169```python170inputs = keras.Input(shape=input_shape)171x = data_augmentation(inputs)172x = layers.Rescaling(1./255)(x)173... # Rest of the model174```175176With this option, your data augmentation will happen *on device*, synchronously177with the rest of the model execution, meaning that it will benefit from GPU178acceleration.179180Note that data augmentation is inactive at test time, so the input samples will only be181augmented during `fit()`, not when calling `evaluate()` or `predict()`.182183If you're training on GPU, this may be a good option.184185**Option 2: apply it to the dataset**, so as to obtain a dataset that yields batches of186augmented images, like this:187188```python189augmented_train_ds = train_ds.map(lambda x, y: (data_augmentation(x), y))190```191192With this option, your data augmentation will happen **on CPU**, asynchronously, and will193be buffered before going into the model.194195If you're training on CPU, this is the better option, since it makes data augmentation196asynchronous and non-blocking.197198In our case, we'll go with the second option. If you're not sure199which one to pick, this second option (asynchronous preprocessing) is always a solid choice.200"""201202"""203## Configure the dataset for performance204205Let's apply data augmentation to our training dataset,206and let's make sure to use buffered prefetching so we can yield data from disk without207having I/O becoming blocking:208"""209210# Apply `data_augmentation` to the training images.211train_ds = train_ds.map(212lambda img, label: (data_augmentation(img), label),213num_parallel_calls=tf_data.AUTOTUNE,214)215# Prefetching samples in GPU memory helps maximize GPU utilization.216train_ds = train_ds.prefetch(tf_data.AUTOTUNE)217val_ds = val_ds.prefetch(tf_data.AUTOTUNE)218219"""220## Build a model221222We'll build a small version of the Xception network. We haven't particularly tried to223optimize the architecture; if you want to do a systematic search for the best model224configuration, consider using225[KerasTuner](https://github.com/keras-team/keras-tuner).226227Note that:228229- We start the model with the `data_augmentation` preprocessor, followed by a230`Rescaling` layer.231- We include a `Dropout` layer before the final classification layer.232"""233234235def make_model(input_shape, num_classes):236inputs = keras.Input(shape=input_shape)237238# Entry block239x = layers.Rescaling(1.0 / 255)(inputs)240x = layers.Conv2D(128, 3, strides=2, padding="same")(x)241x = layers.BatchNormalization()(x)242x = layers.Activation("relu")(x)243244previous_block_activation = x # Set aside residual245246for size in [256, 512, 728]:247x = layers.Activation("relu")(x)248x = layers.SeparableConv2D(size, 3, padding="same")(x)249x = layers.BatchNormalization()(x)250251x = layers.Activation("relu")(x)252x = layers.SeparableConv2D(size, 3, padding="same")(x)253x = layers.BatchNormalization()(x)254255x = layers.MaxPooling2D(3, strides=2, padding="same")(x)256257# Project residual258residual = layers.Conv2D(size, 1, strides=2, padding="same")(259previous_block_activation260)261x = layers.add([x, residual]) # Add back residual262previous_block_activation = x # Set aside next residual263264x = layers.SeparableConv2D(1024, 3, padding="same")(x)265x = layers.BatchNormalization()(x)266x = layers.Activation("relu")(x)267268x = layers.GlobalAveragePooling2D()(x)269if num_classes == 2:270units = 1271else:272units = num_classes273274x = layers.Dropout(0.25)(x)275# We specify activation=None so as to return logits276outputs = layers.Dense(units, activation=None)(x)277return keras.Model(inputs, outputs)278279280model = make_model(input_shape=image_size + (3,), num_classes=2)281keras.utils.plot_model(model, show_shapes=True)282283"""284## Train the model285"""286287epochs = 25288289callbacks = [290keras.callbacks.ModelCheckpoint("save_at_{epoch}.keras"),291]292model.compile(293optimizer=keras.optimizers.Adam(3e-4),294loss=keras.losses.BinaryCrossentropy(from_logits=True),295metrics=[keras.metrics.BinaryAccuracy(name="acc")],296)297model.fit(298train_ds,299epochs=epochs,300callbacks=callbacks,301validation_data=val_ds,302)303304"""305We get to >90% validation accuracy after training for 25 epochs on the full dataset306(in practice, you can train for 50+ epochs before validation performance starts degrading).307"""308309"""310## Run inference on new data311312Note that data augmentation and dropout are inactive at inference time.313"""314315img = keras.utils.load_img("PetImages/Cat/6779.jpg", target_size=image_size)316plt.imshow(img)317318img_array = keras.utils.img_to_array(img)319img_array = keras.ops.expand_dims(img_array, 0) # Create batch axis320321predictions = model.predict(img_array)322score = float(keras.ops.sigmoid(predictions[0][0]))323print(f"This image is {100 * (1 - score):.2f}% cat and {100 * score:.2f}% dog.")324325326