U-Net model for image segmentation
William Anzén (Linkedin), Christian von Koch (Linkedin)
2021, Stockholm, Sweden
This project was supported by Combient Mix AB under the industrial supervision of Razesh Sainudiin and Max Fischer.
This is a modified version of Tensorflows tutorial regarding image segmentation which can be found here. In this notebook a modified U-Net was used.
Importing the required packages...
import tensorflow as tf
import tensorflow_datasets as tfds
import matplotlib.pyplot as plt
from IPython.display import clear_output
from tensorflow.keras import Input
from tensorflow.keras.layers import Conv2D, BatchNormalization, Activation, MaxPooling2D, Concatenate, ReLU, Reshape, Conv2DTranspose
from tensorflow.keras import Model
from tensorflow.keras.applications import VGG16
from typing import Union, Tuple
First, the Oxford-IIT Pet Dataset from the TensorFlow datasets is loaded and then the images are transformed in desired way and the datasets used for training and inference are created. Finally an example image is displayed.
dataset, info = tfds.load('oxford_iiit_pet:3.*.*', with_info=True)
def normalize(input_image, input_mask):
input_image = tf.cast(input_image, tf.float32) / 255.0
input_mask -= 1
return input_image, input_mask
@tf.function
def load_image_train(datapoint):
input_image = tf.image.resize(datapoint['image'], (128, 128))
input_mask = tf.image.resize(datapoint['segmentation_mask'], (128, 128))
if tf.random.uniform(()) > 0.5:
input_image = tf.image.flip_left_right(input_image)
input_mask = tf.image.flip_left_right(input_mask)
input_image, input_mask = normalize(input_image, input_mask)
return input_image, input_mask
def load_image_test(datapoint):
input_image = tf.image.resize(datapoint['image'], (128, 128))
input_mask = tf.image.resize(datapoint['segmentation_mask'], (128, 128))
input_image, input_mask = normalize(input_image, input_mask)
return input_image, input_mask
TRAIN_LENGTH = info.splits['train'].num_examples
BATCH_SIZE = 64
BUFFER_SIZE = 1000
STEPS_PER_EPOCH = TRAIN_LENGTH // BATCH_SIZE
train = dataset['train'].map(load_image_train, num_parallel_calls=tf.data.experimental.AUTOTUNE)
test = dataset['test'].map(load_image_test)
train_dataset = train.cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE).repeat()
train_dataset = train_dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
test_dataset = test.batch(BATCH_SIZE)
def display(display_list):
plt.figure(figsize=(15, 15))
title = ['Input Image', 'True Mask', 'Predicted Mask']
for i in range(len(display_list)):
plt.subplot(1, len(display_list), i+1)
plt.title(title[i])
plt.imshow(tf.keras.preprocessing.image.array_to_img(display_list[i]))
plt.axis('off')
plt.show()
for image, mask in train.take(1):
sample_image, sample_mask = image, mask
display([sample_image, sample_mask])
Now that the dataset has been loaded into memory, the functions needed for the modified U-Net model are defined. Note that the alternative of both using a VGG16 encoder and a regular one is provided.
def conv_block(input_layer: tf.Tensor,
filter_size: int,
kernel_size: tuple,
activation: str = 'relu'
) -> tf.Tensor:
"""
Creates one concolutional block of the U-Net structure.
Parameters
----------
input_layer : tf.Tensor
Input tensor.
n_filters : int
Number of filters/kernels.
kernel_size : tuple
Size of the kernel.
activation : str, default='relu'
Activation function to be applied after the pooling operations.
Returns
-------
tf.Tensor
2D convolutional block.
"""
x = Conv2D(filters=filter_size, kernel_size=kernel_size, padding='same')(input_layer)
x = BatchNormalization()(x)
x = Activation(activation)(x)
x = Conv2D(filters=filter_size, kernel_size=kernel_size,padding='same')(x)
x = BatchNormalization()(x)
x = Activation(activation)(x)
return x
def encoder_VGG16(input_shape: list
) -> Tuple[tf.keras.Model, list]:
"""
Creates the encoder as a VGG16 network.
Parameters
----------
input_shape : list
Input shape to initialize model.
Returns
-------
tf.keras.Model
Model instance from tf.keras.
list
List containing the layers to be concatenated in the upsampling phase.
"""
base_model=VGG16(include_top=False, weights='imagenet', input_shape=input_shape)
layers=[layer.output for layer in base_model.layers]
base_model = tf.keras.Model(inputs=base_model.input, outputs=layers[-2])
base_model.summary()
x = []
y = base_model.get_layer('block1_conv1').output
x.append(y)
y = base_model.get_layer('block2_conv2').output
x.append(y)
y = base_model.get_layer('block3_conv3').output
x.append(y)
y = base_model.get_layer('block4_conv3').output
x.append(y)
y = base_model.get_layer('block5_conv3').output
x.append(y)
return base_model, x
def encoder_unet(input_shape: list
) -> Tuple[tf.keras.Input, list]:
"""
Creates the encoder as the one described in the U-Net paper with slight modifications.
Parameters
----------
input_shape : tf.Tensor
Shape of the inputted image to the model.
Returns
-------
tf.keras.Input
Input layer for the inputted image.
list
List containing the layers to be concatenated in the upsampling phase.
"""
input_layer = tf.keras.Input(shape=input_shape)
conv1 = conv_block(input_layer,4,3,'relu')
pool1 = MaxPooling2D((2,2))(conv1)
conv2 = conv_block(pool1,8,3,'relu')
pool2 = MaxPooling2D((2,2))(conv2)
conv3 = conv_block(pool2,16,3,'relu')
pool3 = MaxPooling2D((2,2))(conv3)
conv4 = conv_block(pool3,32,3,'relu')
pool4 = MaxPooling2D((2,2))(conv4)
conv5 = conv_block(pool4,64,3,'relu')
x = [conv1,conv2,conv3,conv4,conv5]
return input_layer, x
def unet(image_width: int,
image_heigth: int,
n_channels: int,
n_depth: int,
n_classes: int,
vgg16: bool = False,
transfer_learning: bool = False
) -> tf.keras.Model:
"""
Creates the U-Net architecture with slight modifications, using particularily less filters.
Parameters
----------
image_width : int
Shape of the desired width for the inputted image.
image_height : int
Shape of the desired height for the inputted image.
n_channels : int
Number of channels of the inputted image.
n_depth : int
The desired depth level of the resulting U-Net architecture.
n_classes : int
The number of classes to be predicted/number of filters for final prediction layer
vgg16 : bool, default = False
Boolean for using architecture VGG16 in encoder part of the model or not.
transfer_learning : bool, default = True
Boolean for using transfer learning with pre-trained weights from ImageNet or not.
Returns
-------
tf.keras.Model
The produced U-Net model.
"""
if n_depth<1 or n_depth>5:
raise Exception("Unsupported number of layers/upsamples")
input_shape = [image_heigth, image_width, n_channels]
if vgg16:
encoded_model, x = encoder_VGG16(input_shape)
if transfer_learning:
encoded_model.trainable=False
else:
encoded_model, x = encoder_unet(input_shape)
intermediate_model = x[n_depth-1]
#Dropout
for i in reversed(range(0,n_depth-1)):
next_filters = x[i+1].shape[3]/2
intermediate_model = Conv2DTranspose(filters=next_filters ,kernel_size=3,strides=2,padding='same')(intermediate_model)
intermediate_model = tf.keras.layers.Concatenate()([intermediate_model,x[i]])
intermediate_model = tf.keras.layers.BatchNormalization()(intermediate_model)
intermediate_model = tf.keras.layers.ReLU()(intermediate_model)
intermediate_model = conv_block(intermediate_model,next_filters,kernel_size=3,activation='relu')
intermediate_model=Conv2D(filters=n_classes,kernel_size=(1,1),strides=(1),padding='same')(intermediate_model)
intermediate_model = Reshape((image_heigth*image_width, n_classes))(intermediate_model)
intermediate_model = Activation(tf.nn.softmax)(intermediate_model)
intermediate_model = Reshape((image_heigth,image_width, n_classes))(intermediate_model)
final_model=tf.keras.models.Model(inputs=encoded_model ,outputs=intermediate_model)
return final_model
Let's then create the model.
model = unet(128,128,3,5,3)
And here is a summary of the model created...
model.summary()
The model is then compiled and its prediction before training is shown.
model.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
def create_mask(pred_mask):
pred_mask = tf.argmax(pred_mask, axis=-1)
pred_mask = pred_mask[..., tf.newaxis]
return pred_mask[0]
def show_predictions(dataset=None, num=1):
if dataset:
for image, mask in dataset.take(num):
pred_mask = model.predict(image)
display([image[0], mask[0], create_mask(pred_mask)])
else:
display([sample_image, sample_mask,
create_mask(model.predict(sample_image[tf.newaxis, ...]))])
show_predictions()
Below, the model is fitted against the training data and validated on the validation set after each epoch. A validation accuracy of 84.5 % is achieved after one epoch. A custom callback is added to trace the learning of the model througout its training.
class MyCustomCallback(tf.keras.callbacks.Callback):
def on_epoch_end(self, epoch, logs=None):
show_predictions()
EPOCHS = 50
VAL_SUBSPLITS = 5
VALIDATION_STEPS = info.splits['test'].num_examples//BATCH_SIZE//VAL_SUBSPLITS
model_history = model.fit(train_dataset, epochs=EPOCHS,
steps_per_epoch=STEPS_PER_EPOCH,
validation_steps=VALIDATION_STEPS,
validation_data=test_dataset,
callbacks = [MyCustomCallback()])
Some predictions on the test_dataset
are shown to showcase the performance of the model on images it has not been trained on.
show_predictions(test_dataset,num=10)
Finally we plot the learning curves of the model in its 50 epochs of training. Both the loss curves as well as the accuracy curves are presented.
loss = model_history.history['loss']
val_loss = model_history.history['val_loss']
acc = model_history.history['accuracy']
val_acc = model_history.history['val_accuracy']
epochs = range(EPOCHS)
plt.figure(figsize=(10,5))
plt.subplot(1,2,1)
plt.plot(epochs, loss, 'r', label='Training loss')
plt.plot(epochs, val_loss, 'bo', label='Validation loss')
plt.title('Training and Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss Value')
plt.legend()
plt.subplot(1,2,2)
plt.plot(epochs, acc, 'r', label='Training accuracy')
plt.plot(epochs, val_acc, 'bo', label='Validation accuracy')
plt.title('Training and Validation Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.show()