AI

Is PyTorch’s Nesterov Momentum Implementation Fallacious? | by Jason Vega | Sep, 2023

Momentum helps SGD traverse advanced loss landscapes extra effectively. Photograph by Maxim Berg on Unsplash.

In the event you look carefully at PyTorch’s documentation of SGD, you’ll find that their implementation of Nesterov momentum has just a few variations from the formulation discovered within the original paper. Most notably, PyTorch’s implementation evaluates the gradient on the present parameters, whereas the entire level of Nesterov momentum is to judge the gradient at shifted parameters. Sadly, it seems that dialogue about these discrepancies on the web is scarce. On this publish, we’ll look at and clarify the variations between PyTorch’s implementation and the unique formulation of Nesterov momentum. Finally, we’ll see how PyTorch’s implementation isn’t mistaken, however slightly an approximation, and speculate about the advantage of their implementation.

The original paper describes Nesterov momentum utilizing the next replace guidelines:

the place v_{t+1} and θ_{t+1} are the speed vector and mannequin parameters respectively at time t, μ is the momentum issue, and ε is the training charge. The notice in PyTorch’s SGD documentation states they use the next replace guidelines:

the place g_{t+1} represents the gradient used to compute v_{t+1}. We are able to increase the replace rule for θ_{t+1} to get:

From this we will infer that:

and the replace guidelines change into:

These are the replace guidelines that PyTorch makes use of in concept. I discussed earlier that PyTorch truly evaluates the gradient on the present parameters as an alternative of the shifted parameters. This may be seen by wanting on the algorithm description within the PyTorch SGD documentation. We’ll examine this additional in a while.

Be aware that for each the unique (1, 2) and PyTorch (3, 4) formulations, if v_0 = 0, then the primary replace to θ turns into:

Though the PyTorch SGD documentation notice states that the algorithm initializes the momentum buffer to the gradient at step one, we’ll later present that this means v_0 = 0.

There are two quick variations when going from the unique (1, 2) to the PyTorch (3, 4) formulation:

  1. The educational charge is moved outdoors of v_{t+1}.
  2. Within the replace rule for v_{t+1}, the time period involving the gradient is added as an alternative of subtracted, and within the replace rule for θ_{t+1}, the time period involving the speed vector is subtracted as an alternative of added. The distinction in signal contained in the gradient time period is solely a consequence of this as proven within the earlier part.

To grasp these variations, let’s first increase the replace guidelines. As hinted at here, the impact of the primary distinction is extra obvious if we contemplate studying charge schedules. So, we contemplate a generalization of the replace guidelines the place ε is now not fastened however can now fluctuate over time, and denote ε_t as the training charge at time step t. For brevity, let:

Assuming v_0 = 0, the unique formulation turns into:

and the PyTorch formulation turns into:

Within the unique formulation (6), if the training charge had been to vary at time t, then solely the magnitude of the time period at i = t within the summation could be affected, and the magnitudes of all the opposite phrases would stay the identical. In consequence, the quick affect of the training charge change is kind of restricted, and we must look forward to the training charge change to “trickle” down over subsequent time steps to have a stronger affect on the general step dimension. In distinction, within the PyTorch formulation (7), if the training charge had been to vary at time t, then the magnitude of your entire step could be affected instantly.

For v_0 = 0, it’s clear from the expanded guidelines that the second distinction in the end has no impact; in both formulation, the step works out to a reduced sum of gradients that’s subtracted from the present parameters.

Ignoring weight decay and dampening, by analyzing the SGD algorithm in PyTorch’s documentation, we will see that the applied replace guidelines are:

the place θ’_{t+1} are the mannequin parameters at time t and

We’ll confer with equations 3 and 4 because the PyTorch “notice” formulation, and equations 8 and 9 because the PyTorch “applied” formulation. We make a distinction between θ and θ’ for a purpose that can change into obvious quickly. Essentially the most obtrusive distinction from the notice formulation is that the gradient is evaluated on the present parameters slightly than the shifted parameters. From this alone it might seem that the replace guidelines the algorithm implements isn’t a correct implementation of Nesterov momentum.

