SciTech-BigDataAIML-Tensorflow-Writing your own callbacks

发布时间 2023-12-31 14:19:14作者: abaelhe

Introduction
A powerful callback was used to customize the behavior of a Keras model during training, evaluation, or inference.
Examples include tf.keras.callbacks.TensorBoard to visualize training progress and results with TensorBoard,
or tf.keras.callbacks.ModelCheckpoint to periodically save your model during training.

In this guide, you will learn what a Keras callback is, what it can do, and how you can build your own.
We provide a few demos of simple callback applications to get you started.

Keras callbacks overview
All callbacks subclass the keras.callbacks.Callback class, and override a set of methods called at various stages of training, testing, and predicting. Callbacks are useful to get a view on internal states and statistics of the model during training.

You can pass a list of callbacks (as the **keyword argument callbacks**) to the following model methods:

keras.Model.fit()
keras.Model.evaluate()
keras.Model.predict()

An overview of callback methods

  • Global methods

    • on_(train|test|predict)**_begin**(self, logs=None): Called at the beginning of fit/evaluate/predict.
    • on_(train|test|predict)**_end**(self, logs=None): Called at the end of fit/evaluate/predict.
  • Batch-level methods for training/testing/predicting

    • on_(train|test|predict)_batch_begin(self, batch, logs=None): Called right before processing a batch during training/testing/predicting.
    • on_(train|test|predict)_batch_end(self, batch, logs=None): Called at the end of a batched training/testing/predicting. Within this method, logs is a dict containing the metrics results.
  • Epoch-level methods (training only)

    • on_epoch_begin(self, epoch, logs=None): Called at the beginning of an epoch during training.
    • on_epoch_end(self, epoch, logs=None): Called at the end of an epoch during training.
import tensorflow as tf
import keras




class CustomCallback(keras.callbacks.Callback):
    def on_train_begin(self, logs=None):
        keys = list(logs.keys())
        print("Starting training; got log keys: {}".format(keys))

    def on_train_end(self, logs=None):
        keys = list(logs.keys())
        print("Stop training; got log keys: {}".format(keys))

    def on_epoch_begin(self, epoch, logs=None):
        keys = list(logs.keys())
        print("Start epoch {} of training; got log keys: {}".format(epoch, keys))

    def on_epoch_end(self, epoch, logs=None):
        keys = list(logs.keys())
        print("End epoch {} of training; got log keys: {}".format(epoch, keys))

    def on_test_begin(self, logs=None):
        keys = list(logs.keys())
        print("Start testing; got log keys: {}".format(keys))

    def on_test_end(self, logs=None):
        keys = list(logs.keys())
        print("Stop testing; got log keys: {}".format(keys))

    def on_predict_begin(self, logs=None):
        keys = list(logs.keys())
        print("Start predicting; got log keys: {}".format(keys))

    def on_predict_end(self, logs=None):
        keys = list(logs.keys())
        print("Stop predicting; got log keys: {}".format(keys))

    def on_train_batch_begin(self, batch, logs=None):
        keys = list(logs.keys())
        print("...Training: start of batch {}; got log keys: {}".format(batch, keys))

    def on_train_batch_end(self, batch, logs=None):
        keys = list(logs.keys())
        print("...Training: end of batch {}; got log keys: {}".format(batch, keys))

    def on_test_batch_begin(self, batch, logs=None):
        keys = list(logs.keys())
        print("...Evaluating: start of batch {}; got log keys: {}".format(batch, keys))

    def on_test_batch_end(self, batch, logs=None):
        keys = list(logs.keys())
        print("...Evaluating: end of batch {}; got log keys: {}".format(batch, keys))

    def on_predict_batch_begin(self, batch, logs=None):
        keys = list(logs.keys())
        print("...Predicting: start of batch {}; got log keys: {}".format(batch, keys))

    def on_predict_batch_end(self, batch, logs=None):
        keys = list(logs.keys())
        print("...Predicting: end of batch {}; got log keys: {}".format(batch, keys))