ScaDaMaLe Course site and book

Introduction to Distributed Deep Learning (DDL) with Horovod over Tensorflow/keras and Pytorch

Raazesh Sainudiin and Tilo Wiklund

This is just a set of links and notebooks from the databricks blog on the topic.

It is meant to help you take the fastest steps into DDL including instructions to install and setup the right libraries and environments in databricks.

ScaDaMaLe Course site and book

The following is from databricks blog with minor adaptations with help from Tilo Wiklund.

Distributed deep learning training using TensorFlow and Keras with HorovodRunner

This notebook demonstrates how to train a model for the MNIST dataset using the tensorflow.keras API. It first shows how to train the model on a single node, and then shows how to adapt the code using HorovodRunner for distributed training.

Requirements

  • This notebook runs on CPU or GPU clusters.
  • To run the notebook, create a cluster with
  • Two workers
  • Databricks Runtime 6.3 ML or above

Cluster Specs on databricks

Run on tiny-debug-cluster-(no)gpu or another cluster with the following runtime specifications with CPU/non-GPU and GPU clusters, respectively:

  • Runs on non-GPU cluster with 3 (or more) nodes on 7.4 ML runtime (nodes are 1+2 x m4.xlarge)
  • Runs on GPU cluster with 3 (or more) nodes on 7.4 ML GPU runtime (nodes are 1+2 x g4dn.xlarge)

You do not need to "install" anything else in databricks as everything needed is pre-installed in the runtime environment on the right nodes.

Set up checkpoint location

The next cell creates a directory for saved checkpoint models.

import os
import time

checkpoint_dir = '/dbfs/ml/MNISTDemo/train/{}/'.format(time.time())

os.makedirs(checkpoint_dir)

Create function to prepare data

This following cell creates a function that prepares the data for training. This function takes in rank and size arguments so it can be used for both single-node and distributed training. In Horovod, rank is a unique process ID and size is the total number of processes.

This function downloads the data from keras.datasets, distributes the data across the available nodes, and converts the data to shapes and types needed for training.