We’ll now look at how the PyTorch algorithm in the end approximates Nesterov momentum. Derivations for an older model of PyTorch might be discovered here from Ivo Danihelka, referenced in this GitHub issue. Derivations for the present model of PyTorch might be discovered here, which is a comparatively simple adjustment from the earlier derivations. We offer a LaTeX rendering of those (re-derived) derivations right here for the reader’s comfort. The applied formulation is derived by a easy change of variables. Particularly, we let:

It instantly turns into clear that the notice replace rule for v_{t+1} (3) turns into equal to the applied replace rule for v_{t+1} (8) after the change of variables. We now wish to derive an replace rule for θ’_{t+1} by way of θ’_t:

That is precisely the replace rule we noticed applied in PyTorch (9). At a excessive stage, the PyTorch implementation assumes the present parameters θ’_t are already the shifted model of the “precise” parameters θ_t. Therefore, at every time step, the “precise” parameters θ_t are associated to the present parameters θ’_t by:

Nonetheless, it seems from the supply code that the PyTorch SGD implementation doesn’t make any correction on the finish of the algorithm to retrieve the ultimate “precise” parameters, so the ultimate output is technically an approximation of the “precise” parameters.

Lastly, we now present that v_0 have to be 0:

Furthermore, we will verify that the primary replace to the “precise” parameters is identical first replace made within the unique formulation when v_0 = 0:

We are able to see that that is equal to equation 5.

In fact, the massive remaining query is: Why does PyTorch hassle in any respect to reformulate Nesterov momentum from equations 3 and 4 to equations 8 and 9? One doable rationalization is that the reformulation would possibly present some financial savings within the variety of arithmetic operations required. To guage this doable rationalization, let’s rely the variety of arithmetic operations. For the notice formulation (3, 4), we’ve got:

Right here, there are a complete of seven operations. For the applied formulation (8, 9), we’ve got:

Right here, there are a complete of six operations. The second gradient within the PyTorch implementation simply makes use of the saved outcome from the primary gradient computation, so just one gradient computation is carried out at every time step. So, one obvious profit is that the PyTorch implementation cuts down on one further multiplication operation at every step.

In abstract:

  1. The replace guidelines acknowledged in PyTorch’s SGD documentation notice (3, 4) have a distinct location for the training charge in comparison with the unique Nesterov momentum replace guidelines (1, 2). This enables studying charge schedules to have an instantaneous impact on the general step dimension, whereas the unique formulation would have the impact of studying charge modifications to “trickle” down over subsequent time steps.
  2. The replace guidelines applied within the PyTorch SGD algorithm (8, 9) are an approximation to the replace guidelines acknowledged within the documentation notice (3, 4) after a easy change of variables. Though the “precise” parameters are simply recoverable from the present parameters at every time step, the PyTorch implementation doesn’t make any such correction on the finish of the algorithm, and so the ultimate parameters technically stay an approximation of the “precise” remaining parameters.
  3. An obvious good thing about the PyTorch implementation is that it avoids an extra multiplication operation at every time step.
  1. “SGD.” SGD — PyTorch 2.0 Documentation, pytorch.org/docs/steady/generated/torch.optim.SGD.html. Accessed 2 Sept. 2023.
  2. Sutskever, Ilya, et al. “On the importance of initialization and momentum in deep learning.” Worldwide Convention on Machine Studying. PMLR, 2013.
  3. Danihelka, Ivo. “Nesterov’s Momentum Made Simple.” 25 Aug. 2012.
  4. Chintala, Soumith. “nesterov momentum is mistaken in sgd · Subject #27 · torch/optim.” GitHub, 13 Oct. 2014, github.com/torch/optim/issues/27.
  5. Gross, Sam. “Add a notice within the docs concerning the momentum formulation utilized in optim · Subject #1099 · pytorch/pytorch.” GitHub, 25 Mar. 2017, github.com/pytorch/pytorch/issues/1099#issuecomment-289190614.
  6. Zhao, Yilong. “repair Nesterov Momentum Bug · Subject #5920 · pytorch/pytorch.” GitHub, 21 Mar. 2018, https://github.com/pytorch/pytorch/pull/5920#issuecomment-375181908.

Related Articles

Leave a Reply

Your email address will not be published. Required fields are marked *

Back to top button