In the last post, we talked about floating-point precision and the mixed point precision method. We discussed how to cast the weights of a neural network into a lower precision in order to save computation time, as well as using some very clever tricks to preserve accuracy as well.

In this post, we will put these ideas into practice by training a simple neural network on the MNIST dataset using mixed point precision. The machine learning framework used will be Pytorch with the AMP (automatic mixed precision) package.

As the name suggests, the AMP package will automatically carry out the mixed precision algorithm (master copy, loss scaling, arithmetic precision) without the user having to do floating point operations manually.

Here is the basic code format taken from the Pytorch AMP examples page:

# Packages required for automatic mixed precision
from torch.cuda.amp import GradScaler, autocast

# Creates model and optimizer in default precision
model = Net().cuda()
optimizer = optim.SGD(model.parameters(), ...)

# Creates a GradScaler once at the beginning of training.
scaler = GradScaler()

for epoch in epochs:
    for input, target in data:
        optimizer.zero_grad()

        # The forward pass and loss function need to be encapsulated under autocast().
        with autocast():
            output = model(input)
            loss = loss_fn(output, target)

        # The scale() function from the GradScaler object is called instead of
        # the usual loss.backward()
        scaler.scale(loss).backward()

        # Scaling back the gradients is necessary if there are gradient operations
        # (such as clipping) to be done

        # torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=max)

        # scaler.step() first unscales the gradients of the optimizer's assigned params.
        # If these gradients do not contain infs or NaNs, optimizer.step() is then called,
        # otherwise, optimizer.step() is skipped.
        scaler.step(optimizer)

        # A new line that's needed: updates the scaler for next iteration.
        scaler.update()

notice that the main differences are the precision casting in the feedforward and loss function calculation, and the loss scaling.

Experiments

We test out mixed precision training on an NLP classification task, with a pretrained BERT (Bidirectional Encoder Representations from Transformers) (cased) model fine-tuned on IMDB Review Data from HuggingFace Datasets. The dataset consists of reviews of films or TV series by IMDB users, and each review is labelled with a postive (1) sentiment or negative (0) sentiment. Our aim is to train the BERT model to classify reviews as positive or negative.

The Colab notebook can be found here. Much of the code was taken from the excellent tutorial on Curiousily, as well as a HuggingFace pages.

Warning: Mixed precision only works with GPUs! So if you’re not using GPUs, go to Runtime > Change Runtime Type > Hardware Accelerator, and select GPU from the dropdown menu.

The reason for picking BERT lies both in its effectiveness and for its vast number of parameters (~110 million). Therefore we would expect that the techniques in mixed precision training would help us save a significant amount of training time.

We took a small sample of 1000 reviews from the dataset and split them into 800 training samples and 200 testing samples. We then trained 10 epochs with a batch size of 64 with the AdamW optimizer using both regular training and mixed precision training.

The results are summarized as follows:

Training Method Time Taken (s) (*) Test Accuracy (after 10 epochs)
Regular 807.8 0.88
Mixed Precision 412.2 0.88

Notice that with mixed precision, we had almost cut down half the training time! Moreover, the test accuracy was comparable between the two models.

Therefore it can be concluded that for large models such as BERT, mixed precision training is all but essential.

Smaller Models

One small caveat we should mention is that mixed precision is best done with bigger models. In another experiment, we have attempted mixed precision on the Resnet18 model, which has around 11 million parameters. The task was to train Resnet18 on the MNIST dataset under the two training methods. Mixed precision did not offer any advantage in this case.


(*) Technically speaking, the running time also included time spent on validation after each epoch, but validation of 200 samples takes so little time it can safely be ignored from the analysis.