PyStan: A Basic Tutorial of Bayesian Data Analysis in Python

I’ve been waiting for PyStan for a little while, and it has arrived. The STAN team has posted great examples of the STAN modeling language and very brief examples of how to run PyStan. However, there examples are brief and stop just after model fitting. They do not include the full runthrough of plotting predictions, traceplots of specific parameters, etc. So I thought I’d do a blog post of basic linear regression in PyStan, detailing how to code the model, fit the model, check the model, and plot the results. The whole shebang. So without further adieu:

# module import
import pystan
import numpy as np
import pylab as py
import pandas as pd

## data simulation
x = np.arange(1, 100, 5)
y = 2.5 + .5 * x + np.random.randn(20) * 10

# get number of observations
N = len(x)

# plot the data
py.plot(x,y, 'o')
py.show()

# STAN model (this is the most important part)
regress_code = """
data {
 int<lower = 0> N; // number of observations
 real y[N]; // response variable
 real x[N]; // predictor variable
}
parameters {
 real a; // intercept
 real b; // slope
 real<lower=0> sigma; // standard deviation
}
transformed parameters {
 real mu[N]; // fitted values

for(i in 1:N)
 mu[i] <- a + b*x[i];
}
model {
 y ~ normal(mu, sigma);
}
"""

# make a dictionary containing all data to be passed to STAN
regress_dat = {'x': x,
 'y': y,
 'N': N}

# Fit the model
fit = pystan.stan(model_code=regress_code, data=regress_dat,
 iter=1000, chains=4)

# model summary
print fit

# show a traceplot of ALL parameters. This is a bear if you have many
fit.traceplot()
py.show()

# Instead, show a traceplot for single parameter
fit.plot(['a'])
py.show()

##### PREDICTION ####

# make a dataframe of parameter estimates for all chains
params = pd.DataFrame({'a': fit.extract('a', permuted=True), 'b': fit.extract('b', permuted=True)})

# next, make a prediction function. Making a function makes every step following this 10 times easier
def stanPred(p):
 fitted = p[0] + p[1] * predX
 return pd.Series({'fitted': fitted})

# make a prediction vector (the values of X for which you want to predict)
predX = np.arange(1, 100)

# get the median parameter estimates
medParam = params.median()
# predict
yhat = stanPred(medParam)

# get the predicted values for each chain. This is super convenient in pandas because
# it is possible to have a single column where each element is a list
chainPreds = params.apply(stanPred, axis = 1)

## PLOTTING

# create a random index for chain sampling
idx = np.random.choice(1999, 50)
# plot each chain. chainPreds.iloc[i, 0] gets predicted values from the ith set of parameter estimates
for i in range(len(idx)):
 py.plot(predX, chainPreds.iloc[idx[i], 0], color='lightgrey')

# original data
py.plot(x, y, 'ko')
# fitted values
py.plot(predX, yhat['fitted'], 'k')

# supplementals
py.xlabel('X')
py.ylabel('Y')

py.show()

This yields the following:

stan_chains

Instead of showing chains (which can make a messy, hard to read plot in some cases), we can show a shaded credible interval region:

# make a function that iterates over every predicted values in every chain and returns the quantiles. For example:

def quantileGet(q):
    # make a list to store the quantiles
    quants = []

    # for every predicted value
    for i in range(len(predX)):
        # make a vector to store the predictions from each chain
        val = []

        # next go down the rows and store the values
        for j in range(chainPreds.shape[0]):
            val.append(chainPreds['fitted'][j][i])

        # return the quantile for the predictions.
        quants.append(np.percentile(val, q))

    return quants

# NOTE THAT NUMPY DOES PERCENTILES, SO MULTIPLE QUANTILE BY 100
# 2.5% quantile
lower = quantileGet(2.5)
#97.5
upper = quantileGet(97.5)

# plot this
fig = py.figure()
ax = fig.add_subplot(111)

# shade the credible interval
ax.fill_between(predX, lower, upper, facecolor = 'lightgrey', edgecolor = 'none')
# plot the data
ax.plot(x, y, 'ko')
# plot the fitted line
ax.plot(predX, yhat['fitted'], 'k')

# supplementals
ax.set_xlabel('X')
ax.set_ylabel('Y')
ax.grid()

py.show()

stan_quants

And there you have it! Data analysis in Python just got a whole lot better.

Advertisements

3 thoughts on “PyStan: A Basic Tutorial of Bayesian Data Analysis in Python

  1. Thank you for this post !
    I had two errors trying to run your code with Pystan . Here are the changes I had to make :

    Line 64 reads:
    params = pd.DataFrame({‘a’: fit.extract(‘a’, permuted=True), ‘b’: fit.extract(‘b’, permuted=True)})

    It should read:
    samples = fit.extract(permuted=True)
    params = pd.DataFrame({‘a’: samples[‘a’], ‘b’: samples[‘b’]})

    In line 14 of the second part:
    val.append(chainPreds[‘fitted’][j][i])

    It should read:
    val.append(chainPreds.iloc[j, 0][i])

    WIth these 2 changes, your code worked nicely !

  2. Thanks for the good code.

    One comment-the terminology is off. You are using “chains” when you should be saying “runs” or MCMC runs. There aren’t 2000 chains in being created. There are 4 chains, and each chain as 1,000 runs, and there are 2,000 posterior draws (each of length 100).

Leave a Reply

Fill in your details below or click an icon to log in:

WordPress.com Logo

You are commenting using your WordPress.com account. Log Out / Change )

Twitter picture

You are commenting using your Twitter account. Log Out / Change )

Facebook photo

You are commenting using your Facebook account. Log Out / Change )

Google+ photo

You are commenting using your Google+ account. Log Out / Change )

Connecting to %s