062_DLbyABr_06-GenerativeNetworks(Python)

SDS-2.x, Scalable Data Engineering Science

This is a 2019 augmentation and update of Adam Breindel's initial notebooks.

Generative Networks

Concept:

If a set of network weights can convert an image of the numeral 8 or a cat
into the classification "8" or "cat" ...

Does it contain enough information to do the reverse?

I.e., can we ask a network what "8" looks like and get a picture?

Let's think about this for a second. Clearly the classifications have far fewer bits of entropy than the source images' theoretical limit.

  • Cat (in cat-vs-dog) has just 1 bit, where perhaps a 256x256 grayscale image has up to 512k bits.
  • 8 (in MNIST) has log210{log _2 10} or a little over 3 bits, where a 28x28 grayscale image has over 6000 bits.

So at first, this would seem difficult or impossible.

But ... let's do a thought experiment.

  • Children can do this easily
  • We could create a lookup table of, say, digit -> image trivially, and use that as a first approximation

Those approaches seem like cheating. But let's think about how they work.

If a child (or adult) draws a cat (or number 8), they are probably not drawing any specific cat (or 8). They are drawing a general approximation of a cat based on

  1. All of the cats they've seen
  2. What they remember as the key elements of a cat (4 legs, tail, pointy ears)
  3. A lookup table substitutes one specific cat or 8 ... and, especially in the case of the 8, that may be fine. The only thing we "lose" is the diversity of things that all mapped to cat (or 8) -- and discarding that information was exactly our goal when building a classifier

The "8" is even simpler: we learn that a number is an idea, not a specific instance, so anything that another human recognizes as 8 is good enough. We are not even trying to make a particular shape, just one that represents our encoded information that distinguishes 8 from other possible symbols in the context and the eye of the viewer.

This should remind you a bit of the KL Divergence we talked about at the start: we are providing just enough information (entropy or surprise) to distinguish the cat or "8" from the other items that a human receiver might be expecting to see.

And where do our handwritten "pixels" come from? No image in particular -- they are totally synthetic based on a probability distribution.

These considerations make it sound more likely that a computer could perform the same task.

But what would the weights really represent? what could they generate?

The weights we learn in classification represent the distinguishing features of a class, across all training examples of the class, modified to overlap minimally with features of other classes. KL divergence again.

To be concrete, if we trained a model on just a few dozen MNIST images using pixels, it would probably learn the 3 or 4 "magic" pixels that happened to distinguish the 8s in that dataset. Trying to generate from that information would yield strong confidence about those magic pixels, but would look like dots to us humans.

On the other hand, if we trained on a very large number of MNIST images -- say we use the convnet this time -- the model's weights should represent general filters of masks for features that distinguish an 8. And if we try to reverse the process by amplifying just those filters, we should get a blurry statistical distribution of those very features. The approximate shape of the Platonic "8"!

Mechanically, How Could This Work?

Let's start with a simpler model called an auto-encoder.

An autoencoder's job is to take a large representation of a record and find weights that represent that record in a smaller encoding, subject to the constraint that the decoded version should match the original as closely as possible.

A bit like training a JPEG encoder to compress images by scoring it with the loss between the original image and the decompressed version of the lossy compressed image.

One nice aspect of this is that it is unsupervised -- i.e., we do not need any ground-truth or human-generated labels in order to find the error and train. The error is always the difference between the output and the input, and the goal is to minimize this over many examples, thus minimize in the general case.

We can do this with a simple multilayer perceptron network. Or, we can get fancier and do this with a convolutional network. In reverse, the convolution (typically called "transposed convolution" or "deconvolution") is an upsampling operation across space (in images) or space & time (in audio/video).

from keras.models import Sequential
from keras.layers import Dense
from keras.utils import to_categorical
import sklearn.datasets
import datetime
import matplotlib.pyplot as plt
import numpy as np

train_libsvm = "/dbfs/databricks-datasets/mnist-digits/data-001/mnist-digits-train.txt"
test_libsvm = "/dbfs/databricks-datasets/mnist-digits/data-001/mnist-digits-test.txt"

X_train, y_train = sklearn.datasets.load_svmlight_file(train_libsvm, n_features=784)
X_train = X_train.toarray()
X_test, y_test = sklearn.datasets.load_svmlight_file(test_libsvm, n_features=784)
X_test = X_test.toarray()

