August 1, 2020

Use Pytorch Lightning with Weights & Biases

Lavanya Shukla, Ayush Chaurasia

PyTorch Lightning is a lightweight wrapper for organizing your PyTorch code and easily adding advanced features such as distributed training and 16-bit precision.

Coupled with Weights & Biases integration, you can quickly train and monitor models for full traceability and reproducibility with only 2 extra lines of code:

from pytorch_lightning.loggers import WandbLogger
wandb_logger = WandbLogger()

In this example, we will optimize simple models on the MNIST dataset.

Try Pytorch Lightning →

🚀 Installing

Pytorch-lightning and W&B are easily installable via pip.

pip install pytorch-lightning wandb

We just need to import a few Pytorch-Lightning modules  as well as the WandbLogger and we are ready to define our model.

🏗️ Defining our model with LightningModule

Research often involves editing the boiler plate code with new experimental variations. Most of the errors get introduced into the codebase due to this tinkering process. Pytorch lighting significantly reduces the boiler plate code by providing definite code structures for defining and training models.

To create a neural network class in Pytorch we have to import or extend from torch.nn.module.  Similarly, when we use Pytorch-Lightning, we import the class pl.LightningModule.

Let’s create the class which we’ll use to train a model for classifying the MNIST dataset. We’ll use the same example as the one in the official documentation in order to compare our results.

class LitMNIST(LightningModule):

 def __init__(self):

   # mnist images are (1, 28, 28) (channels, width, height)
   self.layer_1 = torch.nn.Linear(28 * 28, 128)
   self.layer_2 = torch.nn.Linear(128, 256)
   self.layer_3 = torch.nn.Linear(256, 10)

In addition, we can also add metrics and save our hyper-parameters.

Full code here →

As you can see above, except for the base class imported, everything else in the code is pretty much same as the Pytorch equivalent would be.

We then need to define a few more methods:

   def forward(self, x):
       '''method used for inference input -> output'''

       batch_size, channels, width, height = x.size()

       # (b, 1, 28, 28) -> (b, 1*28*28)
       x = x.view(batch_size, -1)
       x = self.layer_1(x)
       x = F.relu(x)
       x = self.layer_2(x)
       x = F.relu(x)
       x = self.layer_3(x)

       x = F.log_softmax(x, dim=1)
       return x

   def training_step(self, batch, batch_idx):
       '''needs to return a loss from a single batch'''
       x, y = batch
       logits = self(x)
       loss = F.nll_loss(logits, y)

       # Log training loss
       self.log('train_loss', loss)

       # Log metrics
       self.log('train_acc', self.accuracy(logits, y))

       return loss

The definition of optimizers is the same as in Pytorch but needs to be done through configure_optimizes.

   def configure_optimizers(self):
       '''defines model optimizer'''
       return Adam(self.parameters(),

Finally we can add a validation_step and a test_step for logging losses and metrics.

📊 Data loading

Data pipelines can be created with:

DataModules are more structured definition, which allows for additional optimizations such as automated distribution of workload between CPU & GPU.Using DataModules is recommended whenever possible!

A DataModule is also defined by an interface:

Here’s how to do this in code.

class MNISTDataModule(LightningDataModule):

   def __init__(self, data_dir='./', batch_size=256):
       self.data_dir = data_dir
       self.batch_size = batch_size
       self.transform = transforms.ToTensor()

   def prepare_data(self):
       '''called only once and on 1 GPU'''
       # download data
       MNIST(self.data_dir, train=True, download=True)
       MNIST(self.data_dir, train=False, download=True)

   def setup(self, stage=None):
       '''called one ecah GPU separately - stage defines if we are at fit or test step'''
       # we set up only relevant datasets when stage is specified (automatically set by Pytorch-Lightning)
       if stage == 'fit' or stage is None:
           mnist_train = MNIST(self.data_dir, train=True, transform=self.transform)
           self.mnist_train, self.mnist_val = random_split(mnist_train, [55000, 5000])
       if stage == 'test' or stage is None:
           self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)

   def train_dataloader(self):
       '''returns training dataloader'''
       mnist_train = DataLoader(self.mnist_train, batch_size=self.batch_size)
       return mnist_train

   def val_dataloader(self):
       '''returns validation dataloader'''
       mnist_val = DataLoader(self.mnist_val, batch_size=self.batch_size)
       return mnist_val

   def test_dataloader(self):
       '''returns test dataloader'''
       mnist_test = DataLoader(self.mnist_test, batch_size=self.batch_size)
       return mnist_test

