ScaDaMaLe Course site and book

Implementation of ICNet

In this notebook, an implementation of ICNet is presented which is an architecture whitch uses a trade-off between complexity and inference time efficiently. The architecture is evaluated against the Oxford pets dataset.

Below, functions for data manipulation are defined, ensuring that the images inputted to the model is of appropriate format.

import tensorflow as tf
import tensorflow_datasets as tfds
import matplotlib.pyplot as plt
from tensorflow.keras.layers import *
from tensorflow.keras.models import *
import numpy as np

# Function for normalizing image_size so that pixel intensity is between 0 and 1
def normalize(input_image, input_mask):
  input_image = tf.cast(input_image, tf.float32) / 255.0
  input_mask -= 1
  return input_image, input_mask

# Function for resizing the train images to the desired input shape of 128x128 as well as augmenting the training images
@tf.function
def load_image_train(datapoint):
  input_image = tf.image.resize(datapoint['image'], (128, 128))
  input_mask = tf.image.resize(datapoint['segmentation_mask'], (128, 128))

  if tf.random.uniform(()) > 0.5:
    input_image = tf.image.flip_left_right(input_image)
    input_mask = tf.image.flip_left_right(input_mask)

  input_image, input_mask = normalize(input_image, input_mask)
  input_mask = tf.math.round(input_mask)

  return input_image, input_mask

# Function for resizing the test images to the desired output shape (no augmenation)
def load_image_test(datapoint):
  input_image = tf.image.resize(datapoint['image'], (128, 128))
  input_mask = tf.image.resize(datapoint['segmentation_mask'], (128, 128))

  input_image, input_mask = normalize(input_image, input_mask)

  return input_image, input_mask

def load_image_train_noTf(datapoint):
  input_image = tf.image.resize(datapoint['image'], (128, 128))
  input_mask = tf.image.resize(datapoint['segmentation_mask'], (128, 128))

  if tf.random.uniform(()) > 0.5:
    input_image = tf.image.flip_left_right(input_image)
    input_mask = tf.image.flip_left_right(input_mask)

  input_image, input_mask = normalize(input_image, input_mask)
  input_mask = tf.math.round(input_mask)

  return input_image, input_mask