def get_dataset(num_classes, rank=0, size=1):
  from tensorflow import keras
  
  (x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data('MNIST-data-%d' % rank)
  x_train = x_train[rank::size]
  y_train = y_train[rank::size]
  x_test = x_test[rank::size]
  y_test = y_test[rank::size]
  x_train = x_train.reshape(x_train.shape[0], 28, 28, 1)
  x_test = x_test.reshape(x_test.shape[0], 28, 28, 1)
  x_train = x_train.astype('float32')
  x_test = x_test.astype('float32')
  x_train /= 255
  x_test /= 255
  y_train = keras.utils.to_categorical(y_train, num_classes)
  y_test = keras.utils.to_categorical(y_test, num_classes)
  return (x_train, y_train), (x_test, y_test)

Create function to train model

The following cell defines the model using the tensorflow.keras API. This code is adapted from the Keras MNIST convnet example. The model consists of 2 convolutional layers, a max-pooling layer, two dropout layers, and a final dense layer.

def get_model(num_classes):
  from tensorflow.keras import models
  from tensorflow.keras import layers
  
  model = models.Sequential()
  model.add(layers.Conv2D(32, kernel_size=(3, 3),
                   activation='relu',
                   input_shape=(28, 28, 1)))
  model.add(layers.Conv2D(64, (3, 3), activation='relu'))
  model.add(layers.MaxPooling2D(pool_size=(2, 2)))
  model.add(layers.Dropout(0.25))
  model.add(layers.Flatten())
  model.add(layers.Dense(128, activation='relu'))
  model.add(layers.Dropout(0.5))
  model.add(layers.Dense(num_classes, activation='softmax'))
  return model

At this point, you have created functions to load and preprocess the dataset and to create the model. This section illustrates single-node training code using tensorflow.keras.

# Specify training parameters
batch_size = 128
epochs = 5
num_classes = 10        


def train(learning_rate=1.0):
  from tensorflow import keras
  
  (x_train, y_train), (x_test, y_test) = get_dataset(num_classes)
  model = get_model(num_classes)

  # Specify the optimizer (Adadelta in this example), using the learning rate input parameter of the function so that Horovod can adjust the learning rate during training
  optimizer = keras.optimizers.Adadelta(lr=learning_rate)

  model.compile(optimizer=optimizer,
                loss='categorical_crossentropy',
                metrics=['accuracy'])

  model.fit(x_train, y_train,
            batch_size=batch_size,
            epochs=epochs,
            verbose=2,
            validation_data=(x_test, y_test))

Run the train function you just created to train a model on the driver node. The process takes several minutes. The accuracy improves with each epoch.

# Runs in  23.67 seconds on 3-node     GPU
# Runs in 418.8  seconds on 3-node non-GPU
train(learning_rate=0.1)
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz

    8192/11490434 [..............................] - ETA: 0s
  319488/11490434 [..............................] - ETA: 1s
  679936/11490434 [>.............................] - ETA: 1s
 1024000/11490434 [=>............................] - ETA: 1s
 1409024/11490434 [==>...........................] - ETA: 1s
 1802240/11490434 [===>..........................] - ETA: 1s
 2195456/11490434 [====>.........................] - ETA: 1s
 2588672/11490434 [=====>........................] - ETA: 1s
 2981888/11490434 [======>.......................] - ETA: 1s
 3358720/11490434 [=======>......................] - ETA: 1s
 3866624/11490434 [=========>....................] - ETA: 1s
 4440064/11490434 [==========>...................] - ETA: 0s
 5160960/11490434 [============>.................] - ETA: 0s
 6094848/11490434 [==============>...............] - ETA: 0s
 7274496/11490434 [=================>............] - ETA: 0s
 8683520/11490434 [=====================>........] - ETA: 0s
10436608/11490434 [==========================>...] - ETA: 0s
11493376/11490434 [==============================] - 1s 0us/step
Epoch 1/5
469/469 - 3s - loss: 0.6257 - accuracy: 0.8091 - val_loss: 0.2157 - val_accuracy: 0.9345
Epoch 2/5
469/469 - 3s - loss: 0.2950 - accuracy: 0.9127 - val_loss: 0.1450 - val_accuracy: 0.9569
Epoch 3/5
469/469 - 3s - loss: 0.2145 - accuracy: 0.9373 - val_loss: 0.1035 - val_accuracy: 0.9695
Epoch 4/5
469/469 - 3s - loss: 0.1688 - accuracy: 0.9512 - val_loss: 0.0856 - val_accuracy: 0.9738
Epoch 5/5
469/469 - 3s - loss: 0.1379 - accuracy: 0.9600 - val_loss: 0.0701 - val_accuracy: 0.9788

Migrate to HorovodRunner for distributed training

This section shows how to modify the single-node code to use Horovod. For more information about Horovod, see the Horovod documentation.

def train_hvd(learning_rate=1.0):
  # Import tensorflow modules to each worker
  from tensorflow.keras import backend as K
  from tensorflow.keras.models import Sequential
  import tensorflow as tf
  from tensorflow import keras
  import horovod.tensorflow.keras as hvd
  
  # Initialize Horovod
  hvd.init()

  # Pin GPU to be used to process local rank (one GPU per process)
  # These steps are skipped on a CPU cluster
  gpus = tf.config.experimental.list_physical_devices('GPU')
  for gpu in gpus:
    tf.config.experimental.set_memory_growth(gpu, True)
  if gpus:
    tf.config.experimental.set_visible_devices(gpus[hvd.local_rank()], 'GPU')

  # Call the get_dataset function you created, this time with the Horovod rank and size
  (x_train, y_train), (x_test, y_test) = get_dataset(num_classes, hvd.rank(), hvd.size())
  model = get_model(num_classes)

  # Adjust learning rate based on number of GPUs
  optimizer = keras.optimizers.Adadelta(lr=learning_rate * hvd.size())

  # Use the Horovod Distributed Optimizer
  optimizer = hvd.DistributedOptimizer(optimizer)

  model.compile(optimizer=optimizer,
                loss='categorical_crossentropy',
                metrics=['accuracy'])

  # Create a callback to broadcast the initial variable states from rank 0 to all other processes.
  # This is required to ensure consistent initialization of all workers when training is started with random weights or restored from a checkpoint.
  callbacks = [
      hvd.callbacks.BroadcastGlobalVariablesCallback(0),
  ]

  # Save checkpoints only on worker 0 to prevent conflicts between workers
  if hvd.rank() == 0:
      callbacks.append(keras.callbacks.ModelCheckpoint(checkpoint_dir + '/checkpoint-{epoch}.ckpt', save_weights_only = True))

  model.fit(x_train, y_train,
            batch_size=batch_size,
            callbacks=callbacks,
            epochs=epochs,
            verbose=2,
            validation_data=(x_test, y_test))

Now that you have defined a training function with Horovod, you can use HorovodRunner to distribute the work of training the model.

The HorovodRunner parameter np sets the number of processes. This example uses a cluster with two workers, each with a single GPU, so set np=2. (If you use np=-1, HorovodRunner trains using a single process on the driver node.)

# runs in  47.84 seconds on 3-node     GPU cluster
# Runs in 316.8  seconds on 3-node non-GPU cluster
from sparkdl import HorovodRunner

hr = HorovodRunner(np=2)
hr.run(train_hvd, learning_rate=0.1)
HorovodRunner will stream all training logs to notebook cell output. If there are too many logs, you
can adjust the log level in your train method. Or you can set driver_log_verbosity to
'log_callback_only' and use a HorovodRunner log  callback on the first worker to get concise
progress updates.
The global names read or written to by the pickled function are {'checkpoint_dir', 'num_classes', 'batch_size', 'epochs', 'get_model', 'get_dataset'}.
The pickled object size is 3560 bytes.

### How to enable Horovod Timeline? ###
HorovodRunner has the ability to record the timeline of its activity with Horovod  Timeline. To
record a Horovod Timeline, set the `HOROVOD_TIMELINE` environment variable  to the location of the
timeline file to be created. You can then open the timeline file  using the chrome://tracing
facility of the Chrome browser.

Start training.
Warning: Permanently added '10.149.233.216' (ECDSA) to the list of known hosts.
[1,1]<stderr>:2021-01-12 15:03:02.376337: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcudart.so.10.1
[1,0]<stderr>:2021-01-12 15:03:02.562832: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcudart.so.10.1
[1,1]<stderr>:2021-01-12 15:03:04.895002: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcuda.so.1
[1,0]<stderr>:2021-01-12 15:03:04.896022: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcuda.so.1
[1,1]<stderr>:2021-01-12 15:03:04.920620: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:982] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
[1,0]<stderr>:2021-01-12 15:03:04.921500: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:982] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
[1,1]<stderr>:2021-01-12 15:03:04.921493: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1716] Found device 0 with properties:
[1,1]<stderr>:pciBusID: 0000:00:1e.0 name: Tesla T4 computeCapability: 7.5
[1,1]<stderr>:coreClock: 1.59GHz coreCount: 40 deviceMemorySize: 14.75GiB deviceMemoryBandwidth: 298.08GiB/s
[1,1]<stderr>:2021-01-12 15:03:04.921528: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcudart.so.10.1
[1,0]<stderr>:2021-01-12 15:03:04.922411: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1716] Found device 0 with properties:
[1,0]<stderr>:pciBusID: 0000:00:1e.0 name: Tesla T4 computeCapability: 7.5
[1,0]<stderr>:coreClock: 1.59GHz coreCount: 40 deviceMemorySize: 14.75GiB deviceMemoryBandwidth: 298.08GiB/s
[1,0]<stderr>:2021-01-12 15:03:04.922448: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcudart.so.10.1
[1,1]<stderr>:2021-01-12 15:03:04.992378: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcublas.so.10
[1,0]<stderr>:2021-01-12 15:03:05.013142: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcublas.so.10
[1,1]<stderr>:2021-01-12 15:03:05.033189: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcufft.so.10
[1,1]<stderr>:2021-01-12 15:03:05.042157: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcurand.so.10
[1,0]<stderr>:2021-01-12 15:03:05.061824: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcufft.so.10
[1,0]<stderr>:2021-01-12 15:03:05.072216: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcurand.so.10
[1,1]<stderr>:2021-01-12 15:03:05.111672: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcusolver.so.10
[1,1]<stderr>:2021-01-12 15:03:05.120257: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcusparse.so.10
[1,0]<stderr>:2021-01-12 15:03:05.162596: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcusolver.so.10
[1,0]<stderr>:2021-01-12 15:03:05.174443: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcusparse.so.10
[1,1]<stderr>:2021-01-12 15:03:05.236402: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcudnn.so.7
[1,1]<stderr>:2021-01-12 15:03:05.236544: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:982] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
[1,1]<stderr>:2021-01-12 15:03:05.237464: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:982] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
[1,1]<stderr>:2021-01-12 15:03:05.238274: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1858] Adding visible gpu devices: 0
[1,1]<stdout>:Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
[1,0]<stderr>:2021-01-12 15:03:05.317271: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcudnn.so.7
[1,0]<stderr>:2021-01-12 15:03:05.317512: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:982] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
[1,0]<stderr>:2021-01-12 15:03:05.318510: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:982] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
[1,0]<stderr>:2021-01-12 15:03:05.319526: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1858] Adding visible gpu devices: 0
[1,0]<stdout>:Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz[1,0]<stdout>:
[1,1]<stdout>:
[1,1]<stdout>:    8192/11490434 [..............................] - ETA: 0s[1,1]<stdout>:
[1,1]<stdout>:  344064/11490434 [..............................] - ETA: 1s[1,1]<stdout>:
[1,1]<stdout>:  688128/11490434 [>.............................] - ETA: 1s[1,0]<stdout>:
[1,0]<stdout>:    8192/11490434 [..............................] - ETA: 0s[1,1]<stdout>:
[1,1]<stdout>: 1056768/11490434 [=>............................] - ETA: 1s[1,0]<stdout>:
[1,0]<stdout>:  278528/11490434 [..............................][1,0]<stdout>: - ETA: 2s[1,1]<stdout>:
[1,1]<stdout>: 1458176/11490434 [==>...........................] - ETA: 1s[1,0]<stdout>:
[1,0]<stdout>:  622592/11490434 [>.............................] - ETA: 1s[1,1]<stdout>:
[1,1]<stdout>: 1867776/11490434 [===>..........................] - ETA: 1s[1,0]<stdout>:
[1,0]<stdout>:  950272/11490434 [=>............................] - ETA: 1s[1,1]<stdout>:
[1,1]<stdout>: 2293760/11490434 [====>.........................] - ETA: 1s[1,0]<stdout>:
[1,0]<stdout>: 1277952/11490434 [==>...........................][1,0]<stdout>: - ETA: 1s[1,1]<stdout>:
[1,1]<stdout>: 2703360/11490434 [======>.......................] - ETA: 1s[1,0]<stdout>:
[1,0]<stdout>: 1605632/11490434 [===>..........................][1,0]<stdout>: - ETA: 1s[1,1]<stdout>:
[1,1]<stdout>: 3112960/11490434 [=======>......................] - ETA: 1s[1,0]<stdout>:
[1,0]<stdout>: 1966080/11490434 [====>.........................][1,0]<stdout>: - ETA: 1s[1,1]<stdout>:
[1,1]<stdout>: 3522560/11490434 [========>.....................] - ETA: 1s[1,0]<stdout>:
[1,0]<stdout>: 2310144/11490434 [=====>........................] - ETA: 1s[1,1]<stdout>:
[1,1]<stdout>: 3915776/11490434 [=========>....................] - ETA: 0s[1,0]<stdout>:
[1,0]<stdout>: 2670592/11490434 [=====>........................] - ETA: 1s[1,1]<stdout>:
[1,1]<stdout>: 4325376/11490434 [==========>...................] - ETA: 0s[1,0]<stdout>:
[1,0]<stdout>: 3031040/11490434 [======>.......................][1,0]<stdout>: - ETA: 1s[1,1]<stdout>:
[1,1]<stdout>: 4734976/11490434 [===========>..................] - ETA: 0s[1,0]<stdout>:
[1,0]<stdout>: 3375104/11490434 [=======>......................][1,0]<stdout>: - ETA: 1s[1,1]<stdout>:
[1,1]<stdout>: 5144576/11490434 [============>.................] - ETA: 0s[1,0]<stdout>:
[1,0]<stdout>: 3719168/11490434 [========>.....................] - ETA: 1s[1,1]<stdout>:
[1,1]<stdout>: 5554176/11490434 [=============>................] - ETA: 0s[1,0]<stdout>...(truncated)
[1,0]<stdout>:
[1,0]<stdout>: 8339456/11490434 [====================>.........][1,0]<stdout>: - ETA: 0s[1,0]<stdout>:
[1,0]<stdout>: 9486336/11490434 [=======================>......] - ETA: 0s[1,0]<stdout>:
11026432/11490434 [===========================>..][1,0]<stdout>: - ETA: 0s[1,0]<stdout>:
[1,0]<stdout>:11493376/11490434 [==============================] - 1s 0us/step
[1,1]<stderr>:2021-01-12 15:03:06.811043: I tensorflow/core/platform/cpu_feature_guard.cc:142] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN)to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F FMA
[1,1]<stderr>:To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
[1,1]<stderr>:2021-01-12 15:03:06.845779: I tensorflow/core/platform/profile_utils/cpu_utils.cc:104] CPU Frequency: 2499995000 Hz
[1,1]<stderr>:2021-01-12 15:03:06.846130: I tensorflow/compiler/xla/service/service.cc:168] XLA service 0x55aed63a8420 initialized for platform Host (this does not guarantee that XLA will be used). Devices:
[1,1]<stderr>:2021-01-12 15:03:06.846166: I tensorflow/compiler/xla/service/service.cc:176]   StreamExecutor device (0): Host, Default Version
[1,0]<stderr>:2021-01-12 15:03:06.935377: I tensorflow/core/platform/cpu_feature_guard.cc:142] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN)to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F FMA
[1,0]<stderr>:To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
[1,1]<stderr>:2021-01-12 15:03:06.943654: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:982] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
[1,1]<stderr>:2021-01-12 15:03:06.944591: I tensorflow/compiler/xla/service/service.cc:168] XLA service 0x55aed4b73250 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
[1,1]<stderr>:2021-01-12 15:03:06.944627: I tensorflow/compiler/xla/service/service.cc:176]   StreamExecutor device (0): Tesla T4, Compute Capability 7.5
[1,1]<stderr>:2021-01-12 15:03:06.945629: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:982] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
[1,1]<stderr>:2021-01-12 15:03:06.946445: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1716] Found device 0 with properties:
[1,1]<stderr>:pciBusID: 0000:00:1e.0 name: Tesla T4 computeCapability: 7.5
[1,1]<stderr>:coreClock: 1.59GHz coreCount: 40 deviceMemorySize: 14.75GiB deviceMemoryBandwidth: 298.08GiB/s
[1,1]<stderr>:2021-01-12 15:03:06.946493: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcudart.so.10.1
[1,1]<stderr>:2021-01-12 15:03:06.946543: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcublas.so.10
[1,1]<stderr>:2021-01-12 15:03:06.946567: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcufft.so.10
[1,1]<stderr>:2021-01-12 15:03:06.946584: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcurand.so.10
[1,1]<stderr>:2021-01-12 15:03:06.946599: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcusolver.so.10
[1,1]<stderr>:2021-01-12 15:03:06.946614: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcusparse.so.10
[1,1]<stderr>:2021-01-12 15:03:06.946630: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcudnn.so.7
[1,1]<stderr>:2021-01-12 15:03:06.946720: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:982] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
[1,1]<stderr>:2021-01-12 15:03:06.947579: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:982] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
[1,1]<stderr>:2021-01-12 15:03:06.948369: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1858] Adding visible gpu devices: 0
[1,1]<stderr>:2021-01-12 15:03:06.949018: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcudart.so.10.1
[1,0]<stderr>:2021-01-12 15:03:06.973106: I tensorflow/core/platform/profile_utils/cpu_utils.cc:104] CPU Frequency: 2499995000 Hz
[1,0]<stderr>:2021-01-12 15:03:06.973423: I tensorflow/compiler/xla/service/service.cc:168] XLA service 0x55ee100218f0 initialized for platform Host (this does not guarantee that XLA will be used). Devices:
[1,0]<stderr>:2021-01-12 15:03:06.973452: I tensorflow/compiler/xla/service/service.cc:176]   StreamExecutor device (0): Host, Default Version
[1,0]<stderr>:2021-01-12 15:03:07.069880: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:982] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
[1,0]<stderr>:2021-01-12 15:03:07.070799: I tensorflow/compiler/xla/service/service.cc:168] XLA service 0x55ee0fa4d9d0 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
[1,0]<stderr>:2021-01-12 15:03:07.070833: I tensorflow/compiler/xla/service/service.cc:176]   StreamExecutor device (0): Tesla T4, Compute Capability 7.5
[1,0]<stderr>:2021-01-12 15:03:07.071991: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:982] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
[1,0]<stderr>:2021-01-12 15:03:07.072849: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1716] Found device 0 with properties:
[1,0]<stderr>:pciBusID: 0000:00:1e.0 name: Tesla T4 computeCapability: 7.5
[1,0]<stderr>:coreClock: 1.59GHz coreCount: 40 deviceMemorySize: 14.75GiB deviceMemoryBandwidth: 298.08GiB/s
[1,0]<stderr>:2021-01-12 15:03:07.072902: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcudart.so.10.1
[1,0]<stderr>:2021-01-12 15:03:07.072961: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcublas.so.10
[1,0]<stderr>:2021-01-12 15:03:07.072988: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcufft.so.10
[1,0]<stderr>:2021-01-12 15:03:07.073009: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcurand.so.10
[1,0]<stderr>:2021-01-12 15:03:07.073038: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcusolver.so.10
[1,0]<stderr>:2021-01-12 15:03:07.073069: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcusparse.so.10
[1,0]<stderr>:2021-01-12 15:03:07.073095: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcudnn.so.7
[1,0]<stderr>:2021-01-12 15:03:07.073204: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:982] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
[1,0]<stderr>:2021-01-12 15:03:07.074061: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:982] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
[1,0]<stderr>:2021-01-12 15:03:07.074821: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1858] Adding visible gpu devices: 0
[1,0]<stderr>:2021-01-12 15:03:07.075888: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcudart.so.10.1
[1,1]<stderr>:2021-01-12 15:03:08.153604: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1257] Device interconnect StreamExecutor with strength 1 edge matrix:
[1,1]<stderr>:2021-01-12 15:03:08.153659: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1263]      0
[1,1]<stderr>:2021-01-12 15:03:08.153672: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1276] 0:   N
[1,1]<stderr>:2021-01-12 15:03:08.155162: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:982] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
[1,1]<stderr>:2021-01-12 15:03:08.156116: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:982] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
[1,1]<stderr>:2021-01-12 15:03:08.156943: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1402] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 13943 MB memory) -> physical GPU (device: 0, name: Tesla T4, pci bus id: 0000:00:1e.0, compute capability: 7.5)
[1,1]<stdout>:Epoch 1/5
[1,0]<stderr>:2021-01-12 15:03:08.485845: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1257] Device interconnect StreamExecutor with strength 1 edge matrix:
[1,0]<stderr>:2021-01-12 15:03:08.485903: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1263]      0
[1,0]<stderr>:2021-01-12 15:03:08.485912: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1276] 0:   N
[1,0]<stderr>:2021-01-12 15:03:08.488817: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:982] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
[1,0]<stderr>:2021-01-12 15:03:08.489793: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:982] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
[1,0]<stderr>:2021-01-12 15:03:08.490644: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1402] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 13943 MB memory) -> physical GPU (device: 0, name: Tesla T4, pci bus id: 0000:00:1e.0, compute capability: 7.5)
[1,0]<stdout>:Epoch 1/5
[1,1]<stderr>:2021-01-12 15:03:08.803868: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcublas.so.10
[1,0]<stderr>:2021-01-12 15:03:09.190762: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcublas.so.10
[1,1]<stderr>:2021-01-12 15:03:09.348763: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcudnn.so.7
[1,0]<stderr>:2021-01-12 15:03:09.857574: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcudnn.so.7
[1,0]<stdout>:1120-144117-apses921-10-149-236-86:1010:1013 [0] NCCL INFO Bootstrap : Using [0]eth0:10.149.236.86<0>
[1,0]<stdout>:1120-144117-apses921-10-149-236-86:1010:1013 [0] NCCL INFO NET/Plugin : No plugin found (libnccl-net.so), using internal implementation
[1,0]<stdout>:
[1,0]<stdout>:1120-144117-apses921-10-149-236-86:1010:1013 [0] misc/ibvwrap.cc:63 NCCL WARN Failed to open libibverbs.so[.1]
[1,0]<stdout>:1120-144117-apses921-10-149-236-86:1010:1013 [0] NCCL INFO NET/Socket : Using [0]eth0:10.149.236.86<0>
[1,0]<stdout>:1120-144117-apses921-10-149-236-86:1010:1013 [0] NCCL INFO Using network Socket
[1,0]<stdout>:NCCL version 2.7.3+cuda10.1
[1,1]<stdout>:1120-144117-apses921-10-149-233-216:1035:1038 [0] NCCL INFO Bootstrap : Using [0]eth0:10.149.233.216<0>
[1,1]<stdout>:1120-144117-apses921-10-149-233-216:1035:1038 [0] NCCL INFO NET/Plugin : No plugin found (libnccl-net.so), using internal implementation
[1,1]<stdout>:
[1,1]<stdout>:1120-144117-apses921-10-149-233-216:1035:1038 [0] misc/ibvwrap.cc:63 NCCL WARN Failed to open libibverbs.so[.1]
[1,1]<stdout>:1120-144117-apses921-10-149-233-216:1035:1038 [0] NCCL INFO NET/Socket : Using [0]eth0:10.149.233.216<0>
[1,1]<stdout>:1120-144117-apses921-10-149-233-216:1035:1038 [0] NCCL INFO Using network Socket
[1,0]<stdout>:1120-144117-apses921-10-149-236-86:1010:1013 [0] NCCL INFO Channel 00/02 :    0   1
[1,0]<stdout>:1120-144117-apses921-10-149-236-86:1010:1013 [0] NCCL INFO Channel 01/02 :    0   1
[1,0]<stdout>:1120-144117-apses921-10-149-236-86:1010:1013 [0] NCCL INFO threadThresholds 8/8/64 | 16/8/64 | 8/8/64
[1,0]<stdout>:1120-144117-apses921-10-149-236-86:1010:1013 [0] NCCL INFO Trees [0] 1/-1/-1->0->-1|-1->0->1/-1/-1 [1] -1/-1/-1->0->1|1->0->-1/-1/-1
[1,1]<stdout>:1120-144117-apses921-10-149-233-216:1035:1038 [0] NCCL INFO threadThresholds 8/8/64 | 16/8/64 | 8/8/64
[1,1]<stdout>:1120-144117-apses921-10-149-233-216:1035:1038 [0] NCCL INFO Trees [0] -1/-1/-1->1->0|0->1->-1/-1/-1 [1] 0/-1/-1->1->-1|-1->1->0/-1/-1
[1,1]<stdout>:1120-144117-apses921-10-149-233-216:1035:1038 [0] NCCL INFO Channel 00 : 0[1e0] -> 1[1e0] [receive] via NET/Socket/0
[1,0]<stdout>:1120-144117-apses921-10-149-236-86:1010:1013 [0] NCCL INFO Channel 00 : 1[1e0] -> 0[1e0] [receive] via NET/Socket/0
[1,1]<stdout>:1120-144117-apses921-10-149-233-216:1035:1038 [0] NCCL INFO Channel 00 : 1[1e0] -> 0[1e0] [send] via NET/Socket/0
[1,0]<stdout>:1120-144117-apses921-10-149-236-86:1010:1013 [0] NCCL INFO Channel 00 : 0[1e0] -> 1[1e0] [send] via NET/Socket/0
[1,1]<stdout>:1120-144117-apses921-10-149-233-216:1035:1038 [0] NCCL INFO Channel 01 : 0[1e0] -> 1[1e0] [receive] via NET/Socket/0
[1,0]<stdout>:1120-144117-apses921-10-149-236-86:1010:1013 [0] NCCL INFO Channel 01 : 1[1e0] -> 0[1e0] [receive] via NET/Socket/0
[1,1]<stdout>:1120-144117-apses921-10-149-233-216:1035:1038 [0] NCCL INFO Channel 01 : 1[1e0] -> 0[1e0] [send] via NET/Socket/0
[1,0]<stdout>:1120-144117-apses921-10-149-236-86:1010:1013 [0] NCCL INFO Channel 01 : 0[1e0] -> 1[1e0] [send] via NET/Socket/0
[1,1]<stdout>:1120-144117-apses921-10-149-233-216:1035:1038 [0] NCCL INFO 2 coll channels, 2 p2p channels, 1 p2p channels per peer
[1,1]<stdout>:1120-144117-apses921-10-149-233-216:1035:1038 [0] NCCL INFO comm 0x7f408c300970 rank 1 nranks 2 cudaDev 0 busId 1e0 - Init COMPLETE
[1,0]<stdout>:1120-144117-apses921-10-149-236-86:1010:1013 [0] NCCL INFO 2 coll channels, 2 p2p channels, 1 p2p channels per peer
[1,0]<stdout>:1120-144117-apses921-10-149-236-86:1010:1013 [0] NCCL INFO comm 0x7f56e53be860 rank 0 nranks 2 cudaDev 0 busId 1e0 - Init COMPLETE
[1,0]<stdout>:1120-144117-apses921-10-149-236-86:1010:1013 [0] NCCL INFO Launch mode Parallel
[1,1]<stdout>:235/235 - 3s - loss: 0.5233 - accuracy: 0.8414 - val_loss: 0.1912 - val_accuracy: 0.9434
[1,1]<stdout>:Epoch 2/5
[1,0]<stdout>:235/235 - 5s - loss: 0.6732 - accuracy: 0.7913 - val_loss: 0.1872 - val_accuracy: 0.9434
[1,0]<stdout>:Epoch 2/5
[1,1]<stdout>:235/235 - 3s - loss: 0.1892 - accuracy: 0.9455 - val_loss: 0.1172 - val_accuracy: 0.9650
[1,1]<stdout>:Epoch 3/5
[1,0]<stdout>:235/235 - 5s - loss: 0.3207 - accuracy: 0.9024 - val_loss: 0.1168 - val_accuracy: 0.9650
[1,0]<stdout>:Epoch 3/5
[1,1]<stdout>:235/235 - 3s - loss: 0.1225 - accuracy: 0.9651 - val_loss: 0.0827 - val_accuracy: 0.9754
[1,1]<stdout>:Epoch 4/5
[1,0]<stdout>:235/235 - 5s - loss: 0.2330 - accuracy: 0.9303 - val_loss: 0.0795 - val_accuracy: 0.9750
[1,0]<stdout>:Epoch 4/5
[1,1]<stdout>:235/235 - 3s - loss: 0.0916 - accuracy: 0.9744 - val_loss: 0.0655 - val_accuracy: 0.9790
[1,1]<stdout>:Epoch 5/5
[1,0]<stdout>:235/235 - 5s - loss: 0.1812 - accuracy: 0.9448 - val_loss: 0.0624 - val_accuracy: 0.9794
[1,0]<stdout>:Epoch 5/5
[1,1]<stdout>:235/235 - 4s - loss: 0.0738 - accuracy: 0.9798 - val_loss: 0.0571 - val_accuracy: 0.9830
[1,0]<stdout>:235/235 - 6s - loss: 0.1536 - accuracy: 0.9528 - val_loss: 0.0544 - val_accuracy: 0.9822

