July 28, 2020

Automate Your Neural Network Training With PyTorch Lightning

Erfandi Maula Yusnu, Lalu

PyTorch Lightning will automate your neural network training while staying your code simple, clean, and flexible. If you’re a researcher you will love this!


Image for post

Image Created By Author, Source of Background and Logo from PyTorch and PyTorch Lightning Official Site

PyTorch Lightning is the lightweight PyTorch wrapper for ML researchers. it helps you to scale your models and write less boilerplate while maintaining your code clean and flexible to scale up. Its help researchers more to focus to solve the problem than writing engineering code.

I’ve been using PyTorch since 2 years ago, I start to use PyTorch from version 0.3.0. Before I use PyTorch, I use Keras for my Deep Learning Framework, but then I start to switch to PyTorch because of several reason. If you want to know my reason just check this article below.

Since I’ve been using PyTorch, I need to sacrifice the enjoyment of simple line code for training in Keras and wrote my own training code. Its has advantage and disadvantage, but I choose the way PyTorch write code to get more control to my training code. But every time I want to try some new model in deep learning, it's mean I need to write training and evaluation code every time.

So, I decided to build my own library that I called torchwisdom, but I got stuck because I still building OCR full pipeline system for my company. So, I tried to find another solution and then I found PyTorch Lightning, and after I see the code It just makes me fall in love in first sight.

So, the thing that I will cover in this article is installation, basic code comparison, and comparison by example that I created myself by taking some of the code from pytorch lightning site. And the last is about the conclusion for this article.

Installation

Okay, let's start with install the pytorch-lighting, so you can follow me along. You can install pytorch lightning with pip or conda.

pip install

pip install pytorch-lightning

conda install

conda install pytorch-lightning -c conda-forge

For me, I prefer to use anaconda for my python interpreter, its more complete for deep learning and data science people. It's ready with many packages of standard machine learning and data processing library from its first installation.

Basic Code Comparison

Before we jump into the code, I want you to see these picture below. There is 2 pictures below that explain what is the difference between pytorch and pytorch lightning way to code, to build the model and to train. In the left there as you can see, pytorch need more line to create model and to train.

With pytorch lightning, the code then becomes inside the LightningModule, all the engineering code for training is resolved by the pytorch lightning. But you have a certain degree to custom your training step like the example code below.

Image for post
Image is taken from PyTorch Lightning Github Repository
Image for post
Image is taken from PyTorch Github Repository

For training code you just need 3 lines of code, the first line is for instantiates the Model class, the second line is for instantiates the Trainer class, and the third line is for training the model.

This example is one of the ways to train with pytorch lightning. And of course, you can do your custom style code of pytorch, because pytorch lightning has a different degree of flexibility. Do you want to see it? Let's continue then.

Comparison By Example

Okay, after you finish your installation, let’s begin to write the code. The first thing to do lets import all the library that you need to work with. And after that, you need to build your dataset and data loader that will be used for training.

As you see the code above, we use MNIST dataset from torchvision and create data loader with torch.utils.DataLoader. Now, in the code below, we prepare the network to be compatible with MNIST dataset that has 28x28 pixel. In the first layer, there will be 128 hidden networks, in the second layer it has 256 hidden networks, and in the third layer or output, it has 10 class as output.

If you see line number 27 and 33 in the gist code above, you would see training_step and configure_optimizers method that override from class LightningModule that has been extended in line number 2. So, this makes standard nn.Module in pytorch different with LightningModule, it has some method that makes it compatible with Trainer class in line number 39.

Now, let’s try another way to write your code. Suppose you have to write a library or want to use other people library that has been written in pure pytorch. How you can use pytorch lightning in this way?

Okay, the code below has two classes, the first class is used the standard pytorch nn.Module as its parent class. And it's written as its normally written in standard pytorch module, but look at the line number 30, there is a class with name ExtendMNIST that inherit two classes. These two classes are combined together from StandardMNIST class and LightningModule class. This is what I love from python, it become possible to have more than one parent from a single class.

If you see the code that in the ExtendMNIST class, you will see that it just overrides the LightningModule class. With this way of writing code, you can just extend any other model that you’ve written before without changing it and still can use pytorch lightning library.

So, can you show me the result when its on training? Okay, let’s see how it look like when it runs in training.

Image for post
Image screenshot by Author From Author Google Collab File

So, that you have it, the screenshot for how its look like when it on training. It has a nice progress bar that shows the loss of the network, isn't that makes your life easier to train a model?

If you want to see the code in action, you can click the link below. The first link is for the standard way of pytorch lightning and the second link is for the custom way.

PyTorch Lightning Standard

Standard way

colab.research.google.com

PyTorch Lightning Custom

Custom Way

colab.research.google.com

Conclusion

PyTorch Lightning has been developed with a good standard of code, it has 229 contributors and its very actively developed. Now, It even has venture funding since it reaches version 0.7. Check this article below if you want to see the information in more detail.

With this condition (Venture Funding), I believe pytorch lightning will be stable enough to be used as your standard library of writing pytorch code and does not have to fear the development will stop in future.

As for me, I choose to use pytorch lighting in my next project, I love the flexibility, the simple and the clean way to write code for researching in deep learning.

Okay, that's all for today, have a nice day. And remember to try it, nothing to lose isn’t it.