model = Sequential()
model.add(Dense(30, input_dim=784, kernel_initializer='normal', activation='relu'))
model.add(Dense(784, kernel_initializer='normal', activation='relu'))
model.compile(loss='mean_squared_error', optimizer='adam', metrics=['mean_squared_error', 'binary_crossentropy'])

start = datetime.datetime.today()

history = model.fit(X_train, X_train, epochs=5, batch_size=100, validation_split=0.1, verbose=2)

scores = model.evaluate(X_test, X_test)

print
for i in range(len(model.metrics_names)):
    print("%s: %f" % (model.metrics_names[i], scores[i]))

print ("Start: " + str(start))
end = datetime.datetime.today()
print ("End: " + str(end))
print ("Elapse: " + str(end-start))

fig, ax = plt.subplots()
fig.set_size_inches((4,4))
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('model loss')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['train', 'val'], loc='upper left')
display(fig)
ax.clear()
fig.set_size_inches((5,5))
ax.imshow(np.reshape(X_test[61], (28,28)), cmap='gray')
display(fig)
encode_decode = model.predict(np.reshape(X_test[61], (1, 784)))
ax.clear()
ax.imshow(np.reshape(encode_decode, (28,28)), cmap='gray')
display(fig)

Any ideas about those black dots in the upper right?

Pretty cool. So we're all done now, right? Now quite...

The problem with the autoencoder is it's "too good" at its task.

It is optimized to compress exactly the input record set, so it is trained only to create records it has seen. If the middle layer, or information bottleneck, is tight enough, the coded records use all of the information space in the middle layer.

So any value in the middle layer decodes to exactly one already-seen exemplar.

In our example, and most autoencoders, there is more space in the middle layer but the coded values are not distributed in any sensible way. So we can decode a random vector and we'll probably just get garbage.

v = np.random.randn(30)
v = np.array(np.reshape(v, (1, 30)))
import tensorflow as tf
t = tf.convert_to_tensor(v, dtype='float32')
out = model.layers[1](t)
from keras import backend as K

with K.get_session().as_default():
    output = out.eval()
    ax.clear()
    ax.imshow(np.reshape(output, (28,28)), cmap='gray')
    display(fig)
    

The Goal is to Generate a Variety of New Output From a Variety of New Inputs

... Where the Class/Category is Common (i.e., all 8s or Cats)

Some considerations:

  • Is "generative content" something new? Or something true?

    • In a Platonic sense, maybe, but in reality it's literally a probabilistic guess based on the training data!
    • E.g., law enforcement photo enhancment
  • How do we train?

    • If we score directly against the training data (like in the autoencoder), the network will be very conservative, generating only examples that it has seen.
    • In extreme cases, it will always generate a single (or small number) of examples, since those score well. This is known as mode collapse, since the network learns to locate the modes in the input distribution.

Two principal approaches / architectures (2015-)

Generative Adversarial Networks (GAN) and Variational Autoencoders (VAE)

Variational Autoencoder (VAE)

Our autoencoder was able to generate images, but the problem was that arbitrary input vectors don't map to anything meaningful. As discussed, this is partly by design -- the training of the VAE is for effectively for compressing a specific input dataset.

What we would like, is that if we start with a valid input vector and move a bit on some direction, we get a plausible output that is also changed in some way.


ASIDE: Manifold Hypothesis

The manifold hypothesis is that the interesting, relevant, or critical subspaces in the space of all vector inputs are actually low(er) dimensional manifolds. A manifold is a space where each point has a neighborhood that behaves like (is homeomorphic to) Rn{\Bbb R^n}. So we would like to be able to move a small amount and have only a small amount of change, not a sudden discontinuous change.


The key feature of Variational Autoencoders is that we add a constraint on the encoded representation of our data: namely, that it follows a Gaussian distribution. Since the Gaussian is determined by its mean and variance (or standard deviation), we can model it as a k-variate Gaussian with these two parameters (μ{\mu} and σ{\sigma}) for each value of k.

One challenge is how to balance accurate reproduction of the input (traditional autoencoder loss) with the requirement that we match a Gaussian distribution. We can force the network to optimize both of these goals by creating a custom error function that sums up two components:

  • How well we match the input, calculated as binary crossentropy or MSE loss
  • How well we match a Gaussian, calculated as KL divergence from the Gaussian distribution

We can easily implement a custom loss function and pass it as a parameter to the optimizer in Keras.

The Keras source examples folder contains an elegant simple implementation, which we'll discuss below. It's a little more complex than the code we've seen so far, but we'll clarify the innovations:

  • Custom loss functions that combined KL divergence and cross-entropy loss
  • Custom "Lambda" layer that provides the sampling from the encoded distribution

