February 27, 2020

From PyTorch to PyTorch Lightning — A gentle introduction

William Falcon

This post answers the most frequent question about why you need Lightning if you’re using PyTorch.

PyTorch is extremely easy to use to build complex AI models. But once the research gets complicated and things like multi-GPU training, 16-bit precision and TPU training get mixed in, users are likely to introduce bugs.

PyTorch Lightning solves exactly this problem. Lightning structures your PyTorch code so it can abstract the details of training. This makes AI research scalable and fast to iterate on.

Who is PyTorch Lightning For?

Image for post
PyTorch Lightning was created while doing PhD research at both NYU and FAIR

PyTorch Lightning was created for professional researchers and PhD students working on AI research.

Lightning was born out of my Ph.D. AI research at NYU CILVR and Facebook AI Research. As a result, the framework is designed to be extremely extensible while making state of the art AI research techniques (like TPU training) trivial.

Now the core contributors are all pushing the state of the art in AI using Lightning and continue to add new cool features.

Image for post
Image for post

However, the simple interface gives professional production teams and newcomers access to the latest state of the art techniques developed by the Pytorch and PyTorch Lightning community.

Lightning counts with over 320 contributors, a core team of 11 research scientists, PhD students and professional deep learning engineers.

Image for post
Image for post

it is rigorously tested

Image for post

and thoroughly documented

Image for post

Outline

This tutorial will walk you through building a simple MNIST classifier showing PyTorch and PyTorch Lightning code side-by-side. While Lightning can build any arbitrarily complicated system, we use MNIST to illustrate how to refactor PyTorch code into PyTorch Lightning.

The full code is available at this Colab Notebook.

The Typical AI Research project

In a research project, we normally want to identify the following key components:

The Model

Let’s design a 3-layer fully-connected neural network that takes as input an image that is 28x28 and outputs a probability distribution over 10 possible labels.

First, let’s define the model in PyTorch

Image for post

This model defines the computational graph to take as input an MNIST image and convert it to a probability distribution over 10 classes for digits 0–9.

Image for post
3-layer network (illustration by: William Falcon)

To convert this model to PyTorch Lightning we simply replace the nn.Module with the pl.LightningModule

Image for post

The new PyTorch Lightning class is EXACTLY the same as the PyTorch, except that the LightningModule provides a structure for the research code.

Lightning provides structure to PyTorch code

Image for post

See? The code is EXACTLY the same for both!

This means you can use a LightningModule exactly as you would a PyTorch module such as prediction

Image for post

Or use it as a pretrained model

Image for post

The Data

For this tutorial we’re using MNIST.

Image for post
Source: Wikipedia

Let’s generate three splits of MNIST, a training, validation and test split.

This again, is the same code in PyTorch as it is in Lightning.

The dataset is added to the Dataloader which handles the loading, shuffling and batching of the dataset.

In short, data preparation has 4 steps:

  1. Download images
  2. Image transforms (these are highly subjective).
  3. Generate training, validation and test dataset splits.
  4. Wrap each dataset split in a DataLoader
Image for post

Again, the code is exactly the same except that we’ve organized the PyTorch code into 4 functions:

prepare_data

This function handles downloads and any data processing. This function makes sure that when you use multiple GPUs you don’t download multiple datasets or apply double manipulations to the data.

This is because each GPU will execute the same PyTorch thereby causing duplication. ALL of the code in Lightning makes sure the critical parts are called from ONLY one GPU.

train_dataloader, val_dataloader, test_dataloader

Each of these is responsible for returning the appropriate data split. Lightning structures it this way so that it is VERY clear HOW the data are being manipulated. If you ever read random github code written in PyTorch it’s nearly impossible to see how they manipulate their data.

Lightning even allows multiple dataloaders for testing or validating.

This code is organized under what we call a DataModule. Although this is 100% optional and lightning can use DataLoaders directly, a DataModule makes your data reusable and easy to share.

The Optimizer

Now we choose how we’re going to do the optimization. We’ll use Adam instead of SGD because it is a good default in most DL research.

Image for post

Again, this is exactly the same in both except it is organized into the configure optimizers function.

Lightning is extremely extensible. For instance, if you wanted to use multiple optimizers (ie: a GAN), you could just return both here.

Image for post

You’ll also notice that in Lightning we pass in self.parameters() and not a model because the LightningModule IS the model.

The Loss

For n-way classification we want to compute the cross-entropy loss. Cross-entropy is the same as NegativeLogLikelihood(log_softmax) which we’ll use instead.

