Goals for today (and the future)

  1. Learn how to get R to compute WAIC and PSIS LOO-CV for our models.

    • Bonus: Learn how to deal with multiple packages that define functions with the same name.
  2. Start learning how to make use of this information to compare models.

  3. Learn two ways to improve our models

    1. Regularization – modifying our priors
    2. “Robust” regression – modifying our likelihood

Side note on packages

Many packages have functions with the same or similar names. In either case, you might not get the function you mean.

When a function with the same name exists in more than one package, R has a systematic way of decided what to use – it uses the one from the last package loaded. You can see the path R searches along using search():

search()
##  [1] ".GlobalEnv"          "package:ggdag"       "package:CalvinBayes"
##  [4] "package:bayesplot"   "package:tidybayes"   "package:rethinking" 
##  [7] "package:parallel"    "package:rstan"       "package:StanHeaders"
## [10] "package:loo"         "package:brms"        "package:Rcpp"       
## [13] "package:purrr"       "package:tidyr"       "package:patchwork"  
## [16] "package:knitr"       "package:pander"      "package:fastR2"     
## [19] "package:mosaic"      "package:ggridges"    "package:mosaicData" 
## [22] "package:ggformula"   "package:ggstance"    "package:dplyr"      
## [25] "package:Matrix"      "package:ggplot2"     "package:lattice"    
## [28] "package:dagitty"     "package:stats"       "package:graphics"   
## [31] "package:grDevices"   "package:utils"       "package:datasets"   
## [34] "package:methods"     "SciViews:TempEnv"    "Autoloads"          
## [37] "package:base"

And you can see which of these contain your function of interest using find().

find('compare')
## [1] "package:rethinking" "package:loo"        "package:mosaic"
find('WAIC')
## [1] "package:rethinking" "package:brms"
find('filter')
## [1] "package:ggdag" "package:dplyr" "package:stats"

When you want to be certain you are getting what you mean to get, use :: to say which package you want.

Note: If the packages were loaded into your environment in a different order than they were loaded in your R Markdown, you can get different results when you knit than when you run a chunk in the console.

Computing WAIC and PSIS/LOO in R

We don’t ever want to compute these by hand. They are involved computations and involved some careful programming to avoid accumulated rounding errors.

Fortunately, R will do all the work for us. Our job will be learning how to make use of these computations.

Adding log-likelihood info to a Stan model

Note that we need to tell Stan to keep track of the log-likelihood information it needs to compute these estimates of out-of-sample deviance.

Here’s how.

Pallets2 <-
  fastR2::Pallets %>%
  mutate(
    emp_idx = as.numeric(factor(employee)),
    day_idx = as.numeric(factor(day))
  )
set.seed(666)
u6 <- ulam(
  data = Pallets2,
  alist(
    pallets ~ dnorm(mu, sigma),
    mu <- b[emp_idx] + d[day_idx], 
    b[emp_idx] ~ dnorm(125, 20),
    d[day_idx] ~ dnorm(0, 10),
    sigma ~ dexp(1)
  ),
  chains = 4, iter = 4000, warmup = 1000, cores = 4, refresh = 0,
  log_lik = TRUE,            # NB: need to tell Stan to keep track of log-likelihood
  file = "fits/test2-u6")

One downside of this: The posterior distribution now has 20 additional things it is tracking – a log-likelihood for each row of our data set. We will want to remove those when they are in our way.

u6 %>% 
  stanfit() %>%
  mcmc_areas()

u6 %>% 
  stanfit() %>%
  mcmc_areas(pars = vars(-matches("lp"), -matches("log_lik")))

Using loo and rethinking to compute WAIC and PSIS/LOO

If we have included the log-likelihood information, then we can use the loo package or the rethinking package to get our out-of-sample deviance estimates. Using rethinking is a bit easier because it knows about ulam objects. For loo, we need to give it some help finding the information it needs. But loo will work with any models fit with Stan, not just the ones created using ulam().

