import tensorflow as tf
import tensorflow_datasets as tfds
import matplotlib.pyplot as plt
from tensorflow.keras.layers import AveragePooling2D, Conv2D, BatchNormalization, Activation, Concatenate, UpSampling2D, Reshape, TimeDistributed, ConvLSTM2D
from tensorflow.keras import Model
import numpy as np
from tensorflow.keras.applications.resnet50 import ResNet50
gpus = tf.config.list_physical_devices('GPU')
if gpus:
try:
# Currently, memory growt*h needs to be the same across GPUs
for gpu in gpus:
tf.config.experimental.set_memory_growth(gpu, True)
logical_gpus = tf.config.experimental.list_logical_devices('GPU')
print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
except RuntimeError as e:
# Memory growth must be set before GPUs have been initialized
print(e)
1 Physical GPUs, 1 Logical GPUs
# Function for normalizing image_size so that pixel intensity is between 0 and 1
def normalize(input_image, input_mask):
input_image = tf.cast(input_image, tf.float32) / 255.0
input_mask -= 1
return input_image, input_mask
# Function for resizing the train images to the desired input shape of 128x128 as well as augmenting the training images
.function
def load_image_train(datapoint):
input_image = tf.image.resize(datapoint['image'], (128, 128))
input_mask = tf.image.resize(datapoint['segmentation_mask'], (128, 128))
if tf.random.uniform(()) > 0.5:
input_image = tf.image.flip_left_right(input_image)
input_mask = tf.image.flip_left_right(input_mask)
input_image, input_mask = normalize(input_image, input_mask)
input_mask = tf.math.round(input_mask)
return input_image, input_mask
# Function for resizing the test images to the desired output shape (no augmenation)
def load_image_test(datapoint):
input_image = tf.image.resize(datapoint['image'], (128, 128))
input_mask = tf.image.resize(datapoint['segmentation_mask'], (128, 128))
input_image, input_mask = normalize(input_image, input_mask)
return input_image, input_mask
dataset, info = tfds.load('oxford_iiit_pet:3.*.*', with_info=True)
TRAIN_LENGTH = info.splits['train'].num_examples
BATCH_SIZE = 64
BUFFER_SIZE = 1000
STEPS_PER_EPOCH = TRAIN_LENGTH // BATCH_SIZE
train = dataset['train'].map(load_image_train, num_parallel_calls=tf.data.experimental.AUTOTUNE)
test = dataset['test'].map(load_image_test)
train_dataset = train.shuffle(BUFFER_SIZE).cache().batch(BATCH_SIZE).repeat()
train_dataset = train_dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
test_dataset = test.batch(BATCH_SIZE)
def display(display_list):
plt.figure(figsize=(15, 15))
title = ['Input Image', 'True Mask', 'Predicted Mask']
for i in range(len(display_list)):
plt.subplot(1, len(display_list), i+1)
plt.title(title[i])
plt.imshow(tf.keras.preprocessing.image.array_to_img(display_list[i]))
plt.axis('off')
plt.show()
for image, mask in train.take(1):
sample_image, sample_mask = image, mask
display([sample_image, sample_mask])
def pool_block(cur_tensor: tf.Tensor,
image_width: int,
image_height: int,
pooling_factor: int,
activation: str = 'relu'
) -> tf.Tensor:
"""
Parameters
----------
cur_tensor : tf.Tensor
Incoming tensor.
image_width : int
Width of the image.
image_height : int
Height of the image.
pooling_factor : int
Pooling factor to scale image.
activation : str, default = 'relu'
Activation function to be applied after the pooling operations.
Returns
-------
tf.Tensor
2D convolutional block.
"""
#Calculate the strides with image size and pooling factor
strides = [int(np.round(float(image_width)/pooling_factor)),
int(np.round(float(image_height)/pooling_factor))]
pooling_size = strides
x = AveragePooling2D(pooling_size, strides=strides, padding='same')(cur_tensor)
x = Conv2D(filters = 512,
kernel_size = (1,1),
padding = 'same')(x)
x = BatchNormalization()(x)
x = Activation(activation)(x)
# Resizing images to correct shape for future concat
x = tf.keras.layers.experimental.preprocessing.Resizing(
image_height,
image_width,
interpolation="bilinear")(x)
return x
def modify_ResNet_Dilation(
model: tf.keras.Model
) -> tf.keras.Model:
"""
Modifies the ResNet to fit the PSPNet Paper with the described dilated strategy.
Parameters
----------
model : tf.keras.Model
The ResNet-50 model to be modified.
Returns
-------
tf.keras.Model
Modified model.
"""
for i in range(0,4):
model.get_layer('conv4_block1_{}_conv'.format(i)).strides = 1
model.get_layer('conv4_block1_{}_conv'.format(i)).dilation_rate = 2
model.get_layer('conv5_block1_{}_conv'.format(i)).strides = 1
model.get_layer('conv5_block1_{}_conv'.format(i)).dilation_rate = 4
model.save('/tmp/my_model')
new_model = tf.keras.models.load_model('/tmp/my_model')
return new_model
def PSPNet(image_width: int,
image_height: int,
n_classes: int,
kernel_size: tuple = (3,3),
activation: str = 'relu',
weights: str = 'imagenet',
shallow_tuning: bool = False,
isICNet: bool = False
) -> tf.keras.Model:
"""
Setting up the PSPNet model
Parameters
----------
image_width : int
Width of the image.
image_height : int
Height of the image.
n_classes : int
Number of classes.
kernel_size : tuple, default = (3, 3)
Size of the kernel.
activation : str, default = 'relu'
Activation function to be applied after the pooling operations.
weights: str, default = 'imagenet'
String defining which weights to use for the backbone, 'imagenet' for ImageNet weights or None for normalized initialization.
shallow_tuning : bool, default = True
Boolean for using shallow tuning with pre-trained weights from ImageNet or not.
isICNet : bool, default = False
Boolean to determine if the PSPNet will be part of an ICNet.
Returns
-------
tf.keras.Model
The finished keras model for PSPNet.
"""
if shallow_tuning and not weights:
raise ValueError("Shallow tuning can not be performed without loading pre-trained weights. Please input 'imagenet' to argument weights...")
#If the function is used for the ICNet input_shape is set to none as ICNet takes 3 inputs
if isICNet:
input_shape=(None, None, 3)
else:
input_shape=(image_height,image_width,3)
#Initializing the ResNet50 Backbone
y=ResNet50(include_top=False, weights=weights, input_shape=input_shape)
y=modify_ResNet_Dilation(y)
if shallow_tuning:
y.trainable=False
pooling_layer=[]
output=y.output
pooling_layer.append(output)
h = image_height//8
w = image_width//8
#Loop for calling the pool block functions for pooling factors [1,2,3,6]
for i in [1,2,3,6]:
pool = pool_block(output, h, w, i, activation)
pooling_layer.append(pool)
x=Concatenate()(pooling_layer)
x=Conv2D(filters=n_classes, kernel_size=(1,1), padding='same')(x)
x=UpSampling2D(size=(8,8), data_format='channels_last', interpolation='bilinear')(x)
x=Reshape((image_height*image_width, n_classes))(x)
x=Activation(tf.nn.softmax)(x)
x=Reshape((image_height,image_width,n_classes))(x)
final_model=tf.keras.Model(inputs=y.input, outputs=x)
return final_model
model = PSPNet(image_height = 128, image_width = 128, n_classes = 3)
INFO:tensorflow:Assets written to: /tmp/my_model/assets
INFO:tensorflow:Assets written to: /tmp/my_model/assets
WARNING:tensorflow:No training configuration found in save file, so the model was *not* compiled. Compile it manually.
WARNING:tensorflow:No training configuration found in save file, so the model was *not* compiled. Compile it manually.
Implementation of PSPNet
William Anzén (Linkedin), Christian von Koch (Linkedin)
2021, Stockholm, Sweden
This project was supported by Combient Mix AB under the industrial supervision of Razesh Sainudiin and Max Fischer.
In this notebook, an implementation of PSPNet is presented which is an architecture which uses scene parsing and evaluates the images at different scales and finally combines the different results to form a final prediction. The architecture is evaluated against the Oxford-IIIT Pet Dataset. This notebook has reused material from the Image Segmentation Tutorial on Tensorflow for loading the dataset and showing predictions.