CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutSign UpSign In
better-data-science

CoCalc provides the best real-time collaborative environment for Jupyter Notebooks, LaTeX documents, and SageMath, scalable from individual users to large groups and classes!

GitHub Repository: better-data-science/TensorFlow
Path: blob/main/007_Custom_Callbacks.ipynb
Views: 47
Kernel: tf
import os import numpy as np import pandas as pd import warnings from datetime import datetime import matplotlib.pyplot as plt plt.rcParams['figure.figsize'] = (24, 6) plt.rcParams['axes.spines.top'] = False plt.rcParams['axes.spines.right'] = False os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' warnings.filterwarnings('ignore') df = pd.read_csv('data/winequalityN.csv') df.sample(5)
from sklearn.model_selection import train_test_split from sklearn.preprocessing import StandardScaler # Prepare the data df = df.dropna() df['is_white_wine'] = [1 if typ == 'white' else 0 for typ in df['type']] df['is_good_wine'] = [1 if quality >= 6 else 0 for quality in df['quality']] df.drop(['type', 'quality'], axis=1, inplace=True) # Train/test split X = df.drop('is_good_wine', axis=1) y = df['is_good_wine'] X_train, X_test, y_train, y_test = train_test_split( X, y, test_size=0.2, random_state=42 ) # Scaling scaler = StandardScaler() X_train_scaled = scaler.fit_transform(X_train) X_test_scaled = scaler.transform(X_test)

Modelling

  • Let's declare a function that builds and trains the model

  • We're doing this because we'll train the exact same model multiple times

import tensorflow as tf tf.random.set_seed(42)
Init Plugin Init Graph Optimizer Init Kernel
def build_and_train(callbacks: list, num_epochs: int = 5) -> tf.keras.Sequential: model = tf.keras.Sequential([ tf.keras.layers.Dense(64, activation='relu'), tf.keras.layers.Dense(64, activation='relu'), tf.keras.layers.Dense(64, activation='relu'), tf.keras.layers.Dense(1, activation='sigmoid') ]) model.compile( loss=tf.keras.losses.binary_crossentropy, optimizer=tf.keras.optimizers.Adam(), metrics=[tf.keras.metrics.BinaryAccuracy(name='accuracy')] ) model.fit( X_train_scaled, y_train, epochs=num_epochs, validation_data=(X_test_scaled, y_test), callbacks=callbacks, verbose=0 ) return model

Basic custom callback

  • We'll define what happens on:

    • Train begin - we'll just print the time at which the training started

    • Train end - we'll print the time at which the training finsihed, how much time did the training last, and evaluation metrics (accuracy, precision, recall, f1) on the test set

class MyCallback(tf.keras.callbacks.Callback): def __init__(self): self.time_started = None self.time_finished = None def on_train_begin(self, logs=None): self.time_started = datetime.now() print(f'TRAINING STARTED | {self.time_started}\n') def on_train_end(self, logs=None): self.time_finished = datetime.now() train_duration = str(self.time_finished - self.time_started) print(f'\nTRAINING FINISHED | {self.time_finished} | Duration: {train_duration}') tl = f"Training loss: {logs['loss']:.5f}" ta = f"Training accuracy: {logs['accuracy']:.5f}" vl = f"Validation loss: {logs['val_loss']:.5f}" va = f"Validation accuracy: {logs['val_accuracy']:.5f}" print('\n'.join([tl, vl, ta, va]))
  • Pass in the callback like this:

model = build_and_train( callbacks=[MyCallback()] )
TRAINING STARTED | 2021-10-29 15:49:21.494512 TRAINING FINISHED | 2021-10-29 15:49:26.210987 | Duration: 0:00:04.716475 Training loss: 0.46859 Validation loss: 0.46578 Training accuracy: 0.77988 Validation accuracy: 0.77726

Extending the callback functionality

  • We'll also modify the behavior for a single epoch:

    • Epoch begin - just save the time to the constructor

    • Epoch end - Calculate epoch duration and keep track of the training and validation metrics. We'll print them in a somewhat of a visually apealing way

