PyStan: A Second Intermediate Tutorial of Bayesian Analysis in Python

I promised a while ago that I’d give a more advanced tutorial of using PySTAN and Python to fit a Bayesian hierarchical model. Well, I’ve been waiting for a while because the paper was in review and then in print. Now, it’s out and I’m super excited! My first pure Python paper, using Python for all data manipulation, analysis, and plotting.

The question was whether temperature affects herbivory by insects in any predictable way. I gathered as many insect species as I could and fed them whatever they ate at multiple temperatures. Check the article for more detail, but the idea was to fit a curve to all 21 herbivore-plant pairs as well as to estimate the overall effect of temperature. We also suspected (incorrectly as it turns out) that plant nutritional quality might be a good predictor of the shape of these curves, so we included that as a group-level predictor.

Anyway, here’s the code, complete with STAN model, posterior manipulations, and some plotting. First, here’s the actual STAN model. NOTE: a lot of data manipulation and whatnot is missing. The point is not to show that but to describe how to fit a STAN model and work with the output. Anyone who wants the full code and data to work with can find it on my website or in Dryad (see the article for a link).


stanMod = """
data{
int<lower = 1> N;
int<lower = 1> J;
vector[N] Temp;
vector[N] y;
int<lower = 1> Curve[N];
vector[J] pctN;
vector[J] pctP;
vector[J] pctH20;
matrix[3, 3] R;
}

parameters{
vector[3] beta[J];
real mu_a;
real mu_b;
real mu_c;
real g_a1;
real g_a2;
real g_a3;
real g_b1;
real g_b2;
real g_b3;
real g_c1;
real g_c2;
real g_c3;
real<lower = 0> sigma[J];
cov_matrix[3] Tau;
}

transformed parameters{
vector[N] y_hat;
vector[N] sd_y;
vector[3] beta_hat[J];
// First, get the predicted value as an exponential curve
// Also make a dummy variable for SD so it can be vectorized
for (n in 1:N){
y_hat[n] <- exp( beta[Curve[n], 1] + beta[Curve[n], 2]*Temp[n] + beta[Curve[n], 3]*pow(Temp[n], 2) );
sd_y[n] <- sigma[Curve[n]];
}
// Next, for each group-level coefficient, include the group-level predictors
for (j in 1:J){
beta_hat[j, 1] <- mu_a + g_a1*pctN[j] + g_a2*pctP[j] + g_a3*pctH20[j];
beta_hat[j, 2] <- mu_b + g_b1*pctN[j] + g_b2*pctP[j] + g_b3*pctH20[j];
beta_hat[j, 3] <- mu_c + g_c1*pctN[j] + g_c2*pctP[j] + g_c3*pctH20[j];
}
}

model{
y ~ normal(y_hat, sd_y);
for (j in 1:J){
beta[j] ~ multi_normal_prec(beta_hat[j], Tau);
}
// PRIORS
mu_a ~ normal(0, 1);
mu_b ~ normal(0, 1);
mu_c ~ normal(0, 1);
g_a1 ~ normal(0, 1);
g_a2 ~ normal(0, 1);
g_a3 ~ normal(0, 1);
g_b1 ~ normal(0, 1);
g_b2 ~ normal(0, 1);
g_b3 ~ normal(0, 1);
g_c1 ~ normal(0, 1);
g_c2 ~ normal(0, 1);
g_c3 ~ normal(0, 1);
sigma ~ uniform(0, 100);
Tau ~ wishart(4, R);
}
"""

# fit the model!
fit = pystan.stan(model_code=stanMod, data=dat,
 iter=10000, chains=4, thin = 20)

Not so bad, was it? It’s actually pretty straightforward.

After the model has been run, we work with the output. We can check traceplots of various parameters:

fit.plot(['mu_a', 'mu_b', 'mu_c'])
fit.plot(['g_a1', 'g_a2', 'g_a3'])
fit.plot(['g_b1', 'g_b2', 'g_b3'])
fit.plot(['g_c1', 'g_c2', 'g_c3'])
py.show()

As a brief example, we can extract the overall coefficients and plot them:


mus = fit.extract(['mu_a', 'mu_b', 'mu_c'])
mus = pd.DataFrame({'Intercept' : mus['mu_a'], 'Linear' : mus['mu_b'], 'Quadratic' : mus['mu_c']})