Overall it's probably simpler than you might expect. Let's start it (since it takes a few minutes to train) and discuss the code:

import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import norm

from keras.layers import Input, Dense, Lambda
from keras.models import Model
from keras import backend as K
from keras import objectives
from keras.datasets import mnist
import sklearn.datasets

batch_size = 100
original_dim = 784
latent_dim = 2
intermediate_dim = 256
nb_epoch = 50
epsilon_std = 1.0

x = Input(batch_shape=(batch_size, original_dim))
h = Dense(intermediate_dim, activation='relu')(x)
z_mean = Dense(latent_dim)(h)
z_log_var = Dense(latent_dim)(h)

def sampling(args):
    z_mean, z_log_var = args
    epsilon = K.random_normal(shape=(batch_size, latent_dim), mean=0.,
                              stddev=epsilon_std)
    return z_mean + K.exp(z_log_var / 2) * epsilon

# note that "output_shape" isn't necessary with the TensorFlow backend
z = Lambda(sampling, output_shape=(latent_dim,))([z_mean, z_log_var])

# we instantiate these layers separately so as to reuse them later
decoder_h = Dense(intermediate_dim, activation='relu')
decoder_mean = Dense(original_dim, activation='sigmoid')
h_decoded = decoder_h(z)
x_decoded_mean = decoder_mean(h_decoded)

def vae_loss(x, x_decoded_mean):
    xent_loss = original_dim * objectives.binary_crossentropy(x, x_decoded_mean)
    kl_loss = - 0.5 * K.sum(1 + z_log_var - K.square(z_mean) - K.exp(z_log_var), axis=-1)
    return xent_loss + kl_loss

vae = Model(x, x_decoded_mean)
vae.compile(optimizer='rmsprop', loss=vae_loss)

train_libsvm = "/dbfs/databricks-datasets/mnist-digits/data-001/mnist-digits-train.txt"
test_libsvm = "/dbfs/databricks-datasets/mnist-digits/data-001/mnist-digits-test.txt"

x_train, y_train = sklearn.datasets.load_svmlight_file(train_libsvm, n_features=784)
x_train = x_train.toarray()
x_test, y_test = sklearn.datasets.load_svmlight_file(test_libsvm, n_features=784)
x_test = x_test.toarray()

x_train = x_train.astype('float32') / 255.
x_test = x_test.astype('float32') / 255.
x_train = x_train.reshape((len(x_train), np.prod(x_train.shape[1:])))
x_test = x_test.reshape((len(x_test), np.prod(x_test.shape[1:])))

vae.fit(x_train, x_train,
        shuffle=True,
        epochs=nb_epoch,
        batch_size=batch_size,
        validation_data=(x_test, x_test), verbose=2)

