worksheet09_solutions
COMP90051 Workshop 9¶
Variational autoencoders (VAEs)¶
In this worksheet, we’ll implement a variational autoencoder (VAE) as introduced by Kingma & Welling (2013).
We’ll use an independent Bernoulli likelihood for the data $\mathbf{x}$, with a spherical Gaussian prior on the latent variable $\mathbf{z}$.
We’ll adopt a convolutional architecture for the encoder/decoder neural nets, which is appropriate for image data.
In [1]:
import numpy as np
import matplotlib.pyplot as plt
plt.rcParams[‘figure.dpi’] = 108
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from packaging import version
print(“TensorFlow version: “, tf.__version__)
assert version.parse(tf.__version__) >= version.parse(“2.3”), \
“This notebook requires TensorFlow 2.3 or above.”
TensorFlow version: 2.6.0
1. Binarised MNIST data¶
We’ll reuse the MNIST data introduced in the last worksheet.
However, we need to binarise the images, since our generative model makes the simplifying assumption that each pixel is binary (black/white) rather than 8-bit greyscale.
In [2]:
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
# Binarise images
x_train = (x_train >= 128).astype(‘float32’)
x_train = np.expand_dims(x_train, -1)
x_test = (x_test >= 128).astype(‘float32′)
x_test = np.expand_dims(x_test, -1)
Below we illustrate a binary image of a “5”.
In [3]:
plt.imshow(x_train[0, :, :, 0], cmap=’Greys_r’)
plt.axis(‘off’)
plt.show()
2. Specification of the VAE¶
Following the original VAE paper (Kingma & Welling, 2013), we assume a latent variable model $p_\theta(\mathbf{x}, \mathbf{z})$ where:
$p_\theta(\mathbf{x}|\mathbf{z}) = \prod_{i = 1}^{d} \pi_i^{x_i} (1 – \pi_i)^{1 – x_i}$ (an independent Bernoulli likelihood for each pixel)
$\boldsymbol{\pi} = (\pi_1, \ldots, \pi_d) = \operatorname{DecoderNN}_\theta(\mathbf{z})$ (Bernoulli probabilities are parameterised by a decoder neural network)
$p_\theta(\mathbf{z}) = \mathcal{N}(\mathbf{z}; 0, \mathbf{I})$ (a spherical Gaussian prior on $\mathbf{z}$)
The posterior approximation $q_\phi(\mathbf{z}|\mathbf{x})$ is taken to be a factorised Gaussian:
$q_\phi(\mathbf{z}|\mathbf{x}) = \mathcal{N}(\mathbf{z}; \boldsymbol{\mu}, \operatorname{diag}(\boldsymbol{\sigma}^2))$
$(\boldsymbol{\mu}, \log \boldsymbol{\sigma}^2) = \operatorname{EncoderNN}_\phi(\mathbf{x})$ (mean and covariance are parameterised by an encoder neural network)
3. Encoder and decoder neural networks¶
In this section, we’ll implement the encoder and decoder neural networks, which are denoted by $\operatorname{EncoderNN}_\phi$ and $\operatorname{DecoderNN}_\theta$ above.
We’ll use a convolutional architecture, since we’re working with images.
For the encoder network, we’ll use two convolutional layers, followed by separate fully-connected layers for the mean ($\boldsymbol{\mu}$) and variance ($\log \boldsymbol{\sigma}^2$) parameters.
Since the network has more than one output ($\boldsymbol{\mu}$ and $\log \boldsymbol{\sigma}^2$) we can’t use the basic Sequential API in Keras.
Instead, we’ll use the more general Functional API.
With the Functional API, we start with Input layers, chain layer calls to create outputs and then instantiate a Model from the inputs and outputs.
In [4]:
LATENT_DIM = 2 # assume z lives in a 2D space (easy to visualise later)
encoder_input = layers.Input(shape=(28, 28, 1))
x = layers.Conv2D(32, 4, strides=2, activation=”relu”, padding=”same”)(encoder_input)
x = layers.Conv2D(64, 4, strides=2, activation=”relu”, padding=”same”)(x)
x = layers.Flatten()(x)
z_mean = layers.Dense(LATENT_DIM, name=’z_mean’)(x)
z_log_var = layers.Dense(LATENT_DIM, name=’z_log_var’)(x)
encoder = keras.Model(inputs=encoder_input, outputs=[z_mean, z_log_var], name=’encoder’)
encoder.summary()
Model: “encoder”
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
input_1 (InputLayer) [(None, 28, 28, 1)] 0
__________________________________________________________________________________________________
conv2d (Conv2D) (None, 14, 14, 32) 544 input_1[0][0]
__________________________________________________________________________________________________
conv2d_1 (Conv2D) (None, 7, 7, 64) 32832 conv2d[0][0]
__________________________________________________________________________________________________
flatten (Flatten) (None, 3136) 0 conv2d_1[0][0]
__________________________________________________________________________________________________
z_mean (Dense) (None, 2) 6274 flatten[0][0]
__________________________________________________________________________________________________
z_log_var (Dense) (None, 2) 6274 flatten[0][0]
==================================================================================================
Total params: 45,924
Trainable params: 45,924
Non-trainable params: 0
__________________________________________________________________________________________________
For the decoder network, we use a fully-connected layer, followed by three convolutional transpose (deconvolutional) layers.
Since there’s only one output ($\boldsymbol{\pi}$), we can use the Sequential API.
In [5]:
decoder = keras.Sequential([
keras.Input(shape=(LATENT_DIM,)),
layers.Dense(7 * 7 * 64, activation=”relu”),
layers.Reshape((7, 7, 64)),
layers.Conv2DTranspose(32, 4, activation=”relu”, strides=2, padding=”same”),
layers.Conv2DTranspose(16, 4, activation=”relu”, strides=2, padding=”same”),
layers.Conv2DTranspose(1, 4, activation=”sigmoid”, strides=1, padding=”same”)
], name=”decoder”)
decoder.summary()
Model: “decoder”
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
dense (Dense) (None, 3136) 9408
_________________________________________________________________
reshape (Reshape) (None, 7, 7, 64) 0
_________________________________________________________________
conv2d_transpose (Conv2DTran (None, 14, 14, 32) 32800
_________________________________________________________________
conv2d_transpose_1 (Conv2DTr (None, 28, 28, 16) 8208
_________________________________________________________________
conv2d_transpose_2 (Conv2DTr (None, 28, 28, 1) 257
=================================================================
Total params: 50,673
Trainable params: 50,673
Non-trainable params: 0
_________________________________________________________________
4. Objective function¶
We want to maximise the evidence lower bound (ELBO), or equivalently minimise the negative ELBO.
In the workshop slides, we saw that the negative ELBO can be written as:
$$
-\mathcal{L}_{\theta, \phi}(\mathbf{x}) = – \mathbb{E}_{q_\phi(\mathbf{z}|\mathbf{x})}[\log p_\theta(\mathbf{x}|\mathbf{z})] + D_{\mathrm{KL}}[q_\phi(\mathbf{z}|\mathbf{x}) \| p_\theta(z)]
$$Note: the first term can approximated by a single sample $\mathbf{z}^\star$ from $q\phi(\mathbf{z}|\mathbf{x})$, in which case it simplifies to $\log p\theta(\mathbf{x}|\mathbf{z}^\star)$.
The two terms in the expression for the negative ELBO can be written as:
$$
\log p\theta(\mathbf{x}|\mathbf{z}^\star) = \sum{i = 1}^{d} \left{- x_i \log (\pi_i^\star) – (1 – x_i) \log (1 – \pii^\star) \right} \ \text{with} \ \boldsymbol{\pi}^\star = \operatorname{DecoderNN}\theta(\mathbf{z}^\star)
$$
where $d$ is the dimension of $\mathbf{x}$ (in this case $d = 784$) and
$$
D_{\mathrm{KL}}(q_\phi(\mathbf{z}|\mathbf{x}) \| p_\theta(\mathbf{z})) = – \frac{1}{2} \sum_{i = 1}^{m} \left\{ 1 + \log (\sigma_i^2) – \mu_i^2 – \sigma_i^2 \right\}
$$where $m$ is the dimension of $\mathbf{z}$ (in this case $m = 2$).
We implement these terms as separate loss functions below using TensorFlow ops.
In [6]:
@tf.function()
def loss_kl_term(z_mean, z_log_var):
“””KL divergence contribution to the total loss
Args:
z_mean: A tensor with shape [n_samples, n_latent_dim]
z_log_var: A tensor with shape [n_samples, n_latent_dim]
Returns:
A scalar tensor
“””
loss = 1 + z_log_var – tf.square(z_mean) – tf.exp(z_log_var)
# Add up contributions from each latent dimension
loss = -0.5 * tf.reduce_sum(loss, axis=-1)
# Average over samples
return tf.reduce_mean(loss)
@tf.function()
def loss_reconstruction_term(x, x_reconstructed):
“””Reconstruction contribution to the total loss
Args:
x: A tensor with shape [n_samples, …]
x_reconstructed: A tensor with shape [n_samples, …]
Returns:
A scalar tensor
“””
loss = keras.backend.binary_crossentropy(x, x_reconstructed)
# Add up contributions from each pixel
loss = tf.reduce_sum(loss, axis=[1,2,3])
# Average over samples
return tf.reduce_mean(loss)
5. Training the VAE¶
We’ve now prepared the main components we need for the VAE: the encoder and decoder neural nets, as well as custom loss functions for training.
In this section, we’ll put the components together.
Below we define a new VAE model class, which will hold the encoder and decoder and coordinate training.
Since training is non-standard, we need to override the train_step method (this is the code that runs for each step of stochastic gradient descent).
There are a couple of important points to note about the code:
We need to direct TensorFlow to record operations that are applied in the forward pass, so that we can do backpropagation. In TensorFlow 2, operations can be recorded by running them in the scope of a GradientTape.
Since we can’t do backpropagation through a sampling step, we apply the reparameterisation trick. Specifically, we compute $\mathbf{z} = \boldsymbol{\mu} + \boldsymbol{\sigma} \odot \boldsymbol{\epsilon}$ where $\boldsymbol{\epsilon} \sim \mathcal{N}(0, \mathbf{I})$ in place of $\mathbf{z} \sim \mathcal{N}(\boldsymbol{\mu}, \operatorname{diag}(\boldsymbol{\sigma}^2))$.
In [7]:
class VAE(keras.Model):
def __init__(self, encoder, decoder, **kwargs):
super(VAE, self).__init__(**kwargs)
self.encoder = encoder
self.decoder = decoder
def train_step(self, data):
# Record operations to tape
with tf.GradientTape() as tape:
z_mean, z_log_var = self.encoder(data)
# Reparameterisation trick to sample z
epsilon = tf.random.normal(shape=tf.shape(z_mean))
z = z_mean + tf.exp(0.5 * z_log_var) * epsilon
pi = self.decoder(z)
# Compute contributions to the loss
r_loss = loss_reconstruction_term(data, pi)
kl_loss = loss_kl_term(z_mean, z_log_var)
total_loss = r_loss + kl_loss
# Compute gradients
grads = tape.gradient(total_loss, self.trainable_weights)
# Update trainable weights
self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
return {
“loss”: total_loss,
“reconstruction_loss”: r_loss,
“kl_loss”: kl_loss,
}
We’re now ready to instantiate the model and fit it on the training data.
This may take 10-15 min when running on a CPU.
In [8]:
vae = VAE(encoder, decoder)
vae.compile(optimizer=’adam’)
vae_history = vae.fit(x_train, epochs=10, batch_size=128)
Epoch 1/10
469/469 [==============================] – 34s 72ms/step – loss: 206.6656 – reconstruction_loss: 203.2439 – kl_loss: 3.4218
Epoch 2/10
469/469 [==============================] – 34s 72ms/step – loss: 162.8295 – reconstruction_loss: 157.9034 – kl_loss: 4.9261
Epoch 3/10
469/469 [==============================] – 34s 72ms/step – loss: 155.5948 – reconstruction_loss: 150.1400 – kl_loss: 5.4548
Epoch 4/10
469/469 [==============================] – 34s 72ms/step – loss: 152.8415 – reconstruction_loss: 147.2141 – kl_loss: 5.6274
Epoch 5/10
469/469 [==============================] – 33s 71ms/step – loss: 151.0785 – reconstruction_loss: 145.3281 – kl_loss: 5.7504
Epoch 6/10
469/469 [==============================] – 35s 75ms/step – loss: 149.8327 – reconstruction_loss: 143.9966 – kl_loss: 5.8361
Epoch 7/10
469/469 [==============================] – 35s 75ms/step – loss: 148.9349 – reconstruction_loss: 143.0347 – kl_loss: 5.9002
Epoch 8/10
469/469 [==============================] – 35s 75ms/step – loss: 148.2937 – reconstruction_loss: 142.3139 – kl_loss: 5.9798
Epoch 9/10
469/469 [==============================] – 35s 74ms/step – loss: 147.5455 – reconstruction_loss: 141.5295 – kl_loss: 6.0160
Epoch 10/10
469/469 [==============================] – 35s 75ms/step – loss: 147.0632 – reconstruction_loss: 141.0062 – kl_loss: 6.0569
6. Visualising the latent space¶
Since one of the main applications of VAEs is representation learning, it is useful to understand the structure of the latent (representation) space.
This is not too difficult in our case, since we chose to use a 2D latent space which is easy to visualise.
In the code block below, we apply the decoder on a rectangular grid of points in the latent space, plotting the expected image obtained at each point (when passed through the decoder).
In [9]:
from scipy.stats import norm
def plot_latent(vae, num_images, figsize=(15,15)):
_, image_height, image_width, _ = vae.decoder.output_shape
num_images = 29
# Create equal probability density grid on 2D latent space
z_grid = norm().ppf(np.linspace(0.01, 0.99, num_images))
z_0, z_1 = np.meshgrid(z_grid, z_grid[::-1])
# Apply decoder on grid
z = np.c_[z_0.ravel(), z_1.ravel()]
images = vae.decoder(z)
# Reorganise into num_images x num_images grid
images = tf.reshape(images, (num_images, num_images, image_height, image_width))
images = tf.transpose(images, perm=[0,2,1,3])
images = tf.reshape(images, (num_images * image_height, num_images * image_width))
plt.figure(figsize=(10,10))
plt.imshow(images, cmap=”Greys_r”)
ticks = np.arange(num_images) * image_width + image_width/2
labels = [“{:.2f}”.format(l) for l in z_grid]
plt.xticks(ticks=ticks[::2], labels=labels[::2])
plt.yticks(ticks=ticks[::2], labels=labels[::-2])
plt.xlabel(“$z_0$”)
plt.ylabel(“$z_1$”)
plt.show()
plot_latent(vae, 30)
Question: Does the VAE do a good job at representing all types of digits? How would you expect the results to change if the latent dimension $m$ was set to 3 instead of 2?
Answer:
First, we note that the structure of the latent space may vary across different runs, as we haven’t fixed the random seed.
In this instance the VAE is doing a poor job at representing 4’s and 7’s clearly (they are not clearly separated from 9’s).
It is likely that the latent space is too small to adequately represent the variation in the images.
We’d expect an improvement by increasing $m$.
Since we have class labels for the training instances, we can generate another visualisation that shows how the classes are organised in the latent space.
Specifically, we apply the encoder to all images in the training set, plot the corresponding latent codes as points, and colour the points according to the class label.
This shows similar information to the previous plot, but it makes the overlap between classes more obvious.
In [10]:
z_mean, _ = vae.encoder(x_train)
z_mean = z_mean.numpy()
plt.figure(figsize=(10,10))
for digit in range(10):
plt.scatter(z_mean[y_train==digit, 0], z_mean[y_train==digit, 1], label=digit, marker=”.”)
plt.legend(title=”Digit”)
plt.xlabel(‘$z_0$’)
plt.ylabel(‘$z_1$’)
plt.show()
7. Reconstructing images¶
In this section, we interpret the VAE as an autoencoder which is attempting to output an accurate reconstruction of an input image $\mathbf{x}$.
Since the encoder $q_\phi(\mathbf{z}|\mathbf{x})$ and decoder $p_\theta(\mathbf{x}|\mathbf{z})$ both output random variables, we need to decide on a method to collapse the random variable to a point.
For the encoder, we take the mean $\bar{\mathbf{z}} = \mathbb{E}_{q_\phi(\mathbf{z}|\mathbf{x})}(\mathbf{z})$.
For the decoder, we consider three options:
draw a sample from $p(\mathbf{x}|\bar{\mathbf{z}})$ (2nd row of plot)
take the mode $\arg \max_{\mathbf{x}} p(\mathbf{x}|\bar{\mathbf{z}})$ (3rd row of plot)
take the expected value $\mathbb{E}_{p(\mathbf{x}|\bar{\mathbf{z}})}(\mathbf{x})$ (4th row of plot)
In [12]:
num_images = 10
# randomly select some images from the training set
train_ids = np.random.choice(x_train.shape[0], size=num_images, replace=False)
fig, ax = plt.subplots(nrows=4, ncols=num_images + 1, figsize=(num_images, 0.4*num_images))
for i in range(num_images):
# Plot original image
x = x_train[train_ids[i]].squeeze()
ax[0,i].imshow(x, cmap=”Greys_r”)
ax[0,i].axis(‘off’)
z_mean, _ = vae.encoder(np.expand_dims(x, 0))
# Plot sample from decoder
pi = vae.decoder(z_mean)
x_sample = tf.cast(tf.random.uniform(pi.shape) <= pi, tf.float32)
ax[1,i].imshow(tf.squeeze(x_sample), cmap='Greys_r')
ax[1,i].axis('off')
# Plot mode of decoder
x_mode = vae.decoder(z_mean) >= 0.5
ax[2,i].imshow(tf.squeeze(x_mode), cmap=’Greys_r’)
ax[2,i].axis(‘off’)
# Plot mean of decoder
x_mean = vae.decoder(z_mean)
ax[3,i].imshow(tf.squeeze(x_mean), cmap=’Greys_r’)
ax[3,i].axis(‘off’)
ax[0,num_images].text(0, 0.4, ‘Original’)
ax[0,num_images].axis(‘off’)
ax[1,num_images].text(0, 0.4, ‘Reconstructed: sample’)
ax[1,num_images].axis(‘off’)
ax[2,num_images].text(0, 0.4, ‘Reconstructed: mode’)
ax[2,num_images].axis(‘off’)
ax[3,num_images].text(0, 0.4, ‘Reconstructed: mean’)
ax[3,num_images].axis(‘off’)
plt.show()
8. Bonus: Drawing from the marginal distribution¶
Exercise: Write code that draws a random image from the marginal $p_\theta(\mathbf{x}) = \int p_\theta(\mathbf{x}|\mathbf{z}) p_\theta(\mathbf{z}) \, \mathrm{d}\mathbf{z}$. Comment on the quality of the samples.
Solution:
Some of the samples are recognisable as images of digits. However, the images are generally noisy. The model could be improved by not treating the pixels as being independent conditional on $\mathbf{z}$.
In [ ]:
# Draw z from prior
z = tf.random.normal((1, LATENT_DIM))
# Get Bernoulli parameters conditional on z
pi = vae.decoder(z)
# Draw image conditional on pi
x = tf.cast(tf.random.uniform(pi.shape) <= pi, tf.float32)
# Plot the result
plt.imshow(x.numpy().squeeze(), cmap='Greys_r')
plt.axis('off')
plt.show()