You do the research.
Lightning will do everything else.
The ultimate PyTorch research framework.
Scale your models, without the boilerplate.

Lightning makes coding complex networks simple.

Spend more time on research, less on engineering. It is fully flexible to fit any use case and built on pure PyTorch so there is no need to learn a new language. A quick refactor will allow you to:
  • Run your code on any hardware
  • Performance & bottleneck profiler
  • Model checkpointing
  • 16-bit precision
  • Run distributed training
  • Logging
  • Metrics
  • Visualization
  • Early stopping
  • ... and many more!
# Lightning Module
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torch.utils.data import random_split
from torchvision.datasets import MNIST
from torchvision import transforms
import pytorch_lightning as pl
class
LitAutoEncoder
(pl.LightningModule):
def
__init__
(self):
super().__init__()
self.encoder = nn.Sequential(
nn.Linear(28 * 28, 64),
nn.ReLU(),
nn.Linear(64, 3))
self.decoder = nn.Sequential(
nn.Linear(3, 64),
nn.ReLU(),
nn.Linear(64, 28 * 28))
def
forward
(self, x):
embedding = self.encoder(x)
return
embedding
def
configure_optimizers
(self):
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
return
optimizer
def
training_step
(self, train_batch, batch_idx):
x, y = train_batch
x = x.view(x.size(0), -1)
z = self.encoder(x)    
x_hat = self.decoder(z)
loss = F.mse_loss(x_hat, x)
self.log('train_loss', loss)
return
loss
def
validation_step
(self, val_batch, batch_idx):
x, y = val_batch
x = x.view(x.size(0), -1)
z = self.encoder(x)
x_hat = self.decoder(z)
loss = F.mse_loss(x_hat, x)
self.log('val_loss', loss)
# data
dataset = MNIST('', train=True, download=True, transform=transforms.ToTensor())
mnist_train, mnist_val = random_split(dataset, [55000, 5000])
train_loader = DataLoader(mnist_train, batch_size=32)
val_loader = DataLoader(mnist_val, batch_size=32)
# model
model =
LitAutoEncoder
()
# training
trainer = pl.Trainer(gpus=4, num_nodes=8, precision=16, limit_train_batches=0.5)
trainer.fit(model, train_loader, val_loader)

Ultimate Flexibility for Any Type of Research

TPUs or GPUs,
without code changes

Want to train on multiple GPUs? TPUs? Determine your hardware on the go. Change one trainer param and run!

Run Notebook

Turn PyTorch into Lightning

Lightning is just plain PyTorch

1. Computational code goes into LightningModule

Model architecture goes to init

2. Set forward hook

In lightning, forward defines the prediction/inference actions

3. Optimizers go into configure_optimizers LightningModule hook

Pass all optimizers and schedulers

4. Training logic into training_step LightningModule hook

Use self.log to send any metric to your preffered logger

5. Validation logic goes to validation_step LightningModule hook

self.log will automatically accumulate and log at the end of the epoch

6. Remove Any .CUDA() or .to(device) calls

Your Lightning Module is Hardware agnostic

7. Override LightningModule Hooks as needed

LightningModule has over 20 hooks you can override to keep all the flexibility

8. Init LightningModule

9. Init the Lightning Trainer

The Lightning trainer automates all the engineering (loops, hardware calls, .train(), .eval()…)

10. Pass in any PyTorch DataLoader to trainer.fit

Or you can use LIghtningDataModule API for reusability

Train as fast as lightning

You can train on multi GPUs or TPUs, without changing your model

Train as fast as lightning

Train on CPUs

Train as fast as lightning

Train on GPUs

Train as fast as lightning

Train on TPUs
pytorch code
# models
encoder = nn.Sequential(nn.Linear(28 * 28, 64), nn.ReLU(), nn.Linear(64, 3))
decoder = nn.Sequential(nn.Linear(28 * 28, 64), nn.Linear(3, 64), nn.ReLU(), nn.Linear(64, 28 * 28))
encoder.cuda(0)
decoder.cuda(0)
# download on rank 0 only
if global_rank == 0:
mnist_train = MNIST(os.getcwd(), train=True, download=True)
# download on rank 0 only
transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize(0.5, 0.5)])
mnist_train = MNIST(os.getcwd(), train=True, download=True, transform=transform)
# train (55,000 images), val split (5,000 images)
mnist_train, mnist_val = random_split(mnist_train, [55000, 5000])
# The dataloaders handle shuffling, batching, etc...
mnist_train = DataLoader(mnist_train, batch_size=64)
mnist_train
mnist_val = DataLoader(mnist_val, batch_size=64)
mnist_val
# optimizer
params = [encoder.parameters(), decoder.parameters()]
optimizer = torch.optim.Adam(params, lr=1e-3)
# TRAIN LOOP
model.train()
num_epochs = 1
for epoch in range(num_epochs):
for train_batch in mnist_train:
x, y = train_batch
x = x.cuda(0)
x = x.view(x.size(0), -1)
z = encoder(x)
x_hat = decoder(z)
loss = F.mse_loss(x_hat, x)
print('train loss: ', loss.item())
loss.backward()
optimizer.step()
optimizer.zero_grad()
# EVAL LOOP
model.eval()
with torch.no_grad():
val_loss = []
for val_batch in mnist_val:
x, y = val_batch
x = x.cuda(0)
x = x.view(x.size(0), -1)
z = encoder(x)
x_hat = decoder(z)
loss = F.mse_loss(x_hat, x)
val_loss.append(loss)
val_loss = torch.mean(torch.tensor(val_loss))
model.train()
Lightning code
# model
class LITAutoEncoder(pl.LightningModule):
def __init__(self):
super().__init__()
self.encoder = nn.Sequential(n.Linear(28 * 28, 64), n.ReLU(), nn.Linear(64, 3))
self.decoder = nn.Sequential(nn.Linear(28 * 28, 64), nn.Linear(3, 64), nn.ReLU(), nn.Linear(64, 28 * 28))
encoder.cuda(0)
decoder.cuda(0)
def forward(self, x):
embedding = self.encoder(x)
return embedding
def configure_optimizers(self):
params = [encoder.parameters(), decoder.parameters()]
optimizer = torch.optim.Adam(self.parameters, lr=1e-3)
optimizer = torch.optim.Adam(params, lr=1e-3)
return optimizer
return optimizer
def training_step(self, train_batch, batch_idx):
x, y = train_batch
x = x.cuda(0)
x = x.view(x.size(0), -1)
z = self.encoder(x)
x_hat = self.decoder(z)
loss = F.mse_loss(x_hat, x)
self.log(‘train_loss’, loss)
return loss
def validation_step(self, val_batch, batch_idx):
x, y = val_batch
x = x.cuda(0)
x = x.view(x.size(0), -1)
z = self.encoder(x)
x_hat = self.decoder(z)
loss = F.mse_loss(x_hat, x)
val_loss.append(loss)
self.log(‘val_loss’, loss)
def backward(self, trainer, loss, optimizer, optimizer_idx):
loss.backward()
# train
model = LITAutoEncoder()
trainer = pl.Trainer()
trainer.fit(model, mnist_train, mnist_val)
lyghtning code
# Train
model = LitAutoEncoder()
trainer = pl.Trainer( gpus=4tpu_cores=8)
trainer.fit(model, mnist_train, mnist_val)
GPU available: True, Used: False
GPU available: True, Used: True
GPU available: True, used: False
CUDA_VISIBLE_DEVICES: [0,1,2,3]
TPU available: True, using: 8 TPU cores
training on 8 TPU cores
Epoch 1:
Epoch 2:

Seamlessly train hundreds of models in the cloud from your laptop with Grid.

Use Grid to seamlessly orchestrate training in the cloud and manage artifacts like checkpoints and logs - all from your laptop without changing a line of code.
Learn more about Grid

Join us on Slack!

Our bustling, friendly slack community has hundreds of experienced Deep Learning experts of all kinds and a channel for (almost) everything you can think of from #ai to #transformers, #questions to #jokes and everything in between.

Join our #Slack

Copied!