Image for post

Again… code is exactly the same!

Training and Validation Loop

We assembled all the key ingredients needed for training:

  1. The model (3-layer NN)
  2. The dataset (MNIST)
  3. An optimizer
  4. A loss

Now we implement a full training routine which does the following:

in math

Image for post

in code

Image for post

in math

Image for post
Image for post

in code

Image for post

in math

Image for post

the code

Image for post

in math

Image for post

in code

Image for post

in math

Image for post

in code

Image for post

in math

Image for post

in code

Image for post

In both PyTorch and Lightning the pseudocode looks like this

Image for post

This is where lightning differs though. In PyTorch, you write the for loop yourself which means you have to remember to call the correct things in the right order — this leaves a lot of room for bugs.

Even if your model is simple, it won’t be once you start doing more advanced things like using multiple GPUs, gradient clipping, early stopping, checkpointing, TPU training, 16-bit precision, etc… Your code complexity will quickly explode.

Even if your model is simple, it won’t be once you start doing more advanced things

Here’s are the validation and training loop for both PyTorch and Lightning

Image for post

This is the beauty of lightning. It abstracts the boilerplate (the stuff not in boxes) but leaves everything else unchanged. This means you are STILL writing PyTorch except your code has been structured nicely.

This increases readability which helps with reproducibility!

The Lightning Trainer

The trainer is how we abstract the boilerplate code.

Image for post

Again, this is possible because ALL you had to do was organize your PyTorch code into a LightningModule

Full Training Loop for PyTorch

The full MNIST example written in PyTorch is as follows:

Full Training loop in Lightning

The lightning version is EXACTLY the same except:

This version does not use the DataModule, but instead keeps the dataloaders defined freely.

And here is the same code but the data has been grouped under the DataModule and made more reusable.

Highlights

Let’s call out a few key points

  1. Without Lightning, the PyTorch code is allowed to be in arbitrary parts. With Lightning, this is structured.
  2. It is the same exact code for both except that it’s structured in Lightning. (worth saying twice lol).
  3. As the project grows in complexity, your code won’t because Lightning abstracts out most of it.
  4. You retain the flexibility of PyTorch because you have full control over the key points in training. For instance, you could have an arbitrarily complex training_step such as a seq2seq

5. In Lightning you got a bunch of freebies such as a sick progress bar

Image for post

you also got a beautiful weights summary

Image for post

tensorboard logs (yup! you had to nothing to get this)

Image for post

and free checkpointing, and early stopping.

All for free!

Additional Features

But Lightning is known best for out of the box goodies such as TPU training etc…

In Lightning, you can train your model on CPUs, GPUs, Multiple GPUs, or TPUs without changing a single line of your PyTorch code.

You can also do 16-bit precision training

Image for post

Log using 5 other alternatives to Tensorboard

Image for post
Logging with Neptune.AI (credits: Neptune.ai)

Image for post
Logging with Comet.ml

We even have a built in profiler that can tell you where the bottlenecks are in your training.

Image for post

Setting this flag on gives you this output

Image for post

Or a more advanced output if you want

Image for post
Image for post

We can also train on multiple GPUs at once without you doing any work (you still have to submit a SLURM job)

Image for post

And there are about 40 other features it supports which you can read about in the documentation.

Extensibility With Hooks

You’re probably wondering how it’s possible for Lightning to do this for you and yet somehow make it so that you have full control over everything?

Unlike keras or other high-level frameworks lightning does not hide any of the necessary details. But if you do find the need to modify every aspect of training on your own, then you have two main options.

The first is extensibility by overriding hooks. Here’s a non-exhaustive list:

Image for post
Image for post
Image for post
Image for post

These overrides happen in the LightningModule

Image for post

Extensibility with Callbacks

A callback is a piece of code that you’d like to be executed at various parts of training. In Lightning callbacks are reserved for non-essential code such as logging or something not related to research code. This keeps the research code super clean and organized.

Let’s say you wanted to print something or save something at various parts of training. Here’s how the callback would look like

PyTorch Lightning Callback

Image for post

Now you pass this into the trainer and this code will be called at arbitrary times

Image for post

This paradigm keeps your research code organized into three different buckets

  1. Research code (LightningModule) (this is the science).
  2. Engineering code (Trainer)
  3. Non-research related code (Callbacks)

How to start

Hopefully this guide showed you exactly how to get started. The easiest way to start is to run the colab notebook with the MNIST example here.

Or install Lightning

Image for post

Or check out the Github page.