Bayesian inference primer: variational inference
We’ve seen the major drawbacks in direct sampling, local curvature, and random walk approaches to Bayesian inference. There are a huge number of hyperparameters to tune and no guarantees on their correctness in any real-world application. Worse, they are liable to fail completely, or take forever to compute. What if we had an approach that had some reasonable guarantees on correctness, and could be fast to boot? Enter our heavyweight Variational Inference. But watch out, he has a big chip on his shoulder, so don’t think the competition is over just yet.
Let’s think again about our likelihood function and what makes it difficult. It could take any form, with massive peaks that could appear anywhere, that could possibly be infinitely steep, and all the variables could be jumbled together, or require simulations to compute. If only we could make our lives easier. Well, what if we just assumed it was easier, and worked from there?
Like the local curvature methods, variational inference is an approximate solution that simply has better guarantees. Variational inference takes a thorny sampling problem and converts it into an optimization problem. In short, variational inferences says “what if we approximated the likelihood function by some other function with well-behaved parameters, and optimize those insead?”
After all, we can already fit curves using these methods, such as finding the best width and height of a multivariate Gaussian to fit a set of data (in Python, a common approach is the leastsq function in scipy to minimize the squared residuals). And some proofs show that if we minimize a measure of distance between two distributions with a horrible name (the Kullback–Leibler divergence), we can obtain a lower-bound on the evidence. However, what form should the approximation function take? Gaussians are a good bet, but we can actually use anything we want.
This most common version of variational inference that you’ll encounter is called the Mean Field Approximation, and has a genius formulation: we use a one-dimensional function (say, a one-dimensional Gaussian) for each parameter, and multiply them together to produce the likelihood function. This means that we assume that each dimension of our parameter vector θ has an independent contribution to the likelihood function.This means that there is no correlation in how the parameters affect the likelihood, and it can be thought of in the Gaussian case as a Gaussian with a diagonal covariance matrix, that is, a Gaussian that can’t skew.
Just like with the k-means clustering algorithm, what we do is we hold all the pseudo-parameters fixed except for one, optimize the pseudo-parameters of the one-dimensional approximation function for the given parameter, and iterate. Easy-peasy! (Note: we call them pseudo-parameters to clarify that they are not the same parameters as those of the underlying, more complex function we are approximating.) Even better, it runs quickly, and in many cases is guaranteed to converge to some solution (as opposed to just bailing like the other approaches).
For this approach, the practitioner can use:
PyMC3 has an experimental version of Variational Inference, but it hasn’t been vetted yet.
Pyro is based on variational inference, which is why it is the library of choice for Uber. It also has excellent tutorials on how to use it to solve your problem. This is the one you should use.
Tensorflow Probability provides methods for this inference type.
Note of warning: if you search the literature on variational Bayesian inference, you’ll find a cornucopia of derived solutions for specific problems. These articles will make your head spin, with page after page of proofs. What they’re doing is deriving the optimization equations for a specific set of approximation functions. You’ll probably never need to implement them or master them, and they are a form of math-flexing. Use the general-purpose solutions like you’ll find in Pyro instead.