PyTorch Geometric Temporal: What Is it & Your InDepth Guide

PyTorch Geometric Temporal

PyTorch Geometric Temporal is a temporal extension of PyTorch Geometric(PyG) framework, which we have covered in our previous article. This open-source python library’s central idea is more or less the same as Pytorch Geometric but with temporal data. Like PyG, PyTorch Geometric temporal is also licensed under MIT. It contains many dynamic and temporal state-of-the-art deep learning algorithms or Geometric Deep learning algorithms for spatio-temporal signals. It also provides an easy-to-use mini-batch loader/dataloader, multi GPU-support, benchmark datasets and an iterator for dynamic and temporal graphs. The keys points of the framework are :

  • On both dynamic and static graphs, it offers discrete-time graph neural networks.  
  • When the time period is continuous, it allows for spatio-temporal learning without the use of discrete snapshots.

Methods included in PyTorch Geometric Temporal.

THE BELAMY

Sign up for your weekly dose of what's up in emerging technology.
  • Discrete Recurrent Graph Convolutions
  • Temporal Graph Convolutions
  • Auxiliary Graph Convolutions

You can check all the algorithms supported by it, here.

Without further ado, let’s begin with the code part!

Requirements & Installation

Install all the requirements of PyTorch Geometric Temporal and then install it via PyPI.

  • PyTorch >= 1.4.0

    For checking the version of PyTorch, run the mentioned code:

!python -c "import torch; print(torch.__version__)"

  • Check the version of CUDA installed with PyTorch.

!python -c "import torch; print(torch.version.cuda)"

  • Install the dependencies :

The code is for PyTorch version = 1.7.0 and replaces ${CUDA} with the CUDA version which you are using.

 %%bash
 pip install torch-scatter==latest+${CUDA} -f https://pytorch-geometric.com/whl/torch-1.7.0.html
 pip install torch-sparse==latest+${CUDA} -f https://pytorch-geometric.com/whl/torch-1.7.0.html
 pip install torch-cluster==latest+${CUDA} -f https://pytorch-geometric.com/whl/torch-1.7.0.html
 pip install torch-spline-conv==latest+${CUDA} -f https://pytorch-geometric.com/whl/torch-1.7.0.html 
  • Install PyTorch Geometric:

!pip install torch-geometric

  • Install the required framework:

!pip install torch-geometric-temporal

Data Structures

This part will briefly talk about the data structures provided by PyTorch Geometric Temporal Framework.

Discrete Dataset Iterators

There are two types of discrete data iterators provided by PyTorch Geometric Temporal:

  1. Statistics Graphs with Discrete Signal : Used of discrete spatio-temporal signals on a static graph, the constructor is defined as StaticGraphDiscreteSignal which has the following parameter requirements:
  • edge_index – A NumPy array to hold the edge indices.
  • edge_weight – A NumPy array to hold the edge weights.
  • features – A list of NumPy arrays to hold the vertex features for each time period.
  • targets – A list of NumPy arrays to hold the vertex level targets for each time period.
  1. Static Graphs with Dynamic Signal :  Used of discrete spatio-temporal signals on a dynamic graph, the constructor is defined as  DynamicGraphDiscreteSignal which has the following parameter requirements:
  • edge_indices – A list of NumPy arrays to hold the edge indices.
  • edge_weights – A list of NumPy arrays to hold the edge weights.
  • features – A list of NumPy arrays to hold the vertex features for each time period.
  • targets – A list of NumPy arrays to hold the vertex level targets for each time period.

Temporal Snapshots: It is a discrete temporal Data object, functions same as the Data object discussed in PyTorch Geometric.

Benchmark Datasets 

There are many datasets provided by PyTorch Geometric Temporal framework for the comparison of performance of GNN algorithms

Discrete Time Datasets : To import Hungarian Chickenpox Dataset discrete dataset, the code snippet is available below:

 from torch_geometric_temporal.data.dataset import ChickenpoxDatasetLoader
 loader = ChickenpoxDatasetLoader()
 #get_dataset is the method of StaticGraphDiscreteSignal object
 dataset = loader.get_dataset() 

Train-Test Splitter