class MyCallback(tf.keras.callbacks.Callback): def __init__(self): self.time_started = None self.time_finished = None self.time_curr_epoch = None def on_train_begin(self, logs=None): self.time_started = datetime.now() print(f'TRAINING STARTED | {self.time_started}\n') def on_train_end(self, logs=None): self.time_finished = datetime.now() train_duration = str(self.time_finished - self.time_started) print(f'\nTRAINING FINISHED | {self.time_finished} | Duration: {train_duration}') tl = f"Training loss: {logs['loss']:.5f}" ta = f"Training accuracy: {logs['accuracy']:.5f}" vl = f"Validation loss: {logs['val_loss']:.5f}" va = f"Validation accuracy: {logs['val_accuracy']:.5f}" print('\n'.join([tl, vl, ta, va])) def on_epoch_begin(self, epoch, logs=None): self.time_curr_epoch = datetime.now() def on_epoch_end(self, epoch, logs=None): epoch_dur = (datetime.now() - self.time_curr_epoch).total_seconds() tl = logs['loss'] ta = logs['accuracy'] vl = logs['val_loss'] va = logs['val_accuracy'] train_metrics = f"train_loss: {tl:.5f}, train_accuracy: {ta:.5f}" valid_metrics = f"valid_loss: {vl:.5f}, valid_accuracy: {va:.5f}" print(f"Epoch: {epoch:4} | Runtime: {epoch_dur:.3f}s | {train_metrics} | {valid_metrics}")
model = build_and_train( callbacks=[MyCallback()] )
TRAINING STARTED | 2021-10-29 15:49:34.048506 Epoch: 0 | Runtime: 1.187s | train_loss: 0.53959, train_accuracy: 0.73095 | valid_loss: 0.49353, valid_accuracy: 0.75793 Epoch: 1 | Runtime: 0.898s | train_loss: 0.49826, train_accuracy: 0.76035 | valid_loss: 0.50824, valid_accuracy: 0.74865 Epoch: 2 | Runtime: 0.888s | train_loss: 0.48819, train_accuracy: 0.76402 | valid_loss: 0.47304, valid_accuracy: 0.76643 Epoch: 3 | Runtime: 0.874s | train_loss: 0.47488, train_accuracy: 0.77350 | valid_loss: 0.47708, valid_accuracy: 0.75870 Epoch: 4 | Runtime: 0.879s | train_loss: 0.46941, train_accuracy: 0.78085 | valid_loss: 0.47244, valid_accuracy: 0.76025 TRAINING FINISHED | 2021-10-29 15:49:38.785216 | Duration: 0:00:04.736710 Training loss: 0.46941 Validation loss: 0.47244 Training accuracy: 0.78085 Validation accuracy: 0.76025

Tweaking the functionality even further

  • We'll declare a function that plots training loss vs. validation loss and training accuracy vs. validation accuracy (`_plot_model_performance()``

  • We'll plot ot on training end

