Now Reading
Guide To Pyro – A Deep Probabilistic Programming Language

Guide To Pyro – A Deep Probabilistic Programming Language


Pyro is a state-of-the-art programming language for deep probabilistic modelling. It is a flexible and scalable probabilistic programming language (PPL). It unifies the modern concepts of deep learning and Bayesian modelling. It has been written in Python and built on top of Pytorch. The Uber AI Labs introduced it in 2017. A team now maintains it at the Broad Institute in collaboration with its developer community.

Are you unfamiliar with the term ‘probabilistic programming’? Refer to the ‘probabilistic programming’ section of this article before proceeding!

Before moving on to Pyro’s details, we will briefly talk about its base library PyTorch, which it utilizes as an underlying tensor computation engine.

PyTorch – the backbone of Pyro

PyTorch is an open-source ML library based on the Torch framework designed for scientific computing. Developed by the Facebook AI Research (FAIR) Lab, it is extensively used in applications such as Natural Language Processing (NLP) and Computer Vision (CV). It includes automatic differentiation using rapid-paced, GPU-accelerated tensor math. It constructs gradients dynamically. 

Since Pyro uses PyTorch at its backend, Pyro programs include stochastic control structure. In other words, random choices in a Pyro program can control other random choices in the program. Stochastic control structure makes a PPL universal. Thus, Pyro can represent any probabilistic model, while providing automatic optimization-based inference and hence, it is said to be a ‘universal PPL’.

Highlighting features of Pyro

  • Universality – Pyro can represent any computable probability distribution.
  • Scalability – Pyro scales well to voluminous datasets with little overhead as compared to hand-written code.
  • Minimality – Pyro is an agile and easy-to-maintain library implemented with a small core of powerful yet composable abstractions.
  • Flexibility – Both the generative and inference guides in Pyro can include deep neural networks as components. Both the types of models can be expressed using high-level abstractions, and the inferences are easily customizable. It can aim for either automation or control as per the user’s needs.

Installation of Pyro

NOTE: Pyro supports Python 3.6+ versions.

Pyro can be installed using pip command as follows:

pip install pyro-ppl

Practical implementation of Pyro

Problem statement:

Suppose we have a weighing scale that tells us the weight of an object it holds. But the scale lacks accuracy and gives different measurements each time we weigh a given object. Assume, the scale’s errors form a normal distribution around that object’s actual weight, with a standard deviation of 0.1kg. We describe the scaling process and infer the true weight using Pyro in the implementation below.

We have used Python 3.6.9, PyTorch 1.7.0 and Pyro 1.5.2 versions for this demonstration. The step-wise explanation of the Google colab code is as follows:

  1. Import the required libraries and modules

#Import PyTorch

import torch

#Import Pyro

install pyro

#Debug using the assert keyword to ensure Pyro’s version used

assert pyro.__version__.startswith('1.5.2')

#Set a random number generator to some default value


#torch.distributions package containing sampling functions and parameterizable #probability distributions

import torch.distributions as dist

#Import Pandas and Numpy for basic data manipulation operations and time library #for time-related functions

 import numpy as np
 import pandas as pd 
 import time

#Import Seaborn and Matplotlib for visualization

 import seaborn as sns
 import matplotlib.pyplot as plt
 %matplotlib inline  #To display output of plotting commands inline i.e.  
 within the frontend 

#Import scipy.stats.norm for a normal continuous random variable

from scipy.stats import norm

#Use Pyro distributions

import pyro.distributions as pdist

  1. Define the method which gives measurement observations
           def measure(wt):
#torch.distributions.normal.Normal() creates a normal (Gaussian) #distribution
          distribution = dist.Normal(wt, 0.1)
 #’wt’ is mean of the distribution and 0.1 is its standard deviation
 #Randomly sample the normal distribution using sample() method
          result = distribution.sample()
          return result 
  1. Test the results when some weight of say 0.6kg is placed on the scale multiple times

Sample output:


Note: The output may vary as many times as you execute the code.

It can be seen from the output that every time we do not get the same measurement results. The observations are not always shown to be 0.6.

  1. Now suppose we do not insist upon getting the exact measurement but want to predict an observation’s probability. For instance, what is the probability that the observed measurement will be above 0.66?
 rough_measure = np.sum([measure(0.6) > 0.66 for i in range(1000)])/1000
 print(f'Rough Estimate: {rough_measure}')
 reasonable_measure = np.sum([measure(0.6) > 0.66 for i in range(10000)])/10000
 print(f'Eeasonable Estimate: {reasonable_measure}')
 good_measure = np.sum([measure(0.6) > 0.66 for i in range(100000)])/100000
 print(f'Good Estimate: {good_measure}')
 true_measure = 1.0 - norm(0.6, 0.1).cdf(0.66) 
 #0.6 is mean and 0.1 is standard deviation of the normal distribution.  
 #cdf(x) means the probability that a random sample will be less than or  
 #equal to x
 print(f'True Estimate: {true_measure}')
 #’cdf’ in scipy.stat.norm.cdf stands for cumulative distribution function 