# build a model to project inputs on the latent space
encoder = Model(x, z_mean)
Train on 60000 samples, validate on 10000 samples Epoch 1/50 - 7s - loss: 190.9768 - val_loss: 173.1796 Epoch 2/50 - 7s - loss: 170.5004 - val_loss: 168.3274 Epoch 3/50 - 7s - loss: 166.9483 - val_loss: 165.7523 Epoch 4/50 - 7s - loss: 164.8370 - val_loss: 164.2769 Epoch 5/50 - 7s - loss: 163.3543 - val_loss: 162.8434 Epoch 6/50 - 7s - loss: 162.1008 - val_loss: 161.5877 Epoch 7/50 - 7s - loss: 161.0557 - val_loss: 160.8535 Epoch 8/50 - 7s - loss: 160.0666 - val_loss: 159.8321 Epoch 9/50 - 7s - loss: 159.1418 - val_loss: 159.0197 Epoch 10/50 - 7s - loss: 158.3370 - val_loss: 158.1361 Epoch 11/50 - 7s - loss: 157.6069 - val_loss: 157.9401 Epoch 12/50 - 7s - loss: 156.9792 - val_loss: 156.8949 Epoch 13/50 - 7s - loss: 156.4237 - val_loss: 156.6119 Epoch 14/50 - 7s - loss: 155.9121 - val_loss: 156.3393 Epoch 15/50 - 7s - loss: 155.4880 - val_loss: 156.1204 Epoch 16/50 - 7s - loss: 155.0392 - val_loss: 156.1275 Epoch 17/50 - 7s - loss: 154.6869 - val_loss: 155.0182 Epoch 18/50 - 7s - loss: 154.3163 - val_loss: 154.7265 Epoch 19/50 - 7s - loss: 154.0314 - val_loss: 154.6318 Epoch 20/50 - 7s - loss: 153.7218 - val_loss: 154.3242 Epoch 21/50 - 7s - loss: 153.4679 - val_loss: 154.4581 Epoch 22/50 - 7s - loss: 153.1870 - val_loss: 153.9840 Epoch 23/50 - 7s - loss: 152.9839 - val_loss: 153.8052 Epoch 24/50 - 7s - loss: 152.7605 - val_loss: 153.9252 Epoch 25/50 - 7s - loss: 152.5473 - val_loss: 153.7088 Epoch 26/50 - 7s - loss: 152.3364 - val_loss: 153.6085 Epoch 27/50 - 7s - loss: 152.1534 - val_loss: 153.2579 Epoch 28/50 - 7s - loss: 152.0203 - val_loss: 153.1721 Epoch 29/50 - 7s - loss: 151.8157 - val_loss: 153.1620 Epoch 30/50 - 7s - loss: 151.6618 - val_loss: 153.1152 Epoch 31/50 - 7s - loss: 151.5067 - val_loss: 152.8042 Epoch 32/50 - 7s - loss: 151.3810 - val_loss: 152.8947 Epoch 33/50 - 7s - loss: 151.2489 - val_loss: 152.5349 Epoch 34/50 - 7s - loss: 151.0784 - val_loss: 152.9623 Epoch 35/50 - 7s - loss: 150.9885 - val_loss: 152.5544 Epoch 36/50 - 7s - loss: 150.8500 - val_loss: 152.6461 Epoch 37/50 - 7s - loss: 150.7122 - val_loss: 152.0259 Epoch 38/50 - 7s - loss: 150.6166 - val_loss: 152.8115 Epoch 39/50 - 7s - loss: 150.4893 - val_loss: 152.0757 Epoch 40/50 - 7s - loss: 150.4013 - val_loss: 152.3985 Epoch 41/50 - 7s - loss: 150.3099 - val_loss: 151.8574 Epoch 42/50 - 7s - loss: 150.1686 - val_loss: 151.9436 Epoch 43/50 - 7s - loss: 150.1139 - val_loss: 152.1689 Epoch 44/50 - 7s - loss: 150.0130 - val_loss: 151.7565 Epoch 45/50 - 7s - loss: 149.8982 - val_loss: 151.7926 Epoch 46/50 - 7s - loss: 149.8157 - val_loss: 152.5258 Epoch 47/50 - 7s - loss: 149.7415 - val_loss: 151.6356 Epoch 48/50 - 7s - loss: 149.6592 - val_loss: 151.9871 Epoch 49/50 - 7s - loss: 149.5619 - val_loss: 151.6842 Epoch 50/50 - 7s - loss: 149.4900 - val_loss: 151.6131
# display a 2D plot of the digit classes in the latent space
x_test_encoded = encoder.predict(x_test, batch_size=batch_size)
fig, ax = plt.subplots()
fig.set_size_inches((8,7))
plt.scatter(x_test_encoded[:, 0], x_test_encoded[:, 1], c=y_test)
plt.colorbar()
display(fig)
# build a digit generator that can sample from the learned distribution
decoder_input = Input(shape=(latent_dim,))
_h_decoded = decoder_h(decoder_input)
_x_decoded_mean = decoder_mean(_h_decoded)
generator = Model(decoder_input, _x_decoded_mean)

# display a 2D manifold of the digits
n = 15  # figure with 15x15 digits
digit_size = 28
figure = np.zeros((digit_size * n, digit_size * n))
# linearly spaced coordinates on the unit square were transformed through the inverse CDF (ppf) of the Gaussian
# to produce values of the latent variables z, since the prior of the latent space is Gaussian
grid_x = norm.ppf(np.linspace(0.05, 0.95, n))
grid_y = norm.ppf(np.linspace(0.05, 0.95, n))

for i, yi in enumerate(grid_x):
    for j, xi in enumerate(grid_y):
        z_sample = np.array([[xi, yi]])
        x_decoded = generator.predict(z_sample)
        digit = x_decoded[0].reshape(digit_size, digit_size)
        figure[i * digit_size: (i + 1) * digit_size,
               j * digit_size: (j + 1) * digit_size] = digit