Under the hood, HorovodRunner takes a Python method that contains deep learning training code with Horovod hooks. HorovodRunner pickles the method on the driver and distributes it to Spark workers. A Horovod MPI job is embedded as a Spark job using the barrier execution mode. The first executor collects the IP addresses of all task executors using BarrierTaskContext and triggers a Horovod job using mpirun. Each Python MPI process loads the pickled user program, deserializes it, and runs it.

For more information, see HorovodRunner API documentation.

ScaDaMaLe Course site and book

The following is from databricks blog with minor adaptations with help from Tilo Wiklund.

Distributed deep learning training using PyTorch with HorovodRunner for MNIST

This notebook demonstrates how to train a model for the MNIST dataset using PyTorch. It first shows how to train the model on a single node, and then shows how to adapt the code using HorovodRunner for distributed training.

Requirements

  • This notebook runs on CPU or GPU clusters.
  • To run the notebook, create a cluster with
  • Two workers

Cluster Specs on databricks

Run on tiny-debug-cluster-(no)gpu or another cluster with the following runtime specifications with CPU/non-GPU and GPU clusters, respectively:

  • Runs on non-GPU cluster with 3 (or more) nodes on 7.4 ML runtime (nodes are 1+2 x m4.xlarge)
  • Runs on GPU cluster with 3 (or more) nodes on 7.4 ML GPU runtime (nodes are 1+2 x g4dn.xlarge)

You do not need to "install" anything else in databricks as everything needed is pre-installed in the runtime environment on the right nodes.

Set up checkpoint location

The next cell creates a directory for saved checkpoint models. Databricks recommends saving training data under dbfs:/ml, which maps to file:/dbfs/ml on driver and worker nodes.

PYTORCH_DIR = '/dbfs/ml/horovod_pytorch'

Prepare single-node code

First you need to have working single-node PyTorch code. This is modified from Horovod's PyTorch MNIST Example.

import torch
import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return F.log_softmax(x)
# Specify training parameters
batch_size = 100
num_epochs = 5
momentum = 0.5
log_interval = 100
def train_one_epoch(model, device, data_loader, optimizer, epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(data_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(data_loader) * len(data),
                100. * batch_idx / len(data_loader), loss.item()))
from time import time
import os

LOG_DIR = os.path.join(PYTORCH_DIR, str(time()), 'MNISTDemo')
os.makedirs(LOG_DIR)
def save_checkpoint(model, optimizer, epoch):
  filepath = LOG_DIR + '/checkpoint-{epoch}.pth.tar'.format(epoch=epoch)
  state = {
    'model': model.state_dict(),
    'optimizer': optimizer.state_dict(),
  }
  torch.save(state, filepath)
import torch.optim as optim
from torchvision import datasets, transforms

def train(learning_rate):
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

  train_dataset = datasets.MNIST(
    'data', 
    train=True,
    download=True,
    transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]))
  data_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

  model = Net().to(device)

  optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=momentum)

  for epoch in range(1, num_epochs + 1):
    train_one_epoch(model, device, data_loader, optimizer, epoch)
    save_checkpoint(model, optimizer, epoch)

Run the train function you just created to train a model on the driver node.

# Runs in  49.65 seconds on 3 node    GPU cluster
# Runs in 118.2 seconds on 3 node non-GPU cluster
train(learning_rate = 0.001)

Migrate to HorovodRunner

HorovodRunner takes a Python method that contains deep learning training code with Horovod hooks. HorovodRunner pickles the method on the driver and distributes it to Spark workers. A Horovod MPI job is embedded as a Spark job using barrier execution mode.

import horovod.torch as hvd
from sparkdl import HorovodRunner
def train_hvd(learning_rate):
  
  # Initialize Horovod
  hvd.init()  
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  
  if device.type == 'cuda':
    # Pin GPU to local rank
    torch.cuda.set_device(hvd.local_rank())

  train_dataset = datasets.MNIST(
    # Use different root directory for each worker to avoid conflicts
    root='data-%d'% hvd.rank(),  
    train=True, 
    download=True,
    transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
  )

  from torch.utils.data.distributed import DistributedSampler
  
  # Configure the sampler so that each worker gets a distinct sample of the input dataset
  train_sampler = DistributedSampler(train_dataset, num_replicas=hvd.size(), rank=hvd.rank())
  # Use train_sampler to load a different sample of data on each worker
  train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, sampler=train_sampler)

  model = Net().to(device)
  
  # The effective batch size in synchronous distributed training is scaled by the number of workers
  # Increase learning_rate to compensate for the increased batch size
  optimizer = optim.SGD(model.parameters(), lr=learning_rate * hvd.size(), momentum=momentum)

  # Wrap the local optimizer with hvd.DistributedOptimizer so that Horovod handles the distributed optimization
  optimizer = hvd.DistributedOptimizer(optimizer, named_parameters=model.named_parameters())
  
  # Broadcast initial parameters so all workers start with the same parameters
  hvd.broadcast_parameters(model.state_dict(), root_rank=0)

  for epoch in range(1, num_epochs + 1):
    train_one_epoch(model, device, train_loader, optimizer, epoch)
    # Save checkpoints only on worker 0 to prevent conflicts between workers
    if hvd.rank() == 0:
      save_checkpoint(model, optimizer, epoch)

Now that you have defined a training function with Horovod, you can use HorovodRunner to distribute the work of training the model.

The HorovodRunner parameter np sets the number of processes. This example uses a cluster with two workers, each with a single GPU, so set np=2. (If you use np=-1, HorovodRunner trains using a single process on the driver node.)

# Runs in 51.63 seconds on 3 node     GPU cluster
# Runs in 96.6  seconds on 3 node non-GPU cluster
hr = HorovodRunner(np=2) 
hr.run(train_hvd, learning_rate = 0.001)

Under the hood, HorovodRunner takes a Python method that contains deep learning training code with Horovod hooks. HorovodRunner pickles the method on the driver and distributes it to Spark workers. A Horovod MPI job is embedded as a Spark job using the barrier execution mode. The first executor collects the IP addresses of all task executors using BarrierTaskContext and triggers a Horovod job using mpirun. Each Python MPI process loads the pickled user program, deserializes it, and runs it.

For more information, see HorovodRunner API documentation.

Distributed Deep Learning

CNN's with horovod, MLFlow and hypertuning through SparkTrials

William Anzén (Linkedin), Christian von Koch (Linkedin)

2021, Stockholm, Sweden

This project was supported by Combient Mix AB under the industrial supervision of Razesh Sainudiin and Max Fischer.

However, all the neuromusculature credit goes to William and Christian, they absorbed the WASP PhD course over their X-mas holidays

Caveat These notebooks were done an databricks shard on Azure, as opposed to AWS.

So one has to take some care up to databricks' terraform pipes. Loading data should be independent of the underlying cloud-provider as the data is loaded through Tensorflow Datasets, although the following notebooks have not been tested on this AWS infrastructure with their kind support of a total of USD 7,500 AWS credits through The databricks University Alliance which waived the DBU-units on a professional enterprise-grade shard for WASP SCadaMaLe/sds-3-x course with voluntary research students at any Swedish University who go through the curriculum first. Raazesh Sainudiin is most grateful to Rob Reed for the most admirable administration of The databricks University Alliance.

** Resources: **

These notebooks were inspired by Tensorflow's tutorial on Image Segmentation.

01aimagesegmentation_unet

In this chapter a simple U-Net architecture is implemented and evaluated against the Oxford Pets Data set. The model achieves a validation accuracy of 81.51% and a validation loss of 0.7251 after 38/50 epochs (3.96 min full 50 epochs).

exjobbsOfCombientMix202102aimagesegmenationpspnet

In this chapter a PSPNet architecture is implemented and evaluated against the Oxford Pets Data set. The model achieves a validation accuracy of 90.30% and a validation loss of 0.3145 after 42/50 epochs (39.64 min full 50 epochs).

exjobbsOfCombientMix202103aimagesegmenationicnet

In this chapter the ICNet architecture is implemented and evaluated against the Oxford Pets Data set. The model achieves a validation accuracy of 86.64% and a validation loss of 0.3750 after 31/50 epochs (12.50 min full 50 epochs).

exjobbsOfCombientMix202104apspnettuningparallel

In this chapter we run hyperparameter tuning with hyperopt & SparkTrials allowing the hyperparameter tuning to be made in parallel across multiple workers. Achieved 0.56 loss with parameters({'batchsize': 16, 'learningrate': 0.0001437661898681224}) (1.56 hours - 4 workers)

exjobbsOfCombientMix202105pspnet_horovod

In this chapter we add horovod to the notebook, allowing distributed training of the model. Achieved a validation accuracy of 89.87% and validation loss of loss: 0.2861 after 49/50 epochs (33.93 min - 4 workers).

U-Net model for image segmentation

William Anzén (Linkedin), Christian von Koch (Linkedin)

2021, Stockholm, Sweden

This project was supported by Combient Mix AB under the industrial supervision of Razesh Sainudiin and Max Fischer.

This is a modified version of Tensorflows tutorial regarding image segmentation which can be found here. In this notebook a modified U-Net was used.

Importing the required packages...

import tensorflow as tf
import tensorflow_datasets as tfds
import matplotlib.pyplot as plt
from IPython.display import clear_output
from tensorflow.keras import Input
from tensorflow.keras.layers import Conv2D, BatchNormalization, Activation, MaxPooling2D, Concatenate, ReLU, Reshape, Conv2DTranspose
from tensorflow.keras import Model
from tensorflow.keras.applications import VGG16
from typing import Union, Tuple

First, the Oxford-IIT Pet Dataset from the TensorFlow datasets is loaded and then the images are transformed in desired way and the datasets used for training and inference are created. Finally an example image is displayed.

dataset, info = tfds.load('oxford_iiit_pet:3.*.*', with_info=True)

def normalize(input_image, input_mask):
  input_image = tf.cast(input_image, tf.float32) / 255.0
  input_mask -= 1
  return input_image, input_mask

@tf.function
def load_image_train(datapoint):
  input_image = tf.image.resize(datapoint['image'], (128, 128))
  input_mask = tf.image.resize(datapoint['segmentation_mask'], (128, 128))

  if tf.random.uniform(()) > 0.5:
    input_image = tf.image.flip_left_right(input_image)
    input_mask = tf.image.flip_left_right(input_mask)

  input_image, input_mask = normalize(input_image, input_mask)

  return input_image, input_mask

def load_image_test(datapoint):
  input_image = tf.image.resize(datapoint['image'], (128, 128))
  input_mask = tf.image.resize(datapoint['segmentation_mask'], (128, 128))

  input_image, input_mask = normalize(input_image, input_mask)

  return input_image, input_mask

TRAIN_LENGTH = info.splits['train'].num_examples
BATCH_SIZE = 64
BUFFER_SIZE = 1000
STEPS_PER_EPOCH = TRAIN_LENGTH // BATCH_SIZE

train = dataset['train'].map(load_image_train, num_parallel_calls=tf.data.experimental.AUTOTUNE)
test = dataset['test'].map(load_image_test)

train_dataset = train.cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE).repeat()
train_dataset = train_dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
test_dataset = test.batch(BATCH_SIZE)

def display(display_list):
  plt.figure(figsize=(15, 15))

  title = ['Input Image', 'True Mask', 'Predicted Mask']

  for i in range(len(display_list)):
    plt.subplot(1, len(display_list), i+1)
    plt.title(title[i])
    plt.imshow(tf.keras.preprocessing.image.array_to_img(display_list[i]))
    plt.axis('off')
  plt.show()

