September 19, 2020

Introducing PyTorch Forecasting State-of-the-art forecasting with neural networks made simple

Jan Beitner

I am pleased to announce the open-source Python package PyTorch Forecasting. It makes time series forecasting with neural networks simple both for data science practitioners and researchers.

Why is accurate forecasting so important?

Forecasting time series is important in many contexts and highly relevant to machine learning practitioners. Take, for example, demand forecasting from which many use cases derive. Almost every manufacturer would benefit from better understanding demand for their products in order to optimise produced quantities. Underproduce and you will lose revenues, overproduce and you will be forced to sell excess produce at a discount. Very related is pricing, which is essentially a demand forecast with a specific focus on price elasticity. Pricing is relevant to virtually all companies.

For a large number of additional machine learning applications time is of the essence: predictive maintenance, risk scoring, fraud detection, etc. — you name it. The order of events and time between them is crucial to create a reliable forecast.

In fact, while time series forecasting might not be as shiny as image recognition or language processing, it is more common in industry. This is because image recognition and language processing are relatively new to the field and are often used to power new products, while forecasting has been around for decades and sits at the heart of many decision (support) systems. The employment of high-accuracy machine learning models such as the ones in PyTorch Forecasting can better support decision making or even automate it, often directly resulting in multi-million dollars of additional profits.

Deep learning emerges as a powerful forecasting tool

Deep learning (neural networks) has only recently outperformed traditional methods in time series forecasting, and has done so by a smaller margin than in image and language processing. In fact, in forecasting pure time series (which means without covariates, for example, price is to demand), deep learning has surpassed traditional statistical methods only two years ago [1]. However, as the field is quickly advancing, accuracy advantages associated with neural networks have become significant, which merits their increased use in time series forecasting. For example, the latest architecture N-BEATS demonstrates an 11% decrease in sMAPE on the M4 competition dataset compared to the next best non-neural-network-based method which is an ensemble of statistical methods [2]. This network is also implemented in PyTorch Forecasting.

Furthermore, even compared to other popular machine learning algorithms, such as gradient boosted trees, deep learning has two advantages. First, neural network architectures can be designed with an inherent understanding of time, i.e. they automatically make a connection between temporally close data points. As a result, they can capture complex time dependencies. On the contrary, traditional machine learning models require manual creation of time series features, such as the average over the last x days. This diminishes the capabilities of these traditional machine learning algorithms to model time dependencies. Second, most tree-based models output a step function by design. Therefore, they cannot predict the marginal impact of change in inputs and, further, are notoriously unreliable in out-of-domain forecasts. For example, if we have observed only prices at 30 EUR and 50 EUR, tree-based models cannot assess the impact on demand of changing the price from 30 EUR to 35 EUR. In consequence, they often cannot directly be used to optimise inputs. However, this is often the whole point of creating a machine learning model — the value is in the optimisation of covariates. At the same time, neural networks employ continuous activation functions and are particularly good at interpolation in high-dimensional spaces, i.e. they can be used to optimise inputs, such as price.

What is PyTorch Forecasting?

Image for post

PyTorch Forecasting aims to ease time series forecasting with neural networks for real-world cases and research alike. It does so by providing state-of-the-art time series forecasting architectures that can be easily trained with pandas dataframes.

To get started, detailed tutorials in the documentation showcase end-to-end workflows. I will also discuss a concrete example later in this article.

Why do we need this package?

PyTorch Forecasting helps overcome important barriers to the usage of deep learning. While deep learning has become dominant in image and language processing, this is less so in time series forecasting. The field remains dominated by traditional statistical methods such as ARIMA and machine learning algorithms such as gradient boosting, with the odd exemption of a Bayesian model. The reasons why deep learning has not yet become mainstream in time series forecasting are two-fold, all of which can already be overcome:

  1. Training neural networks almost always require GPUs which are not always readily available. Hardware requirements are often an important impediment. However, by moving computation into the cloud this hurdle can be overcome.
  2. Neural networks are comparably harder to use than traditional methods. This is particularly the case for time series forecasting. There is a lack of a high-level API that works with the popular frameworks, such as PyTorch by Facebook or Tensorflow by Google. For traditional machine learning the sci-kit learn ecosystem exists which provides a standardised interface for practitioners.

