Distillation (machine Learning)
   HOME

TheInfoList



OR:

In
machine learning Machine learning (ML) is a field of inquiry devoted to understanding and building methods that 'learn', that is, methods that leverage data to improve performance on some set of tasks. It is seen as a part of artificial intelligence. Machine ...
, knowledge distillation is the process of transferring knowledge from a large model to a smaller one. While large models (such as very
deep neural network Deep learning (also known as deep structured learning) is part of a broader family of machine learning methods based on artificial neural networks with representation learning. Learning can be supervised, semi-supervised or unsupervised. De ...
s or
ensemble Ensemble may refer to: Art * Architectural ensemble * ''Ensemble'' (album), Kendji Girac 2015 album * Ensemble (band), a project of Olivier Alary * Ensemble cast (drama, comedy) * Ensemble (musical theatre), also known as the chorus * ''En ...
s of many models) have higher knowledge capacity than small models, this capacity might not be fully utilized. It can be just as computationally expensive to evaluate a model even if it utilizes little of its knowledge capacity. Knowledge distillation transfers knowledge from a large model to a smaller model without loss of
validity Validity or Valid may refer to: Science/mathematics/statistics: * Validity (logic), a property of a logical argument * Scientific: ** Internal validity, the validity of causal inferences within scientific studies, usually based on experiments ** ...
. As smaller models are less expensive to evaluate, they can be deployed on less powerful hardware (such as a
mobile device A mobile device (or handheld computer) is a computer small enough to hold and operate in the hand. Mobile devices typically have a flat LCD or OLED screen, a touchscreen interface, and digital or physical buttons. They may also have a physical ...
). Knowledge distillation has been successfully used in several applications of machine learning such as object detection, acoustic models, and
natural language processing Natural language processing (NLP) is an interdisciplinary subfield of linguistics, computer science, and artificial intelligence concerned with the interactions between computers and human language, in particular how to program computers to pro ...
. Recently, it has also been introduced to graph neural networks applicable to non-grid data.


Concept of distillation

Transferring the knowledge from a large to a small model needs to somehow teach to the latter without loss of validity. If both models are trained on the same data, the small model may have insufficient capacity to learn a concise knowledge representation given the same computational resources and same data as the large model. However, some information about a concise knowledge representation is encoded in the
pseudolikelihood In statistical theory, a pseudolikelihood is an approximation to the joint probability distribution of a collection of random variables. The practical use of this is that it can provide an approximation to the likelihood function of a set of observ ...
s assigned to its output: when a model correctly predicts a class, it assigns a large value to the output variable corresponding to such class, and smaller values to the other output variables. The distribution of values among the outputs for a record provides information on how the large model represents knowledge. Therefore, the goal of economical deployment of a valid model can be achieved by training only the large model on the data, exploiting its better ability to learn concise knowledge representations, and then distilling such knowledge into the smaller model, that would not be able to learn it on its own, by training it to learn the soft output of the large model. Model compression, a methodology to compress the knowledge of multiple models into a single
neural network A neural network is a network or circuit of biological neurons, or, in a modern sense, an artificial neural network, composed of artificial neurons or nodes. Thus, a neural network is either a biological neural network, made up of biological ...
, was introduced in 2006. Compression was achieved by training a smaller model on large amounts of pseudo-data labelled by a higher-performing ensemble, optimising to match the
logit In statistics, the logit ( ) function is the quantile function associated with the standard logistic distribution. It has many uses in data analysis and machine learning, especially in data transformations. Mathematically, the logit is the ...
of the compressed model to the logit of the ensemble. Knowledge distillation is a generalisation of such approach, introduced by Geoffrey Hinton et al. in 2015, in a
preprint In academic publishing, a preprint is a version of a scholarly or scientific paper that precedes formal peer review and publication in a peer-reviewed scholarly or scientific journal. The preprint may be available, often as a non-typeset versio ...
that formulated the concept and showed some results achieved in the task of
image classification Computer vision is an interdisciplinary scientific field that deals with how computers can gain high-level understanding from digital images or videos. From the perspective of engineering, it seeks to understand and automate tasks that the hum ...
.


