December 1, 2019

Converting From Keras To PyTorch Lightning

William Falcon
Image for post
Photo By: Nicole Crank

In this tutorial, we’ll convert a Keras model into a PyTorch Lightning model to add another capability to your deep-learning ninja skills.

Keras provides a terrific high-level interface to Tensorflow. Now Keras users can try out PyTorch via a similar high-level interface called PyTorch Lightning.

However, Lightning differs from Keras in that it’s not so much a framework but more of a style-guide for PyTorch which gives users (researchers, students, production teams) ultimate flexibility to try crazy ideas, without having to learn yet another framework while automating away all the engineering details.

Lightning differs from Keras in that:

  1. It is not a framework but more of a style guide.
  2. Lightning does not hide away details of network and optimization design.
  3. Lightning automates away all the engineering like early stopping, multi-GPU training, multi-node training, etc…
  4. Lightning gives ultimate flexibility to researchers and production teams who might need to do things like negative sampling across multiple GPUs or set up their own distributed environments.

Why learn PyTorch Lightning?

It’s useful to be able to work with both Tensorflow and PyTorch depending on what your team’s needs are. Users of Lightning love the flexibility of the framework. They can keep things very simple or modify the training behavior down to how the backward step is done.

This balance of simplicity and outmost complexity (for those who need it) makes Lightning a unique framework that makes productionizing and prorotyping with PyTorch very easy.

In fact, it even makes deep learning research more reproducible (check out the NeurIPS reproducibility challenge)!

Hello MNIST in Keras

Image for post

To see the core of how to convert a project, let’s look at the canonical MNIST Keras example.

This code can be broken down into a few sections.

Section 1: Imports

Section 2: Data Loading

Section 3: Model definition

Section 4: Training

Section 5: Testing

Each section above has an equivalent in PyTorch Lightning. But in Lightning the code is structured in a specific interface where each function takes care of one of the above sections.

LightningModule

Image for post

Every research idea is implemented into a different LightningModule. Here is the same Keras example as a LightningModule.

LightningModule - PyTorch-Lightning 0.6.0 documentation

A LightningModule is a strict superclass of torch.nn.Module but provides an interface to standardize the "ingredients"…

pytorch-lightning.readthedocs.io

Although it looks more verbose on the surface, the added lines give you deeper control over what’s happening. Notice the following:

  1. Through the beauty of PyTorch, now you can debug the flow of the model in the forward step.
Image for post

In fact, you can even try experimenting in real-time by changing the input in the debugger. In this case we just want to see what happens to the dimensions after the second layer, so in the debugger we feed a tensor with the dimensions we want to test.

Image for post

2. What happens in the training step, validation step and test step are decoupled. For instance, we can calculate the accuracy in the validation step but not the training step.

If we were doing something like machine translation, we could do a beam search in the validation step to generate a sample.

3. The dataloading is abstracted nicely behind the dataloaders.

4. The code is standard! If a project uses Lightning, you can see the core of what’s happening by looking in the training step… of any project! That’s an incredible step towards helping achieve more reproducible deep learning code.

5. This code is pure PyTorch! There’s no abstraction on top… this means you can get as crazy as you need with your code.

In summary, the LightningModule groups the core ingredients we need to build a deep learning system:

  1. The computations (init, forward).
  2. What happens in the training loop (training_step).
  3. What happens in the validation loop (validation_step).
  4. What happens in the testing loop (test_step).
  5. The optimizer(s) to use (configure_optimizers).
  6. The data to use (train, test, val dataloaders).

Trainer

Notice that the LightningModule had nothing about GPUs or 16-bit precision or early stopping or logging or anything like that… all of that is automatically handled by the trainer.

That’s all it takes to train this model! The trainer handles everything for you including:

  1. Early stopping
  2. Automatic logging to tensorboard (or comet, mlflow, etc…)
  3. Auto checkpointing
Image for post

If you run the trainer as is, you’ll notice a folder called lightning_logs which you can run a tensorboard session from :)

tensorboard — logdir=./lightning_logs
Image for post

All of this is free out of the box!

In fact, Lightning adds a nifty text window to show you what parameters you used for this particular experiment. You get this for free as long as you pass the hparams argument to the LightningModule.

Image for post

Now each run of your model knows exactly what hyperparameters it used!

GPU Training

Lightning makes GPU and multi-GPU training trivial. For instance, if you want to train the above example on multiple GPUs just add the following flags to the trainer:

Using the above flags will run this model on 4 GPUs.

If you want to run on say 16 GPUs, where you have 4 machines each with 4 GPUs, change the trainer flags to this:

And submit the following SLURM job:

Advanced Use

Lightning is very easy to use for students, and people who are somewhat familiar with deep learning. However for advanced users like researchers and production teams, Lightning gives you even more control.

For instance, you can do things like:

  1. Gradient Clipping
Trainer(gradient_clip_val=2.0)

2. Accumulated Gradients

Trainer(accumulate_grad_batches=12)

3. 16-bit precision

Trainer(use_amp=True)

4. Truncated back-propagation through time

Trainer(truncated_bptt_steps=3)

and about 42 more advanced features

Advanced use ++

But maybe you need even MORE flexibility. In this case, you can do things like:

  1. Change how the backward step is done.
  2. Change how 16-bit is initialized.
  3. Add your own way of doing distributed training.
  4. Add Learning rate schedulers.
  5. Use multiple optimizers.
  6. Change the frequency of optimizer updates.

And many many more things. Under the hood, everything in Lightning is implemented as hooks that can be overridden by the user. This makes EVERY single aspect of training highly configurable — which is exactly the flexibility a research or production team needs.

Summary

This tutorial explained how to convert a Keras model into PyTorch Lightning for users who want to try PyTorch or are looking for more flexibility.

The key things about Lightning are that:

  1. It is not a framework but more of a style-guide for PyTorch. This is because Lightning exposes core PyTorch and does not abstract it away.
  2. By using the Lightning format, you get rid of the engineering complexities and allow your code to be reproducible.
  3. Lightning can be very simple for new users but extremely flexible for even the most advanced research teams.
  4. Although we showed an MNIST example, Lightning can implement arbitrarily complex approaches. Check out this colab!