Consider a natural image of size 100×100 with a single channel. This image is a point in 10.000-dimensional space. Natural images are usually not uniformly distributed in this space but reside on a much lower-dimensional manifold within this high-dimensional space. The lower dimensionality of the manifold is related to the limited degrees of freedom in these images e.g. only a limited number of pixel value combinations are actually perceived as natural images.
Modeling natural images with latent variable models whose continuous latent variables represent locations on the manifold can be a useful approach that is also discussed here. As in part 1, a model with one latent variable ti per observation xi is used but now the latent variables are continuous rather than discrete variables. Therefore, summations over latent variable states are now replaced by integrals and these are often intractable for more complex models.
Observations i.e. images X=x1,…,xN are again described with a probabilistic model p(x∣θ). Goal is to maximize the data likelihood p(X∣θ) w.r.t. θ and to obtain approximate posterior distributions over continuous latent variables. The joint distribution over an observed variable x and a latent variable t is defined as the product of the conditional distribution over x given t and the prior distribution over t.
观测( 即图像 X={x1,…,xN} ) 再次用概率模型 p(x∣θ)。目标是最大化数据似然 p(X∣θ) w.r.t. θ 并获得连续潜在变量的近似后验分布。观察变量 x 和潜在变量 t 上的联合分布被定义为给定 t 上 x 的条件分布的乘积以及 t 上的先验分布。
p(x,t∣θ)=p(x∣t,θ)p(t∣θ)(1)
We obtain the marginal distribution over x by integrating over t.
我们通过在 t 上积分来获得 x 上的边际分布。
p(x∣θ)=∫p(x∣t,θ)p(t∣θ)dt(2)
This integral is usually intractable for even moderately complex conditional probabilities p(x∣t,θ) and consequently also the true posterior.
p(t∣x,θ)=p(x∣θ)p(x∣t,θ)p(t∣θ)(3)
This means that the E-step of the EM algorithm becomes intractable. Recall from part 1 that the lower bound of the log marginal likelihood is given by
即使对于中等复杂的条件概率 p(x∣t,θ),这个积分通常也是难以处理的,因此也是真正的后验概率。
L(θ,q)=logp(X∣θ)−KL(q(T∣X)∣∣p(T∣X,θ))(4)
In the E-step, the lower bound is maximized w.r.t. q and θ is held fixed. If the true posterior is tractable, we can set q to the true posterior so that the KL divergence becomes 0 which maximizes the lower bound for the current value of θ. If the true posterior is intractable approximations must be used.
在 E-step 中,下界被最大化 w.r.t. q 和 θ 保持固定。如果真实后验是可处理的,我们可以将 q 设置为真实后验,以便 KL 散度变为 0,从而最大化 θ 当前值的下限。如果真正的后验是棘手的,则必须使用近似值。
Here, we will use stochastic variational inference, a Bayesian inference method that also scales to large datasets[1]. Numerous other approximate inference approaches exist but these are not discussed here to keep the article focused.
在这里,我们将使用随机变分推理,这是一种贝叶斯推理方法,也适用于大型数据集[1]。存在许多其他近似推理方法,但为了保持本文的重点,这里不讨论这些方法。
随机变分推断
The field of mathematics that covers the optimization of a functional w.r.t. a function, like argmaxqL(θ,q) in our example, is the calculus of variations, hence the name variational inference. In this context, q is called a variational distribution and L(θ,q) a variational lower bound.
We will approximate the true posterior with a parametric variational distribution q(t∣x,ϕ) and try to find a value of ϕ that minimizes the KL divergence between this distribution and the true posterior. Using q(t∣x,ϕ) we can formulate the variational lower bound for a single observation xi as
We assume that the integral ∫q(ti∣xi,ϕ)logp(xi∣ti,θ)dti is intractable but we can choose a functional form of q(ti∣xi,ϕ) from which we can easily sample so that the expectation of logp(xi∣ti,θ) w.r.t. to q(ti∣xi,ϕ) can be approximated with L samples from q.
where ti,l∼q(ti∣xi,ϕ). We will also choose the functional form of q(ti∣xi,ϕ) and p(ti∣θ) such that integration of the KL divergence can be done analytically, hence, no samples are needed to evaluate the KL divergence. With these choices, an approximate evaluation of the variational lower bound is possible. But in order to optimize the lower bound w.r.t. θ and ϕ we need to approximate the gradients w.r.t. these parameters.
We first assume that the analytical expression of the KL divergence, the second term on the RHS of Eq. (5), is differentiable w.r.t. ϕ and θ so that deterministic gradients can be computed. The gradient of the first term on the RHS of Eq. (5) w.r.t. θ is
Here, ∇θ can be moved inside the expectation as q(ti∣xi,ϕ) doesn’t depend on θ. Assuming that p(xi∣ti,θ) is differentiable w.r.t. θ, unbiased estimates of the gradient can be obtained by sampling from q(ti∣xi,ϕ).
We will later implement p(x_i∣t_i,θ) as neural network and use Tensorflow to compute ∇θlogp(x_i∣t_i,l,θ). The gradient w.r.t. ϕ is a bit more tricky as ∇ϕ cannot be moved inside the expectation because q(ti∣xi,ϕ) depends on ϕ. But if we can decompose q(ti∣xi,ϕ) into an auxiliary distribution p(ϵ) that doesn’t depend on ϕ and a deterministic, differentiable function g(ϵ,x,ϕ) where ti=g(ϵ,xi,ϕ) and ϵ∼p(ϵ) then we can re-formulate the gradient w.r.t. ϕ as
where ti,l=g(ϵl,xi,ϕ) and ϵl∼p(ϵ). This so-called reparameterization trick can be applied to a wide range of probability distributions, including Gaussian distributions. Furthermore, stochastic gradients w.r.t. ϕ obtained with this trick have much smaller variance than those obtained with alternative approaches (not shown here).
The above approximations for the variational lower bound and its gradients have been formulated for a single training example xi but this can be easily extended to mini-batches XM=x1,…,xM with M random samples from a dataset X of N i.i.d. observations. The lower bound of the full dataset L(θ,q;X) can then be approximated as
上述变分下界及其梯度的近似值已针对单个训练示例 xi 制定,但这可以轻松扩展到小批量 XM=mathbfx1,…,xM 来自数据集 X of N iid 的 M 随机样本观察。完整数据集 L(θ,q;X) 的下界可以近似为
L(θ,q;X)≈MNi=1∑ML(θ,q;xi)=LM(θ,q;XM)(11)
Gradients of LM(θ,q;XM) can be obtained as described above together with averaging over the mini-batch and used in combination with optimizers like Adam, for example, to update the parameters of the latent variable model. Sampling from the variational distribution q and usage of mini-batches leads to noisy gradients, hence the term stochastic variational inference.
If M is sufficiently large, for example M=100, then L can be even set to 1 i.e. a single sample from the variational distribution per training example is sufficient to get a good gradient estimate on average.
如果 M 足够大,例如 M=100,那么 L 甚至可以设置为 1,即每个训练示例的变分分布中的单个样本足以获得良好的平均梯度估计。
Variational autoencoder
变分自编码器
From the perspective of a generative model, q(ti∣xi,ϕ) is a probabilistic encoder because it generates a latent codeti for input image xi and p(xi∣ti,θ) is a probabilistic decoder because it generates or reconstructs an image xi from latent code ti. Optimizing the variational lower bound w.r.t. parameters θ and ϕ can therefore be regarded as training a probabilistic autoencoder or variational autoencoder (VAE)[1].
In this context, the first term on the RHS of Eq. (5) can be interpreted as expected negative reconstruction error. The second term is a regularization term that encourages the variational distribution to be close to the prior over latent variables. If the regularization term is omitted, the variational distribution would collapse to a delta function and the variational autoencoder would degenerate to a “usual” deterministic autoencoder.
Implementation
For implementing a variational autoencoder, we make the following choices:
The variational distribution q(ti∣xi,ϕ) is a multivariate Gaussian N(ti∣μ(xi,ϕ),σ2(xi,ϕ)) with a diagonal covariance matrix where mean vector μ and the covariance diagonal σ2 are functions of xi and ϕ. These functions are implemented as neural network and learned during optimization of the lower bound w.r.t. ϕ. After reparameterization, samples from q(ti∣xi,ϕ) are obtained via the deterministic function g(ϵ,xi,ϕ)=μ(xi,ϕ)+σ2(xi,ϕ)⊙ϵ and an auxiliary distribution p(ϵ)=N(ϵ∣0,I).
The conditional distribution p(xi∣ti,θ) is a multivariate Bernoulli distribution Ber(xi∣k(ti,θ)) where parameter k is a function of ti and θ. This distribution models the binary training data i.e. monochrome (= binarized) MNIST images in our example. Function k computes for each pixel its expected value. It is also implemented as neural network and learned during optimization of the lower bound w.r.t. θ. Taking the (negative) logarithm of Ber(xi∣k(ti,θ)) gives a sum over pixel-wise binary cross entropies as shown in Eq. (12)
Prior p(t_i∣θ) is a multivariate Gaussian distribution N(t_i∣0,I) with zero mean and unit covariance matrix. With the chosen functional forms of the prior and the variational distribution q, KL(q(t_i∣x_i,ϕ)∣∣p(t_i∣θ)) can be integrated analytically to −21∑d=1D(1+logσi,d2−μi,d2−σi,d2) where D is the dimensionality of the latent space and μi,d and σi,d is the d-th element of μ(xi,ϕ) and σ(xi,ϕ), respectively.
Using these choices and setting L=1, the variational lower bound for a single image xi can be approximated as
where xi,c is the value of pixel c in image x_i and k_i,c its expected value. The negative value of the lower bound is used as loss during training. The following figure outlines the architecture of the variational autoencoder.
The definitions of the encoder and decoder neural networks were taken from [2]. Here, the encoder computes the logarithm of the variance, instead of the variance directly, for reasons of numerical stability.
classVariationalAutoencoder(Model): def__init__(self, latent_dim=2): """ Creates a variational autoencoder Keras model. Args: latent_dim: dimensionality of latent space. """ super().__init__() self.latent_dim = latent_dim self.encoder = create_encoder(latent_dim) self.decoder = create_decoder(latent_dim)
defencode(self, x): """ Computes variational distribution q statistics from input image x. Args: x: input image, shape (M, 28, 28, 1). Returns: Mean, shape (M, latent_dim), and log variance, shape (M, latent_dim), of multivariate Gaussian distribution q. """ q_mean, q_log_var = self.encoder(x) return q_mean, q_log_var
defsample(self, q_mean, q_log_var): """ Samples latent code from variational distribution q. Args: q_mean: mean of q, shape (M, latent_dim). q_log_var: log variance of q, shape (M, latent_dim). Returns: Latent code sample, shape (M, latent_dim). """ eps = tf.random.normal(shape=q_mean.shape) return q_mean + tf.exp(q_log_var * .5) * eps
defdecode(self, t): """ Computes expected pixel values (= probabilities k) from latent code t. Args: t: latent code, shape (M, latent_dim). Returns: Probabilities k of multivariate Bernoulli distribution p, shape (M, 28, 28, 1). """ k = self.decoder(t) return k
defcall(self, x): """ Computes expected pixel values (= probabilities k) of a reconstruction of input image x. Args: x: input image, shape (M, 28, 28, 1). Returns: Probabilities k of multivariate Bernoulli distribution p, shape (M, 28, 28, 1). """ q_mean, q_log_var = self.encode(x) t = self.sample(q_mean, q_log_var) return self.decode(t)
The variational_lower_bound function is implemented using Eq. (12) and Eq. (11) but instead of estimating the lower bound for the full dataset it is normalized by the dataset size N.
# Average over mini-batch (of size M) return tf.reduce_mean(neg_rc_error + neg_kl_div)
The training procedure uses the negative value of the variational lower bound as loss to compute stochastic gradient estimates. These are used by the optimizer to update model parameters θ and ϕ. The normalized variational lower bound of the test set is computed at the end of each epoch and printed.
@tf.function deftrain_step(model, optimizer, x): """Trains VAE on mini-batch x using optimizer. """ with tf.GradientTape() as tape: # Compute neg. variational lower bound as loss loss = -variational_lower_bound(model, x) # Compute gradients from neg. variational lower bound gradients = tape.gradient(loss, model.trainable_variables) # Apply gradients to model parameters theta and phi optimizer.apply_gradients(zip(gradients, model.trainable_variables)) return loss
deftrain(model, optimizer, ds_train, ds_test, epochs): """Trains VAE on training dataset ds_train using optimizer for given number of epochs. """ for epoch inrange(1, epochs + 1): for x in ds_train: train_step(model, optimizer, x)
vlb_mean = tf.keras.metrics.Mean() for x in ds_test: vlb_mean(variational_lower_bound(model, x)) vlb = vlb_mean.result() print(f'Epoch: {epoch:02d}, Test set VLB: {vlb:.2f}')
Since the data are modelled with a multivariate Bernoulli distribution, the MNIST images are first binarized to monochrome images so that their pixel values are either 0 or 1. The training batch size is set to 100 to get reliable stochastic gradient estimates.
We choose a two-dimensional latent space so that it can be easily visualized. Training the variational autoencoder with RMSProp as optimizer at a learning rate of 1e-3 for 20 epochs gives already reasonable results. This takes a few minutes on a single GPU.
Epoch: 01, Test set VLB: -166.56
Epoch: 02, Test set VLB: -158.25
Epoch: 03, Test set VLB: -154.44
Epoch: 04, Test set VLB: -152.20
Epoch: 05, Test set VLB: -150.47
Epoch: 06, Test set VLB: -148.30
Epoch: 07, Test set VLB: -148.63
Epoch: 08, Test set VLB: -146.66
Epoch: 09, Test set VLB: -145.61
Epoch: 10, Test set VLB: -147.64
Epoch: 11, Test set VLB: -148.42
Epoch: 12, Test set VLB: -143.86
Epoch: 13, Test set VLB: -143.31
Epoch: 14, Test set VLB: -145.67
Epoch: 15, Test set VLB: -143.78
Epoch: 16, Test set VLB: -143.29
Epoch: 17, Test set VLB: -142.25
Epoch: 18, Test set VLB: -142.99
Epoch: 19, Test set VLB: -143.39
Epoch: 20, Test set VLB: -143.31
The following figure shows the locations of test set images in latent space. Here, the mean vectors of the variational distributions are plotted. The latent space is organized by structural similarity of digits i.e. structurally similar digits have a smaller distance in latent space than structurally dissimilar digits. For example, digits 4 and 9 usually differ only by a horizontal bar or curve at the top of the image and are therefore in proximity.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
import matplotlib.pyplot as plt
%matplotlib inline
# Compute mean vectors of variational distributions (= latent code locations) q_test_mean, _ = vae.encode(x_test)
# Use a discrete colormap cmap = plt.get_cmap('viridis', 10)
# Plot latent code locations colored by the digit value on input images im = plt.scatter(q_test_mean[:, 0], q_test_mean[:, 1], c=y_test, cmap=cmap, vmin=-0.5, vmax=9.5, marker='x', s=0.2)
plt.colorbar(im, ticks=range(10));
When we sample locations in latent space (with density proportional to the prior density over latent variables) and decode these locations we can get a nice overview how MNIST digits are organized by structural similarity in latent space. Each digit is plotted with its expected pixel values k instead of using a sample from the corresponding multivariate Bernoulli distribution.