Formulation

Given a large model as a function of the vector variable \mathbf, trained for a specific
classification Classification is a process related to categorization, the process in which ideas and objects are recognized, differentiated and understood. Classification is the grouping of related facts into classes. It may also refer to: Business, organizat ...
task, typically the final layer of the network is a softmax in the form : y_i(\mathbf, t) = \frac where t is a parameter called ''temperature'', that for a standard softmax is normally set to 1. The softmax operator converts the
logit In statistics, the logit ( ) function is the quantile function associated with the standard logistic distribution. It has many uses in data analysis and machine learning, especially in data transformations. Mathematically, the logit is the ...
values z_i(\mathbf) to pseudo-probabilities, and higher values of temperature have the effect of generating a softer distribution of pseudo-probabilities among the output classes. Knowledge distillation consists of training a smaller network, called the ''distilled model'', on a
dataset A data set (or dataset) is a collection of data. In the case of tabular data, a data set corresponds to one or more database tables, where every column of a table represents a particular variable, and each row corresponds to a given record of the ...
called transfer set (different than the dataset used to train the large model) using the cross entropy as
loss function In mathematical optimization and decision theory, a loss function or cost function (sometimes also called an error function) is a function that maps an event or values of one or more variables onto a real number intuitively representing some "cost ...
between the output of the distilled model \mathbf(\mathbf, t) and the output \hat(\mathbf, t) produced by the large model on the same record (or the average of the individual outputs, if the large model is an ensemble), using a high value of softmax temperature t for both models : E(\mathbf, t) = -\sum_i \hat_i(\mathbf, t) \log y_i(\mathbf, t) . In this context, a high temperature increases the entropy of the output, and therefore provides more information to learn for the distilled model compared to hard targets, at the same time reducing the variance of the
gradient In vector calculus, the gradient of a scalar-valued differentiable function of several variables is the vector field (or vector-valued function) \nabla f whose value at a point p is the "direction and rate of fastest increase". If the gradi ...
between different records and therefore allowing higher
learning rate In machine learning and statistics, the learning rate is a tuning parameter in an optimization algorithm that determines the step size at each iteration while moving toward a minimum of a loss function. Since it influences to what extent newly ac ...
s. If ground truth is available for the transfer set, the process can be strengthened by adding to the loss the cross-entropy between the output of the distilled model (computed with t = 1) and the known label \bar : E(\mathbf, t) = -t^2 \sum_i \hat_i(\mathbf, t) \log y_i(\mathbf, t) - \sum_i \bar_i \log y_i(\mathbf, 1) where the component of the loss with respect to the large model is weighted by a factor of t^2 since, as the temperature increases, the gradient of the loss with respect to the model weights scales by a factor of \frac.


Relationship with model compression

Under the assumption that the logits have zero
mean There are several kinds of mean in mathematics, especially in statistics. Each mean serves to summarize a given group of data, often to better understand the overall value (magnitude and sign) of a given data set. For a data set, the ''arithme ...
, it is possible to show that model compression is a special case of knowledge distillation. The gradient of the knowledge distillation loss E with respect to the logit of the distilled model z_i is given by : \begin \frac E &= -\frac \sum_j \hat_j \log y_j \\ &= -\hat_i \frac \frac y_i \\ &= -\hat_i \frac \frac \frac \\ &= -\hat_i \frac \left( \frac \right) \\ &= -\hat_i \frac \left( \frac - \frac \right) \\ &= \frac \left( y_i - \hat_i \right) \\ &= \frac \left( \frac - \frac \right) \\ \end where \hat_i are the logits of the large model. For large values of t this can be approximated as : \frac \left( \frac - \frac \right) and under the zero-mean hypothesis \sum_j z_j = \sum_j \hat_j = 0 it becomes \frac , which is the derivative of \frac \left( z_i - \hat{z}_i \right)^2, i.e. the loss is equivalent to matching the logits of the two models, as done in model compression.


References


External links


Distilling the knowledge in a neural network – Google AI
Deep learning