are simply linear interpolations (β * x₀ + (1 — β) * x₁) of the gradients and their variances, which gives us a moving average of each.

The higher the beta, the less we update the moving average for each new sample, thus smoothing our estimate of the mean and variance of the gradient across batches.

Here’s a visualization of how much smoothing we get on a noisy dataset for different betas.

Linear interpolation for various strengths of beta (from fast.

ai)If we have a small batch, our estimates of the gradient at each step might be noisier, so we’ll need a higher beta.

If we’re using a huge batch size with consistent data, we probably need less beta.

The problem with the moving averages above is that when the algorithm first initializes, the moving averages are 0.

This causes the summary statistics to be closer to 0 than they should be for the first couple of timesteps if beta is close to 1 because we are taking most of the mass from the previous step.

This effect is particularly visible in the beta=0.

99 graph above.

We fix this by debiasingWhich results in:Linear interpolation with debiasing (from fast.

ai)The problem doesn’t go away, but it’s much better.

To plug some numbers in, if β = 0.

9, on the first iteration the debias will multiply the value by 1 / 1–0.

9¹ = 10.

Then when we linearly interpolate, β????₀ + (1 — β)????₁, the first term β????₀ = 0.

The debias factor, 10, will cancel out (1 — β) = 0.

1 in the second term, so we entirely use the new value, ????₁.

After just a few steps, the debias factor will converge to 1.

The graphs below show how many steps it takes for the debias term to disappear (note the difference in the y axis):Visualizing debias factor over timeNow the final parameter update isThe numerator says “for every parameter, take a step in the direction of the gradient for that parameter.

” The denominator says “normalize the step by its standard deviation.

”The intuitive interpretation is that when we first start updating parameters, we’ll probably be way off.

If the gradients are all pointing in different directions (high variance), we’ll take a small, cautious step.

Conversely, if all the gradients are telling us to move in the same direction, the variance will be small, so we’ll take a bigger step in that direction.

Either way, if the scale of all the gradients is large, the constant factor will cancel out when we divide since we’re using the uncentered variance.

As training stabilizes and loss gets closer to 0, the mean will approach 0, so updates will automatically get finer.

The ε in the denominator says “ok, we may think we’ve got no noise at all, but let’s not go too crazy here and just take one step at a time.

” This effectively sets an upper bound on the size of the step you take as the noise variance approaches zero.

This ratio m/sqrt(v) might look like μ/σ, which is the signal to noise ratio, but that interpretation only applies to scalars.

LARSAs batch size grows, the number of iterations per epoch decreases.

To converge in the same number of dataset iterations, we can compensate by increasing the learning rate.

However as learning rate increases, training becomes more unstable.

The SOTA was to use a learning rate warm up, but this is only helped up to a certain point, at which the learning would start diverging anyway.

The warmup was a patch over the real issue: the gradients must be noisy.

The authors of Layerwise Adaptive Rate Scaling (LARS) explain their trick to solve this problem:To analyze the training stability with large LRs we measured the ratio between the norm of the layer weights and norm of gradients update.

We observed that if this ratio is too high, the training may become unstable.

On the other hand, if the ratio is too small, then weights don’t change fast enough.

They call this ratio the “trust ratio”.

When it’s higher, the gradients change faster and vice versa.

Since we can now be more confident of each step, the cautionary warm-up often used in learning rate schedules is no longer necessary and we can scale to much bigger batch sizes without diverging.

In English: the layer-wise learning rate λ is the global learning rate η times the ratio of the norm of the layer weights to the norm of the layer gradients.

If we use weight decay, we can just add it in the denominator.

When we plug this into SGD, the denominator ends up normalizing the gradients to have unit norm, which helps avoid divergence.

The numerator is the norm of the weights because as networks deepen, it becomes important to have zero mean, unit variance (ZMUV) weights.

This is because at each layer, these weights are multiplied together, so if it diverges from ZMUV, the values may explode or vanish.

When weights are small, we take a small step.

When weights are large we take a bigger step.

Combined with weight decay this helps us stably step towards ZMUV weights.

Let’s get a sense of what’s going on here.

