$$ \newcommand{\problemdivider}{\begin{center}\large \bf\ldots\ldots\ldots\ldots\ldots\ldots\end{center}} \newcommand{\subproblemdivider}{\begin{center}\large \bf\ldots\ldots\end{center}} \newcommand{\pdiv}{\problemdivider} \newcommand{\spdiv}{\subproblemdivider} \newcommand{\ba}{\begin{align*}} \newcommand{\ea}{\end{align*}} \newcommand{\rt}{\right} \newcommand{\lt}{\left} \newcommand{\bp}{\begin{problem}} \newcommand{\ep}{\end{problem}} \newcommand{\bsp}{\begin{subproblem}} \newcommand{\esp}{\end{subproblem}} \newcommand{\bssp}{\begin{subsubproblem}} \newcommand{\essp}{\end{subsubproblem}} \newcommand{\atag}[1]{\addtocounter{equation}{1}\label{#1}\tag{\arabic{section}.\alph{subsection}.\alph{equation}}} \newcommand{\btag}[1]{\addtocounter{equation}{1}\label{#1}\tag{\arabic{section}.\alph{equation}}} \newcommand{\ctag}[1]{\addtocounter{equation}{1}\label{#1}\tag{\arabic{equation}}} \newcommand{\dtag}[1]{\addtocounter{equation}{1}\label{#1}\tag{\Alph{chapter}.\arabic{section}.\arabic{equation}}} \newcommand{\unts}[1]{\ \text{#1}} \newcommand{\textop}[1]{\operatorname{#1}} \newcommand{\textopl}[1]{\operatornamewithlimits{#1}} \newcommand{\prt}{\partial} \newcommand{\pderi}[3]{\frac{\prt^{#3}#1}{\prt #2^{#3}}} \newcommand{\deri}[3]{\frac{d^{#3}#1}{d #2^{#3}}} \newcommand{\del}{\vec\nabla} \newcommand{\exval}[1]{\langle #1\rangle} \newcommand{\bra}[1]{\langle #1|} \newcommand{\ket}[1]{|#1\rangle} \newcommand{\ham}{\mathcal{H}} \newcommand{\arr}{\mathfrak{r}} \newcommand{\conv}{\mathop{\scalebox{2}{\raisebox{-0.2ex}{$\ast$}}}} \newcommand{\bsm}{\lt(\begin{smallmatrix}} \newcommand{\esm}{\end{smallmatrix}\rt)} \newcommand{\bpm}{\begin{pmatrix}} \newcommand{\epm}{\end{pmatrix}} \newcommand{\bdet}{\lt|\begin{smallmatrix}} \newcommand{\edet}{\end{smallmatrix}\rt|} \newcommand{\bs}[1]{\boldsymbol{#1}} \newcommand{\uvec}[1]{\bs{\hat{#1}}} \newcommand{\qed}{\hfill$\Box$} $$
Reparameterization trick in GLMs
22 November 2020
Tags:
  • stats
  • mcmc
  • log-normal-poisson
  • \[\newcommand{\rt}{\right} \newcommand{\lt}{\left} \newcommand{\textop}[1]{\operatorname{#1}} \newcommand{\ctag}[1]{\addtocounter{equation}{1}\label{#1}\tag{\arabic{equation}}}\]

    The “reparameterization trick” is best known for its role in variational autoencoders and other stochastic neural networks, where it’s used to backpropagate parameters through stochastic units. Suppose the output of a stochastic unit is a random sample $y\sim p(\mathbf\Theta(\vec x))$. It is impossible to differentiate the sampling operation with respect to model parameters $\mathbf\Theta$. However, if we can somehow express $y$ deterministically in terms of a random sample that does not depend on $\mathbf\Theta$, i.e. $y = f(\mathbf\Theta(\vec x), \epsilon)$, with $\epsilon\sim p(\text{const.})$, then we can compute $\del_{\mathbf\Theta} y$, since the random sample is independent of the differential1.

    A couple years ago while working on MCMC methods for sampling a GLM, I discovered another seemingly unrelated use of the reparameterization trick to dramatically improve Markov Chain mixing and thus sampler efficiency. I found scant literature documenting this, so I’m writing my results here. The specific model I worked on was a log-normal Poisson (LNP) regression, but the trick generalizes to any GLM in the exponential family.

    The log-normal Poisson model

    In its most basic form, the LNP model hierarchy is:

    \[\begin{align*} x &\sim \text{Pois}(\lambda)\\ \lambda &\sim \log\mathcal{N}(\mu,\sigma) \end{align*}\]

    (This can be easily turned into a GLM by making $\mu$ a linear function of model covariates, but for our purposes we’ll stick with a single offset parameter $\mu$.)

    The PDF of this model is

    \[p(x|\mu,\sigma) = \int_0^\infty d\epsilon\, \textop{Pois}(x|\epsilon)\log\mathcal{N}(\epsilon| \mu, \sigma)\]

    which has no closed form solution. Although it’s possible to get a very good MLE approximation of $\mu$ and $\sigma$2, we are likely interested in the posterior distribution on $\mu$ and $\sigma$,

    \[p(\mu,\sigma|\{x\},\mathcal{H}) = \frac{\prod_{i=1}^n p(x_i|\mu,\sigma)p(\mu,\sigma|\mathcal{H})}{\iint_{M,\Sigma} d\mu\,d\sigma\,\prod_{i=1}^n p(x_i|\mu,\sigma)p(\mu,\sigma|\mathcal{H})}.\]

    This has no easy approximation, since the marginal likelihood in the denominator is even less tractable than each likelihood term in the numerator3.

    Posterior sampling via MCMC

    MCMC is ideal for sampling the posterior distribution, since the complete data likelihood can be computed in closed form, and sampling from the complete distribution and discarding the samples for latent variable has the same effect as marginalizing it out. Furthermore, since the latent variables are independent, they can be sampled in parallel, making this extremely efficient to implement even for very high dimensional likelihoods.

    Our complete posterior is proportional to

    \[p(\mu,\sigma,\{\epsilon\}|\{x\},\mathcal{H}) \propto \prod_{i=1}^n\lt[\textop{pois}(x_i|\exp\epsilon_i)\mathcal{N}(\epsilon_i|\mu, \sigma)\rt] p(\mu,\sigma|\mathcal{H}).\]

    Using conjugate priors4, the explicit form for this is

    \[\begin{align*} p(\mu,\sigma,\{\epsilon\}|\{x\},\mathcal{H}) &\propto \prod_{i=1}^n \lt[\exp(\epsilon_i)^{x_i}\exp(-\exp\epsilon_i) \lt(\sigma^{-1}\exp\lt(-\frac{(\epsilon_i - \mu)^2}{2\sigma^2}\rt)\rt)\rt]\\ &\phantom{ {}\propto{} }\times\mathcal{N}(\mu|\mu_\mu,\sigma_\mu)\textop{gam}^{-1}(\sigma^2|a_\sigma,b_\sigma).\tag{1}\label{post1} \end{align*}\]

    My first impulse was to employ a Gibbs sampler directly to $\eqref{post1}$ as written, since as we see in the next section, the full conditionals for $\mu$ and $\sigma$ have closed form. In theory, this ought to make sampling extremely efficient.

    Closed form full conditionals

    The full conditional on μ is

    \[\begin{align*} p(\mu|-) &\propto \prod_{i=1}^n\lt[\exp\lt(-\frac{(\epsilon_i - \mu)^2}{2\sigma^2}\rt)\rt]\exp\lt(-\frac{(\mu-\mu_\mu)^2}{2\sigma_\mu^2}\rt)\\ &= \exp\lt(-\frac{1}{2\sigma^2}\sum_{i=1}^n (\epsilon_i - \mu)^2 \rt)\exp\lt(-\frac{\mu-\mu_\mu^2}{2\sigma_\mu^2}\rt) \end{align*}\]

    which, after some algebra, is proportional to a normal distribution:

    \[p(\mu|-)=\mathcal{N}\lt[\mu \big| \lt(\frac{\mu_\mu}{\sigma_\mu^2}+\sigma^{-2}\sum_{i=1}^n \epsilon_i\rt)(\sigma_\mu^{-2} + n\sigma^{-2})^{-1},\lt(\sigma_\mu^{-2} + n\sigma^{-2}\rt)^{-1/2}\rt]\]

    which can be easily sampled from.

    The full conditional on σ is

    \[\begin{align*} p(\sigma^2|-) &\propto \prod_{i=1}^n \sigma^{-1}\lt[\exp\lt(-\frac{(\epsilon_i - \mu)^2}{2\sigma^2}\rt)\rt] \textop{gam}^{-1}(a_\sigma,b_\sigma)\\ &= (\sigma^2)^{-a_\sigma-1-n/2}\exp\lt[-\frac{1}{\sigma^2}\lt(\frac{1}{2}\sum_{i=1}^n (\epsilon_i - \mu)^2 + b_\sigma\rt)\rt] \end{align*}\]

    which is proportional to an inverse gamma distribution:

    \[p(\sigma^2|-) = \textop{gam}^{-1}\lt(\sigma^2|a_\sigma + \frac{n}{2}, \frac{1}{2}\sum_{i=1}^n (\epsilon_i - \mu)^2 + b_\sigma\rt),\]

    which is again easily sampled from.

    Note that we can also use a normal-inverse-gamma prior, in which case $p(\mu,\sigma^2|-)\sim\mathcal{N}\textop{gam}^{-1}$.

    Non-closed form conditional

    The full conditional of $\epsilon_i$ lacks closed form:

    \[p(\epsilon_i|-) \propto \exp(\epsilon_i)^{x_i}\exp(-\exp\epsilon_i)\exp\lt(-\frac{(\epsilon_i - \mu)^2}{2\sigma^2}\rt).\]

    Thus, it must be approximately sampled via Metropolis-Hastings. Because this expression is log-concave, with negligible contribution from terms above second-order, a proposal distribution $q(\epsilon_i^\ast|\epsilon_i)$ that quadratically approximates the log full conditional at its maximum will work well, i.e.

    \[\log q(\epsilon_i^\ast|\epsilon_i) \propto \frac{(\epsilon_i - \hat\mu)^2}{\hat\sigma^2}.\]

    We use Newton-Raphson iterations to quickly find the maximum and the curvature there, and propose with a t-distribution whose mean $\hat\mu$ is the maximum, variance $\hat\sigma$ is the inverse of the curvature, and degrees-of-freedom $\nu$ is a tunable hyperparameter to adjust Metropolis acceptance rates. For $\nu\to\infty$, the t-distribution is identical to the (normal) quadratic approximation.

    Evaluating sampler efficiency

    Now that we have a (seemingly) sound sampling scheme, let’s evaluate how well it actually works. Since the log-normal Poisson PDF can be nicely approximated via Hermite quadrature, we can obtain a good approximation of the posterior numerator:

    \[\begin{align*} p(\mu,\sigma|\{x\},\mathcal{H})&\propto \prod_{i=1}^n\lt[\int_0^\infty \!d\lambda_i\,\textop{Pois}(x_i|\exp\lambda_i)\mathcal{N}(\lambda_i|\mu,\sigma)\rt]\, p(\mu,\sigma|\mathcal{H})\tag{2}\label{hermite}\\ &\approx\prod_{i=1}^n \lt[\sum_{j=1}^m w_j f(r_j;x_i,\mu,\sigma)\rt] p(\mu,\sigma|\mathcal{H}), \end{align*}\]

    where $f$ is the integrand in \eqref{hermite}, $r_j$ is the $j$th $m$th order Hermite polynomial root and $w_j$ the $j$th quadrature weight.

    If we simulate 50,000 draws from a log-normal Poisson distribution with $\mu=-3$ and $\sigma = 0.9$,

    x = np.random.poisson(np.exp(-3 + 0.9*np.random.randn(50000)));
    

    and plot level sets of the posterior numerator,

    mu_r = np.linspace(-3.15, -2.85, 100)
    sigma_r = np.linspace(0.70, 1.05, 100)
    mu_g, sigma_g = np.meshgrid(mu_r, sigma_r)
    
    z = log_post_num(mu_g, sigma_g, x)
    
    plt.figure(1); plt.clf()
    plt.contour(
      mu_r,
      sigma_r, 
      np.exp(z - np.max(z)), 
      levels = np.linspace(0.01, 1, 10), 
      cmap = "Greys", 
      vmin = -0.2
    )
    plt.scatter(-3, 0.9, color = "r", marker = "x")
    

    we can overlay MCMC steps onto the contour plot as a quick-and-dirty way of seeing if our sampler is doing a good job. See notebook for full details.

    1. For example, there is no way to differentiate the sampling operation $y\sim\mathcal{N}(\mu(\vec x), \sigma(\vec x))$ with respect to model parameters $\mu(\vec x)$ and $\sigma(\vec x)$, but by instead sampling $\epsilon \sim \mathcal{N}(0, 1)$ which is independent of our model parameters and computing $y = \mu(\vec x) + \epsilon \sigma(\vec x)$, we obtain an equivalent expression for $y$ that is differentiable with respect to our model parameters. As far as I know, this was first described by <cite> in <year> 

    2. Via derivative-based optimization of a Hermite quadrature approximation of this integral. I may write up my notes on this later. 

    3. For this minimal example, a grid search over $\mu,\sigma$ is feasible, but adding any model covariates quickly makes this impossible too. 

    4. Here, we treat $\mu$ and $\sigma$ independently, with $\mu\sim \mathcal{N}(\mu_\mu,\sigma_\mu)$ and $\sigma^2\sim \textop{gam}^{-1}(a_\sigma,b_\sigma)$. We could also model them non-independently as a normal-inverse-gamma.