In this case study we show how to implement an estimator for the leave-one-out (LOO) version of the expected Brier score (denoted by LOO-BS) for Bayesian survival models, with potentially right-censored event times. We combine the “inverse of probability of censoring weighted estimator” by Gerds et al. with Pareto-smoothed importance sampling (PSIS), by Vehtari et al., in order to efficiently approximate the LOO-BS. We compare the PSIS approximation to the exact LOO-BS.
As a by-product we show how to implement proportional hazard models with a time-dependent baseline hazard by using monotone splines (M-splines).
Useful References:
For PSIS we recommend:
Vehtari, Aki, Andrew Gelman, and Jonah Gabry. “Practical Bayesian model evaluation using leave-one-out cross-validation and WAIC.” Statistics and Computing 27.5 (2017): 1413-1432. APA
For the inverse of probability of censoring weighted estimator see:
Gerds, Thomas A., and Martin Schumacher. “Consistent estimation of the expected Brier score in general survival models with right‐censored event times.” Biometrical Journal 48.6 (2006): 1029-1040.
Monotone (M-splines) were introduced here:
Ramsay, James O. “Monotone regression splines in action.” Statistical science 3.4 (1988): 425-441.
See als this Stan discourse discussion.
knitr::opts_chunk$set(echo = TRUE)
library(rstan);rstan_options(auto_write = TRUE)
library(loo)
library(survival)
library(tibble)
library(dplyr)
library(purrr)
library(bayesplot)
library(splines2)
library(pec)
library(cowplot)
library(MCMCpack)
library(RColorBrewer)
qual_col_pals = brewer.pal.info[brewer.pal.info$category == 'qual',]
col_vector = unlist(mapply(brewer.pal, qual_col_pals$maxcolors, rownames(qual_col_pals)))
sm <- stan_model("~/Desktop/Stan/Loo_Pec/survival_parametric_baseline_hazard_simplex_loo.stan")
expose_stan_functions("~/Desktop/Stan/Loo_Pec/survival_parametric_baseline_hazard_simplex_loo.stan")
df <- read.delim("~/Desktop/Stan/Loo_Pec/ncog.txt",sep=" ")
df <- df %>% mutate(event=d,time=t*12/365)
df <- df%>% mutate(isGroupB=as.double(arm=="B")) %>%
dplyr::select(isGroupB, event, time)
N <- nrow(df)
times <- pull(df,time)
is_censored <- pull(df,event)==0
msk_censored <- is_censored == 1
time_range <- range(times)
time_min <- time_range[1]
time_max <- time_range[2]
X <- as.matrix(pull(df, isGroupB))
The table below reports survival data from a randomized control trial run by the Nothern California Oncology Group, that compares two treatments for head and neck cancer: isGroupB==0
in case of chemotherapy and isGroup==1
if chemotherapy and radiation was applied. The response variable time
is the survival time in months (actually days divided by 12/365) and the variable event
is equal to 0 iff the survival time of that patient is only known to exceed the corresponding survival time. For details see (starting) chapter 9.2 in “Computer Age Statistical Inference” by B. Efron and T. Hastie.
We build a proportional hazard model with a baseline-hazard that is parametrized by M-Splines. Below we show prior predictive (checks) draws for the time-dependent baseline1 hazard and survival functions.
For details we refer to accompanying Stan Code in survival_parametric_baseline_hazard_simplex_loo.stan
.
The functions surv_t
and hazard_t
are defined in the above file and are exposed above as R functions through expose_stan_functions
. They allow the evaluation of the survival and hazard function, respectively, at fixed time, for different (latent) paremters and covariates. Both functions return a matrix with as many rows as covariates and as many columns as parameters. For their respective signatures, refer to the definition in the Stan file.
dirichlet_alpha<- 2
sigma_intercept <- 1
mspline_degree<-3
nprior_samples <- 32
nplot_points <- 1000
knots <- quantile(times[!msk_censored], probs=c(.05, .275, .5, .725, .95))
nknots <- length(knots)
cbbPalette <- sample(col_vector, nprior_samples)
times_plot <- seq(0,time_max,length.out = nplot_points)
isp_plot <- iSpline(times_plot, knots=knots, degree=mspline_degree,
intercept=FALSE,Boundary.knots = c(0, time_max))
msp_plot <- deriv(isp_plot)
nbasis <- dim(isp_plot)[2]
simplex_samples <- t(MCMCpack::rdirichlet(nprior_samples, rep(dirichlet_alpha,nbasis)))
icpt_samples <- rnorm(nprior_samples, 0, sigma_intercept)
survs_prior <- do.call(rbind,
map(1:length(times_plot),
~surv_t(isp_plot[.,],
as.matrix(c(0)),
simplex_samples,
icpt_samples,
t(as.matrix(rep(0, nprior_samples)))
)
)
)
hazards_prior <- do.call(rbind,
map(1:length(times_plot),
~hazard_t(msp_plot[.,],
as.matrix(c(0)),
simplex_samples,
icpt_samples,
t(as.matrix(rep(0, nprior_samples)))
)
)
)
ggplot(data=mutate(map_dfr(1:nprior_samples,
~tibble(y=hazards_prior[,.], time=times_plot, sample=.)), sample=as.factor(sample)))+
geom_vline(xintercept = c(0,knots,time_max), alpha=.5, color='gray')+
geom_line(mapping = aes(x=time, y=y, color=sample))+
ggtitle("Control Group (Baseline)")+
xlab("Time")+
ylab("Hazard")+
guides(color=FALSE)+
scale_fill_manual(values=cbbPalette)+
scale_colour_manual(values=cbbPalette)
ggplot(data=mutate(map_dfr(1:nprior_samples,
~tibble(y=survs_prior[,.], time=times_plot, sample=.)), sample=as.factor(sample)))+
geom_vline(xintercept = c(0,knots,time_max), alpha=.5, color='gray')+
geom_line(mapping = aes(x=time, y=y, color=sample))+
ggtitle("Control Group (Baseline)")+
xlab("Time")+
ylab("Survival Probability")+
guides( color=FALSE)+
scale_fill_manual(values=cbbPalette)+
scale_colour_manual(values=cbbPalette)
Let’s obtain our Posterior…
create_stan_data <- function(df,knots,m_spline_degree, dirichlet_alpha) {
times <- pull(df,time)
N <- nrow(df)
X <- as.matrix(pull(df, isGroupB))
is_censored <- pull(df,event)==0
times <- pull(df,time)
msk_censored <- is_censored == 1
N_censored <- sum(msk_censored)
time_range <- range(times)
time_min <- time_range[1]
time_max <- time_range[2]
i_spline_basis_evals <- iSpline(times,
knots=knots,
degree=mspline_degree,
intercept=FALSE,
Boundary.knots = c(0, time_max))
m_spline_basis_evals <- deriv(i_spline_basis_evals)
i_spline_basis_evals_censored <- i_spline_basis_evals[msk_censored,]
i_spline_basis_evals_uncensored <- i_spline_basis_evals[!msk_censored,]
m_spline_basis_evals_uncensored <- m_spline_basis_evals[!msk_censored,]
nbasis <- dim(i_spline_basis_evals_censored)[2]
list(N_uncensored=N-N_censored,
N_censored=N_censored,
X_censored=as.matrix(X[msk_censored,]),
X_uncensored=as.matrix(X[!msk_censored,]),
times_censored=times[msk_censored],
times_uncensored = times[!msk_censored],
NC=ncol(X),
m=nbasis,
m_spline_basis_evals_uncensored=m_spline_basis_evals_uncensored,
i_spline_basis_evals_uncensored=i_spline_basis_evals_uncensored,
i_spline_basis_evals_censored=i_spline_basis_evals_censored,
alpha=dirichlet_alpha
)
}
stan_data <- create_stan_data(df,knots,m_spline_degree,dirichlet_alpha )
fit <- sampling(sm, data=stan_data, seed=42, chains=4, cores=2, iter=4000)
post <- as.array(fit)
Get a summary of the posterior:
summary(fit,
pars=c("intercept", "betas[1]", sprintf("gammas[%d]", 1:stan_data$m)),
probs=c(0.025, .5, .975))$summary %>%
broom::tidy() %>% rename(parameter=`.rownames`)
quantile_func <- function(x, probs) {
x <- as.data.frame(t(x))
map_dbl(x, ~quantile(., probs=probs))
}
get_post_plot_df <- function(eval_lst) {
funcs <- purrr::map(eval_lst, rowMeans)
funcs_low <- purrr::map(eval_lst, ~quantile_func(.,probs=c(0.025)))
funcs_up <- purrr::map(eval_lst, ~quantile_func(.,probs=c(0.975)))
dplyr::bind_rows(
tibble(
t=times_plot,
isGroupB=F,
func=map_dbl(funcs, ~.[[1]]),
func_low=map_dbl(funcs_low,~.[[1]]),
func_up=map_dbl(funcs_up,~.[[1]])
),
tibble(
t=times_plot,
isGroupB=T,
func=map_dbl(funcs, ~.[[2]]),
func_low=map_dbl(funcs_low,~.[[2]]),
func_up=map_dbl(funcs_up,~.[[2]])
)
)
}
get_post_plot <- function(plot_df, ylab) {
ggplot(data=plot_df) +
geom_ribbon(mapping=aes(x=t,ymin=func_low, ymax=func_up, fill=isGroupB), alpha=.25)+
geom_line(mapping=aes(x=t, y=func, color=isGroupB))+
geom_vline(xintercept = c(0,knots, time_max), alpha=.5, color='gray', linetype='dashed')+
geom_rug(data=tibble(times=times[is_censored]), aes(x=times), col="green",alpha=0.5, size=1.25)+
geom_rug(data=tibble(times=times[!is_censored]), aes(x=times), col="red",alpha=0.5, size=1.25)+
ylab(ylab)+
xlab("Time")
}
intercepts <- as.vector(post[,,"intercept"])
betas <- t(matrix(post[,,sprintf("betas[%d]" ,1:stan_data$NC)], dim(post)[1]*dim(post)[2], stan_data$NC))
gammas <- t(matrix(post[,,sprintf("gammas[%d]", 1:nbasis)], dim(post)[1]*dim(post)[2], nbasis))
X_surv = as.matrix(c(0,1))
times_plot <- seq(0, time_max, length.out = nplot_points)
i_spline_basis_evals <- iSpline(times_plot,
knots=knots,
degree=mspline_degree,
intercept=FALSE,
Boundary.knots = c(0, time_max))
m_spline_basis_evals <- deriv(i_spline_basis_evals)
hazard_df <- get_post_plot_df(purrr::map(1:length(times_plot),
~hazard_t(m_spline_basis_evals[.,],
X_surv,
gammas,
intercepts,
betas
)
)
)
get_post_plot(hazard_df,"Hazard")+
coord_cartesian(xlim=c(0, 50))
Denote by \(\mathbf{T}_{-i}\) the \(N-1\) dimensional vector2 of observed times (censored and uncensored) for all individuals but for individual \(i\). Furthermore denote by \(\mathbf{T}\) the \(N\) dimensional vector of observed times (censored and uncensored) for all individuals.
Now, write \(S_i(t\vert \mathbf{T}_{-i})\) for the expected survival probability of individual \(i\) at time \(t\), where the expectation is with respect to the posterior over latent parameters \(\boldsymbol{\theta}\), given \(\mathbf{T}_{-i}\).
Suppose \(\{\boldsymbol{\theta}_s\}_{s=1\dots S}\) is a (correlated) sequence of latent model parameters drawn from the posterior, given \(\mathbf{T}\). We approximate \(S_i(t\vert \mathbf{T}_{-i})\) as follows
\[ S_i(t\vert \mathbf{T}_{-i}) \approx \frac{\sum_{s=1}^S r_{s,i} S_i(t\vert \boldsymbol{\theta}_s)}{\sum_{s=1}^S r_{s,i}} \]
where \(S_i(t\vert \boldsymbol{\theta}_s)\) is the probability that individual \(i\) survives beyond time \(t\), given latent parameters \(\boldsymbol{\theta}_s\), and
\[ r_{s,i} \equiv \frac{1}{p(T_i\vert \boldsymbol{\theta}_s)} \]
Coming to the Brier score, we now define
\[ \eta_{i,t} \equiv \left\{ 1_{\{T_i > t\}} - S_i(t\vert \mathbf{T}_{-i})\right\}^2 \]
Based on the residuals \(\eta_{i}(t)\) we define, following Gerds et al., the inverse probability of censoring weighted estimator \(\Omega(t)\):
\[ \Omega(t)\equiv \frac{1}{N} \sum_{i=1}^N w_i(t)\eta_{i}(t)^2 \]
with
\[ w_i(t) \equiv \frac{\mathbf{1}_{\{T_i \leq t\}} \delta_i}{G_i(T_i)} + \frac{\mathbf{1}_{\{T_i > t\}}}{G_i(t)} \]
Here \(\delta_i\) is the event indicator, i.e. \(\delta_i = 1\) iff individual experienced the event and \(0\) otherwise. Moreover, \(G_i(t)\) is the probability that the censoring time3, for individual \(i\), denoted by \(C_i\), exceeds \(t\). Note that for both the event and censoring survival functions we indicate a potential dependence on covariates by adding the index \(i\). \(G_i(t)\) can be obtained by fitting, for example, a Cox model to the data with reversed event indicators. By introducing a weighting based on \(G_i\) one can guarantee that the bias of the estimator for \(\Omega(t)\) does not depent on the survival model (otherwise model comparison would be hard!). Finally, it can be shown that under weak conditions, \(\Omega(t)\) is a uniformly consistent estimator for the expected mean squared error in time (aka Brier score). We refer the reader for more details and proofs to Gerds et al..
Practically, this means we can estimate \(S_i(t\vert \mathbf{T}_{-i})\) for all \(i\) and required times \(t\), based on posterior samples \(S_i(t\vert \boldsymbol{\theta}_s)\), obtained from an model for all individuals! This matrix (!) \(\{S_i(t\vert \mathbf{T}_{-i})\}_{i,t}\) can then be passed to the function pec()
provided by pec
R Package, which calculates \(\Omega(t)\) and provides ways to choose \(G_i\):
nbrier_points <- 100
log_lik <- extract_log_lik(fit,
merge_chains = T)
r_eff <- relative_eff(exp(log_lik),
chain_id=rep(1:dim(post)[2],
each=dim(post)[1]))
loo_object <- loo(log_lik,
r_eff = r_eff,
cores = 2,
save_psis=T)
times_pec <- seq(0, time_max, length.out = nbrier_points)
i_spline_basis_evals <- iSpline(times_pec,
knots=knots,
degree=mspline_degree,
intercept=FALSE,
Boundary.knots = c(0, time_max))
survs <- do.call(cbind,
purrr::map(1:length(times_pec),
~colMeans(t(surv_t(i_spline_basis_evals[.,],
X,
gammas,
intercepts,
betas)
)
)
)
)
# c.f. https://avehtari.github.io/modelselection/diabetes.html#43_other_predictive_performance_measures
# for another example on how to use E_loo
survs_psis_loo <- do.call(cbind,
purrr::map(1:length(times_pec),
~E_loo(t(surv_t(i_spline_basis_evals[.,],
X,
gammas,
intercepts,
betas)
),
loo_object$psis_object,
type = "mean",
log_ratios = -log_lik)$value
)
)
# This calculates the true Loo
i_spline_basis_evals_ <- iSpline(times_pec,
knots=knots,
degree=mspline_degree,
intercept=FALSE,
Boundary.knots = c(0, time_max)
)
do_loo <- function(idx) {
stan_data_ <- create_stan_data(slice(df, -idx),
knots,
mspline_degree,
dirichlet_alpha)
fit_ <- sampling(sm,
data=stan_data_,
seed=42,
chains=4,
cores=2,
iter=4000,
refresh=0)
post_ <- as.array(fit_)
intercepts_ <- as.vector(post_[,,"intercept"])
betas_ <- t(matrix(post_[,,sprintf("betas[%d]" ,1:stan_data_$NC)], dim(post_)[1]*dim(post_)[2],stan_data_$NC))
gammas_ <- t(matrix(post_[,,sprintf("gammas[%d]", 1:stan_data_$m)], dim(post_)[1]*dim(post_)[2], stan_data_$m))
purrr::map_dbl(1:length(times_pec),
~as.vector(rowMeans(surv_t(i_spline_basis_evals_[.,],
as.matrix(X[idx]),
gammas_,
intercepts_,
betas_
)
)
)
)
}
survs_loo <- do.call(rbind, purrr::map(1:nrow(df), do_loo))
pec_rslt <- pec(list("MSpline-PSIS-Loo"=survs_psis_loo,
"MSpline-Loo"=survs_loo,
"MSpline"=survs
),
formula =Surv(time,event)~isGroupB ,
data=df,
exact = F,
times = times_pec,
cens.model="cox",
splitMethod="none",
B=0,
verbose=TRUE
)
tibble(time=pec_rslt$time,
bs=pec_rslt$AppErr[["MSpline-PSIS-Loo"]],
model="MSpline-PSIS-Loo"
) %>%
bind_rows(tibble(time=pec_rslt$time,
bs=pec_rslt$AppErr[["MSpline-Loo"]],
model="MSpline-Loo"
)
) %>%
bind_rows(tibble(time=pec_rslt$time,
bs=pec_rslt$AppErr[["MSpline"]],
model="MSpline")
) %>%
bind_rows(tibble(time=pec_rslt$time,
bs=pec_rslt$AppErr[["Reference"]],
model="Kaplan-Meier")) -> df_pec
df_pec %>%
ggplot()+
geom_hline(yintercept=.25,alpha=.5, color='gray', linetype='dashed')+
geom_vline(xintercept = c(0,knots, time_max), alpha=.5, color='gray', linetype='dashed')+
geom_step(aes(x=time, y=bs, color=model))+
geom_rug(data=tibble(times=times[is_censored]), aes(x=times), col="green",alpha=0.5, size=1.25)+
geom_rug(data=tibble(times=times[!is_censored]), aes(x=times), col="red",alpha=0.5, size=1.25)+
ylab("Brier Score")+
xlab("Time")+
theme(legend.position="bottom")-> p1
p1
trsf_bs <- function(m) {
dplyr::filter(df_pec, model==m) %>%
dplyr::select(time, bs) %>%
dplyr::rename(!!paste("bs",m,sep="_") := bs)
}
purrr::map(unique(df_pec$model),trsf_bs) %>%
purrr::reduce(left_join, by = "time")
##
## Computed from 8000 by 96 log-likelihood matrix
##
## Estimate SE
## elpd_loo -274.7 13.1
## p_loo 4.4 0.4
## looic 549.5 26.3
## ------
## Monte Carlo SE of elpd_loo is 0.0.
##
## All Pareto k estimates are good (k < 0.5).
## See help('pareto-k-diagnostic') for details.
Baseline here means control group isGroupB==0
↩
where \(N\) is the number of individuals↩
Here we assume that actual event process competes with a censoring process and that the observed time is the minimum of the two; right censoring occurrs precisely when the censoring time is smaller than the event time.↩