August 27, 2020

Keeping Up with PyTorch Lightning and Hydra

Peter Yu

How I shrank my training script by 50% using the new features from PyTorch lightning 0.9 and Hydra 1.0.0rc4

Image for post
Try to keep up! — Source

Introduction

Last week, PyTorch Lightning 0.9.0 and Hydra’s fourth release candidate for 1.0.0 were released with a choke-full of new features and mostly final APIs. I thought it’d be a good time for me to revisit my side project Leela Zero PyTorch to see how these new versions can be integrated into it. In this post, I’ll talk about some of the new features of the two libraries, and how they helped Leela Zero PyTorch. I’m not going to talk about the details about Leela Zero PyTorch all too much here, so if you want to read more about my side project for more context, you can read my previous blog post about it here.

PyTorch Lightning 0.9.0

This is a major milestone for the PyTorch Lightning team as they diligently work toward the 1.0.0 release. It introduces a number of new features and an API that is ever closer to the final one. Before we jump in, if you want to read more about this release, check out the official blog post. If you want to learn more about PyTorch Lightning in general, check out the Github page as well as the official documentation.

Result

Have you found yourself repetitively implementing *_epoch_end methods just so that you can aggregate results from your *_step methods? Have you found yourself getting tripped on how to properly log the metrics calculated in your *_step and *_epoch_end methods? You’re not alone, and PyTorch Lightning 0.9.0 has introduced a new abstraction called Result to solve these very problems.

There are two types of Result, TrainResult and EvalResult. As the names suggest, TrainResult is used for training and EvalResult is used for validation and testing. Their interfaces are simple: you specify the main metrics to act on during instantiation (for TrainResult, the metrics to minimize, for EvalResult, metrics to checkpoint or early stop on), then you specify additional metrics to log. Let’s take a look at how they’re used in my project:

In training_step(), I specify the overall loss to be minimized, and log the overall loss (which is also specified to be displayed in the progress bar), mean squared error loss, cross entropy loss and finally the accuracy (which is calculated using PyTorch Lightning’s new metrics package, which will be discussed shortly). I don’t need to write the code to aggregate them at the epoch level since TrainResult takes care of that. As a matter of fact, you can specify the level (step, epoch or both) at which each metrics should be aggregated and logged with TrainResult, and it will automatically handle everything for you.

Similarly, in validation_step(), I specify the overall loss to be used for checkpointing, and log the overall loss, mean squared error loss, cross entropy loss and accuracy. Again, I don’t need to write validation_epoch_end(), since aggregation and logging are handled by EvalResult. Furthermore, I don’t need to repeat myself for test_step(), and simply call validation_step() and rename the keys for metrics to be logged.

You can immediately see how my code has become simpler, more readable and more maintainable thanks to Result. You can read more about it here.

Metrics

Continuing their work in 0.8, the PyTorch Lightning team has introduced even more implementations of metrics in 0.9.0. Every metrics implementation in PyTorch Lightning is a PyTorch Module, and has its functional counterpart, making it extremely easy and flexible to use. For my project, I decided to integrate the functional implementation of accuracy, which was just a matter of importing it and calling it in the appropriate *_step methods.

There are many other metrics implementations included in PyTorch Lightning now, including advanced NLP metrics like the BLEU score. You can read more about it here.

LightningDataModule

Another pain point you may have had with PyTorch Lightning is handling various data sets. Up until 0.9.0, PyTorch Lightning has remained silent on how to organize your data processing code, except that you use PyTorch’s Dataset and DataLoader. This certainly gave you a lot of freedom, but made it hard to keep your data set implementation clean, maintainable and easily sharable with others. In 0.9.0, PyTorch Lightning introduces a new way of organizing data processing code in LightningDataModule, which encapsulates the most common steps in data processing. It has a simple interface with five methods: prepare_data(), setup(), train_dataloader(), val_dataloader() and test_dataloader(). Let’s go over how each of them is implemented in my project to understand its role.

Now, it’s just a matter of passing the LightningDataModule into trainer.fit() and trainer.test(). You can also imagine a scenario where I implement another LightningDataModule for a different type of data set such as chess game data, and the trainer will accept it just the same. I can take it further and use Hydra’s object instantiation pattern and easily switch between various data modules.