for image, mask in train.take(1):
  sample_image, sample_mask = image, mask
display([sample_image, sample_mask])

Now that the dataset has been loaded into memory, the functions needed for the modified U-Net model are defined. Note that the alternative of both using a VGG16 encoder and a regular one is provided.

def conv_block(input_layer: tf.Tensor,
               filter_size: int, 
               kernel_size: tuple, 
               activation: str = 'relu'
              ) -> tf.Tensor:
  """
  Creates one concolutional block of the U-Net structure.
  
  Parameters
  ----------
  input_layer : tf.Tensor
      Input tensor.

  n_filters : int
      Number of filters/kernels.

   kernel_size : tuple
      Size of the kernel.

  activation : str, default='relu'
      Activation function to be applied after the pooling operations.

  Returns
  -------
  tf.Tensor
      2D convolutional block.
  """
  x = Conv2D(filters=filter_size, kernel_size=kernel_size, padding='same')(input_layer)
  x = BatchNormalization()(x)
  x = Activation(activation)(x)
  x = Conv2D(filters=filter_size, kernel_size=kernel_size,padding='same')(x)
  x = BatchNormalization()(x)
  x = Activation(activation)(x)
  return x

def encoder_VGG16(input_shape: list
                 ) -> Tuple[tf.keras.Model, list]:
  """
  Creates the encoder as a VGG16 network.
  
  Parameters
  ----------
  input_shape : list
      Input shape to initialize model.

  Returns
  -------
  tf.keras.Model
      Model instance from tf.keras.

  list
      List containing the layers to be concatenated in the upsampling phase.
  """
  base_model=VGG16(include_top=False, weights='imagenet', input_shape=input_shape)
  layers=[layer.output for layer in base_model.layers]
  base_model = tf.keras.Model(inputs=base_model.input, outputs=layers[-2])
  base_model.summary()
  
  x = []
  y = base_model.get_layer('block1_conv1').output
  x.append(y)
  y = base_model.get_layer('block2_conv2').output
  x.append(y)
  y = base_model.get_layer('block3_conv3').output
  x.append(y)
  y = base_model.get_layer('block4_conv3').output
  x.append(y)
  y =  base_model.get_layer('block5_conv3').output
  x.append(y)
  return base_model, x

def encoder_unet(input_shape: list
                ) -> Tuple[tf.keras.Input, list]:
  """
  Creates the encoder as the one described in the U-Net paper with slight modifications.
  
  Parameters
  ----------
  input_shape : tf.Tensor
      Shape of the inputted image to the model.

  Returns
  -------
  tf.keras.Input
      Input layer for the inputted image.

  list
      List containing the layers to be concatenated in the upsampling phase.
  """
  input_layer = tf.keras.Input(shape=input_shape)
  
  conv1 = conv_block(input_layer,4,3,'relu')
  pool1 = MaxPooling2D((2,2))(conv1)
  
  conv2 = conv_block(pool1,8,3,'relu')
  pool2 = MaxPooling2D((2,2))(conv2)
  
  conv3 = conv_block(pool2,16,3,'relu')
  pool3 = MaxPooling2D((2,2))(conv3)
  
  conv4 = conv_block(pool3,32,3,'relu')
  pool4 = MaxPooling2D((2,2))(conv4)
  
  conv5 = conv_block(pool4,64,3,'relu')  
  
  x = [conv1,conv2,conv3,conv4,conv5]
  return input_layer, x

def unet(image_width: int,
         image_heigth: int,
         n_channels: int,
         n_depth: int,
         n_classes: int,
         vgg16: bool = False,
         transfer_learning: bool = False
        ) -> tf.keras.Model:
  """
  Creates the U-Net architecture with slight modifications, using particularily less filters.
  
  Parameters
  ----------
  image_width : int
      Shape of the desired width for the inputted image.
      
  image_height : int
      Shape of the desired height for the inputted image.
      
  n_channels : int
      Number of channels of the inputted image.
      
  n_depth : int
      The desired depth level of the resulting U-Net architecture. 
      
  n_classes : int
      The number of classes to be predicted/number of filters for final prediction layer
      
  vgg16 : bool, default = False
      Boolean for using architecture VGG16 in encoder part of the model or not.
      
  transfer_learning : bool, default = True
      Boolean for using transfer learning with pre-trained weights from ImageNet or not.

  Returns
  -------
  tf.keras.Model
      The produced U-Net model.
  """
  if n_depth<1 or n_depth>5:
    raise Exception("Unsupported number of layers/upsamples")
  input_shape = [image_heigth, image_width, n_channels]
  if vgg16:
    encoded_model, x = encoder_VGG16(input_shape)
    if transfer_learning:
      encoded_model.trainable=False
  else:
    encoded_model, x = encoder_unet(input_shape)
  intermediate_model = x[n_depth-1]
  #Dropout
  for i in reversed(range(0,n_depth-1)):
    next_filters = x[i+1].shape[3]/2
    intermediate_model = Conv2DTranspose(filters=next_filters ,kernel_size=3,strides=2,padding='same')(intermediate_model)
    intermediate_model = tf.keras.layers.Concatenate()([intermediate_model,x[i]])
    intermediate_model = tf.keras.layers.BatchNormalization()(intermediate_model)
    intermediate_model = tf.keras.layers.ReLU()(intermediate_model)
    intermediate_model = conv_block(intermediate_model,next_filters,kernel_size=3,activation='relu')
    
  intermediate_model=Conv2D(filters=n_classes,kernel_size=(1,1),strides=(1),padding='same')(intermediate_model)
  intermediate_model = Reshape((image_heigth*image_width, n_classes))(intermediate_model)
  intermediate_model = Activation(tf.nn.softmax)(intermediate_model)
  intermediate_model = Reshape((image_heigth,image_width, n_classes))(intermediate_model)
  
  final_model=tf.keras.models.Model(inputs=encoded_model ,outputs=intermediate_model)
  return final_model

Let's then create the model.

model = unet(128,128,3,5,3)

And here is a summary of the model created...

model.summary()

The model is then compiled and its prediction before training is shown.

model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

def create_mask(pred_mask):
  pred_mask = tf.argmax(pred_mask, axis=-1)
  pred_mask = pred_mask[..., tf.newaxis]
  return pred_mask[0]

def show_predictions(dataset=None, num=1):
  if dataset:
    for image, mask in dataset.take(num):
      pred_mask = model.predict(image)
      display([image[0], mask[0], create_mask(pred_mask)])
  else:
    display([sample_image, sample_mask,
             create_mask(model.predict(sample_image[tf.newaxis, ...]))])
    
show_predictions()

Below, the model is fitted against the training data and validated on the validation set after each epoch. A validation accuracy of 84.5 % is achieved after one epoch. A custom callback is added to trace the learning of the model througout its training.

class MyCustomCallback(tf.keras.callbacks.Callback):
  def on_epoch_end(self, epoch, logs=None):
    show_predictions()

EPOCHS = 50
VAL_SUBSPLITS = 5
VALIDATION_STEPS = info.splits['test'].num_examples//BATCH_SIZE//VAL_SUBSPLITS

model_history = model.fit(train_dataset, epochs=EPOCHS,
                          steps_per_epoch=STEPS_PER_EPOCH,
                          validation_steps=VALIDATION_STEPS,
                          validation_data=test_dataset,
                          callbacks = [MyCustomCallback()])

Some predictions on the test_dataset are shown to showcase the performance of the model on images it has not been trained on.

show_predictions(test_dataset,num=10)

Finally we plot the learning curves of the model in its 50 epochs of training. Both the loss curves as well as the accuracy curves are presented.

loss = model_history.history['loss']
val_loss = model_history.history['val_loss']
acc = model_history.history['accuracy']
val_acc = model_history.history['val_accuracy']

epochs = range(EPOCHS)

plt.figure(figsize=(10,5))
plt.subplot(1,2,1)
plt.plot(epochs, loss, 'r', label='Training loss')
plt.plot(epochs, val_loss, 'bo', label='Validation loss')
plt.title('Training and Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss Value')
plt.legend()
plt.subplot(1,2,2)
plt.plot(epochs, acc, 'r', label='Training accuracy')
plt.plot(epochs, val_acc, 'bo', label='Validation accuracy')
plt.title('Training and Validation Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.show()

Implementation of PSPNet

William Anzén (Linkedin), Christian von Koch (Linkedin)

2021, Stockholm, Sweden

This project was supported by Combient Mix AB under the industrial supervision of Razesh Sainudiin and Max Fischer.

In this notebook, an implementation of PSPNet is presented which is an architecture which uses scene parsing and evaluates the images at different scales and finally combines the different results to form a final prediction. The architecture is evaluated against the Oxford-IIIT Pet Dataset. This notebook has reused material from the Image Segmentation Tutorial on Tensorflow for loading the dataset and showing predictions.

Importing the required packages.

import tensorflow as tf
import tensorflow_datasets as tfds
import matplotlib.pyplot as plt
from tensorflow.keras.layers import AveragePooling2D, Conv2D, BatchNormalization, Activation, Concatenate, UpSampling2D, Reshape, TimeDistributed, ConvLSTM2D
from tensorflow.keras import Model
import numpy as np
from tensorflow.keras.applications.resnet50 import ResNet50

Setting memory growth to the GPUs is recommended as these model is quite memory intensive.

gpus = tf.config.list_physical_devices('GPU')
if gpus:
  try:
    # Currently, memory growt*h needs to be the same across GPUs
    for gpu in gpus:
      tf.config.experimental.set_memory_growth(gpu, True)
    logical_gpus = tf.config.experimental.list_logical_devices('GPU')
    print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
  except RuntimeError as e:
    # Memory growth must be set before GPUs have been initialized
    print(e)

Defining functions for normalizing and transforming the images.

# Function for normalizing image_size so that pixel intensity is between 0 and 1
def normalize(input_image, input_mask):
  input_image = tf.cast(input_image, tf.float32) / 255.0
  input_mask -= 1
  return input_image, input_mask

# Function for resizing the train images to the desired input shape of 128x128 as well as augmenting the training images
@tf.function
def load_image_train(datapoint):
  input_image = tf.image.resize(datapoint['image'], (128, 128))
  input_mask = tf.image.resize(datapoint['segmentation_mask'], (128, 128))

  if tf.random.uniform(()) > 0.5:
    input_image = tf.image.flip_left_right(input_image)
    input_mask = tf.image.flip_left_right(input_mask)

  input_image, input_mask = normalize(input_image, input_mask)
  input_mask = tf.math.round(input_mask)

  return input_image, input_mask

# Function for resizing the test images to the desired output shape (no augmenation)
def load_image_test(datapoint):
  input_image = tf.image.resize(datapoint['image'], (128, 128))
  input_mask = tf.image.resize(datapoint['segmentation_mask'], (128, 128))

  input_image, input_mask = normalize(input_image, input_mask)

  return input_image, input_mask

The Oxford-IIT Pet Dataset from the TensorFlow datasets is loaded and then the images are transformed in desired way and the datasets used for training and inference are created.

dataset, info = tfds.load('oxford_iiit_pet:3.*.*', with_info=True)

TRAIN_LENGTH = info.splits['train'].num_examples
BATCH_SIZE = 64
BUFFER_SIZE = 1000
STEPS_PER_EPOCH = TRAIN_LENGTH // BATCH_SIZE

train = dataset['train'].map(load_image_train, num_parallel_calls=tf.data.experimental.AUTOTUNE)
test = dataset['test'].map(load_image_test)

train_dataset = train.shuffle(BUFFER_SIZE).cache().batch(BATCH_SIZE).repeat()
train_dataset = train_dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
test_dataset = test.batch(BATCH_SIZE)

An example image is displayed.

def display(display_list):
  plt.figure(figsize=(15, 15))

  title = ['Input Image', 'True Mask', 'Predicted Mask']

  for i in range(len(display_list)):
    plt.subplot(1, len(display_list), i+1)
    plt.title(title[i])
    plt.imshow(tf.keras.preprocessing.image.array_to_img(display_list[i]))
    plt.axis('off')
  plt.show()

for image, mask in train.take(1):
  sample_image, sample_mask = image, mask
display([sample_image, sample_mask])

Defining the functions needed for the PSPNet.

def pool_block(cur_tensor: tf.Tensor,
               image_width: int,
               image_height: int,
               pooling_factor: int,
               activation: str = 'relu'
              ) -> tf.Tensor:
  """
  Parameters
  ----------
  cur_tensor : tf.Tensor
      Incoming tensor.

  image_width : int
      Width of the image.

  image_height : int
      Height of the image.

  pooling_factor : int
      Pooling factor to scale image.

  activation : str, default = 'relu'
      Activation function to be applied after the pooling operations.

  Returns
  -------
  tf.Tensor
      2D convolutional block.
  """
  
  #Calculate the strides with image size and pooling factor
  strides = [int(np.round(float(image_width)/pooling_factor)),
            int(np.round(float(image_height)/pooling_factor))]
  
  pooling_size = strides
  
  x = AveragePooling2D(pooling_size, strides=strides, padding='same')(cur_tensor)
  
  x = Conv2D(filters = 512, 
             kernel_size = (1,1),
             padding = 'same')(x)
  
  x = BatchNormalization()(x)
  x = Activation(activation)(x)
  
  # Resizing images to correct shape for future concat
  x = tf.keras.layers.experimental.preprocessing.Resizing(
    image_height, 
    image_width, 
    interpolation="bilinear")(x) 
  
  return x

def modify_ResNet_Dilation(
  model: tf.keras.Model
) -> tf.keras.Model:
  """
  Modifies the ResNet to fit the PSPNet Paper with the described dilated strategy.
  
    Parameters
    ----------
    model : tf.keras.Model
        The ResNet-50 model to be modified.
        
    Returns
    -------
    tf.keras.Model
        Modified model.
  """
  for i in range(0,4):
    model.get_layer('conv4_block1_{}_conv'.format(i)).strides = 1
    model.get_layer('conv4_block1_{}_conv'.format(i)).dilation_rate = 2
    model.get_layer('conv5_block1_{}_conv'.format(i)).strides = 1
    model.get_layer('conv5_block1_{}_conv'.format(i)).dilation_rate = 4
  model.save('/tmp/my_model')
  new_model = tf.keras.models.load_model('/tmp/my_model')
  return new_model

def PSPNet(image_width: int,
           image_height: int,
           n_classes: int,
           kernel_size: tuple = (3,3),
           activation: str = 'relu',
           weights: str = 'imagenet',
           shallow_tuning: bool = False,
           isICNet: bool = False
          ) -> tf.keras.Model:
  """
  Setting up the PSPNet model
  
    Parameters
    ----------
    image_width : int
        Width of the image.
        
    image_height : int
        Height of the image.
        
    n_classes : int
        Number of classes.
        
    kernel_size : tuple, default = (3, 3)
        Size of the kernel.    
        
    activation : str, default = 'relu'
        Activation function to be applied after the pooling operations.
        
    weights: str, default = 'imagenet'
        String defining which weights to use for the backbone, 'imagenet' for ImageNet weights or None for normalized initialization.
        
    shallow_tuning : bool, default = True
        Boolean for using shallow tuning with pre-trained weights from ImageNet or not.
        
    isICNet : bool,  default = False   
        Boolean to determine if the PSPNet will be part of an ICNet.
        
    Returns
    -------
    tf.keras.Model
        The finished keras model for PSPNet.
  """
  if shallow_tuning and not weights:
    raise ValueError("Shallow tuning can not be performed without loading pre-trained weights. Please input 'imagenet' to argument weights...")
  #If the function is used for the ICNet input_shape is set to none as ICNet takes 3 inputs
  if isICNet:
    input_shape=(None, None, 3)
  else:
    input_shape=(image_height,image_width,3)
    
  #Initializing the ResNet50 Backbone  
  y=ResNet50(include_top=False, weights=weights, input_shape=input_shape)

  y=modify_ResNet_Dilation(y)
  if shallow_tuning:
    y.trainable=False
  
  pooling_layer=[]
  output=y.output

  pooling_layer.append(output)
  
  h = image_height//8
  w = image_width//8
  
  #Loop for calling the pool block functions for pooling factors [1,2,3,6]
  for i in [1,2,3,6]:
    pool = pool_block(output, h, w, i, activation)
    pooling_layer.append(pool)
    
  x=Concatenate()(pooling_layer)
  
  x=Conv2D(filters=n_classes, kernel_size=(1,1), padding='same')(x)
  
  x=UpSampling2D(size=(8,8), data_format='channels_last', interpolation='bilinear')(x)
  x=Reshape((image_height*image_width, n_classes))(x)
  x=Activation(tf.nn.softmax)(x)
  x=Reshape((image_height,image_width,n_classes))(x)

  final_model=tf.keras.Model(inputs=y.input, outputs=x)
                              
  return final_model

Creating the PSPModel with image height and width of 128 pixels, three classes, kernel size of (3,3).

model = PSPNet(image_height = 128, image_width = 128, n_classes = 3)

And here is the model summary.

model.summary()

Compiling the model.

model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(),
              metrics=['accuracy'])

