It’s Only Natural: An Excessively Deep Dive Into Natural Gradient Optimization

The premise of having a small learning rate is that we know that a single local estimate of gradient may only be valid in a small local region around that estimate.

But, parameters can exist on different scales, and can have different degrees of impact on your learned conditional distribution.

And, this degree of impact can fluctuate over the course of training.

From this perspective, it doesn’t look like defining a safety bubble in terms of a fixed, global radius in Euclidean parameter space is a particularly sensible or meaningful thing to do.

Image Credit: Science Magazine, because I don’t have the software to make a cool image like this, and damnit the subtleties of gradient calculation techniques are not an easy thing to visualizeA counter-proposal, implicitly made by proponents of natural gradient, is that instead of defining our safety window in terms of distance in parameter space, we should define it in terms of distance in distribution space.

So, instead of “I’ll follow my current gradient, subject to keeping the parameter vector within epsilon distance of the current vector,” you’d instead say “I’ll follow my current gradient, subject to keeping the distribution my model is predicting within epsilon distance of the distribution it was previously predicting”.

The notion here is that distances between distributions are invariant to any scaling or shifting or general re-parameterizing.

For example, the same Gaussian can be parameterized using either a variance parameter or a scale parameter (1/variance); if you looked in parameter space, two distributions would be different distances apart based on whether they were parameterized using variance or scale.

But if you defined a distance in raw probabilities space, it would be consistent.

The rest of this post will try to build up a stronger, more intuitive understanding of an approach called Natural Gradient learning, a conceptually elegant idea that seeks to correct this arbitrariness of scaling in parameter space.

I’ll dive into how it works, how to build bridges between the different mathematical ideas that make it up, and ultimately discuss if and where it’s actually useful.

But, first: what does it mean to calculate distances between distributions?Licensed to KLThe KL Divergence, or, more properly the Kullback-Leibler divergence, is not technically a distance metric between distributions (mathemeticians are picky about what gets to be called a metric, or a proper distance), but it’s a pretty close approximation of that idea.

Mathematically, it’s taken by calculating the expected value of the ratio of log probabilities (that is, the raw difference in probability value) taken over values of x sampled from one distribution or the other.

That fact that the expectation is taken over one of the distributions or the other makes it an asymmetric measure, where KL(P||Q) != KL(Q||P).

But, in many other ways, KL divergence maps to our notion of what a probability distance should look like: it measures directly in terms of how probability density functions are defined, that is, differences in density value at a bunch of points over which the distribution is defined.

This has a very practical aspect to it, whereby distributions are seen as more different when they have farther-apart answers to the question of “what’s the probability at this X” for a broad series of X.

In the context of Natural Gradient, KL divergence is deployed as a way of measuring the change in the output distribution our model is predicting.

If we’re solving a multi-way classification problem, then the output of our model will be a softmax, which can be seen as a multinomial distribution, with different probability placed on each class.

When we talk about the conditional probability function defined by our current parameter values, this is the probability distribution we’re talking about.

If we use a KL divergence as a way of scaling our gradient steps, that means that we see two parameter configurations as “farther apart” in this space if they would induce predicted class distributions that are very different, in terms of a KL divergence, for a given input set of features.

The Fisher ThingSo far, we’ve discussed why scaling the distance of our update step in parameter space is unsatisfyingly arbitrary, and suggested a less arbitrary alternative: scaling our steps such that only go, at maximum, a certain distance in terms of KL Divergence from the class distribution our model had previously been predicting.

For me, the most difficult part of understanding Natural Gradient was this next part: the connection between KL Divergence and the Fisher Information Matrix.

Starting with the ending of the story, Natural Gradient is implemented like this:“Natural Gradient is defined as…”The def over the equals sign means that what follows on the right is the definition of the symbol on the left.

The right hand term is composed of two parts.

First, there’s the gradient of your loss function with respect to parameters (this is the same gradient that’d be used in a more normal gradient descent step).

The “natural” bit comes from the second component: the expected value, taken over z, of the squared gradient of the log probability function.

We take that whole object, which is referred to as the Fisher Information Matrix, and multiply our loss gradient by its inverse.

The p-theta(z) term is the conditional probability distribution defined by our model, which is to say: the softmax at the end of a neural net.

We’re looking at the gradient of all of all of the p-theta terms, because we care about the amount that our predicted class probabilities will change as a result of change in parameters.

The greater the change in predicted probabilities, the greater the KL divergence between our pre-update and post-update predicted distributions.

Part of what makes natural gradient optimization confusing is that, when you’re reading or thinking about it, there are two distinct gradient objects you have to understand and contend which, which mean different things.

As an aside, this gets unavoidably pretty deep into the weeds, particularly when discussing the likelihood, and it’s not really necessary to grasp the overall intuition; feel free to skip to the next section if you disprefer going through all the gory details.

