December 10, 2020

PyTorch Lightning 1.1 - Model Parallelism Training and More Logging Options

PyTorch Lightning team
Image for post

Lightning 1.1 is now available with some exciting new features. Since the launch of V1.0.0 stable release, we have hit some incredible milestones- 10K GitHub stars, 350 contributors, and many new members in our slack community! A few highlights include:

Sharded model training [BETA]

We're thrilled to introduce the beta version of our new sharded model training plugin, in collaboration with FairScale by Facebook. Sharded Training utilizes Data-Parallel Training under the hood, but optimizer states and gradients are sharded across GPUs. This means the memory overhead per GPU is lower, as each GPU only has to maintain a partition of your optimizer state and gradients. You can use this plugin to reduce memory requirements by up to 60% (!) by simply adding a single flag to your Lightning trainer, with no performance loss.

# install fairscale
pip install https://github.com/PyTorchLightning/fairscale/archive/pl_1.1.0.zip# train using Sharded DDP
trainer = Trainer(gpus=8, accelerator='ddp', plugins='ddp_sharded')

Average Peak Memory Training a Transformer LM ((22 layers, hidden size 3072, trained on SST, 2 billion variant with 32 layers), SwAV Wide Resnet (trained on STL-10), DeepSpeech2 (trained on Librispeech100), iGPT (trained on MNIST) using 8 A100s. Uses same hyper-parameters and batch size per model. We increase model capacity to roughly a billion parameters. Lower is better. Image by Author

Image for post

To learn more about our new sharded training, read this blog.

Pipeline model sharding [BETA]

This release also includes integration for Sequential Model Parallelism from FairScale. Sequential Model Parallelism allows splitting a sequential module onto multiple GPUs according to the preffered balance, reducing peak GPU memory requierements. Furthermore, Model Parallelism supports micro-batches and memory monger for fitting even larger sequential model.

To use Sequential Model Parallelism, you must define a nn.Sequential module that defines the layers you wish to parallelize across GPUs. This should be kept within the sequential_module variable within your LightningModule like below.

Want to give it a try? We provide a minimal example of Sequential Model Parallelism using a convolutional model training on cifar10, split onto GPUs here. Simply run:

pip install pytorch-lightning-boltpython pl_examples/basic_examples/conv_sequential_example.py --batch_size 1024 --gpus 2 --accelerator ddp --use_ddp_sequential

Automatic logging everywhere

In 1.0 we introduced a new easy way to log any scalar in the training or validation step, using self.log the method. It is now available in all LightningModule or Callback hooks (except hooks for *_batch_start- such as on_train_batch_start or on_validation_batch_start. Use on_train_batch_end/on_validation_batch_end instead!).

Depending on where self.log is called from, Lightning auto-determines the correct logging mode for you (logs after every step in training_step, logs epoch accumulated metrics for every epoch in validation or test steps). But of course, you can override the default behavior by manually setting the log() parameters.

self.log('my_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)

Read more about logging in our docs.

More improvements

Image for post

Image for post
ROC

Image for post
PrecisionRecallCurve

Image for post
AveragePrecision

We’d like to thank all the hard working contributors that took part in this release. Kudos! If you want to give back to the community, here’s a list of issues for new contributors you can try to solve.

Let’s meet!

Want to learn more about new features and get inspired by community projects? In our next community meetup were introducing Lightning Talks- 5 projects in 5 minutes, join us on December 17th 1PM EST to learn more about the new model sharded training, self supervised learning for object detection, and how a kaggle grandmaster is using Lightning in his projects! RSVP here.

Interested in presenting in our next meetup? Fill this out! It’s a great way to make connections, spread the word about your work, and help your fellow researchers.

PyTorch

An open source machine learning framework that accelerates…