u6 %>% stanfit() %>% loo::loo()
## Warning: Some Pareto k diagnostic values are too high. See help('pareto-k-diagnostic') for details.
## 
## Computed from 12000 by 20 log-likelihood matrix
## 
##          Estimate  SE
## elpd_loo    -46.9 4.2
## p_loo         9.1 2.7
## looic        93.9 8.4
## ------
## Monte Carlo SE of elpd_loo is NA.
## 
## Pareto k diagnostic values:
##                          Count Pct.    Min. n_eff
## (-Inf, 0.5]   (good)     14    70.0%   1940      
##  (0.5, 0.7]   (ok)        3    15.0%   1282      
##    (0.7, 1]   (bad)       3    15.0%   23        
##    (1, Inf)   (very bad)  0     0.0%   <NA>      
## See help('pareto-k-diagnostic') for details.
u6 %>% stanfit() %>% loo::extract_log_lik() %>% loo::waic()
## Warning: 
## 5 (25.0%) p_waic estimates greater than 0.4. We recommend trying loo instead.
## 
## Computed from 12000 by 20 log-likelihood matrix
## 
##           Estimate  SE
## elpd_waic    -45.3 3.5
## p_waic         7.5 2.0
## waic          90.7 7.0
## 
## 5 (25.0%) p_waic estimates greater than 0.4. We recommend trying loo instead.
u6 %>% rethinking::WAIC()
##       WAIC      lppd  penalty  std_err
## 1 90.69338 -37.85348 7.493209 6.831297
u6 %>% rethinking::PSIS()
## Some Pareto k values are high (>0.5). Set pointwise=TRUE to inspect individual points.
##       PSIS      lppd  penalty  std_err
## 1 93.89309 -46.94654 9.093065 8.425538

There are some warning messages we should investigate a bit.

u6 %>% rethinking::PSIS(pointwise = TRUE)
## Some Pareto k values are high (>0.5). Set pointwise=TRUE to inspect individual points.
##         PSIS      lppd   penalty  std_err         k
## 1   6.035924 -3.017962 0.8068309 8.425538 0.4551421
## 2   5.579056 -2.789528 0.6774241 8.425538 0.6385923
## 3   4.471797 -2.235898 0.3568349 8.425538 0.5091443
## 4   3.783746 -1.891873 0.1794347 8.425538 0.4499811
## 5  11.280805 -5.640402 2.6226388 8.425538 0.9401211
## 6   7.784532 -3.892266 1.3829496 8.425538 0.7870311
## 7   3.580419 -1.790209 0.1274257 8.425538 0.4199218
## 8   3.941095 -1.970548 0.2234272 8.425538 0.4817584
## 9   3.972583 -1.986291 0.2329946 8.425538 0.5578727
## 10  3.887821 -1.943910 0.2070843 8.425538 0.4774834
## 11  3.563163 -1.781582 0.1217214 8.425538 0.3500579
## 12  4.141036 -2.070518 0.2658781 8.425538 0.4181905
## 13  3.797467 -1.898733 0.1851048 8.425538 0.4572735
## 14  3.633286 -1.816643 0.1414083 8.425538 0.4113050
## 15  3.718076 -1.859038 0.1589375 8.425538 0.3603478
## 16  3.601481 -1.800740 0.1309903 8.425538 0.3635936
## 17  3.957038 -1.978519 0.2208342 8.425538 0.4708673
## 18  3.588353 -1.794177 0.1305099 8.425538 0.3759778
## 19  4.230522 -2.115261 0.2923391 8.425538 0.3443411
## 20  5.344891 -2.672445 0.6282963 8.425538 0.7224930
u6 %>% rethinking::PSIS(pointwise = TRUE) %>%
  apply(2, sum)
## Some Pareto k values are high (>0.5). Set pointwise=TRUE to inspect individual points.
##       PSIS       lppd    penalty    std_err          k 
##  93.893089 -46.946544   9.093065 168.510756   9.991496
u6 %>% stanfit() %>% loo::loo() %>% plot()
## Warning: Some Pareto k diagnostic values are too high. See help('pareto-k-diagnostic') for details.

