Get Gradients with Keras Tensorflow 2.0

ghz 7months ago ⋅ 75 views

I would like to keep track of the gradients over tensorboard. However, since session run statements are not a thing anymore and the write_grads argument of tf.keras.callbacks.TensorBoard is deprecated, I would like to know how to keep track of gradients during training with Keras or tensorflow 2.0.

My current approach is to create a new callback class for this purpose, but without success. Maybe someone else knows how to accomplish this kind of advanced stuff.

The code created for testing is shown below, but runs into errors independently of printing a gradient value to console or tensorboard.

import tensorflow as tf
from tensorflow.python.keras import backend as K

mnist = tf.keras.datasets.mnist

(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(input_shape=(28, 28)),
  tf.keras.layers.Dense(128, activation='relu', name='dense128'),
  tf.keras.layers.Dropout(0.2),
  tf.keras.layers.Dense(10, activation='softmax', name='dense10')
])

model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])


class GradientCallback(tf.keras.callbacks.Callback):
    console = True

    def on_epoch_end(self, epoch, logs=None):
        weights = [w for w in self.model.trainable_weights if 'dense' in w.name and 'bias' in w.name]
        loss = self.model.total_loss
        optimizer = self.model.optimizer
        gradients = optimizer.get_gradients(loss, weights)
        for t in gradients:
            if self.console:
                print('Tensor: {}'.format(t.name))
                print('{}\n'.format(K.get_value(t)[:10]))
            else:
                tf.summary.histogram(t.name, data=t)


file_writer = tf.summary.create_file_writer("./metrics")
file_writer.set_as_default()

# write_grads has been removed
tensorboard_cb = tf.keras.callbacks.TensorBoard(histogram_freq=1, write_grads=True)
gradient_cb = GradientCallback()

model.fit(x_train, y_train, epochs=5, callbacks=[gradient_cb, tensorboard_cb])
  • Priniting bias gradients to console (console parameter = True) leads to: AttributeError: 'Tensor' object has no attribute 'numpy'
  • Writing to tensorboard (console parameter = False) creates: TypeError: Using a tf.Tensor as a Python bool is not allowed. Use if t is not None: instead of if t: to test if a tensor is defined, and use TensorFlow ops such as tf.cond to execute subgraphs conditioned on the value of a tensor.

Answers

To track gradients during training with TensorFlow 2.x and Keras, you need to use custom training loops or custom callbacks to log gradients to TensorBoard. The main challenge is to get the gradients and log them appropriately using TensorFlow operations. Here's an example that shows how to accomplish this:

  1. Define a Custom Callback: Create a custom callback to compute and log the gradients.
  2. Use tf.GradientTape: Use tf.GradientTape within the custom callback to compute the gradients.
  3. Log Gradients to TensorBoard: Use tf.summary to log the gradients to TensorBoard.

Here's how you can implement this:

import tensorflow as tf
import numpy as np

mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(input_shape=(28, 28)),
  tf.keras.layers.Dense(128, activation='relu', name='dense128'),
  tf.keras.layers.Dropout(0.2),
  tf.keras.layers.Dense(10, activation='softmax', name='dense10')
])

model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])

class GradientCallback(tf.keras.callbacks.Callback):
    def __init__(self, log_dir='./logs'):
        super(GradientCallback, self).__init__()
        self.log_dir = log_dir
        self.file_writer = tf.summary.create_file_writer(log_dir)
    
    def on_epoch_end(self, epoch, logs=None):
        with tf.GradientTape() as tape:
            # Get the loss value for the current batch.
            loss = self.model.total_loss
        # Get the gradients of the trainable variables with respect to the loss.
        gradients = tape.gradient(loss, self.model.trainable_weights)
        
        with self.file_writer.as_default():
            for weight, grad in zip(self.model.trainable_weights, gradients):
                if grad is not None:
                    tf.summary.histogram(f'{weight.name}/gradients', data=grad, step=epoch)
        
        # You can also log the mean and std of the gradients
        for grad in gradients:
            if grad is not None:
                tf.summary.scalar('gradient_mean', data=tf.reduce_mean(grad), step=epoch)
                tf.summary.scalar('gradient_stddev', data=tf.math.reduce_std(grad), step=epoch)

tensorboard_cb = tf.keras.callbacks.TensorBoard(log_dir='./logs')
gradient_cb = GradientCallback(log_dir='./logs')

model.fit(x_train, y_train, epochs=5, callbacks=[gradient_cb, tensorboard_cb])

Explanation:

  1. Model Definition: The model is defined as usual.
  2. GradientCallback Class:
    • The GradientCallback class inherits from tf.keras.callbacks.Callback.
    • In the __init__ method, it sets up the tf.summary.create_file_writer to log data to TensorBoard.
    • In the on_epoch_end method, it uses tf.GradientTape to compute the gradients of the trainable weights with respect to the loss. These gradients are then logged to TensorBoard using tf.summary.histogram and tf.summary.scalar for mean and standard deviation of the gradients.
  3. Logging to TensorBoard:
    • The tensorboard_cb callback logs other metrics.
    • The gradient_cb callback logs the gradients.

Notes:

  • Ensure that your TensorFlow version is 2.x or above.
  • You can view the logs by running tensorboard --logdir=./logs in your terminal and navigating to localhost:6006 in your web browser.

This approach captures the gradients and logs them to TensorBoard, providing insights into the training process and the behavior of gradients over time.