May 10, 2020

Getting Started with PyTorch Lightning

Neelabh Madan (IIT Delhi)

Imagine, one day you have an amazing idea for your machine learning project. You write down all the details on a piece of paper- the model architecture, the optimizer, the dataset. And now you just have code it up and do some hyperparameter tuning to put it to application.

So, you light up your machine and start coding. But suddenly it hits you, you need to go through the hard work of creating batches out of the data, writing loops to iterate over batches and epochs, debugging any issues that may arise while doing so, repeating the same for the validation set and the list goes on. It turns out to be a headache before it even started.

Illustration by @Neelabh

But not anymore. PyTorch Lightning is here to save your day. Not only does it automatically do the hard work for you but it also structures your code to make it more scalable. It comes fully packed with awesome features that will enhance your machine learning experience. Beginners should definitely give it a go.

Throughout this blog we will learn how can Lightning be used along with PyTorch to make development easy and reproducible.

Roadmap

With this blog post, I aim to help people get to know PyTorch Lightning. From now on I will be referring to PyTorch Lightning as Lightning.

Photo by Bruno Bergher on Unsplash

I will begin with a brief introduction to the new library and its underlying principles so that you can build research-friendly neural network models from scratch.

This tutorial assumes that you have prior knowledge of how a neural network works. It also assumes you are familiar with the PyTorch framework. Even if you are not familiar, you will be alright. For PyTorch users, this tutorial may serve as a medium to encourage them to include Lightening in their PyTorch code.

Let us start with some basic introduction.

What is PyTorch?

Source: Ventrebeat.com

Based on the Torch library, PyTorch is an open-source machine learning library. PyTorch is imperative, which means computations run immediately, and the user need not wait to write the full code before checking if it works or not. We can efficiently run a part of the code and inspect it in real-time. The library is python based and built for providing flexibility as a deep learning development platform.

PyTorch is extremely “pythonic” in nature. It is basically a NumPy substitute that utilizes the computation benefits of powerful GPUs

PyTorch enables the support of dynamic computational graphs that allows us to change the network on the fly.

The Catch

PyTorch is an excellent framework, great for researchers. But after a certain point, it involves more engineering than researching.

As I mentioned in the introduction, the hard work starts taking over the research work. The focus shifts from training and tuning the model to correctly implementing the following features

Even though they may be simple to implement, we would still end up losing precious time and might risk a chance of making a mistake while coding these up leading to time being wasted in debugging.

Source : The Independent

Consider an example. We are training a model. We want that after 100 epochs it stops and saves the trained model into a .pth file. But we made a mistake in writing the model-saving code. The thing about python is that it does not show an error until it runs into one. So, after 10 hours of training, we run into an error. and our model did not save. And just like that, the 10 hours go down the drain. How frustrating would this be?

Enter Lightning

Lightning is a very lightweight wrapper on PyTorch. This means you don’t have to learn a new library. It defers the core training and validation logic to you and automates the rest. It guarantees tested and correct code with the best modern practices for the automated parts.

How to save model in PyTorch
In lightning Models are saved by default

So we can actually save those 10 hours by carefully organizing our code in Lightning modules.

As the name suggests, Lightning is related to closely PyTorch: not only do they share their roots at Facebook but also Lightning is a wrapper for PyTorch itself. In fact, the core foundation of PyTorch Lightning is built upon PyTorch.

In its true sense, Lightning is a structuring tool for your PyTorch code. You just have to provide the bare minimum details (Eg. number of epoch, optimizer, etc). The rest will be automated by Lightning.

Lightning reduces the amount of work needed to be done (By @neelabh)

By using Lightning, you make sure that all the tricky pieces of code work for you and you can focus on the real research:

Lightning ensures that when your network becomes complex your code doesn’t

It ensures that you focus on the real deal and not worry about how to run your model on multiple GPUs or speeding up the code. Lightning will handle that for you.

But what does this mean for you? It means that this framework is designed to be extremely extensible while making state of the art AI research techniques (like multi-GPU training) trivial.

Quick MNIST Classifier on Google Colab

I will be showing you exactly how you can build a MNIST classifier using Lightning. I will be walking you through a very small network with 99.4% accuracy on MNIST Validation set using <8k trainable parameters. I tried re-implementing the code using PyTorch-Lightening and added my own intuitions and explanations.

We shall do this as quickly as possible so that we can move on to even more interesting details of Lightning

Source: Wikipedia

The Main Aspects of a Lightning Model

Illustration by @neelabh

The basic and essential chunks of a Neural Network in Lightning are the following

  1. Model architecture— Restructuring
  2. Data — Restructuring
  3. Forward pass — Restructuring
  4. Optimizer — Restructuring
  5. Training Step — Restructuring
  6. Training and Validation Loops (Lightning Trainer) — Abstraction

We can clearly see that they are contained in 2 categories: Restructuring and Abstraction

Restructuring

Source: PyTorch Lightning Docs

Restructuring refers to keeping code in its respective place in the Lightning Module. It has just been arranged in the functions of Lightning Module known as Callbacks. They have a special meaning to the Lightning because it helps it understand the functionality of the function

It is to be noted that there is no change in the PyTorch code during the restructuring

Abstraction

The boilerplate code is abstracted by the Lightning trainer. It automates most of the code for us.

Now there is no need to write separate code for saving your model or iterating over batches. Its is now abstracted into the Trainer

What does it contain?

Lightning provides us with the following methods of its class pl.LightningModule that help in structuring the code. They refer to them as Callbacks:

I've partnered with OpenCV.org to bring you official courses in Computer Vision, Machine Learning, and AI! Sign up now and take your skills to the next level!

OFFICIAL COURSES BY OPENCV.ORG

Coding an MNIST Classifier

Now let’s dive right into coding so that we can get a hands on experience with Lightning

Installing Lightning

Run the following to install Lightning on Google Colab

1!pip install pytorch_lightning

You will have to restart the runtime for some new changes to be reflected

Do not forget to select the GPU. Go to Edit->Notebook Settings->Hardware Accelerator and select GPU in Google Colab Notebook

Import Libraries

1import torch

2from torch.nn import functional as F

3from torch.utils.data import DataLoader, random_split

4from torchvision.datasets import MNIST

5from torchvision import transforms

6import pytorch_lightning as pl

1. The Model

We will be defining our own class called smallAndSmartClassifier and we will be inheriting pl.LightningModule from Lightning

Let’s start building the model

Illustration by @neelabh

1class smallAndSmartModel(pl.LightningModule):

2    def __init__(self):

3        super(smallAndSmartModel, self).__init__()

4        self.layer1 = torch.nn.Sequential(

5            torch.nn.Conv2d(1,28,kernel_size=5),

6            torch.nn.ReLU(),

7            torch.nn.MaxPool2d(kernel_size=2))

8        self.layer2 = torch.nn.Sequential(

9            torch.nn.Conv2d(28,10,kernel_size=2),

10            torch.nn.ReLU(),

11            torch.nn.MaxPool2d(kernel_size=2))

12        self.dropout1=torch.nn.Dropout(0.25)

13        self.fc1=torch.nn.Linear(250,18)

14        self.dropout2=torch.nn.Dropout(0.08)

15        self.fc2=torch.nn.Linear(18,10)

2. Data Loading

1class smallAndSmartModel(pl.LightningModule):

2    

3    #This contains the manupulation on data that needs to be done only once such as downloading it

4    def prepare_data(self):

5        MNIST(os.getcwd(), train=True, download =True)

6        MNIST(os.getcwd(), train=False, download =True)

7    

8    def train_dataloader(self):

9        #This is an essential function. Needs to be included in the code

10        #See here i have set download to false as it is already downloaded in prepare_data

11        mnist_train=MNIST(os.getcwd(), train=True, download =False,transform=transforms.ToTensor())

12        

13        #Dividing into validation and training set

14        self.train_set, self.val_set= random_split(mnist_train,[55000,5000])

15        

16        return DataLoader(self.train_set,batch_size=128)

17        

18    def val_dataloader(self):

19        # OPTIONAL

20        return DataLoader(self.val_set, batch_size=128)

21

22    def test_dataloader(self):

23        # OPTIONAL

24        return DataLoader(MNIST(os.getcwd(), train=False, download=False, transform=transforms.ToTensor()), batch_size=128)

The train_dataloader, test_dataloader and val_dataloader are reserved functions in pl.LightningModule. We use them as wrappers for loading our data.

It is necessary to write the code in these functions just because they have a special meaning in Lightning, just like how forward has in nn.module

Each of these is responsible for returning the appropriate data split. Lightning structures it in a way so that it is very clear how the data is being manipulated. If you ever read someone else’s code that isn’t structured like this (like most GitHub codes), you won’t be able to figure out how they manipulated their data.

Lightning even allows multiple data loaders for testing or validating.

3. Forward Pass

1class smallAndSmartModel(pl.LightningModule):

2      def forward(self,x):

3          x=self.layer1(x)

4          x=self.layer2(x)

5          x=self.dropout1(x)

6          x=torch.relu(self.fc1(x.view(x.size(0), -1)))

7          x=F.leaky_relu(self.dropout2(x))

8          

9          return F.softmax(self.fc2(x))

This is the forward pass — where the calculation process takes place and we generate the values for the output layers from the inputs data.

Users of PyTorch may notice that there is no change in its implementation

4. Optimizer

1class smallAndSmartModel(pl.LightningModule):

2    def configure_optimizers(self):

3        # Essential fuction

4        #we are using Adam optimizer for our model

5        return torch.optim.Adam(self.parameters())