# Function for resizing the image to the desired size of factor 2 or 4 to be inputted to the ICNet architecture
def resize_image16(img, mask):
  input_image = tf.image.resize(img, (128//16, 128//16))
  input_mask=tf.image.resize(mask, (128//16, 128//16))
  input_mask = tf.math.round(input_mask)
  return input_image, input_mask

def resize_image8(img, mask):
  input_image = tf.image.resize(img, (128//8, 128//8))
  input_mask=tf.image.resize(mask, (128//8, 128//8))
  input_mask = tf.math.round(input_mask)
  return input_image, input_mask

def resize_image4(img, mask):
  input_image = tf.image.resize(img, (128//4, 128//4))
  input_mask=tf.image.resize(mask, (128//4, 128//4))
  input_mask = tf.math.round(input_mask)
  return input_image, input_mask

Here the data is loaded from Tensorflow datasets.

import tensorflow_datasets as tfds

dataset, info = tfds.load('oxford_iiit_pet:3.*.*', with_info=True)

TRAIN_LENGTH = info.splits['train'].num_examples
BATCH_SIZE = 114
BUFFER_SIZE = 1000
STEPS_PER_EPOCH = TRAIN_LENGTH // BATCH_SIZE

#train = dataset['train'].map(load_image_train, num_parallel_calls=tf.data.experimental.AUTOTUNE)
#train16 = dataset['train'].map(resize_image16, num_parallel_calls=tf.data.experimental.AUTOTUNE)
#train8 = dataset['train'].map(resize_image8, num_parallel_calls=tf.data.experimental.AUTOTUNE)
#train4 = dataset['train'].map(resize_image4, num_parallel_calls=tf.data.experimental.AUTOTUNE)
#test = dataset['test'].map(load_image_test)

#train_dataset = train.cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE).repeat()
#train_dataset = train_dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
#test_dataset = test.batch(BATCH_SIZE)

train_orig = np.ndarray(shape=(3680,128,128,3))
train_orig_mask = np.ndarray(shape=(3680,128,128,1))
train16_mask = np.ndarray(shape=(3680,8,8,1))
train8_mask = np.ndarray(shape=(3680,16,16,1))
train4_mask = np.ndarray(shape=(3680,32,32,1))

count = 0
for datapoint in dataset['train']:
  img_orig, mask_orig = load_image_train_noTf(datapoint)
  train_orig[count]=img_orig
  train_orig_mask[count]=(mask_orig)
  img, mask = resize_image16(img_orig, mask_orig)
  train16_mask[count]=(mask)
  img, mask = resize_image8(img_orig, mask_orig)
  train8_mask[count]=(mask)
  img, mask = resize_image4(img_orig, mask_orig)
  train4_mask[count]=(mask)
  count+=1

test_orig = np.ndarray(shape=(3669,128,128,3))
test_orig_mask = np.ndarray(shape=(3669,128,128,1))
test_orig_img = np.ndarray(shape=(3669,128,128,3))
test16_mask = np.ndarray(shape=(3669,8,8,1))
test16_img = np.ndarray(shape=(3669,8,8,3))
test8_mask = np.ndarray(shape=(3669,16,16,1))
test8_img = np.ndarray(shape=(3669,16,16,3))
test4_mask = np.ndarray(shape=(3669,32,32,1))
test4_img = np.ndarray(shape=(3669,32,32,3))

count=0
for datapoint in dataset['test']:
  img_orig, mask_orig = load_image_test(datapoint)
  test_orig[count]=(img_orig)
  test_orig_mask[count]=(mask_orig)
  img, mask = resize_image16(img_orig, mask_orig)
  test16_mask[count]=(mask)
  test16_img[count]=(img)
  img, mask = resize_image8(img_orig, mask_orig)
  test8_mask[count]=(mask)
  test8_img[count]=(img)
  img, mask = resize_image4(img_orig, mask_orig)
  test4_mask[count]=(mask)
  test4_img[count]=(img)
  count+=1


def display(display_list):
  plt.figure(figsize=(15, 15))

  title = ['Input Image', 'True Mask', 'Predicted Mask']

  for i in range(len(display_list)):
    plt.subplot(1, len(display_list), i+1)
    plt.title(title[i])
    plt.imshow(tf.keras.preprocessing.image.array_to_img(display_list[i]))
    plt.axis('off')
  plt.show()

sample_image, sample_mask = train_orig[0], train_orig_mask[0]
display([sample_image, sample_mask])
#for image, mask in train.take(1):
#  sample_image, sample_mask = image, mask
#display([sample_image, sample_mask])

#Keep shape with (Batch_SIZE, height,width, channels)
#in either np.array or try datasets.
#train = tfds.as_numpy(dataset['train']['image'])
#train16 = tfds.as_numpy(train16)
#train8 = tfds.as_numpy(train8)
#train4 = tfds.as_numpy(train4)
#truth16 = np.concatenate([y for x, y in train16], axis=0)
#truth8 = np.concatenate([y for x, y in train8], axis=0)
#truth4 = np.concatenate([y for x, y in train4], axis=0)

test_dataset = dataset['test'].map(load_image_test)
test_dataset = test.batch(BATCH_SIZE)
print(test8_mask)
train_orig = np.split(train_orig[0:3648], 114)
print("Finshed 1")
train16_mask = np.split(train16_mask[0:3648], 114)
train8_mask = np.split(train8_mask[0:3648], 114)
train4_mask = np.split(train4_mask[0:3648], 114)

test_orig = np.split(test_orig[0:3648], 114)
print("Finshed 1")
test16_mask = np.split(test16_mask[0:3648], 114)
test8_mask = np.split(test8_mask[0:3648], 114)
test4_mask = np.split(test4_mask[0:3648], 114)
test16_img = np.split(test16_img[0:3648], 114)
test8_img = np.split(test8_img[0:3648], 114)
test4_img = np.split(test4_img[0:3648], 114)
print(test16_mask)
def pool_block(cur_tensor,
               image_width,
               image_height,
               pooling_factor,
               activation):

  strides = [int(np.round(float(image_width)/pooling_factor)),
            int(np.round(float(image_height)/pooling_factor))]
  pooling_size = strides
  x = AveragePooling2D(pooling_size, strides=strides, padding='same')(cur_tensor)
  x = Conv2D(128,(1,1),padding='same')(x)
  x = BatchNormalization()(x)
  x = Activation(activation)(x)
  x = tf.keras.layers.experimental.preprocessing.Resizing(
    image_height, image_width, interpolation="bilinear")(x) # Resizing images to correct shape for future concat
  return x

import tensorflow as tf
from tensorflow.keras.applications.resnet50 import ResNet50

# Function for formatting the resnet model to a modified one which takes advantage of dilation rates instead of strides in the final blocks
def modify_ResNet_Dilation(model):
  for i in range(0,4):
    model.get_layer('conv4_block1_{}_conv'.format(i)).strides = 1
    model.get_layer('conv4_block1_{}_conv'.format(i)).dilation_rate = 2
    model.get_layer('conv5_block1_{}_conv'.format(i)).strides = 1
    model.get_layer('conv5_block1_{}_conv'.format(i)).dilation_rate = 4
  new_model = model_from_json(model.to_json())
  return new_model

def PSPNet(num_classes: int,
           n_filters: int,
           kernel_size: tuple,
           activation: str,
           image_width: int,
           image_height: int,
           isICNet: bool = False
          ):
  if isICNet:
    input_shape=(None, None, 3)
  else:
    input_shape=(image_height,image_width,3)
  encoder=ResNet50(include_top=False, weights='imagenet', input_shape=input_shape)
  encoder=modify_ResNet_Dilation(encoder)
  encoder.trainable=False
  resnet_output=encoder.output
  pooling_layer=[]
  pooling_layer.append(resnet_output)
  output=Dropout(rate=0.5)(resnet_output)
  h = image_height//8
  w = image_width//8
  for i in [1,2,3,6]:
    pool = pool_block(output, h, w, i, activation)
    pooling_layer.append(pool)
  concat=Concatenate()(pooling_layer)
  output_layer=Conv2D(filters=num_classes, kernel_size=(1,1), padding='same')(concat)
  final_layer=UpSampling2D(size=(8,8), data_format='channels_last', interpolation='bilinear')(output_layer)
  final_model=tf.keras.models.Model(inputs=encoder.input, outputs=final_layer)
  return final_model

PSP = PSPNet(3, 16, (3,3), 'relu', 128,128)
PSP.summary()
def PSP_rest(input_prev: tf.Tensor):

  y_ = input_prev
  #Conv_Block
  y = Conv2D(256, 1, dilation_rate=2, padding='same', name='C4_block1_conv1')(y_)
  y = BatchNormalization(name='C4_block1_bn1')(y)
  y = Activation('relu', name='C4_block1_act1')(y)
  y = Conv2D(256, 3, dilation_rate=2, padding='same', name='C4_block1_conv2')(y)
  y = BatchNormalization(name='C4_block1_bn2')(y)
  y = Activation('relu', name='C4_block1_act2')(y)
  y_ = Conv2D(1024, 1, dilation_rate=2, padding='same', name='C4_block1_conv0')(y_)
  y = Conv2D(1024, 1, dilation_rate=2, padding='same', name='C4_block1_conv3')(y)
  y_ = BatchNormalization(name='C4_block1_bn0')(y_)
  y = BatchNormalization(name='C4_block1_bn3')(y)
  y = Add(name='C4_skip1')([y_,y])
  y_ = Activation('relu', name='C4_block1_act3')(y)
  #IDBLOCK1
  y = Conv2D(256, 1, dilation_rate=2, padding='same', name='C4_block2_conv1')(y_)
  y = BatchNormalization(name='C4_block2_bn1')(y)
  y = Activation('relu', name='C4_block2_act1')(y)
  y = Conv2D(256, 3, dilation_rate=2, padding='same', name='C4_block2_conv2')(y)
  y = BatchNormalization(name='C4_block2_bn2')(y)
  y = Activation('relu', name='C4_block2_act2')(y)
  y = Conv2D(1024,1, dilation_rate=2, padding='same', name='C4_block2_conv3')(y)
  y = BatchNormalization(name='C4_block2_bn3')(y)
  y = Add(name='C4_skip2')([y_,y])
  y_ = Activation('relu', name='C4_block2_act3')(y)
  #IDBLOCK2
  y = Conv2D(256, 1, dilation_rate=2, padding='same', name='C4_block3_conv1')(y_)
  y = BatchNormalization(name='C4_block3_bn1')(y)
  y = Activation('relu', name='C4_block3_act1')(y)
  y = Conv2D(256, 3, dilation_rate=2, padding='same', name='C4_block3_conv2')(y)
  y = BatchNormalization(name='C4_block3_bn2')(y)
  y = Activation('relu', name='C4_block3_act2')(y)
  y = Conv2D(1024,1, dilation_rate=2, padding='same', name='C4_block3_conv3')(y)
  y = BatchNormalization(name='C4_block3_bn3')(y)
  y = Add(name='C4_skip3')([y_,y])
  y_ = Activation('relu', name='C4_block3_act3')(y)
  #IDBlock3
  y = Conv2D(256, 1, dilation_rate=2, padding='same', name='C4_block4_conv1')(y_)
  y = BatchNormalization(name='C4_block4_bn1')(y)
  y = Activation('relu', name='C4_block4_act1')(y)
  y = Conv2D(256, 3, dilation_rate=2, padding='same', name='C4_block4_conv2')(y)
  y = BatchNormalization(name='C4_block4_bn2')(y)
  y = Activation('relu', name='C4_block4_act2')(y)
  y = Conv2D(1024,1, dilation_rate=2, padding='same', name='C4_block4_conv3')(y)
  y = BatchNormalization(name='C4_block4_bn3')(y)
  y = Add(name='C4_skip4')([y_,y])
  y_ = Activation('relu', name='C4_block4_act3')(y)
  #ID4
  y = Conv2D(256, 1, dilation_rate=2, padding='same', name='C4_block5_conv1')(y_)
  y = BatchNormalization(name='C4_block5_bn1')(y)
  y = Activation('relu', name='C4_block5_act1')(y)
  y = Conv2D(256, 3, dilation_rate=2, padding='same', name='C4_block5_conv2')(y)
  y = BatchNormalization(name='C4_block5_bn2')(y)
  y = Activation('relu', name='C4_block5_act2')(y)
  y = Conv2D(1024,1, dilation_rate=2, padding='same', name='C4_block5_conv3')(y)
  y = BatchNormalization(name='C4_block5_bn3')(y)
  y = Add(name='C4_skip5')([y_,y])
  y_ = Activation('relu', name='C4_block5_act3')(y)
  #ID5
  y = Conv2D(256, 1, dilation_rate=2, padding='same', name='C4_block6_conv1')(y_)
  y = BatchNormalization(name='C4_block6_bn1')(y)
  y = Activation('relu', name='C4_block6_act1')(y)
  y = Conv2D(256, 3, dilation_rate=2, padding='same', name='C4_block6_conv2')(y)
  y = BatchNormalization(name='C4_block6_bn2')(y)
  y = Activation('relu', name='C4_block6_act2')(y)
  y = Conv2D(1024,1, dilation_rate=2, padding='same', name='C4_block6_conv3')(y)
  y = BatchNormalization(name='C4_block6_bn3')(y)
  y = Add(name='C4_skip6')([y_,y])
  y_ = Activation('relu', name='C4_block6_act3')(y)

  #Conv
  y = Conv2D(512, 1, dilation_rate=4,padding='same', name='C5_block1_conv1')(y_)
  y = BatchNormalization(name='C5_block1_bn1')(y)
  y = Activation('relu', name='C5_block1_act1')(y)
  y = Conv2D(512, 3, dilation_rate=4,padding='same', name='C5_block1_conv2')(y)
  y = BatchNormalization(name='C5_block1_bn2')(y)
  y = Activation('relu', name='C5_block1_act2')(y)
  y_ = Conv2D(2048, 1, dilation_rate=4,padding='same', name='C5_block1_conv0')(y_)
  y = Conv2D(2048, 1, dilation_rate=4,padding='same', name='C5_block1_conv3')(y)
  y_ = BatchNormalization(name='C5_block1_bn0')(y_)
  y = BatchNormalization(name='C5_block1_bn3')(y)
  y = Add(name='C5_skip1')([y_,y])
  y_ = Activation('relu', name='C5_block1_act3')(y)

  #ID
  y = Conv2D(512, 1, dilation_rate=4,padding='same', name='C5_block2_conv1')(y_)
  y = BatchNormalization(name='C5_block2_bn1')(y)
  y = Activation('relu', name='C5_block2_act1')(y)
  y = Conv2D(512, 3, dilation_rate=4,padding='same', name='C5_block2_conv2')(y)
  y = BatchNormalization(name='C5_block2_bn2')(y)
  y = Activation('relu', name='C5_block2_act2')(y)
  y = Conv2D(2048, 1, dilation_rate=4,padding='same', name='C5_block2_conv3')(y)
  y = BatchNormalization(name='C5_block2_bn3')(y)
  y = Add(name='C5_skip2')([y_,y])
  y_ = Activation('relu', name='C5_block2_act3')(y)

  #ID
  y = Conv2D(512, 1, dilation_rate=4,padding='same', name='C5_block3_conv1')(y_)
  y = BatchNormalization(name='C5_block3_bn1')(y)
  y = Activation('relu', name='C5_block3_act1')(y)
  y = Conv2D(512, 3, dilation_rate=4,padding='same', name='C5_block3_conv2')(y)
  y = BatchNormalization(name='C5_block3_bn2')(y)
  y = Activation('relu', name='C5_block3_act2')(y)
  y = Conv2D(2048, 1, dilation_rate=4,padding='same', name='C5_block3_conv3')(y)
  y = BatchNormalization(name='C5_block3_bn3')(y)
  y = Add(name='C5_skip3')([y_,y])
  y_ = Activation('relu', name='C5_block3_act3')(y)

  return(y_)

Method for performing cascade feature fusion. See https://arxiv.org/pdf/1704.08545.pdf

def CFF(stage: int, F_small, F_large, n_classes: int, input_width_small: int, input_height_small: int):
  F_up = tf.keras.layers.experimental.preprocessing.Resizing(int(input_width_small*2), int(input_height_small*2), interpolation="bilinear", name="Upsample_x2_small_{}".format(stage))(F_small)
  F_aux = Conv2D(n_classes, 1, name="ClassifierConv_{}".format(stage), activation='softmax')(F_up)
  #y = ZeroPadding2D(padding=2, name='padding17')(F_up) ?? behövs denna?
  intermediate_f_small = Conv2D(128, 3, dilation_rate=2, padding='same', name="intermediate_f_small_{}".format(stage))(F_up)
  print(intermediate_f_small)
  intermediate_f_small_bn = BatchNormalization(name="intermediate_f_small_bn_{}".format(stage))(intermediate_f_small)
  intermediate_f_large = Conv2D(128, 1, padding='same', name="intermediate_f_large_{}".format(stage))(F_large)
  intermediate_f_large_bn = BatchNormalization(name="intermediate_f_large_bn_{}".format(stage))(intermediate_f_large)
  intermediate_f_sum = Add(name="add_intermediates_{}".format(stage))([intermediate_f_small_bn,intermediate_f_large_bn])
  intermediate_f_relu = Activation('relu', name="activation_CFF_{}".format(stage))(intermediate_f_sum)
  return F_aux, intermediate_f_relu

#%sh sudo apt-get install -y graphviz
def ICNet_1(input_obj: tf.keras.Input,
           n_filters: int,
           kernel_size: tuple,
           activation: str):
  temp=input_obj
  for i in range(1,4):
    # Dropout layer on the hidden units, i.e. not on the input layer
    if i == 2 or i == 3:
      temp=Dropout(rate=0.5)(temp)
    conv1=Conv2D(filters=n_filters*2*i, kernel_size=kernel_size, strides=(2,2), padding='same')(temp)
    batch_norm1=BatchNormalization()(conv1)
    temp=Activation(activation)(batch_norm1)
  return temp  

def ICNet(image_width: int,
         image_height: int,
         n_classes: int,
         n_filters: int = 16,
         kernel_size: tuple = (3,3),
         activation: str = 'relu'):
  input_shape=[image_width,image_height,3]
  input_obj = tf.keras.Input(shape=input_shape, name="input_img_1")
  input_obj_4 = tf.keras.layers.experimental.preprocessing.Resizing(
    image_width//4, image_height//4, interpolation="bilinear", name="input_img_4")(input_obj)
  input_obj_2 = tf.keras.layers.experimental.preprocessing.Resizing(
    image_width//2, image_height//2, interpolation="bilinear", name="input_img_2")(input_obj)
  ICNet_Model1=ICNet_1(input_obj, n_filters, kernel_size, activation)
  PSPModel = PSPNet(n_classes, n_filters, kernel_size, activation, image_width//4, image_height//4, True)
  last_layer = PSPModel.get_layer('conv4_block3_out').output
  PSPModel_2_4 = tf.keras.models.Model(inputs=PSPModel.input, outputs=last_layer, name="JointResNet_2_4")
  ICNet_Model4 = PSPModel_2_4(input_obj_4)
  ICNet_Model2 = PSPModel_2_4(input_obj_2)
  ICNet_4_rest = PSP_rest(ICNet_Model4)
  out1, last_layer = CFF(1, ICNet_4_rest, ICNet_Model2, n_classes, image_width//32, image_height//32)
  out2, last_layer = CFF(2, last_layer, ICNet_Model1, n_classes, image_width//16, image_height//16)
  upsample_2 = UpSampling2D(2, interpolation='bilinear', name="Upsampling_final_prediction")(last_layer)
  output = Conv2D(n_classes, 1, name="ClassifierConv_final_prediction", activation='softmax')(upsample_2)
  final_model = tf.keras.models.Model(inputs=input_obj, outputs=[out1, out2, output])
  return final_model

model=ICNet(128,128,3)
model.summary()
#final_model=tf.keras.models.Model(inputs=input_obj ,outputs=model)
#final_model.summary()
from IPython.display import display as Display, Image
import matplotlib.pyplot as plt
import matplotlib.image as mpimg

tf.keras.utils.plot_model(model, to_file='/dbfs/FileStore/my_model.jpg', show_shapes=True)
img = mpimg.imread('/dbfs/FileStore/my_model.jpg')
plt.figure(figsize=(200,200))
imgplot = plt.imshow(img)

ls /dbfs/FileStore
#import datetime


model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(), loss_weights=[0.1,0.3,0.6],
              metrics=tf.keras.metrics.SparseCategoricalAccuracy())

#log_dir = "logs/fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
#tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)

def create_mask(pred_mask):
  pred_mask = tf.argmax(pred_mask, axis=-1)
  pred_mask = pred_mask[..., tf.newaxis]
  return pred_mask[0]

def show_predictions(dataset=None, num=1):
  if dataset:
    for image, mask in dataset.take(num):
      print(image)
      pred_mask = model.predict(image)[2]
      display([image[0], mask[0], create_mask(pred_mask)])
  else:
    display([sample_image, sample_mask,
             create_mask(model.predict(sample_image[tf.newaxis, ...])[2])])

show_predictions()

ls
import mlflow.tensorflow
mlflow.tensorflow.autolog()
def batch_generator(X, Y16, Y8, Y4, batch_size = BATCH_SIZE):
  indices = np.arange(len(X))
  batch=[]
  while True:
  # it might be a good idea to shuffle your data before each epoch
    np.random.shuffle(indices)
    for i in indices:
      batch.append(i)
      if len(batch)==batch_size:
        yield X[batch], {'ClassifierConv_1': Y16[batch], 'ClassifierConv_2': Y8[batch], 'ClassifierConv_final_prediction': Y4[batch]}
        batch=[]

def batch_generator_eval(X, Y16, Y8, Y4, batch_size = BATCH_SIZE):
  indices = np.arange(len(X))
  batch=[]
  while True:
    for i in indices:
      batch.append(i)
      if len(batch)==batch_size:
        yield X[batch], {'ClassifierConv_1': Y16[batch], 'ClassifierConv_2': Y8[batch], 'ClassifierConv_final_prediction': Y4[batch]}
        batch=[]

class DisplayCallback(tf.keras.callbacks.Callback):
  def on_epoch_end(self, epoch, logs=None):
    show_predictions()

res_eval_1 = []
class MyCustomCallback(tf.keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs=None):
        res_eval_1.append(self.model.evaluate(test_orig, [test16_mask, test8_mask, test4_mask], batch_size=45, verbose=1))
        show_predictions()

TRAIN_LENGTH = info.splits['train'].num_examples
BATCH_SIZE = 64
BUFFER_SIZE = 1000
STEPS_PER_EPOCH = TRAIN_LENGTH//BATCH_SIZE

EPOCHS = 10
VAL_SUBSPLITS = 5
VALIDATION_STEPS = info.splits['test'].num_examples//BATCH_SIZE

train_generator = batch_generator(train_orig,train16_mask,train8_mask,train4_mask,batch_size=BATCH_SIZE)
eval_generator = batch_generator_eval(test_orig, test16_mask, test8_mask, test4_mask, batch_size=BATCH_SIZE)
model_history =  model.fit(train_generator, epochs=EPOCHS,steps_per_epoch=STEPS_PER_EPOCH,
                                     callbacks=[MyCustomCallback()],verbose=1)                        
#model_history = model.fit(x=train_orig, y=[train16_mask, train8_mask, train4_mask],
#                          epochs=EPOCHS,
#                          steps_per_epoch=STEPS_PER_EPOCH,
#                          callbacks=[MyCustomCallback()], verbose=1)






print(res_eval_1)
loss = model_history.history['loss']
acc = model_history.history['ClassifierConv_final_prediction_sparse_categorical_accuracy']
val_loss = []
val_acc = []
val_loss1 = []
val_loss2 = []
val_loss3 = []
val_acc1 = []
val_acc2 = []
for i in range(EPOCHS):
  val_loss.append(res_eval_1[i][0])
  val_loss1.append(res_eval_1[i][1])
  val_loss2.append(res_eval_1[i][2])
  val_loss3.append(res_eval_1[i][3])
  val_acc.append(res_eval_1[i][6])
  val_acc1.append(res_eval_1[i][4])
  val_acc2.append(res_eval_1[i][5])

epochs = range(EPOCHS)

plt.figure(figsize=(20,3))
plt.subplot(1,4,1)
plt.plot(epochs, loss, 'r', label='Training loss')
plt.plot(epochs, val_loss, 'bo', label='Validation loss')
plt.title('Training and Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss Value')
plt.legend()
plt.subplot(1,4,2)
plt.plot(epochs, acc, 'r', label="Training accuracy")
plt.plot(epochs, val_acc, 'bo', label="Validation accuracy")
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.subplot(1,4,3)
plt.plot(epochs, val_loss1, 'b', label="Loss output 1")
plt.plot(epochs, val_loss2, 'g', label="Loss output 2")
plt.plot(epochs, val_loss3, 'y', label="Loss output 3")
plt.legend()
plt.subplot(1,4,4)
plt.plot(epochs, val_acc1, 'b', label="Acc output 1")
plt.plot(epochs, val_acc2, 'g', label="Acc output 2")
plt.plot(epochs, val_acc, 'y', label="Acc output 3")
plt.legend()
plt.show()