Auto-Encoding Variational Bayes

 

Main Contributions

  1. A reparameterization of the variational lower bound yields a lower bound estimator that can be straightforwardly optimized using standard stochastic gradient method.
  2. For i.i.d. datasets with continuous latent variables per datapoint, posterior inference can be made especially efficient by fitting an approximate inference model (also called a recognition model) to the intractable posterior using the proposed lower bound estimator.

vae_graph

Problem scenario

  • $X=\{ x^{(i)} \}_{i=1}^N $ consisting of $N$ i.i.d. samples of some continuous or discrete variable $x$.
  • $p_\theta(z), p_\theta(x|z)$ come frome parametric families of distributions.
  • $\theta$와 $z^{(i)}$은 모름

어려운 점들

  1. Intractibility(보통 likelihood인 $p_\theta(x|z)$가 복잡한 함수인 경우)
    • $p_\theta(x)=\int p_\theta(z)p_\theta(x|z)$ 적분, or $\theta$에 대한 미분이 힘든 경우
    • posterior $p_\theta(z|x)$ 가 intractable
    • mean-field VB algorithm에서 필요한 적분이 intractable한 경우
  2. Large dataset
    • Sampling-based solution(e.g. Monte Carlo EM) 너무 느림(데이터하나 마다 샘플링을 하기 위한 반복루프필요)

3가지 문제를 풀자!

  1. Efficient approximate ML or MAP estimation for the parameters θ. 이로 인해 위 그림의 Graph를 따르는 생성과정으로 좋은 결과(그림)을 만들기를 바람!
  2. $x, \theta$가 주어졌을 때($x$같은 경우는 observed 되었을 때), 그에 대응되는 $z$를 효율적으로 잘 샘플링하거나 그 $z$에 대한 적분을 잘하고 싶다!(이를 Efficient approximate posterior inference $p(z|x,\theta)$라고 함.
  3. 마찬가지로 $x$에 관한 sampling 혹은 적분을 잘하고 싶음(Efficient approximate marginal inference). This allows us to perform all kinds of inference tasks where a prior over x is required. Common applications in computer vision include image denoising, inpainting and super-resolution.

$q_\phi(z\lvert x), p_\theta(x\lvert z)$ 각각을 probabilistic encoder, probabilistic decoder라고 함.

The variationial bound

$\log p_\theta(x^{(i)}) = D_KL()+\mathcal{L}(\theta,\phi;x^{(i)})$에서 $\mathcal{L}$ term 구체적으로

\[\mathcal{L}(\theta,\phi;x^{(i)}) = \mathbb{E}_{q_\phi(z\lvert x)}[ -\log q_\phi(z\lvert x) + \log p_\theta(x,z) ]\]

Usual (naive) Monte Carlo gradient estimator

\[\nabla_\phi \mathbb{E}_{q_\phi(z)}[f(z)] = \mathbb{E}_{q_\phi(z)}[f(Z)\nabla_\phi \log q_\phi(z) ]\approx \frac{1}{L} \sum_{l=1}^L f(z^{(l)}) \nabla_\phi \log q_\phi (z^{(l)})\]

위 Estimator는 high variance -> 그래서 나온게 SGVB estimator!

  • 본문의 $\tilde{\mathcal{L}}^A$는 $\mathcal{L}$의 unbiased estimator인데, monte carlo를 곁들인
  • $\tilde{\mathcal{L}}^B$는 $\tilde{\mathcal{L}}^A$를 두 term으로 분석해서 특정한 경우(예를 들어 $q_\phi(z\lvert x), p_\theta(z)$가 Gaussian)에 두 term중 KL term에 해당하는 부분이 closed form으로 deterministic하게 계산할 수 있기 때문에 그 부분은 직접계산하고 나머지 term에 대해서만 monte carlo estimate를 함
  • 직관적으로 $\tilde{\mathcal{L}}^A$는 어떤 두 기댓값의 Monte Carlo 근사의 합이고 $\tilde{\mathcal{L}}^B$는 하나의 Monte Carlo 근사만 사용하기 때문에 우리가 추정(estimate)하고자 하는 $\mathcal{L}(\theta,\phi;x^{(i)})$에 대한 variance of estimator가 B가 A보다 작을 것으로 기대됨(본문에도 그렇게 적혀있음..)

Reparametrization trick

  1. Tractable inverse cdf
    • Examples: Exponential, Cauchy, Logistic, Rayleigh, Pareto, Weibull, Reciprocal, Gompertz, Gumbel and Erlang distributions
  2. “Location-scale” family of distributions
    • Examples: Laplace, Elliptical, Student’s t, Logistic, Uniform, Triangular and Gaussian distributions
  3. Composition : 쉽게 sampling할 수 있는 random variable(auxiliary r.v.)들의 transformation으로 표현되는 distributions
    • Examples: Log-Normal (exponentiation of normally distributed variable), Gamma (a sum over exponentially distributed variables), Dirichlet (weighted sum of Gamma variates), Beta, Chi-Squared, and F distribution

Conclusion

  • ELBO를 잘 분석해서 reparametrization trick을 사용하여 미분가능한(backprop가능한) Stochastic Gradient VB(a.k.a SGVB)라는 novel estimator를 제안함.
  • 당시에 유행했던 모델, 알고리즘들 보다 large dataset에 대하여 efficient inference와 learning이 가능했다.
  • SGVB를 통한 학습을 통해 얻어딘 Auto-Encoding VB(AEVB) 모델은 inference model을 잘 근사했다.(그당시 기준)

Appendix

Monte Carlo EM

  • Monte Carlo EM 알고리즘은 encoder를 학습하지 않음. 다시 말해 $p_\theta(z\lvert x)$을 approximate하거나 $q_\phi$등 으로 대체하려는 시도자체를 안함.
  • 대신 inference시에 $z~\sim p_\theta(z\lvert x)$인 $z$를 샘플링하기 위해 다음과 같이 $p_\theta(z\lvert x)$의 score function을 계산함.
\[\nabla_z\log p_\theta(z\lvert x) = \nabla_z \log p_\theta(z) + \nabla_z \log p_\theta (x\lvert z)\]
  • MCEM은 Hamiltonian Monte Carlo leapfrog를 10회 반복하여 $z\sim p_\theta(z\lvert x)$인 $z$를 샘플링한다.
  • 이렇게 sampling된 $z\sim p_\theta(z\lvert x)$를 활용하여 marginal log likelihood $\log p_\theta (x) \approx \displaystyle\sum_{i=1}^M \log p_\theta(g(z_i))$ (여기서 $g$는 학습이 필요한 생성함수)
  • Inference시에는 marginal likelihood $p_\theta (x)$가 위와 같이 계산되는데 구체적으로는 dataset의 처음 1000개의 데이터에 대해서 각 데이터 하나씩마다 50개의 $z_i$들을 뽑고 각 $z_i$를 뽑기위해 또 HMC leapfrog를 4 step씩 실시한다.
  • 다시 말하면 위에서 $M=50$이고 $z_i\sim p_\theta(z\lvert x)$를 뽑기 위해 HMC leapfrog 4step씩 이니까 한 datapoint의 inference를 위한 marginal $p_\theta(x_i)$를 추정하기 위해 총 50*4=200번의 HMC leapfrog step을 해야됨.

If you like this post, don’t forget to give me a star. :star2:

Star This Project