At the beginning of training, layers are supposed to output ZMUV, so the numerator above will be 0 or close to it.

Any steps we take are likely to be small.

In contrast, the denominator will probably be large since when everything is wrong, the gradients are large.

In this way we naturally warm up as the weights increase.

As we approach 0 loss, the gradients will be small, so the trust ratio will keep the learning rate up to 10x (due to clipping) higher than without the trust ratio, keeping us from giving up on reaching the optimum too early.

LAMBLAMB stands for “Layer-wise Adaptive Moments optimizer for Batch training.

” It makes a few small changes to LARSIf the numerator (r₁ below) or denominator (r₂ below) of the trust ratio is 0, then use 1 instead.

This section was hard to read, so I based this on some code.

Fixing weight decay: in LARS, the denominator of the trust ratio is |∇L| + β |w|, whereas in LAMB it’s |∇L + β w|.

This preserves more information.

Instead of using the SGD update rule, they use the Adam update rule.

Clip the trust ratio at 10.

So in full, the trust ratio in LAMB isThe final line is the layer-wise LAMB update rule.

????₂ is the norm of the Adam update rule with weight decay, ηᴸ is the layer-wise learning rate adjusted by the trust ratio.

So overall this method can be summarized as LARS applied to Adam, since it’s just multiplying the old update step by the trust ratio.

The authors don’t report whether LAMB improves ImageNet training performance over LARS, and they don’t compare LARS to LAMB for BERT, so it’s a bit hard to say how much difference these changes make, but implementation is pretty simple.

ExperimentsTo get a better sense of what’s going on, I implemented LAMB in Pytorch.

I ran a bunch of experiments on MNIST and found that where Adam diverges, LAMB keeps chugging.

I chose MNIST because it’s tiny enough to try on CPU, but it means we can’t see any convergence improvements.

I’ll publish another post exploring LAMB applied to big transformers soon!I visualized some of the experiments.

Below, I compare Adam (blue below) and LAMB (red below) with learning rate 0.

01 and betas .

9, .

99.

They’re pretty similar, but LAMB generalizes on test accuracy better.

MNIST training loss and test accuracy over time for Adam (blue) vs LAMB (red)To find out what’s going on under the hood, I wanted to look at the layer-wise components of the trust ratio, so I logged every value of r, r₁, and r₂ as a histogram after every few batches.

For Adam, I calculate the values, but don’t use them anywhere.

How to interpret the chart below: the Y axis shows what timestep, with the first at the top, X axis is histogram buckets, Z axis is histogram frequency.

LAMB parameter histograms on Adam (blue) vs LAMB (red) on MNISTYou can see that on both sides, r starts significantly below 1.

On the LAMB side, this creates a natural warm up period across all layers.

Then, as some of the layers start gaining larger weights and stabler gradients, r encourages them to take larger steps.

This exaggerates the norms relative to the Adam baseline.

For the next experiment, I compared LAMB to itself across learning rates 0.

1 and 0.

01.

Adam converges normally at learning rate .

01 and at 0.

1 doesn’t learn at all, so I won’t compare it here.

On the left (blue) learning rate = .

01, on the right (green) learning rate = 0.

1.

On the right, it converges almost instantly during the warmup, but then a few layer weights start to explode (see difference in X axis scale) and it diverges.

To address the weights running away, I added weight decay 0.

01 below right.

Training didn’t diverge!.Generally the trust ratio kept learning slow at less than 1, whereas in the more comfortable regime above left, it got as high as 4.

5.

SummaryVanilla SGD becomes unstable as learning rate increases.

LARS adjusts the SGD learning rate by a layer-wise trust ratio that normalizes the gradients and weights.

Adam modulates updates with debiased means normalized by debiased variances.

LAMB adjusts Adam’s learning rate by a more accurate layer-wise, clipped trust ratio.

Combining all these techniques allows us to train on large batches with a high learning rate, decreasing wall time by 100X for BERT!Thanks to Yaroslav Bulatov and Sarah Jane Hong for edits and Jeremy Howard/fast.

ai part 2 for inspiration.