👟 Training the model

If we consider a traditional Pytorch training pipeline, we’ll need to implement the loop for epochs, iterate the mini-batches, perform feed forward pass for each mini-batch, compute the loss, perform backprop for each batch and then finally update the gradients.

To do the same in Pytorch Lightning, we just pulled out the main elements of the training logic and data loading within Pytorch Lightning modules.

Using these functions, Pytorch Lightning will automate the training part of the pipeline. We’ll get to that but before let’s see how pytorch lightning easily integrates with Weights & Biases to track experiments and create visualizations you can monitor from anywhere.

train_loss, val_loss

sgd-64-0.01   train_loss    sgd-64-0.01   val_loss  

Run set


Track Pytorch Lightning Model Performance with WandB

Let’s see how the wandbLogger integrates with lightning.

from pytorch_lightning.loggers import WandbLogger
wandb_logger = WandbLogger(name='Adam-32-0.001',project='pytorchlightning')

Here, we’ve created a wandbLogger object which holds the details about the project and the run being logged.

Training Loop

Now, let’s jump into the most important part of training any model, the training loop. As we are using Pytorch Lightning, most of the logic is already captured behind the scenes. We just need to specify a few hyper-parameters and the training process will be completed automatically using a Trainer. As an added benefit, you’ll also get a cool progress bar for each iteration.

Full code here →

# setup data
mnist = MNISTDataModule()

# setup model - choose different hyperparameters per experiment
model = LitMNIST(n_layer_1=128, n_layer_2=256, lr=1e-3)