Gradient with respect to lossTypically, your classification loss is a cross entropy function, but more broadly, it’s some function that takes as input your model’s predicted probability distribution and the true target values, and has higher values when your distribution is farther from the target.

The gradient of this object is the core bread and butter of gradient descent learning; it represents the amount your loss would change if you moved each parameter by one unit.

The gradient of the log likelihoodThis was hands-down the most confusing part of learning natural gradient for me.

Because, if you read about the Fisher Information Matrix, you’ll get lots of links explaining that it is related to the gradient of the log likelihood of the model.

My previous understanding of the likelihood function was that it represented how probable your model thought some set of data was; in particular, you needed target values to calculate it, because your goal was to calculate the probability that your model assigned to the true target, when you conditioned it on the input features.

In most contexts where likelihood is discussed, such as the very common Maximum Likelihood technique, you care about the gradient of log likelihood because the higher your likelihood, the higher probability your model is assigning values sampled from the true distribution, and the happier we all are.

In practice, this looks like calculating the expected value of the p(class|x) gradient, with the probabilities inside the expectation drawn from the actual class distribution in the data.

However, you can also evaluate likelihood in another way, and, instead of calculating your likelihood with respect to the true target values (where you would expect to have a non-zero gradient, because it’d possible to push your parameters to increase probability to the true targets), you can calculate your expectation using probabilities drawn from your conditional distribution itself.

That is, if your network results in a softmax, instead of taking the expectation of logp(z) with 0/1 probabilities based on the true class in the data for a given observation, use the model’s estimated probability for that class as its weight in the expectation.

This would lead to an overall expected gradient of 0, because we’re feeding in our model’s current belief as the ground truth, but it we can still get estimates of the variance of the gradient (i.

e.

the gradient squared), which are what’s needed in our Fisher matrix to (implicitly) calculate the KL divergence in predicted class space.

So… Does It Help?This post has spent a lot of time talking about mechanics: what exactly is this thing called a natural gradient estimator, and what are better intuitions about how and why it works.

But I feel like I’d be remiss if I didn’t answer the question of: does this thing actually provide value?The short answer is: practically speaking, it doesn’t provide compelling enough value to be in common use for most deep learning applications.

There is evidence of natural gradient leading to convergence happening in fewer steps, but, as I’ll discuss later, that’s a bit of a complicated comparison.

The idea of natural gradient is elegant and satisfying to people frustrated by the arbitrariness of scaling update steps in parameter space.

But, other than being elegant, it’s not clear to me that it’s providing value that couldn’t be provided via more heuristic means.

As far as I can tell, natural gradient is providing two key sources of value:It’s providing information about curvatureIt’s providing a way to directly control movement of your model in predicted distribution space, as separate from movement of your model in loss spaceCurvatureOne of the great wonders of modern gradient descent is that it’s accomplished with first-order methods.

A first order method is one that only calculates derivatives with respect to the parameters you want to update, and not second derivatives.

With a first derivative, all you know is the (many dimensional version of) a tangent line to your curve at a specific point.

You don’t know how quickly that tangent line is changing: the second derivative or, more descriptively, the level of curvature that your function has in any given direction.

Curvature is a useful thing to know because in an area of high curvature, where gradients are changing dramatically from point to point, you may want to be cautious taking a large step, lest your local signal of climbing a steep mountain mislead you into jumping off the cliff that lies just beyond.

A (admittedly more heuristic than it is rigorous) way I like to think about this is that if you’re in a region where gradients from point to point are very variable (that is to say: high variance), then your minibatch estimate of the gradient is in some sense more uncertain.

By contrast, if the gradients are barely changing at a given point, less caution is needed in taking your next step.

Second order derivative information is useful because it lets you scale your steps according to the level of curvature.

What Natural Gradient is actually, mechanically, doing, is dividing your parameter updates by the second derivative of a gradient.

The more the gradient is changing with respect to a given parameter direction, the higher the value in the Fisher Information Matrix, and the lower the update step in that direction will be.

The gradient in question here is the gradient of the empirical likelihood for the points in your batches.

That’s not the same thing as the gradient with respect to loss.

But, intuitively, its going to be rare that a dramatic change in likelihood doesn’t correspond to dramatic change in the loss function.

So by capturing information about the curvature of the log-likelihood-derivative space at a given point, Natural Gradient is also giving us a good signal of the curvature in our true, underlying loss space.

There’s a pretty strong argument that, when Natural Gradient has been shown to speed convergence (at least in terms of number of needed gradient steps), that’s where the benefit was comes from.

However.

Notice that I said that Natural Gradient is shown to speed up convergence in terms of gradient steps.

That precision comes from the fact that each individual step of Natural Gradient takes longer, because it requires calculating a Fisher Information Matrix, which, remember, is a quantity that exists in n_parameters² space.

That dramatic slowdown is, in fact, analogous to the slowdown induced by calculating the second-order derivatives of the true loss function.