This third hurdle is considered crucial in the deep learning community given its user-unfriendliness requires substantial software engineering. The following tweet summarises the sentiment of many:

Typical sentiment from a deep learning practitioner

Some even thought the statement was trivial:

In a nutshell, PyTorch Forecasting aims to do what fast.ai has done for image recognition and natural language processing. That is significantly contributing to the proliferation of neural networks from academia into the real world. PyTorch Forecasting seeks to do the equivalent for time series forecasting by providing a high-level API for PyTorch that can directly make use of pandas dataframes. To facilitate learning it, unlike fast.ai, the package does not create a completely new API but rather builds on the well-established PyTorch and PyTorch Lightning APIs.

How to use PyTorch Forecasting?

This small example showcases the power of the package and its most important abstractions. We will

  1. create a training and validation dataset,
  2. train the Temporal Fusion Transformer [2]. This is an architecture developed by Oxford University and Google that has beaten Amazon’s DeepAR by 36–69% in benchmarks,
  3. inspect results on the validation set and interpret the trained model.

NOTE: The code below works only up to version 0.4.1 of PyTorch Forecasting and 0.9.0 of PyTorch Lightning. Minimal modifications are required to run with the latest version. A full tutorial with the latest code can be found here.

Creating datasets for training and validation

First, we need to transform our time series into a pandas dataframe where each row can be identified with a time step and a time series. Fortunately, most datasets are already in this format. For this tutorial, we will use the Stallion dataset from Kaggle describing sales of various beverages. Our task is to make a six-month forecast of the sold volume by stock keeping units (SKU), that is products, sold by an agency, that is a store. There are about 21 000 monthly historic sales records. In addition to historic sales we have information about the sales price, the location of the agency, special days such as holidays, and volume sold in the entire industry.

from pytorch_forecasting.data.examples import get_stallion_datadata = get_stallion_data()  # load data as pandas dataframe

The dataset is already in the correct format but misses some important features. Most importantly, we need to add a time index that is incremented by one for each time step. Further, it is beneficial to add date features, which in this case means extracting the month from the date record.

# add time index
data["time_idx"] = data["date"].dt.year * 12 + data["date"].dt.monthdata["time_idx"] -= data["time_idx"].min()# add additional features
# categories have to be strings
data["month"] = data.date.dt.month.astype(str).astype("category")
data["log_volume"] = np.log(data.volume + 1e-8)
data["avg_volume_by_sku"] = (
   data
   .groupby(["time_idx", "sku"], observed=True)
   .
volume.transform("mean")
)
data["avg_volume_by_agency"] = (
   data
   .groupby(["time_idx", "agency"], observed=True)
   .
volume.transform("mean")
)
# we want to encode special days as one variable and
# thus need to first reverse one-hot encoding
special_days = [
   "easter_day", "good_friday", "new_year", "christmas",
   "labor_day", "independence_day", "revolution_day_memorial",
   "regional_games", "fifa_u_17_world_cup", "football_gold_cup",
   "beer_capital", "music_fest"
]data[special_days] = (
   data[special_days]
   .apply(lambda x: x.map({0: "-", 1: x.name}))
   .astype("category")
)# show sample data
data.sample(10, random_state=521)
Image for post

Random rows samples from dataframe

Image for post

The next step is to convert the dataframe into a PyTorch Forecasting dataset. Apart from telling the dataset which features are categorical vs continuous and which are static vs varying in time, we also have to decide how we normalise the data. Here, we standard scale each time series separately and indicate that values are always positive.

We also choose to use the last six months as a validation set.

