Stochastic gradient descent (often abbreviated SGD) is an iterative method for optimizing an
objective function
In mathematical optimization and decision theory, a loss function or cost function (sometimes also called an error function) is a function that maps an event or values of one or more variables onto a real number intuitively representing some "cost ...
with suitable
smoothness properties (e.g.
differentiable
In mathematics, a differentiable function of one real variable is a function whose derivative exists at each point in its domain. In other words, the graph of a differentiable function has a non-vertical tangent line at each interior point in its ...
or
subdifferentiable). It can be regarded as a
stochastic approximation of
gradient descent
In mathematics, gradient descent (also often called steepest descent) is a first-order iterative optimization algorithm for finding a local minimum of a differentiable function. The idea is to take repeated steps in the opposite direction of the ...
optimization, since it replaces the actual gradient (calculated from the entire
data set A data set (or dataset) is a collection of data. In the case of tabular data, a data set corresponds to one or more database tables, where every column of a table represents a particular variable, and each row corresponds to a given record of the ...
) by an estimate thereof (calculated from a randomly selected subset of the data). Especially in
high-dimensional
In physics and mathematics, the dimension of a mathematical space (or object) is informally defined as the minimum number of coordinates needed to specify any point within it. Thus, a line has a dimension of one (1D) because only one coord ...
optimization problems this reduces the very high
computational burden, achieving faster iterations in trade for a lower convergence rate.
While the basic idea behind stochastic approximation can be traced back to the
Robbins–Monro algorithm
Stochastic approximation methods are a family of iterative methods typically used for root-finding problems or for optimization problem, optimization problems. The recursive update rules of stochastic approximation methods can be used, among other ...
of the 1950s, stochastic gradient descent has become an important optimization method in
machine learning
Machine learning (ML) is a field of inquiry devoted to understanding and building methods that 'learn', that is, methods that leverage data to improve performance on some set of tasks. It is seen as a part of artificial intelligence.
Machine ...
.
Background
Both
statistical estimation
Estimation (or estimating) is the process of finding an estimate or approximation, which is a value that is usable for some purpose even if input data may be incomplete, uncertain, or unstable. The value is nonetheless usable because it is de ...
and
machine learning
Machine learning (ML) is a field of inquiry devoted to understanding and building methods that 'learn', that is, methods that leverage data to improve performance on some set of tasks. It is seen as a part of artificial intelligence.
Machine ...
consider the problem of
minimizing an
objective function
In mathematical optimization and decision theory, a loss function or cost function (sometimes also called an error function) is a function that maps an event or values of one or more variables onto a real number intuitively representing some "cost ...
that has the form of a sum:
:
where the
parameter
A parameter (), generally, is any characteristic that can help in defining or classifying a particular system (meaning an event, project, object, situation, etc.). That is, a parameter is an element of a system that is useful, or critical, when ...
that minimizes
is to be
estimated
Estimation (or estimating) is the process of finding an estimate or approximation, which is a value that is usable for some purpose even if input data may be incomplete, uncertain, or unstable. The value is nonetheless usable because it is der ...
. Each summand function
is typically associated with the
-th
observation in the
data set A data set (or dataset) is a collection of data. In the case of tabular data, a data set corresponds to one or more database tables, where every column of a table represents a particular variable, and each row corresponds to a given record of the ...
(used for training).
In classical statistics, sum-minimization problems arise in
least squares and in
maximum-likelihood estimation
In statistics, maximum likelihood estimation (MLE) is a method of estimating the parameters of an assumed probability distribution, given some observed data. This is achieved by maximizing a likelihood function so that, under the assumed stati ...
(for independent observations). The general class of estimators that arise as minimizers of sums are called
M-estimator
In statistics, M-estimators are a broad class of extremum estimators for which the objective function is a sample average. Both non-linear least squares and maximum likelihood estimation are special cases of M-estimators. The definition of M-estim ...
s. However, in statistics, it has been long recognized that requiring even local minimization is too restrictive for some problems of maximum-likelihood estimation. Therefore, contemporary statistical theorists often consider
stationary point
In mathematics, particularly in calculus, a stationary point of a differentiable function of one variable is a point on the graph of the function where the function's derivative is zero. Informally, it is a point where the function "stops" in ...
s of the
likelihood function
The likelihood function (often simply called the likelihood) represents the probability of random variable realizations conditional on particular values of the statistical parameters. Thus, when evaluated on a given sample, the likelihood funct ...
(or zeros of its derivative, the
score function, and other
estimating equations
In statistics, the method of estimating equations is a way of specifying how the parameters of a statistical model should be estimated. This can be thought of as a generalisation of many classical methods—the method of moments, least squares, ...
).
The sum-minimization problem also arises for
empirical risk minimization
Empirical risk minimization (ERM) is a principle in statistical learning theory which defines a family of learning algorithms and is used to give theoretical bounds on their performance. The core idea is that we cannot know exactly how well an al ...
. In this case,
is the value of the
loss function at
-th example, and
is the empirical risk.
When used to minimize the above function, a standard (or "batch")
gradient descent
In mathematics, gradient descent (also often called steepest descent) is a first-order iterative optimization algorithm for finding a local minimum of a differentiable function. The idea is to take repeated steps in the opposite direction of the ...
method would perform the following iterations:
:
where
is a step size (sometimes called the ''
learning rate
In machine learning and statistics, the learning rate is a tuning parameter in an optimization algorithm that determines the step size at each iteration while moving toward a minimum of a loss function. Since it influences to what extent newly ...
'' in machine learning).
In many cases, the summand functions have a simple form that enables inexpensive evaluations of the sum-function and the sum gradient. For example, in statistics,
one-parameter exponential families allow economical function-evaluations and gradient-evaluations.
However, in other cases, evaluating the sum-gradient may require expensive evaluations of the gradients from all summand functions. When the training set is enormous and no simple formulas exist, evaluating the sums of gradients becomes very expensive, because evaluating the gradient requires evaluating all the summand functions' gradients. To economize on the computational cost at every iteration, stochastic gradient descent
samples a subset of summand functions at every step. This is very effective in the case of large-scale machine learning problems.
Iterative method
In stochastic (or "on-line") gradient descent, the true gradient of
is approximated by a gradient at a single sample:
:
As the algorithm sweeps through the training set, it performs the above update for each training sample. Several passes can be made over the training set until the algorithm converges. If this is done, the data can be shuffled for each pass to prevent cycles. Typical implementations may use an
adaptive learning rate
In machine learning and statistics, the learning rate is a Hyperparameter (machine learning), tuning parameter in an Mathematical optimization, optimization algorithm that determines the step size at each iteration while moving toward a minimum of ...
so that the algorithm converges.
In pseudocode, stochastic gradient descent can be presented as :
* Choose an initial vector of parameters and learning rate .
* Repeat until an approximate minimum is obtained:
** Randomly shuffle samples in the training set.
** For , do:
***
A compromise between computing the true gradient and the gradient at a single sample is to compute the gradient against more than one training sample (called a "mini-batch") at each step. This can perform significantly better than "true" stochastic gradient descent described, because the code can make use of
vectorization
Vectorization may refer to:
Computing
* Array programming, a style of computer programming where operations are applied to whole arrays instead of individual elements
* Automatic vectorization, a compiler optimization that transforms loops to vec ...
libraries rather than computing each step separately as was first shown in where it was called "the bunch-mode back-propagation algorithm". It may also result in smoother convergence, as the gradient computed at each step is averaged over more training sample.
The convergence of stochastic gradient descent has been analyzed using the theories of
convex minimization and of
stochastic approximation. Briefly, when the
learning rate
In machine learning and statistics, the learning rate is a tuning parameter in an optimization algorithm that determines the step size at each iteration while moving toward a minimum of a loss function. Since it influences to what extent newly ...
s
decrease with an appropriate rate,
and subject to relatively mild assumptions, stochastic gradient descent converges
almost surely
In probability theory, an event is said to happen almost surely (sometimes abbreviated as a.s.) if it happens with probability 1 (or Lebesgue measure 1). In other words, the set of possible exceptions may be non-empty, but it has probability 0 ...
to a global minimum
when the objective function is
convex
Convex or convexity may refer to:
Science and technology
* Convex lens, in optics
Mathematics
* Convex set, containing the whole line segment that joins points
** Convex polygon, a polygon which encloses a convex set of points
** Convex polytop ...
or
pseudoconvex,
and otherwise converges almost surely to a local minimum.
This is in fact a consequence of the
Robbins–Siegmund theorem.
Example
Let's suppose we want to fit a straight line
to a training set with observations
and corresponding estimated responses
using
least squares. The objective function to be minimized is:
:
The last line in the above pseudocode for this specific problem will become:
:
Note that in each iteration (also called update), the gradient is only evaluated at a single point
instead of at the set of all samples.
The key difference compared to standard (Batch) Gradient Descent is that only one piece of data from the dataset is used to calculate the step, and the piece of data is picked randomly at each step.
Notable applications
Stochastic gradient descent is a popular algorithm for training a wide range of models in
machine learning
Machine learning (ML) is a field of inquiry devoted to understanding and building methods that 'learn', that is, methods that leverage data to improve performance on some set of tasks. It is seen as a part of artificial intelligence.
Machine ...
, including (linear)
support vector machines,
logistic regression
In statistics, the logistic model (or logit model) is a statistical model that models the probability of an event taking place by having the log-odds for the event be a linear combination of one or more independent variables. In regression a ...
(see, e.g.,
Vowpal Wabbit
Vowpal Wabbit (VW) is an open-source fast online interactive machine learning system library and program developed originally at Yahoo! Research, and currently at Microsoft Research. It was started and is led by John Langford. Vowpal Wabbit's i ...
) and
graphical model
A graphical model or probabilistic graphical model (PGM) or structured probabilistic model is a probabilistic model for which a graph expresses the conditional dependence structure between random variables. They are commonly used in probabili ...
s. When combined with the
backpropagation
In machine learning, backpropagation (backprop, BP) is a widely used algorithm for training feedforward artificial neural networks. Generalizations of backpropagation exist for other artificial neural networks (ANNs), and for functions gener ...
algorithm, it is the ''de facto'' standard algorithm for training
artificial neural network
Artificial neural networks (ANNs), usually simply called neural networks (NNs) or neural nets, are computing systems inspired by the biological neural networks that constitute animal brains.
An ANN is based on a collection of connected unit ...
s. Its use has been also reported in the
Geophysics
Geophysics () is a subject of natural science concerned with the physical processes and physical properties of the Earth and its surrounding space environment, and the use of quantitative methods for their analysis. The term ''geophysics'' so ...
community, specifically to applications of Full Waveform Inversion (FWI).
Stochastic gradient descent competes with the
L-BFGS
Limited-memory BFGS (L-BFGS or LM-BFGS) is an optimization algorithm in the family of quasi-Newton methods that approximates the Broyden–Fletcher–Goldfarb–Shanno algorithm (BFGS) using a limited amount of computer memory. It is a popular algo ...
algorithm, which is also widely used. Stochastic gradient descent has been used since at least 1960 for training
linear regression models, originally under the name
ADALINE
ADALINE (Adaptive Linear Neuron or later Adaptive Linear Element) is an early single-layer artificial neural network and the name of the physical device that implemented this network. The network uses memistors. It was developed by Professor Bern ...
.
Another stochastic gradient descent algorithm is the
least mean squares (LMS) adaptive filter.
Extensions and variants
Many improvements on the basic stochastic gradient descent algorithm have been proposed and used. In particular, in machine learning, the need to set a
learning rate
In machine learning and statistics, the learning rate is a tuning parameter in an optimization algorithm that determines the step size at each iteration while moving toward a minimum of a loss function. Since it influences to what extent newly ...
(step size) has been recognized as problematic. Setting this parameter too high can cause the algorithm to diverge; setting it too low makes it slow to converge. A conceptually simple extension of stochastic gradient descent makes the learning rate a decreasing function of the iteration number , giving a ''learning rate schedule'', so that the first iterations cause large changes in the parameters, while the later ones do only fine-tuning. Such schedules have been known since the work of MacQueen on
-means clustering. Practical guidance on choosing the step size in several variants of SGD is given by Spall.
Implicit updates (ISGD)
As mentioned earlier, classical stochastic gradient descent is generally sensitive to
learning rate
In machine learning and statistics, the learning rate is a tuning parameter in an optimization algorithm that determines the step size at each iteration while moving toward a minimum of a loss function. Since it influences to what extent newly ...
. Fast convergence requires large learning rates but this may induce numerical instability. The problem can be largely solved by considering ''implicit updates'' whereby the stochastic gradient is evaluated at the next iterate rather than the current one:
:
This equation is implicit since
appears on both sides of the equation. It is a stochastic form of the
proximal gradient method
Proximal gradient methods are a generalized form of projection used to solve non-differentiable convex optimization problems.
Many interesting problems can be formulated as convex optimization problems of the form
\operatorname\limits_ \sum_^n ...
since the update
can also be written as:
:
As an example,
consider least squares with features
and observations
. We wish to solve:
:
where
indicates the inner product.
Note that
could have "1" as the first element to include an intercept. Classical stochastic gradient descent proceeds as follows:
:
where
is uniformly sampled between 1 and
. Although theoretical convergence of this procedure happens under relatively mild assumptions, in practice the procedure can be quite unstable. In particular, when
is misspecified so that
has large absolute eigenvalues with high probability, the procedure may diverge numerically within a few iterations. In contrast, ''implicit stochastic gradient descent'' (shortened as ISGD) can be solved in closed-form as:
:
This procedure will remain numerically stable virtually for all
as the
learning rate
In machine learning and statistics, the learning rate is a tuning parameter in an optimization algorithm that determines the step size at each iteration while moving toward a minimum of a loss function. Since it influences to what extent newly ...
is now normalized. Such comparison between classical and implicit stochastic gradient descent in the least squares problem is very similar to the comparison between
least mean squares (LMS) and
normalized least mean squares filter (NLMS).
Even though a closed-form solution for ISGD is only possible in least squares, the procedure can be efficiently implemented in a wide range of models. Specifically, suppose that
depends on
only through a linear combination with features
, so that we can write
, where
may depend on
as well but not on
except through
. Least squares obeys this rule, and so does
logistic regression
In statistics, the logistic model (or logit model) is a statistical model that models the probability of an event taking place by having the log-odds for the event be a linear combination of one or more independent variables. In regression a ...
, and most
generalized linear models. For instance, in least squares,
, and in logistic regression
, where
is the
logistic function. In
Poisson regression
In statistics, Poisson regression is a generalized linear model form of regression analysis used to model count data and contingency tables. Poisson regression assumes the response variable ''Y'' has a Poisson distribution, and assumes the logari ...
,
, and so on.
In such settings, ISGD is simply implemented as follows. Let
, where
is scalar.
Then, ISGD is equivalent to:
:
The scaling factor
can be found through the
bisection method
In mathematics, the bisection method is a root-finding method that applies to any continuous function for which one knows two values with opposite signs. The method consists of repeatedly bisecting the interval defined by these values and the ...
since
in most regular models, such as the aforementioned generalized linear models, function
is decreasing,
and thus the search bounds for
are