Let us suppose that you are developing a machine learning model which has to generalize as well as possible on unseen data. In general, sophisticated models such as a random forest with many trees, a very deep neural network or an ensemble of models do have a good generalization power. However, these models require a large disk space and may be slow at inference.
In this article, we will see how knowledge distillation helps to build models mimicking the performance of a cumbersome model but lighter, i.e. having less parameters.
In machine learning, we accept the idea that the objective function is made to reflect the interest of the user as closely as possible. However, all the algorithms tend to minimize the cost function on the training data while the real interest is to generalize well on new data.
It is clearly better to train models to generalize better but this requires knowing how to do so. When we are distilling the knowledge from a large model to a small model or from a specialist to a generalist, we can train the student to generalize in the same way than the teacher.
In general, the teacher is well suited to transfer this kind of information as it is either a cumbersome model, an ensemble of multiple networks for example, or an expert model on his domain (and generalizes better on his domain).
The objective of knowledge distillation is to fill the gap between the interest of the user, which is good generalization on unseen data, and the cost function used during training. One way to transfer the generalization ability of the cumbersome model or the multiple models is to use class probabilities as soft targets. Instead of trying to match the ground truth labels, we will perform optimization on the softened targets provided by the teacher.
In other terms, if we consider a classification problem, the negative log-likelihood cost function will be replaced by the Kullback-Leibler divergence between the teacher’s distribution and the student’s distribution.
Traditionally, the loss function for the training data point x :
where y is the true class and K the number of classes.
Replacing the cost function by the KL-divergence would result in:
where p^T and p^S are the probability distributions of the teacher and the student respectively.
This is considered as the simplest form of distillation that can work without having true labels for the transfer set, which can be the teacher’s training set. When the correct labels are known for this transfer set or a subset of it, we can incorporate this information to make use of the added information and train the model to produce the correct labels.
The combination method used in our case result in this cost function which is a weighted sum between the negative log-likelihood and the Kullback-Leibler divergence.
where \lambda is a hyper-parameter of the method quantifying the softening of the targets. If \lambda = 0 we find the cross-entropy loss and if \lambda = 1 we get the simplest form of distillation described above where the ground truth labels weren’t exploited.
Knowledge distillation as multi-task technique
Knowledge distillation can also be used to produce a multi-task or multi-domain student that may or may not be of the same size than the teachers (K \leq N).
This strategy will enable us to use a single model to translate sentences from different domains instead of storing the different expert models, filtering the data during inference to know the domain of the sentence to translate and exploiting the corresponding teacher. It is time and memory efficient and particularly convenient in a production environment.
The use of specialists that are trained on different domains has some resemblance to mixtures of experts which learns how to assign each example to the most likely expert through probability computation. During training, two learning phases happen simultaneously, the experts are learning how to be accurate on the examples assigned to them and the gating network, responsible of the assignment of examples, is learning how to do so.
The major difference between the two methods is parallelization. It is easier to train at the same time different teachers on their corresponding domain than to parallelize mixtures of experts.