While it may be the case, I haven’t seen it stated anywhere that calculating the Natural Gradient Fisher matrix is faster than calculating second derivatives with respect to underlying loss.

Taking that as an assumption, it’s hard to see what marginal value Natural Gradient is providing when compared to the (also, possibly equally, costly) approach of doing direct second-order optimization on the loss itself.

A lot of the reason that modern neural networks have been able to succeed where theory would predict that a first-order-only method would fail is that Deep Learning practitioners have found a bunch of clever tricks to essentially empirically approximate the information that would be contained in an analytic second-derivative matrix.

Momentum as an optimization strategy works by keeping a running exponentially weighted average of past gradient values, and biasing any given gradient update towards that past moving average.

This helps solve the problem of being in a part of space where gradient values are varying wildly: if you’re constantly getting contradictory gradient updates, they’ll average out to not having a strong opinion one way or another, analogous to slowing down your learning rate.

And, by contrast, if you’re repeatedly getting gradient estimates that go in the same direction, that’s an indication of a low-curvature region, and suggests an approach of larger steps, which Momentum follows.

RMSProp, which, hilariously, was invented by Geoff Hinton mid-Coursera-course, is a mild modification of a previously existing algorithm called Adagrad.

RMSProp works by taking a exponentially weighted moving average of past squared gradient values, or, in other words, the past variance of the gradient, and dividing your update steps by that value.

This can be roughly thought of as a empirical estimate of the second derivative of the gradient.

Adam (Adaptive moment estimation), essentially combines both of these approaches, estimating both the running mean and running variance of the gradient.

It’s one of the most common, and most default-used optimization strategies today, largely because it has the effect of smoothing out what would otherwise be this noisy, first-order gradient signalSomething interesting, and also worth mentioning with all these approaches is that, in addition to generally scaling update steps in terms of function curvature, they scale different directions of update differently, according to the amount of curvature in those specific directions.

This gets to something we discussed earlier, about how scaling all parameters by the same amount might not be a sensible thing to do.

You can even think of this a bit in terms of distance: if curvature in a direction is high, then a step in the same amount of Euclidean parameter space would move us farther in terms of expected change in gradient value.

So, while this doesn’t have quite the elegance of natural gradient in terms of defining a coherent direction for parameter updates, it does check most of the same boxes: the ability to adapt your update steps in directions, and at points in time, where curvature differs, and, notionally, where given-sized parameter steps have differing degrees of practical impact.

Direct Distributional ControlOkay, so, the last section argued: if our goal is to use analytic curvature estimates of the log likelihood as a stand-in for curvature estimates of the loss, why don’t we just do the latter, or approximate the latter, since both analytic N² calculations appear to be quite time costly.

But what if you’re in a situation where you actually care about the change in the predicted class distribution for its own sake, and not just as a proxy for change in loss?.What would a situation like that even look like?One example of such a situation is, presumably not coincidentally, one of the major areas of current use for Natural Gradient approaches: Trust Region Policy Optimization in the area of reinforcement learning.

The basic intuition for TRPO is wrapped up in the idea of catastrophic failure, or catastrophic collapse.

In a policy gradients setting, the distribution you’re predicting at the end of your model is the distribution over actions, conditional on some input state.

And, if you’re learning on-policy, where the data for your next round of training is being collected from your model’s current predicted policy, it’s possible to update your policy into a regime where you can no longer collect interesting data to learn your way out of (for example, the policy of spinning around in a circle, which is unlikely to get you useful reward signal to learn from).

This is what it means for a policy to undergo catastrophic collapse.

To avoid this, we want to exercise caution, and not do gradient updates that would dramatically change our policy (in terms of the probabilities we place on different actions in a given scenario).

If we are cautious and gradual in terms of how much we let our predicted probabilities change, that limits our ability to suddenly jump to an unworkable regime.

This is, then, a stronger case for Natural Gradient: here, the actual thing we care about controlling is how much the predicted probabilities of different actions change under a new parameter configuration.

And we care about it in its own right, not just as a proxy for the loss function.

Open QuestionsI like to conclude these posts by letting you know what areas of confusion I still have about a topic, because, much though the framing of a explainer post implies a lofty position of complete comprehension, that’s not quite the reality.

As always, if you notice something you think I’ve gotten wrong, please comment/DM and I’ll work to correct it!I never conclusively found out whether calculating the log likelihood Fisher matrix is more efficient than just calculating the Hessian of the loss function (if it were, that would be an argument for Natural Gradient being a cheaper way to get curvature information about the loss surface)I am relatively but not totally confident that, when we take an expected value over z of the log probabilities, that expectation is being taken over the probabilities predicted by our model (An expected value has to be defined relative to some set of probabilities).

References“Revisiting natural gradient for deep networks”“Natural Neural Networks”“Why Natural Gradient”Sebastian Ruder’s amazing post on adaptive gradient methods.