Keras callbacks

TensorFlow
Published

March 25, 2022

from tensorflow.keras import (datasets, callbacks, models, layers, optimizers, losses)

Load fashion mnist data

f_mnist = datasets.fashion_mnist

(x_train, y_train), (x_test, y_test) = f_mnist.load_data()
x_train = x_train / 255.
x_test = x_test / 255.

Create a callback class

class MyCallback(callbacks.Callback):
  def on_epoch_end(self, epoch, logs={}):
    '''
    This callback stop training when reaching 75% accuracy
    '''
    if (logs.get('acc') > 0.75):
      print(f'\nReached 75% accuracy. Terminating training ...')
      self.model.stop_training = True

Build the model

model = models.Sequential([layers.Input((28,28)),
                           layers.Flatten(),
                           layers.Dense(128, activation='relu'),
                           layers.Dense(10, activation='softmax'),
])

model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['acc'])

model.summary()
Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 flatten (Flatten)           (None, 784)               0         
                                                                 
 dense (Dense)               (None, 128)               100480    
                                                                 
 dense_1 (Dense)             (None, 10)                1290      
                                                                 
=================================================================
Total params: 101,770
Trainable params: 101,770
Non-trainable params: 0
_________________________________________________________________

Train the model

callback = MyCallback()

model.fit(x_train, y_train, epochs=100, callbacks=[callback])
Epoch 1/100
1856/1875 [============================>.] - ETA: 0s - loss: 0.4953 - acc: 0.8254
Reached 75% accuracy. Terminating training ...
1875/1875 [==============================] - 6s 3ms/step - loss: 0.4945 - acc: 0.8257
<keras.callbacks.History at 0x7fe21dc819d0>