py.plot(mus.median(), range(3), 'ko', ms = 10)
py.hlines(range(3), mus.quantile(0.025), mus.quantile(0.975), 'k')
py.hlines(range(3), mus.quantile(0.1), mus.quantile(0.9), 'k', linewidth = 3)
py.axvline(0, linestyle = 'dashed', color = 'k')
py.xlabel('Median Coefficient Estimate (80 and 95% CI)')
py.yticks(range(3), ['Intercept', 'Exponential', 'Gaussian'])
py.ylim([-0.5, 2.5])
py.title('Overall Coefficients')
py.gca().invert_yaxis()
py.show()

The resulting plot:

Figure3

We can also make a prediction line with confidence intervals:

#first, define a prediction function
def predFunc(x, v = 1):
 yhat = np.exp( x[0] + x[1]*xPred + v*x[2]*xPred**2 )
 return pd.Series({'yhat' : yhat})

# next, define a function to return the quantiles at each predicted value
def quantGet(data , q):
 quant = []
 for i in range(len(xPred)):
 val = []
 for j in range(len(data)):
 val.append( data[j][i] )
 quant.append( np.percentile(val, q) )
 return quant

# make a vector of temperatures to predict (and convert to the real temperature scale)
xPred = np.linspace(feeding_Final['Temp_Scale'].min(), feeding_Final['Temp_Scale'].max(), 100)
realTemp = xPred * feeding_Final['Temperature'].std() + feeding_Final['Temperature'].mean()

# make predictions for every chain (in overall effects)
ovPred = mus.apply(predFunc, axis = 1)

# get lower and upper quantiles
ovLower = quantGet(ovPred['yhat'], 2.5)
ovLower80 = quantGet(ovPred['yhat'], 10)
ovUpper80 = quantGet(ovPred['yhat'], 90)
ovUpper = quantGet(ovPred['yhat'], 97.5)

# get median predictions
ovPred = predFunc(mus.median())

Then, just plot the median (ovPred) and the quantiles against temperature (realTemp). With just a little effort, you can wind up with something that looks pretty good:

Figure2

I apologize for only posting part of the code, but the full script is really long. This should serve as a pretty good start for anyone looking to use Python as their Bayesian platform of choice. Anyone interested can get the data and full script from my article or website and give it a try! It’s all publicly available.

Advertisements

5 thoughts on “PyStan: A Second Intermediate Tutorial of Bayesian Analysis in Python

  1. Pingback: Repeatable and transparent data analysis: making the leap from Excel to Python (with tutorial) | EcoPress

  2. Thanks for blogging about Stan! I’m really excited about all these ecology models—I spent part of my time in Melbourne earlier this year hanging with a bunch of statistical population ecologists and am still working on soil carbon models in my spare time. We’re working on improving out diff eq solver if there are models like this most naturally formulated in terms of system dynamics and fit with noisy measurements.

    There’s a small bug in your model as written. Parameters need to have support over their declared constraints. Because sigma is declared with

    real sigma[J];
    

    this

    sigma ~ uniform(0, 100);
    

    is ill-formed because sigma is allowed to be greater than 100 in its declared constraint but not in its support. This won’t cause you to get the wrong answer if the correct answer isn’t concentrated around a boundary, but it can cause computational issues with effective sample size. In any case, we’re recommend providing a more informative prior on sigma declared with no constraints. Alternatively, you could add , upper=100 to the declared constraint if you want to stick with the uniform-on-an-interval prior.

    We also have a very nice multivariate prior that’s more interpretable than the Wishart — see the regression chapter of the manual on multivariate priors. It captures the intent of what Gelman and Hill recommend in their regression book with a scaled inverse Wishart prior on covariance, but does it more directly with a natural prior on the correlation structure of the covariance matrix that’s independent (in the prior) from the scales.

    For efficiency, you really want to vectorize this

    for (j in 1:J)
      beta[j] ~ multi_normal_prec(beta_hat[j], Tau);
    

    to

    beta ~ multi_normal_prec(beta_hat, Tau);
    

    You already have the right data structure for beta and beta_hat here.

    It helps to read code if it’s indented. We have style recommendations in the back of the manual. Wrestling with WordPress and other tools is another matter.

Leave a Reply

Fill in your details below or click an icon to log in:

WordPress.com Logo

You are commenting using your WordPress.com account. Log Out / Change )

Twitter picture

You are commenting using your Twitter account. Log Out / Change )

Facebook photo

You are commenting using your Facebook account. Log Out / Change )

Google+ photo

You are commenting using your Google+ account. Log Out / Change )

Connecting to %s