Stochastic gradient descent (often abbreviated SGD) is an
iterative
Iteration is the repetition of a process in order to generate a (possibly unbounded) sequence of outcomes. Each repetition of the process is a single iteration, and the outcome of each iteration is then the starting point of the next iteration.
...
method for optimizing an
objective function
In mathematical optimization and decision theory, a loss function or cost function (sometimes also called an error function) is a function that maps an event or values of one or more variables onto a real number intuitively representing some "cost ...
with suitable
smoothness
In mathematical analysis, the smoothness of a function is a property measured by the number of continuous derivatives (''differentiability class)'' it has over its domain.
A function of class C^k is a function of smoothness at least ; t ...
properties (e.g.
differentiable
In mathematics, a differentiable function of one real variable is a function whose derivative exists at each point in its domain. In other words, the graph of a differentiable function has a non- vertical tangent line at each interior point in ...
or
subdifferentiable). It can be regarded as a
stochastic approximation
Stochastic approximation methods are a family of iterative methods typically used for root-finding problems or for optimization problems. The recursive update rules of stochastic approximation methods can be used, among other things, for solving l ...
of
gradient descent
Gradient descent is a method for unconstrained mathematical optimization. It is a first-order iterative algorithm for minimizing a differentiable multivariate function.
The idea is to take repeated steps in the opposite direction of the gradi ...
optimization, since it replaces the actual gradient (calculated from the entire
data set
A data set (or dataset) is a collection of data. In the case of tabular data, a data set corresponds to one or more table (database), database tables, where every column (database), column of a table represents a particular Variable (computer sci ...
) by an estimate thereof (calculated from a randomly selected subset of the data). Especially in
high-dimensional optimization problems this reduces the very high
computational burden, achieving faster iterations in exchange for a lower
convergence rate.
The basic idea behind stochastic approximation can be traced back to the Robbins–Monro algorithm of the 1950s. Today, stochastic gradient descent has become an important optimization method in
machine learning
Machine learning (ML) is a field of study in artificial intelligence concerned with the development and study of Computational statistics, statistical algorithms that can learn from data and generalise to unseen data, and thus perform Task ( ...
.
Background
Both
statistical
Statistics (from German language, German: ', "description of a State (polity), state, a country") is the discipline that concerns the collection, organization, analysis, interpretation, and presentation of data. In applying statistics to a s ...
estimation
Estimation (or estimating) is the process of finding an estimate or approximation, which is a value that is usable for some purpose even if input data may be incomplete, uncertain, or unstable. The value is nonetheless usable because it is d ...
and
machine learning
Machine learning (ML) is a field of study in artificial intelligence concerned with the development and study of Computational statistics, statistical algorithms that can learn from data and generalise to unseen data, and thus perform Task ( ...
consider the problem of
minimizing an
objective function
In mathematical optimization and decision theory, a loss function or cost function (sometimes also called an error function) is a function that maps an event or values of one or more variables onto a real number intuitively representing some "cost ...
that has the form of a sum:
where the
parameter
A parameter (), generally, is any characteristic that can help in defining or classifying a particular system (meaning an event, project, object, situation, etc.). That is, a parameter is an element of a system that is useful, or critical, when ...
that minimizes
is to be
estimated
Estimation (or estimating) is the process of finding an estimate or approximation, which is a value that is usable for some purpose even if input data may be incomplete, uncertain, or unstable. The value is nonetheless usable because it is de ...
. Each summand function
is typically associated with the
-th
observation
Observation in the natural sciences is an act or instance of noticing or perceiving and the acquisition of information from a primary source. In living beings, observation employs the senses. In science, observation can also involve the percep ...
in the
data set
A data set (or dataset) is a collection of data. In the case of tabular data, a data set corresponds to one or more table (database), database tables, where every column (database), column of a table represents a particular Variable (computer sci ...
(used for training).
In classical statistics, sum-minimization problems arise in
least squares
The method of least squares is a mathematical optimization technique that aims to determine the best fit function by minimizing the sum of the squares of the differences between the observed values and the predicted values of the model. The me ...
and in
maximum-likelihood estimation
In statistics, maximum likelihood estimation (MLE) is a method of estimating the parameters of an assumed probability distribution, given some observed data. This is achieved by maximizing a likelihood function so that, under the assumed stati ...
(for independent observations). The general class of estimators that arise as minimizers of sums are called
M-estimator
In statistics, M-estimators are a broad class of extremum estimators for which the objective function is a sample average. Both non-linear least squares and maximum likelihood estimation are special cases of M-estimators. The definition of M-estim ...
s. However, in statistics, it has been long recognized that requiring even local minimization is too restrictive for some problems of maximum-likelihood estimation. Therefore, contemporary statistical theorists often consider
stationary point
In mathematics, particularly in calculus, a stationary point of a differentiable function of one variable is a point on the graph of a function, graph of the function where the function's derivative is zero. Informally, it is a point where the ...
s of the
likelihood function
A likelihood function (often simply called the likelihood) measures how well a statistical model explains observed data by calculating the probability of seeing that data under different parameter values of the model. It is constructed from the ...
(or zeros of its derivative, the
score function, and other
estimating equations
In statistics, the method of estimating equations is a way of specifying how the parameters of a statistical model should be estimated. This can be thought of as a generalisation of many classical methods—the method of moments, least squares, ...
).
The sum-minimization problem also arises for
empirical risk minimization
In statistical learning theory, the principle of empirical risk minimization defines a family of learning algorithms based on evaluating performance over a known and fixed dataset. The core idea is based on an application of the law of large num ...
. There,
is the value of the
loss function
In mathematical optimization and decision theory, a loss function or cost function (sometimes also called an error function) is a function that maps an event or values of one or more variables onto a real number intuitively representing some "cost ...
at
-th example, and
is the empirical risk.
When used to minimize the above function, a standard (or "batch")
gradient descent
Gradient descent is a method for unconstrained mathematical optimization. It is a first-order iterative algorithm for minimizing a differentiable multivariate function.
The idea is to take repeated steps in the opposite direction of the gradi ...
method would perform the following iterations:
The step size is denoted by
(sometimes called the ''
learning rate
In machine learning and statistics, the learning rate is a tuning parameter in an optimization algorithm that determines the step size at each iteration while moving toward a minimum of a loss function. Since it influences to what extent newly ...
'' in machine learning) and here "
" denotes the update of a variable in the algorithm.
In many cases, the summand functions have a simple form that enables inexpensive evaluations of the sum-function and the sum gradient. For example, in statistics,
one-parameter exponential families allow economical function-evaluations and gradient-evaluations.
However, in other cases, evaluating the sum-gradient may require expensive evaluations of the gradients from all summand functions. When the training set is enormous and no simple formulas exist, evaluating the sums of gradients becomes very expensive, because evaluating the gradient requires evaluating all the summand functions' gradients. To economize on the computational cost at every iteration, stochastic gradient descent
samples a subset of summand functions at every step. This is very effective in the case of large-scale machine learning problems.
Iterative method
In stochastic (or "on-line") gradient descent, the true gradient of
is approximated by a gradient at a single sample:
As the algorithm sweeps through the training set, it performs the above update for each training sample. Several passes can be made over the training set until the algorithm converges. If this is done, the data can be shuffled for each pass to prevent cycles. Typical implementations may use an
adaptive learning rate so that the algorithm converges.
In pseudocode, stochastic gradient descent can be presented as :
* Choose an initial vector of parameters and learning rate .
* Repeat until an approximate minimum is obtained:
** Randomly shuffle samples in the training set.
** For , do:
***
A compromise between computing the true gradient and the gradient at a single sample is to compute the gradient against more than one training sample (called a "mini-batch") at each step. This can perform significantly better than "true" stochastic gradient descent described, because the code can make use of
vectorization libraries rather than computing each step separately as was first shown in where it was called "the bunch-mode back-propagation algorithm". It may also result in smoother convergence, as the gradient computed at each step is averaged over more training samples.
The convergence of stochastic gradient descent has been analyzed using the theories of
convex minimization and of
stochastic approximation
Stochastic approximation methods are a family of iterative methods typically used for root-finding problems or for optimization problems. The recursive update rules of stochastic approximation methods can be used, among other things, for solving l ...
. Briefly, when the
learning rate
In machine learning and statistics, the learning rate is a tuning parameter in an optimization algorithm that determines the step size at each iteration while moving toward a minimum of a loss function. Since it influences to what extent newly ...
s
decrease with an appropriate rate,
and subject to relatively mild assumptions, stochastic gradient descent converges
almost surely
In probability theory, an event is said to happen almost surely (sometimes abbreviated as a.s.) if it happens with probability 1 (with respect to the probability measure). In other words, the set of outcomes on which the event does not occur ha ...
to a global minimum
when the objective function is
convex
Convex or convexity may refer to:
Science and technology
* Convex lens, in optics
Mathematics
* Convex set, containing the whole line segment that joins points
** Convex polygon, a polygon which encloses a convex set of points
** Convex polytop ...
or
pseudoconvex,
and otherwise converges almost surely to a local minimum.
This is in fact a consequence of the
Robbins–Siegmund theorem.
Linear regression
Suppose we want to fit a straight line
to a training set with observations
and corresponding estimated responses
using
least squares
The method of least squares is a mathematical optimization technique that aims to determine the best fit function by minimizing the sum of the squares of the differences between the observed values and the predicted values of the model. The me ...
. The objective function to be minimized is
The last line in the above pseudocode for this specific problem will become:
Note that in each iteration or update step, the gradient is only evaluated at a single
. This is the key difference between stochastic gradient descent and batched gradient descent.
In general, given a linear regression
problem, stochastic gradient descent behaves differently when
(underparameterized) and
(overparameterized). In the overparameterized case, stochastic gradient descent converges to
. That is, SGD converges to the interpolation solution with minimum distance from the starting
. This is true even when the learning rate remains constant. In the underparameterized case, SGD does not converge if learning rate remains constant.
History
In 1951,
Herbert Robbins
Herbert Ellis Robbins (January 12, 1915 – February 12, 2001) was an American mathematician and statistician. He did research in topology, measure theory, statistics, and a variety of other fields.
He was the co-author, with Richard Courant ...
and
Sutton Monro introduced the earliest stochastic approximation methods, preceding stochastic gradient descent.
Building on this work one year later,
Jack Kiefer and
Jacob Wolfowitz published
an optimization algorithm very close to stochastic gradient descent, using
central differences as an approximation of the gradient. Later in the 1950s,
Frank Rosenblatt
Frank Rosenblatt (July 11, 1928July 11, 1971) was an American psychologist notable in the field of artificial intelligence. He is sometimes called the father of deep learning for his pioneering work on artificial neural networks.
Life and career
...
used SGD to optimize his
perceptron model, demonstrating the first applicability of stochastic gradient descent to neural networks.
Backpropagation
In machine learning, backpropagation is a gradient computation method commonly used for training a neural network to compute its parameter updates.
It is an efficient application of the chain rule to neural networks. Backpropagation computes th ...
was first described in 1986, with stochastic gradient descent being used to efficiently optimize parameters across neural networks with multiple
hidden layers
In artificial neural networks, a hidden layer is a layer of artificial neurons that is neither an input layer nor an output layer. The simplest examples appear in Feedforward neural network, multilayer perceptrons (MLP), as illustrated in the diag ...
. Soon after, another improvement was developed: mini-batch gradient descent, where small batches of data are substituted for single samples. In 1997, the practical performance benefits from vectorization achievable with such small batches were first explored, paving the way for efficient optimization in machine learning. As of 2023, this mini-batch approach remains the norm for training neural networks, balancing the benefits of stochastic gradient descent with
gradient descent
Gradient descent is a method for unconstrained mathematical optimization. It is a first-order iterative algorithm for minimizing a differentiable multivariate function.
The idea is to take repeated steps in the opposite direction of the gradi ...
.
By the 1980s,
momentum
In Newtonian mechanics, momentum (: momenta or momentums; more specifically linear momentum or translational momentum) is the product of the mass and velocity of an object. It is a vector quantity, possessing a magnitude and a direction. ...
had already been introduced, and was added to SGD optimization techniques in 1986. However, these optimization techniques assumed constant
hyperparameters, i.e. a fixed learning rate and momentum parameter. In the 2010s, adaptive approaches to applying SGD with a per-parameter learning rate were introduced with AdaGrad (for "Adaptive Gradient") in 2011
and RMSprop (for "Root Mean Square Propagation") in 2012.
In 2014, Adam (for "Adaptive Moment Estimation") was published, applying the adaptive approaches of RMSprop to momentum; many improvements and branches of Adam were then developed such as Adadelta, Adagrad, AdamW, and Adamax.
Within machine learning, approaches to optimization in 2023 are dominated by Adam-derived optimizers. TensorFlow and PyTorch, by far the most popular machine learning libraries, as of 2023 largely only include Adam-derived optimizers, as well as predecessors to Adam such as RMSprop and classic SGD. PyTorch also partially supports
Limited-memory BFGS
Limited-memory BFGS (L-BFGS or LM-BFGS) is an optimization algorithm in the family of quasi-Newton methods that approximates the Broyden–Fletcher–Goldfarb–Shanno algorithm (BFGS) using a limited amount of computer memory. It is a popular alg ...
, a line-search method, but only for single-device setups without parameter groups.
Notable applications
Stochastic gradient descent is a popular algorithm for training a wide range of models in
machine learning
Machine learning (ML) is a field of study in artificial intelligence concerned with the development and study of Computational statistics, statistical algorithms that can learn from data and generalise to unseen data, and thus perform Task ( ...
, including (linear)
support vector machine
In machine learning, support vector machines (SVMs, also support vector networks) are supervised max-margin models with associated learning algorithms that analyze data for classification and regression analysis. Developed at AT&T Bell Laborato ...
s,
logistic regression
In statistics, a logistic model (or logit model) is a statistical model that models the logit, log-odds of an event as a linear function (calculus), linear combination of one or more independent variables. In regression analysis, logistic regres ...
(see, e.g.,
Vowpal Wabbit
Vowpal Wabbit (VW) is an open-source fast online interactive machine learning system library and program developed originally at Yahoo! Research, and currently at Microsoft Research. It was started and is led by John Langford. Vowpal Wabbit's in ...
) and
graphical model
A graphical model or probabilistic graphical model (PGM) or structured probabilistic model is a probabilistic model for which a graph expresses the conditional dependence structure between random variables. Graphical models are commonly used in ...
s. When combined with the
back propagation algorithm, it is the ''de facto'' standard algorithm for training
artificial neural network
In machine learning, a neural network (also artificial neural network or neural net, abbreviated ANN or NN) is a computational model inspired by the structure and functions of biological neural networks.
A neural network consists of connected ...
s. Its use has been also reported in the
Geophysics
Geophysics () is a subject of natural science concerned with the physical processes and Physical property, properties of Earth and its surrounding space environment, and the use of quantitative methods for their analysis. Geophysicists conduct i ...
community, specifically to applications of Full Waveform Inversion (FWI).
Stochastic gradient descent competes with the
L-BFGS algorithm, which is also widely used. Stochastic gradient descent has been used since at least 1960 for training
linear regression
In statistics, linear regression is a statistical model, model that estimates the relationship between a Scalar (mathematics), scalar response (dependent variable) and one or more explanatory variables (regressor or independent variable). A mode ...
models, originally under the name
ADALINE.
Another stochastic gradient descent algorithm is the
least mean squares (LMS) adaptive filter.
Extensions and variants
Many improvements on the basic stochastic gradient descent algorithm have been proposed and used. In particular, in machine learning, the need to set a
learning rate
In machine learning and statistics, the learning rate is a tuning parameter in an optimization algorithm that determines the step size at each iteration while moving toward a minimum of a loss function. Since it influences to what extent newly ...
(step size) has been recognized as problematic. Setting this parameter too high can cause the algorithm to diverge; setting it too low makes it slow to converge. A conceptually simple extension of stochastic gradient descent makes the learning rate a decreasing function of the iteration number , giving a ''learning rate schedule'', so that the first iterations cause large changes in the parameters, while the later ones do only fine-tuning. Such schedules have been known since the work of MacQueen on
-means clustering. Practical guidance on choosing the step size in several variants of SGD is given by Spall.
Implicit updates (ISGD)
As mentioned earlier, classical stochastic gradient descent is generally sensitive to
learning rate
In machine learning and statistics, the learning rate is a tuning parameter in an optimization algorithm that determines the step size at each iteration while moving toward a minimum of a loss function. Since it influences to what extent newly ...
. Fast convergence requires large learning rates but this may induce numerical instability. The problem can be largely solved by considering ''implicit updates'' whereby the stochastic gradient is evaluated at the next iterate rather than the current one:
This equation is implicit since
appears on both sides of the equation. It is a stochastic form of the
proximal gradient method
Proximal gradient methods are a generalized form of projection used to solve non-differentiable convex optimization problems.
Many interesting problems can be formulated as convex optimization problems of the form
\operatorname\limits_ \sum_^ ...
since the update
can also be written as:
As an example,
consider least squares with features
and observations
. We wish to solve:
where
indicates the inner product.
Note that
could have "1" as the first element to include an intercept. Classical stochastic gradient descent proceeds as follows:
where
is uniformly sampled between 1 and
. Although theoretical convergence of this procedure happens under relatively mild assumptions, in practice the procedure can be quite unstable. In particular, when
is misspecified so that
has large absolute eigenvalues with high probability, the procedure may diverge numerically within a few iterations. In contrast, ''implicit stochastic gradient descent'' (shortened as ISGD) can be solved in closed-form as:
This procedure will remain numerically stable virtually for all
as the
learning rate
In machine learning and statistics, the learning rate is a tuning parameter in an optimization algorithm that determines the step size at each iteration while moving toward a minimum of a loss function. Since it influences to what extent newly ...
is now normalized. Such comparison between classical and implicit stochastic gradient descent in the least squares problem is very similar to the comparison between
least mean squares (LMS) and
normalized least mean squares filter (NLMS).
Even though a closed-form solution for ISGD is only possible in least squares, the procedure can be efficiently implemented in a wide range of models. Specifically, suppose that
depends on
only through a linear combination with features
, so that we can write
, where
may depend on
as well but not on
except through
. Least squares obeys this rule, and so does
logistic regression
In statistics, a logistic model (or logit model) is a statistical model that models the logit, log-odds of an event as a linear function (calculus), linear combination of one or more independent variables. In regression analysis, logistic regres ...
, and most
generalized linear model
In statistics, a generalized linear model (GLM) is a flexible generalization of ordinary linear regression. The GLM generalizes linear regression by allowing the linear model to be related to the response variable via a ''link function'' and by ...
s. For instance, in least squares,
, and in logistic regression
, where
is the
logistic function
A logistic function or logistic curve is a common S-shaped curve ( sigmoid curve) with the equation
f(x) = \frac
where
The logistic function has domain the real numbers, the limit as x \to -\infty is 0, and the limit as x \to +\infty is L.
...
. In
Poisson regression
In statistics, Poisson regression is a generalized linear model form of regression analysis used to model count data and contingency tables. Poisson regression assumes the response variable ''Y'' has a Poisson distribution, and assumes the lo ...
,
, and so on.
In such settings, ISGD is simply implemented as follows. Let
, where
is scalar.
Then, ISGD is equivalent to:
The scaling factor
can be found through the
bisection method
In mathematics, the bisection method is a root-finding method that applies to any continuous function for which one knows two values with opposite signs. The method consists of repeatedly bisecting the interval defined by these values and t ...
since in most regular models, such as the aforementioned generalized linear models, function
is decreasing, and thus the search bounds for
are