This required function returns the kind of optimizer we require. Interestingly Lightning provides us with the wrapper configure_optimizers, which allows us to even return multiple optimizers with ease (for example in GANs)

5. Training Step (The interesting part)

1class smallAndSmartModel(pl.LightningModule):

2      def training_step(self,batch,batch_idx):

3          

4          #extracting input and output from the batch

5          x,labels=batch

6          

7          #doing a forward pass

8          pred=self.forward(x)

9          

10          #calculating the loss

11          loss = F.nnl_loss(pred, labels)

12          

13          #logs

14          logs={"train_loss": loss}

15          

16          output={

17              #REQUIRED: It ie required for us to return "loss"

18              "loss": loss,

19              #optional for logging purposes

20              "log": logs

21          }

22          

23          return output

This step is called for every batch in our dataset. Some key operations that occur in this function are:

It is essential for training_step to return a dictionary containing loss. Any other data returned is optional

6. The Lightning Trainer ( Where Magic Happens)

Photo by Dollar Gill on Unsplash

Obviously, there is no magic. But when I tell you what Lightning Trainer is capable of, you won’t refrain from claiming that indeed, it is charming and exquisite.

1#abstracts the training, val and test loops

2

3#using one gpu given to us by google colab for max 40 epochs

4myTrainer=pl.Trainer(gpus=1,max_nb_epochs=100)

5

6model=smallAndSmartModel()

7myTrainer.fit(model)

The Trainer is the heart of PyTorch Lightning. This is where all the abstractions take place. It abstracts the most obvious pieces of code such as:

Now you don’t have to worry about engineering these steps. The Trainer does that for you. You just have to make sure that your code is well structured as explained in the above sections.

Download Code To easily follow along this tutorial, please download code by clicking on the button below. It's FREE!

DOWNLOAD CODE

Lightning Trainer Flags

The trainer provides some very helpful flags. We can assign values to these flags to configure our classifier’s behavior.

Perks of Lightning Trainer

Photo by Todd Quackenbush on Unsplash

By using the Trainer, you automatically get the following tools and features:

  1. Training and validation loop
  2. Tensorboard logging
  3. Early-stopping
  4. Model checkpointing
  5. The ability to resume training from wherever you left

Why should I start using PyTorch Lightning?

That’s the question you should be asking me after I told you so much about Pytorch Lightning. I will answer this by letting you in on my love for Lightning

1. Peace of Mind (Structured Code)

Photo by Dingzeyu Li on Unsplash

When I look at how the code is structured in Lightning, it feels almost natural and intuitive to put it there. The structuring ensures that I have a step-by-step strategy of developing my classifier from scratch. It is as if it makes me more confident in developing my models.

2. Simplistic

Photo by Lindsay Henwood on Unsplash

The steps to make solution for machine learning are now very simple and intuitive.

Now, to come up with a solution using Lightning, I know that I need to proceed by preparing data, adding optimizers, add the training step, and so on. This helps me in moving along with the flow of ideas in my mind.

3. Grouping the relevant together

Source: Alarmy

The best thing about Lightning is that each process is separated from the other in the Lightningmodule. That’s the benefit of structuring.

training_step contains information about the training step and not about the validation step or about the optimizer. It makes things more clear for me

4. No True Change in Code required

Illustration by @neelabh

Since Lightning is a wrapper for PyTorch, I did not have to learn a new language. Also, if I want to make very complex training steps I can easily do that without compromising on the flexibility of PyTorch.

Those who are familiar with PyTorch will find the transition to be extremely smooth.

5. The Lightning Trainer — Automation

Photo by Giorgio Trovato on Unsplash

The Trainer just wins it all. It automates most of the complex tasks for me.

In the case of GPUs, I don’t have to worry about converting my tensors to tensor.to(device=cuda). It automatically figures out the details. I just have to set a few flags. With this, I can even enable 16-bit precision, auto-cluster saving, auto-learning-rate-finder, Tensorboard visualization, etc.

By using the Trainer, I’m not only getting some very neat algorithms but I am also getting the guarantee that they will work correctly. Now that’s one less thing for me to worry about. And I can focus on my real research.

My personal favorite is Tensorboard logging and resuming training from where I left it.

Whom does Lightning caters to?

Photo by Chris Liverani on Unsplash

Lightning is best for scholars and researchers who are working on developing the best strategies to tackle a problem. Lightning takes away the unnecessary engineering from them and provides with a clean environment to perform relevant research.

I also believe that early PyTorch users should start using Lightning so that their thinking process becomes structured and more intuitive. Also, they might find it amazing to have so many perks at their disposal, ready to be exploited.

Congratulations

Now that you are acquainted with PyTorch Lightning, I hope you will start using Lightning (especially if you are a researcher) and fall in love with its amazing features.

That’s all from me. If you liked my little introduction to Lightning do share feedback

Keep learning and have fun!!

References