from pytorch_forecasting.data import (
   TimeSeriesDataSet,
   GroupNormalizer
)max_prediction_length = 6  # forecast 6 months
max_encoder_length = 24  # use 24 months of history
training_cutoff = data["time_idx"].max() - max_prediction_lengthtraining = TimeSeriesDataSet(
   data[lambda x: x.time_idx <= training_cutoff],
   time_idx="time_idx",
   target="volume",
   group_ids=["agency", "sku"],
   min_encoder_length=0,  # allow predictions without history
   max_encoder_length=max_encoder_length,
   min_prediction_length=1,
   max_prediction_length=max_prediction_length,
   static_categoricals=["agency", "sku"],
   static_reals=[
       "avg_population_2017",
       "avg_yearly_household_income_2017"
   ],
   time_varying_known_categoricals=["special_days", "month"],
   # group of categorical variables can be treated as
   # one variable
   variable_groups={"special_days": special_days},
   time_varying_known_reals=[
       "time_idx",
       "price_regular",
       "discount_in_percent"
   ],
   time_varying_unknown_categoricals=[],
   time_varying_unknown_reals=[
       "volume",
       "log_volume",
       "industry_volume",
       "soda_volume",
       "avg_max_temp",
       "avg_volume_by_agency",
       "avg_volume_by_sku",
   ],
   target_normalizer=GroupNormalizer(
       groups=["agency", "sku"], coerce_positive=1.0
   ),  # use softplus with beta=1.0 and normalize by group
   add_relative_time_idx=True,  # add as feature
   add_target_scales=True,  # add as feature
   add_encoder_length=True,  # add as feature
)# create validation set (predict=True) which means to predict the
# last max_prediction_length points in time for each series
validation = TimeSeriesDataSet.from_dataset(
   training, data, predict=True, stop_randomization=True
)# create dataloaders for model
batch_size = 128
train_dataloader = training.to_dataloader(
   train=True, batch_size=batch_size, num_workers=0
)
val_dataloader = validation.to_dataloader(
   train=False, batch_size=batch_size * 10, num_workers=0
)

Training the Temporal Fusion Transformer

It is now time to create our model. We train the model with PyTorch Lightning. Prior to training, you can identify the optimal learning rate with its learning rate finder (see the documentation for an example).

import pytorch_lightning as pl
from pytorch_lightning.callbacks import (
   EarlyStopping,
   LearningRateLogger
)
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_forecasting.metrics import QuantileLoss
from pytorch_forecasting.models import TemporalFusionTransformer# stop training, when loss metric does not improve on validation set
early_stop_callback = EarlyStopping(
   monitor="val_loss",
   min_delta=1e-4,
   patience=10,
   verbose=False,
   mode="min"
)
lr_logger = LearningRateLogger()  # log the learning rate
logger = TensorBoardLogger("lightning_logs")  # log to tensorboard# create trainer
trainer = pl.Trainer(
   max_epochs=30,
   gpus=0,  # train on CPU, use gpus = [0] to run on GPU
   gradient_clip_val=0.1,
   early_stop_callback=early_stop_callback,
   limit_train_batches=30,  # running validation every 30 batches
   # fast_dev_run=True,  # comment in to quickly check for bugs
   callbacks=[lr_logger],
   logger=logger,
)# initialise model
tft = TemporalFusionTransformer.from_dataset(
   training,
   learning_rate=0.03,
   hidden_size=16,  # biggest influence network size
   attention_head_size=1,
   dropout=0.1,
   hidden_continuous_size=8,
   output_size=7,  # QuantileLoss has 7 quantiles by default
   loss=QuantileLoss(),
   log_interval=10,  # log example every 10 batches
   reduce_on_plateau_patience=4,  # reduce learning automatically
)
tft.size() # 29.6k parameters in model# fit network
trainer.fit(
   tft,
   train_dataloader=train_dataloader,
   val_dataloaders=val_dataloader
)
Image for post

Training takes about three minutes on my Macbook but for larger networks and datasets, it can take hours. During training, we can monitor the tensorboard which can be spun up with tensorboard --logdir=lightning_logs. For example, we can monitor examples predictions on the training and validation set. As you can see from the figure below, forecasts look rather accurate. If you wonder, the grey lines denote the amount of attention the model pays to different points in time when making the prediction. This is a special feature of the Temporal Fusion Transformer.

Image for post
Tensorboard panel showing training examples

Evaluating the trained model