fig, ax = plt.subplots()
fig.set_size_inches((7,7))
ax.imshow(figure, cmap='Greys_r')
display(fig)

Note that it is blurry, and "manipulable" by moving through the latent space!


It is not intuitively obvious where the calculation of the KL divergence comes from, and in general there is not a simple analytic way to derive KL divergence for arbitrary distributions. Because we have assumptions about Gaussians here, this is a special case -- the derivation is included in the Auto-Encoding Variational Bayes paper (2014; https://arxiv.org/pdf/1312.6114.pdf)


Generative Adversarial Network (GAN)

The GAN, popularized recently by Ian Goodfellow's work, consists of two networks:

  1. Generator network (that initially generates output from noise)
  2. Discriminator network (trained with real data, to simply distinguish 2 class: real and fake)
    • The discriminator is also sometimes called the "A" or adversarial network

The basic procedure for building a GAN is to train both neworks in tandem according to the following simple procedure:

  1. Generate bogus output from "G"
  2. Train "D" with real and bogus data, labeled properly
  3. Train "G" to target the "real/true/1" label by
    • taking the "stacked" G + D model
    • feeding noise in at the start (G) end
    • and backpropagating from the real/true/1 distribution at the output (D) end

As always, there are lots of variants! But this is the core idea, as illustrated in the following code.

Zackory Erickson's example is so elegant and clear, I've used included it from https://github.com/Zackory/Keras-MNIST-GAN

Once again, we'll start it running first, since it takes a while to train.

import os
import numpy as np
import matplotlib.pyplot as plt

from keras.layers import Input
from keras.models import Model, Sequential
from keras.layers.core import Reshape, Dense, Dropout, Flatten
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import Convolution2D, UpSampling2D
from keras.layers.normalization import BatchNormalization
from keras.regularizers import l1, l1_l2
from keras.optimizers import Adam
from keras import backend as K
from keras import initializers
import sklearn.datasets

K.set_image_data_format('channels_last')

# Deterministic output.
# Tired of seeing the same results every time? Remove the line below.
np.random.seed(1000)

# The results are a little better when the dimensionality of the random vector is only 10.
# The dimensionality has been left at 100 for consistency with other GAN implementations.
randomDim = 100

train_libsvm = "/dbfs/databricks-datasets/mnist-digits/data-001/mnist-digits-train.txt"
test_libsvm = "/dbfs/databricks-datasets/mnist-digits/data-001/mnist-digits-test.txt"

X_train, y_train = sklearn.datasets.load_svmlight_file(train_libsvm, n_features=784)
X_train = X_train.toarray()
X_test, y_test = sklearn.datasets.load_svmlight_file(test_libsvm, n_features=784)
X_test = X_test.toarray()

X_train = (X_train.astype(np.float32) - 127.5)/127.5
X_train = X_train.reshape(60000, 784)

# Function for initializing network weights
def initNormal():
    return initializers.normal(stddev=0.02)

# Optimizer
adam = Adam(lr=0.0002, beta_1=0.5)

generator = Sequential()
generator.add(Dense(256, input_dim=randomDim, kernel_initializer=initializers.normal(stddev=0.02)))
generator.add(LeakyReLU(0.2))
generator.add(Dense(512))
generator.add(LeakyReLU(0.2))
generator.add(Dense(1024))
generator.add(LeakyReLU(0.2))
generator.add(Dense(784, activation='tanh'))
generator.compile(loss='binary_crossentropy', optimizer=adam)

discriminator = Sequential()
discriminator.add(Dense(1024, input_dim=784, kernel_initializer=initializers.normal(stddev=0.02)))
discriminator.add(LeakyReLU(0.2))
discriminator.add(Dropout(0.3))
discriminator.add(Dense(512))
discriminator.add(LeakyReLU(0.2))
discriminator.add(Dropout(0.3))
discriminator.add(Dense(256))
discriminator.add(LeakyReLU(0.2))
discriminator.add(Dropout(0.3))
discriminator.add(Dense(1, activation='sigmoid'))
discriminator.compile(loss='binary_crossentropy', optimizer=adam)

# Combined network
discriminator.trainable = False
ganInput = Input(shape=(randomDim,))
x = generator(ganInput)
ganOutput = discriminator(x)
gan = Model(inputs=ganInput, outputs=ganOutput)
gan.compile(loss='binary_crossentropy', optimizer=adam)

dLosses = []
gLosses = []

# Plot the loss from each batch
def plotLoss(epoch):
    plt.figure(figsize=(10, 8))
    plt.plot(dLosses, label='Discriminitive loss')
    plt.plot(gLosses, label='Generative loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.savefig('/dbfs/FileStore/gan_loss_epoch_%d.png' % epoch)

# Create a wall of generated MNIST images
def plotGeneratedImages(epoch, examples=100, dim=(10, 10), figsize=(10, 10)):
    noise = np.random.normal(0, 1, size=[examples, randomDim])
    generatedImages = generator.predict(noise)
    generatedImages = generatedImages.reshape(examples, 28, 28)

    plt.figure(figsize=figsize)
    for i in range(generatedImages.shape[0]):
        plt.subplot(dim[0], dim[1], i+1)
        plt.imshow(generatedImages[i], interpolation='nearest', cmap='gray_r')
        plt.axis('off')
    plt.tight_layout()
    plt.savefig('/dbfs/FileStore/gan_generated_image_epoch_%d.png' % epoch)

# Save the generator and discriminator networks (and weights) for later use
def saveModels(epoch):
    generator.save('/tmp/gan_generator_epoch_%d.h5' % epoch)
    discriminator.save('/tmp/gan_discriminator_epoch_%d.h5' % epoch)

def train(epochs=1, batchSize=128):
    batchCount = X_train.shape[0] // batchSize
    print('Epochs:', epochs)
    print('Batch size:', batchSize)
    print('Batches per epoch:', batchCount)

    for e in range(1, epochs+1):
        print('-'*15, 'Epoch %d' % e, '-'*15)
        for _ in range(batchCount):
            # Get a random set of input noise and images
            noise = np.random.normal(0, 1, size=[batchSize, randomDim])
            imageBatch = X_train[np.random.randint(0, X_train.shape[0], size=batchSize)]

            # Generate fake MNIST images
            generatedImages = generator.predict(noise)
            # print np.shape(imageBatch), np.shape(generatedImages)
            X = np.concatenate([imageBatch, generatedImages])

            # Labels for generated and real data
            yDis = np.zeros(2*batchSize)
            # One-sided label smoothing
            yDis[:batchSize] = 0.9

            # Train discriminator
            discriminator.trainable = True
            dloss = discriminator.train_on_batch(X, yDis)

            # Train generator
            noise = np.random.normal(0, 1, size=[batchSize, randomDim])
            yGen = np.ones(batchSize)
            discriminator.trainable = False
            gloss = gan.train_on_batch(noise, yGen)

        # Store loss of most recent batch from this epoch
        dLosses.append(dloss)
        gLosses.append(gloss)

        if e == 1 or e % 10 == 0:
            plotGeneratedImages(e)
            saveModels(e)

    # Plot losses from every epoch
    plotLoss(e)

train(10, 128)
WARNING:tensorflow:From /databricks/python/local/lib/python2.7/site-packages/keras/backend/tensorflow_backend.py:3445: calling dropout (from tensorflow.python.ops.nn_ops) with keep_prob is deprecated and will be removed in a future version. Instructions for updating: Please use `rate` instead of `keep_prob`. Rate should be set to `rate = 1 - keep_prob`. ('Epochs:', 10) ('Batch size:', 128) ('Batches per epoch:', 468) ('---------------', 'Epoch 1', '---------------') ('---------------', 'Epoch 2', '---------------') ('---------------', 'Epoch 3', '---------------') ('---------------', 'Epoch 4', '---------------') ('---------------', 'Epoch 5', '---------------') ('---------------', 'Epoch 6', '---------------') ('---------------', 'Epoch 7', '---------------') ('---------------', 'Epoch 8', '---------------') ('---------------', 'Epoch 9', '---------------') ('---------------', 'Epoch 10', '---------------')

Sample generated digits: epoch 1

Sample generated digits: epoch 10

Generator/Discriminator Loss

Which Strategy to Use?

This is definitely an area of active research, so you'll want to experiment with both of these approaches.

GANs typically produce "sharper pictures" -- the adversarial loss is better than the combined MSE/XE + KL loss used in VAEs, but then again, that's partly by design.

VAEs are -- as seen above -- blurrier but more manipulable. One way of thinking about the multivariate Gaussian representation is that VAEs are trained to find some "meaning" in variation along each dimensin. And, in fact, with specific training it is possible to get them to associate specific meanings like color, translation, rotation, etc. to those dimensions.

Where Next?

Keep an eye on Medium articles and others at sites like towardsdatascience:

Here is one which caught my eve in 2019: