My basic understanding of HMC
The Hamitonian Monte Carlo (HMC) and its variants are almost the best sampling algorithms now, as long as you can get access to the gradient. I’ve learnt it several years ago, but soon forgot. Now I try again, to understand it in a physicist way – though i’m not a physicist.
Traditional MCMC
An analogy to statistical mechanics
Suppose we want to get the distribution \(V(q)\) of an variable \(q\) (let’s name it in Hamitonian way), one thing we can learn from the statistical physics is that at equilibrium state, a physical system with potential \(U(q)\) converges at \(p(q)\propto \exp^{-U(q)}\) (Assuming kinetic energy is much less than the potential energy in this system and \(\beta=1\)). Then one can easily figure out, the distribution \(V(q)\) is equivalent to the particle distribution in a physical system with potential \(- \log V(q)\) in the equilibrium state. Suppose this state have many particles – their distribution will show us our target distribution.
Temporal series, instead of ensembles
– But how do we get these particles? What we want is the ensemble of single particle. In this ensemble, the probability of particle show up at state \(q\) is \(\exp^{-U(q)}\). Get an ensemble is non-trivial, but back to why we introduce the ensemble, we want to substitude the ensemble average with temporal average, now we can do the opposite: we evolve a particle to get its temporal evolution, which should be identical to an ensemble.
Well, as in statistical mechanics, `detailed balance’ can help us to determine the evolution track of a particle. The detailed balance tells us, in the equilibrium, the transition probability from \(q_0\) to \(q_1\) and its inverse follows: \(\begin{equation} p_t(q_0 \rightarrow q_1) p(q_0) = p_t(q_1 \rightarrow q_0) p(q_1) \label{eqn:detbal} \end{equation}\) where \(p_t\) represents the transition probability. Its meaning is precise and clear: at any time, the number of transition events from \(q_1\) to \(q_0\) should be the same of \(q_0\) to \(q_1\), thus the total distribution holds. After knowing this, we can see the sampling process as a chain of transitions: a particle in this system will transit all the time, and the time evolution can be viewed as ensembles. It’s interesting that in statistical physics we use ensemble average to substitude the time average, while now we do the opposite. Now one can create a transition probability: \(\begin{equation} p_t(q_0 \rightarrow q_1) = \left\{ \begin{aligned} &1 &p(q_0)<p(q_1) \\ &\frac{p(q_1)}{p(q_0)}\ \ &p(q_0)>p(q_1) \end{aligned} \right. \label{eqn:tranprob} \end{equation}\) It’s not hard to check this equation is consistent with equation \eqref{eqn:detbal}.
Then one can imagine in a physical system, a particle \(q\) will go through a bunch of states: \(\begin{equation} {\rm State}_0 \rightarrow {\rm State}_1 \rightarrow {\rm State}_2\rightarrow ... \rightarrow {\rm State}_n \label{eqn:chain} \end{equation}\) And the distribution of all these \(n\) states follows the \(V(q)\).
A quick proof of detailed balance will lead to \(V(q)\)
Imagine we have a two-level system, with only two possible states, let’s do this process for \(n\) times: suppose \(q_0\) has lower potential energy and we start from it. Then assume process \(q_0 \rightarrow q_1\) is tried \(x\) times, the backward process must happen \(x*p(q_1)/p(q_0)\) times. So among all the samples, we have \(x*p(q_1)/p(q_0)\) \(q_1\), and \(x\) \(q_0\), and you can see it perfectly recovers the distribution.
– But in real problem we have infinite possible states! Does this relation still holds? Yes. Suppose we run a long chain as in \eqref{eqn:chain}, and it has come to a steady state, thus the distribution of particles \(\pi(q)\) won’t change overtime. Let’s cut the total chain into two parts: the first part is \({\rm State}_0,{\rm State}_2,{\rm State}_4,...\) and the second part is \({\rm State}_1,{\rm State}_3,{\rm State}_5,...\) and since it is long enough, we can imagine half of the total samples should also be in steady state, thus, the distribution \(\pi_1(q)\) and \(\pi_2(q)\) for both part should be the same, same as the total distribution, and more interestingly, the second part can be view as the acceptance result of part 1. So we can interpret these two parts as: first we have a sample (group 1), them we let all particles in this sample to do a transition (no matter successful or not), and we get the group 2.
Then we just randomly choose a random position \(q_1\), in the sample, they appears \(n_1 = n*\pi(q_1)\) times. Then total successful transition from \(q_1\) to all other places is:
\[\begin{equation} n_{1,\rm out} = \int n*\pi(q_1)Q(q\vert q_1)p(q_1\rightarrow q) \ dq \end{equation}\]Where \(Q(q_1\vert q)\) is the probability of proposing \(q_1\) as the next destination. Samely, successful transitions from all other places to \(q_1\) is: \(\begin{equation} n_{1,\rm in} = \int n*\pi(q)Q(q_1\vert q)p(q\rightarrow q_1)\ dq \end{equation}\)
Then as we talked just now, a steady state means \(n_{1,\rm out} = n_{1,\rm in}\) at each location, thus we have:
\[\begin{equation} \int n*\pi(q_1)Q(q\vert q_1)p(q_1\rightarrow q) \ dq = \int n*\pi(q)Q(q_1\vert q)p(q\rightarrow q_1)\ dq \end{equation}\]Differetiate both sides yield:
\[\begin{equation} \pi(q_1)Q(q\vert q_1)p(q_1\rightarrow q) = \pi(q)Q(q_1\vert q)p(q\rightarrow q_1) \end{equation}\]Then we can see that if we just use a location proposal function \(Q\) that satifies \(Q(q\vert q_1)=Q(q_1\vert q)\), then the above equation comes to:
\[\begin{equation} \pi(q_1)p(q_1\rightarrow q) = \pi(q)p(q\rightarrow q_1) \end{equation}\]Applied the equation \eqref{eqn:tranprob}, we can get:
\[\begin{equation} \pi(q_1)/p(q_1) = \pi(q)/p(q) \end{equation}\]Since integration of \(\pi(q)\) and \(p(q)\) are all 1, one can find that \(\pi(q) = p(q)\) must hold. We’ve already know the desired distribution can be gained once the detailed balance is achieved, and we just proved that if we construct this chain follows the detailed balance rule, we will get the target distribution. I’ve forgotten how do we prove the equivalence between equilibrium state and detailed balance, but I guess it should be similar.
So far so good? But the traditional MCMC have some problem, and I’ll talk about how do we deal with the problem with Hamitonian Monte-Carlo.
Hamitonian Monte Carlo
The main setback of vanilla MCMC lies in the low acceptance ratio, espeially in high dimensional space. Let’s consider an example: imagine your proposal kernel \(Q\) will propose next position randomly inside a \(n-\)dimensional sphere, with radius \(R\), and the true probablity distribution is a also a sphere with radius \(R\): \(p(q) = const\) inside sphere and 0 otherwise. Now the starting point lies at the edge of this sphere.
Then in \(1-\)D case, the probability of proposing a location inside this const prob sphere is simply 1/2. However in \(2-\)D case, this probability shrinks to \(\sim 0.35\); in \(3-\)D case, this probability shrinks to \(0.25\). Thus, If you’re working on high-dimensional parameter space, for example in LCDM we have 6 parameters, this method would be overwhelmingly inefficient, sometimes even fails: the states would be trapped in a small region, never seeing the whole distribution.
Now we’re clear about the enemy – low acceptance ratio, which lead to biased result. Someone had a crazy idea about how to solve the problem: back to equation \eqref{eqn:tranprob}, the low ratio came from the low probability of finding a good target to get high \(p(q_1)/p(q_0)\). How to find a good target each time? How about making a chain with 100% acceptance probability?
Then the problem becomes, is there anyway keeping the \(\exp^{-U(q)}\) the same after proposing a new location? I bet you think of one theory: Hamitonian Dynamics. So far we’ve been neglect the kinetic energy since we don’t want extra useless quantity, but soon we can prove it’s not useless, actually extremely helpful.
If we introduce momentum \(p\) associated with \(q\), we can write down the full Hamitonian as \(H(p,q) = T + U = p^2/2m + U(q)\). Thus in this case, the equilibrium state follows \(p(p,q)\propto \exp^{-H(p,q)}\). Now the interesting thing happens, \(\exp^{-H(p,q)} =\exp^{-T(p)}\exp^{-U(q)}\), thus, if we integrate over all \(p\)s to get the marginalized distribution \(p(q)\), it’s still \(p(q)\propto \exp^{-U(q)}\). In this case, we can say that the \(p\) variable is totally irrelavent to \(q\) variable. So introducing this \(p\) didn’t change anything about \(q\), we just need to drop all the \(p\)s when summerizing our result.
Now how do this new parameter change our life? The Hamitonian physics gives us one way of maintaining the \(H\) during finding locations. Following Haminotian equation, the \(H\) automatically conserves. More interestingly, we know the distribution of \(p\): in probability density, the relative part is \(e^{p^2/2m}\), this is exactly the form of Gaussian distribution, upto a adjustable parameter \(m\). Thus, we know very well how to sample \(p\). But how do this related with sampling \(q\)? One guy had a brilliant idea one day: instead of directly sampling a new \(p,q\) pair, we could just sample a new \(p\) according to Gaussian distribution and hold \(p\), and evolve the \(p,q\) with Hamitonian dynamics for a constant “time”. This looks nice, but how do we justify this follows the detailed balance? Here is a simple proof:
Suppose our current state is in \(q_0\), target state is \(q_1\), eqn. \eqref{eqn:detbal} is \(\begin{equation} p_t(q_0 \rightarrow q_1) p(q_0) = p_t(q_1 \rightarrow q_0) p(q_1) \end{equation}\)
Suppose there is one single \(p_0\) to satisfy that \(p_0,q_0\) goes to \(p_1,q_1\) after Hamitonian process. Then the LHS of eqn. \eqref{eqn:detbal} becomes: \(\begin{equation} p_t(q_0 \rightarrow q_1)p(q_0) = p(p_0)p(q_0) = e^{-p_0^2/2m-U(q_0)} \end{equation}\) Since we know the hamitonian process is invertible, i.e. if \(p_0,q_0\) goes to \(p_1,q_1\), then \(-p_1,q_1\) goes to \(-p_0,q_0\).The RHS becomes: \(\begin{equation} p_t(q_1 \rightarrow q_0)p(q_1) = p(-p_1)p(q_1) = e^{-p_1^2/2m-U(q_1)} \end{equation}\)
Following Hamitonian conservation, we simply have \(e^{-p_1^2/2m-U(q_1)} = e^{-p_0^2/2m-U(q_0)}\), thus detailed balance still holds.
If there’re more than one \(p_0\) that leads to \(q_1\), we can simply rewrite LHS as: \(\begin{equation} p_t(q_0 \rightarrow q_1)p(q_0) = (p(p_01)+p(p_03)+p(p_02)+...+p(p_0n))p(q_0) \end{equation}\)
For every single \(p(p_0i)\), the detailed balance is valid as our derivation before, so the summation over all \(p_0s\) must follows detailed balance.
For now we’ve finished explaining HMC, mainly on two aspects: why this is faster? because it proposes better new states. why it follows the detailed balance? Because this Hamitonian dynamics is invertible. So this is a much better method, in practice it can even sample ~10k dimensions of parameters. But what’s the price?