August 12, 2020

Pytorch Lightning Machine Learning Zero To Hero In 75 Lines Of Code

Sandro Luck (DGuyAI)

Pytorch Lightning is taking the world by storm. Don’t miss out on these 75 lines of code that kick start your machine learning road to mastery. We will cover Early Stopping, Auto Batch Scaling, Auto Learning Rate finding, Dynamic Batch Sizes, Datasets in Pytorch, Saving your Model, and Visualization. All in under 75 Lines.

Image for post

Our Road to Mastery, Image by Author

The Basics

Pytorch Lightning has features that will drastically reduce your development time. Once you understand the Basics your efficiency will increase magically. Under the hood, Lightning is still Pytorch but way easier and faster to work with. You do not have to worry about most things. Sit back and relax and Lightning does its work.

NOTE: I will show in this article the CPU-based code, but will add a GPU-based version at the bottom of this article (minor differences)

As we can see we simply import all cool features and we are ready to go. We will look at each of them once we use it

1. The Dataset

We will now define our own dataset, Lightning handles it from there. We could also load one from the internet, but knowing how to get our own data into the model is key!

9 lines for your dataset, this is basically still plain Pytorch. A dataset is a class from PyTorch that allows us to leverage many benefits. All we have to do is make our class inherit from “Dataset”. A dataset class needs minimally these 3 methods.

Today we will build a very simple model, that will learn to predict whether the sum of all input elements is >0 or <0. This allows us to learn about Lightning with minimal cluttering around our code.

__init__: here we initiate the dataset, we basically just create a random tensor of shape (samples, dimension). You can imagine this being a Table with 42 rows and 21000 columns, where each cell is between -1 and 1.

__getitem__: Pytorch needs a way to retrieve the item at “index” here we basically just return the “index-th” item from our dataset. We then generate a label, the label here is 1 if the sum of the input elements is >0 and 0 otherwise. Eventually, we return a tuple (Features, Label) so Lightning can do its job.

__len__: Pytorch wants to know how many samples we have in the dataset. self.dataset.size() return the dimensions of our table (42,21000) and [0] takes the 42 our of the tuple, this is our length.

2. The Model

Let the games begin. We will now define a really simple but powerful model! A model in Lightning needs to inherit from “LightningModule”. For a Lightning model to work we need to define at least the __init__(),forward(),train_dataloader() and training_step() functions. We will even define a few more to get into some shiny features that are brighter than Zeus himself.

__init__(): Here we basically just define a Linear-layer model with 2048 nodes in each, we will use them later in the forward step. Additionally, we define the self.lr and the self.batch_size to use Lightning auto-learning features later on. Yes, you heard right we will learn them while training!

forward(): This is the heart and soul of our model. Here we define how the information flows through our model. We basically just call our layers in order and do a simple reshape at the end to get things consistent. Basically, the input here is one of the rows of our spreadsheet we then throw it through our pipeline, and in the end the output will be between 0 and 1.

configure_optimizer: We define an adam optimizer, this is the thing that helps us learn. We make the learning rate tuneable such that we can learn that one too

train_dataloader(): This function has to return a data loader. A data loader is an object that helps Pytorch feed the training samples into the model, it handles the batch size for use and saves tons of code. We simply throw the dataset we defined in 1. into it and say we will generate 43210 samples. Ah, life can sometimes be simpler than being bored in quarantine.

The next for functions are basically a bit of repetition work for us. We define a loss function such that Adam knows what to optimize. The validation data loader is the same thing we just programmed and helps us simply getting a good comparison set so we can implement early stopping later. The rest is simply aggregating statistics for our super cool visualization.

3. Auto Learn Your Learning Rate

Congratulations the last piece of code will be your reward. We will define a simple main function to start our killer pipeline.

Let’s bring this together! First, we use the cool function seed_everything(42) this will greatly help us in making our runs more reproducible. While it doesn’t make Machine Learning completely predictable, it definitely helps to eliminate most randomness.

In this example we will use our CPU, so we don’t have to spin up our machines. But running it with GPU is just as easy, I will leave the GPU code at the bottom (4 lines are different)

Early stopping is a great mechanism to avoid overfitting with this easy call we can monitor the validation loss. As soon as the training loss and the validation loss diverge the model will stop training. We can assume then that we did neither overfit nor underfit two of the main problems in machine learning.

After we cast the model to our CPU, we can start up our Trainer! This is basically all, we can now start to learn and fit our model. After this, we simply save our model and we are done. Let’s get into the two big time-saving features, I love those two more than my beloved cat, they safe me probably a full day of work per week.

Image for post
Training Process

Early Stopping

As we can see the Early stopping hits at around 21 epochs, this is the point were further training does not reduce the validation loss any further. The Trainer simply stops there and selects that model.

Auto Learn Your Learning Rate

This is probably one of the most useful Machine Learning features ever discovered. No more tweaking the learning rate, no more grid searching, power at the tip of your fingers. So how does it work? We simply turn on the flag “auto_lr_find = True” in our Trainer and we are ready to go. What this basically does it tests several learning rates and checks which one is king. It does so by testing how much the loss decreases on a selection of batches. For more details, I can recommend the excellent paper “ Cyclical Learning Rates for Training Neural Networks” or for a more simple and hands-on explanation this https://github.com/davidtvs/pytorch-lr-finder explains a very similar process pretty straight forward.

Image for post
Learning rate is learned

4. Auto Learn Your Batch Size

Another killer feature that really can save a lot of thinking capacity. It helped me a lot figuring out what GPU can run with how many batches when working on a cluster with several different types a godsend invention. To turn it on we simply set “auto_scale_batch_size=True”.

DANGER WARNING: When you run auto batch and auto-learning rate together these processes will takes ages (but it will work). So first find the ideal batch size then use the learning rate finder.

Image for post
Learning the batch size

Auto batch scaling Automatically tries to find the largest batch size that fits into memory, before any training. While with this simple example we won't hit the ceiling, once you run your serious models this will converge really fast.

5. Visualization

One more thing. What we coded in these 75 lines is enough for a simple visualization using tensorboard. After you ran the code a folder name lightning_logs will be created and populated with the results of our run. All we have to do is run the command tensorboard with the flag logdir and we are good to analyze what we did!

tensorboard --logdir=lightning_logs/
Image for post

As we can see we recorded both the validation loss and the training loss. We did this simply by returning it inside the validation_step/training_step function. The plots are made with the batches on the X-axis, we could change that by returning only the aggregated the values over epochs, inside the validation_epoch_end function.

Conclusion

We saw how easy life can be. Pytorch Lightning probably one of the least effort per feature modules out there. While we did not cover all features we did look at the most important ones from my perspective. Another few honorable mentions are Half precision training, multi-GPU support, and various logging and datasets downscaling features.

You have now the skills and knowledge to train your first model carried by the speed of lightning under your wings. Make sure to stay tuned and follow me for more productivity and performance hacks in your Machine Learning models and pipelines.

If you enjoyed this article, I would be excited to connect on Twitter or LinkedIn.

Make sure to check out my YouTube channel, where I will be publishing new videos every week.

Image for post

Full Code