u6 %>% rethinking::PSIS(pointwise = TRUE) %>%
  as.data.frame() %>%
  mutate(row = 1:n()) %>%
  bind_cols(Pallets2) %>%
  gf_point(k ~ row, color = ~employee)
## Some Pareto k values are high (>0.5). Set pointwise=TRUE to inspect individual points.

u6 %>% rethinking::PSIS(pointwise = TRUE) %>%
  as.data.frame() %>%
  mutate(row = 1:n()) %>%
  bind_cols(Pallets2) %>%
  gf_line(pallets ~ day, color = ~employee, group = ~ employee) %>%
  gf_point(pallets ~ day, color = ~employee, size = ~ k)
## Some Pareto k values are high (>0.5). Set pointwise=TRUE to inspect individual points.

Pallets2 %>% 
  mutate(row = 1:n()) %>%
  gf_line(pallets ~ day, color = ~employee, group = ~employee) %>%
  gf_text(pallets ~ day, color = ~employee, label = ~ row) 

Pallets2 <-
  fastR2::Pallets %>%
  mutate(
    emp_idx = as.numeric(factor(employee)),
    day_idx = as.numeric(factor(day))
  )
set.seed(666)
u6t<- ulam(
  data = Pallets2,
  alist(
    pallets ~ dstudent(2, mu, sigma),
    mu <- b[emp_idx] + d[day_idx], 
    b[emp_idx] ~ dnorm(125, 20),
    d[day_idx] ~ dnorm(0, 10),
    sigma ~ dexp(1)
  ),
  chains = 4, iter = 4000, warmup = 1000, cores = 4, refresh = 0,
  log_lik = TRUE,            # NB: need to tell Stan to keep track of log-likelihood
  file = "fits/test2-u6t")
u6t %>% PSIS()
##       PSIS      lppd  penalty std_err
## 1 89.95509 -44.97755 11.77763   9.638
compare(u6, u6t)
##         WAIC       SE     dWAIC      dSE     pWAIC    weight
## u6t 89.70633 9.546275 0.0000000       NA 11.653247 0.6209361
## u6  90.69338 6.831297 0.9870471 4.439826  7.493209 0.3790639
compare(u6, u6t) %>% plot()

coeftab(u6, u6t) %>% plot()

WAIC or PSIS/LOO?

WAIC and PSIS-LOO-CV are both trying to estimate out-of-sample deviance, so they will often give very similar results. So which do we prefer?

The authors of the loo package (who are on the Stan development team) have this recommendation:

The waic() methods can be used to compute WAIC from the pointwise log-likelihood. However, we recommend LOO-CV using PSIS (as implemented by the loo() function) because PSIS provides useful diagnostics as well as effective sample size and Monte Carlo estimates.

McElreath seems to sometimes use one and sometimes the other.

Example Model Comparisions

Deviance doesn’t mean anything on its own

The numbers we get from WAIC() and PSIS() don’t mean much in isolation, they are only useful as a way to compare models that use the same data. In that context, a value closer to 0 indicates a model that we expect will make better out-of-sample predictions.

But remember that these are approximations, so we shouldn’t read too much into small differences unless we know how good the approximations are. Fortunately both WAIC and PSIS provide estimates of their precision as well. We will use these standard errors as well as the estimates when comparing models.

Waffle Houses

A while back we fit three models to predict divorce rate.

model predictors
m5.1 average age at marriate (A)
m5.2 marriage rate (M)
m5.3 both (additive)
data(WaffleDivorce)
Waffles <-
  WaffleDivorce %>%
  mutate(
    WaffleHousesPerCap = WaffleHouses / Population,
    D = standardize(Divorce),
    W = standardize(WaffleHousesPerCap),
    A = standardize(MedianAgeMarriage),
    M = standardize(Marriage)
  )
m5.1 <- quap(
  data = Waffles,
  alist(
    D ~ dnorm(mu, sigma),
    mu <- b_0 + b_A * A,
    b_0 ~ dnorm(0, 0.2),
    b_A ~ dnorm(0, 0.5),
    sigma ~  dexp(1)
  )
)

m5.2 <- quap(
  data = Waffles,
  alist(
    D ~ dnorm(mu, sigma),
    mu <- b_0 + b_M * M,
    b_0 ~ dnorm(0, 0.2),
    b_M ~ dnorm(0, 0.5),
    sigma ~  dexp(1)
  )
)

m5.3 <- quap(
  data = Waffles,
  alist(
    D ~ dnorm(mu, sigma),
    mu <- b_0 + b_A * A + b_M * M,
    b_0 ~ dnorm(0, 0.2),
    b_A ~ dnorm(0, 0.5),
    b_M ~ dnorm(0, 0.5),
    sigma ~  dexp(1)
  )
)

The rethinking package includes a compare function that makes it easy to compare models based on WAIC or PSIS.

compare(m5.1, m5.2, m5.3) %>% pander()
  WAIC SE dWAIC dSE pWAIC weight
m5.1 126.8 14.27 0 NA 4.291 0.8279
m5.3 129.9 15.04 3.153 1.206 6.211 0.1711
m5.2 140.3 10.97 13.51 10.56 3.574 0.000965
compare(m5.1, m5.2, m5.3) %>% plot()

compare(m5.1, m5.2, m5.3)@dSE
##            m5.1     m5.2       m5.3
## m5.1         NA 10.46700  0.9718292
## m5.2 10.4670023       NA 10.8524181
## m5.3  0.9718292 10.85242         NA
  • black dot: in-sample deviance
  • open circles: WAIC (estimated out-of-sample deviance)
  • bars: standard error added and subtracted to WAIC
  • triangles and bars show SE for the difference between a model and the best model.

We can do the same thing for PSIS:

compare(m5.1, m5.2, m5.3, func = PSIS) %>% pander()
## Some Pareto k values are very high (>1). Set pointwise=TRUE to inspect individual points.
## Some Pareto k values are high (>0.5). Set pointwise=TRUE to inspect individual points.
## Some Pareto k values are very high (>1). Set pointwise=TRUE to inspect individual points.
  PSIS SE dPSIS dSE pPSIS weight
m5.1 127.8 15.06 0 NA 4.752 0.8587
m5.3 131.4 16.61 3.628 1.826 6.988 0.14
m5.2 140.6 11.39 12.88 11.28 3.77 0.001372
compare(m5.1, m5.2, m5.3, func = PSIS) %>% plot()
## Some Pareto k values are very high (>1). Set pointwise=TRUE to inspect individual points.
## Some Pareto k values are high (>0.5). Set pointwise=TRUE to inspect individual points.
## Some Pareto k values are high (>0.5). Set pointwise=TRUE to inspect individual points.

compare(m5.1, m5.2, m5.3)@dSE
##            m5.1     m5.2       m5.3
## m5.1         NA 10.49201  0.9263976
## m5.2 10.4920089       NA 11.0069926
## m5.3  0.9263976 11.00699         NA

But what about those warnings?

Let’s see which states have large values of \(k\) for Pareto smoothing or contribute a large amount to the penalty for WAIC.

set.seed(12345)
bind_cols(
  PSIS(m5.3, pointwise = TRUE) %>% rename(penalty_P = penalty),
  WAIC(m5.3, pointwise = TRUE) %>% rename(penalty_W = penalty),
  Waffles
) %>%
  gf_point( penalty_W ~ k ) %>%
  gf_text(label = ~ Loc, alpha = 0.7, size = 10, color = "steelblue",
          data = . %>% filter(k > 0.5))
## Some Pareto k values are high (>0.5). Set pointwise=TRUE to inspect individual points.
## New names:
## * lppd -> lppd...2
## * std_err -> std_err...4
## * lppd -> lppd...7
## * std_err -> std_err...9

Idaho stands out pretty dramatically for both. It is an unusual (surprising, hard to predict) observation. If there were new data, it might not have a value like that, so there is risk that we are overfitting to match this one observation better. What can we do about this?

  • Bad idea: throw out Idaho.

    Removing outliers simply because they are outliers is almost always a bad idea.

  • Better idea: improve the model.

    In this case, let’s create a model that isn’t so surprised by divorce rates that are quite a bit higher or lower than predicted. That is, we want to replace the normal distribution with a distribution that has heavier tails. A commonly used distribution that is similar to the normal distribution but with heavier tails is Student’s t distribution.

m5.3t <- quap(
  data = Waffles,
  alist(
    D ~ dstudent(2, mu, sigma),
    mu <- b_0 + b_A * A + b_M * M,
    b_0 ~ dnorm(0, 0.2),
    b_A ~ dnorm(0, 0.5),
    b_M ~ dnorm(0, 0.5),
    sigma ~  dexp(1)
  )
)

No more warnings – all the Pareto \(k\) values are now well below 0.5:

m5.3t %>% PSIS()
##       PSIS      lppd  penalty  std_err
## 1 133.4585 -66.72925 6.698328 11.89951
m5.3t %>% PSIS(pointwise = TRUE) %>% 
  data.frame() %>%
  mutate(row = 1:n()) %>%
  gf_point(k ~ row)

Comparing the two models

compare(m5.3, m5.3t, func = PSIS) %>% pander()
## Some Pareto k values are very high (>1). Set pointwise=TRUE to inspect individual points.
  PSIS SE dPSIS dSE pPSIS weight
m5.3 132.7 17.14 0 NA 7.54 0.6258
m5.3t 133.8 11.94 1.028 9.225 6.91 0.3742
compare(m5.3, m5.3t, func = PSIS) %>% plot()
## Some Pareto k values are very high (>1). Set pointwise=TRUE to inspect individual points.

coeftab(m5.3, m5.3t) %>% plot()

Regularizing to make your model less prone to overfitting

The figures below (and in the text) are based on a simulation where

  • There are 4 predictors: \(x_1, x_2, x_3, x_4\).

  • The response is defined by \(y = \operatorname{rnorm}(0.15 \cdot x_1 - 0.4 \cdot x_2, 1)\)

    • So the intercept is 0,
    • Slopes on \(x_1\) and \(x_2\) are \(0.15\) and \(-0.4\),
    • \(x_3\) and \(x_4\) are unrelated to \(y\)
  • Models are fit using 1 - 5 parameters (intercept plus slopes), in order. So the best (for predicting) model should be the model with 3 parameters (2 predictors). Any improvement due to adding in the last two predictors would be unlikely to persist to new data.

Since it is a simulation, we can simulate a new data set generated the same way to compute the actual out-of-sample deviance. Here are the results.

Now let’s consider regularizing our model. This means using narrower priors for the coefficients. Using narrower priors makes the model less sensitive to the data, so less prone to overfitting (but more susceptible to underfitting).

We’ll do this with three versions of the priors (after standardizing in the usual way).

gf_dist("norm", mea = 0, sd = 1, color = ~ "Norm(0,1)") %>%
  gf_dist("norm", mea = 0, sd = 0.5, color = ~ "Norm(0,0.5)") %>%
  gf_dist("norm", mea = 0, sd = 0.2, color = ~ "Norm(0,0.2)") %>%
  gf_theme(legend.position = "top") %>%
  gf_labs(color = "Prior distribution")

As the prior becomes narrower, the effect of overfitting becomes less pronounced.

Things to notice:

  • In the smaller data set, the narrower priors fit the data less well (larger in-sample deviance) but fit new data better (smaller out-of-sample deviance)

  • Choice of prior matters less in a larger data set We’ll return regularizing in the context of other examples. For now, keep in mind that this is another effect of the choice of prior.

  • In each of these cases, deviance is picking up that 3 is the right number of parameters.

Plants again