class MyCallback(tf.keras.callbacks.Callback): def __init__(self): self.time_started = None self.time_finished = None self.time_curr_epoch = None self.num_epochs = 0 self._loss, self._acc, self._val_loss, self._val_acc = [], [], [], [] def _plot_model_performance(self): fig, (ax1, ax2) = plt.subplots(1, 2) fig.suptitle('Model performance', size=20) ax1.plot(range(self.num_epochs), self._loss, label='Training loss') ax1.plot(range(self.num_epochs), self._val_loss, label='Validation loss') ax1.set_xlabel('Epoch', size=14) ax1.set_ylabel('Loss', size=14) ax1.legend() ax2.plot(range(self.num_epochs), self._acc, label='Training accuracy') ax2.plot(range(self.num_epochs), self._val_acc, label='Validation Accuracy') ax2.set_xlabel('Epoch', size=14) ax2.set_ylabel('Accuracy', size=14) ax2.legend() def on_train_begin(self, logs=None): self.time_started = datetime.now() print(f'TRAINING STARTED | {self.time_started}\n') def on_train_end(self, logs=None): self.time_finished = datetime.now() train_duration = str(self.time_finished - self.time_started) print(f'\nTRAINING FINISHED | {self.time_finished} | Duration: {train_duration}') tl = f"Training loss: {logs['loss']:.5f}" ta = f"Training accuracy: {logs['accuracy']:.5f}" vl = f"Validation loss: {logs['val_loss']:.5f}" va = f"Validation accuracy: {logs['val_accuracy']:.5f}" print('\n'.join([tl, vl, ta, va])) self._plot_model_performance() def on_epoch_begin(self, epoch, logs=None): self.time_curr_epoch = datetime.now() def on_epoch_end(self, epoch, logs=None): self.num_epochs += 1 epoch_dur = (datetime.now() - self.time_curr_epoch).total_seconds() tl = logs['loss'] ta = logs['accuracy'] vl = logs['val_loss'] va = logs['val_accuracy'] self._loss.append(tl); self._acc.append(ta); self._val_loss.append(vl); self._val_acc.append(va) train_metrics = f"train_loss: {tl:.5f}, train_accuracy: {ta:.5f}" valid_metrics = f"valid_loss: {vl:.5f}, valid_accuracy: {va:.5f}" print(f"Epoch: {epoch:4} | Runtime: {epoch_dur:.3f}s | {train_metrics} | {valid_metrics}")
model = build_and_train( callbacks=[MyCallback()], num_epochs=50 )
TRAINING STARTED | 2021-10-29 15:50:07.255394 Epoch: 0 | Runtime: 1.156s | train_loss: 0.54984, train_accuracy: 0.71412 | valid_loss: 0.49403, valid_accuracy: 0.75483 Epoch: 1 | Runtime: 0.863s | train_loss: 0.49869, train_accuracy: 0.75706 | valid_loss: 0.49263, valid_accuracy: 0.74710 Epoch: 2 | Runtime: 0.867s | train_loss: 0.48524, train_accuracy: 0.77234 | valid_loss: 0.46720, valid_accuracy: 0.77185 Epoch: 3 | Runtime: 0.865s | train_loss: 0.47193, train_accuracy: 0.77776 | valid_loss: 0.47483, valid_accuracy: 0.75638 Epoch: 4 | Runtime: 0.877s | train_loss: 0.46571, train_accuracy: 0.78414 | valid_loss: 0.46983, valid_accuracy: 0.76875 Epoch: 5 | Runtime: 0.873s | train_loss: 0.45538, train_accuracy: 0.78665 | valid_loss: 0.46185, valid_accuracy: 0.77108 Epoch: 6 | Runtime: 0.869s | train_loss: 0.45035, train_accuracy: 0.79052 | valid_loss: 0.46432, valid_accuracy: 0.77340 Epoch: 7 | Runtime: 0.866s | train_loss: 0.44260, train_accuracy: 0.79342 | valid_loss: 0.45941, valid_accuracy: 0.77494 Epoch: 8 | Runtime: 0.878s | train_loss: 0.43329, train_accuracy: 0.80116 | valid_loss: 0.46422, valid_accuracy: 0.77262 Epoch: 9 | Runtime: 0.923s | train_loss: 0.43049, train_accuracy: 0.80174 | valid_loss: 0.47019, valid_accuracy: 0.77494 Epoch: 10 | Runtime: 0.867s | train_loss: 0.42353, train_accuracy: 0.80580 | valid_loss: 0.46703, valid_accuracy: 0.78113 Epoch: 11 | Runtime: 0.860s | train_loss: 0.41527, train_accuracy: 0.81296 | valid_loss: 0.47062, valid_accuracy: 0.77185 Epoch: 12 | Runtime: 0.868s | train_loss: 0.41093, train_accuracy: 0.81122 | valid_loss: 0.48872, valid_accuracy: 0.77340 Epoch: 13 | Runtime: 0.861s | train_loss: 0.40597, train_accuracy: 0.80928 | valid_loss: 0.46695, valid_accuracy: 0.77417 Epoch: 14 | Runtime: 0.863s | train_loss: 0.39769, train_accuracy: 0.81857 | valid_loss: 0.48053, valid_accuracy: 0.77726 Epoch: 15 | Runtime: 0.868s | train_loss: 0.39267, train_accuracy: 0.82186 | valid_loss: 0.47627, valid_accuracy: 0.77108 Epoch: 16 | Runtime: 0.870s | train_loss: 0.38474, train_accuracy: 0.83037 | valid_loss: 0.48007, valid_accuracy: 0.77572 Epoch: 17 | Runtime: 0.862s | train_loss: 0.38112, train_accuracy: 0.82708 | valid_loss: 0.47731, valid_accuracy: 0.78268 Epoch: 18 | Runtime: 0.868s | train_loss: 0.37659, train_accuracy: 0.82592 | valid_loss: 0.48525, valid_accuracy: 0.77881 Epoch: 19 | Runtime: 0.868s | train_loss: 0.36965, train_accuracy: 0.83656 | valid_loss: 0.48052, valid_accuracy: 0.77804 Epoch: 20 | Runtime: 0.875s | train_loss: 0.36277, train_accuracy: 0.84217 | valid_loss: 0.48497, valid_accuracy: 0.78809 Epoch: 21 | Runtime: 0.863s | train_loss: 0.36157, train_accuracy: 0.83752 | valid_loss: 0.48956, valid_accuracy: 0.78113 Epoch: 22 | Runtime: 0.874s | train_loss: 0.35574, train_accuracy: 0.83849 | valid_loss: 0.48658, valid_accuracy: 0.77572 Epoch: 23 | Runtime: 0.866s | train_loss: 0.34575, train_accuracy: 0.85203 | valid_loss: 0.49739, valid_accuracy: 0.78036 Epoch: 24 | Runtime: 0.863s | train_loss: 0.34246, train_accuracy: 0.84874 | valid_loss: 0.48785, valid_accuracy: 0.78654 Epoch: 25 | Runtime: 0.870s | train_loss: 0.33659, train_accuracy: 0.85029 | valid_loss: 0.51253, valid_accuracy: 0.78886 Epoch: 26 | Runtime: 0.868s | train_loss: 0.32992, train_accuracy: 0.85300 | valid_loss: 0.52090, valid_accuracy: 0.77726 Epoch: 27 | Runtime: 0.870s | train_loss: 0.32881, train_accuracy: 0.85513 | valid_loss: 0.53420, valid_accuracy: 0.77572 Epoch: 28 | Runtime: 0.878s | train_loss: 0.32202, train_accuracy: 0.86402 | valid_loss: 0.53723, valid_accuracy: 0.77881 Epoch: 29 | Runtime: 0.877s | train_loss: 0.31739, train_accuracy: 0.86035 | valid_loss: 0.51675, valid_accuracy: 0.79273 Epoch: 30 | Runtime: 0.877s | train_loss: 0.31142, train_accuracy: 0.86325 | valid_loss: 0.52563, valid_accuracy: 0.78036 Epoch: 31 | Runtime: 0.888s | train_loss: 0.30740, train_accuracy: 0.86809 | valid_loss: 0.53693, valid_accuracy: 0.79505 Epoch: 32 | Runtime: 0.883s | train_loss: 0.30584, train_accuracy: 0.86151 | valid_loss: 0.52875, valid_accuracy: 0.77417 Epoch: 33 | Runtime: 0.878s | train_loss: 0.30174, train_accuracy: 0.87137 | valid_loss: 0.52646, valid_accuracy: 0.78036 Epoch: 34 | Runtime: 0.874s | train_loss: 0.29382, train_accuracy: 0.86983 | valid_loss: 0.54965, valid_accuracy: 0.77804 Epoch: 35 | Runtime: 0.873s | train_loss: 0.29185, train_accuracy: 0.87350 | valid_loss: 0.52804, valid_accuracy: 0.78113 Epoch: 36 | Runtime: 0.872s | train_loss: 0.28378, train_accuracy: 0.87872 | valid_loss: 0.53974, valid_accuracy: 0.78422 Epoch: 37 | Runtime: 0.865s | train_loss: 0.27826, train_accuracy: 0.88066 | valid_loss: 0.53486, valid_accuracy: 0.78036 Epoch: 38 | Runtime: 0.867s | train_loss: 0.27347, train_accuracy: 0.88607 | valid_loss: 0.57144, valid_accuracy: 0.77881 Epoch: 39 | Runtime: 0.874s | train_loss: 0.27613, train_accuracy: 0.88395 | valid_loss: 0.55132, valid_accuracy: 0.79118 Epoch: 40 | Runtime: 0.878s | train_loss: 0.26357, train_accuracy: 0.89033 | valid_loss: 0.55898, valid_accuracy: 0.78886 Epoch: 41 | Runtime: 0.872s | train_loss: 0.26624, train_accuracy: 0.88298 | valid_loss: 0.55616, valid_accuracy: 0.77958 Epoch: 42 | Runtime: 0.880s | train_loss: 0.26508, train_accuracy: 0.88839 | valid_loss: 0.57052, valid_accuracy: 0.78422 Epoch: 43 | Runtime: 0.895s | train_loss: 0.25255, train_accuracy: 0.89246 | valid_loss: 0.58852, valid_accuracy: 0.78577 Epoch: 44 | Runtime: 0.880s | train_loss: 0.24786, train_accuracy: 0.89845 | valid_loss: 0.58164, valid_accuracy: 0.79196 Epoch: 45 | Runtime: 0.881s | train_loss: 0.24558, train_accuracy: 0.89807 | valid_loss: 0.57115, valid_accuracy: 0.78500 Epoch: 46 | Runtime: 0.877s | train_loss: 0.23848, train_accuracy: 0.90000 | valid_loss: 0.58779, valid_accuracy: 0.77881 Epoch: 47 | Runtime: 0.879s | train_loss: 0.23460, train_accuracy: 0.90058 | valid_loss: 0.59602, valid_accuracy: 0.78422 Epoch: 48 | Runtime: 0.876s | train_loss: 0.23146, train_accuracy: 0.90812 | valid_loss: 0.59239, valid_accuracy: 0.78809 Epoch: 49 | Runtime: 0.885s | train_loss: 0.22946, train_accuracy: 0.90967 | valid_loss: 0.60174, valid_accuracy: 0.78190 TRAINING FINISHED | 2021-10-29 15:50:51.267975 | Duration: 0:00:44.012581 Training loss: 0.22946 Validation loss: 0.60174 Training accuracy: 0.90967 Validation accuracy: 0.78190
Image in a Jupyter notebook