# define a Trainer
trainer = Trainer(
   logger=wandb_logger,    # W&B integration
   gpus=-1,                # use all GPU's
   max_epochs=3            # number of epochs

The important part in the code regarding the visualization is the part where WandbLogger object is passed as a logger in the Trainer object of Pytorch Lightning. This will automatically use the logger to log the results.

def train():

This is all you need to do in order to train your Pytorch model using Pytorch Lightning. This one line code will easily replace your bulky and inefficient vanilla Pytorch code.

Visualizing Performance with Weights & Biases

Let’s have a look at the visualizations generated for this run.

Train loss and validation loss for the particular run are automatically logged in the dashboard in real-time as the model is being trained.

We can repeat the same training step with different hyper-parameters to compare different runs. We’ll change the name of the logger to uniquely identify each run.

Full code here →

wandb_logger = WandbLogger(name='Adam-32-0.001',project='pytorchlightning')
wandb_logger = WandbLogger(name='Adam-64-0.01',project='pytorchlightning')
wandb_logger = WandbLogger(name='sgd-64-0.01',project='pytorchlightning')

Here I’ve used a convention to name the runs. The first part is the optimizer, the second is the mini-batch size and third is the learning rate. For example the name ‘Adam-32-0.001’ means the optimizer being used is Adam with batch size of 32 and the learning rate is 0.001.

You can see how each model is performing in the plots above.

These visualizations are stored forever in your project which makes it much easier to compare the performances of variations with different hyperparameters, restore the best performing model and share results with your team.


Baseline   sgd-64-0.01   sgd-32-0.01   sgd-64-0.001   sgd-32-0.001   Adam-64-0.01   Adam-64-0.001   Adam-32-0.01   Adam-32-0.001  


Baseline   sgd-64-0.01   sgd-32-0.01   sgd-64-0.001   sgd-32-0.001   Adam-64-0.01   Adam-64-0.001   Adam-32-0.01   Adam-32-0.001  

Run set


Pytorch Lightning provides 2 methods to incorporate early stopping. Here’s how you can do use them:

Full code here →

# A) Set early_stop_callback to True. Will look for 'val_loss'
# in validation_end() return dict. If it is not found an error is raised.
trainer = Trainer(early_stop_callback=True)

# B) Or configure your own callback
early_stop_callback = EarlyStopping(
trainer = Trainer(early_stop_callback=early_stop_callback)

As we’ve defined a validation function, we can directly set the early_stop_callback = true:

trainer = pl.Trainer(max_epochs = 5,logger= wandb_logger, gpus=1, distributed_backend='dp',early_stop_callback=True)


EarlyStopping-Adam-0.001-32   EarlyStopping   Baseline  


EarlyStopping-Adam-0.001-32   EarlyStopping   Baseline  

Run set


Depending on the requirements of a project, you might need to increase or decrease the precision of the weights of a model. Reducing precision allows you to fit bigger models into your GPU. Let’s see how we can incorporate 16-bit precision in pytorch lightning.

First, we need to install NVIDIA apex. To do that, we’ll create a shell script in colab and execute it.

Full code here →


git clone
pip install -v --no-cache-dir ./apex

You’ll need to restart the runtime after installing apex.

Now we can directly pass in the required value in the precision parameter of the trainer.

trainer = pl.Trainer(max_epochs = 100,logger= wandb_logger, gpus=1, distributed_backend='dp',early_stop_callback=True, amp_level='O1',precision=16)


16-bit-adam-0.001   Baseline  


16-bit-adam-0.001   Baseline  


16-bit-adam-0.001   Baseline  

Run set


Lightning provides a simple API for performing data parallelism and multi-gpu training. You don’t need to use torch’s data parallelism class in the sampler. You just need to specify the parallelism mode and the number of GPUs you wish to use.

There are multiple ways of training:

We’ll use the data parallel backend in this post. Here’s how we can incorporate it in the existing code.

trainer = pl.Trainer(gpus=1, distributed_backend='dp',max_epochs = 5,logger= wandb_logger)

Here I’m using only 1 GPU as I’m working on google colab.

As you use more GPUs, you'd be able to monitor the difference in memory usage between different configurations in wandb, like in the plot to the left.


MultiGPU   Baseline  


MultiGPU   Baseline  

system/gpu.0.gpu, system/proc.memory.percent

MultiGPU   GPU 0 Utilization (%)   Baseline   GPU 0 Utilization (%)   MultiGPU   system/proc.memory.percent   Baseline   system/proc.memory.percent  

Run set


Often during research, you’ll need to train a model in intervals. This brings up the need to stop the training, save the state, load the saved state later and then resume the training where we stopped.

Being able to save and restore models also allows you collaborate more effectively with your team and return to experiments from a few weeks ago.

To save pytorch lightning models with wandb, we use:


This creates a checkpoint file in the local runtime, and uploads it to wandb. Now, when we decide to resume training even on a different system, we can simply load the checkpoint file from wandb and load it into our program like so:


Now the checkpoint has been loaded into the model and the training can be resumed using the desired training module.Now that we’ve seen the simplistic framework that lightning provides, let’s have a quick look at how it compares with pytorch. In lightning, we can train the model with automatic callbacks as well as progress bars by just creating a trainer and calling train() method on it.Let’s see how the same can be achieved using Vanilla Pytorch.

Full code here →

pytorch_model = MNISTClassifier()
optimizer = torch.optim.Adam(pytorch_model.parameters(), lr=1e-3)

# ----------------
# ----------------
def cross_entropy_loss(logits, labels):
 return F.nll_loss(logits, labels)

# ----------------
# ----------------
num_epochs = 1
for epoch in range(num_epochs):

 for train_batch in mnist_train:
   x, y = train_batch

   logits = pytorch_model(x)
   loss = cross_entropy_loss(logits, y)
   print('train loss: ', loss.item())



 with torch.no_grad():
   val_loss = []
   for val_batch in mnist_val:
     x, y = val_batch
     logits = pytorch_model(x)
     val_loss.append(cross_entropy_loss(logits, y).item())

   val_loss = torch.mean(torch.tensor(val_loss))
   print('val_loss: ', val_loss.item())

You can see how complicated the training code can get and we haven’t even included the modifications to incorporate multi GPU training, early stopping or tracking performance with wandb yet.

For adding distributed training in Pytorch, we need to use DistributedSampler for sampling our dataset.

def train_dataloader(self):
   dataset = MNIST(...)
   sampler = None

   if self.on_tpu:
       sampler = DistributedSampler(dataset)

   return DataLoader(dataset, sampler=sampler)

You’ll also need to write a custom function to incorporate early stopping.But when using lightning, all of this can be accomplished by one line of code.

#Pytorch Lightning
trainer = pl.Trainer(max_epochs = 5,logger= wandb_logger, gpus=1, distributed_backend='dp',early_stop_callback=True)

That’s all for this post.

Give Pytorch Lightning a try →.

If you have any questions about integrating lightning with Weights and Biases, we'd love to answer them in our slack community.