Discrete train-test splitter, splits the temporal and return train and test iterators generated from discrete time iterators using a fix ratio. The function of discrete_train_test_split is the same as scikit’s train_test_split. On an input of any StaticGraphDiscreteSignal or a DynamicGraphDiscreteSignal, discrete_train_test_split returns two iterators splitted according to the train_ratio.

 from torch_geometric_temporal.data.dataset import ChickenpoxDatasetLoader
 from torch_geometric_temporal.data.splitter import discrete_train_test_split
 loader = ChickenpoxDatasetLoader()
 dataset = loader.get_dataset()
 train_dataset, test_dataset = discrete_train_test_split(dataset, train_ratio=0.8) 

Applications of PyTorch Geometric Temporal

This section will provide you an overview of how PyG Temporal is used in real-world scenarios. 

Learning from a Discrete Temporal Signal

The following demo trains a regressor to predict the weekly Chickenpox  cases reported by the country. The dataset used for it is Hungarian Chickenpox Dataset.

  1. Load the dataset and divide the iterator into train set and test set.
 from torch_geometric_temporal.data.dataset import ChickenpoxDatasetLoader
 from torch_geometric_temporal.data.splitter import discrete_train_test_split
 loader = ChickenpoxDatasetLoader()
 dataset = loader.get_dataset()
 train_dataset, test_dataset = discrete_train_test_split(dataset, train_ratio=0.2) 
  1. Next step is to create a Recurrent Graph Neural Network(RGNN) for supervised learning. The fundamental design of the architecture is the same as PyTorch Geometric. The RecurrentGNN function Object() {RecurrentGCN } generates a DCRNN and a feedforward layer, and the ReLU activation function is used to manually establish non linearity between the recurrent and linear layers.
 import torch
 import torch.nn.functional as F
 from torch_geometric_temporal.nn.recurrent import DCRNN
 class RecurrentGCN(torch.nn.Module):
     def __init__(self, node_features):
         super(RecurrentGCN, self).__init__()
         self.recurrent = DCRNN(node_features, 32, 1)
         self.linear = torch.nn.Linear(32, 1)
     def forward(self, x, edge_index, edge_weight):
         h = self.recurrent(x, edge_index, edge_weight)
         h = F.relu(h)
         h = self.linear(h)
         return h 
  1. Train the model on the train_dataset for 200 epochs by back propagating the loss from every snapshot. The Adam optimizer is used with a learning rate of 0.01.
 from tqdm import tqdm
 model = RecurrentGCN(node_features = 4)
 optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
 model.train()
 for epoch in tqdm(range(200)):
     cost = 0
     for time, snapshot in enumerate(train_dataset):
         y_hat = model(snapshot.x, snapshot.edge_index, snapshot.edge_attr)
         cost = cost + torch.mean((y_hat-snapshot.y)**2)
     cost = cost / (time+1)
     cost.backward()
     optimizer.step()
     optimizer.zero_grad() 
  1. Finally, run a test dataset to assess the model’s output and measure the Mean Squared Error (MSE) for all spatial units and time periods.
 model.eval()
 cost = 0
 for time, snapshot in enumerate(test_dataset):
     y_hat = model(snapshot.x, snapshot.edge_index, snapshot.edge_attr)
     cost = cost + torch.mean((y_hat-snapshot.y)**2)
 cost = cost / (time+1)
 cost = cost.item()
 print("MSE: {:.4f}".format(cost)) 

Colab Notebook : PyTorch Geometric Temporal Demo

References : 

Official codes, Documentation and Tutorials are available at : 

More Great AIM Stories

Aishwarya Verma
A data science enthusiast and a post-graduate in Big Data Analytics. Creative and organized with an analytical bent of mind.

Our Upcoming Events

Conference, in-person (Bangalore)
Machine Learning Developers Summit (MLDS) 2023
19-20th Jan, 2023

Conference, in-person (Bangalore)
Rising 2023 | Women in Tech Conference
16-17th Mar, 2023

Conference, in-person (Bangalore)
Data Engineering Summit (DES) 2023
27-28th Apr, 2023

Conference, in-person (Bangalore)
MachineCon 2023
23rd Jun, 2023

3 Ways to Join our Community

Discord Server

Stay Connected with a larger ecosystem of data science and ML Professionals

Telegram Channel

Discover special offers, top stories, upcoming events, and more.

Subscribe to our newsletter

Get the latest updates from AIM