$$ \newcommand{\problemdivider}{\begin{center}\large \bf\ldots\ldots\ldots\ldots\ldots\ldots\end{center}} \newcommand{\subproblemdivider}{\begin{center}\large \bf\ldots\ldots\end{center}} \newcommand{\pdiv}{\problemdivider} \newcommand{\spdiv}{\subproblemdivider} \newcommand{\ba}{\begin{align*}} \newcommand{\ea}{\end{align*}} \newcommand{\rt}{\right} \newcommand{\lt}{\left} \newcommand{\bp}{\begin{problem}} \newcommand{\ep}{\end{problem}} \newcommand{\bsp}{\begin{subproblem}} \newcommand{\esp}{\end{subproblem}} \newcommand{\bssp}{\begin{subsubproblem}} \newcommand{\essp}{\end{subsubproblem}} \newcommand{\atag}[1]{\addtocounter{equation}{1}\label{#1}\tag{\arabic{section}.\alph{subsection}.\alph{equation}}} \newcommand{\btag}[1]{\addtocounter{equation}{1}\label{#1}\tag{\arabic{section}.\alph{equation}}} \newcommand{\ctag}[1]{\addtocounter{equation}{1}\label{#1}\tag{\arabic{equation}}} \newcommand{\dtag}[1]{\addtocounter{equation}{1}\label{#1}\tag{\Alph{chapter}.\arabic{section}.\arabic{equation}}} \newcommand{\unts}[1]{\ \text{#1}} \newcommand{\textop}[1]{\operatorname{#1}} \newcommand{\textopl}[1]{\operatornamewithlimits{#1}} \newcommand{\prt}{\partial} \newcommand{\pderi}[3]{\frac{\prt^{#3}#1}{\prt #2^{#3}}} \newcommand{\deri}[3]{\frac{d^{#3}#1}{d #2^{#3}}} \newcommand{\del}{\vec\nabla} \newcommand{\exval}[1]{\langle #1\rangle} \newcommand{\bra}[1]{\langle #1|} \newcommand{\ket}[1]{|#1\rangle} \newcommand{\ham}{\mathcal{H}} \newcommand{\arr}{\mathfrak{r}} \newcommand{\conv}{\mathop{\scalebox{2}{\raisebox{-0.2ex}{$\ast$}}}} \newcommand{\bsm}{\lt(\begin{smallmatrix}} \newcommand{\esm}{\end{smallmatrix}\rt)} \newcommand{\bpm}{\begin{pmatrix}} \newcommand{\epm}{\end{pmatrix}} \newcommand{\bdet}{\lt|\begin{smallmatrix}} \newcommand{\edet}{\end{smallmatrix}\rt|} \newcommand{\bs}[1]{\boldsymbol{#1}} \newcommand{\uvec}[1]{\bs{\hat{#1}}} \newcommand{\qed}{\hfill$\Box$} $$
Tags:
  • statistics
  • \[\newcommand{\prt}{\partial} \newcommand{\pderi}[3]{\frac{\prt^{#3}#1}{\prt #2^{#3}}} \newcommand{\deri}[3]{\frac{d^{#3}#1}{d #2^{#3}}}\]

    For posterior inference

    In general, VI approximates the posterior distribution $p(\theta|x)$ with some other distribution $q(\theta|x,\phi)$. Typically, the posterior is intractable because the integral in the normalizing constant is intractable:

    \[p(\theta|x) = \frac{p(x|\theta)p(\theta)}{\int d\theta\, p(x|\theta)p(\theta)} = \frac{p(x|\theta)p(\theta)}{p(x)}.\]

    However, something clever happens when you take the KL divergence between $q(\theta|x,\phi)$ and $p(\theta|x)$:

    \[\begin{align*} \operatorname{KL}(q\parallel p) &= \int d\theta\, q(\theta|x,\phi)\left[\log q(\theta|x,\phi) - \log p(\theta|x)\right]\\ &= \int d\theta\, q(\theta|x,\phi)\left[\log q(\theta|x,\phi) - \log\frac{p(x,\theta)}{p(x)}\right]\\ &= \int d\theta\, \left[q(\theta|x,\phi)\log q(\theta|x,\phi) - q(\theta|x,\phi)\log p(x,\theta) + q(\theta|x,\phi)\log p(x)\right]\\ &= \log p(x) + \int d\theta\, \left[q(\theta|x,\phi)\log q(\theta|x,\phi) - q(\theta|x,\phi)\log p(x,\theta) \right]. \end{align*}\]

    Since $\int d\theta\, q(\theta|x,\phi) = 1$, the marginal likelihood term pops out of the integral. Because this term is not a function of any model parameters, it is constant. Thus, maximizing the integral minimizes the KL divergence between $q$ and $p$, thereby a) maximizing the goodness of approximation and b) yielding a lower bound on the marginal likelihood (ELBO).

    Mean-field approximation

    If our posterior is over multiple parameters $\theta_1\dots\theta_N$, i.e. $p({\theta_i}_{i=1}^N|{x})$, if we assume $q$ factors disjointly over model parameters (“mean-field approximation”), i.e.

    \[q(\{\theta\}|\{\phi\},\{x\}) = \prod_i q_i(\theta_i|\phi_i,\{x\})\]

    then the optimal form of each factor is

    \(\log q_i(\theta_i|\phi_i) \propto \operatorname{E}_{q_i}[\log p(\{\theta_{\setminus i}\})] = \int d\theta_{\setminus i}\,q_i(\theta_i|\phi_i) \log p(\{\theta_{\setminus i}\}).\tag{1}\label{expect}\) where we have dropped the data term ${x}$ for brevity.

    $\phi_i$ corresponds to any hyperparameters of the underlying distributions. This is very Gibbs-like!

    But how to actually use this thing?

    Generally, we do not presuppose a specific form for each $q$. Instead, as with Gibbs sampling, we hope that the posterior factors such that each (block of) variable(s) is proportional to a closed-form distribution. Unlike with Gibbs sampling, where we want those distributions to be something easy to sample from, in this case we want it to be easy to compute relevant expectations of those distributions. In other words, we want whatever integrals fall out of \eqref{expect} to be tractable. Ideally, these are simple moments.

    Suppose we are considering a bivariate posterior, $p(\theta_1,\theta_2|{x})$. Suppose $q(\theta_1|{x},\phi_1)\sim p(\phi_1)$

    A worked example:

    Suppose we have a mixture model with a set number ${m_1,\dots m_j,\dots m_M}$ components. Also suppose for each datapoint ${x_1,\dots x_i,\dots,x_N}$, we have some likelihood $\ell_{i,j}$ that $x_i$ belongs to $m_j$.

    For an example with slightly more algebra, this works through the posterior on a normal distribution nicely: https://arxiv.org/pdf/2103.01327 This is another nice example, similar to what we did here: https://chrischoy.github.io/research/Expectation-Maximization-and-Variational-Inference-2/

    (Unfinished) proof of mean field approximation

    TODO: finish

    Factor the joint

    \[p(\theta_1,\dots,\theta_N,\{x\}) = p(\theta_1|\theta_2\dots\theta_N,\{x\})p(\theta_2|\theta_3\dots\theta_N,\{x\})\dots p(\{x\}).\]

    The ELBO is thus

    \[\mathcal{L} = \int d\theta_1\dots d\theta_N\,\left[\prod_i \big(q_i(\theta_i)\big)\log \prod_i q_i(\theta_i) - \prod_i \big(q_i(\theta_i)\big)p(\theta_1|\theta_2\dots\theta_N,\{x\})\dots p(\{x\})\right].\]

    Optimizing WRT $q_1$ (and dropping all terms that do not contain $q_1$, we have

    \[\pderi{\mathcal{L}}{q_1}{} = \pderi{}{q_1}{}\left[\int d\theta_1\, q_1\log q_1 - p(\{x\})\int d\theta_1\, q_1p(\theta_1|\theta_{\setminus 1},\{x\})\int d\theta_2\dots d\theta_N\,\prod_{i=2}^N q_i p(\theta_i|\theta_{i+1:N},\{x\})\right]\]

    For latent variable marginalization

    TODO

    Variational autoencoders

    Variational autoencoders instead optimize $q$ using a deep neural network. Note that $p(x,\theta)$ can be expanded into likelihood $p(x|\theta)$ times prior $p(\theta)$ and factored in terms of the KL divergence between $q$ and $p(\theta)$:

    \[\begin{align*} &\phantom{{}={}}\int d\theta\, \left[q(\theta|x,\phi)\log q(\theta|x,\phi) - q(\theta|x,\phi)\big(\log p(x|\theta) + \log p(\theta)\big) \right]\\ &= \operatorname{KL}(q(\theta|x,\phi)\parallel p(\theta)) - \int d\theta\, q(\theta|x,\phi)\log p(x|\theta) \end{align*}\]

    since

    \[\begin{align*} \operatorname{KL}(q(\theta|x,\phi)\parallel p(\theta)) = \int d\theta\, q(\theta|x,\phi) \left[ \log q(\theta|x,\phi) - \log p(\theta)\right]. \end{align*}\]

    Now, $\phi$ are the parameters of the network.

    What if the traditional method of vbayes were applied to VAEs?