I’ve been waiting for PyStan for a little while, and it has arrived. The STAN team has posted great examples of the STAN modeling language and very brief examples of how to run PyStan. However, there examples are brief and stop just after model fitting. They do not include the full runthrough of plotting predictions, traceplots of specific parameters, etc. So I thought I’d do a blog post of basic linear regression in PyStan, detailing how to code the model, fit the model, check the model, and plot the results. The whole shebang. So without further adieu:

# module import import pystan import numpy as np import pylab as py import pandas as pd ## data simulation x = np.arange(1, 100, 5) y = 2.5 + .5 * x + np.random.randn(20) * 10 # get number of observations N = len(x) # plot the data py.plot(x,y, 'o') py.show() # STAN model (this is the most important part) regress_code = """ data { int<lower = 0> N; // number of observations real y[N]; // response variable real x[N]; // predictor variable } parameters { real a; // intercept real b; // slope real<lower=0> sigma; // standard deviation } transformed parameters { real mu[N]; // fitted values for(i in 1:N) mu[i] <- a + b*x[i]; } model { y ~ normal(mu, sigma); } """ # make a dictionary containing all data to be passed to STAN regress_dat = {'x': x, 'y': y, 'N': N} # Fit the model fit = pystan.stan(model_code=regress_code, data=regress_dat, iter=1000, chains=4) # model summary print fit # show a traceplot of ALL parameters. This is a bear if you have many fit.traceplot() py.show() # Instead, show a traceplot for single parameter fit.plot(['a']) py.show() ##### PREDICTION #### # make a dataframe of parameter estimates for all chains params = pd.DataFrame({'a': fit.extract('a', permuted=True), 'b': fit.extract('b', permuted=True)}) # next, make a prediction function. Making a function makes every step following this 10 times easier def stanPred(p): fitted = p[0] + p[1] * predX return pd.Series({'fitted': fitted}) # make a prediction vector (the values of X for which you want to predict) predX = np.arange(1, 100) # get the median parameter estimates medParam = params.median() # predict yhat = stanPred(medParam) # get the predicted values for each chain. This is super convenient in pandas because # it is possible to have a single column where each element is a list chainPreds = params.apply(stanPred, axis = 1) ## PLOTTING # create a random index for chain sampling idx = np.random.choice(1999, 50) # plot each chain. chainPreds.iloc[i, 0] gets predicted values from the ith set of parameter estimates for i in range(len(idx)): py.plot(predX, chainPreds.iloc[idx[i], 0], color='lightgrey') # original data py.plot(x, y, 'ko') # fitted values py.plot(predX, yhat['fitted'], 'k') # supplementals py.xlabel('X') py.ylabel('Y') py.show()

This yields the following:

Instead of showing chains (which can make a messy, hard to read plot in some cases), we can show a shaded credible interval region:

# make a function that iterates over every predicted values in every chain and returns the quantiles. For example: def quantileGet(q): # make a list to store the quantiles quants = [] # for every predicted value for i in range(len(predX)): # make a vector to store the predictions from each chain val = [] # next go down the rows and store the values for j in range(chainPreds.shape[0]): val.append(chainPreds['fitted'][j][i]) # return the quantile for the predictions. quants.append(np.percentile(val, q)) return quants # NOTE THAT NUMPY DOES PERCENTILES, SO MULTIPLE QUANTILE BY 100 # 2.5% quantile lower = quantileGet(2.5) #97.5 upper = quantileGet(97.5) # plot this fig = py.figure() ax = fig.add_subplot(111) # shade the credible interval ax.fill_between(predX, lower, upper, facecolor = 'lightgrey', edgecolor = 'none') # plot the data ax.plot(x, y, 'ko') # plot the fitted line ax.plot(predX, yhat['fitted'], 'k') # supplementals ax.set_xlabel('X') ax.set_ylabel('Y') ax.grid() py.show()

And there you have it! Data analysis in Python just got a whole lot better.

great men!

Thank you for this post !

I had two errors trying to run your code with Pystan . Here are the changes I had to make :

Line 64 reads:

params = pd.DataFrame({‘a’: fit.extract(‘a’, permuted=True), ‘b’: fit.extract(‘b’, permuted=True)})

It should read:

samples = fit.extract(permuted=True)

params = pd.DataFrame({‘a’: samples[‘a’], ‘b’: samples[‘b’]})

In line 14 of the second part:

val.append(chainPreds[‘fitted’][j][i])

It should read:

val.append(chainPreds.iloc[j, 0][i])

WIth these 2 changes, your code worked nicely !

Thanks for the good code.

One comment-the terminology is off. You are using “chains” when you should be saying “runs” or MCMC runs. There aren’t 2000 chains in being created. There are 4 chains, and each chain as 1,000 runs, and there are 2,000 posterior draws (each of length 100).