Hydra 1.0.0rc4

1.0.0rc4 brings Hydra even closer to its official 1.0.0 release. It contains many bug fixes and some crucial API changes that make the library more mature and easier to use. Before we jump in, if you want to learn more about Hydra in general, check out the official website as well as the official documentation!

@hydra.main()

You can add this decorator to any function that accepts OmegaConf’s DictConfig, and Hydra will automatically handle various aspects of your script. This is not a new feature per se, but a feature I originally decided not to use due to the fact that it takes over the output directory structure as well as the working directory. I actually used Hydra’s experimental Compose API, which I will discuss later, to get around this issue. However, after talking to Omry, the creator of Hydra, I realized that not only is this not the recommended approach, but also I lose a number of cool features provided by Hydra such as automatic handling of the command line interface, automatic help messages and tab completion. Furthermore, after using it for some time, I’ve found that Hydra’s output directory and working directory management are quite useful, because I do not have to manually set up the logging directory structure on PyTorch Lightning’s side. You can read more about this decorator in Hydra’s basic tutorial.

Package Directive

In Hydra 0.11, there was only one global namespace for the configurations, but in 1.0.0, you can organize your configurations in different namespaces using package directives. This allows you to keep your yaml configuration files flat and clean without unnecessary nesting. Let’s take a look at the network size configuration from Leela Zero PyTorch:

Image for post
The network size configuration has been added under “network” as specified. Please note that “board_size” and “in_channels” come from the data configuration (composition!)

As you can see, package directives make your configuration more manageable. You can read more about package directives and their more advanced use cases here.

Instantiating Objects

Hydra provides a feature where you can instantiate an object or call a function based on configurations. This is extremely useful when you want your script to have a simple interface to switch between various implementations. This is not a new feature either, but its interface has vastly improved in 1.0.0rc4. In my case, I use it to switch between network sizes, training loggers and datasets. Let’s take the network size configuration as an example.

Instantiate the network based on the selected configuration. Notice that you can pass in additional arguments to instantiate() as I did with cfg.train here.

NetworkLightningModule accepts two arguments for its __init__(), network_conf and train_conf. The former is passed in from the configuration, and the latter is passed in as an extra argument in instantiate() (cfg.train). All you have to do to select different network sizes is to pass in +network={small,big,huge} in the command line. You can even imagine selecting a totally different architecture by creating a new config with a different _target_, and passing in the config name in the command line. No need to pass in all the small details via the command line! You can read more about this pattern here.

Compose API

Although Hydra’s Compose API is not the recommended way for writing scripts, it’s still recommended and useful for writing unit tests. I used it to write unit tests for the main training script. Again, this is not a new feature, but Hydra 1.0.0rc4 does bring in a cleaner interface for the Compose API using Python’s context manager (the with statement).

You can read more about the Compose API here, and how to use it unit tests here.

Unused Features: Structured Configs and Variable Interpolation

There are many other features in Hydra 1.0.0rc4 I didn’t take advantage of, mostly due to the fact that I haven’t had enough time to integrate them. I’ll go over the biggest one in this section — structured configs.

Structured configs are a major new feature introduced in 1.0.0 that utilize Python’s dataclasses to provide runtime and static type checking, which can be extremely useful as your application grows in complexity. I’ll probably integrate them in the future when I can find time, so please stay tuned for another blog post!

Conclusion

Since I wrote my first blog post about Leela Zero PyTorch, both Hydra and PyTorch Lightning have introduced a number of new features and abstractions that can help you greatly simplify your PyTorch scripts. As you can see above, my main training script now consists of mere 28 lines compared to 56 lines before. Moreover, each part of the training pipeline, the neural network architecture, data set and logger, is modular and easily swappable. This enables faster iteration, easier maintenance and better reproducibility, allowing you to focus on the most fun and important parts of your projects. I hope this blog post has been helpful as you “keep up” with these two awesome libraries! You can find the code for Leela Zero PyTorch here.

Written by

Peter Yu

PhD Student at UMich Researching NLP and Cognitive Architectures • Perviously Real-time Distributed System Engineer turned NLP Research Engineer at ASAPP

Thanks to PyTorch Lightning team.