Mixed-precision arithmetic is a form of
floating-point arithmetic
In computing, floating-point arithmetic (FP) is arithmetic on subsets of real numbers formed by a ''significand'' (a Sign (mathematics), signed sequence of a fixed number of digits in some Radix, base) multiplied by an integer power of that ba ...
that uses numbers with varying widths in a single operation.
Overview
A common usage of mixed-precision arithmetic is for operating on inaccurate numbers with a small width and expanding them to a larger, more accurate representation. For example, two
half-precision or
bfloat16 (16-bit) floating-point numbers may be multiplied together to result in a more accurate
single-precision
Single-precision floating-point format (sometimes called FP32 or float32) is a computer number format, usually occupying 32 bits in computer memory; it represents a wide dynamic range of numeric values by using a floating radix point.
A floati ...
(32-bit) float.
In this way, mixed-precision arithmetic approximates
arbitrary-precision arithmetic
In computer science, arbitrary-precision arithmetic, also called bignum arithmetic, multiple-precision arithmetic, or sometimes infinite-precision arithmetic, indicates that calculations are performed on numbers whose digits of precision are po ...
, albeit with a low number of possible precisions.
Iterative algorithms (like
gradient descent
Gradient descent is a method for unconstrained mathematical optimization. It is a first-order iterative algorithm for minimizing a differentiable multivariate function.
The idea is to take repeated steps in the opposite direction of the gradi ...
) are good candidates for mixed-precision arithmetic. In an iterative algorithm like
square root
In mathematics, a square root of a number is a number such that y^2 = x; in other words, a number whose ''square'' (the result of multiplying the number by itself, or y \cdot y) is . For example, 4 and −4 are square roots of 16 because 4 ...
, a coarse integral guess can be made and refined over many iterations until the error in precision makes it such that the smallest addition or subtraction to the guess is still too coarse to be an acceptable answer. When this happens, the precision can be increased to something more precise, which allows for smaller increments to be used for the approximation.
Supercomputer
A supercomputer is a type of computer with a high level of performance as compared to a general-purpose computer. The performance of a supercomputer is commonly measured in floating-point operations per second (FLOPS) instead of million instruc ...
s such as
Summit
A summit is a point on a surface that is higher in elevation than all points immediately adjacent to it. The topographic terms acme, apex, peak (mountain peak), and zenith are synonymous.
The term (mountain top) is generally used only for ...
utilize mixed-precision arithmetic to be more efficient with regards to memory and processing time, as well as power consumption.
Floating point format
A floating-point number is typically packed into a single bit-string, as the sign bit, the exponent field, and the significand or mantissa, from left to right. As an example, a
IEEE 754
The IEEE Standard for Floating-Point Arithmetic (IEEE 754) is a technical standard for floating-point arithmetic originally established in 1985 by the Institute of Electrical and Electronics Engineers (IEEE). The standard #Design rationale, add ...
standard 32-bit float ("FP32", "float32", or "binary32") is packed as follows:
The IEEE 754 binary floats are:
Machine learning
Mixed-precision arithmetic is used in the field of
machine learning
Machine learning (ML) is a field of study in artificial intelligence concerned with the development and study of Computational statistics, statistical algorithms that can learn from data and generalise to unseen data, and thus perform Task ( ...
, since
gradient descent
Gradient descent is a method for unconstrained mathematical optimization. It is a first-order iterative algorithm for minimizing a differentiable multivariate function.
The idea is to take repeated steps in the opposite direction of the gradi ...
algorithms can use coarse and efficient half-precision floats for certain tasks, but can be more accurate if they use more precise but slower single-precision floats. Some platforms, including
Nvidia
Nvidia Corporation ( ) is an American multinational corporation and technology company headquartered in Santa Clara, California, and incorporated in Delaware. Founded in 1993 by Jensen Huang (president and CEO), Chris Malachowsky, and Curti ...
,
Intel
Intel Corporation is an American multinational corporation and technology company headquartered in Santa Clara, California, and Delaware General Corporation Law, incorporated in Delaware. Intel designs, manufactures, and sells computer compo ...
, and
AMD
Advanced Micro Devices, Inc. (AMD) is an American multinational corporation and technology company headquartered in Santa Clara, California and maintains significant operations in Austin, Texas. AMD is a hardware and fabless company that de ...
CPUs and GPUs, provide mixed-precision arithmetic for this purpose, using coarse floats when possible, but expanding them to higher precision when necessary.
Automatic mixed precision
PyTorch
PyTorch is a machine learning library based on the Torch library, used for applications such as computer vision and natural language processing, originally developed by Meta AI and now part of the Linux Foundation umbrella. It is one of the mo ...
implements automatic mixed-precision (AMP), which performs autocasting, gradient scaling, and loss scaling.
* The weights are stored in a master copy at a high precision, usually in FP32.
* Autocasting means automatically converting a floating-point number between different precisions, such as from FP32 to FP16, during training. For example,
matrix multiplication
In mathematics, specifically in linear algebra, matrix multiplication is a binary operation that produces a matrix (mathematics), matrix from two matrices. For matrix multiplication, the number of columns in the first matrix must be equal to the n ...
s can often be performed in FP16 without loss of accuracy, even if the master copy weights are stored in FP32. Low-precision weights are used during forward pass.
* Gradient scaling means multiplying
gradient
In vector calculus, the gradient of a scalar-valued differentiable function f of several variables is the vector field (or vector-valued function) \nabla f whose value at a point p gives the direction and the rate of fastest increase. The g ...
s by a constant factor during training, typically before the
weight optimizer update. This is done to prevent the gradients from underflowing to zero when using low-precision data types like FP16. Mathematically, if the unscaled gradient is
, the scaled gradient is
where
is the scaling factor. Within the optimizer update, the scaled gradient is cast to a higher precision before it is scaled down (no longer underflowing, as it is in a higher precision) to update the weights.
* Loss scaling means multiplying the loss function by a constant factor during training, typically before
backpropagation
In machine learning, backpropagation is a gradient computation method commonly used for training a neural network to compute its parameter updates.
It is an efficient application of the chain rule to neural networks. Backpropagation computes th ...
. This is done to prevent the gradients from underflowing to zero when using low-precision data types. If the unscaled loss is
, the scaled loss is
where
is the scaling factor. Since gradient scaling and loss scaling are mathematically equivalent by
, loss scaling is an implementation of gradient scaling.
PyTorch AMP uses
exponential backoff
Exponential backoff is an algorithm that uses feedback to multiplicatively decrease the rate of some process, in order to gradually find an acceptable rate. These algorithms find usage in a wide range of systems and processes, with radio networ ...
to automatically adjust the scale factor for loss scaling. That is, it periodically increase the scale factor. Whenever the gradients contain a
NaN (indicating overflow), the weight update is skipped, and the scale factor is decreased.
References
{{Reflist
Floating point
Computer arithmetic