Main Contributions
- A reparameterization of the variational lower bound yields a lower bound estimator that can be straightforwardly optimized using standard stochastic gradient method.
- For i.i.d. datasets with continuous latent variables per datapoint, posterior inference can be made especially efficient by fitting an approximate inference model (also called a recognition model) to the intractable posterior using the proposed lower bound estimator.
Problem scenario
- $X=\{ x^{(i)} \}_{i=1}^N $ consisting of $N$ i.i.d. samples of some continuous or discrete variable $x$.
- $p_\theta(z), p_\theta(x|z)$ come frome parametric families of distributions.
- $\theta$와 $z^{(i)}$은 모름
어려운 점들
- Intractibility(보통 likelihood인 $p_\theta(x|z)$가 복잡한 함수인 경우)
- $p_\theta(x)=\int p_\theta(z)p_\theta(x|z)$ 적분, or $\theta$에 대한 미분이 힘든 경우
- posterior $p_\theta(z|x)$ 가 intractable
- mean-field VB algorithm에서 필요한 적분이 intractable한 경우
- Large dataset
- Sampling-based solution(e.g. Monte Carlo EM) 너무 느림(데이터하나 마다 샘플링을 하기 위한 반복루프필요)
3가지 문제를 풀자!
- Efficient approximate ML or MAP estimation for the parameters θ. 이로 인해 위 그림의 Graph를 따르는 생성과정으로 좋은 결과(그림)을 만들기를 바람!
- $x, \theta$가 주어졌을 때($x$같은 경우는 observed 되었을 때), 그에 대응되는 $z$를 효율적으로 잘 샘플링하거나 그 $z$에 대한 적분을 잘하고 싶다!(이를 Efficient approximate posterior inference $p(z|x,\theta)$라고 함.
- 마찬가지로 $x$에 관한 sampling 혹은 적분을 잘하고 싶음(Efficient approximate marginal inference). This allows us to perform all kinds of inference tasks where a prior over x is required. Common applications in computer vision include image denoising, inpainting and super-resolution.
$q_\phi(z\lvert x), p_\theta(x\lvert z)$ 각각을 probabilistic encoder, probabilistic decoder라고 함.
The variationial bound
$\log p_\theta(x^{(i)}) = D_KL()+\mathcal{L}(\theta,\phi;x^{(i)})$에서 $\mathcal{L}$ term 구체적으로
\[\mathcal{L}(\theta,\phi;x^{(i)}) = \mathbb{E}_{q_\phi(z\lvert x)}[ -\log q_\phi(z\lvert x) + \log p_\theta(x,z) ]\]Usual (naive) Monte Carlo gradient estimator
\[\nabla_\phi \mathbb{E}_{q_\phi(z)}[f(z)] = \mathbb{E}_{q_\phi(z)}[f(Z)\nabla_\phi \log q_\phi(z) ]\approx \frac{1}{L} \sum_{l=1}^L f(z^{(l)}) \nabla_\phi \log q_\phi (z^{(l)})\]위 Estimator는 high variance -> 그래서 나온게 SGVB estimator!
- 본문의 $\tilde{\mathcal{L}}^A$는 $\mathcal{L}$의 unbiased estimator인데, monte carlo를 곁들인
- $\tilde{\mathcal{L}}^B$는 $\tilde{\mathcal{L}}^A$를 두 term으로 분석해서 특정한 경우(예를 들어 $q_\phi(z\lvert x), p_\theta(z)$가 Gaussian)에 두 term중 KL term에 해당하는 부분이 closed form으로 deterministic하게 계산할 수 있기 때문에 그 부분은 직접계산하고 나머지 term에 대해서만 monte carlo estimate를 함
- 직관적으로 $\tilde{\mathcal{L}}^A$는 어떤 두 기댓값의 Monte Carlo 근사의 합이고 $\tilde{\mathcal{L}}^B$는 하나의 Monte Carlo 근사만 사용하기 때문에 우리가 추정(estimate)하고자 하는 $\mathcal{L}(\theta,\phi;x^{(i)})$에 대한 variance of estimator가 B가 A보다 작을 것으로 기대됨(본문에도 그렇게 적혀있음..)
Reparametrization trick
- Tractable inverse cdf
- Examples: Exponential, Cauchy, Logistic, Rayleigh, Pareto, Weibull, Reciprocal, Gompertz, Gumbel and Erlang distributions
- “Location-scale” family of distributions
- Examples: Laplace, Elliptical, Student’s t, Logistic, Uniform, Triangular and Gaussian distributions
- Composition : 쉽게 sampling할 수 있는 random variable(auxiliary r.v.)들의 transformation으로 표현되는 distributions
- Examples: Log-Normal (exponentiation of normally distributed variable), Gamma (a sum over exponentially distributed variables), Dirichlet (weighted sum of Gamma variates), Beta, Chi-Squared, and F distribution
Conclusion
- ELBO를 잘 분석해서 reparametrization trick을 사용하여 미분가능한(backprop가능한) Stochastic Gradient VB(a.k.a SGVB)라는 novel estimator를 제안함.
- 당시에 유행했던 모델, 알고리즘들 보다 large dataset에 대하여 efficient inference와 learning이 가능했다.
- SGVB를 통한 학습을 통해 얻어딘 Auto-Encoding VB(AEVB) 모델은 inference model을 잘 근사했다.(그당시 기준)
Appendix
Monte Carlo EM
- Monte Carlo EM 알고리즘은 encoder를 학습하지 않음. 다시 말해 $p_\theta(z\lvert x)$을 approximate하거나 $q_\phi$등 으로 대체하려는 시도자체를 안함.
- 대신 inference시에 $z~\sim p_\theta(z\lvert x)$인 $z$를 샘플링하기 위해 다음과 같이 $p_\theta(z\lvert x)$의 score function을 계산함.
- MCEM은 Hamiltonian Monte Carlo leapfrog를 10회 반복하여 $z\sim p_\theta(z\lvert x)$인 $z$를 샘플링한다.
- 이렇게 sampling된 $z\sim p_\theta(z\lvert x)$를 활용하여 marginal log likelihood $\log p_\theta (x) \approx \displaystyle\sum_{i=1}^M \log p_\theta(g(z_i))$ (여기서 $g$는 학습이 필요한 생성함수)
- Inference시에는 marginal likelihood $p_\theta (x)$가 위와 같이 계산되는데 구체적으로는 dataset의 처음 1000개의 데이터에 대해서 각 데이터 하나씩마다 50개의 $z_i$들을 뽑고 각 $z_i$를 뽑기위해 또 HMC leapfrog를 4 step씩 실시한다.
- 다시 말하면 위에서 $M=50$이고 $z_i\sim p_\theta(z\lvert x)$를 뽑기 위해 HMC leapfrog 4step씩 이니까 한 datapoint의 inference를 위한 marginal $p_\theta(x_i)$를 추정하기 위해 총 50*4=200번의 HMC leapfrog step을 해야됨.
If you like this post, don’t forget to give me a star.