Below, functions needed to show the model's predictions against the true mask are defined.

def create_mask(pred_mask):
  pred_mask = tf.argmax(pred_mask, axis=-1)
  pred_mask = pred_mask[..., tf.newaxis]
  return pred_mask[0]

def show_predictions(dataset=None, num=1):
  if dataset:
    for image, mask in dataset.take(num):
      pred_mask = model.predict(image)
      display([image[0], mask[0], create_mask(pred_mask)])
  else:
    display([sample_image, sample_mask,
             create_mask(model.predict(sample_image[tf.newaxis, ...]))])
    
show_predictions()

A custom callback function is defined for showing how the model learns to predict while training.

class DisplayCallback(tf.keras.callbacks.Callback):
  def on_epoch_end(self, epoch, logs=None):
    show_predictions()

And finally the model is fitted against the training dataset and validated against the test dataset. .

EPOCHS = 50
VAL_SUBSPLITS = 5
VALIDATION_STEPS = info.splits['test'].num_examples//BATCH_SIZE//VAL_SUBSPLITS

model_history = model.fit(train_dataset, epochs=EPOCHS,
                          steps_per_epoch=STEPS_PER_EPOCH,
                          validation_steps=VALIDATION_STEPS,
                          validation_data=test_dataset,
                          callbacks=[DisplayCallback()])                    

Some predictions on the test_dataset are shown to showcase the performance of the model on images it has not been trained on.

show_predictions(test_dataset, num=10)

Finally we plot the learning curves of the model in its 50 epochs of training. Both the loss curves as well as the accuracy curves are presented.

loss = model_history.history['loss']
val_loss = model_history.history['val_loss']
acc = model_history.history['accuracy']
val_acc = model_history.history['val_accuracy']

epochs = range(EPOCHS)

