Stochastic gradient descent (often abbreviated SGD) is an iterative method for optimizing an
objective function
In mathematical optimization and decision theory, a loss function or cost function (sometimes also called an error function) is a function that maps an event or values of one or more variables onto a real number intuitively representing some "cos ...
with suitable
smoothness properties (e.g.
differentiable or
subdifferentiable). It can be regarded as a
stochastic approximation of
gradient descent optimization, since it replaces the actual gradient (calculated from the entire
data set A data set (or dataset) is a collection of data. In the case of tabular data, a data set corresponds to one or more database tables, where every column of a table represents a particular variable, and each row corresponds to a given record of th ...
) by an estimate thereof (calculated from a randomly selected subset of the data). Especially in
high-dimensional optimization problems this reduces the very high
computational burden, achieving faster iterations in trade for a lower convergence rate.
While the basic idea behind stochastic approximation can be traced back to the
Robbins–Monro algorithm of the 1950s, stochastic gradient descent has become an important optimization method in
machine learning.
Background
Both
statistical estimation and
machine learning consider the problem of
minimizing an
objective function
In mathematical optimization and decision theory, a loss function or cost function (sometimes also called an error function) is a function that maps an event or values of one or more variables onto a real number intuitively representing some "cos ...
that has the form of a sum:
:
where the
parameter that minimizes
is to be
estimated
Estimation (or estimating) is the process of finding an estimate or approximation, which is a value that is usable for some purpose even if input data may be incomplete, uncertain, or unstable. The value is nonetheless usable because it is der ...
. Each summand function
is typically associated with the
-th
observation
Observation is the active acquisition of information from a primary source. In living beings, observation employs the senses. In science, observation can also involve the perception and recording of data via the use of scientific instruments. Th ...
in the
data set A data set (or dataset) is a collection of data. In the case of tabular data, a data set corresponds to one or more database tables, where every column of a table represents a particular variable, and each row corresponds to a given record of th ...
(used for training).
In classical statistics, sum-minimization problems arise in
least squares
The method of least squares is a standard approach in regression analysis to approximate the solution of overdetermined systems (sets of equations in which there are more equations than unknowns) by minimizing the sum of the squares of the re ...
and in
maximum-likelihood estimation (for independent observations). The general class of estimators that arise as minimizers of sums are called
M-estimators. However, in statistics, it has been long recognized that requiring even local minimization is too restrictive for some problems of maximum-likelihood estimation. Therefore, contemporary statistical theorists often consider
stationary point
In mathematics, particularly in calculus, a stationary point of a differentiable function of one variable is a point on the graph of the function where the function's derivative is zero. Informally, it is a point where the function "stops" in ...
s of the
likelihood function (or zeros of its derivative, the
score function, and other
estimating equations).
The sum-minimization problem also arises for
empirical risk minimization. In this case,
is the value of the
loss function at
-th example, and
is the empirical risk.
When used to minimize the above function, a standard (or "batch")
gradient descent method would perform the following iterations:
:
where
is a step size (sometimes called the ''
learning rate
In machine learning and statistics, the learning rate is a tuning parameter in an optimization algorithm that determines the step size at each iteration while moving toward a minimum of a loss function. Since it influences to what extent newly ...
'' in machine learning).
In many cases, the summand functions have a simple form that enables inexpensive evaluations of the sum-function and the sum gradient. For example, in statistics,
one-parameter exponential families allow economical function-evaluations and gradient-evaluations.
However, in other cases, evaluating the sum-gradient may require expensive evaluations of the gradients from all summand functions. When the training set is enormous and no simple formulas exist, evaluating the sums of gradients becomes very expensive, because evaluating the gradient requires evaluating all the summand functions' gradients. To economize on the computational cost at every iteration, stochastic gradient descent
samples a subset of summand functions at every step. This is very effective in the case of large-scale machine learning problems.
Iterative method
In stochastic (or "on-line") gradient descent, the true gradient of
is approximated by a gradient at a single sample:
:
As the algorithm sweeps through the training set, it performs the above update for each training sample. Several passes can be made over the training set until the algorithm converges. If this is done, the data can be shuffled for each pass to prevent cycles. Typical implementations may use an
adaptive learning rate so that the algorithm converges.
In pseudocode, stochastic gradient descent can be presented as :
* Choose an initial vector of parameters and learning rate .
* Repeat until an approximate minimum is obtained:
** Randomly shuffle samples in the training set.
** For , do:
***
A compromise between computing the true gradient and the gradient at a single sample is to compute the gradient against more than one training sample (called a "mini-batch") at each step. This can perform significantly better than "true" stochastic gradient descent described, because the code can make use of
vectorization
Vectorization may refer to:
Computing
* Array programming, a style of computer programming where operations are applied to whole arrays instead of individual elements
* Automatic vectorization, a compiler optimization that transforms loops to vect ...
libraries rather than computing each step separately as was first shown in where it was called "the bunch-mode back-propagation algorithm". It may also result in smoother convergence, as the gradient computed at each step is averaged over more training sample.
The convergence of stochastic gradient descent has been analyzed using the theories of
convex minimization
Convex optimization is a subfield of mathematical optimization that studies the problem of minimizing convex functions over convex sets (or, equivalently, maximizing concave functions over convex sets). Many classes of convex optimization probl ...
and of
stochastic approximation. Briefly, when the
learning rate
In machine learning and statistics, the learning rate is a tuning parameter in an optimization algorithm that determines the step size at each iteration while moving toward a minimum of a loss function. Since it influences to what extent newly ...
s
decrease with an appropriate rate,
and subject to relatively mild assumptions, stochastic gradient descent converges
almost surely
In probability theory, an event is said to happen almost surely (sometimes abbreviated as a.s.) if it happens with probability 1 (or Lebesgue measure 1). In other words, the set of possible exceptions may be non-empty, but it has probability 0. ...
to a global minimum
when the objective function is
convex
Convex or convexity may refer to:
Science and technology
* Convex lens, in optics
Mathematics
* Convex set, containing the whole line segment that joins points
** Convex polygon, a polygon which encloses a convex set of points
** Convex polyto ...
or
pseudoconvex
In mathematics, more precisely in the theory of functions of several complex variables, a pseudoconvex set is a special type of open set in the ''n''-dimensional complex space C''n''. Pseudoconvex sets are important, as they allow for classificati ...
,
and otherwise converges almost surely to a local minimum.
This is in fact a consequence of the
Robbins–Siegmund theorem.
Example
Let's suppose we want to fit a straight line
to a training set with observations
and corresponding estimated responses
using
least squares
The method of least squares is a standard approach in regression analysis to approximate the solution of overdetermined systems (sets of equations in which there are more equations than unknowns) by minimizing the sum of the squares of the re ...
. The objective function to be minimized is:
:
The last line in the above pseudocode for this specific problem will become:
:
Note that in each iteration (also called update), the gradient is only evaluated at a single point
instead of at the set of all samples.
The key difference compared to standard (Batch) Gradient Descent is that only one piece of data from the dataset is used to calculate the step, and the piece of data is picked randomly at each step.
Notable applications
Stochastic gradient descent is a popular algorithm for training a wide range of models in
machine learning, including (linear)
support vector machines,
logistic regression
In statistics, the logistic model (or logit model) is a statistical model that models the probability of an event taking place by having the log-odds for the event be a linear combination of one or more independent variables. In regression analy ...
(see, e.g.,
Vowpal Wabbit) and
graphical model
A graphical model or probabilistic graphical model (PGM) or structured probabilistic model is a probabilistic model for which a graph expresses the conditional dependence structure between random variables. They are commonly used in probability ...
s. When combined with the
backpropagation
In machine learning, backpropagation (backprop, BP) is a widely used algorithm for training feedforward artificial neural networks. Generalizations of backpropagation exist for other artificial neural networks (ANNs), and for functions gene ...
algorithm, it is the ''de facto'' standard algorithm for training
artificial neural network
Artificial neural networks (ANNs), usually simply called neural networks (NNs) or neural nets, are computing systems inspired by the biological neural networks that constitute animal brains.
An ANN is based on a collection of connected units ...
s. Its use has been also reported in the
Geophysics community, specifically to applications of Full Waveform Inversion (FWI).
Stochastic gradient descent competes with the
L-BFGS
Limited-memory BFGS (L-BFGS or LM-BFGS) is an optimization algorithm in the family of quasi-Newton methods that approximates the Broyden–Fletcher–Goldfarb–Shanno algorithm (BFGS) using a limited amount of computer memory. It is a popular algo ...
algorithm, which is also widely used. Stochastic gradient descent has been used since at least 1960 for training
linear regression models, originally under the name
ADALINE.
Another stochastic gradient descent algorithm is the
least mean squares (LMS) adaptive filter.
Extensions and variants
Many improvements on the basic stochastic gradient descent algorithm have been proposed and used. In particular, in machine learning, the need to set a
learning rate
In machine learning and statistics, the learning rate is a tuning parameter in an optimization algorithm that determines the step size at each iteration while moving toward a minimum of a loss function. Since it influences to what extent newly ...
(step size) has been recognized as problematic. Setting this parameter too high can cause the algorithm to diverge; setting it too low makes it slow to converge. A conceptually simple extension of stochastic gradient descent makes the learning rate a decreasing function of the iteration number , giving a ''learning rate schedule'', so that the first iterations cause large changes in the parameters, while the later ones do only fine-tuning. Such schedules have been known since the work of MacQueen on
-means clustering. Practical guidance on choosing the step size in several variants of SGD is given by Spall.
Implicit updates (ISGD)
As mentioned earlier, classical stochastic gradient descent is generally sensitive to
learning rate
In machine learning and statistics, the learning rate is a tuning parameter in an optimization algorithm that determines the step size at each iteration while moving toward a minimum of a loss function. Since it influences to what extent newly ...
. Fast convergence requires large learning rates but this may induce numerical instability. The problem can be largely solved by considering ''implicit updates'' whereby the stochastic gradient is evaluated at the next iterate rather than the current one:
:
This equation is implicit since
appears on both sides of the equation. It is a stochastic form of the
proximal gradient method
Proximal gradient methods are a generalized form of projection used to solve non-differentiable convex optimization problems.
Many interesting problems can be formulated as convex optimization problems of the form
\operatorname\limits_ \sum_^n ...
since the update
can also be written as:
:
As an example,
consider least squares with features
and observations
. We wish to solve:
:
where
indicates the inner product.
Note that
could have "1" as the first element to include an intercept. Classical stochastic gradient descent proceeds as follows:
:
where
is uniformly sampled between 1 and
. Although theoretical convergence of this procedure happens under relatively mild assumptions, in practice the procedure can be quite unstable. In particular, when
is misspecified so that
has large absolute eigenvalues with high probability, the procedure may diverge numerically within a few iterations. In contrast, ''implicit stochastic gradient descent'' (shortened as ISGD) can be solved in closed-form as:
:
This procedure will remain numerically stable virtually for all
as the
learning rate
In machine learning and statistics, the learning rate is a tuning parameter in an optimization algorithm that determines the step size at each iteration while moving toward a minimum of a loss function. Since it influences to what extent newly ...
is now normalized. Such comparison between classical and implicit stochastic gradient descent in the least squares problem is very similar to the comparison between
least mean squares (LMS) and
normalized least mean squares filter (NLMS).
Even though a closed-form solution for ISGD is only possible in least squares, the procedure can be efficiently implemented in a wide range of models. Specifically, suppose that
depends on
only through a linear combination with features
, so that we can write
, where
may depend on
as well but not on
except through
. Least squares obeys this rule, and so does
logistic regression
In statistics, the logistic model (or logit model) is a statistical model that models the probability of an event taking place by having the log-odds for the event be a linear combination of one or more independent variables. In regression analy ...
, and most
generalized linear model
In statistics, a generalized linear model (GLM) is a flexible generalization of ordinary linear regression. The GLM generalizes linear regression by allowing the linear model to be related to the response variable via a ''link function'' and by ...
s. For instance, in least squares,
, and in logistic regression
, where
is the
logistic function
A logistic function or logistic curve is a common S-shaped curve ( sigmoid curve) with equation
f(x) = \frac,
where
For values of x in the domain of real numbers from -\infty to +\infty, the S-curve shown on the right is obtained, with th ...
. In
Poisson regression
In statistics, Poisson regression is a generalized linear model form of regression analysis used to model count data and contingency tables. Poisson regression assumes the response variable ''Y'' has a Poisson distribution, and assumes the logari ...
,
, and so on.
In such settings, ISGD is simply implemented as follows. Let
, where
is scalar.
Then, ISGD is equivalent to:
:
The scaling factor
can be found through the
bisection method
In mathematics, the bisection method is a root-finding method that applies to any continuous function for which one knows two values with opposite signs. The method consists of repeatedly bisecting the interval defined by these values and th ...
since
in most regular models, such as the aforementioned generalized linear models, function
is decreasing,
and thus the search bounds for
are