Sample output:

 Rough Estimate: 0.26
 Reasonable Estimate: 0.269
 Good Estimate: 0.27278
 True Estimate: 0.27425311775007344 

Note: The output may vary as per the output of measure() method in each step.

The process done in this step involves somewhat tedious calculations though it gives satisfactory results. 

  1. Now suppose we have to handle complex queries and the distribution of weights is also not normal. For instance, we have some observations about an object as follows:

0.77, 0.88, 0.67, 0.77, 0.82, 0.71

The task is to find out the likeliest true weight of the object for which we got the above observations on our scale.

Such are the questions where probabilistic programming comes into picture. We will use Pyro to handle the query as follows:

See Also

First, form a torch tensor (multi-dimensional matrix having elements of a common datatype) of your observations. 

NOTE: Pyro works only on torch tensors

 results = torch.tensor([0.77, 0.88, 0.67, 0.77, 0.82, 0.71])
 print(f'Mean = {torch.mean(results)}') 

Output: Mean = 0.7699999809265137

  1. Define the Pyro model
 def my_model(results):
 We are attempting to find a distribution over possible weight values, given the observations we recorded. The things done in the function are as follows:
We first build a prior distribution guessing the mean computed in step (5) as the object’s probable weight. We thus create a normal distribution called ‘prior_wt’ with the mean (~0.769) and 1.0 as its standard deviation.
 prior_wt = pdist.Normal(0.769, 1.0)
# Sample a value from the prior_wt distribution and call it ‘wt_1’. Pyro #will modify prior_wt to be more aligned with our observation.
      wt = pyro.sample("wt_1", prior_wt)
#Define the way our scale measures i.e. return values from a normal #distribution with ‘wt’ at the centre and 0.1 as the standard deviation  
        distribution = pdist.Normal(wt, 0.1)
#For each of the observed measurements, draw a sample from our distribution, #a sample that should be aligned with our observations.
      for i,result in enumerate(results):
          measurement = pyro.sample(f'observation_{i}', distribution,  
  1. Infer the actual weight of the object for which we got the observations.

We are using the Hamiltonian Monte Carlo (HMC) algorithm belonging to the Markov Chain Monte Carlo (MCMC) family of algorithms.

#Import MCMC and HMC in-built algorithm classes from pyro.infer

from pyro.infer import MCMC, HMC

#Instantiate HMC model which uses the my_model() function we defined in step(6) for #sampling

kernel = HMC(my_model)

#Define the MCMC model for inferring the most likely distribution of wt_1

mcmc = MCMC(kernel, num_samples=30000, warmup_steps=150)

#warmup_steps is the number of warmup iterations. So MCMC algorithm will run for #30,000 iterations. The samples generated in the warmup phase are eliminated. #num_samples is the same amount we need to generate (except the discarded ones #generated in the warmup phase). 

Note: Warmup phase (also called burn-in phase) in MCMC refers to the early phase of statistical modelling in which sequences get closer to the distribution’s mass.

#Send our observations to the MCMC model and run the model

  1. The samples we get while estimating wt_1 are in the form of torch tensor. Convert them into a NumPy array so that they can be plotted as a histogram.
 ` #numpy() function converts a torch tensor into a ndarray object 
  1. Plot the samples 
 plt.figure(figsize=(15, 5)) #Size of the plot
 sns.distplot(my_mcmc.get_samples()['weight1'].numpy(), kde=False, 
 plt.xlabel("Weight of object in kg") #X-axis label
 plt.ylabel("No. of observed samples") #Y-axis label #show the plot 

Sample output:

  1. Know the predicted most likely weight of the object from the model’s summary.
 #’prob’ here is the 95% credibility interval 

Sample output:

           Mean   std   median   2.5%   97.5%   n_eff       r_hat
 wt_1      0.77   0.04   0.77    0.69   0.85    27563.94    1.00
 Number of divergences: 0 

The output shows that the median of the distribution is 0.77. The MCMC algorithm is 95% confident that the object’s weight lies in the range of 0.69kg-0.85kg. The mean of these values i.e. 0.77kg shows that the actual weight is very close to this figure.r_hat is the degree of convergence of a random Markov Chain. N_eff shows the number of effective samples plotted in the histogram.

  • Google colab notebook for the above implementation code can be found here.


To have an in-depth understanding of the Pyro PPL, refer to the following sources:

What Do You Think?

Join Our Telegram Group. Be part of an engaging online community. Join Here.

Subscribe to our Newsletter

Get the latest updates and relevant offers by sharing your email.

Copyright Analytics India Magazine Pvt Ltd

Scroll To Top