Introduction to various Keras Callbacks

Introduction to various Keras Callbacks
Photo by Vincent Yuan @USA / Unsplash

1 Introduction

In tensorflow.keras, callbacks can run along with the model's life cycle during fit, evaluate and predict processes. At present, tensorflow.keras has built many types of callbacks available for users to prevent overfitting, visualize the training process, debug, save model checkpoints, and generate TensorBoard, etc. Through this article, we will learn how to use various callbacks in tensorflow.keras and how to customize callbacks.

2 Using Callbacks

Using callbacks is very simple. First, define the callbacks, then pass the defined callbacks to the callbacks parameter in model.fit(), model.evaluate(), and model.predict().

Take the most common ModelCheckpoint as an example. The usage process is as follows:

...
model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=filePath,
    save_weights_only=True,
    monitor='val_accuracy',
    mode='max')

model.fit(x, y, callbacks=model_checkpoint_callback)

In this way, when the model is trained, the model checkpoints will be stored in the corresponding position for later use. In addition to ModelCheckpoint, there are many other types of callbacks available for use in TensorFlow 2.0. Let's explore them.

2.1 EarlyStopping

This callback can monitor the specified evaluation metric. During the training process, when the evaluation metric stops increasing, the training will end early to prevent overfitting. Its default parameters are as follows:

tf.keras.callbacks.EarlyStopping(monitor='val_loss',
        min_delta=0,
        patience=0,
        verbose=0,
        mode='auto',
        baseline=None,
        restore_best_weights=False)

The parameters are as follows:

  • monitor: the evaluation metric monitored by callbacks.
  • min_delta: the smallest metric improvement that will be counted.
  • patience: the number of epochs to wait before stopping training when the evaluation metric stops improving.
  • verbose: whether to print logs.
  • mode: the mode of monitoring metrics, such as whether to monitor if the metric is decreasing, increasing, or automatically inferred based on the metric name.
  • baseline: the baseline of the monitored metric. When the result of the model training is below the standard line, the training will stop.
  • restore_best_weights: whether to restore the model from the epoch with the best training effect. If set to False, the model weights will be restored from the last step.

2.2 LearningRateScheduler

This callback can adjust the learning rate during the model training process. Generally, the learning rate can be appropriately reduced as the number of training increases to help the model converge to the global optimum. Therefore, this callback needs to be used together with a learning rate scheduler. At the beginning of each epoch, the schedule function will obtain the latest learning rate and use it in the current epoch:

tf.keras.callbacks.LearningRateScheduler(
    schedule, verbose=0
)

# The scheduling function calls the initial learning rate before 10 epochs, and then the learning rate decreases exponentially
def scheduler(epoch, lr):
  if epoch < 10:
    return lr
  else:
    return lr * tf.math.exp(-0.1)

model = tf.keras.models.Sequential([tf.keras.layers.Dense(10)])
model.compile(tf.keras.optimizers.SGD(), loss='mse')
callback = tf.keras.callbacks.LearningRateScheduler(scheduler)
history = model.fit(np.arange(100).reshape(5, 20), np.zeros(5),
                    epochs=15, callbacks=[callback], verbose=0)

2.3 ReduceLROnPlateau

Compared with LearningRateScheduler, ReduceLROnPlateau does not adjust the learning rate according to the pre-set schedule. It reduces the learning rate when the evaluation metric stops improving.

tf.keras.callbacks.ReduceLROnPlateau(
    monitor='val_loss', factor=0.1, patience=10, verbose=0,
    mode='auto', min_delta=0.0001, cooldown=0, min_lr=0, **kwargs
)

Important parameters are:

  • factor: the degree of learning rate reduction, new_lr = lr * factor.
  • cooldown: the number of epochs to wait before monitoring the evaluation metric again.
  • min_lr: the minimum allowed learning rate.

2.4 TensorBoard

TensorBoard can conveniently display the model architecture and the training process. This callback can generate TensorBoard logs, and you can view the visualization results in TensorBoard after the training is completed.

tf.keras.callbacks.TensorBoard(
    log_dir='logs', histogram_freq=0, write_graph=True,
    write_images=False, write_steps_per_second=False, update_freq='epoch',
    profile_batch=2, embeddings_freq=0, embeddings_metadata=None, **kwargs
)

Important parameters are:

  • log_dir: the path to the log output.
  • histogram_freq: the frequency of calculating the activation function and weight histograms. If set to 0, the histograms will not be calculated.
  • write_graph: whether to visualize the graph in TensorBoard.
  • update_freq: a string that can be 'batch', 'epoch', or an integer. Loss and evaluation metrics will be written to TensorBoard after the specified processing completes. If set to an integer, it means that the loss and evaluation metrics will be written to TensorBoard after the specified number of samples have been trained.
  • write_images: whether to write weight histograms and other variables as images.
  • write_steps_per_second: whether to write the number of steps per second during processing.
  • profile_batch: the batch for profiling. The default value is 2, and -1 means that all batches will be profiled.

2.5 CSVLogger

As the name suggests, this callback can write the training process to a CSV file.

tf.keras.callbacks.CSVLogger(
    filename, separator=',', append=False
)

Important parameters:

  • append: whether to continue writing logs to existing files.

2.6 TerminateOnNaN

Stop training when the loss becomes NaN.

tf.keras.callbacks.TerminateOnNaN()

2.7 Custom Callback

In addition to the above callbacks, there are other callbacks available on the TensorFlow official website. When using multiple callbacks, you can pass multiple callbacks into a list, or use tf.keras.callbacks.CallbackList. In addition, you can also customize callbacks, which requires inheriting keras.callbacks.Callback and then overriding methods at different training stages.

training_finished = False

class MyCallback(tf.keras.callbacks.Callback):
  def on_train_end(self, logs=None):
    global training_finished
    training_finished = True

model = tf.keras.Sequential([tf.keras.layers.Dense(1, input_shape=(1,))])
model.compile(loss='mean_squared_error')
model.fit(tf.constant([[1.0]]), tf.constant([[1.0]]),
          callbacks=[MyCallback()])

assert training_finished == True

3 Conclusion

This article summarizes some commonly used tf.keras.callbacks. In actual work, please use them as needed and check the official documentation of tf.keras.callbacks to confirm the parameter values.

Hope this sharing is helpful to you, and welcome to leave a comment in the comment section to discuss!