In this tutorial, we’ll convert a Keras model into a PyTorch Lightning model to add another capability to your deep-learning ninja skills.
Keras provides a terrific high-level interface to Tensorflow. Now Keras users can try out PyTorch via a similar high-level interface called PyTorch Lightning.
However, Lightning differs from Keras in that it’s not so much a framework but more of a style-guide for PyTorch which gives users (researchers, students, production teams) ultimate flexibility to try crazy ideas, without having to learn yet another framework while automating away all the engineering details.
Lightning differs from Keras in that:
It’s useful to be able to work with both Tensorflow and PyTorch depending on what your team’s needs are. Users of Lightning love the flexibility of the framework. They can keep things very simple or modify the training behavior down to how the backward step is done.
This balance of simplicity and outmost complexity (for those who need it) makes Lightning a unique framework that makes productionizing and prorotyping with PyTorch very easy.
In fact, it even makes deep learning research more reproducible (check out the NeurIPS reproducibility challenge)!
To see the core of how to convert a project, let’s look at the canonical MNIST Keras example.
This code can be broken down into a few sections.
Section 1: Imports
Section 2: Data Loading
Section 3: Model definition
Section 4: Training
Section 5: Testing
Each section above has an equivalent in PyTorch Lightning. But in Lightning the code is structured in a specific interface where each function takes care of one of the above sections.
Every research idea is implemented into a different LightningModule. Here is the same Keras example as a LightningModule.
Although it looks more verbose on the surface, the added lines give you deeper control over what’s happening. Notice the following:
In fact, you can even try experimenting in real-time by changing the input in the debugger. In this case we just want to see what happens to the dimensions after the second layer, so in the debugger we feed a tensor with the dimensions we want to test.
2. What happens in the training step, validation step and test step are decoupled. For instance, we can calculate the accuracy in the validation step but not the training step.
If we were doing something like machine translation, we could do a beam search in the validation step to generate a sample.
3. The dataloading is abstracted nicely behind the dataloaders.
4. The code is standard! If a project uses Lightning, you can see the core of what’s happening by looking in the training step… of any project! That’s an incredible step towards helping achieve more reproducible deep learning code.
5. This code is pure PyTorch! There’s no abstraction on top… this means you can get as crazy as you need with your code.
In summary, the LightningModule groups the core ingredients we need to build a deep learning system:
Notice that the LightningModule had nothing about GPUs or 16-bit precision or early stopping or logging or anything like that… all of that is automatically handled by the trainer.
That’s all it takes to train this model! The trainer handles everything for you including:
If you run the trainer as is, you’ll notice a folder called lightning_logs which you can run a tensorboard session from :)
tensorboard — logdir=./lightning_logs
All of this is free out of the box!
In fact, Lightning adds a nifty text window to show you what parameters you used for this particular experiment. You get this for free as long as you pass the hparams argument to the LightningModule.
Now each run of your model knows exactly what hyperparameters it used!
Lightning makes GPU and multi-GPU training trivial. For instance, if you want to train the above example on multiple GPUs just add the following flags to the trainer:
Using the above flags will run this model on 4 GPUs.
If you want to run on say 16 GPUs, where you have 4 machines each with 4 GPUs, change the trainer flags to this:
And submit the following SLURM job:
Lightning is very easy to use for students, and people who are somewhat familiar with deep learning. However for advanced users like researchers and production teams, Lightning gives you even more control.
For instance, you can do things like:
Trainer(gradient_clip_val=2.0)
Trainer(accumulate_grad_batches=12)
3. 16-bit precision
Trainer(use_amp=True)
4. Truncated back-propagation through time
Trainer(truncated_bptt_steps=3)
and about 42 more advanced features…
But maybe you need even MORE flexibility. In this case, you can do things like:
And many many more things. Under the hood, everything in Lightning is implemented as hooks that can be overridden by the user. This makes EVERY single aspect of training highly configurable — which is exactly the flexibility a research or production team needs.
This tutorial explained how to convert a Keras model into PyTorch Lightning for users who want to try PyTorch or are looking for more flexibility.
The key things about Lightning are that: