Learn how to get R to compute WAIC and PSIS LOO-CV for our models.
Start learning how to make use of this information to compare models.
Learn two ways to improve our models
Many packages have functions with the same or similar names. In either case, you might not get the function you mean.
When a function with the same name exists in more than one package, R has a systematic way of decided what to use – it uses the one from the last package loaded. You can see the path R searches along using search()
:
search()
## [1] ".GlobalEnv" "package:ggdag" "package:CalvinBayes"
## [4] "package:bayesplot" "package:tidybayes" "package:rethinking"
## [7] "package:parallel" "package:rstan" "package:StanHeaders"
## [10] "package:loo" "package:brms" "package:Rcpp"
## [13] "package:purrr" "package:tidyr" "package:patchwork"
## [16] "package:knitr" "package:pander" "package:fastR2"
## [19] "package:mosaic" "package:ggridges" "package:mosaicData"
## [22] "package:ggformula" "package:ggstance" "package:dplyr"
## [25] "package:Matrix" "package:ggplot2" "package:lattice"
## [28] "package:dagitty" "package:stats" "package:graphics"
## [31] "package:grDevices" "package:utils" "package:datasets"
## [34] "package:methods" "SciViews:TempEnv" "Autoloads"
## [37] "package:base"
And you can see which of these contain your function of interest using find()
.
find('compare')
## [1] "package:rethinking" "package:loo" "package:mosaic"
find('WAIC')
## [1] "package:rethinking" "package:brms"
find('filter')
## [1] "package:ggdag" "package:dplyr" "package:stats"
When you want to be certain you are getting what you mean to get, use ::
to say which package you want.
Note: If the packages were loaded into your environment in a different order than they were loaded in your R Markdown, you can get different results when you knit than when you run a chunk in the console.
We don’t ever want to compute these by hand. They are involved computations and involved some careful programming to avoid accumulated rounding errors.
Fortunately, R will do all the work for us. Our job will be learning how to make use of these computations.
Note that we need to tell Stan to keep track of the log-likelihood information it needs to compute these estimates of out-of-sample deviance.
Here’s how.
Pallets2 <-
fastR2::Pallets %>%
mutate(
emp_idx = as.numeric(factor(employee)),
day_idx = as.numeric(factor(day))
)
set.seed(666)
u6 <- ulam(
data = Pallets2,
alist(
pallets ~ dnorm(mu, sigma),
mu <- b[emp_idx] + d[day_idx],
b[emp_idx] ~ dnorm(125, 20),
d[day_idx] ~ dnorm(0, 10),
sigma ~ dexp(1)
),
chains = 4, iter = 4000, warmup = 1000, cores = 4, refresh = 0,
log_lik = TRUE, # NB: need to tell Stan to keep track of log-likelihood
file = "fits/test2-u6")
One downside of this: The posterior distribution now has 20 additional things it is tracking – a log-likelihood for each row of our data set. We will want to remove those when they are in our way.
u6 %>%
stanfit() %>%
mcmc_areas()
u6 %>%
stanfit() %>%
mcmc_areas(pars = vars(-matches("lp"), -matches("log_lik")))
If we have included the log-likelihood information, then we can use the loo
package or the rethinking
package to get our out-of-sample deviance estimates. Using rethinking
is a bit easier because it knows about ulam
objects. For loo
, we need to give it some help finding the information it needs. But loo
will work with any models fit with Stan, not just the ones created using ulam()
.
u6 %>% stanfit() %>% loo::loo()
## Warning: Some Pareto k diagnostic values are too high. See help('pareto-k-diagnostic') for details.
##
## Computed from 12000 by 20 log-likelihood matrix
##
## Estimate SE
## elpd_loo -46.9 4.2
## p_loo 9.1 2.7
## looic 93.9 8.4
## ------
## Monte Carlo SE of elpd_loo is NA.
##
## Pareto k diagnostic values:
## Count Pct. Min. n_eff
## (-Inf, 0.5] (good) 14 70.0% 1940
## (0.5, 0.7] (ok) 3 15.0% 1282
## (0.7, 1] (bad) 3 15.0% 23
## (1, Inf) (very bad) 0 0.0% <NA>
## See help('pareto-k-diagnostic') for details.
u6 %>% stanfit() %>% loo::extract_log_lik() %>% loo::waic()
## Warning:
## 5 (25.0%) p_waic estimates greater than 0.4. We recommend trying loo instead.
##
## Computed from 12000 by 20 log-likelihood matrix
##
## Estimate SE
## elpd_waic -45.3 3.5
## p_waic 7.5 2.0
## waic 90.7 7.0
##
## 5 (25.0%) p_waic estimates greater than 0.4. We recommend trying loo instead.
u6 %>% rethinking::WAIC()
## WAIC lppd penalty std_err
## 1 90.69338 -37.85348 7.493209 6.831297
u6 %>% rethinking::PSIS()
## Some Pareto k values are high (>0.5). Set pointwise=TRUE to inspect individual points.
## PSIS lppd penalty std_err
## 1 93.89309 -46.94654 9.093065 8.425538
There are some warning messages we should investigate a bit.
u6 %>% rethinking::PSIS(pointwise = TRUE)
## Some Pareto k values are high (>0.5). Set pointwise=TRUE to inspect individual points.
## PSIS lppd penalty std_err k
## 1 6.035924 -3.017962 0.8068309 8.425538 0.4551421
## 2 5.579056 -2.789528 0.6774241 8.425538 0.6385923
## 3 4.471797 -2.235898 0.3568349 8.425538 0.5091443
## 4 3.783746 -1.891873 0.1794347 8.425538 0.4499811
## 5 11.280805 -5.640402 2.6226388 8.425538 0.9401211
## 6 7.784532 -3.892266 1.3829496 8.425538 0.7870311
## 7 3.580419 -1.790209 0.1274257 8.425538 0.4199218
## 8 3.941095 -1.970548 0.2234272 8.425538 0.4817584
## 9 3.972583 -1.986291 0.2329946 8.425538 0.5578727
## 10 3.887821 -1.943910 0.2070843 8.425538 0.4774834
## 11 3.563163 -1.781582 0.1217214 8.425538 0.3500579
## 12 4.141036 -2.070518 0.2658781 8.425538 0.4181905
## 13 3.797467 -1.898733 0.1851048 8.425538 0.4572735
## 14 3.633286 -1.816643 0.1414083 8.425538 0.4113050
## 15 3.718076 -1.859038 0.1589375 8.425538 0.3603478
## 16 3.601481 -1.800740 0.1309903 8.425538 0.3635936
## 17 3.957038 -1.978519 0.2208342 8.425538 0.4708673
## 18 3.588353 -1.794177 0.1305099 8.425538 0.3759778
## 19 4.230522 -2.115261 0.2923391 8.425538 0.3443411
## 20 5.344891 -2.672445 0.6282963 8.425538 0.7224930
u6 %>% rethinking::PSIS(pointwise = TRUE) %>%
apply(2, sum)
## Some Pareto k values are high (>0.5). Set pointwise=TRUE to inspect individual points.
## PSIS lppd penalty std_err k
## 93.893089 -46.946544 9.093065 168.510756 9.991496
u6 %>% stanfit() %>% loo::loo() %>% plot()
## Warning: Some Pareto k diagnostic values are too high. See help('pareto-k-diagnostic') for details.
u6 %>% rethinking::PSIS(pointwise = TRUE) %>%
as.data.frame() %>%
mutate(row = 1:n()) %>%
bind_cols(Pallets2) %>%
gf_point(k ~ row, color = ~employee)
## Some Pareto k values are high (>0.5). Set pointwise=TRUE to inspect individual points.
u6 %>% rethinking::PSIS(pointwise = TRUE) %>%
as.data.frame() %>%
mutate(row = 1:n()) %>%
bind_cols(Pallets2) %>%
gf_line(pallets ~ day, color = ~employee, group = ~ employee) %>%
gf_point(pallets ~ day, color = ~employee, size = ~ k)
## Some Pareto k values are high (>0.5). Set pointwise=TRUE to inspect individual points.
Pallets2 %>%
mutate(row = 1:n()) %>%
gf_line(pallets ~ day, color = ~employee, group = ~employee) %>%
gf_text(pallets ~ day, color = ~employee, label = ~ row)
Pallets2 <-
fastR2::Pallets %>%
mutate(
emp_idx = as.numeric(factor(employee)),
day_idx = as.numeric(factor(day))
)
set.seed(666)
u6t<- ulam(
data = Pallets2,
alist(
pallets ~ dstudent(2, mu, sigma),
mu <- b[emp_idx] + d[day_idx],
b[emp_idx] ~ dnorm(125, 20),
d[day_idx] ~ dnorm(0, 10),
sigma ~ dexp(1)
),
chains = 4, iter = 4000, warmup = 1000, cores = 4, refresh = 0,
log_lik = TRUE, # NB: need to tell Stan to keep track of log-likelihood
file = "fits/test2-u6t")
u6t %>% PSIS()
## PSIS lppd penalty std_err
## 1 89.95509 -44.97755 11.77763 9.638
compare(u6, u6t)
## WAIC SE dWAIC dSE pWAIC weight
## u6t 89.70633 9.546275 0.0000000 NA 11.653247 0.6209361
## u6 90.69338 6.831297 0.9870471 4.439826 7.493209 0.3790639
compare(u6, u6t) %>% plot()
coeftab(u6, u6t) %>% plot()
WAIC and PSIS-LOO-CV are both trying to estimate out-of-sample deviance, so they will often give very similar results. So which do we prefer?
The authors of the loo
package (who are on the Stan development team) have this recommendation:
The
waic()
methods can be used to compute WAIC from the pointwise log-likelihood. However, we recommend LOO-CV using PSIS (as implemented by the loo() function) because PSIS provides useful diagnostics as well as effective sample size and Monte Carlo estimates.
McElreath seems to sometimes use one and sometimes the other.
The numbers we get from WAIC()
and PSIS()
don’t mean much in isolation, they are only useful as a way to compare models that use the same data. In that context, a value closer to 0 indicates a model that we expect will make better out-of-sample predictions.
But remember that these are approximations, so we shouldn’t read too much into small differences unless we know how good the approximations are. Fortunately both WAIC and PSIS provide estimates of their precision as well. We will use these standard errors as well as the estimates when comparing models.
A while back we fit three models to predict divorce rate.
model | predictors |
---|---|
m5.1 |
average age at marriate (A) |
m5.2 |
marriage rate (M) |
m5.3 |
both (additive) |
data(WaffleDivorce)
Waffles <-
WaffleDivorce %>%
mutate(
WaffleHousesPerCap = WaffleHouses / Population,
D = standardize(Divorce),
W = standardize(WaffleHousesPerCap),
A = standardize(MedianAgeMarriage),
M = standardize(Marriage)
)
m5.1 <- quap(
data = Waffles,
alist(
D ~ dnorm(mu, sigma),
mu <- b_0 + b_A * A,
b_0 ~ dnorm(0, 0.2),
b_A ~ dnorm(0, 0.5),
sigma ~ dexp(1)
)
)
m5.2 <- quap(
data = Waffles,
alist(
D ~ dnorm(mu, sigma),
mu <- b_0 + b_M * M,
b_0 ~ dnorm(0, 0.2),
b_M ~ dnorm(0, 0.5),
sigma ~ dexp(1)
)
)
m5.3 <- quap(
data = Waffles,
alist(
D ~ dnorm(mu, sigma),
mu <- b_0 + b_A * A + b_M * M,
b_0 ~ dnorm(0, 0.2),
b_A ~ dnorm(0, 0.5),
b_M ~ dnorm(0, 0.5),
sigma ~ dexp(1)
)
)
The rethinking package includes a compare function that makes it easy to compare models based on WAIC or PSIS.
compare(m5.1, m5.2, m5.3) %>% pander()
WAIC | SE | dWAIC | dSE | pWAIC | weight | |
---|---|---|---|---|---|---|
m5.1 | 126.8 | 14.27 | 0 | NA | 4.291 | 0.8279 |
m5.3 | 129.9 | 15.04 | 3.153 | 1.206 | 6.211 | 0.1711 |
m5.2 | 140.3 | 10.97 | 13.51 | 10.56 | 3.574 | 0.000965 |
compare(m5.1, m5.2, m5.3) %>% plot()
compare(m5.1, m5.2, m5.3)@dSE
## m5.1 m5.2 m5.3
## m5.1 NA 10.46700 0.9718292
## m5.2 10.4670023 NA 10.8524181
## m5.3 0.9718292 10.85242 NA
We can do the same thing for PSIS:
compare(m5.1, m5.2, m5.3, func = PSIS) %>% pander()
## Some Pareto k values are very high (>1). Set pointwise=TRUE to inspect individual points.
## Some Pareto k values are high (>0.5). Set pointwise=TRUE to inspect individual points.
## Some Pareto k values are very high (>1). Set pointwise=TRUE to inspect individual points.
PSIS | SE | dPSIS | dSE | pPSIS | weight | |
---|---|---|---|---|---|---|
m5.1 | 127.8 | 15.06 | 0 | NA | 4.752 | 0.8587 |
m5.3 | 131.4 | 16.61 | 3.628 | 1.826 | 6.988 | 0.14 |
m5.2 | 140.6 | 11.39 | 12.88 | 11.28 | 3.77 | 0.001372 |
compare(m5.1, m5.2, m5.3, func = PSIS) %>% plot()
## Some Pareto k values are very high (>1). Set pointwise=TRUE to inspect individual points.
## Some Pareto k values are high (>0.5). Set pointwise=TRUE to inspect individual points.
## Some Pareto k values are high (>0.5). Set pointwise=TRUE to inspect individual points.
compare(m5.1, m5.2, m5.3)@dSE
## m5.1 m5.2 m5.3
## m5.1 NA 10.49201 0.9263976
## m5.2 10.4920089 NA 11.0069926
## m5.3 0.9263976 11.00699 NA
Let’s see which states have large values of \(k\) for Pareto smoothing or contribute a large amount to the penalty for WAIC.
set.seed(12345)
bind_cols(
PSIS(m5.3, pointwise = TRUE) %>% rename(penalty_P = penalty),
WAIC(m5.3, pointwise = TRUE) %>% rename(penalty_W = penalty),
Waffles
) %>%
gf_point( penalty_W ~ k ) %>%
gf_text(label = ~ Loc, alpha = 0.7, size = 10, color = "steelblue",
data = . %>% filter(k > 0.5))
## Some Pareto k values are high (>0.5). Set pointwise=TRUE to inspect individual points.
## New names:
## * lppd -> lppd...2
## * std_err -> std_err...4
## * lppd -> lppd...7
## * std_err -> std_err...9
Idaho stands out pretty dramatically for both. It is an unusual (surprising, hard to predict) observation. If there were new data, it might not have a value like that, so there is risk that we are overfitting to match this one observation better. What can we do about this?
Bad idea: throw out Idaho.
Removing outliers simply because they are outliers is almost always a bad idea.
Better idea: improve the model.
In this case, let’s create a model that isn’t so surprised by divorce rates that are quite a bit higher or lower than predicted. That is, we want to replace the normal distribution with a distribution that has heavier tails. A commonly used distribution that is similar to the normal distribution but with heavier tails is Student’s t distribution.
m5.3t <- quap(
data = Waffles,
alist(
D ~ dstudent(2, mu, sigma),
mu <- b_0 + b_A * A + b_M * M,
b_0 ~ dnorm(0, 0.2),
b_A ~ dnorm(0, 0.5),
b_M ~ dnorm(0, 0.5),
sigma ~ dexp(1)
)
)
No more warnings – all the Pareto \(k\) values are now well below 0.5:
m5.3t %>% PSIS()
## PSIS lppd penalty std_err
## 1 133.4585 -66.72925 6.698328 11.89951
m5.3t %>% PSIS(pointwise = TRUE) %>%
data.frame() %>%
mutate(row = 1:n()) %>%
gf_point(k ~ row)
compare(m5.3, m5.3t, func = PSIS) %>% pander()
## Some Pareto k values are very high (>1). Set pointwise=TRUE to inspect individual points.
PSIS | SE | dPSIS | dSE | pPSIS | weight | |
---|---|---|---|---|---|---|
m5.3 | 132.7 | 17.14 | 0 | NA | 7.54 | 0.6258 |
m5.3t | 133.8 | 11.94 | 1.028 | 9.225 | 6.91 | 0.3742 |
compare(m5.3, m5.3t, func = PSIS) %>% plot()
## Some Pareto k values are very high (>1). Set pointwise=TRUE to inspect individual points.
coeftab(m5.3, m5.3t) %>% plot()
The figures below (and in the text) are based on a simulation where
There are 4 predictors: \(x_1, x_2, x_3, x_4\).
The response is defined by \(y = \operatorname{rnorm}(0.15 \cdot x_1 - 0.4 \cdot x_2, 1)\)
Models are fit using 1 - 5 parameters (intercept plus slopes), in order. So the best (for predicting) model should be the model with 3 parameters (2 predictors). Any improvement due to adding in the last two predictors would be unlikely to persist to new data.
Since it is a simulation, we can simulate a new data set generated the same way to compute the actual out-of-sample deviance. Here are the results.
Now let’s consider regularizing our model. This means using narrower priors for the coefficients. Using narrower priors makes the model less sensitive to the data, so less prone to overfitting (but more susceptible to underfitting).
We’ll do this with three versions of the priors (after standardizing in the usual way).
gf_dist("norm", mea = 0, sd = 1, color = ~ "Norm(0,1)") %>%
gf_dist("norm", mea = 0, sd = 0.5, color = ~ "Norm(0,0.5)") %>%
gf_dist("norm", mea = 0, sd = 0.2, color = ~ "Norm(0,0.2)") %>%
gf_theme(legend.position = "top") %>%
gf_labs(color = "Prior distribution")
As the prior becomes narrower, the effect of overfitting becomes less pronounced.
Things to notice:
In the smaller data set, the narrower priors fit the data less well (larger in-sample deviance) but fit new data better (smaller out-of-sample deviance)
Choice of prior matters less in a larger data set We’ll return regularizing in the context of other examples. For now, keep in mind that this is another effect of the choice of prior.
In each of these cases, deviance is picking up that 3 is the right number of parameters.