plt.figure(figsize=(10,5))
plt.subplot(1,2,1)
plt.plot(epochs, loss, 'r', label='Training loss')
plt.plot(epochs, val_loss, 'bo', label='Validation loss')
plt.title('Training and Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss Value')
plt.legend()
plt.subplot(1,2,2)
plt.plot(epochs, acc, 'r', label='Training accuracy')
plt.plot(epochs, val_acc, 'bo', label='Validation accuracy')
plt.title('Training and Validation Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.show()

Implementation of ICNet

William Anzén (Linkedin), Christian von Koch (Linkedin)

2021, Stockholm, Sweden

This project was supported by Combient Mix AB under the industrial supervision of Razesh Sainudiin and Max Fischer.

In this notebook, an implementation of ICNet is presented which is an architecture which uses a trade-off between complexity and inference time efficiently. The architecture is evaluated against the Oxford-IIIT Pet Dataset. This notebook has reused material from the Image Segmentation Tutorial on TensorFlow for loading the dataset and showing predictions.

Importing the required packages.

import tensorflow as tf
import tensorflow_datasets as tfds
import matplotlib.pyplot as plt
from tensorflow.keras.layers import AveragePooling2D, Conv2D, BatchNormalization, Activation, Concatenate, UpSampling2D, Reshape, Add 
from tensorflow.keras import Model
import numpy as np
from tensorflow.keras.applications.resnet50 import ResNet50
from tensorflow.keras.optimizers import Adam
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from typing import Tuple

Setting memory growth to the GPUs is recommended as these model is quite memory intensive.

gpus = tf.config.list_physical_devices('GPU')
if gpus:
  try:
    # Currently, memory growt*h needs to be the same across GPUs
    for gpu in gpus:
      tf.config.experimental.set_memory_growth(gpu, True)
    logical_gpus = tf.config.experimental.list_logical_devices('GPU')
    print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
  except RuntimeError as e:
    # Memory growth must be set before GPUs have been initialized
    print(e)

Defining functions for normalizing and transforming the images.

def normalize(input_image, input_mask):
  input_image = tf.cast(input_image, tf.float32) / 255.0
  input_mask -= 1
  return input_image, input_mask

# Function for resizing the train images to the desired input shape of HxW as well as augmenting the training images.
@tf.function
def load_image_train(datapoint):
  input_image = tf.image.resize(datapoint['image'], (128, 128))
  input_mask = tf.image.resize(datapoint['segmentation_mask'], (128, 128))

  if tf.random.uniform(()) > 0.5:
    input_image = tf.image.flip_left_right(input_image)
    input_mask = tf.image.flip_left_right(input_mask)

  input_image, input_mask = normalize(input_image, input_mask)
  input_mask = tf.math.round(input_mask)

  return input_image, input_mask

# Function for resizing the test images to the desired output shape (no augmentaion).
def load_image_test(datapoint):
  input_image = tf.image.resize(datapoint['image'], (128, 128))
  input_mask = tf.image.resize(datapoint['segmentation_mask'], (128, 128))

  input_image, input_mask = normalize(input_image, input_mask)

  return input_image, input_mask



# Functions for resizing the mask to the desired size of factor 4, 8 or 16 to be used as cascade losses.
def resize_image(mask, wanted_height: int, wanted_width: int, resize_factor : int):
  input_mask=tf.image.resize(mask, (wanted_height//resize_factor, wanted_width//resize_factor))
  input_mask = tf.math.round(input_mask)
  return input_mask

# Function for resizing the masks used as cascade losses in the ICNet architecture.
def preprocess_icnet(image, 
                     mask,
                     image_height : int = 128,
                     image_width : int = 128):
  mask4 = resize_image(mask, image_height, image_width, 4)
  mask8 = resize_image(mask, image_height, image_width, 8)
  mask16 = resize_image(mask, image_height, image_width, 16)
  return image, {'CC_1': mask16, 'CC_2': mask8, 'CC_fin': mask4, 'final_output': mask}

Separating dataset into input and multiple outputs of different sizes.

dataset, info = tfds.load('oxford_iiit_pet:3.*.*', with_info=True)
n_train = info.splits['train'].num_examples
n_test = info.splits['test'].num_examples

TRAIN_LENGTH = n_train
BATCH_SIZE = 64
BUFFER_SIZE = 1000
STEPS_PER_EPOCH = TRAIN_LENGTH // BATCH_SIZE

train = dataset['train'].map(load_image_train, num_parallel_calls=tf.data.experimental.AUTOTUNE)
test = dataset['test'].map(load_image_test)

train = train.map(preprocess_icnet, num_parallel_calls=tf.data.experimental.AUTOTUNE)
test = test.map(preprocess_icnet)

train_dataset = train.cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE).repeat()
train_dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
test_dataset = test.batch(BATCH_SIZE)

Defining the function for displaying images and the model's predictions jointly.

def display(display_list):
  plt.figure(figsize=(15, 15))

  title = ['Input Image', 'True Mask', 'Predicted Mask']

  for i in range(len(display_list)):
    plt.subplot(1, len(display_list), i+1)
    plt.title(title[i])
    plt.imshow(tf.keras.preprocessing.image.array_to_img(display_list[i]))
    plt.axis('off')
  plt.show()

for image, mask in train.take(1):
  sample_image, sample_mask = image, mask['final_output']
display([sample_image, sample_mask])

Defining the functions needed for the PSPNet module.

def pool_block(cur_tensor: tf.Tensor,
               image_width: int,
               image_height: int,
               pooling_factor: int,
               activation: str = 'relu'
              ) -> tf.Tensor:
  """
  Parameters
  ----------
  cur_tensor : tf.Tensor
      Incoming tensor.

  image_width : int
      Width of the image.

  image_height : int
      Height of the image.

  pooling_factor : int
      Pooling factor to scale image.

  activation : str, default = 'relu'
      Activation function to be applied after the pooling operations.

  Returns
  -------
  tf.Tensor
      2D convolutional block.
  """
  
  #Calculate the strides with image size and pooling factor
  strides = [int(np.round(float(image_width)/pooling_factor)),
            int(np.round(float(image_height)/pooling_factor))]
  
  pooling_size = strides
  
  x = AveragePooling2D(pooling_size, strides=strides, padding='same')(cur_tensor)
  
  x = Conv2D(filters = 512, 
             kernel_size = (1,1),
             padding = 'same')(x)
  
  x = BatchNormalization()(x)
  
  x = Activation(activation)(x)
  
  # Resizing images to correct shape for future concat
  x = tf.keras.layers.experimental.preprocessing.Resizing(
    image_height, 
    image_width, 
    interpolation="bilinear")(x) 
  
  return x

# Function for formatting the resnet model to a modified one which takes advantage of dilation rates instead of strides in the final blocks
def modify_ResNet_Dilation(
  model: tf.keras.Model
) -> tf.keras.Model:
  """
  Modifies the ResNet to fit the PSPNet Paper with the described dilated strategy.
  
    Parameters
    ----------
    model : tf.keras.Model
        The ResNet-50 model to be modified.
        
    Returns
    -------
    tf.keras.Model
        Modified model.
  """
  for i in range(0,4):
    model.get_layer('conv4_block1_{}_conv'.format(i)).strides = 1
    model.get_layer('conv4_block1_{}_conv'.format(i)).dilation_rate = 2
    model.get_layer('conv5_block1_{}_conv'.format(i)).strides = 1
    model.get_layer('conv5_block1_{}_conv'.format(i)).dilation_rate = 4
  model.save('/tmp/my_model')
  new_model = tf.keras.models.load_model('/tmp/my_model')
  return new_model
  
def PSPNet(image_width: int,
           image_height: int,
           n_classes: int,
           kernel_size: tuple = (3,3),
           activation: str = 'relu',
           weights: str = 'imagenet',
           shallow_tuning: bool = False,
           isICNet: bool = False
          ) -> tf.keras.Model:
  """
  Setting up the PSPNet model.
  
    Parameters
    ----------
    image_width : int
        Width of the image.
        
    image_height : int
        Height of the image.
        
    n_classes : int
        Number of classes.
        
    kernel_size : tuple, default = (3, 3)
        Size of the kernel.    
        
    activation : str, default = 'relu'
        Activation function to be applied after the pooling operations.
        
    weights: str, default = 'imagenet'
        String defining which weights to use for the backbone, 'imagenet' for ImageNet weights or None for normalized initialization.
        
    shallow_tuning : bool, default = True
        Boolean for using shallow tuning with pre-trained weights from ImageNet or not.
        
    isICNet : bool,  default = False   
        Boolean to determine if the PSPNet will be part of an ICNet.
        
    Returns
    -------
    tf.keras.Model
        The finished keras model for PSPNet.
  """
  if shallow_tuning and not weights:
    raise ValueError("Shallow tuning can not be performed without loading pre-trained weights. Please input 'imagenet' to argument weights...")
  #If the function is used for the ICNet input_shape is set to none as ICNet takes 3 inputs
  if isICNet:
    input_shape=(None, None, 3)
  else:
    input_shape=(image_height,image_width,3)
    
  #Initializing the ResNet50 Backbone  
  y=ResNet50(include_top=False, weights=weights, input_shape=input_shape)

  y=modify_ResNet_Dilation(y)
  if shallow_tuning:
    y.trainable=False
  
  pooling_layer=[]
  output=y.output

  pooling_layer.append(output)
  
  h = image_height//8
  w = image_width//8
  
  #Loop for calling the pool block functions for pooling factors [1,2,3,6]
  for i in [1,2,3,6]:
    pool = pool_block(output, h, w, i, activation)
    pooling_layer.append(pool)
    
  x=Concatenate()(pooling_layer)
  
  x=Conv2D(filters=n_classes, kernel_size=(1,1), padding='same')(x)
  
  x=UpSampling2D(size=(8,8), data_format='channels_last', interpolation='bilinear')(x)
  x=Reshape((image_height*image_width, n_classes))(x)
  x=Activation(tf.nn.softmax)(x)
  x=Reshape((image_height,image_width,n_classes))(x)

  final_model=tf.keras.Model(inputs=y.input, outputs=x)
                              
  return final_model

Defining the functions needed for the ICNet.

def PSP_rest(input_prev: tf.Tensor
            ) -> tf.Tensor:
  """
  Function for adding stage 4 and 5 of ResNet50 to the 1/4 image size branch of the ICNet.
  
    Parameters
    ----------
    input_prev: tf.Tensor
        The incoming tensor from the output of the joint medium and low resolution branch.
    
    Returns
    -------
    tf.Tensor
        The tensor used to continue the 1/4th branch.
  """
  
  
  y_ = input_prev
  #Stage 4
  #Conv_Block
  y = Conv2D(256, 1, dilation_rate=2, padding='same', name='C4_block1_conv1')(y_)
  y = BatchNormalization(name='C4_block1_bn1')(y)
  y = Activation('relu', name='C4_block1_act1')(y)
  y = Conv2D(256, 3, dilation_rate=2, padding='same', name='C4_block1_conv2')(y)
  y = BatchNormalization(name='C4_block1_bn2')(y)
  y = Activation('relu', name='C4_block1_act2')(y)
  y_ = Conv2D(1024, 1, dilation_rate=2, padding='same', name='C4_block1_conv0')(y_)
  y = Conv2D(1024, 1, dilation_rate=2, padding='same', name='C4_block1_conv3')(y)
  y_ = BatchNormalization(name='C4_block1_bn0')(y_)
  y = BatchNormalization(name='C4_block1_bn3')(y)
  y = Add(name='C4_skip1')([y_,y])
  y_ = Activation('relu', name='C4_block1_act3')(y)
  #IDBLOCK1
  y = Conv2D(256, 1, dilation_rate=2, padding='same', name='C4_block2_conv1')(y_)
  y = BatchNormalization(name='C4_block2_bn1')(y)
  y = Activation('relu', name='C4_block2_act1')(y)
  y = Conv2D(256, 3, dilation_rate=2, padding='same', name='C4_block2_conv2')(y)
  y = BatchNormalization(name='C4_block2_bn2')(y)
  y = Activation('relu', name='C4_block2_act2')(y)
  y = Conv2D(1024,1, dilation_rate=2, padding='same', name='C4_block2_conv3')(y)
  y = BatchNormalization(name='C4_block2_bn3')(y)
  y = Add(name='C4_skip2')([y_,y])
  y_ = Activation('relu', name='C4_block2_act3')(y)
  #IDBLOCK2
  y = Conv2D(256, 1, dilation_rate=2, padding='same', name='C4_block3_conv1')(y_)
  y = BatchNormalization(name='C4_block3_bn1')(y)
  y = Activation('relu', name='C4_block3_act1')(y)
  y = Conv2D(256, 3, dilation_rate=2, padding='same', name='C4_block3_conv2')(y)
  y = BatchNormalization(name='C4_block3_bn2')(y)
  y = Activation('relu', name='C4_block3_act2')(y)
  y = Conv2D(1024,1, dilation_rate=2, padding='same', name='C4_block3_conv3')(y)
  y = BatchNormalization(name='C4_block3_bn3')(y)
  y = Add(name='C4_skip3')([y_,y])
  y_ = Activation('relu', name='C4_block3_act3')(y)
  #IDBlock3
  y = Conv2D(256, 1, dilation_rate=2, padding='same', name='C4_block4_conv1')(y_)
  y = BatchNormalization(name='C4_block4_bn1')(y)
  y = Activation('relu', name='C4_block4_act1')(y)
  y = Conv2D(256, 3, dilation_rate=2, padding='same', name='C4_block4_conv2')(y)
  y = BatchNormalization(name='C4_block4_bn2')(y)
  y = Activation('relu', name='C4_block4_act2')(y)
  y = Conv2D(1024,1, dilation_rate=2, padding='same', name='C4_block4_conv3')(y)
  y = BatchNormalization(name='C4_block4_bn3')(y)
  y = Add(name='C4_skip4')([y_,y])
  y_ = Activation('relu', name='C4_block4_act3')(y)
  #ID4
  y = Conv2D(256, 1, dilation_rate=2, padding='same', name='C4_block5_conv1')(y_)
  y = BatchNormalization(name='C4_block5_bn1')(y)
  y = Activation('relu', name='C4_block5_act1')(y)
  y = Conv2D(256, 3, dilation_rate=2, padding='same', name='C4_block5_conv2')(y)
  y = BatchNormalization(name='C4_block5_bn2')(y)
  y = Activation('relu', name='C4_block5_act2')(y)
  y = Conv2D(1024,1, dilation_rate=2, padding='same', name='C4_block5_conv3')(y)
  y = BatchNormalization(name='C4_block5_bn3')(y)
  y = Add(name='C4_skip5')([y_,y])
  y_ = Activation('relu', name='C4_block5_act3')(y)
  #ID5
  y = Conv2D(256, 1, dilation_rate=2, padding='same', name='C4_block6_conv1')(y_)
  y = BatchNormalization(name='C4_block6_bn1')(y)
  y = Activation('relu', name='C4_block6_act1')(y)
  y = Conv2D(256, 3, dilation_rate=2, padding='same', name='C4_block6_conv2')(y)
  y = BatchNormalization(name='C4_block6_bn2')(y)
  y = Activation('relu', name='C4_block6_act2')(y)
  y = Conv2D(1024,1, dilation_rate=2, padding='same', name='C4_block6_conv3')(y)
  y = BatchNormalization(name='C4_block6_bn3')(y)
  y = Add(name='C4_skip6')([y_,y])
  y_ = Activation('relu', name='C4_block6_act3')(y)
  
  #Stage 5
  #Conv
  y = Conv2D(512, 1, dilation_rate=4,padding='same', name='C5_block1_conv1')(y_)
  y = BatchNormalization(name='C5_block1_bn1')(y)
  y = Activation('relu', name='C5_block1_act1')(y)
  y = Conv2D(512, 3, dilation_rate=4,padding='same', name='C5_block1_conv2')(y)
  y = BatchNormalization(name='C5_block1_bn2')(y)
  y = Activation('relu', name='C5_block1_act2')(y)
  y_ = Conv2D(2048, 1, dilation_rate=4,padding='same', name='C5_block1_conv0')(y_)
  y = Conv2D(2048, 1, dilation_rate=4,padding='same', name='C5_block1_conv3')(y)
  y_ = BatchNormalization(name='C5_block1_bn0')(y_)
  y = BatchNormalization(name='C5_block1_bn3')(y)
  y = Add(name='C5_skip1')([y_,y])
  y_ = Activation('relu', name='C5_block1_act3')(y)
  
  #ID
  y = Conv2D(512, 1, dilation_rate=4,padding='same', name='C5_block2_conv1')(y_)
  y = BatchNormalization(name='C5_block2_bn1')(y)
  y = Activation('relu', name='C5_block2_act1')(y)
  y = Conv2D(512, 3, dilation_rate=4,padding='same', name='C5_block2_conv2')(y)
  y = BatchNormalization(name='C5_block2_bn2')(y)
  y = Activation('relu', name='C5_block2_act2')(y)
  y = Conv2D(2048, 1, dilation_rate=4,padding='same', name='C5_block2_conv3')(y)
  y = BatchNormalization(name='C5_block2_bn3')(y)
  y = Add(name='C5_skip2')([y_,y])
  y_ = Activation('relu', name='C5_block2_act3')(y)
  
  #ID
  y = Conv2D(512, 1, dilation_rate=4,padding='same', name='C5_block3_conv1')(y_)
  y = BatchNormalization(name='C5_block3_bn1')(y)
  y = Activation('relu', name='C5_block3_act1')(y)
  y = Conv2D(512, 3, dilation_rate=4,padding='same', name='C5_block3_conv2')(y)
  y = BatchNormalization(name='C5_block3_bn2')(y)
  y = Activation('relu', name='C5_block3_act2')(y)
  y = Conv2D(2048, 1, dilation_rate=4,padding='same', name='C5_block3_conv3')(y)
  y = BatchNormalization(name='C5_block3_bn3')(y)
  y = Add(name='C5_skip3')([y_,y])
  y_ = Activation('relu', name='C5_block3_act3')(y)
  
  return(y_)

# Function for the CFF module in the ICNet architecture. The inputs are which stage (1 or 2), the output from the smaller branch, the output from the
# larger branch, n_classes and the width and height of the output of the smaller branch.
def CFF(stage : int, 
        F_small : tf.Tensor, 
        F_large : tf.Tensor, 
        n_classes: int,
        input_height_small: int,
        input_width_small: int
        ) -> Tuple[tf.Tensor, tf.Tensor]:
  """
  Function for creating the cascade feature fusion (CFF) inside the ICNet model.
  
    Parameters
    ----------
    stage : int
        Integer determining the stage of the CFF in order to name the layers correctly.
        
     F_small : tf.Tensor
         The smaller tensor to be used in the CFF.
        
     F_large : tf.Tensor
         The larger tensor to be used in the CFF.
    
    n_classes : int
        Number of classes.
        
    input_height_small: int
        The height of the smaller tensor.
    
    input_width_small: int
        The width of the smaller tensor.
    -------
    tf.Tensor
        The tensor used to calculate the auxilliary loss.
        
    tf.Tensor    
        The tensor used to continue the model.
  """
  
  F_small = tf.keras.layers.experimental.preprocessing.Resizing(int(input_width_small*2), int(input_height_small*2), interpolation="bilinear", name="Upsample_x2_small_{}".format(stage))(F_small)
  F_aux = Conv2D(n_classes, 1, name="CC_{}".format(stage), activation='softmax')(F_small)
  
  F_small = Conv2D(128, 3, dilation_rate=2, padding='same', name="intermediate_f_small_{}".format(stage))(F_small)
  F_small = BatchNormalization(name="intermediate_f_small_bn_{}".format(stage))(F_small)
  
  F_large = Conv2D(128, 1, padding='same', name="intermediate_f_large_{}".format(stage))(F_large)
  F_large = BatchNormalization(name="intermediate_f_large_bn_{}".format(stage))(F_large)
  
  F_small = Add(name="add_intermediates_{}".format(stage))([F_small,F_large])
  F_small = Activation('relu', name="activation_CFF_{}".format(stage))(F_small)
  return F_aux, F_small

def ICNet_1(input_obj : tf.keras.Input,
            n_filters: int,
            kernel_size: tuple,
            activation: str
           ) -> tf.Tensor:
  """
  Function for the high-res branch of ICNet where image is in scale 1:1.
  
    Parameters
    ----------
    input_obj : tf.keras.Input
        The input object inputted to the branch.
        
    n_filters : int
        Number of filters.
        
    kernel_size : tuple
        Size of the kernel.    
        
    activation : str
        Activation function to be applied.
        
    Returns
    -------
    tf.Tensor
        The output of the original size branch as a tf.Tensor.
  """
  for i in range(1,4):
    input_obj=Conv2D(filters=n_filters*2*i, kernel_size=kernel_size, strides=(2,2), padding='same')(input_obj)
    input_obj=BatchNormalization()(input_obj)
    input_obj=Activation(activation)(input_obj)
  return input_obj  

def ICNet(image_height: int,
         image_width: int,
         n_classes: int,
         n_filters: int = 16,
         kernel_size: tuple = (3,3),
         activation: str = 'relu'
         ) -> tf.keras.Model:
  """
  Function for creating the ICNet model.
  
    Parameters
    ----------
    image_height : int
        Height of the image.
        
    image_width : int
        Width of the image.
        
    n_classes : int
        Number of classes.
        
    n_filters : int, default = 16
        Number of filters in the ICNet original size image branch.
        
    kernel_size : tuple, default = (3, 3)
        Size of the kernel.    
        
    activation : str, default = 'relu'
        Activation function to be applied after the pooling operations.
        
    Returns
    -------
    tf.keras.Model
        The finished keras model for the ICNet.
  """
  input_shape=[image_height,image_width,3]
  input_obj = tf.keras.Input(shape=input_shape, name="input_img_1")
  input_obj_4 = tf.keras.layers.experimental.preprocessing.Resizing(
    image_height//4, image_width//4, interpolation="bilinear", name="input_img_4")(input_obj)
  input_obj_2 = tf.keras.layers.experimental.preprocessing.Resizing(
    image_height//2, image_width//2, interpolation="bilinear", name="input_img_2")(input_obj)
  ICNet_Model1=ICNet_1(input_obj, n_filters, kernel_size, activation)
  PSPModel = PSPNet(image_height//4, image_width//4, n_classes, isICNet=True)
  PSPModel_2_4 = tf.keras.models.Model(inputs=PSPModel.input, outputs=PSPModel.get_layer('conv4_block3_out').output, name="JointResNet_2_4")
  
  ICNet_Model4 = PSPModel_2_4(input_obj_4)
  ICNet_Model2 = PSPModel_2_4(input_obj_2) 
  ICNet_4_rest = PSP_rest(ICNet_Model4)
  
  out1, last_layer = CFF(1, ICNet_4_rest, ICNet_Model2, n_classes, image_height//32, image_width//32)
  out2, last_layer = CFF(2, last_layer, ICNet_Model1, n_classes, image_height//16, image_width//16)
  upsample_2 = UpSampling2D(2, interpolation='bilinear', name="Upsampling_final_prediction")(last_layer)
  output = Conv2D(n_classes, 1, name="CC_fin", activation='softmax')(upsample_2)
  final_output = UpSampling2D(4, interpolation='bilinear', name='final_output')(output)
  final_model = tf.keras.models.Model(inputs=input_obj, outputs=[out1, out2, output, final_output])
  return final_model

Let's call the ICNet function to create the model with input shape (128, 128, 3) and 3 classes with the standard values for number of filters, kernel size and activation function.

model=ICNet(128,128,3)

Here is the summary of the model.

model.summary() 

Compiling the model with optimizer Adam, loss function SparseCategoricalCrossentropy and metrics SparseCategoricalAccuracy. We also add loss weights 0.4, 0.4, 1 and 0 to the lower resolution output, medium resolution output and high resolution output and final output (only evaluated in testing phase) respectively as was done in the original ICNet paper.

model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(), loss_weights=[0.4,0.4,1,0],
              metrics="acc")

Below, the functions for displaying the predictions from the model against the true image are defined.

# Function for creating the predicted image. It takes the max value between the classes and assigns the correct class label to the image, thus creating a predicted mask. 
def create_mask(pred_mask):
  pred_mask = tf.argmax(pred_mask, axis=-1)
  pred_mask = pred_mask[..., tf.newaxis]
  return pred_mask[0]

# Function for showing the model prediction. Output can be 0, 1 or 2 depending on if you want to see the low resolution, medium resolution or high resolution prediction respectively. 
def show_predictions(dataset=None, num=1, output=3):
  if dataset:
    for image, mask in dataset.take(num):
      pred_mask = model.predict(image[tf.newaxis,...])[output]
      display([image, mask['final_output'], create_mask(pred_mask)])
  else:
    display([sample_image, sample_mask,
             create_mask(model.predict(sample_image[tf.newaxis, ...])[output])])
    
show_predictions()

Let's define the variables needed for training the model.

EPOCHS = 50
VAL_SUBSPLITS = 5
VALIDATION_STEPS = n_test//BATCH_SIZE//VAL_SUBSPLITS

And here we define the custom callback function for showing how the model improves its predictions.

class MyCustomCallback(tf.keras.callbacks.Callback):
  def on_epoch_end(self, epoch, logs=None):
    show_predictions()
    

Finally, we fit the model to the Oxford dataset.

model_history =  model.fit(train_dataset, epochs=EPOCHS, steps_per_epoch=STEPS_PER_EPOCH, 
                           validation_steps=VALIDATION_STEPS, validation_data=test_dataset, callbacks=[MyCustomCallback()]) 

Finally, we visualize some predictions on the test dataset.

show_predictions(test, 10)

We also visualize the accuracies and losses through the library matplotlib.

loss = model_history.history['loss']
acc = model_history.history['final_output_acc']
val_loss = model_history.history['val_loss']
val_loss1 = model_history.history['val_CC_1_loss']
val_loss2 = model_history.history['val_CC_2_loss']
val_loss3 = model_history.history['val_CC_fin_loss']
val_loss4 = model_history.history['val_final_output_loss']
val_acc1 = model_history.history['val_CC_1_acc']
val_acc2 = model_history.history['val_CC_2_acc']
val_acc3 = model_history.history['val_CC_fin_acc']
val_acc4 = model_history.history['val_final_output_acc']

epochs = range(EPOCHS)

plt.figure(figsize=(20,3))
plt.subplot(1,4,1)
plt.plot(epochs, loss, 'r', label='Training loss')
plt.plot(epochs, val_loss, 'bo', label='Validation loss')
plt.title('Training and Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss Value')
plt.legend()
plt.subplot(1,4,2)
plt.plot(epochs, acc, 'r', label="Training accuracy")
plt.plot(epochs, val_acc4, 'bo', label="Validation accuracy")
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.title('Training and Validation Accuracy')
plt.legend()
plt.subplot(1,4,3)
plt.plot(epochs, val_loss1, 'bo', label="CC_1")
plt.plot(epochs, val_loss2, 'go', label="CC_2")
plt.plot(epochs, val_loss3, 'yo', label="CC_fin")
plt.plot(epochs, val_loss4, 'yo', label="final_output")
plt.xlabel('Epoch')
plt.ylabel('Loss Value')
plt.title('Validation Loss for the Different Outputs')
plt.legend()
plt.subplot(1,4,4)
plt.plot(epochs, val_acc1, 'bo', label="CC_1")
plt.plot(epochs, val_acc2, 'go', label="CC_2")
plt.plot(epochs, val_acc3, 'yo', label="CC_fin")
plt.plot(epochs, val_acc4, 'yo', label="final_output")
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.title('Validation Accuracy for the Different Outputs')
plt.legend()
plt.show()

Implementation of PSPNet with Hyperparameter Tuning in Parallel

William Anzén (Linkedin), Christian von Koch (Linkedin)

2021, Stockholm, Sweden

This project was supported by Combient Mix AB under the industrial supervision of Razesh Sainudiin and Max Fischer.

In this notebook, an implementation of PSPNet is presented which is an architecture which uses scene parsing and evaluates the images at different scales and finally combines the different results to form a final prediction. The architecture is evaluated against the Oxford-IIIT Pet Dataset whilst some hyperparameters of the model are tunead parallely through Spark Trials.

Importing the required packages.

import tensorflow as tf
import tensorflow_datasets as tfds
import matplotlib.pyplot as plt
from tensorflow.keras.layers import AveragePooling2D, Conv2D, BatchNormalization, Activation, Concatenate, UpSampling2D, Reshape, TimeDistributed, ConvLSTM2D
from tensorflow.keras import Model
import numpy as np
from tensorflow.keras.applications.resnet50 import ResNet50
from hyperopt import fmin, tpe, hp, Trials, STATUS_OK, SparkTrials

Defining functions for normalizing and transforming the images.

# Function for normalizing image_size so that pixel intensity is between 0 and 1
def normalize(input_image, input_mask):
  input_image = tf.cast(input_image, tf.float32) / 255.0
  input_mask -= 1
  return input_image, input_mask

# Function for resizing the train images to the desired input shape of 128x128 as well as augmenting the training images
def load_image_train(datapoint):
  input_image = tf.image.resize(datapoint['image'], (128, 128))
  input_mask = tf.image.resize(datapoint['segmentation_mask'], (128, 128))

  if tf.random.uniform(()) > 0.5:
    input_image = tf.image.flip_left_right(input_image)
    input_mask = tf.image.flip_left_right(input_mask)

  input_image, input_mask = normalize(input_image, input_mask)
  input_mask = tf.math.round(input_mask)

  return input_image, input_mask

# Function for resizing the test images to the desired output shape (no augmenation)
def load_image_test(datapoint):
  input_image = tf.image.resize(datapoint['image'], (128, 128))
  input_mask = tf.image.resize(datapoint['segmentation_mask'], (128, 128))

  input_image, input_mask = normalize(input_image, input_mask)

  return input_image, input_mask

The Oxford-IIT Pet Dataset from the TensorFlow datasets is loaded and then the images are transformed in desired way and the datasets used for training and inference are created.

dataset, info = tfds.load('oxford_iiit_pet:3.*.*', with_info=True)

TRAIN_LENGTH = info.splits['train'].num_examples
TEST_LENGTH = info.splits['test'].num_examples
BATCH_SIZE = 64
BUFFER_SIZE = 1000
STEPS_PER_EPOCH = TRAIN_LENGTH // BATCH_SIZE

train = dataset['train'].map(load_image_train)
test = dataset['test'].map(load_image_test)

The dataset is then converted to numpy format since Spark Trials does not yet handle the TensorFlow Dataset class.

train_numpy = tfds.as_numpy(train)
test_numpy = tfds.as_numpy(test)
X_train = np.array(list(map(lambda x: x[0], train_numpy)))
Y_train = np.array(list(map(lambda x: x[1], train_numpy)))

X_test = np.array(list(map(lambda x: x[0], test_numpy)))
Y_test = np.array(list(map(lambda x: x[1], test_numpy)))

An example image is displayed.

def display(display_list):
  plt.figure(figsize=(15, 15))

  title = ['Input Image', 'True Mask', 'Predicted Mask']

  for i in range(len(display_list)):
    plt.subplot(1, len(display_list), i+1)
    plt.title(title[i])
    plt.imshow(tf.keras.preprocessing.image.array_to_img(display_list[i]))
    plt.axis('off')
  plt.show()

it = iter(train)
#for image, mask in next(train):
sample_image, sample_mask = next(it)
display([sample_image, sample_mask])

Defining the functions needed for the PSPNet.

def pool_block(cur_tensor: tf.Tensor,
               image_width: int,
               image_height: int,
               pooling_factor: int,
               activation: str = 'relu'
              ) -> tf.Tensor:
  """
  Parameters
  ----------
  cur_tensor : tf.Tensor
      Incoming tensor.

  image_width : int
      Width of the image.

  image_height : int
      Height of the image.

  pooling_factor : int
      Pooling factor to scale image.

  activation : str, default = 'relu'
      Activation function to be applied after the pooling operations.

  Returns
  -------
  tf.Tensor
      2D convolutional block.
  """
  
  #Calculate the strides with image size and pooling factor
  strides = [int(np.round(float(image_width)/pooling_factor)),
            int(np.round(float(image_height)/pooling_factor))]
  
  pooling_size = strides
  
  x = AveragePooling2D(pooling_size, strides=strides, padding='same')(cur_tensor)
  
  x = Conv2D(filters = 512, 
             kernel_size = (1,1),
             padding = 'same')(x)
  
  x = BatchNormalization()(x)
  x = Activation(activation)(x)
  
  # Resizing images to correct shape for future concat
  x = tf.keras.layers.experimental.preprocessing.Resizing(
    image_height, 
    image_width, 
    interpolation="bilinear")(x) 
  
  return x

def modify_ResNet_Dilation(
  model: tf.keras.Model
) -> tf.keras.Model:
  """
  Modifies the ResNet to fit the PSPNet Paper with the described dilated strategy.
  
    Parameters
    ----------
    model : tf.keras.Model
        The ResNet-50 model to be modified.
        
    Returns
    -------
    tf.keras.Model
        Modified model.
  """
  for i in range(0,4):
    model.get_layer('conv4_block1_{}_conv'.format(i)).strides = 1
    model.get_layer('conv4_block1_{}_conv'.format(i)).dilation_rate = 2
    model.get_layer('conv5_block1_{}_conv'.format(i)).strides = 1
    model.get_layer('conv5_block1_{}_conv'.format(i)).dilation_rate = 4
  model.save('/tmp/my_model')
  new_model = tf.keras.models.load_model('/tmp/my_model')
  return new_model

def PSPNet(image_width: int,
           image_height: int,
           n_classes: int,
           kernel_size: tuple = (3,3),
           activation: str = 'relu',
           weights: str = 'imagenet',
           shallow_tuning: bool = False,
           isICNet: bool = False
          ) -> tf.keras.Model:
  """
  Setting up the PSPNet model
  
    Parameters
    ----------
    image_width : int
        Width of the image.
        
    image_height : int
        Height of the image.
        
    n_classes : int
        Number of classes.
        
    kernel_size : tuple, default = (3, 3)
        Size of the kernel.    
        
    activation : str, default = 'relu'
        Activation function to be applied after the pooling operations.
        
    weights: str, default = 'imagenet'
        String defining which weights to use for the backbone, 'imagenet' for ImageNet weights or None for normalized initialization.
        
    shallow_tuning : bool, default = True
        Boolean for using shallow tuning with pre-trained weights from ImageNet or not.
        
    isICNet : bool,  default = False   
        Boolean to determine if the PSPNet will be part of an ICNet.
        
    Returns
    -------
    tf.keras.Model
        The finished keras model for PSPNet.
  """
  if shallow_tuning and not weights:
    raise ValueError("Shallow tuning can not be performed without loading pre-trained weights. Please input 'imagenet' to argument weights...")
  #If the function is used for the ICNet input_shape is set to none as ICNet takes 3 inputs
  if isICNet:
    input_shape=(None, None, 3)
  else:
    input_shape=(image_height,image_width,3)
    
  #Initializing the ResNet50 Backbone  
  y=ResNet50(include_top=False, weights=weights, input_shape=input_shape)

  y=modify_ResNet_Dilation(y)
  if shallow_tuning:
    y.trainable=False
  
  pooling_layer=[]
  output=y.output

  pooling_layer.append(output)
  
  h = image_height//8
  w = image_width//8
  
  #Loop for calling the pool block functions for pooling factors [1,2,3,6]
  for i in [1,2,3,6]:
    pool = pool_block(output, h, w, i, activation)
    pooling_layer.append(pool)
    
  x=Concatenate()(pooling_layer)
  
  x=Conv2D(filters=n_classes, kernel_size=(1,1), padding='same')(x)
  
  x=UpSampling2D(size=(8,8), data_format='channels_last', interpolation='bilinear')(x)
  x=Reshape((image_height*image_width, n_classes))(x)
  x=Activation(tf.nn.softmax)(x)
  x=Reshape((image_height,image_width,n_classes))(x)

  final_model=tf.keras.Model(inputs=y.input, outputs=x)
                              
  return final_model

Defining the custom data generators used for training the model.

def batch_generator(batch_size):
  indices = np.arange(len(X_train)) 
  batch=[]
  while True:
  # it might be a good idea to shuffle your data before each epoch
    np.random.shuffle(indices) 
    for i in indices:
      batch.append(i)
      if len(batch)==batch_size:
        yield X_train[batch], Y_train[batch]
        batch=[]
        
def batch_generator_eval(batch_size):
  indices = np.arange(len(X_test)) 
  batch=[]
  while True:
    for i in indices:
      batch.append(i)
      if len(batch)==batch_size:
        yield X_test[batch], Y_test[batch]
        batch=[]

To exploit parallelism, Spark Trials require that you define a function to be used for loading the dataset as well as training and evaluating the model.

def train_spark(params):
  
  gpus = tf.config.list_physical_devices('GPU')
  if gpus:
    try:
      # Currently, memory growt*h needs to be the same across GPUs
      for gpu in gpus:
        tf.config.experimental.set_memory_growth(gpu, True)
      logical_gpus = tf.config.experimental.list_logical_devices('GPU')
      print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
    except RuntimeError as e:
      # Memory growth must be set before GPUs have been initialized
      print(e)
  VAL_SUBSPLITS=5
  EPOCHS=50
  VALIDATION_STEPS = TEST_LENGTH//params['batch_size']//VAL_SUBSPLITS
  STEPS_PER_EPOCH = TRAIN_LENGTH // params['batch_size']
  BATCH_SIZE = params['batch_size']
  print(BATCH_SIZE)
  BUFFER_SIZE = 1000
  """
  An example train method that calls into HorovodRunner.
  This method is passed to hyperopt.fmin().
  
  :param params: hyperparameters. Its structure is consistent with how search space is defined. See below.
  :return: dict with fields 'loss' (scalar loss) and 'status' (success/failure status of run)
  """
  train_dataset = batch_generator(BATCH_SIZE)
  test_dataset = batch_generator_eval(BATCH_SIZE)
  model=PSPNet(128,128,3)
  model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=params['learning_rate']),
              loss=tf.keras.losses.SparseCategoricalCrossentropy(),
              metrics=tf.keras.metrics.SparseCategoricalAccuracy())
  
  model_history =  model.fit(train_dataset, epochs=EPOCHS, steps_per_epoch=STEPS_PER_EPOCH)
  loss = model.evaluate(test_dataset, steps=VALIDATION_STEPS)[0]
  return {'loss': loss, 'status': STATUS_OK}

In this experiment, the hyperparameters learning_rate and batch_size were explored through random search in combination with an adaptive algorithm called Tree of Parzen Estimators (TPE). Since parallelism were set to 4 the choice of hyperparameters will be set randomly for the first 4 runs in parallel and then adaptively for the second pass of 4 runs.

space = {
  'learning_rate': hp.loguniform('learning_rate', np.log(1e-4), np.log(1e-1)),
  'batch_size': hp.choice('batch_size', [16, 32, 64]),
}
algo=tpe.suggest

spark_trials = SparkTrials(parallelism=4)
best_param = fmin(
  fn=train_spark,
  space=space,
  algo=algo,
  max_evals=8,
  return_argmin=False,
  trials = spark_trials,
)

print(best_param)

Implementation of PSPNet with distributed training using Horovod

William Anzén (Linkedin), Christian von Koch (Linkedin)

2021, Stockholm, Sweden

This project was supported by Combient Mix AB under the industrial supervision of Razesh Sainudiin and Max Fischer.

In this notebook, an implementation of PSPNet is presented which is an architecture which uses scene parsing and evaluates the images at different scales and finally combines the different results to form a final prediction. The architecture is evaluated against the Oxford-IIIT Pet Dataset and is trained in a distributed manner using Horovod.

Importing the required packages.

import tensorflow as tf
import tensorflow_datasets as tfds
import matplotlib.pyplot as plt
from tensorflow.keras.layers import AveragePooling2D, Conv2D, BatchNormalization, Activation, Concatenate, UpSampling2D, Reshape, TimeDistributed, ConvLSTM2D
from tensorflow.keras import Model
import numpy as np
from tensorflow.keras.applications.resnet50 import ResNet50
import horovod.tensorflow.keras as hvd

Setting memory growth to the GPUs is recommended as the model is quite memory intensive.

gpus = tf.config.list_physical_devices('GPU')
if gpus:
  try:
    # Currently, memory growt*h needs to be the same across GPUs
    for gpu in gpus:
      tf.config.experimental.set_memory_growth(gpu, True)
    logical_gpus = tf.config.experimental.list_logical_devices('GPU')
    print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
  except RuntimeError as e:
    # Memory growth must be set before GPUs have been initialized
    print(e)

Setting up checkpoint location... The next cell creates a directory for saved checkpoint models.

import os
import time

checkpoint_dir = '/dbfs/ml/OxfordDemo_Horovod/train/'

os.makedirs(checkpoint_dir)

Defining functions for normalizing and transforming the images.

# Function for normalizing image_size so that pixel intensity is between 0 and 1
def normalize(input_image, input_mask):
  input_image = tf.cast(input_image, tf.float32) / 255.0
  input_mask -= 1
  return input_image, input_mask

# Function for resizing the train images to the desired input shape of 128x128 as well as augmenting the training images
def load_image_train(datapoint):
  input_image = tf.image.resize(datapoint['image'], (128, 128))
  input_mask = tf.image.resize(datapoint['segmentation_mask'], (128, 128))

  if tf.random.uniform(()) > 0.5:
    input_image = tf.image.flip_left_right(input_image)
    input_mask = tf.image.flip_left_right(input_mask)

  input_image, input_mask = normalize(input_image, input_mask)
  input_mask = tf.math.round(input_mask)

  return input_image, input_mask


# Function for resizing the test images to the desired output shape (no augmenation)
def load_image_test(datapoint):
  input_image = tf.image.resize(datapoint['image'], (128, 128))
  input_mask = tf.image.resize(datapoint['segmentation_mask'], (128, 128))

  input_image, input_mask = normalize(input_image, input_mask)

  return input_image, input_mask

Defining the functions needed for the PSPNet.

def pool_block(cur_tensor: tf.Tensor,
               image_width: int,
               image_height: int,
               pooling_factor: int,
               activation: str = 'relu'
              ) -> tf.Tensor:
  """
  Parameters
  ----------
  cur_tensor : tf.Tensor
      Incoming tensor.

  image_width : int
      Width of the image.

  image_height : int
      Height of the image.

  pooling_factor : int
      Pooling factor to scale image.

  activation : str, default = 'relu'
      Activation function to be applied after the pooling operations.

  Returns
  -------
  tf.Tensor
      2D convolutional block.
  """
  
  #Calculate the strides with image size and pooling factor
  strides = [int(np.round(float(image_width)/pooling_factor)),
            int(np.round(float(image_height)/pooling_factor))]
  
  pooling_size = strides
  
  x = AveragePooling2D(pooling_size, strides=strides, padding='same')(cur_tensor)
  
  x = Conv2D(filters = 512, 
             kernel_size = (1,1),
             padding = 'same')(x)
  
  x = BatchNormalization()(x)
  x = Activation(activation)(x)
  
  # Resizing images to correct shape for future concat
  x = tf.keras.layers.experimental.preprocessing.Resizing(
    image_height, 
    image_width, 
    interpolation="bilinear")(x) 
  
  return x

def modify_ResNet_Dilation(
  model: tf.keras.Model
) -> tf.keras.Model:
  """
  Modifies the ResNet to fit the PSPNet Paper with the described dilated strategy.
  
    Parameters
    ----------
    model : tf.keras.Model
        The ResNet-50 model to be modified.
        
    Returns
    -------
    tf.keras.Model
        Modified model.
  """
  for i in range(0,4):
    model.get_layer('conv4_block1_{}_conv'.format(i)).strides = 1
    model.get_layer('conv4_block1_{}_conv'.format(i)).dilation_rate = 2
    model.get_layer('conv5_block1_{}_conv'.format(i)).strides = 1
    model.get_layer('conv5_block1_{}_conv'.format(i)).dilation_rate = 4
  model.save('/tmp/my_model')
  new_model = tf.keras.models.load_model('/tmp/my_model')
  return new_model

def PSPNet(image_width: int,
           image_height: int,
           n_classes: int,
           kernel_size: tuple = (3,3),
           activation: str = 'relu',
           weights: str = 'imagenet',
           shallow_tuning: bool = False,
           isICNet: bool = False
          ) -> tf.keras.Model:
  """
  Setting up the PSPNet model
  
    Parameters
    ----------
    image_width : int
        Width of the image.
        
    image_height : int
        Height of the image.
        
    n_classes : int
        Number of classes.
        
    kernel_size : tuple, default = (3, 3)
        Size of the kernel.    
        
    activation : str, default = 'relu'
        Activation function to be applied after the pooling operations.
        
    weights: str, default = 'imagenet'
        String defining which weights to use for the backbone, 'imagenet' for ImageNet weights or None for normalized initialization.
        
    shallow_tuning : bool, default = True
        Boolean for using shallow tuning with pre-trained weights from ImageNet or not.
        
    isICNet : bool,  default = False   
        Boolean to determine if the PSPNet will be part of an ICNet.
        
    Returns
    -------
    tf.keras.Model
        The finished keras model for PSPNet.
  """
  if shallow_tuning and not weights:
    raise ValueError("Shallow tuning can not be performed without loading pre-trained weights. Please input 'imagenet' to argument weights...")
  #If the function is used for the ICNet input_shape is set to none as ICNet takes 3 inputs
  if isICNet:
    input_shape=(None, None, 3)
  else:
    input_shape=(image_height,image_width,3)
    
  #Initializing the ResNet50 Backbone  
  y=ResNet50(include_top=False, weights=weights, input_shape=input_shape)

  y=modify_ResNet_Dilation(y)
  if shallow_tuning:
    y.trainable=False
  
  pooling_layer=[]
  output=y.output

  pooling_layer.append(output)
  
  h = image_height//8
  w = image_width//8
  
  #Loop for calling the pool block functions for pooling factors [1,2,3,6]
  for i in [1,2,3,6]:
    pool = pool_block(output, h, w, i, activation)
    pooling_layer.append(pool)
    
  x=Concatenate()(pooling_layer)
  
  x=Conv2D(filters=n_classes, kernel_size=(1,1), padding='same')(x)
  
  x=UpSampling2D(size=(8,8), data_format='channels_last', interpolation='bilinear')(x)
  x=Reshape((image_height*image_width, n_classes))(x)
  x=Activation(tf.nn.softmax)(x)
  x=Reshape((image_height,image_width,n_classes))(x)

  final_model=tf.keras.Model(inputs=y.input, outputs=x)
                              
  return final_model

A function called by the horovod runner creating the datasets for each worker. The dataset is split according to the amount of GPU's initialized in the horovod runner and distributed on each worker. The dataset is transformed to numpy.ndarray to enable the splitting, then transformed back into a tensorflow.dataset object and batched for training purposes.

def create_datasets_hvd_loop(BATCH_SIZE:int = 64, BUFFER_SIZE:int = 1000, rank=0, size=1):
  dataset, info = tfds.load('oxford_iiit_pet:3.*.*', data_dir='Oxford-%d' % rank, with_info=True)
  TRAIN_LENGTH = info.splits['train'].num_examples
  TEST_LENGTH = info.splits['test'].num_examples
  
  #Creating the ndarray in the correct shapes for training data
  train_original_img = np.ndarray(shape=(TRAIN_LENGTH, 128, 128, 3))
  train_original_mask = np.ndarray(shape=(TRAIN_LENGTH, 128, 128, 1))

  #Loading the data into the arrays 
  count = 0
  for datapoint in dataset['train']:
    img_orig, mask_orig = load_image_train(datapoint)
    train_original_img[count]=img_orig
    train_original_mask[count]=mask_orig

    count+=1
  
  #Creating the ndarrays in the correct shapes for test data  
  test_original_img = np.ndarray(shape=(TEST_LENGTH,128,128,3))
  test_original_mask = np.ndarray(shape=(TEST_LENGTH,128,128,1))
  
  #Loading the data into the arrays
  count=0
  for datapoint in dataset['test']:
    img_orig, mask_orig = load_image_test(datapoint)
    test_original_img[count]=img_orig
    test_original_mask[count]=mask_orig
  
    count+=1
    
  train_dataset = tf.data.Dataset.from_tensor_slices((train_original_img[rank::size], train_original_mask[rank::size]))
  orig_test_dataset = tf.data.Dataset.from_tensor_slices((test_original_img[rank::size], test_original_mask[rank::size]))
  
  train_dataset = train_dataset.shuffle(BUFFER_SIZE).cache().batch(BATCH_SIZE).repeat()
  train_dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
  test_dataset = orig_test_dataset.batch(BATCH_SIZE)
  
  n_train = train_original_img[rank::size].shape[0]
  n_test = test_original_img[rank::size].shape[0]
  print(n_train)
  print(n_test)
  
  return train_dataset, test_dataset, n_train, n_test

The training function run by each worker in a distributed horovod manner.

 def train_hvd(learning_rate):
  import tensorflow as tf
  import tensorflow_datasets as tfds
  
  # Initialize Horovod
  hvd.init()
  
  
  # Optimal batch size from previous notebooks hyperparameter search
  BATCH_SIZE = 16
  BUFFER_SIZE = 1000
  EPOCHS = 50

  # Pin GPU to be used to process local rank (one GPU per process)
  # These steps are skipped on a CPU cluster
  gpus = tf.config.experimental.list_physical_devices('GPU')
  for gpu in gpus:
    tf.config.experimental.set_memory_growth(gpu, True)
  if gpus:
    tf.config.experimental.set_visible_devices(gpus[hvd.local_rank()], 'GPU')
  
  train_dataset, test_dataset, n_train, n_test = create_datasets_hvd_loop(BATCH_SIZE = BATCH_SIZE, rank = hvd.rank(), size = hvd.size())
  
  STEPS_PER_EPOCH = n_train // BATCH_SIZE
  VALIDATION_STEPS = n_test//BATCH_SIZE
  
  model = PSPNet(128,128,3)
  
  # Adjust learning rate based on number of GPUs 
  optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate*hvd.size())
  # Use the Horovod Distributed Optimizer
  optimizer = hvd.DistributedOptimizer(optimizer)
  
  model.compile(optimizer=optimizer,
              loss=tf.keras.losses.SparseCategoricalCrossentropy(),
              metrics=tf.keras.metrics.SparseCategoricalAccuracy())
  
  #Create a callback to broadcast the initial variable states from rank 0 to all other processes.
  # This is required to ensure consistent initialization of all workers when training is started with random weights or restored from a checkpoint.
  callbacks = [
      hvd.callbacks.BroadcastGlobalVariablesCallback(0)
  ]
  
  if hvd.rank() == 0:
      callbacks.append(tf.keras.callbacks.ModelCheckpoint(checkpoint_dir + '/checkpoint-{epoch}.ckpt', save_weights_only = True, monitor='val_loss', save_best_only=True))
  
  #train_dataset = batch_generator(batch_size, X_train, Y_train)
  #test_dataset = batch_generator_eval(batch_size, X_test, Y_test)
  
  model_history =  model.fit(train_dataset, epochs=EPOCHS, steps_per_epoch=STEPS_PER_EPOCH, verbose = 1 if hvd.rank() == 0  else 0,
                           validation_steps=VALIDATION_STEPS, validation_data=test_dataset, callbacks=callbacks)

Initialization of the horovod runner. Make sure to set np = "Amount of workers" that are available for your cluster.

from sparkdl import HorovodRunner

hr = HorovodRunner(np=4, driver_log_verbosity = "all")
# Optimal learning rate from previous notebooks hyperparameter search
hr.run(train_hvd, learning_rate=0.0001437661898681224)

Reloading the test dataset to perform inference...

dataset, info = tfds.load('oxford_iiit_pet:3.*.*', with_info=True)

BATCH_SIZE = 16
test = dataset['test'].map(load_image_test)
test_dataset = test.batch(BATCH_SIZE)

TEST_LENGTH = info.splits['test'].num_examples
VALIDATION_STEPS = TEST_LENGTH//BATCH_SIZE

Checking the latest checkpoint saved in the previous training.

ls /dbfs/ml/OxfordDemo_Horovod/train/

Finally we evaluate the best performing model on the test_dataset and note that the model has quite successfully learned to segment the dataset.

model = PSPNet(128, 128, 3)
model_path = checkpoint_dir + 'checkpoint-49.ckpt'
model.load_weights(model_path)
model.compile(loss = tf.keras.losses.SparseCategoricalCrossentropy(), metrics='acc')
model.evaluate(test_dataset, steps = VALIDATION_STEPS)