After training, we can evaluate the metrics on the validation dataset and a couple of examples to see how well the model is doing. Given that we work with only 21 000 samples the results are very reassuring and can compete with results by a gradient booster.

from pytorch_forecasting.metrics import MAE# load the best model according to the validation loss (given that
# we use early stopping, this is not necessarily the last epoch)
best_model_path = trainer.checkpoint_callback.best_model_path
best_tft = TemporalFusionTransformer.load_from_checkpoint(best_model_path)# calculate mean absolute error on validation set
actuals = torch.cat([y for x, y in iter(val_dataloader)])
predictions = best_tft.predict(val_dataloader)MAE(predictions, actuals)
Image for post

Looking at the worst performers in terms of sMAPE gives us an idea where the model has issues with forecasting reliably. These examples can provide important pointers about how to improve the model. This kind of actuals vs predictions plots are available to all models.

from pytorch_forecasting.metrics import SMAPE# calculate metric by which to display
predictions, x = best_tft.predict(val_dataloader)
mean_losses = SMAPE(reduction="none")(predictions, actuals).mean(1)
indices = mean_losses.argsort(descending=True)  # sort lossesraw_predictions, x = best_tft.predict(val_dataloader, mode="raw, return_x=True)# show only two examples for demonstration purposes
for idx in range(2):
   best_tft.plot_prediction(
       x,
       raw_predictions,
       idx=indices[idx],
       add_loss_to_title=SMAPE()
   )

Image for post
The two worst predictions on the validation set. The white line is how much attention the transformer gives to a given point in time.

Similarly, we could also visualise random examples from our model. Another feature of PyTorch Forecasting is interpretation of trained models. For example, all models allow us to readily calculate partial dependence plots. However, for brevity we will show here some inbuilt interpretation capabilities of the Temporal Fusion Transformer. It enables the variable importances by design of the neural network.

interpretation = best_tft.interpret_output(
   raw_predictions, reduction="sum"
)best_tft.plot_interpretation(interpretation)
Image for post

Unsurprisingly, the past observed volume features as the top variable in the encoder and price related variables are among the top predictors in the decoder. Maybe more interesting is that the agency is ranked only fifth amongst the static variables. However, given that the second and third variable are related to location, we could expect agency to rank far higher if those two were not included in the model.

Summary

It is very easy to train a model and get insights into its inner workings with PyTorch Forecasting. As a practitioner, you can employ the package to train and interpret state-of-the-art models out-of-the-box. With PyTorch Lightning integration training and prediction is scalable. As a researcher, you can leverage the package to get automatic tracking and introspection capabilities for your architecture and apply it seamlessly to multiple datasets.

Code, documentation and how to contribute

The code for this tutorial can be found in this notebook: https://github.com/jdb78/pytorch-forecasting/blob/master/docs/source/tutorials/stallion.ipynb

Install PyTorch Forecasting with

pip install pytorch-forecasting

or

conda install -c conda-forge pytorch-forecasting

GitHub repository: https://github.com/jdb78/pytorch-forecasting

Documentation (including tutorials): https://pytorch-forecasting.readthedocs.io

The package is open source under the MIT Licence which permits commercial use. Contributions are very welcome! Please read the contribution guidelines upfront to ensure your contribution is merged swiftly.

Related Work

Gluon-TS by Amazon aims to provide a similar interface but has two distinct disadvantages compared to PyTorch Forecasting. First, the package’s backend is MXNet, a deep learning framework trailing in popularity PyTorch and Tensorflow. Second, while it is a powerful framework, it can be difficult to master and modify given its complex object inheritance structure and tight coupling of components.

References

[1] S. Smyl, J. Ranganathan and A. Pasqua, M4 Forecasting Competition: Introducing a New Hybrid ES-RNN Model (2018), https://eng.uber.com/m4-forecasting-competition

[2] B. N. Oreshkin et al., N-BEATS: Neural basis expansion analysis for interpretable time series forecasting (2020), International Conference on Learning Representations

[3] B. Lim, S. O. Arik, N. Loeff and T. Pfister, Temporal Fusion Transformers for Interpretable Multi-horizon Time Series Forecasting (2019), arXiv:1912.09363