%md
ScaDaMaLe Course [site](https://lamastex.github.io/scalable-data-science/sds/3/x/) and [book](https://lamastex.github.io/ScaDaMaLe/index.html)
This is a 2019-2021 augmentation and update of [Adam Breindel](https://www.linkedin.com/in/adbreind)'s initial notebooks.
_Thanks to [Christian von Koch](https://www.linkedin.com/in/christianvonkoch/) and [William Anzén](https://www.linkedin.com/in/william-anz%C3%A9n-b52003199/) for their contributions towards making these materials Spark 3.0.1 and Python 3+ compliant._
%md
# Generative Networks
### Concept:
If a set of network weights can convert an image of the numeral 8 or a cat
<br/>into the classification "8" or "cat" ...
### Does it contain enough information to do the reverse?
I.e., can we ask a network what "8" looks like and get a picture?
Let's think about this for a second. Clearly the classifications have far fewer bits of entropy than the source images' theoretical limit.
* Cat (in cat-vs-dog) has just 1 bit, where perhaps a 256x256 grayscale image has up to 512k bits.
* 8 (in MNIST) has \\({log _2 10}\\) or a little over 3 bits, where a 28x28 grayscale image has over 6000 bits.
So at first, this would seem difficult or impossible.
__But__ ... let's do a thought experiment.
* Children can do this easily
* We could create a lookup table of, say, digit -> image trivially, and use that as a first approximation
Those approaches seem like cheating. But let's think about how they work.
If a child (or adult) draws a cat (or number 8), they are probably not drawing any specific cat (or 8). They are drawing a general approximation of a cat based on
1. All of the cats they've seen
2. What they remember as the key elements of a cat (4 legs, tail, pointy ears)
3. A lookup table substitutes one specific cat or 8 ... and, especially in the case of the 8, that may be fine. The only thing we "lose" is the diversity of things that all mapped to cat (or 8) -- and discarding that information was exactly our goal when building a classifier
The "8" is even simpler: we learn that a number is an idea, not a specific instance, so anything that another human recognizes as 8 is good enough. We are not even trying to make a particular shape, just one that represents our encoded information that distinguishes 8 from other possible symbols in the context and the eye of the viewer.
This should remind you a bit of the KL Divergence we talked about at the start: we are providing just enough information (entropy or surprise) to distinguish the cat or "8" from the other items that a human receiver might be expecting to see.
And where do our handwritten "pixels" come from? No image in particular -- they are totally synthetic based on a probability distribution.
*These considerations make it sound more likely that a computer could perform the same task.*
But what would the weights really represent? what could they generate?
__The weights we learn in classification represent the distinguishing features of a class, across all training examples of the class, modified to overlap minimally with features of other classes.__ KL divergence again.
To be concrete, if we trained a model on just a few dozen MNIST images using pixels, it would probably learn the 3 or 4 "magic" pixels that *happened* to distinguish the 8s in that dataset. Trying to generate from that information would yield strong confidence about those magic pixels, but would look like dots to us humans.
On the other hand, if we trained on a very large number of MNIST images -- say we use the convnet this time -- the model's weights should represent general filters of masks for features that distinguish an 8. And if we try to reverse the process by amplifying just those filters, we should get a blurry statistical distribution of those very features. The approximate shape of the Platonic "8"!
Generative Networks
Concept:
If a set of network weights can convert an image of the numeral 8 or a cat
into the classification "8" or "cat" ...
Does it contain enough information to do the reverse?
I.e., can we ask a network what "8" looks like and get a picture?
Let's think about this for a second. Clearly the classifications have far fewer bits of entropy than the source images' theoretical limit.
- Cat (in cat-vs-dog) has just 1 bit, where perhaps a 256x256 grayscale image has up to 512k bits.
- 8 (in MNIST) has log2​10 or a little over 3 bits, where a 28x28 grayscale image has over 6000 bits.
So at first, this would seem difficult or impossible.
But ... let's do a thought experiment.
- Children can do this easily
- We could create a lookup table of, say, digit -> image trivially, and use that as a first approximation
Those approaches seem like cheating. But let's think about how they work.
If a child (or adult) draws a cat (or number 8), they are probably not drawing any specific cat (or 8). They are drawing a general approximation of a cat based on
- All of the cats they've seen
- What they remember as the key elements of a cat (4 legs, tail, pointy ears)
- A lookup table substitutes one specific cat or 8 ... and, especially in the case of the 8, that may be fine. The only thing we "lose" is the diversity of things that all mapped to cat (or 8) -- and discarding that information was exactly our goal when building a classifier
The "8" is even simpler: we learn that a number is an idea, not a specific instance, so anything that another human recognizes as 8 is good enough. We are not even trying to make a particular shape, just one that represents our encoded information that distinguishes 8 from other possible symbols in the context and the eye of the viewer.
This should remind you a bit of the KL Divergence we talked about at the start: we are providing just enough information (entropy or surprise) to distinguish the cat or "8" from the other items that a human receiver might be expecting to see.
And where do our handwritten "pixels" come from? No image in particular -- they are totally synthetic based on a probability distribution.
These considerations make it sound more likely that a computer could perform the same task.
But what would the weights really represent? what could they generate?
The weights we learn in classification represent the distinguishing features of a class, across all training examples of the class, modified to overlap minimally with features of other classes. KL divergence again.
To be concrete, if we trained a model on just a few dozen MNIST images using pixels, it would probably learn the 3 or 4 "magic" pixels that happened to distinguish the 8s in that dataset. Trying to generate from that information would yield strong confidence about those magic pixels, but would look like dots to us humans.
On the other hand, if we trained on a very large number of MNIST images -- say we use the convnet this time -- the model's weights should represent general filters of masks for features that distinguish an 8. And if we try to reverse the process by amplifying just those filters, we should get a blurry statistical distribution of those very features. The approximate shape of the Platonic "8"!
%md
## Mechanically, How Could This Work?
Let's start with a simpler model called an auto-encoder.
An autoencoder's job is to take a large representation of a record and find weights that represent that record in a smaller encoding, subject to the constraint that the decoded version should match the original as closely as possible.
A bit like training a JPEG encoder to compress images by scoring it with the loss between the original image and the decompressed version of the lossy compressed image.
<img src="http://i.imgur.com/oTRvlB6.png" width=450>
One nice aspect of this is that it is *unsupervised* -- i.e., we do not need any ground-truth or human-generated labels in order to find the error and train. The error is always the difference between the output and the input, and the goal is to minimize this over many examples, thus minimize in the general case.
We can do this with a simple multilayer perceptron network. Or, we can get fancier and do this with a convolutional network. In reverse, the convolution (typically called "transposed convolution" or "deconvolution") is an upsampling operation across space (in images) or space & time (in audio/video).
Mechanically, How Could This Work?
Let's start with a simpler model called an auto-encoder.
An autoencoder's job is to take a large representation of a record and find weights that represent that record in a smaller encoding, subject to the constraint that the decoded version should match the original as closely as possible.
A bit like training a JPEG encoder to compress images by scoring it with the loss between the original image and the decompressed version of the lossy compressed image.
One nice aspect of this is that it is unsupervised -- i.e., we do not need any ground-truth or human-generated labels in order to find the error and train. The error is always the difference between the output and the input, and the goal is to minimize this over many examples, thus minimize in the general case.
We can do this with a simple multilayer perceptron network. Or, we can get fancier and do this with a convolutional network. In reverse, the convolution (typically called "transposed convolution" or "deconvolution") is an upsampling operation across space (in images) or space & time (in audio/video).
from keras.models import Sequential
from keras.layers import Dense
from keras.utils import to_categorical
import sklearn.datasets
import datetime
import matplotlib.pyplot as plt
import numpy as np
train_libsvm = "/dbfs/databricks-datasets/mnist-digits/data-001/mnist-digits-train.txt"
test_libsvm = "/dbfs/databricks-datasets/mnist-digits/data-001/mnist-digits-test.txt"
X_train, y_train = sklearn.datasets.load_svmlight_file(train_libsvm, n_features=784)
X_train = X_train.toarray()
X_test, y_test = sklearn.datasets.load_svmlight_file(test_libsvm, n_features=784)
X_test = X_test.toarray()
model = Sequential()
model.add(Dense(30, input_dim=784, kernel_initializer='normal', activation='relu'))
model.add(Dense(784, kernel_initializer='normal', activation='relu'))
model.compile(loss='mean_squared_error', optimizer='adam', metrics=['mean_squared_error', 'binary_crossentropy'])
start = datetime.datetime.today()
history = model.fit(X_train, X_train, epochs=5, batch_size=100, validation_split=0.1, verbose=2)
scores = model.evaluate(X_test, X_test)
print
for i in range(len(model.metrics_names)):
print("%s: %f" % (model.metrics_names[i], scores[i]))
print ("Start: " + str(start))
end = datetime.datetime.today()
print ("End: " + str(end))
print ("Elapse: " + str(end-start))
fig, ax = plt.subplots()
fig.set_size_inches((4,4))
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('model loss')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['train', 'val'], loc='upper left')
display(fig)
%md
### Pretty cool. So we're all done now, right? Now quite...
The problem with the autoencoder is it's "too good" at its task.
It is optimized to compress exactly the input record set, so it is trained only to create records it has seen. If the middle layer, or information bottleneck, is tight enough, the coded records use all of the information space in the middle layer.
So any value in the middle layer decodes to exactly one already-seen exemplar.
In our example, and most autoencoders, there is more space in the middle layer but the coded values are not distributed in any sensible way. So we can decode a random vector and we'll probably just get garbage.
Pretty cool. So we're all done now, right? Now quite...
The problem with the autoencoder is it's "too good" at its task.
It is optimized to compress exactly the input record set, so it is trained only to create records it has seen. If the middle layer, or information bottleneck, is tight enough, the coded records use all of the information space in the middle layer.
So any value in the middle layer decodes to exactly one already-seen exemplar.
In our example, and most autoencoders, there is more space in the middle layer but the coded values are not distributed in any sensible way. So we can decode a random vector and we'll probably just get garbage.
ScaDaMaLe Course site and book
This is a 2019-2021 augmentation and update of Adam Breindel's initial notebooks.
Thanks to Christian von Koch and William Anzén for their contributions towards making these materials Spark 3.0.1 and Python 3+ compliant.