diff --git a/DESCRIPTION b/DESCRIPTION index 7cb9c7a07..7313116f0 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -10,6 +10,7 @@ Authors@R: c(person("Jonah", "Gabry", email = "jsg2201@columbia.edu", role = "au person(given = "Jacqueline Buros", family = "Novik", role = "ctb", comment = "R/stan_jm.R"), person("AstraZeneca", role = "ctb", comment = "R/stan_jm.R"), + person("Eren", "Elci", role = "ctb", comment = "R/stan_surv.R"), person("Trustees of", "Columbia University", role = "cph"), person("Simon", "Wood", role = "cph", comment = "R/stan_gamm4.R"), person("R Core", "Deveopment Team", role = "cph", @@ -45,6 +46,8 @@ Imports: rstan (>= 2.32.0), rstantools (>= 2.1.0), shinystan (>= 2.3.0), + splines, + splines2 (>= 0.2.7), stats, survival (>= 2.40.1), RcppParallel (>= 5.0.1), @@ -61,6 +64,7 @@ Suggests: mgcv (>= 1.8-13), rmarkdown, roxygen2, + simsurv (>= 0.2.2), StanHeaders (>= 2.21.0), testthat (>= 1.0.2), gamm4, diff --git a/NAMESPACE b/NAMESPACE index e1a1d825e..baf720ff3 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -23,6 +23,8 @@ S3method(fixef,stanmvreg) S3method(fixef,stanreg) S3method(formula,stanmvreg) S3method(formula,stanreg) +S3method(get_surv,stanjm) +S3method(get_surv,stansurv) S3method(get_x,default) S3method(get_x,gamm4) S3method(get_x,lmerMod) @@ -56,12 +58,16 @@ S3method(nsamples,stanreg) S3method(pairs,stanreg) S3method(plot,predict.stanjm) S3method(plot,stanreg) +S3method(plot,stansurv) S3method(plot,survfit.stanjm) +S3method(plot,survfit.stansurv) S3method(posterior_epred,stanreg) S3method(posterior_interval,stanreg) S3method(posterior_linpred,stanreg) S3method(posterior_predict,stanmvreg) S3method(posterior_predict,stanreg) +S3method(posterior_survfit,stanjm) +S3method(posterior_survfit,stansurv) S3method(posterior_vs_prior,stanreg) S3method(pp_check,stanreg) S3method(predict,stanreg) @@ -79,6 +85,7 @@ S3method(print,stanreg_list) S3method(print,summary.stanmvreg) S3method(print,summary.stanreg) S3method(print,survfit.stanjm) +S3method(print,survfit.stansurv) S3method(prior_summary,stanreg) S3method(ranef,stanmvreg) S3method(ranef,stanreg) @@ -115,6 +122,7 @@ export(default_prior_intercept) export(dirichlet) export(exponential) export(fixef) +export(get_surv) export(get_x) export(get_y) export(get_z) @@ -182,10 +190,12 @@ export(stan_mvmer) export(stan_nlmer) export(stan_polr) export(stan_polr.fit) +export(stan_surv) export(stanjm_list) export(stanmvreg_list) export(stanreg_list) export(student_t) +export(tve) export(waic) if(getRversion()>='3.3.0') importFrom(stats, sigma) else importFrom(lme4,sigma) @@ -195,6 +205,7 @@ import(bayesplot) import(methods) import(rstantools) import(shinystan) +import(splines2) import(stats) importFrom(Matrix,Matrix) importFrom(Matrix,t) @@ -272,6 +283,7 @@ importFrom(rstan,stanc) importFrom(rstan,vb) importFrom(rstantools,loo_R2) importFrom(rstantools,nsamples) +importFrom(splines,bs) importFrom(stats,cov2cor) importFrom(stats,getInitial) importFrom(survival,Surv) diff --git a/R/data_block.R b/R/data_block.R index 40eb64895..3c10c5918 100644 --- a/R/data_block.R +++ b/R/data_block.R @@ -78,7 +78,8 @@ handle_glm_prior <- function(prior, nvars, default_scale, link, prior_scale = as.array(rep(1, nvars)), prior_df = as.array(rep(1, nvars)), prior_dist_name = NA, global_prior_scale = 0, global_prior_df = 0, - slab_df = 0, slab_scale = 0, + slab_df = 0, slab_scale = 0, + prior_concentration = as.array(rep(1, nvars)), prior_autoscale = FALSE)) if (!is.list(prior)) @@ -94,6 +95,7 @@ handle_glm_prior <- function(prior, nvars, default_scale, link, global_prior_df <- 0 slab_df <- 0 slab_scale <- 0 + prior_concentration <- 1 if (!prior_dist_name %in% unlist(ok_dists)) { stop("The prior distribution should be one of ", paste(names(ok_dists), collapse = ", ")) @@ -114,6 +116,14 @@ handle_glm_prior <- function(prior, nvars, default_scale, link, slab_scale <- prior$slab_scale } else if (prior_dist_name %in% "exponential") { prior_dist <- 3L # only used for scale parameters so 3 not a conflict with 3 for hs + } else if (prior_dist_name %in% "dirichlet") { + prior_dist <- 4L # only used by stan_surv for baseline hazard coefficients + prior_concentration <- prior$concentration + if (is.null(prior_concentration)) { + prior_concentration <- rep(1, nvars) + } else { + prior_concentration <- maybe_broadcast(prior_concentration, nvars) + } } prior_df <- maybe_broadcast(prior_df, nvars) @@ -121,7 +131,9 @@ handle_glm_prior <- function(prior, nvars, default_scale, link, prior_mean <- maybe_broadcast(prior_mean, nvars) prior_mean <- as.array(prior_mean) prior_scale <- maybe_broadcast(prior_scale, nvars) - + prior_concentration <- maybe_broadcast(prior_concentration, nvars) + prior_concentration <- as.array(prior_concentration) + nlist(prior_dist, prior_mean, prior_scale, @@ -131,5 +143,6 @@ handle_glm_prior <- function(prior, nvars, default_scale, link, global_prior_df, slab_df, slab_scale, + prior_concentration, prior_autoscale = isTRUE(prior$autoscale)) } diff --git a/R/doc-datasets.R b/R/doc-datasets.R index 0db6369ab..14b5a499c 100644 --- a/R/doc-datasets.R +++ b/R/doc-datasets.R @@ -20,7 +20,7 @@ #' Small datasets for use in \pkg{rstanarm} examples and vignettes. #' #' @name rstanarm-datasets -#' @aliases kidiq roaches wells bball1970 bball2006 mortality tumors radon pbcLong pbcSurv +#' @aliases bball1970 bball2006 bcancer frail kidiq mice mortality pbcLong pbcSurv tumors radon roaches wells #' @format #' \describe{ #' \item{\code{bball1970}}{ @@ -50,6 +50,42 @@ #' \item \code{K} Number of at-bats #' } #' } +#' \item{\code{bcancer}}{ +#' The German Breast Cancer Study Group dataset, containing time to death or +#' recurrence for 686 patients with primary node positive breast cancer +#' recruited between 1984-1989. +#' +#' Source: Royston and Parmar (2002) +#' +#' 686 obs. of 4 variables +#' \itemize{ +#' \item \code{recdays} Time to death or censoring (in days) +#' \item \code{recyrs} Time to death or censoring (in years) +#' \item \code{status} Event indicator (0 = right censored, 1 = event) +#' \item \code{group} Prognostic group, based on a regression model developed +#' by Sauerbrei and Royston (1999) (\code{Good}, \code{Medium}, \code{Poor}) +#' } +#' } +#' \item{\code{frail}}{ +#' A simulated dataset of event times (i.e. survival data) for 200 patients +#' clustered within 20 hospital sites (10 patients per hospital site). +#' The event times are simulated from a parametric proportional hazards model +#' under the following assumptions: (i) a constant (i.e. exponential) baseline +#' hazard rate of 0.1; (ii) a fixed treatment effect with log hazard ratio of +#' 0.3; and (iii) a site-specific random intercept (specified on the log +#' hazard scale) drawn from a \eqn{N(0,1)} distribution. +#' +#' 200 obs. of 6 variables +#' \itemize{ +#' \item \code{id} ID unique to each patient +#' \item \code{site} ID unique to each hospital site (i.e. cluster) +#' \item \code{trt} Treatment indicator (0 = untreated, 1 = treated) +#' \item \code{b} Cluster-specific random intercept used to simulate the +#' event times +#' \item \code{eventtime} Event or censoring time +#' \item \code{status} Event indicator (0 = right censored, 1 = event) +#' } +#' } #' \item{\code{kidiq}}{ #' Data from a survey of adult American women and their children #' (a subsample from the National Longitudinal Survey of Youth). @@ -64,6 +100,24 @@ #' \item \code{mom_age} Mother's age #' } #' } +#' \item{\code{mice}}{ +#' Lung tumor development in 144 RFM mice allocated to either a conventional +#' environment or germ-free environment. Mice were sacrificed and examined +#' for presence of a lung tumor. The outcome variables in the dataset +#' (\code{l} and \code{u}) denote a left-censored or right-censored time +#' interval within which the development of the first lung tumor must have +#' occurred. +#' +#' Source: Hoel and Walburg (1972) +#' +#' 144 obs. of 3 variables +#' \itemize{ +#' \item \code{l} Lower limit of the interval. +#' \item \code{u} Upper limit of the interval. +#' \item \code{grp} Experimental group (\code{ce} = conventional environment, +#' \code{ge} = germ-free environment). +#' } +#' } #' \item{\code{mortality}}{ #' Surgical mortality rates in 12 hospitals performing cardiac surgery #' in babies. @@ -86,21 +140,20 @@ #' #' 304 obs. of 8 variables (\code{pbcLong}) and 40 obs. of 7 variables (\code{pbcSurv}) #' \itemize{ -#' \item \code{age} in years -#' \item \code{albumin} serum albumin (g/dl) -#' \item \code{logBili} logarithm of serum bilirubin -#' \item \code{death} indicator of death at endpoint -#' \item \code{futimeYears} time (in years) between baseline and +#' \item \code{age} Age (in years) +#' \item \code{albumin} Serum albumin (g/dl) +#' \item \code{logBili} Logarithm of serum bilirubin +#' \item \code{death} Indicator of death at endpoint +#' \item \code{futimeYears} Time (in years) between baseline and #' the earliest of death, transplantion or censoring -#' \item \code{id} numeric ID unique to each individual -#' \item \code{platelet} platelet count -#' \item \code{sex} gender (m = male, f = female) -#' \item \code{status} status at endpoint (0 = censored, -#' 1 = transplant, 2 = dead) -#' \item \code{trt} binary treatment code (0 = placebo, 1 = -#' D-penicillamine) -#' \item \code{year} time (in years) of the longitudinal measurements, -#' taken as time since baseline) +#' \item \code{id} Numeric ID unique to each individual +#' \item \code{platelet} Platelet count +#' \item \code{sex} Gender (m = male, f = female) +#' \item \code{status} Status at endpoint (0 = censored, 1 = transplant, +#' 2 = dead) +#' \item \code{trt} Binary treatment code (0 = placebo, 1 = D-penicillamine) +#' \item \code{year} Time (in years) of the longitudinal measurements, +#' taken as time since baseline #' } #' } #' @@ -181,7 +234,20 @@ #' @template reference-gelman-hill #' #' @references -#' Spiegelhalter, D., Thomas, A., Best, N., & Gilks, W. (1996) BUGS 0.5 +#' Hoel, D. and Walburg, H. (1972) Statistical analysis of survival experiments. +#' \emph{The Annals of Statistics} \strong{18}:1259--1294. +#' +#' Royston, P. and Parmar, M. (2002) Flexible parametric proportional-hazards +#' and proportional-odds models for censored survival data, with application +#' to prognostic modelling and estimation of treatment effects. +#' \emph{Statistics in Medicine} \strong{21}(1):2175--2197. +#' +#' Sauerbrei, W. and Royston, P. (1999) Building multivariable prognostic and +#' diagnostic models: transformation of the predictors using fractional +#' polynomials. \emph{Journal of the Royal Statistical Society, Series A} +#' \strong{162}:71--94. +#' +#' Spiegelhalter, D., Thomas, A., Best, N., and Gilks, W. (1996) BUGS 0.5 #' Examples. MRC Biostatistics Unit, Institute of Public health, Cambridge, UK. #' #' Tarone, R. E. (1982) The use of historical control information in testing for diff --git a/R/doc-modeling-functions.R b/R/doc-modeling-functions.R index 8a3fa914e..6c2569aed 100644 --- a/R/doc-modeling-functions.R +++ b/R/doc-modeling-functions.R @@ -1,7 +1,7 @@ #' Modeling functions available in \pkg{rstanarm} -#' +#' #' @name available-models -#' +#' #' @section Modeling functions: #' The model estimating functions are described in greater detail in their #' individual help pages and vignettes. Here we provide a very brief @@ -33,9 +33,16 @@ #' appropriate estimates of uncertainty for models that consist of a mix of #' common and group-specific parameters. #' } +#' \item{\code{\link{stan_mvmer}}}{ +#' A multivariate form of \code{\link{stan_glmer}}, whereby the user can +#' specify one or more submodels each consisting of a GLM with group-specific +#' terms. If more than one submodel is specified (i.e. there is more than one +#' outcome variable) then a dependence is induced by assuming that the +#' group-specific terms for each grouping factor are correlated across submodels. +#' } #' \item{\code{\link{stan_nlmer}}}{ -#' Similar to \code{\link[lme4]{nlmer}} in the \pkg{lme4} package for -#' nonlinear "mixed-effects" models, but the group-specific coefficients +#' Similar to \code{\link[lme4]{nlmer}} in the \pkg{lme4} package for +#' nonlinear "mixed-effects" models, but the group-specific coefficients #' have flexible priors on their unknown covariance matrices. #' } #' \item{\code{\link{stan_gamm4}}}{ @@ -68,19 +75,22 @@ #' to \code{\link[survival]{clogit}} that allow \code{stan_clogit} to accept #' group-specific terms as in \code{\link{stan_glmer}}. #' } -#' \item{\code{\link{stan_mvmer}}}{ -#' A multivariate form of \code{\link{stan_glmer}}, whereby the user can -#' specify one or more submodels each consisting of a GLM with group-specific -#' terms. If more than one submodel is specified (i.e. there is more than one -#' outcome variable) then a dependence is induced by assuming that the -#' group-specific terms for each grouping factor are correlated across submodels. +#' \item{\code{\link{stan_surv}}}{ +#' Fits models to survival (i.e. time-to-event) data using either a hazard +#' scale or accelerated failure time (AFT) formulation. The user can choose +#' between either a flexible spline-based approximation for the baseline +#' hazard or a standard parametric distributional assumption for the baseline +#' hazard. Covariate effects can then be accommodated under proportional or +#' non-proportional hazards assumptions (for models on the hazard scale) or +#' using time-fixed or time-varying acceleration factors (for models on the +#' AFT scale). #' } #' \item{\code{\link{stan_jm}}}{ -#' Estimates shared parameter joint models for longitudinal and time-to-event -#' (i.e. survival) data. The joint model can be univariate (i.e. one longitudinal -#' outcome) or multivariate (i.e. more than one longitudinal outcome). A variety -#' of parameterisations are available for linking the longitudinal and event -#' processes (i.e. a variety of association structures). +#' Fits shared parameter joint models for longitudinal and survival (i.e. +#' time-to-event) data. The joint model can be univariate (i.e. one longitudinal +#' outcome) or multivariate (i.e. more than one longitudinal outcome). A variety +#' of parameterisations are available for linking the longitudinal and event +#' processes (i.e. a variety of association structures). #' } #' } #' diff --git a/R/jm_data_block.R b/R/jm_data_block.R index 6dbe0b0f7..4f87c4255 100644 --- a/R/jm_data_block.R +++ b/R/jm_data_block.R @@ -171,7 +171,7 @@ reformulate_rhs <- function(x, subbars = FALSE) { handle_cov_prior <- function(prior, cnms, ok_dists = nlist("decov", "lkj")) { if (!is.list(prior)) stop(sQuote(deparse(substitute(prior))), " should be a named list") - t <- length(unique(cnms)) # num grouping factors + t <- length(cnms) # num grouping factors p <- sapply(cnms, length) # num terms for each grouping factor prior_dist_name <- prior$dist if (!prior_dist_name %in% unlist(ok_dists)) { @@ -468,7 +468,13 @@ summarize_jm_prior <- adjusted_priorEvent_aux_scale else NULL, df = if (!is.na(prior_dist_name) && prior_dist_name %in% "student_t") - prior_df else NULL, + prior_df else NULL, + concentration = if (!is.na(prior_dist_name) && + prior_dist_name %in% "dirichlet") + prior_concentration else NULL, + rate = if (!is.na(prior_dist_name) && + prior_dist_name %in% "exponential") + 1 / prior_scale else NULL, aux_name = e_aux_name )) } @@ -524,9 +530,14 @@ summarize_jm_prior <- # @param basehaz A list with information about the baseline hazard .rename_e_aux <- function(basehaz) { nm <- basehaz$type_name - if (nm == "weibull") "weibull-shape" else - if (nm == "bs") "spline-coefficients" else - if (nm == "piecewise") "piecewise-coefficients" else NA + switch(nm, + "weibull" = "weibull-shape", + "weibull-aft" = "weibull-shape", + "gompertz" = "gompertz-scale", + "bs" = "B-spline-coefficients", + "ms" = "M-spline-coefficients", + "piecewise" = "piecewise-coefficients", + NA) } # Check if priors were autoscaled @@ -1034,7 +1045,7 @@ validate_observation_times <-function(data, eventtimes, id_var, time_var) { # model_frame: The model frame for the fitted Cox model, but with the # subject ID variable also included. # tvc: Logical, if TRUE then a counting type Surv() object was used -# in the fitted Cox model (ie. time varying covariates). +# in the fitted Cox model (ie. time-varying covariates). handle_e_mod <- function(formula, data, qnodes, id_var, y_id_list) { if (!requireNamespace("survival")) stop("the 'survival' package must be installed to use this function") @@ -1259,6 +1270,20 @@ make_basehaz_X <- function(times, basehaz) { X } +# Create a dummy indicator matrix for time intervals defined by 'knots' +# +# @param x A numeric vector with the original data. +# @param knots The cutpoints defining the desired categories of 'x'. +# @return A dummy matrix. +dummy_matrix <- function(x, knots) { + n_intervals <- length(knots) - 1 + interval <- cut(x, knots, include.lowest = TRUE, labels = FALSE) + out <- matrix(NA, length(interval), n_intervals) + for (i in 1:n_intervals) + out[, i] <- ifelse(interval == i, 1, 0) + as.matrix(out) +} + # Function to return standardised GK quadrature points and weights # # @param nodes The required number of quadrature nodes diff --git a/R/jm_make_assoc_parts.R b/R/jm_make_assoc_parts.R index 60c5ff425..e2996fb78 100644 --- a/R/jm_make_assoc_parts.R +++ b/R/jm_make_assoc_parts.R @@ -79,7 +79,7 @@ make_assoc_parts <- function(use_function = make_assoc_parts_for_stan, # observation time preceeding the quadrature point are carried forward to # represent the covariate value(s) at the quadrature point. (To avoid # missingness there is no limit on how far forwards or how far backwards - # covariate values can be carried). If no time varying covariates are + # covariate values can be carried). If no time-varying covariates are # present in the longitudinal submodel (other than the time variable) # then nothing is carried forward or backward. dataQ <- rolling_merge(data = newdata, ids = ids, times = times, grps = grps) diff --git a/R/log_lik.R b/R/log_lik.R index 7d9815d5f..960e902d6 100644 --- a/R/log_lik.R +++ b/R/log_lik.R @@ -73,22 +73,37 @@ log_lik.stanreg <- function(object, newdata = NULL, offset = NULL, ...) { newdata <- validate_newdata(object, newdata, m = NULL) calling_fun <- as.character(sys.call(-1))[1] dots <- list(...) + if (is.stanmvreg(object)) { - m <- dots[["m"]] - if (is.null(m)) - STOP_arg_required_for_stanmvreg(m) - if (!is.null(offset)) - stop2("'offset' cannot be specified for stanmvreg objects.") + m <- dots[["m"]]; if (is.null(m)) STOP_arg_required_for_stanmvreg(m) } else { m <- NULL } - + newdata <- validate_newdata(object, newdata = newdata, m = m) - args <- ll_args.stanreg(object, newdata = newdata, offset = offset, - reloo_or_kfold = calling_fun %in% c("kfold", "reloo"), - ...) + if (is.stansurv(object)) { + args <- ll_args.stansurv(object, newdata = newdata, ...) + } else { + args <- ll_args.stanreg(object, newdata = newdata, offset = offset, + reloo_or_kfold = calling_fun %in% c("kfold", "reloo"), + ...) + } + fun <- ll_fun(object, m = m) - if (is_clogit(object)) { + if (is.stansurv(object)) { + out <- + vapply( + seq_len(args$N), + FUN.VALUE = numeric(length = args$S), + FUN = function(i) { + as.vector(fun( + draws = args$draws, + data_i = args$data[args$data$cids == + unique(args$data$cids)[i], , drop = FALSE] + )) + } + ) + } else if (is_clogit(object)) { out <- vapply( seq_len(args$N), @@ -173,7 +188,9 @@ log_lik.stanjm <- function(object, newdataLong = NULL, newdataEvent = NULL, ...) ll_fun <- function(x, m = NULL) { validate_stanreg_object(x) f <- family(x, m = m) - if (!is(f, "family") || is_scobit(x)) + if (is.stansurv(x)) { + return(.ll_surv_i) + } else if (!is(f, "family") || is_scobit(x)) return(.ll_polr_i) else if (is_clogit(x)) return(.ll_clogit_i) @@ -201,6 +218,8 @@ ll_fun <- function(x, m = NULL) { # @return a named list with elements data, draws, S (posterior sample size) and # N = number of observations ll_args <- function(object, ...) UseMethod("ll_args") + +#--- ll_args for stanreg models ll_args.stanreg <- function(object, newdata = NULL, offset = NULL, m = NULL, reloo_or_kfold = FALSE, ...) { validate_stanreg_object(object) @@ -372,6 +391,90 @@ ll_args.stanreg <- function(object, newdata = NULL, offset = NULL, m = NULL, return(out) } +#--- ll_args for stansurv models +ll_args.stansurv <- function(object, newdata = NULL, ...) { + + validate_stansurv_object(object) + + if (is.null(newdata)) { + newdata <- get_model_data(object) + } + newdata <- as.data.frame(newdata) + + # response, ie. a Surv object + form <- as.formula(formula(object)) + y <- eval(form[[2L]], newdata) + + # outcome, ie. time variables and status indicator + t_beg <- make_t(y, type = "beg") # entry time + t_end <- make_t(y, type = "end") # exit time + t_upp <- make_t(y, type = "upp") # upper time for interval censoring + status <- make_d(y) + if (any(status < 0 | status > 3)) + stop2("Invalid status indicator in Surv object.") + + # delayed entry indicator for each row of data + delayed <- as.logical(!t_beg == 0) + + # we reconstruct the design matrices even if no newdata, since it is + # too much of a pain to store everything in the fitted model object + # (e.g. w/ delayed entry, interval censoring, quadrature points, etc) + pp <- pp_data(object, newdata, times = t_end) + + # returned object depends on quadrature + if (object$has_quadrature) { + pp_qpts_beg <- pp_data(object, newdata, times = t_beg, at_quadpoints = TRUE) + pp_qpts_end <- pp_data(object, newdata, times = t_end, at_quadpoints = TRUE) + pp_qpts_upp <- pp_data(object, newdata, times = t_upp, at_quadpoints = TRUE) + cpts <- c(pp$pts, pp_qpts_beg$pts, pp_qpts_end$pts, pp_qpts_upp$pts) + cwts <- c(pp$wts, pp_qpts_beg$wts, pp_qpts_end$wts, pp_qpts_upp$wts) + cids <- c(pp$ids, pp_qpts_beg$ids, pp_qpts_end$ids, pp_qpts_upp$ids) + x <- rbind(pp$x, pp_qpts_beg$x, pp_qpts_end$x, pp_qpts_upp$x) + s <- rbind(pp$s, pp_qpts_beg$s, pp_qpts_end$s, pp_qpts_upp$s) + x <- append_prefix_to_colnames(x, "x__") + s <- append_prefix_to_colnames(s, "s__") + status <- c(status, rep(NA, length(cids) - length(status))) + delayed <- c(delayed, rep(NA, length(cids) - length(delayed))) + data <- data.frame(cpts, cwts, cids, status, delayed) + data <- cbind(data, x, s) + } else { + x <- append_prefix_to_colnames(pp$x, "x__") + cids <- seq_along(t_end) + data <- data.frame(cids, t_beg, t_end, t_upp, status, delayed) + data <- cbind(data, x) + } + + # also evaluate random effects structure if relevant + if (object$has_bars) { + z <- t(pp$z$Zt) + if (object$has_quadrature) { + z <- rbind(z, + t(pp_qpts_beg$z$Zt), + t(pp_qpts_end$z$Zt), + t(pp_qpts_upp$z$Zt)) + } + z <- append_prefix_to_colnames(as.matrix(z), "z__") + data <- cbind(data, z) + } + + # parameter draws + draws <- list() + pars <- extract_pars(object) + draws$basehaz <- get_basehaz (object) + draws$aux <- pars$aux + draws$alpha <- pars$alpha + draws$beta <- pars$beta + draws$beta_tve <- pars$beta_tve + draws$b <- if (object$has_bars) pp_b_ord(pars$b, pp$z$Z_names) else NULL + draws$has_quadrature <- pp$has_quadrature + draws$has_tve <- pp$has_tve + draws$has_bars <- pp$has_bars + draws$qnodes <- pp$qnodes + + out <- nlist(data, draws, S = NROW(draws$beta), N = n_distinct(cids)) + return(out) +} + # check intercept for polr models ----------------------------------------- # Check if a model fit with stan_polr has an intercept (i.e. if it's actually a @@ -423,6 +526,23 @@ ll_args.stanreg <- function(object, newdata = NULL, offset = NULL, m = NULL, draws$f_phi$linkinv(eta) } +# for stan_surv only +.xdata_surv <- function(data) { + nms <- colnames(data) + sel <- grep("^x__", nms) + data[, sel] +} +.sdata_surv <- function(data) { + nms <- colnames(data) + sel <- grep("^s__", nms) + data[, sel] +} +.zdata_surv <- function(data) { + nms <- colnames(data) + sel <- grep("^z__", nms) + data[, sel] +} + # log-likelihood functions ------------------------------------------------ .ll_gaussian_i <- function(data_i, draws) { val <- dnorm(data_i$y, mean = .mu(data_i, draws), sd = draws$sigma, log = TRUE) @@ -499,6 +619,139 @@ ll_args.stanreg <- function(object, newdata = NULL, offset = NULL, m = NULL, .weighted(val, data_i$weights) } +.ll_surv_i <- function(data_i, draws) { + + # fixed effects (time-fixed) part of linear predictor + eta <- linear_predictor(draws$beta, .xdata_surv(data_i)) + + # fixed effects (time-varying) part of linear predictor + if (draws$has_tve) { + eta <- eta + linear_predictor(draws$beta_tve, .sdata_surv(data_i)) + } + + # random effects part of linear predictor + if (draws$has_bars) { + eta <- eta + linear_predictor(draws$b, .zdata_surv(data_i)) + } + + # convert linear predictor to log acceleration factor for AFT + eta <- switch(get_basehaz_name(draws$basehaz), + "exp-aft" = sweep(eta, 1L, -1, `*`), + "weibull-aft" = sweep(eta, 1L, -as.vector(draws$aux), `*`), + eta) + + if (draws$has_quadrature) { + + qnodes <- draws$qnodes + status <- data_i[1L, "status"] + delayed <- data_i[1L, "delayed"] + + # row indexing of quadrature points in data_i + idx_epts <- 1 + idx_qpts_beg <- 1 + (qnodes * 0) + (1:qnodes) + idx_qpts_end <- 1 + (qnodes * 1) + (1:qnodes) + idx_qpts_upp <- 1 + (qnodes * 2) + (1:qnodes) + + # arguments to be used later in evaluating log baseline hazard + args <- list(times = data_i$cpts, + basehaz = draws$basehaz, + aux = draws$aux, + intercept = draws$alpha) + + # evaluate log hazard + lhaz <- eta + do.call(evaluate_log_basehaz, args) + + # evaluate log likelihood + if (status == 1) { + # uncensored + lhaz_epts <- lhaz[, idx_epts, drop = FALSE] + lhaz_qpts_end <- lhaz[, idx_qpts_end, drop = FALSE] + lsurv <- -quadrature_sum(exp(lhaz_qpts_end), + qnodes = qnodes, + qwts = data_i$cwts[idx_qpts_end]) + ll <- lhaz_epts + lsurv + } else if (status == 0) { + # right censored + lhaz_qpts_end <- lhaz[, idx_qpts_end, drop = FALSE] + lsurv <- -quadrature_sum(exp(lhaz_qpts_end), + qnodes = qnodes, + qwts = data_i$cwts[idx_qpts_end]) + ll <- lsurv + } else if (status == 2) { + # left censored + lhaz_qpts_end <- lhaz[, idx_qpts_end, drop = FALSE] + lsurv <- -quadrature_sum(exp(lhaz_qpts_end), + qnodes = qnodes, + qwts = data_i$cwts[idx_qpts_end]) + ll <- log(1 - exp(lsurv)) # = log CDF + } else if (status == 3) { + # interval censored + lhaz_qpts_end <- lhaz[, idx_qpts_end, drop = FALSE] + lsurv_lower <- -quadrature_sum(exp(lhaz_qpts_end), + qnodes = qnodes, + qwts = data_i$cwts[idx_qpts_end]) + lhaz_qpts_upp <- lhaz[, idx_qpts_upp, drop = FALSE] + lsurv_upper <- -quadrature_sum(exp(lhaz_qpts_upp), + qnodes = qnodes, + qwts = data_i$cwts[idx_qpts_upp]) + ll <- log(exp(lsurv_lower) - exp(lsurv_upper)) + } + if (delayed) { + # delayed entry + lhaz_qpts_beg <- lhaz[, idx_qpts_beg, drop = FALSE] + lsurv_beg <- -quadrature_sum(exp(lhaz_qpts_beg), + qnodes = qnodes, + qwts = data_i$cwts[idx_qpts_beg]) + ll <- ll - lsurv_beg + } + + } else { # no quadrature + + status <- data_i$status + delayed <- data_i$delayed + + # arguments to be used later in evaluating log baseline hazard + args <- list(basehaz = draws$basehaz, + aux = draws$aux, + intercept = draws$alpha) + + # evaluate log likelihood + if (status == 1) { + # uncensored + args$times <- data_i$t_end + lhaz <- do.call(evaluate_log_basehaz, args) + eta + lsurv <- do.call(evaluate_log_basesurv, args) * exp(eta) + ll <- lhaz + lsurv + } else if (status == 0) { + # right censored + args$times <- data_i$t_end + lsurv <- do.call(evaluate_log_basesurv, args) * exp(eta) + ll <- lsurv + } else if (status == 2) { + # left censored + args$times <- data_i$t_end + lsurv <- do.call(evaluate_log_basesurv, args) * exp(eta) + ll <- log(1 - exp(lsurv)) # = log CDF + } else if (status == 3) { + # interval censored + args$times <- data_i$t_end + lsurv_lower <- do.call(evaluate_log_basesurv, args) * exp(eta) + args$times <- data_i$t_upp + lsurv_upper <- do.call(evaluate_log_basesurv, args) * exp(eta) + ll <- log(exp(lsurv_lower) - exp(lsurv_upper)) + } + if (delayed) { + # delayed entry + args$times <- data_i$t_beg + lsurv_beg <- do.call(evaluate_log_basesurv, args) * exp(eta) + ll <- ll - lsurv_beg + } + + } + return(ll) +} + + # log-likelihood functions for stanjm objects only ---------------------- # Alternative ll_args method for stanjm objects that allows data and pars to be @@ -780,9 +1033,9 @@ ll_args.stanjm <- function(object, data, pars, m = 1, } # Log baseline hazard at etimes (if not NULL) and qtimes - log_basehaz <- evaluate_log_basehaz(times = times, - basehaz = basehaz, - coefs = pars$bhcoef) + log_basehaz <- evaluate_log_basehaz2(times = times, + basehaz = basehaz, + coefs = pars$bhcoef) # Log hazard at etimes (if not NULL) and qtimes log_haz <- log_basehaz + e_eta @@ -832,7 +1085,7 @@ ll_args.stanjm <- function(object, data, pars, m = 1, # @param basehaz A list with info about the baseline hazard. # @param coefs A vector or matrix of parameter estimates (MCMC draws). # @return A vector or matrix, depending on the input type of coefs. -evaluate_log_basehaz <- function(times, basehaz, coefs) { +evaluate_log_basehaz2 <- function(times, basehaz, coefs) { type <- basehaz$type_name if (type == "weibull") { X <- log(times) # log times @@ -889,3 +1142,187 @@ evaluate_log_survival.matrix <- function(log_haz, qnodes, qwts) { # return: -cum_haz == log survival probability -cum_haz } + +#------------- + +# Evaluate the log baseline hazard at the specified times given the +# vector or matrix of MCMC draws for the baseline hazard parameters +# +# @param times A vector of times. +# @param basehaz A list with info about the baseline hazard. +# @param aux,intercept A vector or matrix of parameter estimates (MCMC draws). +# @param x Predictor matrix. +# @param s Predictor matrix for time-varying effects. +# @return A vector or matrix, depending on the input type of aux. +evaluate_log_basehaz <- function(times, basehaz, aux, intercept = NULL) { + switch(get_basehaz_name(basehaz), + "exp" = log_basehaz_exponential (times, log_scale = intercept), + "exp-aft" = log_basehaz_exponentialAFT(times, log_scale = intercept), + "weibull" = log_basehaz_weibull (times, shape = aux, log_scale = intercept), + "weibull-aft" = log_basehaz_weibullAFT(times, shape = aux, log_scale = intercept), + "gompertz" = log_basehaz_gompertz(times, scale = aux, log_shape = intercept), + "ms" = log_basehaz_ms(times, coefs = aux, basis = basehaz$basis, intercept = intercept), + "bs" = log_basehaz_bs(times, coefs = aux, basis = basehaz$basis, intercept = intercept), + "piecewise" = log_basehaz_pw(times, coefs = aux, knots = basehaz$knots), + stop2("Bug found: unknown type of baseline hazard.")) +} + +log_basehaz_exponential <- function(x, log_scale) { + linear_predictor(log_scale, rep(1, length(x))) +} +log_basehaz_exponentialAFT <- function(x, log_scale) { + linear_predictor(-log_scale, rep(1, length(x))) +} +log_basehaz_weibull <- function(x, shape, log_scale) { + as.vector(log_scale + log(shape)) + linear_predictor(shape - 1, log(x)) +} +log_basehaz_weibullAFT <- function(x, shape, log_scale) { + as.vector(-log_scale * shape + log(shape)) + linear_predictor(shape - 1, log(x)) +} +log_basehaz_gompertz <- function(x, scale, log_shape) { + as.vector(log_shape) + linear_predictor(scale, x) +} +log_basehaz_ms <- function(x, coefs, basis, intercept) { + as.vector(intercept) + log(linear_predictor(coefs, basis_matrix(x, basis = basis))) +} +log_basehaz_bs <- function(x, coefs, basis, intercept) { + as.vector(intercept) + linear_predictor(coefs, basis_matrix(x, basis = basis)) +} +log_basehaz_pw <- function(x, coefs, knots) { + linear_predictor(coefs, dummy_matrix(x, knots = knots)) +} + +evaluate_log_haz <- function(times, basehaz, betas, betas_tve, b = NULL, aux, + intercept = NULL, x, s = NULL, z = NULL) { + eta <- linear_predictor(betas, x) + if ((!is.null(s)) && ncol(s)) + eta <- eta + linear_predictor(betas_tve, s) + if (!is.null(z$Zt) && ncol(z$Zt)) { + b <- pp_b_ord(b, z$Z_names) + z <- as.matrix(t(z$Zt)) + eta <- eta + linear_predictor(b, z) + } + eta <- switch(get_basehaz_name(basehaz), + "exp-aft" = sweep(eta, 1L, -1, `*`), + "weibull-aft" = sweep(eta, 1L, -as.vector(aux), `*`), + eta) + args <- nlist(times, basehaz, aux, intercept) + do.call(evaluate_log_basehaz, args) + eta +} + +evaluate_basehaz <- function(times, basehaz, aux, intercept = NULL) { + exp(evaluate_log_basehaz(times = times, basehaz = basehaz, + aux = aux, intercept = intercept)) +} + +#------------- + +# Evaluate the log baseline survival at the specified times given the +# vector or matrix of MCMC draws for the baseline hazard parameters +# +# @param times A vector of times. +# @param basehaz A list with info about the baseline hazard. +# @param aux,intercept A vector or matrix of parameter estimates (MCMC draws). +# @return A vector or matrix, depending on the input type of aux. +evaluate_log_basesurv <- function(times, basehaz, aux, intercept = NULL) { + switch(get_basehaz_name(basehaz), + "exp" = log_basesurv_exponential (times, log_scale = intercept), + "exp-aft" = log_basesurv_exponentialAFT(times, log_scale = intercept), + "weibull" = log_basesurv_weibull (times, shape = aux, log_scale = intercept), + "weibull-aft" = log_basesurv_weibullAFT(times, shape = aux, log_scale = intercept), + "gompertz" = log_basesurv_gompertz(times, scale = aux, log_shape = intercept), + "ms" = log_basesurv_ms(times, coefs = aux, basis = basehaz$basis, intercept = intercept), + stop2("Bug found: unknown type of baseline hazard.")) +} + +log_basesurv_exponential <- function(x, log_scale) { + -linear_predictor(exp(log_scale), x) +} +log_basesurv_exponentialAFT <- function(x, log_scale) { + -linear_predictor(exp(-log_scale), x) +} +log_basesurv_weibull <- function(x, shape, log_scale) { + -exp(as.vector(log_scale) + linear_predictor(shape, log(x))) +} +log_basesurv_weibullAFT <- function(x, shape, log_scale) { + -exp(as.vector(-shape * log_scale) + linear_predictor(shape, log(x))) +} +log_basesurv_gompertz <- function(x, scale, log_shape) { + -(as.vector(exp(log_shape) / scale)) * (exp(linear_predictor(scale, x)) - 1) +} +log_basesurv_ms <- function(x, coefs, basis, intercept) { + - exp(as.vector(intercept)) * + linear_predictor(coefs, basis_matrix(x, basis = basis, integrate = TRUE)) +} + +evaluate_log_surv <- function(times, basehaz, betas, b = NULL, aux, + intercept = NULL, x, z = NULL, ...) { + eta <- linear_predictor(betas, x) + if (!is.null(z$Zt) && ncol(z$Zt)) { + b <- pp_b_ord(b, z$Z_names) + z <- as.matrix(t(z$Zt)) + eta <- eta + linear_predictor(b, z) + } + eta <- switch(get_basehaz_name(basehaz), + "exp-aft" = sweep(eta, 1L, -1, `*`), + "weibull-aft" = sweep(eta, 1L, -as.vector(aux), `*`), + eta) + args <- nlist(times, basehaz, aux, intercept) + do.call(evaluate_log_basesurv, args) * exp(eta) +} + +#--------------- + +quadrature_sum <- function(x, qnodes, qwts) { + UseMethod("quadrature_sum") +} + +quadrature_sum.default <- function(x, qnodes, qwts) { + weighted_x <- qwts * x # apply quadrature weights + splitted_x <- split_vector(x, n_segments = qnodes) # split at each quad node + Reduce('+', splitted_x) # sum over the quad nodes +} + +quadrature_sum.matrix <- function(x, qnodes, qwts) { + weighted_x <- sweep_multiply(x, qwts, margin = 2L) # apply quadrature weights + splitted_x <- array2list(weighted_x, nsplits = qnodes) # split at each quad node + Reduce('+', splitted_x) # sum over the quad nodes +} + +# Split a vector or matrix into a specified number of segments and return +# each segment as an element of a list. The matrix method allows splitting +# across the column (bycol = TRUE) or row margin (bycol = FALSE). +# +# @param x A vector or matrix. +# @param n_segments Integer specifying the number of segments. +# @param bycol Logical, should a matrix be split along the column or row margin? +# @return A list with n_segments elements. +split2 <- function(x, n_segments = 1, ...) { + UseMethod("split2") +} + +split2.vector <- function(x, n_segments = 1, ...) { + len <- length(x) + segment_length <- len %/% n_segments + if (!len == (segment_length * n_segments)) + stop("Dividing x by n_segments does not result in an integer.") + split(x, rep(1:n_segments, each = segment_length)) +} + +split2.matrix <- function(x, n_segments = 1, bycol = TRUE) { + len <- if (bycol) ncol(x) else nrow(x) + segment_length <- len %/% n_segments + if (!len == (segment_length * n_segments)) + stop("Dividing x by n_segments does not result in an integer.") + lapply(1:n_segments, function(k) { + if (bycol) x[, (k-1) * segment_length + 1:segment_length, drop = FALSE] else + x[(k-1) * segment_length + 1:segment_length, , drop = FALSE]}) +} + +# Split a vector or matrix into a specified number of segments +# (see rstanarm:::split2) and then reduce it using 'FUN' +split_and_reduce <- function(x, n_segments = 1, bycol = TRUE, FUN = '+') { + splitted_x <- split2(x, n_segments = n_segments, bycol = bycol) + Reduce(FUN, splitted_x) +} + diff --git a/R/loo.R b/R/loo.R index e54be5538..5bdd52729 100644 --- a/R/loo.R +++ b/R/loo.R @@ -233,7 +233,7 @@ loo.stanreg <- )) } else if (is_clogit(x)) { ll <- log_lik.stanreg(x) - cons <- apply(ll,MARGIN = 2, FUN = function(y) sd(y) < 1e-15) + cons <- apply(ll, MARGIN = 2, FUN = function(y) sd(y) < 1e-15) if (any(cons)) { message( "The following strata were dropped from the ", @@ -250,6 +250,16 @@ loo.stanreg <- cores = cores, save_psis = save_psis )) + } else if (is.stansurv(x) && x$has_quadrature) { + ll <- log_lik.stanreg(x) + r_eff <- loo::relative_eff(exp(ll), chain_id = chain_id, cores = cores) + loo_x <- + suppressWarnings(loo.matrix( + ll, + r_eff = r_eff, + cores = cores, + save_psis = save_psis + )) } else { args <- ll_args(x) llfun <- ll_fun(x) @@ -345,6 +355,8 @@ waic.stanreg <- function(x, ...) { out <- waic.matrix(ll) } else if (is_clogit(x)) { out <- waic.matrix(log_lik(x)) + } else if (is.stansurv(x) && x$has_quadrature) { + out <- waic.matrix(log_lik(x)) } else { args <- ll_args(x) out <- waic.function(ll_fun(x), data = args$data, draws = args$draws) @@ -627,7 +639,7 @@ reloo <- function(x, loo_x, obs, ..., refit = TRUE) { open_progress = FALSE ) } - fit_j_call$subset <- eval(fit_j_call$subset) + fit_j_call$subset <- if (!is.stansurv(x)) eval(fit_j_call$subset) else NULL fit_j_call$data <- eval(fit_j_call$data) if (!is.null(getCall(x)$offset)) { fit_j_call$offset <- x$offset[-omitted] @@ -757,6 +769,8 @@ hash_y <- function(x, ...) { is_discrete <- function(object) { if (inherits(object, "polr")) return(TRUE) + if (inherits(object, "stansurv")) + return(FALSE) if (inherits(object, "stanmvreg")) { fams <- fetch(family(object), "family") res <- sapply(fams, function(x) diff --git a/R/misc.R b/R/misc.R index 9d59d3fb3..5e61fb7f5 100644 --- a/R/misc.R +++ b/R/misc.R @@ -101,10 +101,20 @@ default_stan_control <- function(prior, adapt_delta = NULL, nlist(adapt_delta, max_treedepth) } -# Test if an object is a stanreg object +# Test if an object inherits a specific stanreg class # # @param x The object to test. -is.stanreg <- function(x) inherits(x, "stanreg") +is.stanreg <- function(x) inherits(x, "stanreg") +is.stansurv <- function(x) inherits(x, "stansurv") +is.stanmvreg <- function(x) inherits(x, "stanmvreg") +is.stanjm <- function(x) inherits(x, "stanjm") + +# Test if object contains a specific type of submodel +# +# @param x The object to test. +is.jm <- function(x) isTRUE(x$stan_function %in% c("stan_jm")) +is.mvmer <- function(x) isTRUE(x$stan_function %in% c("stan_jm", "stan_mvmer")) +is.surv <- function(x) isTRUE(x$stan_function %in% c("stan_jm", "stan_surv")) # Throw error if object isn't a stanreg object # @@ -114,16 +124,40 @@ validate_stanreg_object <- function(x, call. = FALSE) { stop("Object is not a stanreg object.", call. = call.) } +# Throw error if object isn't a stanmvreg object +# +# @param x The object to test. +validate_stanmvreg_object <- function(x, call. = FALSE) { + if (!is.stanmvreg(x)) + stop("Object is not a stanmvreg object.", call. = call.) +} + +# Throw error if object isn't a stanjm object +# +# @param x The object to test. +validate_stanjm_object <- function(x, call. = FALSE) { + if (!is.stanjm(x)) + stop("Object is not a stanjm object.", call. = call.) +} + +# Throw error if object isn't a stansurv object +# +# @param x The object to test. +validate_stansurv_object <- function(x, call. = FALSE) { + if (!is.stansurv(x)) + stop("Object is not a stansurv object.", call. = call.) +} + # Test for a given family # # @param x A character vector (probably x = family(fit)$family) is.binomial <- function(x) x == "binomial" is.gaussian <- function(x) x == "gaussian" -is.gamma <- function(x) x == "Gamma" -is.ig <- function(x) x == "inverse.gaussian" -is.nb <- function(x) x == "neg_binomial_2" -is.poisson <- function(x) x == "poisson" -is.beta <- function(x) x == "beta" || x == "Beta regression" +is.gamma <- function(x) x == "Gamma" +is.ig <- function(x) x == "inverse.gaussian" +is.nb <- function(x) x == "neg_binomial_2" +is.poisson <- function(x) x == "poisson" +is.beta <- function(x) x == "beta" || x == "Beta regression" # test if a stanreg object has class clogit is_clogit <- function(object) { @@ -161,9 +195,12 @@ used.variational <- function(x) { # @param x A stanreg object. is.mer <- function(x) { stopifnot(is.stanreg(x)) + check0 <- is.stansurv(x) && x$has_bars check1 <- inherits(x, "lmerMod") check2 <- !is.null(x$glmod) - if (check1 && !check2) { + if (check0) { + return(TRUE) + } else if (check1 && !check2) { stop("Bug found. 'x' has class 'lmerMod' but no 'glmod' component.") } else if (!check1 && check2) { stop("Bug found. 'x' has 'glmod' component but not class 'lmerMod'.") @@ -362,17 +399,18 @@ has_outcome_variable <- function(f) { # Check if any variables in a model frame are constants # # exceptions: constant variable of all 1's is allowed and outcomes with all 0s -# or 1s are allowed (e.g., for binomial models) +# or 1s are allowed (e.g., for binomial models) and survival outcomes (e.g., +# incase all event indicators are 1s) # # @param mf A model frame or model matrix # @return If no constant variables are found mf is returned, otherwise an error # is thrown. check_constant_vars <- function(mf) { mf1 <- mf - if (NCOL(mf[, 1]) == 2 || all(mf[, 1] %in% c(0, 1))) { + if (NCOL(mf[, 1]) == 2 || all(mf[, 1] %in% c(0, 1)) || survival::is.Surv(mf[, 1])) { mf1 <- mf[, -1, drop=FALSE] } - + lu1 <- function(x) !all(x == 1) && length(unique(x)) == 1 nocheck <- c("(weights)", "(offset)", "(Intercept)") sel <- !colnames(mf1) %in% nocheck @@ -625,6 +663,23 @@ get_z.stanmvreg <- function(object, m = NULL, ...) { if (!is.null(m)) ret[[m]] else list_nms(ret, stub = stub) } +#' Extract survival response from a stansurv or stanjm object +#' +#' @keywords internal +#' @export +#' @param object A \code{stansurv} or \code{stanjm} object. +#' @param ... Other arguments passed to methods. +#' @return A \code{Surv} object, see \code{?survival::Surv}. +get_surv <- function(object, ...) UseMethod("get_surv") +#' @export +get_surv.stansurv <- function(object, ...) { + model.response(model.frame(object)) %ORifNULL% stop("response not found") +} +#' @export +get_surv.stanjm <- function(object, ...) { + object$survmod$mod$y %ORifNULL% stop("response not found") +} + # Get inverse link function # # @param x A stanreg object, family object, or string. @@ -902,22 +957,37 @@ array2list <- function(x, nsplits, bycol = TRUE) { x[(k-1) * len_k + 1:len_k, , drop = FALSE]}) } +# Use sweep to multiply a vector or array. Note that usually sweep cannot +# handle a vector, whereas this function definition can. +# +# @param x A vector or array. +# @param y The vector or scalar to multiply 'x' by. +# @param margin The margin of 'x' across which to apply 'y' (only relevant +# if 'x' is an array, i.e. not a vector). +# @return An object of the same class as 'x'. +sweep_multiply <- function(x, y, margin = 2L) { + if (is.vector(x)) return(x * y) + sweep(x, margin, y, `*`) +} + # Convert a standardised quadrature node to an unstandardised value based on # the specified integral limits # # @param x An unstandardised quadrature node # @param a The lower limit(s) of the integral, possibly a vector # @param b The upper limit(s) of the integral, possibly a vector -unstandardise_qpts <- function(x, a, b) { +unstandardise_qpts <- function(x, a, b, na.ok = TRUE) { if (!identical(length(x), 1L) || !is.numeric(x)) - stop("'x' should be a single numeric value.", call. = FALSE) - if (!all(is.numeric(a), is.numeric(b))) - stop("'a' and 'b' should be numeric.", call. = FALSE) + stop2("'x' should be a single numeric value.") if (!length(a) %in% c(1L, length(b))) - stop("'a' and 'b' should be vectors of length 1, or, be the same length.", call. = FALSE) - if (any((b - a) < 0)) - stop("The upper limits for the integral ('b' values) should be greater than ", - "the corresponding lower limits for the integral ('a' values).", call. = FALSE) + stop2("'a' and 'b' should be vectors of length 1, or, be the same length.") + if (!na.ok) { + if (!all(is.numeric(a), is.numeric(b))) + stop2("'a' and 'b' should be numeric.") + if (any((b - a) < 0)) + stop2("The upper limits for the integral ('b' values) should be greater than ", + "the corresponding lower limits for the integral ('a' values).") + } ((b - a) / 2) * x + ((b + a) / 2) } @@ -927,70 +997,21 @@ unstandardise_qpts <- function(x, a, b) { # @param x An unstandardised quadrature weight # @param a The lower limit(s) of the integral, possibly a vector # @param b The upper limit(s) of the integral, possibly a vector -unstandardise_qwts <- function(x, a, b) { +unstandardise_qwts <- function(x, a, b, na.ok = TRUE) { if (!identical(length(x), 1L) || !is.numeric(x)) - stop("'x' should be a single numeric value.", call. = FALSE) - if (!all(is.numeric(a), is.numeric(b))) - stop("'a' and 'b' should be numeric.", call. = FALSE) + stop2("'x' should be a single numeric value.") if (!length(a) %in% c(1L, length(b))) - stop("'a' and 'b' should be vectors of length 1, or, be the same length.", call. = FALSE) - if (any((b - a) < 0)) - stop("The upper limits for the integral ('b' values) should be greater than ", - "the corresponding lower limits for the integral ('a' values).", call. = FALSE) + stop2("'a' and 'b' should be vectors of length 1, or, be the same length.") + if (!na.ok) { + if (!all(is.numeric(a), is.numeric(b))) + stop2("'a' and 'b' should be numeric.") + if (any((b - a) < 0)) + stop2("The upper limits for the integral ('b' values) should be greater than ", + "the corresponding lower limits for the integral ('a' values).") + } ((b - a) / 2) * x } -# Test if object is stanmvreg class -# -# @param x An object to be tested. -is.stanmvreg <- function(x) { - inherits(x, "stanmvreg") -} - -# Test if object is stanjm class -# -# @param x An object to be tested. -is.stanjm <- function(x) { - inherits(x, "stanjm") -} - -# Test if object is a joint longitudinal and survival model -# -# @param x An object to be tested. -is.jm <- function(x) { - isTRUE(x$stan_function == "stan_jm") -} - -# Test if object contains a multivariate GLM -# -# @param x An object to be tested. -is.mvmer <- function(x) { - isTRUE(x$stan_function %in% c("stan_mvmer", "stan_jm")) -} - -# Test if object contains a survival model -# -# @param x An object to be tested. -is.surv <- function(x) { - isTRUE(x$stan_function %in% c("stan_jm")) -} - -# Throw error if object isn't a stanmvreg object -# -# @param x The object to test. -validate_stanmvreg_object <- function(x, call. = FALSE) { - if (!is.stanmvreg(x)) - stop("Object is not a stanmvreg object.", call. = call.) -} - -# Throw error if object isn't a stanjm object -# -# @param x The object to test. -validate_stanjm_object <- function(x, call. = FALSE) { - if (!is.stanjm(x)) - stop("Object is not a stanjm object.", call. = call.) -} - # Throw error if parameter isn't a positive scalar # # @param x The object to test. @@ -1010,18 +1031,20 @@ validate_positive_scalar <- function(x, not_greater_than = NULL) { } } -# Return a list with the median and prob% CrI bounds for each column of a -# matrix or 2D array +# Return a matrix or list with the median and prob% CrI bounds for +# each column of a matrix or 2D array # # @param x A matrix or 2D array # @param prob Value between 0 and 1 indicating the desired width of the CrI -median_and_bounds <- function(x, prob, na.rm = FALSE) { +# @param return_matrix Logical, if TRUE then a matrix with three columns is +# returned (med, lb, ub) else if FALSE a list with three elements is returned. +median_and_bounds <- function(x, prob, na.rm = FALSE, return_matrix = FALSE) { if (!any(is.matrix(x), is.array(x))) stop("x should be a matrix or 2D array.") med <- apply(x, 2, median, na.rm = na.rm) lb <- apply(x, 2, quantile, (1 - prob)/2, na.rm = na.rm) ub <- apply(x, 2, quantile, (1 + prob)/2, na.rm = na.rm) - nlist(med, lb, ub) + if (return_matrix) cbind(med, lb, ub) else nlist(med, lb, ub) } # Return the stub for variable names from one submodel of a stan_jm model @@ -1226,6 +1249,14 @@ STOP_arg_required_for_stanmvreg <- function(arg) { stop2(msg) } +# Error message when not specifying 'id_var' for stansurv methods that require it +# +# @param arg The argument +STOP_id_var_required <- function() { + stop2("'id_var' must be specified for models with a start-stop response ", + "or with time-varying effects.") +} + # Error message when a function is not yet implemented for stanmvreg objects # # @param what A character string naming the function not yet implemented @@ -1236,6 +1267,16 @@ STOP_if_stanmvreg <- function(what) { stop2(msg) } +# Error message when a function is not yet implemented for stansurv objects +# +# @param what A character string naming the function not yet implemented +STOP_if_stansurv <- function(what) { + msg <- "not yet implemented for stansurv objects." + if (!missing(what)) + msg <- paste(what, msg) + stop2(msg) +} + # Error message when a function is not yet implemented for stan_mvmer models # # @param what An optional message to prepend to the default message. @@ -1272,6 +1313,13 @@ STOP_no_var <- function(var) { stop2("Variable '", var, "' cannot be found in the data frame.") } +# Error message when values for the time variable are negative +# +# @param var The name of the time variable +STOP_negative_times <- function(var) { + stop2("Values for the time variable (", var, ") should not be negative.") +} + # Error message for dynamic predictions # # @param what A reason why the dynamic predictions are not allowed @@ -1313,9 +1361,11 @@ check_pp_ids <- function(object, ids, m = 1) { # variable must be included in the new data frame # @return A list of validated data frames validate_newdatas <- function(object, newdataLong = NULL, newdataEvent = NULL, - duplicate_ok = FALSE, response = TRUE) { + duplicate_ok = FALSE, response = TRUE, + needs_time_var = TRUE) { validate_stanmvreg_object(object) id_var <- object$id_var + time_var <- object$time_var newdatas <- list() if (!is.null(newdataLong)) { if (!is(newdataLong, "list")) @@ -1323,6 +1373,14 @@ validate_newdatas <- function(object, newdataLong = NULL, newdataEvent = NULL, dfcheck <- sapply(newdataLong, is.data.frame) if (!all(dfcheck)) stop("'newdataLong' must be a data frame or list of data frames.", call. = FALSE) + if (!needs_time_var) { + newdataLong <- lapply(newdataLong, function(m) { + if (!time_var %in% colnames(m)) { + m[[time_var]] <- 0 # hack to pass nacheck below + } + m + }) + } nacheck <- sapply(seq_along(newdataLong), function(m) { if (response) { # newdataLong needs the reponse variable fmL <- formula(object, m = m) @@ -1349,7 +1407,7 @@ validate_newdatas <- function(object, newdataLong = NULL, newdataEvent = NULL, stop("'newdataEvent' cannot contain NAs.", call. = FALSE) if (!duplicate_ok && any(duplicated(newdataEvent[[id_var]]))) stop("'newdataEvent' should only contain one row per individual, since ", - "time varying covariates are not allowed in the prediction data.") + "time-varying covariates are not allowed in the prediction data.") newdatas <- c(newdatas, list(Event = newdataEvent)) } if (length(newdatas)) { @@ -1368,19 +1426,23 @@ validate_newdatas <- function(object, newdataLong = NULL, newdataEvent = NULL, # Return data frames only including the specified subset of individuals # -# @param object A stanmvreg object # @param data A data frame, or a list of data frames # @param ids A vector of ids indicating which individuals to keep +# @param id_var Character string, the name of the ID variable # @return A data frame, or a list of data frames, depending on the input -subset_ids <- function(object, data, ids) { +subset_ids <- function(data, ids, id_var) { + if (is.null(data)) return(NULL) - validate_stanmvreg_object(object) - id_var <- object$id_var + is_list <- is(data, "list") - if (!is_list) data <- list(data) - is_df <- sapply(data, is.data.frame) - if (!all(is_df)) stop("'data' should be a data frame, or list of data frames.") + if (!is_list) + data <- list(data) # convert to list + + is_df <- sapply(data, inherits, "data.frame") + if (!all(is_df)) + stop("'data' should be a data frame, or list of data frames.") + data <- lapply(data, function(x) { if (!id_var %in% colnames(x)) STOP_no_var(id_var) sel <- which(!ids %in% x[[id_var]]) @@ -1389,6 +1451,7 @@ subset_ids <- function(object, data, ids) { paste(ids[[sel]], collapse = ", ")) x[x[[id_var]] %in% ids, , drop = FALSE] }) + if (is_list) return(data) else return(data[[1]]) } @@ -1488,23 +1551,48 @@ get_time_seq <- function(increments, t0, t1, simplify = TRUE) { # Extract parameters from stanmat and return as a list # -# @param object A stanmvreg object +# @param object A stanmvreg or stansurv object # @param stanmat A matrix of posterior draws, may be provided if the desired # stanmat is only a subset of the draws from as.matrix(object$stanfit) # @return A named list -extract_pars <- function(object, stanmat = NULL, means = FALSE) { +extract_pars <- function(object, ...) { + UseMethod("extract_pars") +} + +extract_pars.stansurv <- function(object, stanmat = NULL, means = FALSE) { + validate_stansurv_object(object) + if (is.null(stanmat)) + stanmat <- as.matrix(object$stanfit) + if (means) + stanmat <- t(colMeans(stanmat)) # return posterior means + nms_beta <- colnames(object$x) + nms_tve <- get_smooth_name(object$s_cpts, type = "smooth_coefs") + nms_smth <- get_smooth_name(object$s_cpts, type = "smooth_sd") + nms_int <- get_int_name_basehaz(object$basehaz) + nms_aux <- get_aux_name_basehaz(object$basehaz) + nms_b <- b_names(colnames(stanmat)) + alpha <- stanmat[, nms_int, drop = FALSE] + beta <- stanmat[, nms_beta, drop = FALSE] + beta_tve <- stanmat[, nms_tve, drop = FALSE] + aux <- stanmat[, nms_aux, drop = FALSE] + smooth <- stanmat[, nms_smth, drop = FALSE] + b <- stanmat[, nms_b, drop = FALSE] + nlist(alpha, beta, beta_tve, aux, smooth, b, stanmat) +} + +extract_pars.stanmvreg <- function(object, stanmat = NULL, means = FALSE) { validate_stanmvreg_object(object) M <- get_M(object) if (is.null(stanmat)) stanmat <- as.matrix(object$stanfit) if (means) stanmat <- t(colMeans(stanmat)) # return posterior means - nms <- collect_nms(colnames(stanmat), M, stub = get_stub(object)) - beta <- lapply(1:M, function(m) stanmat[, nms$y[[m]], drop = FALSE]) - ebeta <- stanmat[, nms$e, drop = FALSE] - abeta <- stanmat[, nms$a, drop = FALSE] + nms <- collect_nms(colnames(stanmat), M, stub = get_stub(object)) + beta <- lapply(1:M, function(m) stanmat[, nms$y[[m]], drop = FALSE]) + b <- lapply(1:M, function(m) stanmat[, nms$y_b[[m]], drop = FALSE]) + ebeta <- stanmat[, nms$e, drop = FALSE] + abeta <- stanmat[, nms$a, drop = FALSE] bhcoef <- stanmat[, nms$e_extra, drop = FALSE] - b <- lapply(1:M, function(m) stanmat[, nms$y_b[[m]], drop = FALSE]) nlist(beta, ebeta, abeta, bhcoef, b, stanmat) } @@ -1712,22 +1800,550 @@ pad_matrix <- function(x, cols = NULL, rows = NULL, x } -#------- helpers from brms package +# Return the cutpoints for a specified number of quantiles of 'x' +# +# @param x A numeric vector. +# @param nq Integer specifying the number of quantiles. +# @return A vector of percentiles corresponding to percentages 100*k/m for +# k=1,2,...,nq-1. +qtile <- function(x, nq = 2) { + if (nq > 1) { + probs <- seq(1, nq - 1) / nq + return(quantile(x, probs = probs)) + } else if (nq == 1) { + return(NULL) + } else { + stop("'nq' must be >= 1.") + } +} + +# Return the desired spline basis for the given knot locations +get_basis <- function(x, iknots, bknots = range(x), + degree = 3, intercept = FALSE, + type = c("bs", "is", "ms")) { + type <- match.arg(type) + if (type == "bs") { + out <- splines2::bSpline(x, knots = iknots, Boundary.knots = bknots, + degree = degree, intercept = intercept) + } else if (type == "is") { + out <- splines2::iSpline(x, knots = iknots, Boundary.knots = bknots, + degree = degree, intercept = TRUE) + } else if (type == "ms") { + out <- splines2::mSpline(x, knots = iknots, Boundary.knots = bknots, + degree = degree, intercept = TRUE) + } else { + stop2("'type' is not yet accommodated.") + } + out +} + +# Paste character vector collapsing with a comma +comma <- function(x) { + paste(x, collapse = ", ") +} + +# Select rows of a matrix +# +# @param x A matrix. +# @param rows Logical or numeric vector stating which rows of 'x' to retain. +keep_rows <- function(x, rows = 1:nrow(x)) { + x[rows, , drop = FALSE] +} + +# Drop rows of a matrix +# +# @param x A matrix. +# @param rows Logical or numeric vector stating which rows of 'x' to drop +drop_rows <- function(x, rows = 1:nrow(x)) { + x[!rows, , drop = FALSE] +} -stop2 <- function(...) { - stop(..., call. = FALSE) +# Replicate rows of a matrix or data frame +# +# @param x A matrix or data frame. +# @param ... Arguments passed to 'rep', namely 'each' or 'times'. +rep_rows <- function(x, ...) { + if (is.null(x) || !nrow(x)) { + return(x) + } else if (is.matrix(x) || is.data.frame(x)) { + x <- x[rep(1:nrow(x), ...), , drop = FALSE] + } else { + stop2("'x' must be a matrix or data frame.") + } + x } +# Stop without printing call +stop2 <- function(...) stop(..., call. = FALSE) + +# Immediate warning without printing call +warning2 <- function(...) warning(..., immediate. = TRUE, call. = FALSE) + +# Shorthand for suppress warnings +SW <- function(expr) base::suppressWarnings(expr) + +# Check if an object is NULL is_null <- function(x) { - # check if an object is NULL is.null(x) || ifelse(is.vector(x), all(sapply(x, is.null)), FALSE) } +# Check if all objects are NULL +all_null <- function(...) { + dots <- list(...) + null_check <- uapply(dots, function(x) { + is.null(x) || ifelse(is.vector(x), all(sapply(x, is.null)), FALSE) + }) + all(null_check) +} + +# Check if any objects are NULL +any_null <- function(...) { + dots <- list(...) + null_check <- uapply(dots, function(x) { + is.null(x) || ifelse(is.vector(x), all(sapply(x, is.null)), FALSE) + }) + any(null_check) +} +# Recursively removes NULL entries from an object rm_null <- function(x, recursive = TRUE) { - # recursively removes NULL entries from an object x <- Filter(Negate(is_null), x) if (recursive) { x <- lapply(x, function(x) if (is.list(x)) rm_null(x) else x) } x } + +# Check if all elements are equal allowing NA and NULL +is_equal <- function(x, y, ...) { + isTRUE(all.equal(x, y, ...)) +} + +# Check if x behaves like a factor in design matrices +is_like_factor <- function(x) { + is.factor(x) || is.character(x) || is.logical(x) +} + +# Check if 'x' is FALSE +isFALSE <- function(x) { + identical(FALSE, x) +} + +# @param x numeric vector +log_sum_exp <- function(x) { + max_x <- max(x) + max_x + log(sum(exp(x - max_x))) +} + +# Concatenate (i.e. 'c(...)') but don't demote factors to integers +ulist <- function(...) { unlist(list(...)) } + +# Return the names for the group specific coefficients +# +# @param cnms A named list with the names of the parameters nested within each +# grouping factor. +# @param flevels A named list with the (unique) factor levels nested within each +# grouping factor. +# @return A character vector. +get_ranef_name <- function(cnms, flevels) { + cnms_nms <- names(cnms) + b_nms <- uapply(seq_along(cnms), FUN = function(i) { + nm <- cnms_nms[i] + nms_i <- paste(cnms[[i]], nm) + flevels[[nm]] <- c(gsub(" ", "_", flevels[[nm]]), + paste0("_NEW_", nm)) + if (length(nms_i) == 1) { + paste0(nms_i, ":", flevels[[nm]]) + } else { + c(t(sapply(nms_i, paste0, ":", flevels[[nm]]))) + } + }) + c(paste0("b[", b_nms, "]")) +} + +# Return the name for the mean_PPD +get_ppd_name <- function(x, ...) { + paste0(x$stub, "|mean_PPD") +} + +# Return the name for the intercept parameter +get_int_name_basehaz <- function(x, is_jm = FALSE, ...) { + if (is_jm || has_intercept(x)) "(Intercept)" else NULL +} +get_int_name_ymod <- function(x, ...) { + if (x$intercept_type$number) paste0(x$stub, "|(Intercept)") else NULL +} +get_int_name_emod <- function(x, is_jm = FALSE, ...) { + nm <- get_int_name_basehaz(x$basehaz, is_jm = is_jm) + if (!is.null(nm)) paste0("Event|", nm) else NULL +} + +# Return the names for the auxiliary parameters +get_aux_name_basehaz <- function(x, ...) { + switch(get_basehaz_name(x), + "exp" = NULL, + "exp-aft" = NULL, + "weibull" = "weibull-shape", + "weibull-aft" = "weibull-shape", + "gompertz" = "gompertz-scale", + "ms" = paste0("m-splines-coef", seq(x$nvars)), + "bs" = paste0("b-splines-coef", seq(x$nvars)), + "piecewise" = paste0("piecewise-coef", seq(x$nvars)), + NA) +} +get_aux_name_ymod <- function(x, ...) { + switch(x$family$family, + gaussian = paste0(x$stub, "|sigma"), + Gamma = paste0(x$stub, "|shape"), + inverse.gaussian = paste0(x$stub, "|lambda"), + neg_binomial_2 = paste0(x$stub, "|reciprocal_dispersion"), + NULL) +} +get_aux_name_emod <- function(x, ...) { + nms <- get_aux_name_basehaz(x$basehaz) + if (!is.null(nms)) paste0("Event|", nms) else NULL +} + +# Return the names for the coefficients +get_beta_name_ymod <- function(x) { + nms <- colnames(x$x$xtemp) + if (!is.null(nms)) paste0(x$stub, "|", nms) else NULL +} +get_beta_name_emod <- function(x, ...) { + nms <- colnames(x$x) + if (!is.null(nms)) paste0("Event|", nms) else NULL +} + +# Return the names for the association parameters +get_assoc_name <- function(a_mod, assoc, ...) { + M <- length(a_mod) + a <- assoc + ev <- "etavalue" + es <- "etaslope" + ea <- "etaauc" + mv <- "muvalue" + ms <- "muslope" + ma <- "muauc" + evd <- "etavalue_data" + esd <- "etaslope_data" + mvd <- "muvalue_data" + msd <- "muslope_data" + evev <- "etavalue_etavalue" + evmv <- "etavalue_muvalue" + mvev <- "muvalue_etavalue" + mvmv <- "muvalue_muvalue" + p <- function(...) paste0(...) + indx <- function(x, m) paste0("Long", assoc["which_interactions",][[m]][[x]]) + cnms <- function(x, m) colnames(a_mod[[m]][["X_data"]][[x]]) + nms <- character() + for (m in 1:M) { + stub <- paste0("Assoc|Long", m, "|") + # order matters here! (needs to line up with the monitored stanpars) + if (a[ev, ][[m]]) nms <- c(nms, p(stub, ev )) + if (a[evd, ][[m]]) nms <- c(nms, p(stub, ev, ":", cnms(evd, m) )) + if (a[evev,][[m]]) nms <- c(nms, p(stub, ev, ":", indx(evev, m), "|", ev)) + if (a[evmv,][[m]]) nms <- c(nms, p(stub, ev, ":", indx(evmv, m), "|", mv)) + if (a[es, ][[m]]) nms <- c(nms, p(stub, es )) + if (a[esd, ][[m]]) nms <- c(nms, p(stub, es, ":", cnms(esd, m) )) + if (a[ea, ][[m]]) nms <- c(nms, p(stub, ea )) + if (a[mv, ][[m]]) nms <- c(nms, p(stub, mv )) + if (a[mvd, ][[m]]) nms <- c(nms, p(stub, mv, ":", cnms(mvd, m) )) + if (a[mvev,][[m]]) nms <- c(nms, p(stub, mv, ":", indx(mvev, m), "|", ev)) + if (a[mvmv,][[m]]) nms <- c(nms, p(stub, mv, ":", indx(mvmv, m), "|", mv)) + if (a[ms, ][[m]]) nms <- c(nms, p(stub, ms )) + if (a[msd, ][[m]]) nms <- c(nms, p(stub, ms, ":", cnms(msd, m) )) + if (a[ma, ][[m]]) nms <- c(nms, p(stub, ma )) + } + nms +} + +# Return the list with summary information about the baseline hazard +# +# @return A named list. +get_basehaz <- function(x) { + if (is.stansurv(x)) + return(x$basehaz) + if (is.stanjm(x)) + return(x$survmod$basehaz) + stop("Bug found: could not find basehaz.") +} + +# Return the name of the baseline hazard +# +# @return A character string. +get_basehaz_name <- function(x) { + if (is.character(x)) + return(x) + if (is.stansurv(x)) + return(x$basehaz$type_name) + if (is.stanjm(x)) + return(x$survmod$basehaz$type_name) + if (is.character(x$type_name)) + return(x$type_name) + stop("Bug found: could not resolve basehaz name.") +} + +# Add the variables in ...'s to the RHS of a model formula +# +# @param x A model formula. +# @param ... Character strings, the variable names. +addto_formula <- function(x, ...) { + rhs_terms <- terms(reformulate_rhs(rhs(x))) + intercept <- attr(rhs_terms, "intercept") + term_labels <- attr(rhs_terms, "term.labels") + reformulate(c(term_labels, c(...)), response = lhs(x), intercept = intercept) +} + +# Shorthand for as.integer, as.double, as.matrix, as.array +ai <- function(x, ...) as.integer(x, ...) +ad <- function(x, ...) as.double (x, ...) +am <- function(x, ...) as.matrix (x, ...) +aa <- function(x, ...) as.array (x, ...) + +# Sample rows from a two-dimensional object +# +# @param x The two-dimensional object (e.g. matrix, data frame, array). +# @param size Integer specifying the number of rows to sample. +# @param replace Should the rows be sampled with replacement? +# @return A two-dimensional object with 'size' rows and 'ncol(x)' columns. +sample_rows <- function(x, size, replace = FALSE) { + samp <- sample(nrow(x), size, replace) + x[samp, , drop = FALSE] +} + +# Sample rows from a stanmat object +# +# @param object A stanreg object. +# @param draws The number of draws/rows to sample from the stanmat. +# @param default_draws Integer or NA. If 'draws' is NULL then the number of +# rows sampled from the stanmat is equal to +# min(default_draws, posterior_sample_size, na.rm = TRUE) +# @return A matrix with 'draws' rows and 'ncol(stanmat)' columns. +sample_stanmat <- function(object, draws = NULL, default_draws = NA) { + S <- posterior_sample_size(object) + if (is.null(draws)) + draws <- min(default_draws, S, na.rm = TRUE) + if (draws > S) + stop2("'draws' should be <= posterior sample size (", S, ").") + stanmat <- as.matrix(object$stanfit) + if (isTRUE(draws < S)) { + stanmat <- sample_rows(stanmat, draws) + } + stanmat +} + +# Method to truncate a numeric vector at defined limits +# +# @param con A numeric vector. +# @param lower Scalar, the lower limit for the returned vector. +# @param upper Scalar, the upper limit for the returned vector. +# @return A numeric vector. +truncate.numeric <- function(con, lower = NULL, upper = NULL) { + if (!is.null(lower)) con[con < lower] <- lower + if (!is.null(upper)) con[con > upper] <- upper + con +} + +# Transpose only if 'x' is a vector +transpose_vector <- function(x) { + if (is.vector(x)) return(t(x)) else return(x) +} + +# Simplified conditional for 'if (is.null(...))' +if_null <- function(test, yes, no) { + if (is.null(test)) yes else no +} + +# Replace entries of 'x' based on a (possibly) vectorised condition +# +# @param x The vector, matrix, or array. +# @param condition The logical condition, possibly a logical vector. +# @param replacement The value to replace with, where the condition is TRUE. +# @param margin The margin of 'x' on which to apply the condition. +# @return The same class as 'x' but possibly with some entries replaced. +replace_where <- function(x, condition, replacement, margin = 1L) { + switch(margin, + x[condition] <- replacement, + x[,condition] <- replacement, + stop("Cannot handle 'margin' > 2.")) + x +} + +# Calculate row means, but don't simplify to a vector +row_means <- function(x, na.rm = FALSE) { + mns <- rowMeans(x, na.rm = na.rm) + if (is.matrix(x)) { + return(matrix(mns, ncol = 1)) + } else if (is.array(x)) { + return(array(mns, dim = c(nrow(x), 1))) + } else if (is.data.frame(x)) { + return(data.frame(mns)) + } else { + stop2("Cannot handle objects of class: ", class(x)) + } +} + +# Calculate column means, but don't simplify to a vector +col_means <- function(x, na.rm = FALSE) { + mns <- colMeans(x, na.rm = na.rm) + if (is.matrix(x)) { + return(matrix(mns, nrow = 1)) + } else if (is.array(x)) { + return(array(mns, dim = c(1, ncol(x)))) + } else { + stop2("Cannot handle objects of class: ", class(x)) + } +} + +# Set row or column names on an object +set_rownames <- function(x, names) { rownames(x) <- names; x } +set_colnames <- function(x, names) { colnames(x) <- names; x } + +# Select rows or columns by name or index +select_rows <- function(x, rows) { x[rows, , drop = FALSE] } +select_cols <- function(x, cols) { x[, cols, drop = FALSE] } + +# Add attributes, but only if 'condition' is TRUE +structure2 <- function(.Data, condition, ...) { + if (condition) structure(.Data, ...) else .Data +} + +# Split a vector in a specified number of (equally sized) segments +# +# @param x The vector to split. +# @param n_segments Integer specifying the desired number of segments. +# @return A list of vectors, see `?split`. +split_vector <- function(x, n_segments = 1) { + split(x, rep(1:n_segments, each = length(x) / n_segments)) +} + +# Replace an NA object, or NA entries in a vector +# +# @param x The vector with elements to potentially replace. +# @param replace_with The replacement value. +replace_na <- function(x, replace_with = "0") { + if (is.na(x)) { + x <- replace_with + } else { + x[is.na(x)] <- replace_with + } + x +} + +# Replace an NULL object, or NULL entries in a vector +# +# @param x The vector with elements to potentially replace. +# @param replace_with The replacement value. +replace_null <- function(x, replace_with = "0") { + if (is.null(x)) { + x <- replace_with + } else { + x[is.null(x)] <- replace_with + } + x +} + +# Add an intercept column onto a predictor matrix +add_intercept <- function(x) { + stopifnot(is.matrix(x)) + cbind(rep(1, nrow(x)), x) +} + +# Replace named elements of 'x' with 'y' +replace_named_elements <- function(x, y) { x[names(y)] <- y; x } + +# Invert 'is.null' +not.null <- function(x) { !is.null(x) } + +# Shorthand for as.integer, as.double, as.matrix, as.array +ai <- function(x, ...) as.integer(x, ...) +ad <- function(x, ...) as.double(x, ...) +am <- function(x, ...) as.matrix(x, ...) +aa <- function(x, ...) as.array(x, ...) + +# Return a vector of 0's or 1's +zeros <- function(n) rep(0, times = n) +ones <- function(n) rep(1, times = n) + +# Check if all elements of a vector are zeros +all_zero <- function(x) all(x == 0) + +# Return the maximum integer or double +max_integer <- function() .Machine$integer.max +max_double <- function() .Machine$double.xmax + +# Check for scalar or string +is.scalar <- function(x) { isTRUE(is.numeric(x) && (length(x) == 1)) } +is.string <- function(x) { isTRUE(is.character(x) && (length(x) == 1)) } + +# Safe deparse +safe_deparse <- function(expr) deparse(expr, 500L) + +# Evaluate a character string +eval_string <- function(x) eval(parse(text = x)) + +# Mutate, similar to dplyr (ie. append a new variable(s) to the data frame) +mutate <- function(x, ..., names_eval = FALSE, n = 4) { + dots <- list(...) + if (names_eval) { # evaluate names in parent frame + nms <- sapply(names(dots), function(x) eval.parent(as.name(x), n = n)) + } else { + nms <- names(dots) + } + for (i in seq_along(dots)) + x[[nms[[i]]]] <- dots[[i]] + x +} +mutate_ <- function(x, ...) mutate(x, ..., names_eval = TRUE, n = 5) + +# Sort the rows of a data frame based on the variables specified in dots. +# (For convenience, any variables in ... that are not in the data frame +# are ignored, rather than throwing an error - dangerous but convenient) +# +# @param x A data frame. +# @param ... Character strings; names of the columns of 'x' on which to sort. +# @param skip Logical, if TRUE then any strings in the ...'s that are not +# present as variables in the data frame are ignored, rather than throwing +# an error - somewhat dangerous, but convenient. +# @return A data frame. +row_sort <- function(x, ...) { + stopifnot(is.data.frame(x)) + vars <- lapply(list(...), as.name) # convert string to name + x[with(x, do.call(order, vars)), , drop = FALSE] +} + +# Order the cols of a data frame in the order specified in the dots. Any +# remaining columns of 'x' are retained as is and included after the ... columns. +# +# @param x A data frame. +# @param ... Character strings; the desired order of the columns of 'x' by name. +# @param skip Logical, if TRUE then any strings in the ...'s that are not +# present as variables in the data frame are ignored, rather than throwing +# an error - somewhat dangerous, but convenient. +# @return A data frame. +col_sort <- function(x, ...) { + stopifnot(is.data.frame(x)) + vars1 <- unlist(list(...)) + vars2 <- setdiff(colnames(x), vars1) # select the leftover columns in x + x[, c(vars1, vars2), drop = FALSE] +} + +# Calculate the specified quantiles for each column of an array +col_quantiles <- function(x, probs, na.rm = FALSE, return_matrix = FALSE) { + stopifnot(is.matrix(x) || is.array(x)) + out <- lapply(probs, function(q) apply(x, 2, quantile, q, na.rm = na.rm)) + if (return_matrix) do.call(cbind, out) else out +} +col_quantiles_ <- function(x, probs) { + col_quantiles(x, probs, na.rm = TRUE, return_matrix = TRUE) +} + +# Append a string (prefix) to the column names of a matrix or array +append_prefix_to_colnames <- function(x, str) { + if (ncol(x)) set_colnames(x, paste0(str, colnames(x))) else x +} + +# Return the name of the calling function as a string +get_calling_fun <- function(which = -2) { + fn <- tryCatch(sys.call(which = which)[[1L]], error = function(e) NULL) + if (!is.null(fn)) safe_deparse(fn) else NULL +} diff --git a/R/plots.R b/R/plots.R index a1f8551a8..7feb2f46f 100644 --- a/R/plots.R +++ b/R/plots.R @@ -1,5 +1,6 @@ # Part of the rstanarm package for estimating model parameters # Copyright (C) 2015, 2016, 2017 Trustees of Columbia University +# Copyright (C) 2018 Sam Brilleman # # This program is free software; you can redistribute it and/or # modify it under the terms of the GNU General Public License @@ -30,16 +31,20 @@ #' @template args-stanreg-object #' @template args-pars #' @template args-regex-pars -#' @param plotfun A character string naming the \pkg{bayesplot} -#' \link[bayesplot:MCMC-overview]{MCMC} function to use. The default is to call -#' \code{\link[bayesplot:MCMC-intervals]{mcmc_intervals}}. \code{plotfun} can be specified +#' @param plotfun A character string naming the \pkg{bayesplot} +#' \link[bayesplot]{MCMC} function to use. The default is to call +#' \code{\link[bayesplot]{mcmc_intervals}}. \code{plotfun} can be specified #' either as the full name of a \pkg{bayesplot} plotting function (e.g. #' \code{"mcmc_hist"}) or can be abbreviated to the part of the name following #' the \code{"mcmc_"} prefix (e.g. \code{"hist"}). To get the names of all -#' available MCMC functions see \code{\link[bayesplot:available_ppc]{available_mcmc}}. +#' available MCMC functions see \code{\link[bayesplot]{available_mcmc}}. +#' For the \code{stansurv} method, one can also specify +#' \code{plotfun = "basehaz"} for a plot of the estimated baseline hazard +#' function, or \code{plot = "tve"} for a plot of the time-varying +#' hazard ratio (if time-varying effects were specified in the model). #' #' @param ... Additional arguments to pass to \code{plotfun} for customizing the -#' plot. These are described on the help pages for the individual plotting +#' plot. These are described on the help pages for the individual plotting #' functions. For example, the arguments accepted for the default #' \code{plotfun="intervals"} can be found at #' \code{\link[bayesplot:MCMC-intervals]{mcmc_intervals}}. @@ -49,19 +54,19 @@ #' (e.g. a gtable object created by \code{\link[gridExtra]{arrangeGrob}}). #' #' @seealso -#' \itemize{ +#' \itemize{ #' \item The vignettes in the \pkg{bayesplot} package for many examples. #' \item \code{\link[bayesplot]{MCMC-overview}} (\pkg{bayesplot}) for links to #' the documentation for all the available plotting functions. #' \item \code{\link[bayesplot:bayesplot-colors]{color_scheme_set}} (\pkg{bayesplot}) to change #' the color scheme used for plotting. #' \item \code{\link{pp_check}} for graphical posterior predictive checks. -#' \item \code{\link{plot_nonlinear}} for models with nonlinear smooth +#' \item \code{\link{plot_nonlinear}} for models with nonlinear smooth #' functions fit using \code{\link{stan_gamm4}}. -#' } +#' } #' #' @template reference-bayesvis -#' +#' #' @examples #' if (.Platform$OS.type != "windows" || .Platform$r_arch != "i386") { #' \donttest{ @@ -82,7 +87,7 @@ #' bayesplot::color_scheme_set("brightblue") #' plot(fit, "areas", regex_pars = "period", #' prob = 0.5, prob_outer = 0.9) -#' +#' #' # Make the same plot by extracting posterior draws and calling #' # bayesplot::mcmc_areas directly #' x <- as.array(fit, regex_pars = "period") @@ -114,27 +119,27 @@ #' ### Rhat, effective sample size, autocorrelation ### #' #################################################### #' bayesplot::color_scheme_set("red") -#' +#' #' # rhat #' plot(fit, "rhat") #' plot(fit, "rhat_hist") -#' +#' #' # ratio of effective sample size to total posterior sample size #' plot(fit, "neff") #' plot(fit, "neff_hist") -#' +#' #' # autocorrelation by chain #' plot(fit, "acf", pars = "(Intercept)", regex_pars = "period") #' plot(fit, "acf_bar", pars = "(Intercept)", regex_pars = "period") -#' -#' +#' +#' #' ################## #' ### Traceplots ### #' ################## #' # NOTE: rstanarm doesn't store the warmup draws (to save space because they #' # are not so essential for diagnosing the particular models implemented in #' # rstanarm) so the iterations in the traceplot are post-warmup iterations -#' +#' #' bayesplot::color_scheme_set("pink") #' (trace <- plot(fit, "trace", pars = "(Intercept)")) #' @@ -142,7 +147,7 @@ #' trace + ggplot2::scale_color_discrete() #' trace + ggplot2::scale_color_manual(values = c("maroon", "skyblue2")) #' -#' # changing facet layout +#' # changing facet layout #' plot(fit, "trace", pars = c("(Intercept)", "period2"), #' facet_args = list(nrow = 2)) #' # same plot by calling bayesplot::mcmc_trace directly @@ -168,16 +173,148 @@ #' plot.stanreg <- function(x, plotfun = "intervals", pars = NULL, regex_pars = NULL, ...) { - + if (plotfun %in% c("pairs", "mcmc_pairs")) return(pairs.stanreg(x, pars = pars, regex_pars = regex_pars, ...)) - + fun <- set_plotting_fun(plotfun) args <- set_plotting_args(x, pars, regex_pars, ..., plotfun = plotfun) do.call(fun, args) } +# plot method for stansurv ---------------------------------------------- + +#' @rdname plot.stanreg +#' @method plot stansurv +#' @export +#' @templateVar cigeomArg ci_geom_args +#' @template args-ci-geom-args +#' @param prob A scalar between 0 and 1 specifying the width to use for the +#' plotted posterior uncertainty interval when \code{limit = "ci"}. For +#' example \code{prob = 0.95} (the default) means that the 2.5th and 97.5th +#' percentiles will be provided. +#' @param limits A quoted character string specifying the type of limits to +#' include in the plot. Can be \code{"ci"} for the Bayesian posterior +#' uncertainty interval, or \code{"none"}. This argument is only relevant +#' when \code{plotfun = "basehaz"} or \code{plotfun = "tve"} +#' @param n Integer specifying the number of points to interpolate along +#' when plotting the baseline hazard or time-varying hazard ratio. Each of +#' the points are joined using a line. +#' +plot.stansurv <- function(x, plotfun = "basehaz", pars = NULL, + regex_pars = NULL, ..., prob = 0.95, + limits = c("ci", "none"), + ci_geom_args = NULL, n = 1000) { + + validate_stansurv_object(x) + + limits <- match.arg(limits) + + if (plotfun %in% c("basehaz", "tve")) { + + stanpars <- extract_pars(x) + has_intercept <- check_for_intercept(x$basehaz) + + t_min <- min(x$entrytime) + t_max <- max(x$eventtime) + times <- seq(t_min, t_max, by = (t_max - t_min) / n) + + if (plotfun == "basehaz") { + + if (!is.null(pars)) + warning2("'pars' is ignored when plotting the baseline hazard.") + if (!is.null(regex_pars)) + warning2("'regex_pars' is ignored when plotting the baseline hazard.") + + args <- nlist(times = times, + basehaz = get_basehaz(x), + aux = stanpars$aux, + intercept = stanpars$alpha) + basehaz <- do.call(evaluate_basehaz, args) + basehaz <- median_and_bounds(basehaz, prob, na.rm = TRUE) + plotdat <- data.frame(times, basehaz) + + uses_step_stair <- (get_basehaz_name(x) %in% c("piecewise")) + + ylab <- "Baseline hazard rate" + xlab <- "Time" + + } else if (plotfun == "tve") { + + if (!x$has_tve) + stop2("Model does not have time-varying effects.") + + smooth_map <- get_smooth_name(x$s_cpts, type = "smooth_map") + smooth_vars <- get_smooth_name(x$s_cpts, type = "smooth_vars") + smooth_coefs <- get_smooth_name(x$s_cpts, type = "smooth_coefs") + + if (is.null(pars)) + pars <- smooth_vars + if (length(pars) > 1L) + stop2("Only one variable can be specified in 'pars' .") + if (!pars %in% smooth_vars) + stop2("Cannot find variable '", pars, "' amongst the tve terms.") + + sel1 <- which(smooth_vars == pars) + sel2 <- smooth_coefs[smooth_map == sel1] + + betas_tf <- stanpars$beta [, pars, drop = FALSE] + betas_td <- stanpars$beta_tve[, sel2, drop = FALSE] + betas <- cbind(betas_tf, betas_td) + + tt_varid <- unique(x$formula$tt_map[smooth_map == sel1]) + tt_type <- x$formula$tt_types [[tt_varid]] + tt_degree <- x$formula$tt_degrees[[tt_varid]] + tt_form <- x$formula$tt_forms [[tt_varid]] + tt_data <- data.frame(times__ = times) + tt_x <- model.matrix(tt_form, tt_data) + + coef <- linear_predictor(betas, tt_x) + + is_aft <- get_basehaz_name(x$basehaz) %in% c("exp-aft", "weibull-aft") + + plotdat <- median_and_bounds(exp(coef), prob, na.rm = TRUE) + plotdat <- data.frame(times, plotdat) + + uses_step_stair <- (tt_degree == 0) + + xlab <- "Time" + ylab <- ifelse(is_aft, + paste0("Survival time ratio\n(", pars, ")"), + paste0("Hazard ratio\n(", pars, ")")) + } + + geom_defs <- list(color = "black") # default plot args + geom_args <- set_geom_args(geom_defs, ...) + geom_maps <- list(aes_string(x = "times", y = "med")) + geom_ylab <- ggplot2::ylab(ylab) + geom_xlab <- ggplot2::xlab(xlab) + geom_base <- ggplot(plotdat) + geom_ylab + geom_xlab + ggplot2::theme_bw() + geom_fun <- if (uses_step_stair) ggplot2::geom_step else ggplot2::geom_line + geom_plot <- geom_base + do.call(geom_fun, c(geom_maps, geom_args)) + + if (limits == "ci") { + lim_defs <- list(alpha = 0.3) # default plot args for ci + lim_args <- c(defaults = list(lim_defs), ci_geom_args) + lim_args <- do.call("set_geom_args", lim_args) + lim_maps <- list(mapping = aes_string(x = "times", ymin = "lb", ymax = "ub")) + lim_tmp <- geom_base + + geom_fun(aes_string(x = "times", y = "lb")) + + geom_fun(aes_string(x = "times", y = "ub")) + lim_build<- ggplot2::ggplot_build(lim_tmp) + lim_data <- list(data = data.frame(times = lim_build$data[[1]]$x, + lb = lim_build$data[[1]]$y, + ub = lim_build$data[[2]]$y)) + lim_plot <- do.call(ggplot2::geom_ribbon, c(lim_data, lim_maps, lim_args)) + } else { + lim_plot <- NULL + } + return(geom_plot + lim_plot) + } + NextMethod("plot") +} + # internal for plot.stanreg ---------------------------------------------- @@ -198,7 +335,7 @@ set_plotting_args <- function(x, pars = NULL, regex_pars = NULL, ..., .plotfun_is_type <- function(patt) { grepl(pattern = paste0("_", patt), x = plotfun, fixed = TRUE) } - + if (.plotfun_is_type("nuts")) { nuts_stuff <- list(x = bayesplot::nuts_params(x), ...) if (!.plotfun_is_type("energy")) @@ -217,13 +354,13 @@ set_plotting_args <- function(x, pars = NULL, regex_pars = NULL, ..., pars <- collect_pars(x, pars, regex_pars) pars <- allow_special_parnames(x, pars) } - + if (!used.sampling(x)) { if (!length(pars)) pars <- NULL return(list(x = as.matrix(x, pars = pars), ...)) } - + list(x = as.array(x, pars = pars, regex_pars = regex_pars), ...) } @@ -253,10 +390,10 @@ mcmc_function_name <- function(fun) { if (!identical(substr(fun, 1, 5), "mcmc_")) fun <- paste0("mcmc_", fun) - + if (!fun %in% bayesplot::available_mcmc()) stop( - fun, " is not a valid MCMC function name.", + fun, " is not a valid MCMC function name.", " Use bayesplot::available_mcmc() for a list of available MCMC functions." ) @@ -289,11 +426,11 @@ set_plotting_fun <- function(plotfun = NULL) { stop("'plotfun' should be a string.", call. = FALSE) plotfun <- mcmc_function_name(plotfun) - fun <- try(get(plotfun, pos = asNamespace("bayesplot"), mode = "function"), + fun <- try(get(plotfun, pos = asNamespace("bayesplot"), mode = "function"), silent = TRUE) if (!inherits(fun, "try-error")) return(fun) - + stop( "Plotting function ", plotfun, " not found. ", "A valid plotting function is any function from the ", @@ -305,14 +442,14 @@ set_plotting_fun <- function(plotfun = NULL) { # check if plotfun is ok to use with vb or optimization validate_plotfun_for_opt_or_vb <- function(plotfun) { plotfun <- mcmc_function_name(plotfun) - if (needs_chains(plotfun) || + if (needs_chains(plotfun) || grepl("_rhat|_neff|_nuts_", plotfun)) STOP_sampling_only(plotfun) } - # pairs method ------------------------------------------------------------ + #' Pairs method for stanreg objects #' #' Interface to \pkg{bayesplot}'s @@ -325,7 +462,7 @@ validate_plotfun_for_opt_or_vb <- function(plotfun) { #' @importFrom bayesplot pairs_style_np pairs_condition #' @export pairs_style_np pairs_condition #' @aliases pairs_style_np pairs_condition -#' +#' #' @templateVar stanregArg x #' @template args-stanreg-object #' @template args-regex-pars @@ -333,18 +470,24 @@ validate_plotfun_for_opt_or_vb <- function(plotfun) { #' are included by default, but for models with more than just a few #' parameters it may be far too many to visualize on a small computer screen #' and also may require substantial computing time. -#' @param condition Same as the \code{condition} argument to -#' \code{\link[bayesplot:MCMC-scatterplots]{mcmc_pairs}} except the \emph{default is different} +#' @param condition Same as the \code{condition} argument to +#' \code{\link[bayesplot]{mcmc_pairs}} except the \emph{default is different} #' for \pkg{rstanarm} models. By default, the \code{mcmc_pairs} function in #' the \pkg{bayesplot} package plots some of the Markov chains (half, in the #' case of an even number of chains) in the panels above the diagonal and the -#' other half in the panels below the diagonal. However since we know that -#' \pkg{rstanarm} models were fit using Stan (which \pkg{bayesplot} doesn't -#' assume) we can make the default more useful by splitting the draws -#' according to the \code{accept_stat__} diagnostic. The plots below the -#' diagonal will contain realizations that are below the median -#' \code{accept_stat__} and the plots above the diagonal will contain +#' other half in the panels below the diagonal. However since we know that +#' \pkg{rstanarm} models were fit using Stan (which \pkg{bayesplot} doesn't +#' assume) we can make the default more useful by splitting the draws +#' according to the \code{accept_stat__} diagnostic. The plots below the +#' diagonal will contain realizations that are below the median +#' \code{accept_stat__} and the plots above the diagonal will contain #' realizations that are above the median \code{accept_stat__}. To change this +#' behavior see the documentation of the \code{condition} argument at +#' \code{\link[bayesplot]{mcmc_pairs}}. +#' @param ... Optional arguments passed to \code{\link[bayesplot]{mcmc_pairs}}. +#' The \code{np}, \code{lp}, and \code{max_treedepth} arguments to +#' \code{mcmc_pairs} are handled automatically by \pkg{rstanarm} and do not +#' need to be specified by the user in \code{...}. The arguments that can be #' behavior see the documentation of the \code{condition} argument at #' \code{\link[bayesplot:MCMC-scatterplots]{mcmc_pairs}}. #' @param ... Optional arguments passed to @@ -363,10 +506,10 @@ validate_plotfun_for_opt_or_vb <- function(plotfun) { #' if (.Platform$OS.type != "windows" || .Platform$r_arch != "i386") { #' \donttest{ #' if (!exists("example_model")) example(example_model) -#' +#' #' bayesplot::color_scheme_set("purple") -#' -#' # see 'condition' argument above for details on the plots below and +#' +#' # see 'condition' argument above for details on the plots below and #' # above the diagonal. default is to split by accept_stat__. #' pairs(example_model, pars = c("(Intercept)", "log-posterior")) #' @@ -380,7 +523,7 @@ validate_plotfun_for_opt_or_vb <- function(plotfun) { #' adapt_delta = 0.9, #' refresh = 0 #' ) -#' +#' #' pairs(fit, pars = c("wt", "sigma", "log-posterior")) #' #' # requires hexbin package @@ -393,20 +536,20 @@ validate_plotfun_for_opt_or_vb <- function(plotfun) { #' #' bayesplot::color_scheme_set("brightblue") #' pairs( -#' fit, -#' pars = c("(Intercept)", "wt", "sigma", "log-posterior"), -#' transformations = list(sigma = "log"), +#' fit, +#' pars = c("(Intercept)", "wt", "sigma", "log-posterior"), +#' transformations = list(sigma = "log"), #' off_diag_args = list(size = 3/4, alpha = 1/3), # size and transparency of scatterplot points #' np_style = pairs_style_np(div_color = "black", div_shape = 2) # color and shape of the divergences #' ) -#' -#' # Using the condition argument to show divergences above the diagonal +#' +#' # Using the condition argument to show divergences above the diagonal #' pairs( -#' fit, -#' pars = c("(Intercept)", "wt", "log-posterior"), +#' fit, +#' pars = c("(Intercept)", "wt", "log-posterior"), #' condition = pairs_condition(nuts = "divergent__") #' ) -#' +#' #' } #' } pairs.stanreg <- @@ -415,21 +558,21 @@ pairs.stanreg <- regex_pars = NULL, condition = pairs_condition(nuts = "accept_stat__"), ...) { - + if (!used.sampling(x)) STOP_sampling_only("pairs") - + dots <- list(...) ignored_args <- c("np", "lp", "max_treedepth") specified <- ignored_args %in% names(dots) if (any(specified)) { warning( "The following arguments were ignored because they are ", - "specified automatically by rstanarm: ", + "specified automatically by rstanarm: ", paste(sQuote(ignored_args[specified]), collapse = ", ") ) } - + posterior <- as.array.stanreg(x, pars = pars, regex_pars = regex_pars) if (is.null(pars) && is.null(regex_pars)) { # include log-posterior by default @@ -444,16 +587,16 @@ pairs.stanreg <- posterior <- tmp } posterior <- round(posterior, digits = 12) - + bayesplot::mcmc_pairs( - x = posterior, - np = bayesplot::nuts_params(x), - lp = bayesplot::log_posterior(x), + x = posterior, + np = bayesplot::nuts_params(x), + lp = bayesplot::log_posterior(x), max_treedepth = .max_treedepth(x), condition = condition, ... ) - + } diff --git a/R/posterior_linpred.R b/R/posterior_linpred.R index 21426b743..c79f4485f 100644 --- a/R/posterior_linpred.R +++ b/R/posterior_linpred.R @@ -92,9 +92,10 @@ posterior_linpred.stanreg <- XZ = FALSE, ...) { - if (is.stanmvreg(object)) { + if (is.stanmvreg(object)) STOP_if_stanmvreg("'posterior_linpred'") - } + if (is.stansurv(object)) + STOP_if_stansurv("'poterior_linpred'") newdata <- validate_newdata(object, newdata = newdata, m = NULL) dat <- pp_data(object, diff --git a/R/posterior_predict.R b/R/posterior_predict.R index ce3a51bbe..de1665035 100644 --- a/R/posterior_predict.R +++ b/R/posterior_predict.R @@ -150,6 +150,11 @@ posterior_predict.stanreg <- function(object, newdata = NULL, draws = NULL, re.form = NULL, fun = NULL, seed = NULL, offset = NULL, ...) { + + if (is.stansurv(object)) + stop2("'posterior_predict' is not implemented for stansurv objects. ", + "Use 'posterior_survfit' instead.") + if (!is.null(seed)) set.seed(seed) if (!is.null(fun)) diff --git a/R/posterior_survfit.R b/R/posterior_survfit.R index 11cd39473..a5cb10eef 100644 --- a/R/posterior_survfit.R +++ b/R/posterior_survfit.R @@ -16,79 +16,94 @@ # along with this program; if not, write to the Free Software # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. -#' Estimate subject-specific or standardised survival probabilities -#' -#' This function allows us to generate estimated survival probabilities -#' based on draws from the posterior predictive distribution. By default -#' the survival probabilities are conditional on an individual's -#' group-specific coefficients (i.e. their individual-level random -#' effects). If prediction data is provided via the \code{newdataLong} -#' and \code{newdataEvent} arguments, then the default behaviour is to -#' sample new group-specific coefficients for the individuals in the -#' new data using a Monte Carlo scheme that conditions on their -#' longitudinal outcome data provided in \code{newdataLong} -#' (sometimes referred to as "dynamic predictions", see Rizopoulos -#' (2011)). This default behaviour can be stopped by specifying -#' \code{dynamic = FALSE}, in which case the predicted survival -#' probabilities will be marginalised over the distribution of the -#' group-specific coefficients. This has the benefit that the user does -#' not need to provide longitudinal outcome measurements for the new -#' individuals, however, it does mean that the predictions will incorporate -#' all the uncertainty associated with between-individual variation, since -#' the predictions aren't conditional on any observed data for the individual. -#' In addition, by default, the predicted subject-specific survival -#' probabilities are conditional on observed values of the fixed effect -#' covariates (ie, the predictions will be obtained using either the design -#' matrices used in the original \code{\link{stan_jm}} model call, or using the -#' covariate values provided in the \code{newdataLong} and \code{newdataEvent} -#' arguments). However, if you wish to average over the observed distribution -#' of the fixed effect covariates then this is possible -- such predictions -#' are sometimes referred to as standardised survival probabilties -- see the -#' \code{standardise} argument below. -#' +#' Posterior predictions for survival models +#' +#' This function allows us to generate predicted quantities for survival +#' models at specified times. These quantities include the hazard rate, +#' cumulative hazard, survival probability, or failure probability (i.e. CDF). +#' Note that the cumulative hazard, survival probability, or failure +#' probability may be conditional on a last known survival time (see the +#' \code{condition} argument discussed below). Predictions are obtained +#' using unique draws from the posterior distribution of each of the model +#' parameters and then summarised into a median and posterior uncertainty +#' interval. For \code{stan_jm} models "dynamic" predictions are allowed and +#' are in fact the default when new data is provided (see the \code{dynamic} +#' argument discussed below). +#' +#' #' @export -#' @templateVar stanjmArg object -#' @template args-stanjm-object -#' -#' @param newdataLong,newdataEvent Optionally, a data frame (or in the case of -#' \code{newdataLong} this can be a list of data frames) in which to look -#' for variables with which to predict. If omitted, the model matrices are used. -#' If new data is provided, then it should also contain the longitudinal -#' outcome data on which to condition when drawing the new group-specific -#' coefficients for individuals in the new data. Note that there is only -#' allowed to be one row of data for each individual in \code{newdataEvent}, +#' @import splines2 +#' +#' @templateVar stanregArg object +#' @template args-stansurv-stanjm-object +#' +#' @param newdata Optionally, a data frame in which to look for variables with +#' which to predict. If omitted, the model matrix is used. If \code{newdata} +#' is provided and any variables were transformed (e.g. rescaled) in the data +#' used to fit the model, then these variables must also be transformed in +#' \code{newdata}. This only applies if variables were transformed before +#' passing the data to one of the modeling functions and \emph{not} if +#' transformations were specified inside the model formula. Also, +#' \code{newdata} can optionally include a variable with information +#' about the last known survival time for the new individuals -- +#' see the description for the \code{last_time} argument below +#' -- however also note that when generating the survival probabilities it +#' is of course assumed that all individuals in \code{newdata} have not +#' yet experienced the event (that is, any variable in \code{newdataEvent} +#' that corresponds to the event indicator will be ignored). +#' @param newdataLong,newdataEvent An optional data frame (or in the case of +#' \code{newdataLong} this can be a list of data frames) in which to look +#' for variables with which to predict. If omitted, the model matrices are +#' used. If new data is provided, then it should also contain the longitudinal +#' outcome data on which to condition when drawing the new group-specific +#' coefficients for individuals in the new data unless the \code{dynamic} +#' argument is set to \code{FALSE}. Note that there is only +#' allowed to be one row of data for each individual in \code{newdataEvent}, #' that is, time-varying covariates are not allowed in the prediction data for #' the event submodel. Also, \code{newdataEvent} can optionally include a #' variable with information about the last known survival time for the new #' individuals -- see the description for the \code{last_time} argument below -#' -- however also note that when generating the survival probabilities it -#' is of course assumed that all individuals in \code{newdataEvent} have not -#' yet experienced the event (that is, any variable in \code{newdataEvent} that -#' corresponds to the event indicator will be ignored). -#' @param extrapolate A logical specifying whether to extrapolate the estimated +#' -- however also note that when generating the survival probabilities it +#' is of course assumed that all individuals in \code{newdataEvent} have not +#' yet experienced the event (that is, any variable in \code{newdataEvent} +#' that corresponds to the event indicator will be ignored). +#' @param type The type of prediction to return. The following are currently +#' allowed: +#' \itemize{ +#' \item \code{"surv"}: the estimated survival probability. +#' \item \code{"cumhaz"}: the estimated cumulative hazard. +#' \item \code{"haz"}: the estimated hazard rate. +#' \item \code{"cdf"}: the estimated failure probability. +#' \item \code{"logsurv"}: the estimated log survival probability. +#' \item \code{"logcumhaz"}: the estimated log cumulative hazard. +#' \item \code{"loghaz"}: the estimated log hazard rate. +#' \item \code{"logcdf"}: the estimated log failure probability. +#' } +#' @param extrapolate A logical specifying whether to extrapolate the estimated #' survival probabilities beyond the times specified in the \code{times} argument. #' If \code{TRUE} then the extrapolation can be further controlled using #' the \code{control} argument. #' @param control A named list with parameters controlling extrapolation #' of the estimated survival function when \code{extrapolate = TRUE}. The list #' can contain one or more of the following named elements: \cr -#' \describe{ -#' \item{\code{epoints}}{a positive integer specifying the number of -#' discrete time points at which to calculate the forecasted survival -#' probabilities. The default is 10.} -#' \item{\code{edist}}{a positive scalar specifying the amount of time -#' across which to forecast the estimated survival function, represented -#' in units of the time variable \code{time_var} (from fitting the model). -#' The default is to extrapolate between the times specified in the -#' \code{times} argument and the maximum event or censoring time in the -#' original data. If \code{edist} leads to times that are beyond -#' the maximum event or censoring time in the original data then the -#' estimated survival probabilities will be truncated at that point, since -#' the estimate for the baseline hazard is not available beyond that time.} -#' } -#' @param condition A logical specifying whether the estimated -#' subject-specific survival probabilities at time \code{t} should be -#' conditioned on survival up to a fixed time point \code{u}. The default +#' \itemize{ +#' \item \code{epoints}: a positive integer specifying the number of +#' discrete time points at which to calculate the forecasted survival +#' probabilities. The default is 100. +#' \item \code{edist}: a positive scalar specifying the amount of time +#' across which to forecast the estimated survival function, represented +#' in the same units of time as were used for the event times in the fitted +#' model. The default is to extrapolate between the times specified in the +#' \code{times} argument and the maximum event or censoring time found in +#' the original data used to fit the model. If \code{edist} leads to times +#' that are beyond the maximum event or censoring time in the original data +#' then the estimated survival probabilities will be truncated at that +#' point, since an estimate for the baseline hazard is not available +#' beyond that time. +#' } +#' @param condition A logical specifying whether the estimated +#' subject-specific survival probabilities at time \code{t} should be +#' conditioned on survival up to a fixed time point \code{u}. The default #' is for \code{condition} to be set to \code{TRUE}, unless standardised survival #' probabilities have been requested (by specifying \code{standardise = TRUE}), #' in which case \code{condition} must (and will) be set to \code{FALSE}. @@ -115,10 +130,11 @@ #' If standardised survival probabilities are requested (i.e. #' \code{standardise = TRUE}) then conditional survival probabilities are #' not allowed and therefore the \code{last_time} argument is ignored. -#' @param ids An optional vector specifying a subset of IDs for whom the -#' predictions should be obtained. The default is to predict for all individuals -#' who were used in estimating the model or, if \code{newdataLong} and -#' \code{newdataEvent} are specified, then all individuals contained in +#' @param ids For \code{stan_jm} models. An optional vector specifying +#' a subset of IDs for whom the predictions should be obtained. +#' The default is to predict for all individuals +#' who were used in estimating the model or, if \code{newdataLong} and +#' \code{newdataEvent} are specified, then all individuals contained in #' the new data. #' @param prob A scalar between 0 and 1 specifying the width to use for the #' uncertainty interval (sometimes called credible interval) for the predictions. @@ -144,62 +160,139 @@ #' then the \code{times} argument must be specified and it must be constant across #' individuals, that is, the survival probabilities must be calculated at the #' same time for all individuals. -#' @param dynamic A logical that is only relevant if new data is provided -#' via the \code{newdataLong} and \code{newdataEvent} arguments. If -#' \code{dynamic = TRUE}, then new group-specific parameters are drawn for -#' the individuals in the new data, conditional on their longitudinal +#' @param dynamic A logical that is only relevant for \code{stan_jm} models +#' when new data is provided via the \code{newdataLong} and \code{newdataEvent} +#' arguments. If \code{dynamic = TRUE}, then new group-specific parameters are +#' drawn for the individuals in the new data, conditional on their longitudinal #' biomarker data contained in \code{newdataLong}. These group-specific #' parameters are then used to generate individual-specific survival probabilities #' for these individuals. These are often referred to as "dynamic predictions" #' in the joint modelling context, because the predictions can be updated #' each time additional longitudinal biomarker data is collected on the individual. #' On the other hand, if \code{dynamic = FALSE} then the survival probabilities -#' will just be marginalised over the distribution of the group-specific -#' coefficients; this will mean that the predictions will incorporate all -#' uncertainty due to between-individual variation so there will likely be -#' very wide credible intervals on the predicted survival probabilities. -#' @param scale A scalar, specifying how much to multiply the asymptotic +#' will be obtained by marginalising over the distribution of the group-specific +#' coefficients; this has the benefit that the user does not need to provide +#' longitudinal outcome data for the new individuals, but it will also +#' mean that the survival predictions will incorporate all uncertainty due +#' to between-individual variation in the longitudinal trajectories and so +#' there is likely to be very wide credible intervals on the predicted +#' survival probabilities. +#' @param scale Only relevant for \code{stan_jm} models when new data +#' is supplied and \code{dynamic = TRUE}, in which case new random effects +#' are simulated for the individuals in the new data using a +#' Metropolis-Hastings algorithm. The \code{scale} argument should be a +#' scalar. It specifies how much to multiply the asymptotic #' variance-covariance matrix for the random effects by, which is then #' used as the "width" (ie. variance-covariance matrix) of the multivariate -#' Student-t proposal distribution in the Metropolis-Hastings algorithm. This -#' is only relevant when \code{newdataEvent} is supplied and -#' \code{dynamic = TRUE}, in which case new random effects are simulated -#' for the individuals in the new data using the Metropolis-Hastings algorithm. -#' @param draws An integer indicating the number of MCMC draws to return. -#' The default is to set the number of draws equal to 200, or equal to the -#' size of the posterior sample if that is less than 200. +#' Student-t proposal distribution in the Metropolis-Hastings algorithm. +#' @param draws An integer specifying the number of MCMC draws to use when +#' evaluating the predicted quantities. For \code{stan_surv} models, the +#' default number of draws is 400 (or the size of the posterior sample if +#' that is less than 400). For \code{stan_jm} models, the default number +#' of draws is 200 (or the size of the posterior sample if that is less +#' than 200). The smaller default number of draws for \code{stan_jm} +#' models is because dynamic predictions (when \code{dynamic = TRUE}) +#' can be slow. #' @param seed An optional \code{\link[=set.seed]{seed}} to use. +#' @param return_matrix A logical. If \code{TRUE} then a list of \code{draws} by +#' \code{nrow(newdata)} matrices is returned. Each matrix contains the actual +#' simulations or draws from the posterior predictive distribution. Otherwise +#' if \code{return_matrix} is set to \code{FALSE} (the default) then a +#' data frame is returned. See the \strong{Value} section below for more +#' detail. #' @param ... Currently unused. -#' -#' @note -#' Note that if any variables were transformed (e.g. rescaled) in the data -#' used to fit the model, then these variables must also be transformed in -#' \code{newdataLong} and \code{newdataEvent}. This only applies if variables -#' were transformed before passing the data to one of the modeling functions and -#' \emph{not} if transformations were specified inside the model formula. -#' -#' @return A data frame of class \code{survfit.stanjm}. The data frame includes -#' columns for each of the following: -#' (i) the median of the posterior predictions of the estimated survival -#' probabilities (\code{survpred}); -#' (ii) each of the lower and upper limits of the corresponding uncertainty -#' interval for the estimated survival probabilities (\code{ci_lb} and -#' \code{ci_ub}); -#' (iii) a subject identifier (\code{id_var}), unless standardised survival -#' probabilities were estimated; -#' (iv) the time that the estimated survival probability is calculated for -#' (\code{time_var}). +#' +#' @details +#' By default, the predicted quantities are evaluated conditional on observed +#' values of the fixed effect covariates. That is, predictions will be +#' obtained using either: +#' \itemize{ +#' \item the design matrices used in the original \code{\link{stan_surv}} +#' or \code{\link{stan_jm}} model call, or +#' \item the covariate values provided in the \code{newdata} argument +#' (or \code{newdataLong} and \code{newdataEvent} arugments for the +#' \code{stanjm} method). +#' } +#' However, if you wish to average over the observed distribution +#' of the fixed effect covariates then this is possible -- such predictions +#' are sometimes referred to as standardised survival probabilties -- see the +#' \code{standardise} argument. +#' +#' For \code{stansurv} objects, the predicted quantities are calculated for +#' \emph{each row of the prediction data}, at the specified \code{times} as +#' well as any times generated through extrapolation (when +#' \code{extrapolate = TRUE}). +#' +#' For \code{stanjm} objects, the predicted quantities are calculated for +#' \emph{each individual}, at the specified \code{times} as well as any times +#' generated through extrapolation (when \code{extrapolate = TRUE}). +#' +#' \subsection{Dynamic versus marginalised predictions}{ +#' The following also applies for \code{stanjm} objects. +#' By default the survival probabilities are conditional on an individual's +#' group-specific coefficients (i.e. their individual-level random +#' effects). If prediction data is provided via the \code{newdataLong} +#' and \code{newdataEvent} arguments, then the default behaviour is to +#' sample new group-specific coefficients for the individuals in the +#' new data using a Monte Carlo scheme that conditions on their +#' longitudinal outcome data provided in \code{newdataLong} +#' (sometimes referred to as "dynamic predictions", see Rizopoulos +#' (2011)). This default behaviour can be stopped by specifying +#' \code{dynamic = FALSE}, in which case the predicted survival +#' probabilities will be marginalised over the distribution of the +#' group-specific coefficients. This has the benefit that the user does +#' not need to provide longitudinal outcome measurements for the new +#' individuals, however, it does mean that the predictions will incorporate +#' all the uncertainty associated with between-individual variation in the +#' biomarker (longitudinal outcome) values since the predictions aren't +#' conditional on any observed biomarker (longitudinal outcome) data for +#' the individual. +#' } +#' +#' @note +#' Note that if any variables were transformed (e.g. rescaled) in the data +#' used to fit the model, then these variables must also be transformed in +#' \code{newdataLong} and \code{newdataEvent}. This only applies if variables +#' were transformed before passing the data to one of the modeling functions +#' and \emph{not} if transformations were specified inside the model formula. +#' +#' @return When \code{return_matrix = FALSE} (the default), a data frame of +#' class \code{survfit.stansurv} or \code{survfit.stanjm}. The data frame +#' includes columns for each of the following: +#' (i) the median of the posterior predictions (\code{median}); +#' (ii) each of the lower and upper limits of the corresponding uncertainty +#' interval for the posterior predictions (\code{ci_lb} and \code{ci_ub}); +#' (iii) an observation identifier (for \code{stan_surv} models) or an +#' individual identifier (for \code{stan_jm} models), unless standardised +#' predictions were requested; +#' (iv) the time that the prediction corresponds to (\code{time}). +#' (v) the last known survival time on which the prediction is conditional +#' (\code{cond_time}); this will be set to NA if not relevant. #' The returned object also includes a number of additional attributes. -#' -#' @seealso \code{\link{plot.survfit.stanjm}} for plotting the estimated survival -#' probabilities, \code{\link{ps_check}} for for graphical checks of the estimated -#' survival function, and \code{\link{posterior_traj}} for estimating the -#' marginal or subject-specific longitudinal trajectories, and -#' \code{\link{plot_stack_jm}} for combining plots of the estimated subject-specific -#' longitudinal trajectory and survival function. -#' -#' @references -#' Rizopoulos, D. (2011). Dynamic predictions and prospective accuracy in +#' +#' When \code{return_matrix = TRUE} a list of matrices is returned. Each +#' matrix contains the predictions evaluated at one step of the +#' extrapolation time sequence (note that if \code{extrapolate = FALSE} +#' then the list will be of length one, i.e. the predictions are only +#' evaluated at \code{times} which corresponds to just one time point +#' for each individual). Each matrix will have \code{draws} rows and +#' \code{nrow(newdata)} columns, such that each row contains a +#' vector of predictions generated using a single draw of the model +#' parameters from the posterior distribution. The returned +#' list also includes a number of additional attributes. +#' +#' @seealso +#' \code{\link{plot.survfit.stanjm}} for plotting the estimated survival +#' probabilities \cr +#' \code{\link{ps_check}} for for graphical checks of the estimated +#' survival function \cr +#' \code{\link{posterior_traj}} for estimating the +#' marginal or subject-specific longitudinal trajectories \cr +#' \code{\link{plot_stack_jm}} for combining plots of the estimated +#' subject-specific longitudinal trajectory and survival function +#' +#' @references +#' Rizopoulos, D. (2011). Dynamic predictions and prospective accuracy in #' joint models for longitudinal and time-to-event data. \emph{Biometrics} #' \strong{67}, 819. #' @@ -261,26 +354,256 @@ #' times = 0, extrapolate = TRUE) #' plot(ps4) #' } -#' } -posterior_survfit <- function(object, newdataLong = NULL, newdataEvent = NULL, - extrapolate = TRUE, control = list(), - condition = NULL, last_time = NULL, prob = 0.95, - ids, times = NULL, standardise = FALSE, - dynamic = TRUE, scale = 1.5, - draws = NULL, seed = NULL, ...) { +#' +posterior_survfit <- function(object, ...) UseMethod("posterior_survfit") + +#' @rdname posterior_survfit +#' @method posterior_survfit stansurv +#' @export +#' +posterior_survfit.stansurv <- function(object, + newdata = NULL, + type = "surv", + extrapolate = TRUE, + control = list(), + condition = FALSE, + last_time = NULL, + prob = 0.95, + times = NULL, + standardise = FALSE, + draws = NULL, + seed = NULL, + return_matrix = FALSE, + ...) { + + validate_stansurv_object(object) + + basehaz <- object$basehaz + + if (!is.null(seed)) + set.seed(seed) + + if (is.null(newdata) && object$ndelayed) + stop("Prediction data for 'posterior_survfit' cannot include delayed ", + "entry. If you estimated a model with delayed entry, you will ", + "not be able to obtain predictions using the estimation data frame. ", + "You must provide prediction data via the 'newdata' argument, and ", + "indicate delayed entry via the 'last_time' argument.") + + dots <- list(...) + + if ("newdataEvent" %in% names(dots)) + stop("The argument 'newdataEvent' should not be specified when ", + "predicting for stan_surv models. Perhaps you meant to specify ", + "'newdata' instead of 'newdataEvent'.") + + if ("newdataLong" %in% names(dots)) + stop("The argument 'newdataLong' should not be specified when ", + "predicting for stan_surv models.") + + newdata <- validate_newdata(object, newdata = newdata) + has_newdata <- not.null(newdata) + + # Obtain a vector of unique subject ids + if (is.null(newdata)) { + id_list <- seq(nrow(get_model_data(object))) + } else { + id_list <- seq(nrow(newdata)) + } + + # Error checks for conditional predictions + if (condition) { + if (standardise) + stop("'condition' cannot be TRUE for standardised predictions.") + if (type %in% c("haz", "loghaz")) + stop("'condition' cannot be TRUE when 'type = \"", type, "\"'.") + } + + # Last known survival time for each individual + if (is.null(newdata)) { # user did not specify newdata + if (!is.null(last_time)) + stop("'last_time' cannot be provided when newdata is NULL, since times ", + "are taken to be the event or censoring time for each individual.") + last_time <- object$eventtime + } else { # user specified newdata + if (is.null(last_time)) { # assume at risk from time zero + last_time <- rep(0, length(id_list)) + } else if (is.string(last_time)) { + if (!last_time %in% colnames(newdata)) + stop("Cannot find 'last_time' column named in newdata") + last_time <- newdata[[last_time]] + } else if (is.scalar(last_time)) { + last_time <- rep(last_time, nrow(newdata)) + } else if (any(!is.numeric(last_time), !length(last_time) == nrow(newdata))) { + stop("Bug found: could not reconcile 'last_time' argument.") + } + names(last_time) <- as.character(id_list) + } + + # Prediction times + if (standardise) { # standardised survival probs + times <- + if (is.null(times)) { + stop("'times' cannot be NULL for obtaining standardised survival probabilities.") + } else if (is.scalar(times)) { + rep(times, length(id_list)) + } else { + stop("'times' should be a numeric vector of length 1 in order to obtain ", + "standardised survival probabilities (the subject-specific survival ", + "probabilities will be calculated at the specified time point, and ", + "then averaged).") + } + } else if (is.null(newdata)) { # subject-specific survival probs without newdata + times <- + if (is.null(times)) { + object$eventtime + } else if (is.scalar(times)) { + rep(times, length(id_list)) + } else { + stop("If newdata is NULL then 'times' must be NULL or a single number.") + } + } else { # subject-specific survival probs with newdata + times <- + if (is.null(times)) { + times <- last_time + } else if (is.scalar(times)) { + rep(times, length(id_list)) + } else if (is.string(times)) { + if (!times %in% colnames(newdata)) + stop("Variable specified in 'times' argument could not be found in newdata.") + times <- newdata[[times]] + } else { + stop("If newdata is specified then 'times' can only be the name of a ", + "variable in newdata, or a single number.") + } + } + + maxtime <- max(object$eventtime) + if (any(times > maxtime)) + stop("'times' are not allowed to be greater than the last event or ", + "censoring time (since unable to extrapolate the baseline hazard).") + + # User specified extrapolation + if (extrapolate) { + control <- extrapolation_control(control, ok_args = c("epoints", "edist")) + if (not.null(control$edist)) { + endtime <- times + control$edist + } else { + endtime <- maxtime + } + endtime <- truncate(endtime, upper = maxtime) + time_seq <- get_time_seq(control$epoints, times, endtime, simplify = FALSE) + } else { + time_seq <- list(times) # no extrapolation + } + + # Get stanmat parameter matrix for specified number of draws + stanmat <- sample_stanmat(object, draws = draws, default_draws = 400) + pars <- extract_pars(object, stanmat) + + # Calculate survival probability at each increment of extrapolation sequence + surv <- lapply(time_seq, .pp_calculate_surv, + object = object, + newdata = newdata, + pars = pars, + type = type, + standardise = standardise) + + # Calculate survival probability at last known survival time and then + # use that to calculate conditional survival probabilities + if (condition) { + cond_surv <- .pp_calculate_surv(last_time, + object = object, + newdata = newdata, + pars = pars, + type = type) + surv <- lapply(surv, function(x) truncate(x / cond_surv, upper = 1)) + attr(surv, "last_time") <- last_time + } + + # Optionally return draws rather than summarising into median and CI + if (return_matrix) { + return(structure(surv, + type = type, + extrapolate = extrapolate, + control = control, + condition = condition, + standardise = standardise, + last_time = if (condition) last_time else NULL, + ids = id_list, + draws = NROW(stanmat), + seed = seed)) + } + + # Summarise posterior draws to get median and CI + out <- .pp_summarise_surv(surv = surv, + prob = prob, + standardise = standardise) + + # Add attributes + structure(out, + id_var = attr(out, "id_var"), + time_var = attr(out, "time_var"), + type = type, + extrapolate = extrapolate, + control = control, + condition = condition, + standardise = standardise, + last_time = if (condition) last_time else NULL, + ids = id_list, + draws = NROW(stanmat), + seed = seed, + class = c("survfit.stansurv", "data.frame")) +} + +#' @rdname posterior_survfit +#' @method posterior_survfit stanjm +#' @export +#' +posterior_survfit.stanjm <- function(object, + newdataLong = NULL, + newdataEvent = NULL, + type = "surv", + extrapolate = TRUE, + control = list(), + condition = NULL, + last_time = NULL, + prob = 0.95, + ids, + times = NULL, + standardise = FALSE, + dynamic = TRUE, + scale = 1.5, + draws = NULL, + seed = NULL, + return_matrix = FALSE, + ...) { + validate_stanjm_object(object) + M <- object$n_markers id_var <- object$id_var time_var <- object$time_var basehaz <- object$basehaz assoc <- object$assoc family <- family(object) - if (!is.null(seed)) + + if (!is.null(seed)) set.seed(seed) if (missing(ids)) ids <- NULL + dots <- list(...) - + + if ("newdata" %in% names(dots)) + stop("The argument 'newdata' should not be specified when predicting ", + "for stan_jm models. You should specify 'newdataLong' and ", + "'newdataEvent' instead of 'newdata'.") + + # Temporarily only allow survprob for stan_jm until refactoring is done + if (!type == "surv") + stop("Currently only 'type = \"surv\"' is allowed for stanjm models.") + # Temporary stop, until make_assoc_terms can handle it sel_stop <- grep("^shared", rownames(object$assoc)) if (any(unlist(object$assoc[sel_stop,]))) @@ -292,22 +615,21 @@ posterior_survfit <- function(object, newdataLong = NULL, newdataEvent = NULL, # ndE: dataEvent to be used in predictions if (!identical(is.null(newdataLong), is.null(newdataEvent))) stop("Both newdataLong and newdataEvent must be supplied together.") + has_newdata <- not.null(newdataEvent) if (is.null(newdataLong)) { # user did not specify newdata dats <- get_model_data(object) ndL <- dats[1:M] ndE <- dats[["Event"]] } else { # user specified newdata - if (!dynamic) - stop2("Marginalised predictions for the event outcome are ", - "not currently implemented.") - newdatas <- validate_newdatas(object, newdataLong, newdataEvent) + newdatas <- validate_newdatas(object, newdataLong, newdataEvent, + response = dynamic, needs_time_var = dynamic) ndL <- newdatas[1:M] ndE <- newdatas[["Event"]] } if (!is.null(ids)) { # user specified a subset of ids - ndL <- subset_ids(object, ndL, ids) - ndE <- subset_ids(object, ndE, ids) - } + ndL <- subset_ids(ndL, ids, id_var) + ndE <- subset_ids(ndE, ids, id_var) + } id_list <- factor(unique(ndE[[id_var]])) # order of ids from data, not ids arg # Last known survival time for each individual @@ -384,8 +706,8 @@ posterior_survfit <- function(object, newdataLong = NULL, newdataEvent = NULL, # User specified extrapolation if (extrapolate) { - ok_control_args <- c("epoints", "edist") - control <- get_extrapolation_control(control, ok_control_args = ok_control_args) + ok_args <- c("epoints", "edist") + control <- extrapolation_control(control, ok_args = ok_args) endtime <- if (!is.null(control$edist)) times + control$edist else maxtime endtime[endtime > maxtime] <- maxtime # nothing beyond end of baseline hazard time_seq <- get_time_seq(control$epoints, times, endtime, simplify = FALSE) @@ -393,30 +715,25 @@ posterior_survfit <- function(object, newdataLong = NULL, newdataEvent = NULL, # Conditional survival times if (is.null(condition)) { - condition <- !standardise + condition <- ifelse(type == "surv", !standardise, FALSE) } else if (condition && standardise) { stop("'condition' cannot be set to TRUE if standardised survival ", "probabilities are requested.") } # Get stanmat parameter matrix for specified number of draws - S <- posterior_sample_size(object) - if (is.null(draws)) - draws <- if (S > 200L) 200L else S - if (draws > S) - stop("'draws' should be <= posterior sample size (", S, ").") - stanmat <- as.matrix(object$stanfit) - some_draws <- isTRUE(draws < S) - if (some_draws) { - samp <- sample(S, draws) - stanmat <- stanmat[samp, , drop = FALSE] - } + stanmat <- sample_stanmat(object, draws = draws, default_draws = 200) # Draw b pars for new individuals - if (dynamic && !is.null(newdataEvent)) { - stanmat <- simulate_b_pars(object, stanmat = stanmat, ndL = ndL, ndE = ndE, - ids = id_list, times = last_time, scale = scale) - b_new <- attr(stanmat, "b_new") + if (dynamic && has_newdata) { + stanmat <- simulate_b_pars(object, + stanmat = stanmat, + ndL = ndL, + ndE = ndE, + ids = id_list, + times = last_time, + scale = scale) + b_new <- attr(stanmat, "b_new") acceptance_rate <- attr(stanmat, "acceptance_rate") } @@ -424,102 +741,349 @@ posterior_survfit <- function(object, newdataLong = NULL, newdataEvent = NULL, # Matrix of surv probs at each increment of the extrapolation sequence # NB If no extrapolation then length(time_seq) == 1L - surv_t <- lapply(time_seq, function(t) { - if (!identical(length(t), length(id_list))) - stop("Bug found: the vector of prediction times is not the same length ", - "as the number of individuals.") - dat <- .pp_data_jm(object, newdataLong = ndL, newdataEvent = ndE, - ids = id_list, etimes = t, long_parts = FALSE) - surv_t <- .ll_survival(object, data = dat, pars = pars, survprob = TRUE) - if (is.vector(surv_t) == 1L) - surv_t <- t(surv_t) # transform if only one individual - surv_t[, (t == 0)] <- 1 # avoids possible NaN due to numerical inaccuracies - if (standardise) { # standardised survival probs - surv_t <- matrix(rowMeans(surv_t), ncol = 1) - dimnames(surv_t) <- list(iterations = NULL, "standardised_survprob") - } else { - dimnames(surv_t) <- list(iterations = NULL, ids = id_list) - } - surv_t - }) - - # If conditioning, need to obtain matrix of surv probs at last known surv time + surv_t <- lapply(time_seq, .pp_calculate_surv, + object = object, + newdataLong = ndL, + newdataEvent = ndE, + pars = pars, + type = type, + id_list = id_list, + standardise = standardise, + dynamic = dynamic) + + # Calculate survival probability at last known survival time and then + # use that to calculate conditional survival probabilities if (condition) { - cond_dat <- .pp_data_jm(object, newdataLong = ndL, newdataEvent = ndE, - ids = id_list, etimes = last_time, long_parts = FALSE) - # matrix of survival probs at last_time - cond_surv <- .ll_survival(object, data = cond_dat, pars = pars, survprob = TRUE) - if (is.vector(cond_surv) == 1L) - cond_surv <- t(cond_surv) # transform if only one individual - cond_surv[, (last_time == 0)] <- 1 # avoids possible NaN due to numerical inaccuracies - surv <- lapply(surv_t, function(x) { # conditional survival probs - vec <- x / cond_surv - vec[vec > 1] <- 1 # if t was before last_time then surv prob may be > 1 - vec - }) - } else surv <- surv_t - - # Summarise posterior draws to get median and ci - out <- do.call("rbind", lapply( - seq_along(surv), function(x, standardise, id_list, time_seq, prob) { - val <- median_and_bounds(surv[[x]], prob, na.rm = TRUE) - if (standardise) { - data.frame(TIMEVAR = unique(time_seq[[x]]), val$med, val$lb, val$ub) - } else - data.frame(IDVAR = id_list, TIMEVAR = time_seq[[x]], val$med, val$lb, val$ub) - }, standardise, id_list, time_seq, prob)) - out <- data.frame(out) - colnames(out) <- c(if ("IDVAR" %in% colnames(out)) id_var, - time_var, "survpred", "ci_lb", "ci_ub") - if (id_var %in% colnames(out)) { # data has id column -- sort by id and time - out <- out[order(out[, id_var], out[, time_var]), , drop = FALSE] - } else { # data does not have id column -- sort by time only - out <- out[order(out[, time_var]), , drop = FALSE] + if (!type == "surv") + stop("'condition' can only be set to TRUE for survival probabilities.") + cond_surv <- .pp_calculate_surv(last_time, + object = object, + newdataLong = ndL, + newdataEvent = ndE, + pars = pars, + type = type, + id_list = id_list, + dynamic = dynamic) + surv <- lapply(surv_t, function(x) truncate(x / cond_surv, upper = 1)) + attr(surv, "last_time") <- last_time + } else { + surv <- surv_t } - rownames(out) <- NULL - - # temporary hack so that predictive_error can call posterior_survfit + + # Optionally return draws rather than summarising into median and CI + if (return_matrix) { + return(structure(surv, + type = type, + extrapolate = extrapolate, + control = control, + condition = condition, + standardise = standardise, + last_time = if (condition) last_time else NULL, + ids = id_list, + draws = NROW(stanmat), + seed = seed)) + } + + # Summarise posterior draws to get median and CI + out <- .pp_summarise_surv(surv = surv, + prob = prob, + id_var = id_var, + time_var = time_var, + standardise = standardise) + + # Temporary hack so that 'predictive_error' can call 'posterior_survfit' # with two separate conditioning times... - fn <- tryCatch(sys.call(-1)[[1]], error = function(e) NULL) - if (!is.null(fn) && - grepl("predictive_error", deparse(fn), fixed = TRUE) && - "last_time2" %in% names(dots)) { - last_time2 <- ndE[[dots$last_time2]] - cond_dat2 <- .pp_data_jm(object, newdataLong = ndL, newdataEvent = ndE, - ids = id_list, etimes = last_time2, long_parts = FALSE) - cond_surv2 <- .ll_survival(object, data = cond_dat2, pars = pars, survprob = TRUE) - if (is.vector(cond_surv2) == 1L) - cond_surv2 <- t(cond_surv2) # transform if only one individual - cond_surv2[, (last_time2 == 0)] <- 1 # avoids possible NaN due to numerical inaccuracies - surv2 <- lapply(surv_t, function(x) { # conditional survival probs - vec <- x / cond_surv2 - vec[vec > 1] <- 1 # if t was before last_time then surv prob may be > 1 - vec - }) - out2 <- do.call("rbind", lapply( - seq_along(surv2), function(x, standardise, id_list, time_seq, prob) { - val <- median_and_bounds(surv2[[x]], prob, na.rm = TRUE) - data.frame(IDVAR = id_list, TIMEVAR = time_seq[[x]], val$med) - }, standardise, id_list, time_seq, prob)) - out2 <- data.frame(out2) - colnames(out2) <- c(id_var, time_var, "survpred_eventtime") - out2 <- out2[order(out2[, id_var, drop = F], out2[, time_var, drop = F]), , drop = F] - rownames(out2) <- NULL + fun_check <- isTRUE(grepl("predictive_error", get_calling_fun(), fixed = TRUE)) + dot_check <- isTRUE("last_time2" %in% names(dots)) + if (fun_check && dot_check) { + if (!type == "surv") + stop("'last_time2' can only be specified for survival probabilities.") + cond_surv2 <- .pp_calculate_surv(ndE[[dots$last_time2]], + object = object, + newdataLong = ndL, + newdataEvent = ndE, + pars = pars, + type = type, + id_list = id_list, + dynamic = dynamic) + surv2 <- lapply(surv_t, function(x) truncate(x / cond_surv2, upper = 1)) + out2 <- .pp_summarise_surv(surv = surv2, + prob = prob, + id_var = id_var, + time_var = time_var, + standardise = standardise, + colnames = "survprob_eventtime") out <- merge(out, out2) } - - class(out) <- c("survfit.stanjm", "data.frame") - out <- structure(out, id_var = id_var, time_var = time_var, extrapolate = extrapolate, - control = control, standardise = standardise, condition = condition, - last_time = last_time, ids = id_list, draws = draws, seed = seed, - offset = offset) - if (dynamic && !is.null(newdataEvent)) { + + # Return object + out <- structure(out, + id_var = id_var, + time_var = time_var, + type = type, + extrapolate = extrapolate, + control = control, + standardise = standardise, + condition = condition, + last_time = if (condition) last_time else NULL, + ids = id_list, + draws = NROW(stanmat), + seed = seed, + offset = offset, + class = c("survfit.stanjm", "data.frame")) + + if (dynamic && has_newdata) { out <- structure(out, b_new = b_new, acceptance_rate = acceptance_rate) } + out } +# ----------------- internal ------------------------------------------------ + +# Calculate the desired prediction (e.g. hazard, cumulative hazard, survival +# probability) at the specified times +.pp_calculate_surv <- function(times, + object, + newdata = NULL, + newdataLong = NULL, + newdataEvent = NULL, + pars, + type = "surv", + id_list = NULL, + standardise = FALSE, + dynamic = TRUE) { + + if (is.stanjm(object) && !identical(length(times), length(id_list))) + stop("Bug found: vector of ids should be same length as vector of times.") + + # Determine whether prediction type requires quadrature + needs_quadrature <- type %in% c("cumhaz", + "surv", + "cdf", + "logcumhaz", + "logsurv", + "logcdf") + + # Evaluate hazard, cumulative hazard, survival or failure probability + if (is.stansurv(object)) { + ppdat <- .pp_data_surv(object, + newdata = newdata, + times = times, + at_quadpoints = needs_quadrature) + out <- .pp_predict_surv(object, + data = ppdat, + pars = pars, + type = type) + } else if (is.stanjm(object)) { + ppdat <- .pp_data_jm(object, + newdataLong = newdataLong, + newdataEvent = newdataEvent, + ids = id_list, + etimes = times, + long_parts = FALSE, + response = dynamic, + needs_time_var = dynamic) + out <- .ll_survival(object, # refactoring for stanjm not yet finished + data = ppdat, + pars = pars, + survprob = TRUE) + } + + # Transform if only one individual + out <- transpose_vector(out) + + # Set survival probability == 1 if time == 0 (avoids possible NaN) + if (type == "surv") + out <- replace_where(out, times == 0, replacement = 1, margin = 2L) + + # Standardisation: within each iteration, calculate mean across individuals + if (standardise) { + out <- row_means(out) + ids <- "standardised_survprob" + times <- unique(times) + } else { + ids <- if (is.null(id_list)) seq(ncol(out)) else id_list + } + dimnames(out) <- list(iterations = NULL, ids = ids) + + # Add subject ids and prediction times as an attribute + structure(out, ids = ids, times = times) +} + + +# Evaluate hazard, cumulative hazard, survival or failure probability +# +# @param object A stansurv or stanjm object. +# @param data Output from .pp_data_surv or .pp_data_jm. +# @param pars Output from extract_pars. +# @param type The type of prediction quantity to return. +.pp_predict_surv <- function(object, ...) UseMethod(".pp_predict_surv") + +.pp_predict_surv.stansurv <- function(object, + data, + pars, + type = "surv") { + + args <- nlist(basehaz = get_basehaz(object), + intercept = pars$alpha, + betas = pars$beta, + betas_tve = pars$beta_tve, + b = pars$b, + aux = pars$aux, + times = data$pts, + x = data$x, + s = data$s, + z = data$z) + + if (type %in% c("loghaz", "haz")) { + # evaluate hazard; quadrature not relevant + lhaz <- do.call(evaluate_log_haz, args) + } else if (!data$has_quadrature){ + # evaluate survival; without quadrature + lsurv <- do.call(evaluate_log_surv, args) + } else { + # evaluate survival; with quadrature + lhaz <- do.call(evaluate_log_haz, args) + lsurv <- -quadrature_sum(exp(lhaz), qnodes = data$qnodes, qwts = data$wts) + } + + switch(type, + loghaz = lhaz, + logcumhaz = log(-lsurv), + logsurv = lsurv, + logcdf = log(1 - exp(lsurv)), + haz = exp(lhaz), + surv = exp(lsurv), + cumhaz = -lsurv, + cdf = 1 - exp(lsurv), + stop("Invalid input to the 'type' argument.")) +} + + +# Summarise predictions into median, lower CI, upper CI +# +# @details Convert a list of matrices (with each element being a S by N matrix, +# where S is the number of MCMC draws and N the number of individuals) +# and collapse it across the MCMC iterations summarising it into median +# and CI. The result is a data frame with K times N rows, where K was +# the length of the original list. +.pp_summarise_surv <- function(surv, + prob = NULL, + id_var = NULL, + time_var = NULL, + standardise = FALSE, + colnames = NULL) { + + # Default variable names if not provided by the user + if (is.null(id_var)) + id_var <- "id" + if (is.null(time_var)) + time_var <- "time" + + # Define variable name for conditioning time + cond_var <- paste0("cond_", time_var) + + # Extract ids and times for the predictions + ids <- uapply(surv, attr, "ids") + times <- uapply(surv, attr, "times") + + # Extract conditioning time that was used for predictions + last_time <- attr(surv, "last_time") + if (is.null(last_time)) { # if not using conditional survival + last_time <- rep(NA, length(ids)) + } + + # Determine the quantiles corresponding to the median and CI limits + if (is.null(prob)) { + probs <- 0.5 # median only + nms <- c(id_var, cond_var, time_var, "median") + } else { + probs <- c(0.5, (1 - prob)/2, (1 + prob)/2) # median and CI + nms <- c(id_var, cond_var, time_var, "median", "ci_lb", "ci_ub") + } + + # Possibly overide default variable names for the returned data frame + if (!is.null(colnames)) { + nms <- c(id_var, cond_var, time_var, colnames) + } + + # Calculate mean and CI at each prediction time + out <- data.frame(do.call("rbind", lapply(surv, col_quantiles_, probs))) + out <- mutate_(out, id_var = ids, cond_var = last_time, time_var = times) + out <- row_sort(out, id_var, time_var) + out <- col_sort(out, id_var, cond_var, time_var) + out <- set_rownames(out, NULL) + out <- set_colnames(out, nms) + + # Drop excess info if standardised predictions were calculated + if (standardise) { + out[[cond_var]] <- NULL + out[[id_var]] <- NULL + id_var <- NULL + } + + structure(out, + id_var = id_var, + time_var = time_var) +} + + +# ------------ print methods ------------------------------------------------ + +#' Generic print method for \code{survfit.stansurv} and \code{survfit.stanjm} +#' objects +#' +#' @rdname print.survfit.stansurv +#' @method print survfit.stansurv +#' @keywords internal +#' @export +#' @param x An object of class \code{survfit.stansurv} or \code{survfit.stanjm}, +#' returned by a call to \code{\link{posterior_survfit}}. +#' @param digits Number of digits to use for formatting the time variable and +#' the survival probabilities. +#' @param ... Ignored. +#' +print.survfit.stansurv <- function(x, digits = 4, ...) { + + x <- as.data.frame(x) + sel <- c(attr(x, "time_var"), "median", "ci_lb", "ci_ub") + for (i in sel) + x[[i]] <- format(round(x[[i]], digits), nsmall = digits) + + cat("stan_surv predictions\n") + cat(" num. individuals:", length(attr(x, "ids")), "\n") + cat(" prediction type: ", tolower(get_survpred_name(attr(x, "type"))), "\n") + cat(" standardised?: ", yes_no_string(attr(x, "standardise")), "\n") + cat(" conditional?: ", yes_no_string(attr(x, "condition")), "\n\n") + print(x, quote = FALSE) + invisible(x) +} + +#' @rdname print.survfit.stansurv +#' @method print survfit.stanjm +#' @export +#' +print.survfit.stanjm <- function(x, digits = 4, ...) { + + x <- as.data.frame(x) + sel <- c(attr(x, "time_var"), "median", "ci_lb", "ci_ub") + for (i in sel) + x[[i]] <- format(round(x[[i]], digits), nsmall = digits) + + cat("stan_jm predictions\n") + cat(" num. individuals:", length(attr(x, "ids")), "\n") + cat(" prediction type: ", tolower(get_survpred_name(attr(x, "type"))), "\n") + cat(" standardised?: ", yes_no_string(attr(x, "standardise")), "\n") + cat(" conditional?: ", yes_no_string(attr(x, "condition")), "\n\n") + print(x, quote = FALSE) + invisible(x) +} + + +# ----------------- plot methods -------------------------------------------- + #' Plot the estimated subject-specific or marginal survival function #' #' This generic \code{plot} method for \code{survfit.stanjm} objects will @@ -608,30 +1172,30 @@ posterior_survfit <- function(object, newdataLong = NULL, newdataEvent = NULL, #' control = list(epoints = 20)) #' plot(ps2) #' } -#' } -plot.survfit.stanjm <- function(x, ids = NULL, - limits = c("ci", "none"), - xlab = NULL, ylab = NULL, facet_scales = "free", +plot.survfit.stanjm <- function(x, + ids = NULL, + limits = c("ci", "none"), + xlab = NULL, + ylab = NULL, + facet_scales = "free", ci_geom_args = NULL, ...) { - - limits <- match.arg(limits) - ci <- (limits == "ci") + + limits <- match.arg (limits) + ci <- as.logical(limits == "ci") + + type <- attr(x, "type") standardise <- attr(x, "standardise") - id_var <- attr(x, "id_var") - time_var <- attr(x, "time_var") + id_var <- attr(x, "id_var") + time_var <- attr(x, "time_var") + if (is.null(xlab)) xlab <- paste0("Time (", time_var, ")") - if (is.null(ylab)) ylab <- "Event free probability" + if (is.null(ylab)) ylab <- get_survpred_name(type) + if (!is.null(ids)) { if (standardise) stop("'ids' argument cannot be specified when plotting standardised ", "survival probabilities.") - if (!id_var %in% colnames(x)) - stop("Bug found: could not find 'id_var' column in the data frame.") - ids_missing <- which(!ids %in% x[[id_var]]) - if (length(ids_missing)) - stop("The following 'ids' are not present in the survfit.stanjm object: ", - paste(ids[[ids_missing]], collapse = ", "), call. = FALSE) - x <- x[(x[[id_var]] %in% ids), , drop = FALSE] + x <- subset_ids(x, ids, id_var) } else { ids <- if (!standardise) attr(x, "ids") else NULL } @@ -639,38 +1203,62 @@ plot.survfit.stanjm <- function(x, ids = NULL, x$time <- x[[time_var]] geom_defaults <- list(color = "black") - geom_args <- set_geom_args(geom_defaults, ...) - - lim_defaults <- list(alpha = 0.3) - lim_args <- do.call("set_geom_args", c(defaults = list(lim_defaults), ci_geom_args)) - - if ((!standardise) && (length(ids) > 60L)) { + geom_mapp <- list(mapping = aes_string(x = "time", + y = "median")) + geom_args <- do.call("set_geom_args", + c(defaults = list(geom_defaults), list(...))) + + lim_defaults <- list(alpha = 0.3) + lim_mapp <- list(mapping = aes_string(x = "time", + ymin = "ci_lb", + ymax = "ci_ub")) + lim_args <- do.call("set_geom_args", + c(defaults = list(lim_defaults), ci_geom_args)) + + if ((!standardise) && (length(ids) > 60L)) stop("Too many individuals to plot for. Perhaps consider limiting ", "the number of individuals by specifying the 'ids' argument.") - } else if ((!standardise) && (length(ids) > 1L)) { - graph <- ggplot(x, aes_string(x = "time", y = "survpred")) + - theme_bw() + - do.call("geom_line", geom_args) + - coord_cartesian(ylim = c(0, 1)) + + + graph_base <- + ggplot(x) + + theme_bw() + + coord_cartesian(ylim = get_survpred_ylim(type)) + + do.call("geom_line", c(geom_mapp, geom_args)) + + graph_facet <- + if ((!standardise) && (length(ids) > 1L)) { facet_wrap(~ id, scales = facet_scales) + } else NULL + + graph_limits <- if (ci) { - lim_mapp <- list(mapping = aes_string(ymin = "ci_lb", ymax = "ci_ub")) - graph_limits <- do.call("geom_ribbon", c(lim_mapp, lim_args)) - } else graph_limits <- NULL - } else { - graph <- ggplot(x, aes_string(x = "time", y = "survpred")) + - theme_bw() + - do.call("geom_line", geom_args) + - coord_cartesian(ylim = c(0, 1)) - if (ci) { - lim_mapp <- list(mapping = aes_string(ymin = "ci_lb", ymax = "ci_ub")) - graph_limits <- do.call("geom_ribbon", c(lim_mapp, lim_args)) - } else graph_limits <- NULL - } + do.call("geom_ribbon", c(lim_mapp, lim_args)) + } else NULL + + graph_labels <- labs(x = xlab, y = ylab) + + gg <- graph_base + graph_facet + graph_limits + graph_labels + class_gg <- class(gg) + class(gg) <- c("plot.survfit.stanjm", class_gg) + gg +} + - ret <- graph + graph_limits + labs(x = xlab, y = ylab) - class_ret <- class(ret) - class(ret) <- c("plot.survfit.stanjm", class_ret) +#' @rdname plot.survfit.stanjm +#' @method plot survfit.stansurv +#' @export +#' +plot.survfit.stansurv <- function(x, + ids = NULL, + limits = c("ci", "none"), + xlab = NULL, + ylab = NULL, + facet_scales = "free", + ci_geom_args = NULL, ...) { + mc <- match.call(expand.dots = FALSE) + mc[[1L]] <- quote(plot.survfit.stanjm) + ret <- eval(mc) + class(ret)[[1L]] <- "plot.survfit.stansurv" ret } @@ -805,33 +1393,31 @@ plot_stack_jm <- function(yplot, survplot) { } -# ------------------ exported but doc kept internal +# ----------------- helpers ------------------------------------------------ -#' Generic print method for \code{survfit.stanjm} objects -#' -#' @rdname print.survfit.stanjm -#' @method print survfit.stanjm -#' @keywords internal -#' @export -#' @param x An object of class \code{survfit.stanjm}, returned by a call to -#' \code{\link{posterior_survfit}}. -#' @param digits Number of digits to use for formatting the time variable and -#' the survival probabilities. -#' @param ... Ignored. -#' -print.survfit.stanjm <- function(x, digits = 4, ...) { - time_var <- attr(x, "time_var") - x <- as.data.frame(x) - sel <- c(time_var, "survpred", "ci_lb", "ci_ub") - for (i in sel) - x[[i]] <- format(round(x[[i]], digits), nsmall = digits) - print(x, quote = FALSE) - invisible(x) +# Return a user-friendly name for the prediction type +get_survpred_name <- function(x) { + switch(x, + haz = "Hazard rate", + cumhaz = "Cumulative hazard rate", + surv = "Event free probability", + cdf = "Failure probability", + loghaz = "log(Hazard rate)", + logcumhaz = "log(Cumulative hazard rate)", + logsurv = "log(Event free probability)", + logcdf = "log(Failure probability)", + stop("Bug found: invalid input to 'type' argument.")) } -# ------------------ internal +# Return appropriate y-axis limits for the prediction type +get_survpred_ylim <- function(x) { + switch(x, + surv = c(0,1), + cdf = c(0,1), + NULL) +} -# default plotting attributes +# Default plotting attributes .PP_FILL <- "skyblue" .PP_DARK <- "skyblue4" .PP_VLINE_CLR <- "#222222" diff --git a/R/posterior_traj.R b/R/posterior_traj.R index d555fd2e7..0c89d9887 100644 --- a/R/posterior_traj.R +++ b/R/posterior_traj.R @@ -73,13 +73,13 @@ #' event or censoring time if no new data is provided; the time specified #' in the "last_time" column if provided in the new data (see \strong{Details} #' section below); or the time of the last longitudinal measurement if new -#' data is provided but no "last_time" column is included. The default is 15.} +#' data is provided but no "last_time" column is included. The default is 100.} #' \item{\code{epoints}}{a positive integer specifying the number of discrete #' time points at which to calculate the estimated longitudinal response for #' \code{extrapolate = TRUE}. These time points are evenly spaced between the #' last known observation time for each individual and the extrapolation #' distance specifed using either \code{edist} or \code{eprop}. -#' The default is 15.} +#' The default is 100.} #' \item{\code{eprop}}{a positive scalar between 0 and 1 specifying the #' amount of time across which to extrapolate the longitudinal trajectory, #' represented as a proportion of the total observed follow up time for each @@ -273,27 +273,45 @@ #' re.form = NA) #' head(pt8) # note the much narrower ci, compared with pt5 #' } -#' } -posterior_traj <- function(object, m = 1, newdata = NULL, newdataLong = NULL, - newdataEvent = NULL, interpolate = TRUE, extrapolate = FALSE, - control = list(), last_time = NULL, prob = 0.95, ids, - dynamic = TRUE, scale = 1.5, draws = NULL, seed = NULL, - return_matrix = FALSE, ...) { +#' +posterior_traj <- function(object, + m = 1, + newdata = NULL, + newdataLong = NULL, + newdataEvent = NULL, + interpolate = TRUE, + extrapolate = FALSE, + control = list(), + last_time = NULL, + prob = 0.95, + ids, + dynamic = TRUE, + scale = 1.5, + draws = NULL, + seed = NULL, + return_matrix = FALSE, + ...) { + if (!requireNamespace("data.table")) stop("the 'data.table' package must be installed to use this function") + validate_stanjm_object(object) + M <- object$n_markers; validate_positive_scalar(m, M) id_var <- object$id_var time_var <- object$time_var grp_stuff <- object$grp_stuff[[m]] glmod <- object$glmod[[m]] + if (!is.null(seed)) set.seed(seed) + if (missing(ids)) ids <- NULL + dots <- list(...) - # Deal with deprecate newdata argument + # Deal with deprecated newdata argument if (!is.null(newdata)) { warning("The 'newdata' argument is deprecated. Use 'newdataLong' instead.") if (!is.null(newdataLong)) @@ -320,8 +338,8 @@ posterior_traj <- function(object, m = 1, newdata = NULL, newdataLong = NULL, ndE <- dats[["Event"]] } if (!is.null(ids)) { # user specified a subset of ids - ndL <- subset_ids(object, ndL, ids) - ndE <- subset_ids(object, ndE, ids) + ndL <- subset_ids(ndL, ids, id_var) + ndE <- subset_ids(ndE, ids, id_var) } id_list <- factor(unique(ndL[[m]][[id_var]])) # order of ids from data, not ids arg @@ -355,23 +373,18 @@ posterior_traj <- function(object, m = 1, newdata = NULL, newdataLong = NULL, } # Get stanmat parameter matrix for specified number of draws - S <- posterior_sample_size(object) - if (is.null(draws)) - draws <- if (S > 200L) 200L else S - if (draws > S) - stop("'draws' should be <= posterior sample size (", S, ").") - stanmat <- as.matrix(object$stanfit) - some_draws <- isTRUE(draws < S) - if (some_draws) { - samp <- sample(S, draws) - stanmat <- stanmat[samp, , drop = FALSE] - } - + stanmat <- sample_stanmat(object, draws = draws, default_draws = 200) + # Draw b pars for new individuals if (dynamic && !is.null(newdataEvent)) { - stanmat <- simulate_b_pars(object, stanmat = stanmat, ndL = ndL, ndE = ndE, - ids = id_list, times = last_time, scale = scale) - b_new <- attr(stanmat, "b_new") + stanmat <- simulate_b_pars(object, + stanmat = stanmat, + ndL = ndL, + ndE = ndE, + ids = id_list, + times = last_time, + scale = scale) + b_new <- attr(stanmat, "b_new") acceptance_rate <- attr(stanmat, "acceptance_rate") } @@ -379,8 +392,8 @@ posterior_traj <- function(object, m = 1, newdata = NULL, newdataLong = NULL, if (interpolate || extrapolate) { # user specified interpolation or extrapolation if (return_matrix) stop("'return_matrix' cannot be TRUE if 'interpolate' or 'extrapolate' is TRUE.") - ok_control_args <- c("ipoints", "epoints", "edist", "eprop") - control <- get_extrapolation_control(control, ok_control_args = ok_control_args) + ok_args <- c("ipoints", "epoints", "edist", "eprop") + control <- extrapolation_control(control, ok_args = ok_args) dist <- if (!is.null(control$eprop)) control$eprop * (last_time - 0) else control$edist iseq <- if (interpolate) get_time_seq(control$ipoints, 0, last_time) else NULL eseq <- if (extrapolate) get_time_seq(control$epoints, last_time, last_time + dist) else NULL @@ -400,7 +413,7 @@ posterior_traj <- function(object, m = 1, newdata = NULL, newdataLong = NULL, newX <- rolling_merge(newX, time_seq[[id_var]], time_seq[[time_var]]) } } - + if (isTRUE(as.logical(glmod$has_offset))) { # create a temporary data frame with a fake outcome to avoid error response_name <- as.character(formula(object)[[m]])[2] @@ -411,40 +424,72 @@ posterior_traj <- function(object, m = 1, newdata = NULL, newdataLong = NULL, newOffset <- NULL } + # Obtain posterior predictions at specified times ytilde <- posterior_predict(object, newdata = newX, m = m, stanmat = stanmat, offset = newOffset, ...) + + # Optionally return S * N matrix of draws (instead of data frame) if (return_matrix) { attr(ytilde, "mu") <- NULL # remove attribute mu - return(ytilde) # return S * N matrix, instead of data frame - } + return(ytilde) + } + + # Extract draws for the posterior mean mutilde <- attr(ytilde, "mu") if (!is.null(newX) && nrow(newX) == 1L) mutilde <- t(mutilde) + + ytilde_bounds <- median_and_bounds(ytilde, prob) # median and prob% CrI limits mutilde_bounds <- median_and_bounds(mutilde, prob) # median and prob% CrI limits - out <- data.frame(IDVAR = newX[[id_var]], - TIMEVAR = newX[[time_var]], - yfit = mutilde_bounds$med, - ci_lb = mutilde_bounds$lb, ci_ub = mutilde_bounds$ub, - pi_lb = ytilde_bounds$lb, pi_ub = ytilde_bounds$ub) + + # Summarise posterior draws to get median and CI if (grp_stuff$has_grp) { - out$GRPVAR = newX[[grp_var]] # add grp_var and reorder cols - out <- out[, c("IDVAR", "GRPVAR", "TIMEVAR", - "yfit", "ci_lb", "ci_ub", "pi_lb", "pi_ub")] + + nms <- c(id_var, grp_var, time_var, "yfit", "ci_lb", "ci_ub", "pi_lb", "pi_ub") + + out <- data.frame(IDVAR = newX[[id_var]], + TIMEVAR = newX[[time_var]], + GRPVAR = newX[[grp_var]], + yfit = mutilde_bounds$med, + ci_lb = mutilde_bounds$lb, + ci_ub = mutilde_bounds$ub, + pi_lb = ytilde_bounds$lb, + pi_ub = ytilde_bounds$ub) + + } else { + + nms <- c(id_var, time_var, "yfit", "ci_lb", "ci_ub", "pi_lb", "pi_ub") + + out <- data.frame(IDVAR = newX[[id_var]], + TIMEVAR = newX[[time_var]], + yfit = mutilde_bounds$med, + ci_lb = mutilde_bounds$lb, + ci_ub = mutilde_bounds$ub, + pi_lb = ytilde_bounds$lb, + pi_ub = ytilde_bounds$ub) + } - colnames(out) <- c(id_var, if (grp_stuff$has_grp) grp_var, time_var, - "yfit", "ci_lb", "ci_ub", "pi_lb", "pi_ub") - class(out) <- c("predict.stanjm", "data.frame") - Terms <- terms(formula(object, m = m)) - vars <- rownames(attr(Terms, "factors")) - y_var <- vars[[attr(Terms, "response")]] - out <- structure(out, observed_data = ndL[[m]], last_time = last_time, - y_var = y_var, id_var = id_var, time_var = time_var, - grp_var = if (grp_stuff$has_grp) grp_var else NULL, - interpolate = interpolate, extrapolate = extrapolate, - control = control, call = match.call()) + + out <- set_colnames(out, nms) + + # Return object + out <- structure(out, + observed_data = ndL[[m]], + last_time = last_time, + y_var = get_resp_name(object, m = m), + id_var = id_var, + time_var = time_var, + grp_var = if (grp_stuff$has_grp) grp_var else NULL, + interpolate = interpolate, + extrapolate = extrapolate, + control = control, + call = match.call(), + class = c("predict.stanjm", "data.frame")) + if (dynamic && !is.null(newdataEvent)) { out <- structure(out, b_new = b_new, acceptance_rate = acceptance_rate) } + out } @@ -541,193 +586,251 @@ posterior_traj <- function(object, m = 1, newdata = NULL, newdataLong = NULL, #' ggplot2::theme(strip.background = ggplot2::element_blank()) + #' ggplot2::labs(title = "Some plotted longitudinal trajectories") #' } -#' } -plot.predict.stanjm <- function(x, ids = NULL, limits = c("ci", "pi", "none"), - xlab = NULL, ylab = NULL, vline = FALSE, - plot_observed = FALSE, facet_scales = "free_x", - ci_geom_args = NULL, grp_overlay = FALSE, ...) { +plot.predict.stanjm <- function(x, + ids = NULL, + limits = c("ci", "pi", "none"), + xlab = NULL, + ylab = NULL, + vline = FALSE, + plot_observed = FALSE, + facet_scales = "free_x", + ci_geom_args = NULL, + grp_overlay = FALSE, + ...) { limits <- match.arg(limits) - if (!(limits == "none")) ci <- (limits == "ci") - y_var <- attr(x, "y_var") - id_var <- attr(x, "id_var") - time_var <- attr(x, "time_var") - grp_var <- attr(x, "grp_var") - obs_dat <- attr(x, "observed_data") + ci <- as.logical(limits == "ci") + pi <- as.logical(limits == "pi") + + y_var <- attr(x, "y_var") # outcome variable + i_var <- attr(x, "id_var") # id variable + g_var <- attr(x, "grp_var") # cluster variable + t_var <- attr(x, "time_var") # time variable + obs_dat <- attr(x, "observed_data") # observed data + fit_dat <- x # predicted data + if (is.null(ylab)) ylab <- paste0("Long. response (", y_var, ")") - if (is.null(xlab)) xlab <- paste0("Time (", time_var, ")") - if (!id_var %in% colnames(x)) + if (is.null(xlab)) xlab <- paste0("Time (", t_var, ")") + + if (!i_var %in% colnames(x)) stop("Bug found: could not find 'id_var' column in the data frame.") - if (!is.null(grp_var) && (!grp_var %in% colnames(x))) + + if (!is.null(g_var) && (!g_var %in% colnames(x))) stop("Bug found: could not find 'grp_var' column in the data frame.") + + # subset data if only plotting for some individuals if (!is.null(ids)) { - ids_missing <- which(!ids %in% x[[id_var]]) - if (length(ids_missing)) - stop("The following 'ids' are not present in the predict.stanjm object: ", - paste(ids[ids_missing], collapse = ", "), call. = FALSE) - plot_dat <- x[x[[id_var]] %in% ids, , drop = FALSE] - obs_dat <- obs_dat[obs_dat[[id_var]] %in% ids, , drop = FALSE] - } else { - plot_dat <- x + fit_dat <- subset_ids(fit_dat, ids, i_var) + obs_dat <- subset_ids(obs_dat, ids, i_var) + } + + # deal with outcome data if plotting observed data points + if (plot_observed) { + obs_dat <- handle_obs_data(obs_dat, + y_var = y_var, + i_var = i_var, + t_var = t_var, + g_var = g_var) } - # 'id_list' provides unique IDs sorted in the same order as plotting data - id_list <- unique(plot_dat[[id_var]]) - if (!is.null(grp_var)) - grp_list <- unique(plot_dat[[grp_var]]) - - plot_dat$id <- factor(plot_dat[[id_var]]) - plot_dat$time <- plot_dat[[time_var]] - if (!is.null(grp_var)) - plot_dat$grp <- plot_dat[[grp_var]] + # obtain list of ids/clusters in the same order as plotting data + i_list <- if (is.null(i_var)) NULL else unique(fit_dat[[i_var]]) # ids + g_list <- if (is.null(g_var)) NULL else unique(fit_dat[[g_var]]) # clusters - geom_defaults <- list(color = "black", method = "loess", se = FALSE) - geom_args <- set_geom_args(geom_defaults, ...) + # create columns with desired names + fit_dat$id <- if (is.null(i_var)) NULL else fit_dat[[i_var]] # ids + fit_dat$grp <- if (is.null(g_var)) NULL else fit_dat[[g_var]] # clusters + fit_dat$time <- if (is.null(t_var)) NULL else fit_dat[[t_var]] # times - lim_defaults <- list(alpha = 0.3) - lim_args <- do.call("set_geom_args", c(defaults = list(lim_defaults), ci_geom_args)) + # promote ids to factors (same as observed data used to fit the model) + fit_dat$id <- factor(fit_dat$id) - obs_defaults <- list() - obs_args <- set_geom_args(obs_defaults) - - if (is.null(grp_var)) { # no lower level clusters + # determine appropriate variable to use for facets + if (is.null(g_var)) { + # no lower level clusters group_var <- NULL facet_var <- "id" - } else if (grp_overlay) { # overlay lower level clusters + } else if (grp_overlay) { + # overlay lower level clusters group_var <- "grp" facet_var <- "id" - } else { # separate facets for lower level clusters + } else { + # separate facet for each lower level cluster group_var <- NULL facet_var <- "grp" } - n_facets <- if (facet_var == "id") length(id_list) else length(grp_list) - - if (n_facets > 60L) { - stop("Too many facets (ie. individuals) to plot. Perhaps limit the ", - "number of individuals by specifying the 'ids' argument.") - } else if (n_facets > 1L) { - geom_mapp <- list( - mapping = aes_string(x = "time", y = "yfit", group = group_var), - data = plot_dat) - graph <- ggplot() + theme_bw() + - do.call("geom_smooth", c(geom_mapp, geom_args)) + + + # validate the number of facets + n_facets <- if (facet_var == "id") length(i_list) else length(g_list) + + if (n_facets > 60L) + stop("Too many facets (ie. individuals) to plot. Perhaps ", + "limit the number of individuals by specifying the 'ids' argument.") + + # determine which limits to plot (used in aes mapping) + lim_lb <- if (ci) "ci_lb" else "pi_lb" + lim_ub <- if (ci) "ci_ub" else "pi_ub" + + # geom mapping for posterior median + med_mapp <- create_geom_mapp(x = "time", + y = "yfit", + group = group_var) + + # geom mapping for ci limits + lim_mapp <- create_geom_mapp(x = "time", + ymin = lim_lb, + ymax = lim_ub) + + # geom mapping for observed data + obs_mapp <- create_geom_mapp(x = "time", + y = "yobs", + group = group_var) + + # determine default plotting args + med_defaults <- list(color = "black") # for posterior median + lim_defaults <- list(alpha = 0.3) # for ci limits + obs_defaults <- list() # for observed data + + # combine default and user-specified plotting args + med_args <- create_geom_args(med_defaults, list(...)) + lim_args <- create_geom_args(lim_defaults, ci_geom_args) + obs_args <- create_geom_args(obs_defaults, list()) + + # construct each plot component + graph_base <- + ggplot(fit_dat) + theme_bw() + + do.call("geom_line", c(med_mapp, med_args)) + + graph_limits <- + if (ci || pi) { + do.call("geom_ribbon", c(lim_mapp, lim_args)) + } else NULL + + graph_facet <- + if (n_facets > 1L) { facet_wrap(facet_var, scales = facet_scales) - if (!limits == "none") { - graph_smoothlim <- ggplot(plot_dat) + - geom_smooth( - aes_string(x = "time", y = if (ci) "ci_lb" else "pi_lb", group = group_var), - method = "loess", se = FALSE) + - geom_smooth( - aes_string(x = "time", y = if (ci) "ci_ub" else "pi_ub", group = group_var), - method = "loess", se = FALSE) + - facet_wrap(facet_var, scales = facet_scales) - build_smoothlim <- ggplot_build(graph_smoothlim) - df_smoothlim <- data.frame(PANEL = build_smoothlim$data[[1]]$PANEL, - time = build_smoothlim$data[[1]]$x, - lb = build_smoothlim$data[[1]]$y, - ub = build_smoothlim$data[[2]]$y, - group = build_smoothlim$data[[1]]$group) - panel_id_map <- build_smoothlim$layout$layout[, c("PANEL", facet_var), drop = FALSE] - df_smoothlim <- merge(df_smoothlim, panel_id_map) - lim_mapp <- list( - mapping = aes_string(x = "time", ymin = "lb", ymax = "ub", group = "group"), - data = df_smoothlim) - graph_limits <- do.call("geom_ribbon", c(lim_mapp, lim_args)) - } else graph_limits <- NULL - } else { - geom_mapp <- list(mapping = aes_string(x = "time", y = "yfit", group = group_var), - data = plot_dat) - graph <- ggplot() + theme_bw() + - do.call("geom_smooth", c(geom_mapp, geom_args)) - if (!(limits == "none")) { - graph_smoothlim <- ggplot(plot_dat) + - geom_smooth(aes_string(x = "time", y = if (ci) "ci_lb" else "pi_lb"), - method = "loess", se = FALSE) + - geom_smooth(aes_string(x = "time", y = if (ci) "ci_ub" else "pi_ub"), - method = "loess", se = FALSE) - build_smoothlim <- ggplot_build(graph_smoothlim) - df_smoothlim <- data.frame(time = build_smoothlim$data[[1]]$x, - lb = build_smoothlim$data[[1]]$y, - ub = build_smoothlim$data[[2]]$y, - group = build_smoothlim$data[[1]]$group) - lim_mapp <- list( - mapping = aes_string(x = "time", ymin = "lb", ymax = "ub", group = "group"), - data = df_smoothlim) - graph_limits <- do.call("geom_ribbon", c(lim_mapp, lim_args)) - } else graph_limits <- NULL - } - if (plot_observed) { - if (y_var %in% colnames(obs_dat)) { - obs_dat$y <- obs_dat[[y_var]] - } else { - obs_dat$y <- try(eval(parse(text = y_var), obs_dat)) - if (inherits(obs_dat$y, "try-error")) - stop("Could not find ", y_var, "in observed data, nor able to parse ", - y_var, "as an expression.") - } - obs_dat$id <- factor(obs_dat[[id_var]]) - obs_dat$time <- obs_dat[[time_var]] - if (!is.null(grp_var)) - obs_dat$grp <- obs_dat[[grp_var]] - if (is.null(obs_dat[["y"]])) - stop("Cannot find observed outcome data to add to plot.") - obs_mapp <- list( - mapping = aes_string(x = "time", y = "y", group = group_var), - data = obs_dat) - graph_obs <- do.call("geom_point", c(obs_mapp, obs_args)) - } else graph_obs <- NULL - if (vline) { - if (facet_var == "id") { - facet_list <- unique(plot_dat[, id_var]) - last_time <- attr(x, "last_time")[as.character(facet_list)] # potentially reorder last_time to match plot_dat - } else { - facet_list <- unique(plot_dat[, c(id_var, grp_var)]) - last_time <- attr(x, "last_time")[as.character(facet_list[[id_var]])] # potentially reorder last_time to match plot_dat - facet_list <- facet_list[[grp_var]] - } - vline_dat <- data.frame(FACETVAR = facet_list, last_time = last_time) - colnames(vline_dat) <- c(facet_var, "last_time") - graph_vline <- geom_vline( - mapping = aes_string(xintercept = "last_time"), - data = vline_dat, linetype = 2) - } else graph_vline <- NULL - - ret <- graph + graph_limits + graph_obs + graph_vline + labs(x = xlab, y = ylab) - class_ret <- class(ret) - class(ret) <- c("plot.predict.stanjm", class_ret) - ret + } else NULL + + graph_obs <- + if (plot_observed) { + obs_mapp$data <- obs_dat + do.call("geom_point", c(obs_mapp, obs_args)) + } else NULL + + graph_vline <- + if (vline) { + # potentially reorder last_time to match ordering in fit_dat + if (facet_var == "id") { + facet_list <- unique(fit_dat[, i_var]) + last_time <- attr(x, "last_time")[as.character(facet_list)] + } else { + facet_list <- unique(fit_dat[, c(i_var, g_var)]) + last_time <- attr(x, "last_time")[as.character(facet_list[[i_var]])] + facet_list <- facet_list[[g_var]] + } + vline_dat <- data.frame(FACETVAR = facet_list, last_time = last_time) + vline_dat <- set_colnames(vline_dat, c(facet_var, "last_time")) + geom_vline(mapping = aes_string(xintercept = "last_time"), + data = vline_dat, + linetype = 2) + } else NULL + + graph_labels <- labs(x = xlab, y = ylab) + gg <- + graph_base + + graph_facet + + graph_limits + + graph_labels + + graph_obs + + graph_vline + + class_gg <- class(gg) + class(gg) <- c("plot.predict.stanjm", class_gg) + gg } # internal ---------------------------------------------------------------- +# Get the name of the response variable +# +# @param object A stanjm model. +# @param m Integer specifying which submodel. +get_resp_name <- function(object, m) { + Terms <- terms(formula(object, m = m)) + vars <- rownames(attr(Terms, "factors")) + yvar <- vars[[attr(Terms, "response")]] + yvar +} + +# Return a list with the aes mapping +create_geom_mapp <- function(...) { + list(mapping = ggplot2::aes_string(...)) +} + +# Call set_geom_args on the default and user-specified plotting args +create_geom_args <- function(defaults = list(), more_args = list()) { + do.call("set_geom_args", c(defaults = list(defaults), more_args)) +} + +# Construct a data frame for plotting observed biomarker data +handle_obs_data <- function(data, y_var, i_var, t_var, g_var) { + + # add outcome variable + if (y_var %in% colnames(data)) { + data$yobs <- data[[y_var]] + } else { + data$yobs <- try(eval(parse(text = y_var), data)) + if (inherits(data$yobs, "try-error")) + stop("Could not find ", y_var, "in observed data, nor able to parse ", + y_var, "as an expression.") + } + + # add id variable + data$id <- factor(data[[i_var]]) + + # add time variable + data$time <- data[[t_var]] + + # add grp variable (identifier for lower level clusters) + if (!is.null(g_var)) + data$grp <- data[[g_var]] + + # final validation that outcome data exists in data frame + if (is.null(data[["yobs"]])) + stop("Cannot find observed outcome data to add to plot.") + + data +} + + # Return a list with the control arguments for interpolation and/or # extrapolation in posterior_predict.stanmvreg and posterior_survfit.stanjm # # @param control A named list, being the user input to the control argument # in the posterior_predict.stanmvreg or posterior_survfit.stanjm call -# @param ok_control_args A character vector of allowed control arguments +# @param ok_args A character vector of allowed control arguments # @return A named list -get_extrapolation_control <- - function(control = list(), ok_control_args = c("epoints", "edist", "eprop")) { - defaults <- list(ipoints = 15, epoints = 15, edist = NULL, eprop = 0.2, last_time = NULL) +extrapolation_control <- + function(control = list(), ok_args = c("epoints", "edist", "eprop")) { + defaults <- list(ipoints = 100, epoints = 100, edist = NULL, eprop = 0.2, last_time = NULL) if (!is.list(control)) { stop("'control' should be a named list.") } else if (!length(control)) { - control <- defaults[ok_control_args] + control <- defaults[ok_args] } else { # user specified control list nms <- names(control) if (!length(nms)) stop("'control' should be a named list.") - if (any(!nms %in% ok_control_args)) + if (any(!nms %in% ok_args)) stop(paste0("'control' list can only contain the following named arguments: ", - paste(ok_control_args, collapse = ", "))) + paste(ok_args, collapse = ", "))) if (all(c("edist", "eprop") %in% nms)) stop("'control' list cannot include both 'edist' and 'eprop'.") - if (("ipoints" %in% ok_control_args) && is.null(control$ipoints)) + if (("ipoints" %in% ok_args) && is.null(control$ipoints)) control$ipoints <- defaults$ipoints - if (("epoints" %in% ok_control_args) && is.null(control$epoints)) + if (("epoints" %in% ok_args) && is.null(control$epoints)) control$epoints <- defaults$epoints if (is.null(control$edist) && is.null(control$eprop)) control$eprop <- defaults$eprop diff --git a/R/pp_data.R b/R/pp_data.R index e86226cf6..126c5357a 100644 --- a/R/pp_data.R +++ b/R/pp_data.R @@ -24,6 +24,9 @@ pp_data <- m = NULL, ...) { validate_stanreg_object(object) + if (is.stansurv(object)) { + return(.pp_data_surv(object, newdata = newdata, ...)) + } if (is.mer(object)) { if (is.nlmer(object)) out <- .pp_data_nlmer(object, newdata = newdata, re.form = re.form, m = m, ...) @@ -35,7 +38,9 @@ pp_data <- .pp_data(object, newdata = newdata, offset = offset, ...) } -# for models without lme4 structure + +#------------- for models without lme4 structure ------------------------- + .pp_data <- function(object, newdata = NULL, offset = NULL, ...) { if (is(object, "gamm4")) { requireNamespace("mgcv", quietly = TRUE) @@ -78,7 +83,8 @@ pp_data <- } -# for models fit using stan_(g)lmer or stan_gamm4 +#--------- for models fit using stan_(g)lmer or stan_gamm4 ----------------- + .pp_data_mer <- function(object, newdata, re.form, m = NULL, ...) { if (is(object, "gamm4")) { requireNamespace("mgcv", quietly = TRUE) @@ -110,44 +116,6 @@ pp_data <- return(nlist(x, offset = offset, Zt = z$Zt, Z_names = z$Z_names)) } -# for models fit using stan_nlmer -.pp_data_nlmer <- function(object, newdata, re.form, offset = NULL, m = NULL, ...) { - inputs <- parse_nlf_inputs(object$glmod$respMod) - if (is.null(newdata)) { - arg1 <- arg2 <- NULL - } else if (object$family$link == "inv_SSfol") { - arg1 <- newdata[[inputs[2]]] - arg2 <- newdata[[inputs[3]]] - } else { - arg1 <- newdata[[inputs[2]]] - arg2 <- NULL - } - f <- formula(object, m = m) - if (!is.null(re.form) && !is.na(re.form)) { - f <- as.character(f) - f[3] <- as.character(re.form) - f <- as.formula(f[-1]) - } - if (is.null(newdata)) newdata <- model.frame(object) - else { - yname <- names(model.frame(object))[1] - newdata[[yname]] <- 0 - } - mc <- match.call(expand.dots = FALSE) - mc$re.form <- mc$offset <- mc$object <- mc$newdata <- NULL - mc$data <- newdata - mc$formula <- f - mc$start <- fixef(object) - nlf <- nlformula(mc) - offset <- .pp_data_offset(object, newdata, offset) - - group <- with(nlf$reTrms, pad_reTrms(Ztlist, cnms, flist)) - if (!is.null(re.form) && !is(re.form, "formula") && is.na(re.form)) - group$Z@x <- 0 - return(nlist(x = nlf$X, offset = offset, Z = group$Z, - Z_names = make_b_nms(group), arg1, arg2)) -} - # the functions below are heavily based on a combination of # lme4:::predict.merMod and lme4:::mkNewReTrms, although they do also have # substantial modifications @@ -241,40 +209,213 @@ pp_data <- } +#------------- for models fit using stan_nlmer ----------------------------- -# handle offsets ---------------------------------------------------------- -null_or_zero <- function(x) { - isTRUE(is.null(x) || all(x == 0)) -} - -.pp_data_offset <- function(object, newdata = NULL, offset = NULL) { +.pp_data_nlmer <- function(object, newdata, re.form, offset = NULL, m = NULL, ...) { + inputs <- parse_nlf_inputs(object$glmod$respMod) if (is.null(newdata)) { - # get offset from model object (should be null if no offset) - if (is.null(offset)) - offset <- object$offset %ORifNULL% model.offset(model.frame(object)) + arg1 <- arg2 <- NULL + } else if (object$family$link == "inv_SSfol") { + arg1 <- newdata[[inputs[2]]] + arg2 <- newdata[[inputs[3]]] } else { - if (!is.null(offset)) - stopifnot(length(offset) == nrow(newdata)) - else { - # if newdata specified but not offset then confirm that model wasn't fit - # with an offset (warning, not error) - if (!is.null(object$call$offset) || - !null_or_zero(object$offset) || - !null_or_zero(model.offset(model.frame(object)))) { - warning( - "'offset' argument is NULL but it looks like you estimated ", - "the model using an offset term.", - call. = FALSE - ) - } - offset <- rep(0, nrow(newdata)) + arg1 <- newdata[[inputs[2]]] + arg2 <- NULL + } + f <- formula(object, m = m) + if (!is.null(re.form) && !is.na(re.form)) { + f <- as.character(f) + f[3] <- as.character(re.form) + f <- as.formula(f[-1]) + } + if (is.null(newdata)) newdata <- model.frame(object) + else { + yname <- names(model.frame(object))[1] + newdata[[yname]] <- 0 + } + mc <- match.call(expand.dots = FALSE) + mc$re.form <- mc$offset <- mc$object <- mc$newdata <- NULL + mc$data <- newdata + mc$formula <- f + mc$start <- fixef(object) + nlf <- nlformula(mc) + offset <- .pp_data_offset(object, newdata, offset) + + group <- with(nlf$reTrms, pad_reTrms(Ztlist, cnms, flist)) + if (!is.null(re.form) && !is(re.form, "formula") && is.na(re.form)) + group$Z@x <- 0 + return(nlist(x = nlf$X, offset = offset, Z = group$Z, + Z_names = make_b_nms(group), arg1, arg2)) +} + + +#------------------ for models fit using stan_surv ----------------------- + +.pp_data_surv <- function(object, + newdata = NULL, + times = NULL, + at_quadpoints = FALSE, + ...) { + + formula <- object$formula + basehaz <- object$basehaz + + # data with row subsetting etc + if (is.null(newdata)) + newdata <- get_model_data(object) + + # flags + has_tve <- object$has_tve + has_quadrature <- object$has_quadrature + has_bars <- object$has_bars + + #----- dimensions and times + + if (has_quadrature && at_quadpoints) { + + if (is.null(times)) + stop("Bug found: 'times' must be specified.") + + # error check time variables + if (!length(times) == nrow(newdata)) + stop("Bug found: length of 'times' should equal number rows in the data.") + + # number of nodes + qnodes <- object$qnodes + + # standardised weights and nodes for quadrature + qq <- get_quadpoints(nodes = qnodes) + qp <- qq$points + qw <- qq$weights + + # quadrature points & weights, evaluated for each row of data + pts <- uapply(qp, unstandardise_qpts, 0, times) + wts <- uapply(qw, unstandardise_qwts, 0, times) + + # id vector for quadrature points + ids <- rep(seq_along(times), times = qnodes) + + } else { # predictions don't require quadrature + + pts <- times + wts <- rep(NA, length(times)) + ids <- seq_along(times) + + } + + #----- time-fixed predictor matrix + + # check all vars are in newdata + vars <- all.vars(delete.response(terms(object, fixed.only = FALSE))) + miss <- which(!vars %in% colnames(newdata)) + if (length(miss)) + stop2("The following variables are missing from the data: ", comma(vars[miss])) + + # drop response from fixed effect formula + tt <- delete.response(terms(object, fixed.only = TRUE)) + + # make model frame based on time-fixed part of model formula + mf <- make_model_frame(tt, newdata, xlevs = object$xlevs)$mf + + # if using quadrature then expand rows of time-fixed predictor matrix + if (has_quadrature && at_quadpoints) + mf <- rep_rows(mf, times = qnodes) + + # check data classes in the model frame match those used in model fitting + if (!is.null(cl <- attr(tt, "dataClasses"))) + .checkMFClasses(cl, mf) + + # check model frame dimensions are correct (may be errors due to NAs?) + if (!length(pts) == nrow(mf)) + stop("Bug found: length of 'pts' should equal number rows in model frame.") + + # construct time-fixed predictor matrix + x <- make_x(tt, mf, check_constant = FALSE)$x + + #----- time-varying predictor matrix + + if (has_tve) { + + if (all(is.na(pts))) { + # temporary replacement to avoid error in creating spline basis + pts_tmp <- rep(0, length(pts)) + } else { + # else use prediction times or quadrature points + pts_tmp <- pts } + + # generate a model frame with time transformations for tve effects + mf_s <- make_model_frame(formula$tt_frame, data.frame(times__ = pts_tmp))$mf + + # check model frame dimensions are correct + if (!length(pts) == nrow(mf_s)) + stop("Bug found: length of 'pts' should equal number rows in model frame.") + + # NB next line avoids dropping terms attribute from 'mf' + mf[, colnames(mf_s)] <- mf_s + + # construct time-varying predictor matrix + s <- make_s(formula, mf) + + if (all(is.na(pts))) { + # if pts were all NA then replace the time-varying predictor + # matrix with all NA, but retain appropriate dimensions + s[] <- NaN + } + + } else { + + s <- matrix(0, nrow(mf), 0) + } - return(offset) + + #----- random effects predictor matrix + + if (has_bars) { + + # drop response from random effects part of model formula + tt_z <- delete.response(terms(object, random.only = TRUE)) + + # make model frame based on random effects part of model formula + mf_z <- make_model_frame(formula = tt_z, + data = newdata, + xlevs = object$xlevs, + na.action = na.pass)$mf + + # if using quadrature then expand rows + if (has_quadrature && at_quadpoints) + mf_z <- rep_rows(mf_z, times = qnodes) + + # check model frame dimensions are correct + if (!length(pts) == nrow(mf_z)) + stop("Bug found: length of 'pts' should equal number rows in model frame.") + + # construct random effects predictor matrix + ReTrms <- lme4::mkReTrms(formula$bars, mf_z) + z <- nlist(Zt = ReTrms$Zt, Z_names = make_b_nms(ReTrms)) + + } else { + + z <- list() + + } + + # return object + return(nlist(pts, + wts, + ids, + x, + s, + z, + has_quadrature, + has_tve, + has_bars, + at_quadpoints, + qnodes = object$qnodes)) } -#----------------------- pp_data for joint models -------------------------- +#-------------------- for models fit using stan_jm ----------------------- # Return the design matrices required for evaluating the linear predictor or # log-likelihood in post-estimation functions for a \code{stan_jm} model @@ -292,17 +433,31 @@ null_or_zero <- function(x) { # the fitted object (if newdataEvent is NULL) or in newdataEvent. # @param long_parts,event_parts A logical specifying whether to return the # design matrices for the longitudinal and/or event submodels. +# @param response Logical specifying whether the newdataLong requires the +# response variable. # @return A named list (with components M, Npat, ndL, ndE, yX, tZt, # yZnames, eXq, assoc_parts) -.pp_data_jm <- function(object, newdataLong = NULL, newdataEvent = NULL, - ids = NULL, etimes = NULL, long_parts = TRUE, - event_parts = TRUE) { +.pp_data_jm <- function(object, + newdataLong = NULL, + newdataEvent = NULL, + ids = NULL, + etimes = NULL, + long_parts = TRUE, + event_parts = TRUE, + response = TRUE, + needs_time_var = TRUE) { + M <- get_M(object) + id_var <- object$id_var time_var <- object$time_var if (!is.null(newdataLong) || !is.null(newdataEvent)) - newdatas <- validate_newdatas(object, newdataLong, newdataEvent) + newdatas <- validate_newdatas(object, + newdataLong, + newdataEvent, + response = response, + needs_time_var = needs_time_var) # prediction data for longitudinal submodels ndL <- if (is.null(newdataLong)) @@ -314,8 +469,8 @@ null_or_zero <- function(x) { # possibly subset if (!is.null(ids)) { - ndL <- subset_ids(object, ndL, ids) - ndE <- subset_ids(object, ndE, ids) + ndL <- subset_ids(ndL, ids, id_var) + ndE <- subset_ids(ndE, ids, id_var) } id_list <- unique(ndE[[id_var]]) # unique subject id list @@ -344,7 +499,7 @@ null_or_zero <- function(x) { if (long_parts && event_parts) lapply(ndL, function(x) { - if (!time_var %in% colnames(x)) + if (!time_var %in% colnames(x)) STOP_no_var(time_var) if (!id_var %in% colnames(x)) STOP_no_var(id_var) @@ -375,6 +530,8 @@ null_or_zero <- function(x) { qtimes <- uapply(qq$points, unstandardise_qpts, 0, etimes) qwts <- uapply(qq$weights, unstandardise_qwts, 0, etimes) starttime <- deparse(formula(object, m = "Event")[[2L]][[2L]]) + if ((!response) && (!starttime %in% colnames(ndE))) + ndE[[starttime]] <- 0 edat <- prepare_data_table(ndE, id_var, time_var = starttime) id_rep <- rep(id_list, qnodes + 1) times <- c(etimes, qtimes) # times used to design event submodel matrices @@ -418,12 +575,22 @@ null_or_zero <- function(x) { # need to be recalculated at quadrature points etc, for example # in posterior_survfit. # -# @param object A stanmvreg object. +# @param object A stansurv, stanmvreg or stanjm object. # @param m Integer specifying which submodel to get the -# prediction data frame for. +# prediction data frame for (for stanmvreg or stanjm objects). # @return A data frame or list of data frames with all the # (unevaluated) variables required for predictions. -get_model_data <- function(object, m = NULL) { +get_model_data <- function(object, ...) UseMethod("get_model_data") + +get_model_data.stansurv <- function(object, ...) { + validate_stansurv_object(object) + terms <- terms(object, fixed.only = FALSE) + row_nms <- row.names(model.frame(object)) + get_all_vars(terms, object$data)[row_nms, , drop = FALSE] +} + +get_model_data.stanmvreg <- function(object, m = NULL, ...) { + validate_stanmvreg_object(object) M <- get_M(object) terms <- terms(object, fixed.only = FALSE) @@ -471,3 +638,36 @@ get_model_data <- function(object, m = NULL) { mfs <- list_nms(mfs, M, stub = get_stub(object)) if (is.null(m)) mfs else mfs[[m]] } + + +#----------------------- handle offsets ---------------------------------- + +null_or_zero <- function(x) { + isTRUE(is.null(x) || all(x == 0)) +} + +.pp_data_offset <- function(object, newdata = NULL, offset = NULL) { + if (is.null(newdata)) { + # get offset from model object (should be null if no offset) + if (is.null(offset)) + offset <- object$offset %ORifNULL% model.offset(model.frame(object)) + } else { + if (!is.null(offset)) + stopifnot(length(offset) == nrow(newdata)) + else { + # if newdata specified but not offset then confirm that model wasn't fit + # with an offset (warning, not error) + if (!is.null(object$call$offset) || + !null_or_zero(object$offset) || + !null_or_zero(model.offset(model.frame(object)))) { + warning( + "'offset' argument is NULL but it looks like you estimated ", + "the model using an offset term.", + call. = FALSE + ) + } + offset <- rep(0, nrow(newdata)) + } + } + return(offset) +} diff --git a/R/print-and-summary.R b/R/print-and-summary.R index 463754050..b0108cb56 100644 --- a/R/print-and-summary.R +++ b/R/print-and-summary.R @@ -68,14 +68,27 @@ #' @seealso \code{\link{summary.stanreg}}, \code{\link{stanreg-methods}} #' print.stanreg <- function(x, digits = 1, detail = TRUE, ...) { + if (detail) { cat(x$stan_function) - cat("\n family: ", family_plus_link(x)) - cat("\n formula: ", formula_string(formula(x))) - cat("\n observations:", nobs(x)) - if (isTRUE(x$stan_function %in% - c("stan_glm", "stan_glm.nb", "stan_lm", "stan_aov"))) { - cat("\n predictors: ", length(coef(x))) + if (is.surv(x)) { + cat("\n baseline hazard:", basehaz_string(x$basehaz)) + cat("\n formula: ", formula_string(formula(x))) + cat("\n observations: ", x$nobs) + cat("\n events: ", x$nevents, percent_string(x$nevents, x$nobs)) + if (x$nlcens > 0) + cat("\n left censored: ", x$nlcens, percent_string(x$nlcens, x$nobs)) + if (x$nrcens > 0) + cat("\n right censored: ", x$nrcens, percent_string(x$nrcens, x$nobs)) + if (x$nicens > 0) + cat("\n interval cens.: ", x$nicens, percent_string(x$nicens, x$nobs)) + cat("\n delayed entry: ", yes_no_string(x$ndelayed)) + } else { + cat("\n family: ", family_plus_link(x)) + cat("\n formula: ", formula_string(formula(x))) + cat("\n observations:", nobs(x)) + if (isTRUE(x$stan_function %in% c("stan_glm", "stan_glm.nb", "stan_lm", "stan_aov"))) + cat("\n predictors: ", length(coef(x))) } if (!is.null(x$call$subset)) { cat("\n subset: ", deparse(x$call$subset)) @@ -84,9 +97,10 @@ print.stanreg <- function(x, digits = 1, detail = TRUE, ...) { cat("\n------\n") } - mer <- is.mer(x) + surv <- is.surv(x) + mer <- is.mer(x) || (surv && x$has_bars) gamm <- isTRUE(x$stan_function == "stan_gamm4") - ord <- is_polr(x) && !("(Intercept)" %in% rownames(x$stan_summary)) + ord <- is_polr(x) && !("(Intercept)" %in% rownames(x$stan_summary)) aux_nms <- .aux_name(x) @@ -122,6 +136,15 @@ print.stanreg <- function(x, digits = 1, detail = TRUE, ...) { if (mer) { estimates <- estimates[!grepl("^Sigma\\[", rownames(estimates)),, drop=FALSE] } + if (surv) { + nms_int <- get_int_name_basehaz(get_basehaz(x)) + nms_aux <- get_aux_name_basehaz(get_basehaz(x)) + nms_beta <- setdiff(rownames(estimates), c(nms_int, nms_aux)) + estimates <- cbind(estimates, + "exp(Median)" = c(rep(NA, length(nms_int)), + exp(estimates[nms_beta, "Median"]), + rep(NA, length(nms_aux)))) + } .printfr(estimates, digits, ...) if (length(aux_nms)) { @@ -172,32 +195,36 @@ print.stanreg <- function(x, digits = 1, detail = TRUE, ...) { #' @rdname print.stanreg #' @export #' @method print stanmvreg -print.stanmvreg <- function(x, digits = 3, ...) { - M <- x$n_markers +print.stanmvreg <- function(x, digits = 3, detail = TRUE, ...) { + mvmer <- is.mvmer(x) - surv <- is.surv(x) - jm <- is.jm(x) + surv <- is.surv(x) + jm <- is.jm(x) + M <- x$n_markers stubs <- paste0("(", get_stub(x), 1:M, "):") - cat(x$stan_function) - if (mvmer) { - for (m in 1:M) { - cat("\n formula", stubs[m], formula_string(formula(x, m = m))) - cat("\n family ", stubs[m], family_plus_link(x, m = m)) - } - } - if (surv) { - cat("\n formula (Event):", formula_string(formula(x, m = "Event"))) - cat("\n baseline hazard:", x$basehaz$type_name) - } - if (jm) { - sel <- grep("^which", rownames(x$assoc), invert = TRUE, value = TRUE) - assoc <- lapply(1:M, function(m) { - vals <- sel[which(x$assoc[sel,m] == TRUE)] - paste0(vals, " (Long", m, ")") - }) - cat("\n assoc: ", paste(unlist(assoc), collapse = ", ")) + + if (detail) { + cat(x$stan_function) + if (mvmer) { + for (m in 1:M) { + cat("\n formula", stubs[m], formula_string(formula(x, m = m))) + cat("\n family ", stubs[m], family_plus_link(x, m = m)) + } + } + if (surv) { + cat("\n formula (Event):", formula_string(formula(x, m = "Event"))) + cat("\n baseline hazard:", x$basehaz$type_name) + } + if (jm) { + sel <- grep("^which", rownames(x$assoc), invert = TRUE, value = TRUE) + assoc <- lapply(1:M, function(m) { + vals <- sel[which(x$assoc[sel,m] == TRUE)] + paste0(vals, " (Long", m, ")") + }) + cat("\n assoc: ", paste(unlist(assoc), collapse = ", ")) + } + cat("\n------\n") } - cat("\n------\n") mat <- as.matrix(x$stanfit) nms <- collect_nms(rownames(x$stan_summary), M, @@ -258,17 +285,13 @@ print.stanmvreg <- function(x, digits = 3, ...) { print(VarCorr(x), digits = digits + 1, ...) cat("Num. levels:", paste(names(ngrps(x)), unname(ngrps(x)), collapse = ", "), "\n") - - # Sample average of the PPD - ppd_mat <- mat[, nms$ppd, drop = FALSE] - ppd_estimates <- .median_and_madsd(ppd_mat) - cat("\nSample avg. posterior predictive distribution \nof", - if (is.jm(x)) "longitudinal outcomes:\n" else "y:\n") - .printfr(ppd_estimates, digits, ...) } - cat("\n------\n") - cat("For info on the priors used see help('prior_summary.stanreg').") + if (detail) { + cat("\n------\n") + cat("* For help interpreting the printed output see ?print.stanreg\n") + cat("* For info on the priors used see ?prior_summary.stanreg\n") + } invisible(x) } @@ -370,7 +393,8 @@ summary.stanreg <- function(object, probs = c(0.1, 0.5, 0.9), ..., digits = 1) { - mer <- is.mer(object) + surv <- is.surv(object) + mer <- is.mer(object) pars <- collect_pars(object, pars, regex_pars) if (!used.optimizing(object)) { @@ -420,22 +444,30 @@ summary.stanreg <- function(object, out <- object$stan_summary[mark, , drop=FALSE] } + is_glm <- + isTRUE(object$stan_function %in% c("stan_glm", "stan_glm.nb", "stan_lm")) + structure( out, - call = object$call, - algorithm = object$algorithm, + call = object$call, + algorithm = object$algorithm, stan_function = object$stan_function, - family = family_plus_link(object), - formula = formula(object), + family = family_plus_link(object), + formula = formula(object), + basehaz = if (surv) basehaz_string(get_basehaz(object)) else NULL, posterior_sample_size = posterior_sample_size(object), - nobs = nobs(object), - npreds = if (isTRUE(object$stan_function %in% c("stan_glm", "stan_glm.nb", "stan_lm"))) - length(coef(object)) else NULL, - ngrps = if (mer) ngrps(object) else NULL, - print.digits = digits, - priors = object$prior.info, + nobs = nobs(object), + npreds = if (is_glm) length(coef(object)) else NULL, + ngrps = if (mer) ngrps(object) else NULL, + nevents = if (surv) object$nevents else NULL, + nlcens = if (surv) object$nlcens else NULL, + nrcens = if (surv) object$nrcens else NULL, + nicens = if (surv) object$nicens else NULL, + ndelayed = if (surv) object$ndelayed else NULL, + print.digits = digits, + priors = object$prior.info, no_ppd_diagnostic = no_mean_PPD(object), - class = "summary.stanreg" + class = "summary.stanreg" ) } @@ -444,11 +476,29 @@ summary.stanreg <- function(object, #' @method print summary.stanreg #' #' @param x An object of class \code{"summary.stanreg"}. -print.summary.stanreg <- - function(x, digits = max(1, attr(x, "print.digits")), - ...) { - atts <- attributes(x) - cat("\nModel Info:") +print.summary.stanreg <- function(x, digits = max(1, attr(x, "print.digits")), + ...) { + atts <- attributes(x) + cat("\nModel Info:\n") + + if (is.surv(atts)) { # survival models + cat("\n function: ", atts$stan_function) + cat("\n baseline hazard:", atts$basehaz) + cat("\n formula: ", formula_string(atts$formula)) + cat("\n algorithm: ", atts$algorithm) + if (!is.null(atts$posterior_sample_size) && atts$algorithm == "sampling") + cat("\n sample: ", atts$posterior_sample_size, "(posterior sample size)") + cat("\n priors: ", "see help('prior_summary')") + cat("\n observations: ", atts$nobs) + cat("\n events: ", atts$nevents, percent_string(atts$nevents, atts$nobs)) + if (atts$nlcens > 0) + cat("\n left censored: ", atts$nlcens, percent_string(atts$nlcens, atts$nobs)) + if (atts$nrcens > 0) + cat("\n right censored: ", atts$nrcens, percent_string(atts$nrcens, atts$nobs)) + if (atts$nicens > 0) + cat("\n interval cens.: ", atts$nicens, percent_string(atts$nicens, atts$nobs)) + cat("\n delayed entry: ", yes_no_string(atts$ndelayed)) + } else { # anything except survival models cat("\n function: ", atts$stan_function) cat("\n family: ", atts$family) cat("\n formula: ", formula_string(atts$formula)) @@ -458,7 +508,6 @@ print.summary.stanreg <- "(posterior sample size)") } cat("\n priors: ", "see help('prior_summary')") - cat("\n observations:", atts$nobs) if (!is.null(atts$npreds)) { cat("\n predictors: ", atts$npreds) @@ -471,64 +520,65 @@ print.summary.stanreg <- unname(atts$ngrps), ")", collapse = ", ")) } + } - cat("\n\nEstimates:\n") - if (used.optimizing(atts) || used.variational(atts)) { - hat <- "khat" - str_diag <- "Monte Carlo diagnostics" - str1 <- "and khat is the Pareto k diagnostic for importance sampling" - str2 <- " (perfomance is usually good when khat < 0.7).\n" - } else { - hat <- "Rhat" - str_diag <- "MCMC diagnostics" - str1 <- "and Rhat is the potential scale reduction factor on split chains" - str2 <- " (at convergence Rhat=1).\n" - } - sel <- which(colnames(x) %in% c("mcse", "n_eff", hat)) - has_mc_diagnostic <- length(sel) > 0 - if (has_mc_diagnostic) { - xtemp <- x[, -sel, drop = FALSE] - colnames(xtemp) <- paste(" ", colnames(xtemp)) - } else { - xtemp <- x - } - - ppd_nms <- grep("^mean_PPD", rownames(x), value = TRUE) - has_ppd_diagnostic <- !atts$no_ppd_diagnostic && length(ppd_nms) > 0 - - if (has_ppd_diagnostic) { - ppd_estimates <- xtemp[rownames(xtemp) %in% ppd_nms, , drop=FALSE] - } else { - ppd_estimates <- NULL - } - xtemp <- xtemp[!rownames(xtemp) %in% c(ppd_nms, "log-posterior"), , drop=FALSE] - - # print table of parameter stats - .printfr(xtemp, digits) - - if (has_ppd_diagnostic) { - cat("\nFit Diagnostics:\n") - .printfr(ppd_estimates, digits) - cat("\nThe mean_ppd is the sample average posterior predictive ", - "distribution of the outcome variable ", - "(for details see help('summary.stanreg')).\n", - sep = '') - } - - if (has_mc_diagnostic) { - cat("\n", str_diag, "\n", sep = '') - mcse_hat <- format(round(x[, c("mcse", hat), drop = FALSE], digits), - nsmall = digits) - n_eff <- format(x[, "n_eff", drop = FALSE], drop0trailing = TRUE) - print(cbind(mcse_hat, n_eff), quote = FALSE) - cat("\nFor each parameter, mcse is Monte Carlo standard error, ", - "n_eff is a crude measure of effective sample size, ", - str1, - str2, sep = '') - } - - invisible(x) + cat("\n\nEstimates:\n") + if (used.optimizing(atts) || used.variational(atts)) { + hat <- "khat" + str_diag <- "Monte Carlo diagnostics" + str1 <- "and khat is the Pareto k diagnostic for importance sampling" + str2 <- " (perfomance is usually good when khat < 0.7).\n" + } else { + hat <- "Rhat" + str_diag <- "MCMC diagnostics" + str1 <- "and Rhat is the potential scale reduction factor on split chains" + str2 <- " (at convergence Rhat=1).\n" + } + sel <- which(colnames(x) %in% c("mcse", "n_eff", hat)) + has_mc_diagnostic <- length(sel) > 0 + if (has_mc_diagnostic) { + xtemp <- x[, -sel, drop = FALSE] + colnames(xtemp) <- paste(" ", colnames(xtemp)) + } else { + xtemp <- x + } + + ppd_nms <- grep("^mean_PPD", rownames(x), value = TRUE) + has_ppd_diagnostic <- !atts$no_ppd_diagnostic && length(ppd_nms) > 0 + + if (has_ppd_diagnostic) { + ppd_estimates <- xtemp[rownames(xtemp) %in% ppd_nms, , drop=FALSE] + } else { + ppd_estimates <- NULL + } + xtemp <- xtemp[!rownames(xtemp) %in% c(ppd_nms, "log-posterior"), , drop=FALSE] + + # print table of parameter stats + .printfr(xtemp, digits) + + if (has_ppd_diagnostic) { + cat("\nFit Diagnostics:\n") + .printfr(ppd_estimates, digits) + cat("\nThe mean_ppd is the sample average posterior predictive ", + "distribution of the outcome variable ", + "(for details see help('summary.stanreg')).\n", + sep = '') + } + + if (has_mc_diagnostic) { + cat("\n", str_diag, "\n", sep = '') + mcse_hat <- format(round(x[, c("mcse", hat), drop = FALSE], digits), + nsmall = digits) + n_eff <- format(x[, "n_eff", drop = FALSE], drop0trailing = TRUE) + print(cbind(mcse_hat, n_eff), quote = FALSE) + cat("\nFor each parameter, mcse is Monte Carlo standard error, ", + "n_eff is a crude measure of effective sample size, ", + str1, + str2, sep = '') } + + invisible(x) +} #' @rdname summary.stanreg #' @method as.data.frame summary.stanreg @@ -540,13 +590,18 @@ as.data.frame.summary.stanreg <- function(x, ...) { #' @rdname summary.stanreg #' @export #' @method summary stanmvreg -summary.stanmvreg <- function(object, pars = NULL, regex_pars = NULL, - probs = NULL, ..., digits = 3) { - pars <- collect_pars(object, pars, regex_pars) - M <- object$n_markers +summary.stanmvreg <- function(object, + pars = NULL, + regex_pars = NULL, + probs = c(0.1, 0.5, 0.9), + ..., + digits = 3) { + + pars <- collect_pars(object, pars, regex_pars) + M <- object$n_markers mvmer <- is.mvmer(object) - surv <- is.surv(object) - jm <- is.jm(object) + surv <- is.surv(object) + jm <- is.jm(object) if (mvmer) { # Outcome variable for each longitudinal submodel @@ -567,24 +622,27 @@ summary.stanmvreg <- function(object, pars = NULL, regex_pars = NULL, } # Construct summary table - args <- list(object = object$stanfit) - if (!is.null(probs)) - args$probs <- probs + args <- list(object = object$stanfit, probs = probs) out <- do.call("summary", args)$summary nms <- collect_nms(rownames(object$stan_summary), M, stub = get_stub(object), value = TRUE) if (!is.null(pars)) { pars2 <- NA if ("alpha" %in% pars) pars2 <- c(pars2, nms$alpha) - if ("beta" %in% pars) pars2 <- c(pars2, nms$beta) - if ("long" %in% pars) pars2 <- c(pars2, unlist(nms$y), unlist(nms$y_extra)) + if ("beta" %in% pars) pars2 <- c(pars2, nms$beta) + if ("long" %in% pars) pars2 <- c(pars2, unlist(nms$y), unlist(nms$y_extra)) if ("event" %in% pars) pars2 <- c(pars2, nms$e, nms$a, nms$e_extra) if ("assoc" %in% pars) pars2 <- c(pars2, nms$a) if ("fixef" %in% pars) pars2 <- c(pars2, unlist(nms$y), nms$e, nms$a) - if ("b" %in% pars) pars2 <- c(pars2, nms$b) - pars2 <- c(pars2, setdiff(pars, - c("alpha", "beta", "varying", "b", - "long", "event", "assoc", "fixef"))) + if ("b" %in% pars) pars2 <- c(pars2, nms$b) + pars2 <- c(pars2, setdiff(pars, c("alpha", + "beta", + "varying", + "b", + "long", + "event", + "assoc", + "fixef"))) pars <- pars2[!is.na(pars2)] } else { pars <- rownames(object$stan_summary) @@ -601,35 +659,52 @@ summary.stanmvreg <- function(object, pars = NULL, regex_pars = NULL, colnames(out)[stats %in% "se_mean"] <- "mcse" # Reorder rows of output table - nms_tmp <- rownames(out) - nms_tmp_y <- lapply(1:M, function(m) + nms_tmp <- rownames(out) + nms_tmp_y <- uapply(1:M, function(m) grep(paste0("^", get_stub(object), m, "\\|"), nms_tmp, value = TRUE)) - nms_tmp_e <- grep("^Event\\|", nms_tmp, value = TRUE) - nms_tmp_a <- grep("^Assoc\\|", nms_tmp, value = TRUE) - nms_tmp_b <- b_names(nms_tmp, value = TRUE) + nms_tmp_e <- grep("^Event\\|", nms_tmp, value = TRUE) + nms_tmp_a <- grep("^Assoc\\|", nms_tmp, value = TRUE) + nms_tmp_b <- b_names(nms_tmp, value = TRUE) nms_tmp_Sigma <- grep("^Sigma", nms_tmp, value = TRUE) - nms_tmp_lp <- grep("^log-posterior$", nms_tmp, value = TRUE) - out <- out[c(unlist(nms_tmp_y), nms_tmp_e, nms_tmp_a, nms_tmp_b, - nms_tmp_Sigma, nms_tmp_lp), , drop = FALSE] + nms_tmp_lp <- grep("^log-posterior$", nms_tmp, value = TRUE) + + out <- out[c(nms_tmp_y, + nms_tmp_e, + nms_tmp_a, + nms_tmp_b, + nms_tmp_Sigma, + nms_tmp_lp), , drop = FALSE] # Output object if (mvmer) - out <- structure( - out, y_vars = y_vars, family = fam, n_markers = object$n_markers, - n_yobs = object$n_yobs, n_grps = object$n_grps) + out <- structure(out, + y_vars = y_vars, + family = fam, + n_markers = object$n_markers, + n_yobs = object$n_yobs, + n_grps = object$n_grps) + if (surv) - out <- structure( - out, n_subjects = object$n_subjects, n_events = object$n_events, - basehaz = object$basehaz) + out <- structure(out, + n_subjects = object$n_subjects, + n_events = object$n_events, + basehaz = object$basehaz) + if (jm) - out <- structure( - out, id_var = object$id_var, time_var = object$time_var, assoc = assoc) + out <- structure(out, + id_var = object$id_var, + time_var = object$time_var, + assoc = assoc) + structure( - out, formula = object$formula, algorithm = object$algorithm, + out, + formula = object$formula, + algorithm = object$algorithm, stan_function = object$stan_function, posterior_sample_size = posterior_sample_size(object), - runtime = object$runtime, print.digits = digits, - class = c("summary.stanmvreg", "summary.stanreg")) + runtime = object$runtime, + print.digits = digits, + class = c("summary.stanmvreg", "summary.stanreg")) } #' @rdname summary.stanreg @@ -642,42 +717,42 @@ print.summary.stanmvreg <- function(x, digits = max(1, attr(x, "print.digits")), jm <- atts$stan_function == "stan_jm" tab <- if (jm) " " else "" cat("\nModel Info:\n") - cat("\n function: ", tab, atts$stan_function) + cat("\n function: ", tab, atts$stan_function) if (mvmer) { M <- atts$n_markers - stubs <- paste0("(", if (jm) "Long" else "y", 1:M, "):") + stubs <- paste0("(", if (jm) "Long" else "y", 1:M, "): ") for (m in 1:M) { cat("\n formula", stubs[m], formula_string(atts$formula[[m]])) cat("\n family ", stubs[m], atts$family[[m]]) } } if (jm) { - cat("\n formula (Event):", formula_string(atts$formula[["Event"]])) - cat("\n baseline hazard:", atts$basehaz$type_name) + cat("\n formula (Event): ", formula_string(atts$formula[["Event"]])) + cat("\n baseline hazard: ", atts$basehaz$type_name) assoc_fmt <- unlist(lapply(1:M, function(m) paste0(atts$assoc[[m]], " (Long", m, ")"))) - cat("\n assoc: ", paste(assoc_fmt, collapse = ", ")) + cat("\n assoc: ", paste(assoc_fmt, collapse = ", ")) } - cat("\n algorithm: ", tab, atts$algorithm) - cat("\n priors: ", tab, "see help('prior_summary')") + cat("\n algorithm: ", tab, atts$algorithm) if (!is.null(atts$posterior_sample_size) && atts$algorithm == "sampling") - cat("\n sample: ", tab, atts$posterior_sample_size, "(posterior sample size)") + cat("\n sample: ", tab, atts$posterior_sample_size, "(posterior sample size)") + cat("\n priors: ", tab, "see help('prior_summary')") if (mvmer) { obs_vals <- paste0(atts$n_yobs, " (", if (jm) "Long" else "y", 1:M, ")") - cat("\n num obs: ", tab, paste(obs_vals, collapse = ", ")) + cat("\n observations:", tab, paste(obs_vals, collapse = ", ")) } if (jm) { - cat("\n num subjects: ", atts$n_subjects) - cat(paste0("\n num events: ", atts$n_events, " (", + cat("\n subjects: ", atts$n_subjects) + cat(paste0("\n events: ", atts$n_events, " (", round(100 * atts$n_events/atts$n_subjects, 1), "%)")) } if (!is.null(atts$n_grps)) - cat("\n groups: ", tab, + cat("\n groups: ", tab, paste0(names(atts$n_grps), " (", unname(atts$n_grps), ")", collapse = ", ")) if (atts$algorithm == "sampling") { maxtime <- max(atts$runtime[, "total"]) if (maxtime == 0) maxtime <- "<0.1" - cat("\n runtime: ", tab, maxtime, "mins") + cat("\n runtime: ", tab, maxtime, "mins") } cat("\n\nEstimates:\n") @@ -747,6 +822,9 @@ allow_special_parnames <- function(object, pars) { # @param x stanreg object # @param ... Optionally include m to specify which submodel for stanmvreg models family_plus_link <- function(x, ...) { + if (is.stansurv(x)) { + return(NULL) + } fam <- family(x, ...) if (is.character(fam)) { stopifnot(identical(fam, x$method)) @@ -775,7 +853,7 @@ formula_string <- function(formula, break_and_indent = TRUE) { # get name of aux parameter based on family .aux_name <- function(object) { aux <- character() - if (!is_polr(object)) { + if (!is_polr(object) && !is.stansurv(object)) { aux <- .rename_aux(family(object)) if (is.na(aux)) { aux <- character() @@ -784,7 +862,6 @@ formula_string <- function(formula, break_and_indent = TRUE) { return(aux) } - # print anova table for stan_aov models # @param x stanreg object created by stan_aov() print_anova_table <- function(x, digits, ...) { @@ -807,3 +884,29 @@ print_anova_table <- function(x, digits, ...) { cat("\nANOVA-like table:\n") .printfr(anova_table, digits, ...) } + +# @param basehaz A list with info about the baseline hazard +basehaz_string <- function(basehaz, break_and_indent = TRUE) { + nm <- get_basehaz_name(basehaz) + switch(nm, + "exp" = "exponential", + "exp-aft" = "exponential, aft parameterisation", + "weibull" = "weibull", + "weibull-aft" = "weibull, aft parameterisation", + "gompertz" = "gompertz", + "ms" = "M-splines on hazard scale", + "bs" = "B-splines on log hazard scale", + "piecewise" = "piecewise constant on log hazard scale", + NULL) +} + +# @param x A logical (or a scalar to be evaluated as a logical). +yes_no_string <- function(x) { + if (x) "yes" else "no" +} + +# @param numer,denom The numerator and denominator with which to evaluate a %. +percent_string <- function(numer, denom) { + val <- round(100 * numer / denom, 1) + paste0("(", val, "%)") +} diff --git a/R/prior_summary.R b/R/prior_summary.R index e13a8ad9f..f0170e205 100644 --- a/R/prior_summary.R +++ b/R/prior_summary.R @@ -25,7 +25,8 @@ #' correspond to the intercept with the predictors as specified by the user #' (unmodified by \pkg{rstanarm}), but when \emph{specifying} the prior the #' intercept can be thought of as the expected outcome when the predictors are -#' set to their means. The only exception to this is for models fit with the +#' set to their means. The only exceptions to this are for models fit using +#' the \code{stan_surv} modelling function, or models fit with the #' \code{sparse} argument set to \code{TRUE} (which is only possible with a #' subset of the modeling functions and never the default). #' @@ -270,6 +271,42 @@ print.prior_summary.stanreg <- function(x, digits, ...) { ) } + # unique to stan_surv + if (stan_function == "stan_surv") { + if (!is.null(x[["priorEvent_intercept"]])) + .print_scalar_prior( + x[["priorEvent_intercept"]], + txt = paste0("Intercept"), # predictors not currently centered + formatters + ) + has_intercept <- !is.null(x[["priorEvent_intercept"]]) + if (!is.null(x[["priorEvent"]])) + .print_vector_prior( + x[["priorEvent"]], + txt = paste0(if (has_intercept) "\n", "Coefficients"), + formatters = formatters + ) + if (!is.null(x[["priorEvent_aux"]])) { + aux_name <- x[["priorEvent_aux"]][["aux_name"]] + aux_dist <- x[["priorEvent_aux"]][["dist"]] + if (aux_name %in% c("weibull-shape", "gompertz-scale")) { + if (aux_dist %in% c("normal", "student_t", "cauchy")) + x[["priorEvent_aux"]][["dist"]] <- paste0("half-", aux_dist) + .print_scalar_prior( + x[["priorEvent_aux"]], + txt = paste0("\nAuxiliary (", aux_name, ")"), + formatters + ) + } else { # ms, bs, piecewise + .print_vector_prior( + x[["priorEvent_aux"]], + txt = paste0("\nAuxiliary (", aux_name, ")"), + formatters + ) + } + } + } + # unique to stan_(g)lmer, stan_gamm4, stan_mvmer, or stan_jm if (!is.null(x[["prior_covariance"]])) .print_covariance_prior(x[["prior_covariance"]], txt = "\nCovariance", formatters) @@ -369,7 +406,6 @@ used.sparse <- function(x) { cat("\n Adjusted prior:") .cat_scalar_prior(p, adjusted = TRUE, prepend_chars = "\n ~") } - } .print_covariance_prior <- function(p, txt = "Covariance", formatters = list()) { @@ -425,8 +461,11 @@ used.sparse <- function(x) { p$df2 <- .format_pars(p$scale, .f1) } else if (p$dist %in% c("hs")) { p$df <- .format_pars(p$df, .f1) - } else if (p$dist %in% c("product_normal")) + } else if (p$dist %in% c("product_normal")){ p$df <- .format_pars(p$df, .f1) + } else if (p$dist %in% c("dirichlet")) { + p$concentration <- .format_pars(p$concentration, .f1) + } } .cat_vector_prior <- function(p, adjusted = FALSE, prepend_chars = "\n ~") { @@ -452,6 +491,8 @@ used.sparse <- function(x) { paste0("hs(df = ", .f1(p$df), ")") } else if (p$dist %in% c("R2")) { paste0("R2(location = ", .f1(p$location), ", what = '", p$what, "')") + } else if (p$dist %in% c("dirichlet")) { + paste0("dirichlet(concentration = ", .f1(p$concentration), ")") }) } diff --git a/R/ps_check.R b/R/ps_check.R index 1013857b4..d3d8362f9 100644 --- a/R/ps_check.R +++ b/R/ps_check.R @@ -1,118 +1,139 @@ # Part of the rstanarm package for estimating model parameters # Copyright (C) 2015, 2016, 2017 Trustees of Columbia University # Copyright (C) 2016, 2017 Sam Brilleman -# +# # This program is free software; you can redistribute it and/or # modify it under the terms of the GNU General Public License # as published by the Free Software Foundation; either version 3 # of the License, or (at your option) any later version. -# +# # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. -# +# # You should have received a copy of the GNU General Public License # along with this program; if not, write to the Free Software # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. # #' Graphical checks of the estimated survival function -#' -#' This function plots the estimated marginal survival function based on draws -#' from the posterior predictive distribution of the fitted joint model, and then -#' overlays the Kaplan-Meier curve based on the observed data. -#' +#' +#' This function plots the estimated standardised survival curve for the +#' estimation sample based on draws from the posterior distribution of the +#' fitted model, and then overlays a Kaplan-Meier survival curve based +#' on the observed data. +#' +#' @importFrom ggplot2 ggplot aes_string geom_step #' @export -#' @templateVar stanjmArg object +#' @templateVar stanregArg object #' @templateVar labsArg xlab,ylab #' @templateVar cigeomArg ci_geom_args -#' @template args-stanjm-object +#' @template args-stansurv-stanjm-object #' @template args-labs #' @template args-ci-geom-args -#' -#' @param check The type of plot to show. Currently only "survival" is -#' allowed, which compares the estimated marginal survival function under -#' the joint model to the estimated Kaplan-Meier curve based on the -#' observed data. +#' +#' @param check The type of plot to show. Currently only "survival" is +#' allowed, which compares the estimated standardised survival curve +#' based on the fitted model to the estimated Kaplan-Meier survival +#' curve based on the observed data. #' @param limits A quoted character string specifying the type of limits to #' include in the plot. Can be one of: \code{"ci"} for the Bayesian #' posterior uncertainty interval (often known as a credible interval); #' or \code{"none"} for no interval limits. -#' @param draws An integer indicating the number of MCMC draws to use to -#' to estimate the survival function. The default and maximum number of -#' draws is the size of the posterior sample. +#' @param draws Passed to the \code{draws} argument of +#' \code{\link{posterior_survfit}}. It must be an integer indicating the +#' number of MCMC draws to use to when evaluating the posterior estimate +#' of the standardised survival curve. The default is 400 for +#' \code{stan_surv} models or 200 for \code{stan_jm} models (or +#' equal to the posterior sample size if it is smaller than that). +#' @param npoints The number of time points at which to predict the survival +#' function. The plot of the survival curve is generated by interpolating +#' along these points using \code{\link[ggplot2]{geom_line}}. #' @param seed An optional \code{\link[=set.seed]{seed}} to use. #' @param ... Optional arguments passed to #' \code{\link[ggplot2:geom_path]{geom_line}} and used to control features #' of the plotted trajectory. -#' +#' #' @return A ggplot object that can be further customized using the #' \pkg{ggplot2} package. -#' -#' @seealso \code{\link{posterior_survfit}} for the estimated marginal or +#' +#' @seealso +#' \code{\link{posterior_survfit}} for the estimated marginal or #' subject-specific survival function based on draws of the model parameters -#' from the posterior distribution, -#' \code{\link{posterior_predict}} for drawing from the posterior -#' predictive distribution for the longitudinal submodel, and -#' \code{\link{pp_check}} for graphical checks of the longitudinal submodel. -#' +#' from the posterior distribution \cr +#' \code{\link{posterior_predict}} for drawing from the posterior +#' predictive distribution for the longitudinal submodel (for +#' \code{\link{stan_jm}} models only) \cr +#' \code{\link{pp_check}} for graphical checks of the longitudinal submodel +#' (for \code{\link{stan_jm}} models only) +#' #' @examples #' if (.Platform$OS.type != "windows" || .Platform$r_arch != "i386") { #' \donttest{ #' if (!exists("example_jm")) example(example_jm) #' # Compare estimated survival function to Kaplan-Meier curve #' ps <- ps_check(example_jm) -#' ps + -#' ggplot2::scale_color_manual(values = c("red", "black")) + # change colors -#' ggplot2::scale_size_manual(values = c(0.5, 3)) + # change line sizes -#' ggplot2::scale_fill_manual(values = c(NA, NA)) # remove fill -#' } +#' ps + +#' ggplot2::scale_color_manual(values = c("red", "black")) + # change colors +#' ggplot2::scale_size_manual (values = c(0.5, 3)) + # change line sizes +#' ggplot2::scale_fill_manual (values = c(NA, NA)) # remove fill #' } #' @importFrom ggplot2 ggplot aes_string geom_step -#' -ps_check <- function(object, check = "survival", - limits = c("ci", "none"), - draws = NULL, seed = NULL, - xlab = NULL, ylab = NULL, - ci_geom_args = NULL, ...) { +ps_check <- function(object, + check = "survival", + limits = c("ci", "none"), + draws = NULL, + npoints = 101, + seed = NULL, + xlab = NULL, + ylab = NULL, + ci_geom_args = NULL, + ...) { + if (!requireNamespace("survival")) stop("the 'survival' package must be installed to use this function") + + if (!any(is.stansurv(object), is.stanjm(object))) + stop("Object is not a 'stansurv' or 'stanjm' object.") + + if (is.stansurv(object) && object$ndelayed) + stop("'ps_check' cannot currently be used on models with delayed entry.") - validate_stanjm_object(object) limits <- match.arg(limits) - # Predictions for plotting the estimated survival function - dat <- posterior_survfit(object, standardise = TRUE, - condition = FALSE, - times = 0, extrapolate = TRUE, - draws = draws, seed = seed) - - # Estimate KM curve based on response from the event submodel - form <- reformulate("1", response = formula(object)$Event[[2]]) - coxdat <- object$survmod$mod$y - if (is.null(coxdat)) - stop("Bug found: no response y found in the 'survmod' component of the ", - "fitted joint model.") - resp <- attr(coxdat, "type") - if (resp == "right") { - form <- formula(survival::Surv(time, status) ~ 1) - } else if (resp == "counting") { - form <- formula(survival::Surv(start, stop, time) ~ 1) - } else { - stop("Bug found: only 'right' or 'counting' survival outcomes should ", - "have been allowed as the response type in the fitted joint model.") - } - km <- survival::survfit(form, data = as.data.frame(unclass(coxdat))) - kmdat <- data.frame(times = km$time, surv = km$surv, - lb = km$lower, ub = km$upper) - - # Plot estimated survival function with KM curve overlaid - graph <- plot.survfit.stanjm(dat, ids = NULL, limits = limits, ...) - kmgraph <- geom_step(data = kmdat, - mapping = aes_string(x = "times", y = "surv")) - graph + kmgraph -} + # Obtain standardised survival probabilities for the fitted model + dat <- posterior_survfit(object, + times = 0, + extrapolate = TRUE, + standardise = TRUE, + condition = FALSE, + draws = draws, + seed = seed, + control = list(epoints = npoints)) + # Obtain the response variable for the fitted model + response <- get_surv(object) + if (is.null(response)) + stop("Bug found: no response variable found in fitted model object.") + # Obtain the formula for KM curve + type <- attr(response, "type") + form <- switch( + type, + right = formula(survival::Surv(time, status, type = type) ~ 1), + counting = formula(survival::Surv(start, stop, status, type = type) ~ 1), + interval = formula(survival::Surv(time1, time2, status, type = 'interval') ~ 1), + interval2 = formula(survival::Surv(time1, time2, status, type = 'interval') ~ 1), + stop("Bug found: invalid type of survival object.")) + + + # Obtain the KM estimates + kmfit <- survival::survfit(form, data = data.frame(unclass(response))) + kmdat <- data.frame(times = kmfit$time, surv = kmfit$surv) + # Plot estimated survival function with KM curve overlaid + psgraph <- plot.survfit.stanjm(dat, ids = NULL, limits = limits, ...) + kmgraph <- geom_step(aes_string(x = "times", y = "surv"), kmdat) + psgraph + kmgraph +} diff --git a/R/stan_surv.R b/R/stan_surv.R new file mode 100644 index 000000000..25c65d191 --- /dev/null +++ b/R/stan_surv.R @@ -0,0 +1,2332 @@ +# Part of the rstanarm package for estimating model parameters +# Copyright (C) 2018 Sam Brilleman +# Copyright (C) 2018 Trustees of Columbia University +# +# This program is free software; you can redistribute it and/or +# modify it under the terms of the GNU General Public License +# as published by the Free Software Foundation; either version 3 +# of the License, or (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program; if not, write to the Free Software +# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. + +#' Bayesian survival models via Stan +#' +#' \if{html}{\figure{stanlogo.png}{options: width="25px" alt="http://mc-stan.org/about/logo/"}} +#' Bayesian inference for survival models (sometimes known as models for +#' time-to-event data). Currently, the command fits: +#' (i) flexible parametric (cubic spline-based) survival +#' models on the hazard scale, with covariates included under assumptions of +#' either proportional or non-proportional hazards; +#' (ii) standard parametric (exponential, Weibull and Gompertz) survival +#' models on the hazard scale, with covariates included under assumptions of +#' either proportional or non-proportional hazards; and +#' (iii) standard parametric (exponential, Weibull) accelerated failure time +#' models, with covariates included under assumptions of either time-fixed or +#' time-varying survival time ratios. Left, right, and interval censored +#' survival data are allowed. Delayed entry is allowed. Both fixed and random +#' effects can be estimated for covariates (i.e. group-specific parameters +#' are allowed). Time-varying covariates and time-varying coefficients are +#' both allowed. For modelling each time-varying coefficient (i.e. time-varying +#' log hazard ratio or time-varying log survival time ratio) the user can +#' choose between either a smooth B-spline function or a piecewise constant +#' function. +#' +#' @export +#' @importFrom splines bs +#' @import splines2 +#' +#' @template args-dots +#' @template args-priors +#' @template args-prior_covariance +#' @template args-prior_PD +#' @template args-algorithm +#' @template args-adapt_delta +#' +#' @param formula A two-sided formula object describing the model. +#' The left hand side of the formula should be a \code{Surv()} +#' object. Left censored, right censored, and interval censored data +#' are allowed, as well as delayed entry (i.e. left truncation). See +#' \code{\link[survival]{Surv}} for how to specify these outcome types. +#' The right hand side of the formula can include fixed and/or random +#' effects of covariates, with random effects specified in the same +#' way as for the \code{\link[lme4]{lmer}} function in the \pkg{lme4} +#' package. If you wish to include time-varying effects (i.e. time-varying +#' coefficients, e.g. non-proportional hazards) in the model +#' then any covariate(s) that you wish to estimate a time-varying +#' coefficient for should be specified as \code{tve(varname)} where +#' \code{varname} is the name of the covariate. For more information on +#' how time-varying effects are formulated see the documentation +#' for the \code{\link{tve}} function as well as the \strong{Details} +#' and \strong{Examples} sections below. +#' @param data A data frame containing the variables specified in +#' \code{formula}. +#' @param basehaz A character string indicating which baseline hazard or +#' baseline survival distribution to use for the event submodel. +#' +#' The following are available under a hazard scale formulation: +#' \itemize{ +#' \item \code{"ms"}: A flexible parametric model using cubic M-splines to +#' model the baseline hazard. The default locations for the internal knots, +#' as well as the basis terms for the splines, are calculated with respect +#' to time. If the model does \emph{not} include any time-dependendent +#' effects then a closed form solution is available for both the hazard +#' and cumulative hazard and so this approach should be relatively fast. +#' On the other hand, if the model does include time-varying effects then +#' quadrature is used to evaluate the cumulative hazard at each MCMC +#' iteration and, therefore, estimation of the model will be slower. +#' \item \code{"bs"}: A flexible parametric model using cubic B-splines to +#' model the \emph{log} baseline hazard. The default locations for the +#' internal knots, as well as the basis terms for the splines, are calculated +#' with respect to time. A closed form solution for the cumulative hazard +#' is \strong{not} available regardless of whether or not the model includes +#' time-varying effects; instead, quadrature is used to evaluate +#' the cumulative hazard at each MCMC iteration. Therefore, if your model +#' does not include any time-varying effects, then estimation using the +#' \code{"ms"} baseline hazard will be faster. +#' \item \code{"exp"}: An exponential distribution for the event times +#' (i.e. a constant baseline hazard). +#' \item \code{"weibull"}: A Weibull distribution for the event times. +#' \item \code{"gompertz"}: A Gompertz distribution for the event times. +#' } +#' +#' The following are available under an accelerated failure time (AFT) +#' formulation: +#' \itemize{ +#' \item \code{"exp-aft"}: an exponential distribution for the event times. +#' \item \code{"weibull-aft"}: a Weibull distribution for the event times. +#' } +#' @param basehaz_ops A named list specifying options related to the baseline +#' hazard. Currently this can include: \cr +#' \itemize{ +#' \item \code{degree}: A positive integer specifying the degree for the +#' M-splines or B-splines. The default is \code{degree = 3}, which +#' corresponds to cubic splines. Note that specifying \code{degree = 0} +#' is also allowed and corresponds to piecewise constant. +#' \item \code{df}: A positive integer specifying the degrees of freedom +#' for the M-splines or B-splines. For M-splines (i.e. when +#' \code{basehaz = "ms"}), two boundary knots and \code{df - degree - 1} +#' internal knots are used to generate the spline basis. For B-splines +#' (i.e. when \code{basehaz = "bs"}), two boundary knots and +#' \code{df - degree} internal knots are used to generate the spline +#' basis. The difference is due to the fact that the M-spline basis +#' includes an intercept, whereas the B-spline basis does not. The +#' default is \code{df = 6} for M-splines and \code{df = 5} for +#' B-splines (i.e. two boundary knots and two internal knots when the +#' default cubic splines are being used). The internal knots are placed +#' at equally spaced percentiles of the distribution of uncensored event +#' times. +#' \item \code{knots}: A numeric vector explicitly specifying internal +#' knot locations for the M-splines or B-splines. Note that \code{knots} +#' cannot be specified if \code{df} is specified. +#' } +#' Note that for the M-splines and B-splines -- in addition to any internal +#' \code{knots} -- a lower boundary knot is placed at the earliest entry time +#' and an upper boundary knot is placed at the latest event or censoring time. +#' These boundary knot locations are the default and cannot be changed by the +#' user. +#' @param qnodes The number of nodes to use for the Gauss-Kronrod quadrature +#' that is used to evaluate the cumulative hazard when \code{basehaz = "bs"} +#' or when time-varying effects are specified in the linear predictor. +#' Options are 15 (the default), 11 or 7. +#' @param prior_intercept The prior distribution for the intercept in the +#' linear predictor. All models include an intercept parameter. +#' \code{prior_intercept} can be a call to \code{normal}, +#' \code{student_t} or \code{cauchy}. See the \link[=priors]{priors help page} +#' for details on these functions. However, note that default scale for +#' \code{prior_intercept} is 20 for \code{stan_surv} models (rather than 10, +#' which is the default scale used for \code{prior_intercept} by most +#' \pkg{rstanarm} modelling functions). To omit a prior on the intercept +#' ---i.e., to use a flat (improper) uniform prior--- \code{prior_intercept} +#' can be set to \code{NULL}. +#' +#' \strong{Note:} The prior distribution for the intercept is set so it +#' applies to the value \emph{when all predictors are centered} and with an +#' adjustment ("constant shift") equal to the \emph{log crude event rate}. +#' However, the reported \emph{estimates} for the intercept always correspond +#' to a parameterization without centered predictors and without the +#' "constant shift". That is, these adjustments are made internally to help +#' with numerical stability and sampling, but the necessary +#' back-transformations are made so that they are not relevant for the +#' estimates returned to the user. +#' @param prior_aux The prior distribution for "auxiliary" parameters related +#' to the baseline hazard. The relevant parameters differ depending +#' on the type of baseline hazard specified in the \code{basehaz} +#' argument. The following applies (for further technical details, +#' refer to the \emph{stan_surv: Survival (Time-to-Event) Models vignette}): +#' \itemize{ +#' \item \code{basehaz = "ms"}: the auxiliary parameters are the +#' coefficients for the M-spline basis terms on the baseline hazard. +#' These coefficients are defined using a simplex; that is, they are +#' all between 0 and 1, and constrained to sum to 1. This constraint +#' is necessary for identifiability of the intercept in the linear +#' predictor. The default prior is a Dirichlet distribution with all +#' concentration parameters set equal to 1. That is, a uniform +#' prior over all points defined within the support of the simplex. +#' Specifying all concentration parameters equal and > 1 supports a more +#' even distribution (i.e. a smoother spline function), while specifying a +#' all concentration parameters equal and < 1 supports a more sparse +#' distribution (i.e. a less smooth spline function). +#' \item \code{basehaz = "bs"}: the auxiliary parameters are the +#' coefficients for the B-spline basis terms on the log baseline hazard. +#' These parameters are unbounded. The default prior is a normal +#' distribution with mean 0 and scale 20. +#' \item \code{basehaz = "exp"} or \code{basehaz = "exp-aft"}: +#' there is \strong{no} auxiliary parameter, +#' since the log scale parameter for the exponential distribution is +#' incorporated as an intercept in the linear predictor. +#' \item \code{basehaz = "weibull"} or \code{basehaz = "weibull-aft"}: +#' the auxiliary parameter is the Weibull +#' shape parameter, while the log scale parameter for the Weibull +#' distribution is incorporated as an intercept in the linear predictor. +#' The auxiliary parameter has a lower bound at zero. The default prior is +#' a half-normal distribution with mean 0 and scale 2. +#' \item \code{basehaz = "gompertz"}: the auxiliary parameter is the Gompertz +#' scale parameter, while the log shape parameter for the Gompertz +#' distribution is incorporated as an intercept in the linear predictor. +#' The auxiliary parameter has a lower bound at zero. The default prior is +#' a half-normal distribution with mean 0 and scale 0.5. +#' } +#' Currently, \code{prior_aux} can be a call to \code{dirichlet}, +#' \code{normal}, \code{student_t}, \code{cauchy} or \code{exponential}. +#' See \code{\link{priors}} for details on these functions. Note that not +#' all prior distributions are allowed with all types of baseline hazard. +#' To omit a prior ---i.e., to use a flat (improper) uniform prior--- set +#' \code{prior_aux} to \code{NULL}. +#' @param prior_smooth This is only relevant when time-varying effects are +#' specified in the model (i.e. the \code{tve()} function is used in the +#' model formula. When that is the case, \code{prior_smooth} determines the +#' prior distribution given to the hyperparameter (standard deviation) +#' contained in a random-walk prior for the parameters of the function +#' used to generate the time-varying coefficient (i.e. the B-spline +#' coefficients when a B-spline function is used to model the time-varying +#' coefficient, or the deviations in the log hazard ratio specific to each +#' time interval when a piecewise constant function is used to model the +#' time-varying coefficient). Lower values for the hyperparameter +#' yield a less flexible function for the time-varying coefficient. +#' Specifically, \code{prior_smooth} can be a call to \code{exponential} to +#' use an exponential distribution, or \code{normal}, \code{student_t} or +#' \code{cauchy}, which results in a half-normal, half-t, or half-Cauchy +#' prior. See \code{\link{priors}} for details on these functions. To omit a +#' prior ---i.e., to use a flat (improper) uniform prior--- set +#' \code{prior_smooth} to \code{NULL}. The number of hyperparameters depends +#' on the model specification (i.e. the number of time-varying effects +#' specified in the model) but a scalar prior will be recycled as necessary +#' to the appropriate length. +#' +#' @details +#' \subsection{Model formulations}{ +#' Let \eqn{h_i(t)} denote the hazard for individual \eqn{i} at time +#' \eqn{t}, \eqn{h_0(t)} the baseline hazard at time \eqn{t}, \eqn{X_i} +#' a vector of covariates for individual \eqn{i}, \eqn{\beta} a vector of +#' coefficients, \eqn{S_i(t)} the survival probability for individual +#' \eqn{i} at time \eqn{t}, and \eqn{S_0(t)} the baseline survival +#' probability at time \eqn{t}. Without time-varying effects in the +#' model formula our linear predictor is \eqn{\eta_i = X_i \beta}, whereas +#' with time-varying effects in the model formula our linear predictor +#' is \eqn{\eta_i(t) = X_i(t) \beta(t)}. Then the following definitions of +#' the hazard function and survival function apply: +#' +#' \tabular{llll}{ +#' \strong{Scale } \tab +#' \strong{TVE } \tab +#' \strong{Hazard } \tab +#' \strong{Survival } \cr +#' \emph{Hazard} \tab +#' \emph{No} \tab +#' \eqn{h_i(t) = h_0(t) \exp(\eta_i)} \tab +#' \eqn{S_i(t) = [S_0(t)]^{\exp(\eta_i)}} \cr +#' \emph{Hazard} \tab +#' \emph{Yes} \tab +#' \eqn{h_i(t) = h_0(t) \exp(\eta_i(t))} \tab +#' \eqn{S_i(t) = \exp(- \int_0^t h_i(u) du )} \cr +#' \emph{AFT} \tab +#' \emph{No} \tab +#' \eqn{h_i(t) = \exp(-\eta_i) h_0 (t \exp(-\eta_i))} \tab +#' \eqn{S_i(t) = S_0 ( t \exp(-\eta_i) )} \cr +#' \emph{AFT} \tab +#' \emph{Yes} \tab +#' \eqn{h_i(t) = \exp(-\eta_i(t)) h_0(\int_0^t \exp(-\eta_i(u)) du)} \tab +#' \eqn{S_i(t) = S_0 (\int_0^t \exp(-\eta_i(u)) du)} \cr +#' } +#' +#' where \emph{AFT} stands for an accelerated failure time formulation, +#' and \emph{TVE} stands for time-varying effects in the model formula. +#' +#' For models without time-varying effects, the value of \eqn{S_i(t)} can +#' be calculated analytically (with the one exception being when B-splines +#' are used to model the log baseline hazard, i.e. \code{basehaz = "bs"}). +#' +#' For models with time-varying effects \eqn{S_i(t)} cannot be calculated +#' analytically and so Gauss-Kronrod quadrature is used to approximate the +#' relevant integral. The number of nodes used in the quadrature can be +#' controlled via the \code{nodes} argument. +#' +#' For models estimated on the hazard scale, a hazard ratio can be calculated +#' as \eqn{\exp(\beta)}. For models estimated on the AFT scale, a survival +#' time ratio can be calculated as \eqn{\exp(\beta)} and an acceleration +#' factor can be calculated as \eqn{\exp(-\beta)}. +#' +#' Note that the \emph{stan_surv: Survival (Time-to-Event) Models} vignette +#' provides more extensive details on the model formulations, including the +#' parameterisations for each of the parametric distributions. +#' } +#' \subsection{Time-varying effects}{ +#' By default, any covariate effects specified in the \code{formula} are +#' included in the model under a proportional hazards assumption (for models +#' estimated using a hazard scale formulation) or under the assumption of +#' time-fixed acceleration factors (for models estimated using an accelerated +#' failure time formulation). +#' +#' To relax this assumption, it is possible to +#' estimate a time-varying effect (i.e. a time-varying coefficient) for a +#' given covariate. A time-varying effect is specified in the model +#' \code{formula} by wrapping the covariate name in the \code{\link{tve}} +#' function. +#' +#' The following applies: +#' +#' \itemize{ +#' \item Estimating a time-varying effect within a hazard scale model +#' formulation (i.e. when \code{basehaz} is set equal to \code{"ms"}, +#' \code{"bs"}, \code{"exp"}, \code{"weibull"} or \code{"gompertz"}) leads +#' to the estimation of a time-varying hazard ratio for the relevant +#' covariate (i.e. non-proportional hazards). +#' \item Estimating a time-varying effect within an accelerated failure +#' time model formulation (i.e. when \code{basehaz} is set equal to +#' \code{"exp-aft"}, or \code{"weibull-aft"}) leads to the estimation of a +#' time-varying survival time ratio -- or equivalently, a time-varying +#' acceleration factor -- for the relevant covariate. +#' } +#' +#' For example, if we wish to estimate a time-varying effect for the +#' covariate \code{sex} then we can specify \code{tve(sex)} in the +#' \code{formula}, e.g. \code{Surv(time, status) ~ tve(sex) + age + trt}. +#' The coefficient for \code{sex} will then be modelled using a flexible +#' smooth function based on a cubic B-spline expansion of time. +#' +#' Additional arguments used to control the modelling of the time-varying +#' effect are explained in the \code{\link{tve}} documentation. +#' Of particular note is the fact that a piecewise constant basis is +#' allowed as a special case of the B-splines. For example, specifying +#' \code{tve(sex, degree = 0)} in the model formula instead of just +#' \code{tve(sex)} would request a piecewise constant time-varying effect. +#' The user can also control the degrees of freedom or knot locations for +#' the B-spline (or piecewise constant) function. +#' +#' It is worth noting that an additional way to control the +#' flexibility of the function used to model the time-varying effect +#' is through priors. A random walk prior is used for the piecewise +#' constant or B-spline coefficients, and the hyperparameter (standard +#' deviation) of the random walk prior can be controlled via the +#' \code{prior_smooth} argument. This is a more indirect way to +#' control the "smoothness" of the function used to model the time-varying +#' effect, but it nonetheless might be useful in some settings. The +#' \emph{stan_surv: Survival (Time-to-Event) Models} vignette provides +#' more explicit details on the formulation of the time-varying effects +#' and the prior distributions used for their coefficients. +#' +#' It is worth noting that reliable estimation of a time-varying effect +#' usually requires a relatively large number of events in the data (e.g. +#' say >1000 depending on the setting). +#' } +#' +#' @examples +#' if (.Platform$OS.type != "windows" || .Platform$r_arch != "i386") { +#' \donttest{ +#' #----- Proportional hazards +#' +#' # Simulated data +#' library(simsurv) +#' covs <- data.frame(id = 1:200, +#' trt = stats::rbinom(200, 1L, 0.5)) +#' d1 <- simsurv(lambdas = 0.1, +#' gammas = 1.5, +#' betas = c(trt = -0.5), +#' x = covs, +#' maxt = 5) +#' d1 <- merge(d1, covs) +#' f1 <- Surv(eventtime, status) ~ trt +#' m1a <- stan_surv(f1, d1, basehaz = "ms", chains=1,refresh=0,iter=600) +#' m1b <- stan_surv(f1, d1, basehaz = "exp", chains=1,refresh=0,iter=600) +#' m1c <- stan_surv(f1, d1, basehaz = "weibull", chains=1,refresh=0,iter=600) +#' m1d <- stan_surv(f1, d1, basehaz = "gompertz", chains=1,refresh=0,iter=600) +#' get_est <- function(x) { fixef(x)["trt"] } +#' do.call(rbind, lapply(list(m1a, m1b, m1c, m1d), get_est)) +#' bayesplot::bayesplot_grid(plot(m1a), # compare baseline hazards +#' plot(m1b), +#' plot(m1c), +#' plot(m1d), +#' ylim = c(0, 0.8)) +#' +#' #----- Left and right censored data +#' +#' # Mice tumor data +#' m2 <- stan_surv(Surv(l, u, type = "interval2") ~ grp, +#' data = mice, chains = 1, refresh = 0, iter = 600) +#' print(m2, 4) +#' +#' #----- Non-proportional hazards - B-spline tve() +#' +#' # Simulated data +#' library(simsurv) +#' covs <- data.frame(id = 1:250, +#' trt = stats::rbinom(250, 1L, 0.5)) +#' d3 <- simsurv(lambdas = 0.1, +#' gammas = 1.5, +#' betas = c(trt = -0.5), +#' tve = c(trt = 0.2), +#' x = covs, +#' maxt = 5) +#' d3 <- merge(d3, covs) +#' m3 <- stan_surv(Surv(eventtime, status) ~ tve(trt), +#' data = d3, chains = 1, refresh = 0, iter = 600) +#' print(m3, 4) +#' plot(m3, "tve") # time-varying hazard ratio +#' +#' #----- Non-proportional hazards - piecewise constant tve() +#' +#' # Simulated data +#' library(simsurv) +#' covs <- data.frame(id = 1:250, +#' trt = stats::rbinom(250, 1L, 0.5)) +#' d4 <- simsurv(lambdas = 0.1, +#' gammas = 1.5, +#' betas = c(trt = -0.5), +#' tve = c(trt = 0.4), +#' tvefun = function(t) { (t > 2.5) }, +#' x = covs, +#' maxt = 5) +#' d4 <- merge(d4, covs) +#' m4 <- stan_surv(Surv(eventtime, status) ~ +#' tve(trt, degree = 0, knots = c(2.5)), +#' data = d4, chains = 1, refresh = 0, iter = 600) +#' print(m4, 4) +#' plot(m4, "tve") # time-varying hazard ratio +#' +#' #---------- Compare PH and AFT parameterisations +#' +#' # Breast cancer data +#' sel <- sample(1:nrow(bcancer), 100) +#' +#' m_ph <- stan_surv(Surv(recyrs, status) ~ group, +#' data = bcancer[sel,], +#' basehaz = "weibull", +#' chains = 1, +#' refresh = 0, +#' iter = 600, +#' seed = 123) +#' m_aft <- stan_surv(Surv(recyrs, status) ~ group, +#' data = bcancer[sel,], +#' basehaz = "weibull-aft", +#' chains = 1, +#' refresh = 0, +#' iter = 600, +#' seed = 123) +#' +#' exp(fixef(m_ph)) [c('groupMedium', 'groupPoor')] # hazard ratios +#' exp(fixef(m_aft))[c('groupMedium', 'groupPoor')] # survival time ratios +#' +#' # same model (...slight differences due to sampling) +#' summary(m_ph, par = "log-posterior")[, 'mean'] +#' summary(m_aft, par = "log-posterior")[, 'mean'] +#' +#' #----- Frailty model, i.e. site-specific intercepts +#' +#' m_frail <- stan_surv( +#' formula = Surv(eventtime, status) ~ trt + (1 | site), +#' data = frail[1:40,], +#' basehaz = "exp", +#' chains = 1, +#' refresh = 0, +#' iter = 600, +#' seed = 123) +#' print(m_frail) # shows SD for frailty +#' VarCorr(m_frail) # extract SD explicitly +#' +#' } +#' } +#' +stan_surv <- function(formula, + data, + basehaz = "ms", + basehaz_ops, + qnodes = 15, + prior = normal(), + prior_intercept = normal(), + prior_aux, + prior_smooth = exponential(autoscale = FALSE), + prior_covariance = decov(), + prior_PD = FALSE, + algorithm = c("sampling", "meanfield", "fullrank"), + adapt_delta = 0.95, ...) { + + #----------------------------- + # Pre-processing of arguments + #----------------------------- + + if (!requireNamespace("survival")) + stop("the 'survival' package must be installed to use this function.") + + if (missing(basehaz_ops)) + basehaz_ops <- NULL + if (missing(data) || !inherits(data, "data.frame")) + stop("'data' must be a data frame.") + + dots <- list(...) + algorithm <- match.arg(algorithm) + + formula <- parse_formula_and_data(formula, data) + data <- formula$data; formula[["data"]] <- NULL + + #---------------- + # Construct data + #---------------- + + #----- model frame stuff + + mf_stuff <- make_model_frame(formula$tf_form, data, drop.unused.levels = TRUE) + + mf <- mf_stuff$mf # model frame + mt <- mf_stuff$mt # model terms + + #----- dimensions and response vectors + + # entry and exit times for each row of data + t_beg <- make_t(mf, type = "beg") # entry time + t_end <- make_t(mf, type = "end") # exit time + t_upp <- make_t(mf, type = "upp") # upper time for interval censoring + + # ensure no event or censoring times are zero (leads to degenerate + # estimate for log hazard for most baseline hazards, due to log(0)) + check1 <- any(t_end <= 0, na.rm = TRUE) + check2 <- any(t_upp <= 0, na.rm = TRUE) + if (check1 || check2) + stop2("All event and censoring times must be greater than 0.") + + # event indicator for each row of data + status <- make_d(mf) + + if (any(is.na(status))) + stop2("Invalid status indicator in Surv object.") + + if (any(status < 0 | status > 3)) + stop2("Invalid status indicator in Surv object.") + + # delayed entry indicator for each row of data + delayed <- as.logical(!t_beg == 0) + + # time variables for stan + t_event <- aa(t_end[status == 1]) # exact event time + t_lcens <- aa(t_end[status == 2]) # left censoring time + t_rcens <- aa(t_end[status == 0]) # right censoring time + t_icenl <- aa(t_end[status == 3]) # lower limit of interval censoring time + t_icenu <- aa(t_upp[status == 3]) # upper limit of interval censoring time + t_delay <- aa(t_beg[delayed]) # delayed entry time + + # calculate log crude event rate + t_tmp <- sum(rowMeans(cbind(t_end, t_upp), na.rm = TRUE) - t_beg) + d_tmp <- sum(!status == 0) + log_crude_event_rate <- log(d_tmp / t_tmp) + if (is.infinite(log_crude_event_rate)) + log_crude_event_rate <- 0 # avoids error when there are zero events + + # dimensions + nevent <- sum(status == 1) + nrcens <- sum(status == 0) + nlcens <- sum(status == 2) + nicens <- sum(status == 3) + ndelay <- sum(delayed) + + #----- baseline hazard + + ok_basehaz <- c("exp", + "exp-aft", + "weibull", + "weibull-aft", + "gompertz", + "ms", + "bs") + basehaz <- handle_basehaz_surv(basehaz = basehaz, + basehaz_ops = basehaz_ops, + ok_basehaz = ok_basehaz, + times = t_end, + status = status, + min_t = min(t_beg), + max_t = max(c(t_end,t_upp), na.rm = TRUE)) + nvars <- basehaz$nvars # number of basehaz aux parameters + + # flag if intercept is required for baseline hazard + has_intercept <- ai(has_intercept(basehaz)) + + # flag if AFT specification + is_aft <- get_basehaz_name(basehaz) %in% c("exp-aft", "weibull-aft") + + #----- define dimensions and times for quadrature + + # flag if formula uses time-varying effects + has_tve <- !is.null(formula$td_form) + + # flag if closed form available for cumulative baseline hazard + has_closed_form <- check_for_closed_form(basehaz) + + # flag for quadrature + has_quadrature <- has_tve || !has_closed_form + + if (has_quadrature) { # model uses quadrature + + # standardised nodes and weights for quadrature + qq <- get_quadpoints(nodes = qnodes) + qp <- qq$points + qw <- qq$weights + + # quadrature points, evaluated for each row of data + qpts_event <- uapply(qp, unstandardise_qpts, 0, t_event) + qpts_lcens <- uapply(qp, unstandardise_qpts, 0, t_lcens) + qpts_rcens <- uapply(qp, unstandardise_qpts, 0, t_rcens) + qpts_icenl <- uapply(qp, unstandardise_qpts, 0, t_icenl) + qpts_icenu <- uapply(qp, unstandardise_qpts, 0, t_icenu) + qpts_delay <- uapply(qp, unstandardise_qpts, 0, t_delay) + + # quadrature weights, evaluated for each row of data + qwts_event <- uapply(qw, unstandardise_qwts, 0, t_event) + qwts_lcens <- uapply(qw, unstandardise_qwts, 0, t_lcens) + qwts_rcens <- uapply(qw, unstandardise_qwts, 0, t_rcens) + qwts_icenl <- uapply(qw, unstandardise_qwts, 0, t_icenl) + qwts_icenu <- uapply(qw, unstandardise_qwts, 0, t_icenu) + qwts_delay <- uapply(qw, unstandardise_qwts, 0, t_delay) + + # times at events and all quadrature points + cpts_list <- list(t_event, + qpts_event, + qpts_lcens, + qpts_rcens, + qpts_icenl, + qpts_icenu, + qpts_delay) + idx_cpts <- get_idx_array(sapply(cpts_list, length)) + cpts <- unlist(cpts_list) # as vector + + # number of quadrature points + qevent <- length(qwts_event) + qlcens <- length(qwts_lcens) + qrcens <- length(qwts_rcens) + qicens <- length(qwts_icenl) + qdelay <- length(qwts_delay) + + } else { + + # times at all different event types + cpts_list <- list(t_event, + t_lcens, + t_rcens, + t_icenl, + t_icenu, + t_delay) + idx_cpts <- get_idx_array(sapply(cpts_list, length)) + cpts <- unlist(cpts_list) # as vector + + # dud entries for stan + qpts_event <- rep(0,0) + qpts_lcens <- rep(0,0) + qpts_rcens <- rep(0,0) + qpts_icenl <- rep(0,0) + qpts_icenu <- rep(0,0) + qpts_delay <- rep(0,0) + + if (!qnodes == 15) # warn user if qnodes is not equal to the default + warning2("There is no quadrature required so 'qnodes' is being ignored.") + + } + + #----- basis terms for baseline hazard + + if (!has_quadrature) { + + basis_event <- make_basis(t_event, basehaz) + ibasis_event <- make_basis(t_event, basehaz, integrate = TRUE) + ibasis_lcens <- make_basis(t_lcens, basehaz, integrate = TRUE) + ibasis_rcens <- make_basis(t_rcens, basehaz, integrate = TRUE) + ibasis_icenl <- make_basis(t_icenl, basehaz, integrate = TRUE) + ibasis_icenu <- make_basis(t_icenu, basehaz, integrate = TRUE) + ibasis_delay <- make_basis(t_delay, basehaz, integrate = TRUE) + + } else { + + basis_epts_event <- make_basis(t_event, basehaz) + basis_qpts_event <- make_basis(qpts_event, basehaz) + basis_qpts_lcens <- make_basis(qpts_lcens, basehaz) + basis_qpts_rcens <- make_basis(qpts_rcens, basehaz) + basis_qpts_icenl <- make_basis(qpts_icenl, basehaz) + basis_qpts_icenu <- make_basis(qpts_icenu, basehaz) + basis_qpts_delay <- make_basis(qpts_delay, basehaz) + + } + + #----- model frames for generating predictor matrices + + mf_event <- keep_rows(mf, status == 1) + mf_lcens <- keep_rows(mf, status == 2) + mf_rcens <- keep_rows(mf, status == 0) + mf_icens <- keep_rows(mf, status == 3) + mf_delay <- keep_rows(mf, delayed) + + if (!has_quadrature) { + + # combined model frame, without quadrature + mf_cpts <- rbind(mf_event, + mf_lcens, + mf_rcens, + mf_icens, + mf_icens, + mf_delay) + + } else { + + # combined model frame, with quadrature + mf_cpts <- rbind(mf_event, + rep_rows(mf_event, times = qnodes), + rep_rows(mf_lcens, times = qnodes), + rep_rows(mf_rcens, times = qnodes), + rep_rows(mf_icens, times = qnodes), + rep_rows(mf_icens, times = qnodes), + rep_rows(mf_delay, times = qnodes)) + + } + + if (has_tve) { + + # generate a model frame with time transformations for tve effects + mf_tve <- make_model_frame(formula$tt_frame, data.frame(times__ = cpts))$mf + + # NB next line avoids dropping terms attribute from 'mf_cpts' + mf_cpts[, colnames(mf_tve)] <- mf_tve + + } + + #----- time-fixed predictor matrices + + ff <- formula$fe_form + x <- make_x(ff, mf )$x + x_cpts <- make_x(ff, mf_cpts)$x + x_centred <- sweep(x_cpts, 2, colMeans(x), FUN = "-") + K <- ncol(x_cpts) + + if (!has_quadrature) { + + # time-fixed predictor matrices, without quadrature + # NB skip index 5 on purpose, since time fixed predictor matrix is + # identical for lower and upper limits of interval censoring time + x_event <- x_centred[idx_cpts[1,1]:idx_cpts[1,2], , drop = FALSE] + x_lcens <- x_centred[idx_cpts[2,1]:idx_cpts[2,2], , drop = FALSE] + x_rcens <- x_centred[idx_cpts[3,1]:idx_cpts[3,2], , drop = FALSE] + x_icens <- x_centred[idx_cpts[4,1]:idx_cpts[4,2], , drop = FALSE] + x_delay <- x_centred[idx_cpts[6,1]:idx_cpts[6,2], , drop = FALSE] + + } else { + + # time-fixed predictor matrices, with quadrature + # NB skip index 6 on purpose, since time fixed predictor matrix is + # identical for lower and upper limits of interval censoring time + x_epts_event <- x_centred[idx_cpts[1,1]:idx_cpts[1,2], , drop = FALSE] + x_qpts_event <- x_centred[idx_cpts[2,1]:idx_cpts[2,2], , drop = FALSE] + x_qpts_lcens <- x_centred[idx_cpts[3,1]:idx_cpts[3,2], , drop = FALSE] + x_qpts_rcens <- x_centred[idx_cpts[4,1]:idx_cpts[4,2], , drop = FALSE] + x_qpts_icens <- x_centred[idx_cpts[5,1]:idx_cpts[5,2], , drop = FALSE] + x_qpts_delay <- x_centred[idx_cpts[7,1]:idx_cpts[7,2], , drop = FALSE] + + } + + #----- time-varying predictor matrices + + if (has_tve) { + + # time-varying predictor matrix + s_cpts <- make_s(formula, mf_cpts) + smooth_map <- get_smooth_name(s_cpts, type = "smooth_map") + smooth_idx <- get_idx_array(table(smooth_map)) + S <- ncol(s_cpts) # number of tve coefficients + + # store some additional information in model formula + # stating how many columns in the predictor matrix + # each tve() term in the model formula corresponds to + formula$tt_ncol <- attr(s_cpts, "tt_ncol") + formula$tt_map <- attr(s_cpts, "tt_map") + + } else { + + # dud entries if no tve() terms in model formula + s_cpts <- matrix(0,length(cpts),0) + smooth_idx <- matrix(0,0,2) + smooth_map <- integer(0) + S <- 0L + + formula$tt_ncol <- integer(0) + formula$tt_map <- integer(0) + + } + + if (has_quadrature) { + + # time-varying predictor matrices, with quadrature + s_epts_event <- s_cpts[idx_cpts[1,1]:idx_cpts[1,2], , drop = FALSE] + s_qpts_event <- s_cpts[idx_cpts[2,1]:idx_cpts[2,2], , drop = FALSE] + s_qpts_lcens <- s_cpts[idx_cpts[3,1]:idx_cpts[3,2], , drop = FALSE] + s_qpts_rcens <- s_cpts[idx_cpts[4,1]:idx_cpts[4,2], , drop = FALSE] + s_qpts_icenl <- s_cpts[idx_cpts[5,1]:idx_cpts[5,2], , drop = FALSE] + s_qpts_icenu <- s_cpts[idx_cpts[6,1]:idx_cpts[6,2], , drop = FALSE] + s_qpts_delay <- s_cpts[idx_cpts[7,1]:idx_cpts[7,2], , drop = FALSE] + + } + + #----- random effects predictor matrices + + has_bars <- as.logical(length(formula$bars)) + + # use 'stan_glmer' approach + if (has_bars) { + + group_unpadded <- lme4::mkReTrms(formula$bars, mf_cpts) + group <- pad_reTrms(Ztlist = group_unpadded$Ztlist, + cnms = group_unpadded$cnms, + flist = group_unpadded$flist) + z_cpts <- group$Z + + } else { + + group <- NULL + z_cpts <- matrix(0,length(cpts),0) + + } + + if (!has_quadrature) { + + # random effects predictor matrices, without quadrature + # NB skip index 5 on purpose, since time fixed predictor matrix is + # identical for lower and upper limits of interval censoring time + z_event <- z_cpts[idx_cpts[1,1]:idx_cpts[1,2], , drop = FALSE] + z_lcens <- z_cpts[idx_cpts[2,1]:idx_cpts[2,2], , drop = FALSE] + z_rcens <- z_cpts[idx_cpts[3,1]:idx_cpts[3,2], , drop = FALSE] + z_icens <- z_cpts[idx_cpts[4,1]:idx_cpts[4,2], , drop = FALSE] + z_delay <- z_cpts[idx_cpts[6,1]:idx_cpts[6,2], , drop = FALSE] + + parts_event <- extract_sparse_parts(z_event) + parts_lcens <- extract_sparse_parts(z_lcens) + parts_rcens <- extract_sparse_parts(z_rcens) + parts_icens <- extract_sparse_parts(z_icens) + parts_delay <- extract_sparse_parts(z_delay) + + } else { + + # random effects predictor matrices, with quadrature + # NB skip index 6 on purpose, since time fixed predictor matrix is + # identical for lower and upper limits of interval censoring time + z_epts_event <- z_cpts[idx_cpts[1,1]:idx_cpts[1,2], , drop = FALSE] + z_qpts_event <- z_cpts[idx_cpts[2,1]:idx_cpts[2,2], , drop = FALSE] + z_qpts_lcens <- z_cpts[idx_cpts[3,1]:idx_cpts[3,2], , drop = FALSE] + z_qpts_rcens <- z_cpts[idx_cpts[4,1]:idx_cpts[4,2], , drop = FALSE] + z_qpts_icens <- z_cpts[idx_cpts[5,1]:idx_cpts[5,2], , drop = FALSE] + z_qpts_delay <- z_cpts[idx_cpts[7,1]:idx_cpts[7,2], , drop = FALSE] + + parts_epts_event <- extract_sparse_parts(z_epts_event) + parts_qpts_event <- extract_sparse_parts(z_qpts_event) + parts_qpts_lcens <- extract_sparse_parts(z_qpts_lcens) + parts_qpts_rcens <- extract_sparse_parts(z_qpts_rcens) + parts_qpts_icens <- extract_sparse_parts(z_qpts_icens) + parts_qpts_delay <- extract_sparse_parts(z_qpts_delay) + + } + + #----- stan data + + standata <- nlist( + K, S, + nvars, + x_bar = aa(colMeans(x)), + has_intercept, + has_quadrature, + smooth_map, + smooth_idx, + type = basehaz$type, + log_crude_event_rate = + ifelse(is_aft, -log_crude_event_rate, log_crude_event_rate), + + nevent = if (has_quadrature) 0L else nevent, + nlcens = if (has_quadrature) 0L else nlcens, + nrcens = if (has_quadrature) 0L else nrcens, + nicens = if (has_quadrature) 0L else nicens, + ndelay = if (has_quadrature) 0L else ndelay, + + t_event = if (has_quadrature) rep(0,0) else t_event, + t_lcens = if (has_quadrature) rep(0,0) else t_lcens, + t_rcens = if (has_quadrature) rep(0,0) else t_rcens, + t_icenl = if (has_quadrature) rep(0,0) else t_icenl, + t_icenu = if (has_quadrature) rep(0,0) else t_icenu, + t_delay = if (has_quadrature) rep(0,0) else t_delay, + + x_event = if (has_quadrature) matrix(0,0,K) else x_event, + x_lcens = if (has_quadrature) matrix(0,0,K) else x_lcens, + x_rcens = if (has_quadrature) matrix(0,0,K) else x_rcens, + x_icens = if (has_quadrature) matrix(0,0,K) else x_icens, + x_delay = if (has_quadrature) matrix(0,0,K) else x_delay, + + w_event = if (has_quadrature || !has_bars || nevent == 0) double(0) else parts_event$w, + w_lcens = if (has_quadrature || !has_bars || nlcens == 0) double(0) else parts_lcens$w, + w_rcens = if (has_quadrature || !has_bars || nrcens == 0) double(0) else parts_rcens$w, + w_icens = if (has_quadrature || !has_bars || nicens == 0) double(0) else parts_icens$w, + w_delay = if (has_quadrature || !has_bars || ndelay == 0) double(0) else parts_delay$w, + + v_event = if (has_quadrature || !has_bars || nevent == 0) integer(0) else parts_event$v, + v_lcens = if (has_quadrature || !has_bars || nlcens == 0) integer(0) else parts_lcens$v, + v_rcens = if (has_quadrature || !has_bars || nrcens == 0) integer(0) else parts_rcens$v, + v_icens = if (has_quadrature || !has_bars || nicens == 0) integer(0) else parts_icens$v, + v_delay = if (has_quadrature || !has_bars || ndelay == 0) integer(0) else parts_delay$v, + + u_event = if (has_quadrature || !has_bars || nevent == 0) integer(0) else parts_event$u, + u_lcens = if (has_quadrature || !has_bars || nlcens == 0) integer(0) else parts_lcens$u, + u_rcens = if (has_quadrature || !has_bars || nrcens == 0) integer(0) else parts_rcens$u, + u_icens = if (has_quadrature || !has_bars || nicens == 0) integer(0) else parts_icens$u, + u_delay = if (has_quadrature || !has_bars || ndelay == 0) integer(0) else parts_delay$u, + + nnz_event = if (has_quadrature || !has_bars || nevent == 0) 0L else length(parts_event$w), + nnz_lcens = if (has_quadrature || !has_bars || nlcens == 0) 0L else length(parts_lcens$w), + nnz_rcens = if (has_quadrature || !has_bars || nrcens == 0) 0L else length(parts_rcens$w), + nnz_icens = if (has_quadrature || !has_bars || nicens == 0) 0L else length(parts_icens$w), + nnz_delay = if (has_quadrature || !has_bars || ndelay == 0) 0L else length(parts_delay$w), + + basis_event = if (has_quadrature) matrix(0,0,nvars) else basis_event, + ibasis_event = if (has_quadrature) matrix(0,0,nvars) else ibasis_event, + ibasis_lcens = if (has_quadrature) matrix(0,0,nvars) else ibasis_lcens, + ibasis_rcens = if (has_quadrature) matrix(0,0,nvars) else ibasis_rcens, + ibasis_icenl = if (has_quadrature) matrix(0,0,nvars) else ibasis_icenl, + ibasis_icenu = if (has_quadrature) matrix(0,0,nvars) else ibasis_icenu, + ibasis_delay = if (has_quadrature) matrix(0,0,nvars) else ibasis_delay, + + qnodes = if (!has_quadrature) 0L else qnodes, + + Nevent = if (!has_quadrature) 0L else nevent, + Nlcens = if (!has_quadrature) 0L else nlcens, + Nrcens = if (!has_quadrature) 0L else nrcens, + Nicens = if (!has_quadrature) 0L else nicens, + Ndelay = if (!has_quadrature) 0L else ndelay, + + qevent = if (!has_quadrature) 0L else qevent, + qlcens = if (!has_quadrature) 0L else qlcens, + qrcens = if (!has_quadrature) 0L else qrcens, + qicens = if (!has_quadrature) 0L else qicens, + qdelay = if (!has_quadrature) 0L else qdelay, + + epts_event = if (!has_quadrature) rep(0,0) else t_event, + qpts_event = if (!has_quadrature) rep(0,0) else qpts_event, + qpts_lcens = if (!has_quadrature) rep(0,0) else qpts_lcens, + qpts_rcens = if (!has_quadrature) rep(0,0) else qpts_rcens, + qpts_icenl = if (!has_quadrature) rep(0,0) else qpts_icenl, + qpts_icenu = if (!has_quadrature) rep(0,0) else qpts_icenu, + qpts_delay = if (!has_quadrature) rep(0,0) else qpts_delay, + + qwts_event = if (!has_quadrature) rep(0,0) else qwts_event, + qwts_lcens = if (!has_quadrature) rep(0,0) else qwts_lcens, + qwts_rcens = if (!has_quadrature) rep(0,0) else qwts_rcens, + qwts_icenl = if (!has_quadrature) rep(0,0) else qwts_icenl, + qwts_icenu = if (!has_quadrature) rep(0,0) else qwts_icenu, + qwts_delay = if (!has_quadrature) rep(0,0) else qwts_delay, + + x_epts_event = if (!has_quadrature) matrix(0,0,K) else x_epts_event, + x_qpts_event = if (!has_quadrature) matrix(0,0,K) else x_qpts_event, + x_qpts_lcens = if (!has_quadrature) matrix(0,0,K) else x_qpts_lcens, + x_qpts_rcens = if (!has_quadrature) matrix(0,0,K) else x_qpts_rcens, + x_qpts_icens = if (!has_quadrature) matrix(0,0,K) else x_qpts_icens, + x_qpts_delay = if (!has_quadrature) matrix(0,0,K) else x_qpts_delay, + + s_epts_event = if (!has_quadrature) matrix(0,0,S) else s_epts_event, + s_qpts_event = if (!has_quadrature) matrix(0,0,S) else s_qpts_event, + s_qpts_lcens = if (!has_quadrature) matrix(0,0,S) else s_qpts_lcens, + s_qpts_rcens = if (!has_quadrature) matrix(0,0,S) else s_qpts_rcens, + s_qpts_icenl = if (!has_quadrature) matrix(0,0,S) else s_qpts_icenl, + s_qpts_icenu = if (!has_quadrature) matrix(0,0,S) else s_qpts_icenu, + s_qpts_delay = if (!has_quadrature) matrix(0,0,S) else s_qpts_delay, + + w_epts_event = if (!has_quadrature || !has_bars || qevent == 0) double(0) else parts_epts_event$w, + w_qpts_event = if (!has_quadrature || !has_bars || qevent == 0) double(0) else parts_qpts_event$w, + w_qpts_lcens = if (!has_quadrature || !has_bars || qlcens == 0) double(0) else parts_qpts_lcens$w, + w_qpts_rcens = if (!has_quadrature || !has_bars || qrcens == 0) double(0) else parts_qpts_rcens$w, + w_qpts_icens = if (!has_quadrature || !has_bars || qicens == 0) double(0) else parts_qpts_icens$w, + w_qpts_delay = if (!has_quadrature || !has_bars || qdelay == 0) double(0) else parts_qpts_delay$w, + + v_epts_event = if (!has_quadrature || !has_bars || qevent == 0) integer(0) else parts_epts_event$v, + v_qpts_event = if (!has_quadrature || !has_bars || qevent == 0) integer(0) else parts_qpts_event$v, + v_qpts_lcens = if (!has_quadrature || !has_bars || qlcens == 0) integer(0) else parts_qpts_lcens$v, + v_qpts_rcens = if (!has_quadrature || !has_bars || qrcens == 0) integer(0) else parts_qpts_rcens$v, + v_qpts_icens = if (!has_quadrature || !has_bars || qicens == 0) integer(0) else parts_qpts_icens$v, + v_qpts_delay = if (!has_quadrature || !has_bars || qdelay == 0) integer(0) else parts_qpts_delay$v, + + u_epts_event = if (!has_quadrature || !has_bars || qevent == 0) integer(0) else parts_epts_event$u, + u_qpts_event = if (!has_quadrature || !has_bars || qevent == 0) integer(0) else parts_qpts_event$u, + u_qpts_lcens = if (!has_quadrature || !has_bars || qlcens == 0) integer(0) else parts_qpts_lcens$u, + u_qpts_rcens = if (!has_quadrature || !has_bars || qrcens == 0) integer(0) else parts_qpts_rcens$u, + u_qpts_icens = if (!has_quadrature || !has_bars || qicens == 0) integer(0) else parts_qpts_icens$u, + u_qpts_delay = if (!has_quadrature || !has_bars || qdelay == 0) integer(0) else parts_qpts_delay$u, + + nnz_epts_event = if (!has_quadrature || !has_bars || qevent == 0) 0L else length(parts_epts_event$w), + nnz_qpts_event = if (!has_quadrature || !has_bars || qevent == 0) 0L else length(parts_qpts_event$w), + nnz_qpts_lcens = if (!has_quadrature || !has_bars || qlcens == 0) 0L else length(parts_qpts_lcens$w), + nnz_qpts_rcens = if (!has_quadrature || !has_bars || qrcens == 0) 0L else length(parts_qpts_rcens$w), + nnz_qpts_icens = if (!has_quadrature || !has_bars || qicens == 0) 0L else length(parts_qpts_icens$w), + nnz_qpts_delay = if (!has_quadrature || !has_bars || qdelay == 0) 0L else length(parts_qpts_delay$w), + + basis_epts_event = if (!has_quadrature) matrix(0,0,nvars) else basis_epts_event, + basis_qpts_event = if (!has_quadrature) matrix(0,0,nvars) else basis_qpts_event, + basis_qpts_lcens = if (!has_quadrature) matrix(0,0,nvars) else basis_qpts_lcens, + basis_qpts_rcens = if (!has_quadrature) matrix(0,0,nvars) else basis_qpts_rcens, + basis_qpts_icenl = if (!has_quadrature) matrix(0,0,nvars) else basis_qpts_icenl, + basis_qpts_icenu = if (!has_quadrature) matrix(0,0,nvars) else basis_qpts_icenu, + basis_qpts_delay = if (!has_quadrature) matrix(0,0,nvars) else basis_qpts_delay + ) + + #----- random-effects structure + + if (has_bars) { + + fl <- group$flist + p <- sapply(group$cnms, FUN = length) + l <- sapply(attr(fl, "assign"), function(i) nlevels(fl[[i]])) + t <- length(l) + standata$p <- as.array(p) # num ranefs for each grouping factor + standata$l <- as.array(l) # num levels for each grouping factor + standata$t <- t # num of grouping factors + standata$q <- ncol(group$Z) # p * l + standata$special_case <- all(sapply(group$cnms, intercept_only)) + + } else { # no random effects structure + + standata$p <- integer(0) + standata$l <- integer(0) + standata$t <- 0L + standata$q <- 0L + standata$special_case <- 0L + + } + + #----- priors and hyperparameters + + # valid priors + ok_dists <- nlist("normal", + student_t = "t", + "cauchy", + "hs", + "hs_plus", + "laplace", + "lasso") # disallow product normal + ok_intercept_dists <- ok_dists[1:3] + ok_aux_dists <- get_ok_priors_for_aux(basehaz) + ok_smooth_dists <- c(ok_dists[1:3], "exponential") + ok_covariance_dists <- c("decov") + + if (missing(prior_aux)) + prior_aux <- get_default_prior_for_aux(basehaz) + + # priors + user_prior_stuff <- prior_stuff <- + handle_glm_prior(prior, + nvars = K, + default_scale = 2.5, + link = NULL, + ok_dists = ok_dists) + + user_prior_intercept_stuff <- prior_intercept_stuff <- + handle_glm_prior(prior_intercept, + nvars = 1, + default_scale = 20, + link = NULL, + ok_dists = ok_intercept_dists) + + user_prior_aux_stuff <- prior_aux_stuff <- + handle_glm_prior(prior_aux, + nvars = basehaz$nvars, + default_scale = get_default_aux_scale(basehaz), + link = NULL, + ok_dists = ok_aux_dists) + + user_prior_smooth_stuff <- prior_smooth_stuff <- + handle_glm_prior(prior_smooth, + nvars = if (S) max(smooth_map) else 0, + default_scale = 1, + link = NULL, + ok_dists = ok_smooth_dists) + + # stop null priors when prior_PD is true + if (prior_PD) { + if (is.null(prior)) + stop("'prior' cannot be NULL if 'prior_PD' is TRUE.") + if (is.null(prior_intercept) && has_intercept) + stop("'prior_intercept' cannot be NULL if 'prior_PD' is TRUE.") + if (is.null(prior_aux)) + stop("'prior_aux' cannot be NULL if 'prior_PD' is TRUE.") + if (is.null(prior_smooth) && (S > 0)) + stop("'prior_smooth' cannot be NULL if 'prior_PD' is TRUE.") + } + + # handle prior for random effects structure + if (has_bars) { + + user_prior_b_stuff <- prior_b_stuff <- + handle_cov_prior(prior_covariance, + cnms = group$cnms, + ok_dists = ok_covariance_dists) + + if (is.null(prior_covariance)) + stop("'prior_covariance' cannot be NULL.") + + } else { + user_prior_b_stuff <- NULL + prior_b_stuff <- NULL + prior_covariance <- NULL + } + + # autoscaling of priors + prior_stuff <- autoscale_prior(prior_stuff, predictors = x) + prior_intercept_stuff <- autoscale_prior(prior_intercept_stuff) + prior_aux_stuff <- autoscale_prior(prior_aux_stuff) + prior_smooth_stuff <- autoscale_prior(prior_smooth_stuff) + + # priors + standata$prior_dist <- prior_stuff$prior_dist + standata$prior_dist_for_intercept<- prior_intercept_stuff$prior_dist + standata$prior_dist_for_aux <- prior_aux_stuff$prior_dist + standata$prior_dist_for_smooth <- prior_smooth_stuff$prior_dist + standata$prior_dist_for_cov <- prior_b_stuff$prior_dist + + # hyperparameters + standata$prior_mean <- prior_stuff$prior_mean + standata$prior_scale <- prior_stuff$prior_scale + standata$prior_df <- prior_stuff$prior_df + standata$prior_mean_for_intercept <- c(prior_intercept_stuff$prior_mean) + standata$prior_scale_for_intercept<- c(prior_intercept_stuff$prior_scale) + standata$prior_df_for_intercept <- c(prior_intercept_stuff$prior_df) + standata$prior_scale_for_aux <- prior_aux_stuff$prior_scale + standata$prior_df_for_aux <- prior_aux_stuff$prior_df + standata$prior_conc_for_aux <- prior_aux_stuff$prior_concentration + standata$prior_mean_for_smooth <- prior_smooth_stuff$prior_mean + standata$prior_scale_for_smooth <- prior_smooth_stuff$prior_scale + standata$prior_df_for_smooth <- prior_smooth_stuff$prior_df + standata$global_prior_scale <- prior_stuff$global_prior_scale + standata$global_prior_df <- prior_stuff$global_prior_df + standata$slab_df <- prior_stuff$slab_df + standata$slab_scale <- prior_stuff$slab_scale + + # hyperparameters for covariance + if (has_bars) { + standata$b_prior_shape <- prior_b_stuff$prior_shape + standata$b_prior_scale <- prior_b_stuff$prior_scale + standata$concentration <- prior_b_stuff$prior_concentration + standata$regularization <- prior_b_stuff$prior_regularization + standata$len_concentration <- length(standata$concentration) + standata$len_regularization <- length(standata$regularization) + standata$len_theta_L <- sum(choose(standata$p, 2), standata$p) + } else { # no random effects structure + standata$b_prior_shape <- rep(0, 0) + standata$b_prior_scale <- rep(0, 0) + standata$concentration <- rep(0, 0) + standata$regularization <- rep(0, 0) + standata$len_concentration <- 0L + standata$len_regularization <- 0L + standata$len_theta_L <- 0L + } + + # any additional flags + standata$prior_PD <- ai(prior_PD) + + #--------------- + # Prior summary + #--------------- + + prior_info <- summarize_jm_prior( + user_priorEvent = user_prior_stuff, + user_priorEvent_intercept = user_prior_intercept_stuff, + user_priorEvent_aux = user_prior_aux_stuff, + adjusted_priorEvent_scale = prior_stuff$prior_scale, + adjusted_priorEvent_intercept_scale = prior_intercept_stuff$prior_scale, + adjusted_priorEvent_aux_scale = prior_aux_stuff$prior_scale, + e_has_intercept = has_intercept, + e_has_predictors = K > 0, + basehaz = basehaz, + user_prior_covariance = prior_covariance, + b_user_prior_stuff = user_prior_b_stuff, + b_prior_stuff = prior_b_stuff + ) + + #----------- + # Fit model + #----------- + + # obtain stan model code + stanfit <- stanmodels$surv + + # specify parameters for stan to monitor + stanpars <- c(if (standata$has_intercept) "alpha", + if (standata$K) "beta", + if (standata$S) "beta_tve", + if (standata$S) "smooth_sd", + if (standata$nvars) "aux", + if (standata$t) "b", + if (standata$t) "theta_L") + + # fit model using stan + if (algorithm == "sampling") { # mcmc + args <- set_sampling_args( + object = stanfit, + data = standata, + pars = stanpars, + prior = prior, + user_dots = list(...), + user_adapt_delta = adapt_delta, + show_messages = FALSE) + stanfit <- do.call(rstan::sampling, args) + } else { # meanfield or fullrank vb + args <- nlist( + object = stanfit, + data = standata, + pars = stanpars, + algorithm + ) + args[names(dots)] <- dots + stanfit <- do.call(rstan::vb, args) + } + check_stanfit(stanfit) + + # replace 'theta_L' with the variance-covariance matrix + if (has_bars) + stanfit <- evaluate_Sigma(stanfit, group$cnms) + + # define new parameter names + nms_beta <- colnames(x_cpts) # may be NULL + nms_tve <- get_smooth_name(s_cpts, type = "smooth_coefs") # may be NULL + nms_smooth <- get_smooth_name(s_cpts, type = "smooth_sd") # may be NULL + nms_int <- get_int_name_basehaz(basehaz) + nms_aux <- get_aux_name_basehaz(basehaz) + nms_b <- get_b_names(group) # may be NULL + nms_vc <- get_varcov_names(group) # may be NULL + nms_all <- c(nms_int, + nms_beta, + nms_tve, + nms_smooth, + nms_aux, + nms_b, + nms_vc, + "log-posterior") + + # substitute new parameter names into 'stanfit' object + stanfit <- replace_stanfit_nms(stanfit, nms_all) + + # return an object of class 'stansurv' + fit <- nlist(stanfit, + formula, + has_tve, + has_quadrature, + has_bars, + data, + model_frame = mf, + terms = mt, + xlevels = .getXlevels(mt, mf), + x, + x_cpts, + s_cpts = if (has_tve) s_cpts else NULL, + z_cpts = if (has_bars) z_cpts else NULL, + cnms = if (has_bars) group_unpadded$cnms else NULL, + flist = if (has_bars) group_unpadded$flist else NULL, + t_beg, + t_end, + status, + event = as.logical(status == 1), + delayed, + basehaz, + nobs = nrow(mf), + nevents = nevent, + nlcens, + nrcens, + nicens, + ncensor = nlcens + nrcens + nicens, + ndelayed = ndelay, + prior_info, + qnodes = if (has_quadrature) qnodes else NULL, + algorithm, + stan_function = "stan_surv", + rstanarm_version = utils::packageVersion("rstanarm"), + call = match.call(expand.dots = TRUE)) + stansurv(fit) +} + + +#' Time-varying effects in Bayesian survival models +#' +#' This is a special function that can be used in the formula of a Bayesian +#' survival model estimated using \code{\link{stan_surv}}. It specifies that a +#' time-varying coefficient should be estimated for the covariate \code{x}. +#' The time-varying coefficient is currently modelled using B-splines (with +#' piecewise constant included as a special case). Note that the \code{tve} +#' function only has meaning when evaluated within the formula of a +#' \code{\link{stan_surv}} call and does not have meaning outside of that +#' context. The exported function documented here just returns \code{x}. +#' However when called internally the \code{tve} function returns several +#' other pieces of useful information used in the model fitting. +#' +#' @export +#' +#' @param x The covariate for which a time-varying coefficient should be +#' estimated. +#' @param type The type of function used to model the time-varying coefficient. +#' Currently only \code{type = "bs"} is allowed. This corresponds to a +#' B-spline function. Note that \emph{cubic} B-splines are used by default +#' but this can be changed by the user via the \code{degree} argument +#' described below. Of particular note is that \code{degree = 0} is +#' is treated as a special case corresponding to a piecewise constant basis. +#' @param df A positive integer specifying the degrees of freedom +#' for the B-spline function. Two boundary knots and \code{df - degree} +#' internal knots are used to generate the B-spline function. +#' The internal knots are placed at equally spaced percentiles of the +#' distribution of the uncensored event times. The default is to use +#' \code{df = 3} unless \code{df} or \code{knots} is explicitly +#' specified by the user. +#' @param knots A numeric vector explicitly specifying internal knot +#' locations for the B-spline function. Note that \code{knots} cannot be +#' specified if \code{df} is specified. Also note that this argument only +#' controls the \emph{internal} knot locations. In addition, boundary +#' knots are placed at the earliest entry time and latest event or +#' censoring time and these cannot be changed by the user. +#' @param degree A positive integer specifying the degree for the B-spline +#' function. The order of the B-spline is equal to \code{degree + 1}. +#' Note that \code{degree = 0} is allowed and is treated as a special +#' case corresponding to a piecewise constant basis. +#' +#' @return The exported \code{tve} function documented here just returns +#' \code{x}. However, when called internally the \code{tve} function returns +#' several other pieces of useful information. For the most part, these are +#' added to the formula element of the returned \code{\link{stanreg-objects}} +#' (that is \code{object[["formula"]]} where \code{object} is the fitted +#' model). Information added to the formula element of the \code{stanreg} +#' object includes the following: +#' \itemize{ +#' \item \code{tt_vars}: A list with the names of variables in the model +#' formula that were wrapped in the \code{tve} function. +#' \item \code{tt_types}: A list with the \code{type} (e.g. \code{"bs"}) +#' of \code{tve} function corresponding to each variable in \code{tt_vars}. +#' \item \code{tt_degrees}: A list with the \code{degree} for the +#' B-spline function corresponding to each variable in \code{tt_vars}. +#' \item \code{tt_calls}: A list with the call required to construct the +#' transformation of time for each variable in \code{tt_vars}. +#' \item \code{tt_forms}: Same as \code{tt_calls} but expressed as formulas. +#' \item \code{tt_frame}: A single formula that can be used to generate a +#' model frame that contains the unique set of transformations of time +#' (i.e. the basis terms) that are required to build all time-varying +#' coefficients in the model. In other words a single formula with the +#' unique element(s) contained in \code{tt_forms}. +#' } +#' +#' @examples +#' # Exported function just returns the input variable +#' identical(pbcSurv$trt, tve(pbcSurv$trt)) # returns TRUE +#' +#' # Internally the function returns and stores information +#' # used to form the time-varying coefficients in the model +#' m1 <- stan_surv(Surv(futimeYears, death) ~ tve(trt) + tve(sex, degree = 0), +#' data = pbcSurv, chains = 1, iter = 50) +#' m1$formula[["tt_vars"]] +#' m1$formula[["tt_forms"]] +#' +tve <- function(x, + type = "bs", + df = NULL, + knots = NULL, + degree = 3L) { + + type <- match.arg(type) + + if (!is.null(df) && !is.null(knots)) + stop("Cannot specify both 'df' and 'knots' in the 'tve' function.") + + if (degree < 0) + stop("In 'tve' function, 'degree' must be non-negative.") + + if (is.null(df) && is.null(knots)) + df <- 3L + + x +} + + +#---------- internal + +# Construct a list with information about the baseline hazard +# +# @param basehaz A string specifying the type of baseline hazard +# @param basehaz_ops A named list with elements: df, knots, degree +# @param ok_basehaz A list of admissible baseline hazards +# @param times A numeric vector with eventtimes for each individual +# @param status A numeric vector with event indicators for each individual +# @param min_t Scalar, the minimum entry time across all individuals +# @param max_t Scalar, the maximum event or censoring time across all individuals +# @return A named list with the following elements: +# type: integer specifying the type of baseline hazard, 1L = weibull, +# 2L = b-splines, 3L = piecewise. +# type_name: character string specifying the type of baseline hazard. +# user_df: integer specifying the input to the df argument +# df: integer specifying the number of parameters to use for the +# baseline hazard. +# knots: the knot locations for the baseline hazard. +# bs_basis: The basis terms for the B-splines. This is passed to Stan +# as the "model matrix" for the baseline hazard. It is also used in +# post-estimation when evaluating the baseline hazard for posterior +# predictions since it contains information about the knot locations +# for the baseline hazard (this is implemented via splines::predict.bs). +handle_basehaz_surv <- function(basehaz, + basehaz_ops, + ok_basehaz, + times, + status, + min_t, max_t) { + + if (!basehaz %in% ok_basehaz) + stop2("'basehaz' should be one of: ", comma(ok_basehaz)) + + ok_basehaz_ops <- get_ok_basehaz_ops(basehaz) + if (!all(names(basehaz_ops) %in% ok_basehaz_ops)) + stop2("'basehaz_ops' can only include: ", comma(ok_basehaz_ops)) + + if (basehaz %in% c("ms", "bs", "piecewise")) { + + df <- basehaz_ops$df + knots <- basehaz_ops$knots + degree <- basehaz_ops$degree + + if (!is.null(df) && !is.null(knots)) + stop2("Cannot specify both 'df' and 'knots' for the baseline hazard.") + + if (is.null(df)) + df <- switch(basehaz, + "ms" = 6L, # assumes intercept + "bs" = 5L, # assumes no intercept + "piecewise" = 5L, # assumes no intercept + df) # NB this is ignored if the user specified knots + + if (is.null(degree)) + degree <- 3L # cubic splines + + tt <- times[status == 1] # uncensored event times + if (is.null(knots) && !length(tt)) { + warning2("No observed events found in the data. Censoring times will ", + "be used to evaluate default knot locations for splines.") + tt <- times + } + + if (!is.null(knots)) { + if (any(knots < min_t)) + stop2("'knots' cannot be placed before the earliest entry time.") + if (any(knots > max_t)) + stop2("'knots' cannot be placed beyond the latest event time.") + } + + } + + if (basehaz %in% c("exp", "exp-aft")) { + + degree <- NULL # degree for splines + bknots <- NULL # boundary knot locations + iknots <- NULL # internal knot locations + basis <- NULL # spline basis + nvars <- 0L # number of aux parameters, none + + } else if (basehaz %in% c("weibull", "weibull-aft")) { + + degree <- NULL # degree for splines + bknots <- NULL # boundary knot locations + iknots <- NULL # internal knot locations + basis <- NULL # spline basis + nvars <- 1L # number of aux parameters, Weibull shape + + } else if (basehaz == "gompertz") { + + degree <- NULL # degree for splines + bknots <- NULL # boundary knot locations + iknots <- NULL # internal knot locations + basis <- NULL # spline basis + nvars <- 1L # number of aux parameters, Gompertz scale + + } else if (basehaz == "bs") { + + bknots <- c(min_t, max_t) + iknots <- get_iknots(tt, df = df, iknots = knots, degree = degree) + basis <- get_basis(tt, iknots = iknots, bknots = bknots, degree = degree, type = "bs") + nvars <- ncol(basis) # number of aux parameters, basis terms + + } else if (basehaz == "ms") { + + bknots <- c(min_t, max_t) + iknots <- get_iknots(tt, df = df, iknots = knots, degree = degree, intercept = TRUE) + basis <- get_basis(tt, iknots = iknots, bknots = bknots, degree = degree, type = "ms") + nvars <- ncol(basis) # number of aux parameters, basis terms + + } else if (basehaz == "piecewise") { + + degree <- NULL # degree for splines + bknots <- c(min_t, max_t) + iknots <- get_iknots(tt, df = df, iknots = knots, degree = 0) + basis <- NULL # spline basis + nvars <- length(iknots) + 1 # number of aux parameters, dummy indicators + + } + + nlist(type_name = basehaz, + type = basehaz_for_stan(basehaz), + nvars, + iknots, + bknots, + degree, + basis, + df = nvars, + user_df = nvars, + knots = if (basehaz == "bs") iknots else c(bknots[1], iknots, bknots[2]), + bs_basis = basis) +} + +# Return a vector with valid names for elements in the list passed to the +# 'basehaz_ops' argument of a 'stan_jm' or 'stan_surv' call +# +# @param basehaz_name A character string, the type of baseline hazard. +# @return A character vector, or NA if unmatched. +get_ok_basehaz_ops <- function(basehaz_name) { + switch(basehaz_name, + "bs" = c("df", "knots", "degree"), + "ms" = c("df", "knots", "degree"), + "piecewise" = c("df", "knots"), + NA) +} + +# Return the integer representation for the baseline hazard, used by Stan +# +# @param basehaz_name A character string, the type of baseline hazard. +# @return An integer, or NA if unmatched. +basehaz_for_stan <- function(basehaz_name) { + switch(basehaz_name, + "weibull" = 1L, + "bs" = 2L, + "piecewise" = 3L, + "ms" = 4L, + "exp" = 5L, + "gompertz" = 6L, + "exp-aft" = 7L, + "weibull-aft" = 8L, + NA) +} + +# Return a vector with internal knots for 'x', based on evenly spaced quantiles +# +# @param x A numeric vector. +# @param df The degrees of freedom. If specified, then 'df - degree - intercept'. +# knots are placed at evenly spaced percentiles of 'x'. If 'iknots' is +# specified then 'df' is ignored. +# @param degree Non-negative integer. The degree for the spline basis. +# @param iknots Optional vector of internal knots. +# @return A numeric vector of internal knot locations, or NULL if there are +# no internal knots. +get_iknots <- function(x, df = 5L, degree = 3L, iknots = NULL, intercept = FALSE) { + + # obtain number of internal knots + if (is.null(iknots)) { + nk <- df - degree - intercept + } else { + nk <- length(iknots) + } + + # validate number of internal knots + if (nk < 0) { + stop2("Number of internal knots cannot be negative.") + } + + # if no internal knots then return empty vector + if (nk == 0) { + return(numeric(0)) + } + + # obtain default knot locations if necessary + if (is.null(iknots)) { + iknots <- qtile(x, nq = nk + 1) # evenly spaced percentiles + } + + # return internal knot locations, ensuring they are positive + validate_positive_scalar(iknots) + + return(iknots) +} + +# Identify whether the type of baseline hazard requires an intercept in +# the linear predictor (NB splines incorporate the intercept into the basis). +# +# @param basehaz A list with info about the baseline hazard; see 'handle_basehaz'. +# @return A Logical. +has_intercept <- function(basehaz) { + nm <- get_basehaz_name(basehaz) + (nm %in% c("exp", + "exp-aft", + "weibull", + "weibull-aft", + "gompertz", + "ms", + "bs")) +} + +# Return the name of the tve spline coefs or smoothing parameters. +# +# @param x The predictor matrix for the time-varying effects, with column names. +# @param type The type of information about the smoothing parameters to return. +# @return A character or numeric vector, depending on 'type'. +get_smooth_name <- function(x, type = "smooth_coefs") { + + if (is.null(x) || !ncol(x)) + return(NULL) + + nms <- colnames(x) + nms <- gsub(":splines2::bSpline\\(times__.*\\)[0-9]*$", ":tve-bs-coef", nms) + + nms_trim <- gsub(":tve-[a-z][a-z]-coef[0-9]*$", "", nms) + tally <- table(nms_trim) + indices <- uapply(tally, seq_len) + + switch(type, + "smooth_coefs" = paste0(nms, indices), + "smooth_sd" = paste0("smooth_sd[", unique(nms_trim), "]"), + "smooth_map" = aa(rep(seq_along(tally), tally)), + "smooth_vars" = unique(nms_trim), + stop2("Bug found: invalid input to 'type' argument.")) +} + +# Return the valid prior distributions for 'prior_aux'. +# +# @param basehaz A list with info about the baseline hazard; see 'handle_basehaz'. +# @return A named list. +get_ok_priors_for_aux <- function(basehaz) { + nm <- get_basehaz_name(basehaz) + switch(nm, + "exp" = nlist(), + "exp-aft" = nlist(), + "weibull" = nlist("normal", student_t = "t", "cauchy", "exponential"), + "weibull-aft" = nlist("normal", student_t = "t", "cauchy", "exponential"), + "gompertz" = nlist("normal", student_t = "t", "cauchy", "exponential"), + "ms" = nlist("dirichlet"), + "bs" = nlist("normal", student_t = "t", "cauchy"), + "piecewise" = nlist("normal", student_t = "t", "cauchy"), + stop2("Bug found: unknown type of baseline hazard.")) +} + +# Return the default prior distribution for 'prior_aux'. +# +# @param basehaz A list with info about the baseline hazard; see 'handle_basehaz'. +# @return A list corresponding to the default prior. +get_default_prior_for_aux <- function(basehaz) { + nm <- get_basehaz_name(basehaz) + switch(nm, + "exp" = list(), # equivalent to NULL + "exp-aft" = list(), # equivalent to NULL + "weibull" = normal(), + "weibull-aft" = normal(), + "gompertz" = normal(), + "ms" = dirichlet(), + "bs" = normal(), + "piecewise" = normal(), + stop2("Bug found: unknown type of baseline hazard.")) +} + +# Return the names for the group-specific parameters +# +# @param group List returned by rstanarm:::pad_reTerms. +# @return A character vector. +get_b_names <- function(group) { + if (is.null(group)) + return(NULL) # no random effects structure + c(paste0("b[", make_b_nms(group), "]")) +} + +# Return the names for the var-cov parameters +# +# @param group List returned by rstanarm:::pad_reTerms. +# @return A character vector. +get_varcov_names <- function(group) { + if (is.null(group)) + return(NULL) # no random effects structure + paste0("Sigma[", get_Sigma_nms(group$cnms), "]") +} + +# Return the default scale parameter for 'prior_aux'. +# +# @param basehaz A list with info about the baseline hazard; see 'handle_basehaz'. +# @return A scalar. +get_default_aux_scale <- function(basehaz) { + switch(get_basehaz_name(basehaz), + "weibull" = 2, + "weibull-aft" = 2, + "gompertz" = 0.5, + 20) +} + +# Check if the type of baseline hazard has a closed form +# +# @param basehaz A list with info about the baseline hazard; see 'handle_basehaz'. +# @return A logical. +check_for_closed_form <- function(basehaz) { + nm <- get_basehaz_name(basehaz) + nm %in% c("exp", + "exp-aft", + "weibull", + "weibull-aft", + "gompertz", + "ms") +} + +# Replace the parameter names slot of an object of class 'stanfit'. +# +# @param stanfit An object of class 'stanfit'. +# @param new_nms A character vector of new parameter names. +# @return A 'stanfit' object. +replace_stanfit_nms <- function(stanfit, new_nms) { + stanfit@sim$fnames_oi <- new_nms + stanfit +} + +# Return the spline basis for the given type of baseline hazard. +# +# @param times A numeric vector of times at which to evaluate the basis. +# @param basehaz A list with info about the baseline hazard, returned by a +# call to 'handle_basehaz'. +# @param integrate A logical, specifying whether to calculate the integral of +# the specified basis. +# @return A matrix. +make_basis <- function(times, basehaz, integrate = FALSE) { + N <- length(times) + K <- basehaz$nvars + if (!N) { # times is NULL or empty vector + return(matrix(0, 0, K)) + } + switch(basehaz$type_name, + "exp" = matrix(0, N, K), # dud matrix for Stan + "exp-aft" = matrix(0, N, K), # dud matrix for Stan + "weibull" = matrix(0, N, K), # dud matrix for Stan + "weibull-aft" = matrix(0, N, K), # dud matrix for Stan + "gompertz" = matrix(0, N, K), # dud matrix for Stan + "ms" = basis_matrix(times, basis = basehaz$basis, integrate = integrate), + "bs" = basis_matrix(times, basis = basehaz$basis), + "piecewise" = dummy_matrix(times, knots = basehaz$knots), + stop2("Bug found: unknown type of baseline hazard.")) +} + +# Evaluate a spline basis matrix at the specified times +# +# @param time A numeric vector. +# @param basis Info on the spline basis. +# @param integrate A logical, should the integral of the basis be returned? +# @return A two-dimensional array. +basis_matrix <- function(times, basis, integrate = FALSE) { + out <- predict(basis, times) + if (integrate) { + stopifnot(inherits(basis, "MSpline")) + class(basis) <- c("ISpline", "splines2", "matrix") + out <- predict(basis, times) + } + aa(out) +} + +# Parse the model formula and data +# +# @param formula The user input to the formula argument. +# @param data The user input to the data argument (i.e. a data frame). +# @param A list with the model data (following removal of NA rows etc) and +# a number of elements corresponding to different parts of the formula. +parse_formula_and_data <- function(formula, data) { + + formula <- validate_formula(formula, needs_response = TRUE) + + # all variables of entire formula + allvars <- all.vars(formula) + allvars_form <- reformulate(allvars) + + # LHS of entire formula + lhs <- lhs(formula) # LHS as expression + lhs_form <- reformulate_lhs(lhs) # LHS as formula + + # RHS of entire formula + rhs <- rhs(formula) # RHS as expression + rhs_form <- reformulate_rhs(rhs) # RHS as formula + + # evaluate model data (row subsetting etc) + data <- make_model_data(allvars_form, data) + + # evaluated response variables + surv <- eval(lhs, envir = data) # Surv object + surv <- validate_surv(surv) + type <- attr(surv, "type") + + if (type == "right") { + min_t <- 0 + max_t <- max(surv[, "time"]) + status <- as.vector(surv[, "status"]) + t_end <- as.vector(surv[, "time"]) + } else if (type == "counting") { + min_t <- min(surv[, "start"]) + max_t <- max(surv[, "stop"]) + status <- as.vector(surv[, "status"]) + t_end <- as.vector(surv[, "stop"]) + } else if (type == "interval") { + min_t <- 0 + max_t <- max(surv[, c("time1", "time2")]) + status <- as.vector(surv[, "status"]) + t_end <- as.vector(surv[, "time1"]) + } else if (type == "interval2") { + min_t <- 0 + max_t <- max(surv[, c("time1", "time2")]) + status <- as.vector(surv[, "status"]) + t_end <- as.vector(surv[, "time1"]) + } + + if (any(is.na(status))) + stop2("Invalid status indicator in Surv object.") + + if (any(status < 0 | status > 3)) + stop2("Invalid status indicator in Surv object.") + + # deal with tve(x, ...) + tve_stuff <- handle_tve(formula, + min_t = min_t, + max_t = max_t, + times = t_end, + status = status) + tf_form <- tve_stuff$tf_form + td_form <- tve_stuff$td_form # may be NULL + tt_vars <- tve_stuff$tt_vars # may be NULL + tt_frame <- tve_stuff$tt_frame # may be NULL + tt_types <- tve_stuff$tt_types # may be NULL + tt_degrees <- tve_stuff$tt_degrees # may be NULL + tt_calls <- tve_stuff$tt_calls # may be NULL + tt_forms <- tve_stuff$tt_forms # may be NULL + + # just fixed-effect part of formula + fe_form <- lme4::nobars(tf_form) + + # just random-effect part of formula + bars <- lme4::findbars(tf_form) + re_parts <- lapply(bars, split_at_bars) + re_forms <- fetch(re_parts, "re_form") + + nlist(formula, + data, + allvars, + allvars_form, + lhs, + lhs_form, + rhs, + rhs_form, + tf_form, + td_form, + tt_vars, + tt_frame, + tt_types, + tt_degrees, + tt_calls, + tt_forms, + fe_form, + bars, + re_parts, + re_forms, + surv_type = attr(surv, "type")) +} + +# Handle the 'tve(x, ...)' terms in the model formula +# +# @param Terms terms object for the fixed effect part of the model formula. +# @return A named list with the following elements: +# +handle_tve <- function(formula, min_t, max_t, times, status) { + + # extract terms objects for fixed effect part of model formula + Terms <- delete.response(terms(lme4::nobars(formula), specials = "tve")) + + # check which fixed effect terms have a tve() wrapper + sel <- attr(Terms, "specials")$tve + + # if no tve() terms then just return the fixed effect formula as is + if (!length(sel)) { + return(list(tf_form = formula, + td_form = NULL, + tt_vars = NULL, + tt_frame = NULL, + tt_calls = NULL, + tt_forms = NULL)) + } + + # otherwise extract rhs of formula + all_vars <- rownames(attr(Terms, "factors")) # all variables in fe formula + tve_vars <- all_vars[sel] # variables with a tve() wrapper + + # replace 'tve(x, ...)' in formula with 'x' + old_vars <- all_vars + new_vars <- sapply(old_vars, function(x) { + if (x %in% tve_vars) { + # strip tve() from variable + tve <- function(y, ...) { safe_deparse(substitute(y)) } # define locally + return(eval(parse(text = x))) + } else { + # just return variable + return(x) + } + }, USE.NAMES = FALSE) + tf_terms <- attr(Terms, "term.labels") + td_terms <- c() + k <- 0 # initialise td_terms indexing (for creating a new formula) + for (i in sel) { + sel_terms <- which(attr(Terms, "factors")[i, ] > 0) + for (j in sel_terms) { + k <- k + 1 + tf_terms[j] <- td_terms[k] <- gsub(old_vars[i], + new_vars[i], + tf_terms[j], + fixed = TRUE) + } + } + + # extract 'tve(x, ...)' from formula and return '~ x' and '~ bs(times, ...)' + idx <- 1 + tt_vars <- list() + tt_types <- list() + tt_degrees <- list() + tt_calls <- list() + + for (i in seq_along(sel)) { + + # define tve() function locally; uses objects from the parent environment + # + # @param x The variable the time-varying effect is going to be applied to. + # @param type Character string, the type of time-varying effect to use. Can + # currently only be "bs". + # @param df,knots,degree Additional arguments passed to splines2::bSpline. + # @return The call used to construct a time-varying basis. + tve <- function(x, type = "bs", df = NULL, knots = NULL, degree = 3L) { + + type <- match.arg(type) + + if (!is.null(df) && !is.null(knots)) + stop("Cannot specify both 'df' and 'knots' in the 'tve' function.") + + if (degree < 0) + stop("In 'tve' function, 'degree' must be non-negative.") + + if (is.null(df) && is.null(knots)) + df <- 3L + + # note that times and status are taken from the parent environment + tt <- times[status == 1] # uncensored event times + if (is.null(knots) && !length(tt)) { + warning2("No observed events found in the data. Censoring times will ", + "be used to evaluate default knot locations for tve().") + tt <- times + } + + # note that min_t and max_t are taken from the parent environment + if (!is.null(knots)) { + if (any(knots < min_t)) + stop2("In tve(), 'knots' cannot be placed before the earliest entry time.") + if (any(knots > max_t)) + stop2("In tve(), 'knots' cannot be placed beyond the latest event time.") + } + + if (type == "bs") { + + iknots <- get_iknots(tt, df = df, iknots = knots, degree = degree) + + bknots <- c(min_t, max_t) + + new_args <- list(knots = iknots, + Boundary.knots = bknots, + degree = degree) + + return(list( + type = type, + degree = degree, + call = sub("^list\\(", "splines2::bSpline\\(times__, ", + deparse(new_args, 500L, control = c("all", "hexNumeric"))))) + # NB use of hexNumeric to ensure numeric accuracy is maintained + + } + + } + + tt_parsed <- eval(parse(text = all_vars[sel[i]])) + tt_terms <- which(attr(Terms, "factors")[i, ] > 0) + for (j in tt_terms) { + tt_vars [[idx]] <- tf_terms[j] + tt_types [[idx]] <- tt_parsed$type + tt_degrees[[idx]] <- tt_parsed$degree + tt_calls [[idx]] <- tt_parsed$call + idx <- idx + 1 + } + } + + # add on the terms labels from the random effects part of the formula + bars <- lme4::findbars(formula) + if (length(bars)) { + re_terms <- sapply(bars, bracket_wrap) + tf_terms <- c(tf_terms, re_terms) + } + + # formula with all variables but no 'tve(x, ...)' wrappers + tf_form <- reformulate(tf_terms, response = lhs(formula)) + + # formula with only tve variables but no 'tve(x, ...)' wrappers + td_form <- reformulate(td_terms, response = lhs(formula)) + + # unique set of '~ bs(times__, ...)' calls based on all 'tve(x, ...)' terms + tt_frame <- reformulate(unique(unlist(tt_calls)), intercept = FALSE) + + # formula with '~ x' and '~ bs(times__, ...)' from each 'tve(x, ...)' call + tt_vars <- lapply(tt_vars, reformulate) + tt_forms <- lapply(tt_calls, reformulate) + + # return object + nlist(tf_form, + td_form, + tt_vars, + tt_frame, + tt_types, + tt_degrees, + tt_calls, + tt_forms) +} + +# Ensure only valid arguments are passed to the tve() call +validate_tve_args <- function(dots, ok_args) { + + if (!isTRUE(all(names(dots) %in% ok_args))) + stop2("Invalid argument to 'tve' function. ", + "Valid arguments are: ", comma(ok_args)) + + return(dots) +} + +# Deparse an expression and wrap it in brackets +# +# @param x An expression. +# @return A character string. +bracket_wrap <- function(x) { + paste0("(", deparse(x, 500), ")") +} + +# Check input to the formula argument +# +# @param formula The user input to the formula argument. +# @param needs_response A logical; if TRUE then formula must contain a LHS. +# @return A formula. +validate_formula <- function(formula, needs_response = TRUE) { + + if (!inherits(formula, "formula")) { + stop2("'formula' must be a formula.") + } + + if (needs_response) { + len <- length(formula) + if (len < 3) { + stop2("'formula' must contain a response.") + } + } + as.formula(formula) +} + +# Check object is a Surv object with a valid type +# +# @param x A Surv object. That is, the LHS of a formula as evaluated in a +# data frame environment. +# @param ok_types A character vector giving the valid types of Surv object. +# @return A Surv object. +validate_surv <- function(x, ok_types = c("right", "counting", + "interval", "interval2")) { + if (!inherits(x, "Surv")) + stop2("LHS of 'formula' must be a 'Surv' object.") + if (!attr(x, "type") %in% ok_types) + stop2("Surv object type must be one of: ", comma(ok_types)) + x +} + +# Extract LHS of a formula +# +# @param x A formula object. +# @return An expression. +lhs <- function(x, as_formula = FALSE) { + len <- length(x) + if (len == 3L) { + out <- x[[2L]] + } else { + out <- NULL + } + out +} + +# Extract RHS of a formula +# +# @param x A formula object. +# @return An expression. +rhs <- function(x, as_formula = FALSE) { + len <- length(x) + if (len == 3L) { + out <- x[[3L]] + } else { + out <- x[[2L]] + } + out +} + +# Reformulate as LHS of a formula +# +# @param x A character string or expression. +# @return A formula. +reformulate_lhs <- function(x) { + x <- formula(substitute(LHS ~ 1, list(LHS = x))) + x +} + +# Reformulate as RHS of a formula +# +# @param x A character string or expression. +# @return A formula. +reformulate_rhs <- function(x) { + x <- formula(substitute(~ RHS, list(RHS = x))) + x +} + +# Return the response vector (time) +# +# @param model_frame The model frame. +# @param type The type of time variable to return: +# "beg": the entry time for the row in the survival data, +# "end": the exit time for the row in the survival data, +# "gap": the difference between entry and exit times, +# "upp": if the row involved interval censoring, then the exit time +# would have been the lower limit of the interval, and "upp" +# is the upper limit of the interval. +# @return A numeric vector. +make_t <- function(model_frame, type = c("beg", "end", "gap", "upp")) { + + type <- match.arg(type) + resp <- if (survival::is.Surv(model_frame)) + model_frame else model.response(model_frame) + surv <- attr(resp, "type") + err <- paste0("Bug found: cannot handle '", surv, "' Surv objects.") + + t_beg <- switch(surv, + "right" = rep(0, nrow(model_frame)), + "interval" = rep(0, nrow(model_frame)), + "interval2" = rep(0, nrow(model_frame)), + "counting" = as.vector(resp[, "start"]), + stop(err)) + + t_end <- switch(surv, + "right" = as.vector(resp[, "time"]), + "interval" = as.vector(resp[, "time1"]), + "interval2" = as.vector(resp[, "time1"]), + "counting" = as.vector(resp[, "stop"]), + stop(err)) + + t_upp <- switch(surv, + "right" = rep(NaN, nrow(model_frame)), + "counting" = rep(NaN, nrow(model_frame)), + "interval" = as.vector(resp[, "time2"]), + "interval2" = as.vector(resp[, "time2"]), + stop(err)) + + switch(type, + "beg" = t_beg, + "end" = t_end, + "gap" = t_end - t_beg, + "upp" = t_upp, + stop("Bug found: cannot handle specified 'type'.")) +} + +# Return the response vector (status indicator) +# +# @param model_frame The model frame. +# @return A numeric vector. +make_d <- function(model_frame) { + + resp <- if (survival::is.Surv(model_frame)) + model_frame else model.response(model_frame) + surv <- attr(resp, "type") + err <- paste0("Bug found: cannot handle '", surv, "' Surv objects.") + + switch(surv, + "right" = as.vector(resp[, "status"]), + "interval" = as.vector(resp[, "status"]), + "interval2" = as.vector(resp[, "status"]), + "counting" = as.vector(resp[, "status"]), + stop(err)) +} + +# Return a data frame with NAs excluded +# +# @param formula The parsed model formula. +# @param data The (user-specified) data frame. +# @return A data frame, with only complete cases for the variables that +# appear in the model formula. +make_model_data <- function(formula, data) { + mf <- model.frame(formula, data, na.action = na.pass) + include <- apply(mf, 1L, function(row) !any(is.na(row))) + data[include, , drop = FALSE] +} + +# Return the model frame +# +# @param formula The parsed model formula. +# @param data The model data frame. +# @param xlevs Passed to xlev argument of model.frame. +# @param drop.unused.levels Passed to drop.unused.levels argument of model.frame. +# @param check_constant If TRUE then an error is thrown is the returned +# model frame contains any constant variables. +# @return A list with the following elements: +# mf: the model frame based on the formula. +# mt: the model terms associated with the returned model frame. +make_model_frame <- function(formula, + data, + xlevs = NULL, + drop.unused.levels = FALSE, + check_constant = FALSE, + na.action = na.fail) { + + # construct model frame + Terms <- terms(lme4::subbars(formula)) + mf <- stats::model.frame(Terms, + data, + xlev = xlevs, + drop.unused.levels = drop.unused.levels, + na.action = na.action) + + # get predvars for fixed part of formula + TermsF <- terms(lme4::nobars(formula)) + mfF <- stats::model.frame(TermsF, + data, + xlev = xlevs, + drop.unused.levels = drop.unused.levels, + na.action = na.action) + attr(attr(mf, "terms"), "predvars.fixed") <- attr(attr(mfF, "terms"), "predvars") + + # get predvars for random part of formula + has_bars <- length(lme4::findbars(formula)) > 0 + if (has_bars) { + TermsR <- terms(lme4::subbars(justRE(formula, response = TRUE))) + mfR <- stats::model.frame(TermsR, + data, + xlev = xlevs, + drop.unused.levels = drop.unused.levels, + na.action = na.action) + attr(attr(mf, "terms"), "predvars.random") <- attr(attr(mfR, "terms"), "predvars") + } else { + attr(attr(mf, "terms"), "predvars.random") <- NULL + } + + # check no constant vars + if (check_constant) + mf <- check_constant_vars(mf) + + # add additional predvars attributes + + # check for terms + mt <- attr(mf, "terms") + if (is.empty.model(mt)) + stop2("No intercept or predictors specified.") + + nlist(mf, mt) +} + +# Return the predictor matrix +# +# @param formula The parsed model formula. +# @param model_frame The model frame. +# @param xlevs Passed to xlev argument of model.matrix. +# @param check_constant If TRUE then an error is thrown is the returned +# predictor matrix contains any constant columns. +# @return A named list with the following elements: +# x: the fe model matrix, not centered and without intercept. +# x_bar: the column means of the model matrix. +# x_centered: the fe model matrix, centered. +# N: number of rows (observations) in the model matrix. +# K: number of cols (predictors) in the model matrix. +make_x <- function(formula, + model_frame, + xlevs = NULL, + check_constant = TRUE) { + + # uncentred predictor matrix, without intercept + x <- model.matrix(formula, model_frame, xlev = xlevs) + x <- drop_intercept(x) + + # column means of predictor matrix + x_bar <- aa(colMeans(x)) + + # centered predictor matrix + x_centered <- sweep(x, 2, x_bar, FUN = "-") + + # identify any column of x with < 2 unique values (empty interaction levels) + sel <- (apply(x, 2L, n_distinct) < 2) + if (check_constant && any(sel)) { + cols <- paste(colnames(x)[sel], collapse = ", ") + stop2("Cannot deal with empty interaction levels found in columns: ", cols) + } + + nlist(x, x_centered, x_bar, N = NROW(x), K = NCOL(x)) +} + +# Return the tve predictor matrix +# +# @param formula The parsed model formula. +# @param model_frame The model frame. +# @param xlevs Passed to xlev argument of model.matrix. +# @return A named list with the following elements: +# s: model matrix for time-varying terms, not centered and without intercept. +# tt_ncol: stored attribute, a numeric vector with the number of columns in +# the model matrix that correspond to each tve() term in the original +# model formula. +# tt_map: stored attribute, a numeric vector with indexing for the columns +# of the model matrix stating which tve() term in the original model +# formula they correspond to. +make_s <- function(formula, + model_frame, + xlevs = NULL) { + + # create the design matrix for each tve() term + s_parts <- xapply( + formula$tt_vars, # names of variables with a tve() wrapper + formula$tt_forms, # time-transformation functions to interact them with + FUN = function(vn, tt) { + m1 <- make_x(vn, model_frame, xlevs = xlevs, check_constant = FALSE)$x + m2 <- make_x(tt, model_frame, xlevs = xlevs, check_constant = FALSE)$x + m3 <- matrix(apply(m1, 2L, `*`, m2), nrow = nrow(m2)) + colnames(m3) <- uapply(colnames(m1), paste, colnames(m2), sep = ":") + return(m3) + }) + + # bind columns to form one design matrix for tve() terms + s <- do.call("cbind", s_parts) + + # store indexing of the columns in the design matrix + tt_ncol <- sapply(s_parts, ncol) + tt_map <- rep(seq_along(tt_ncol), tt_ncol) + + # return design matrix with indexing info as an attribute + structure(s, tt_ncol = tt_ncol, tt_map = tt_map) +} + +# Check if the only element of a character vector is 'Intercept' +# +# @param x A character vector. +# @return A logical. +intercept_only <- function(x) { + length(x) == 1 && x == "(Intercept)" +} diff --git a/R/stanreg-methods.R b/R/stanreg-methods.R index fa1a9271a..b4c0d140a 100644 --- a/R/stanreg-methods.R +++ b/R/stanreg-methods.R @@ -90,7 +90,7 @@ NULL #' @rdname stanreg-methods #' @export coef.stanreg <- function(object, ...) { - if (is.mer(object)) + if (is.mer(object) && !is.surv(object)) return(coef_mer(object, ...)) object$coefficients @@ -121,7 +121,7 @@ fitted.stanreg <- function(object, ...) { #' @rdname stanreg-methods #' @export nobs.stanreg <- function(object, ...) { - nrow(model.frame(object)) + if (is.surv(object)) object$nobs else nrow(model.frame(object)) } #' @rdname stanreg-methods @@ -399,6 +399,9 @@ family.stanreg <- function(object, ...) object$family #' @param fixed.only See \code{\link[lme4:merMod-class]{model.frame.merMod}}. #' model.frame.stanreg <- function(formula, fixed.only = FALSE, ...) { + if (is.stansurv(formula)) { + return(formula$model_frame) + } if (is.mer(formula)) { fr <- formula$glmod$fr if (fixed.only) { @@ -420,7 +423,7 @@ model.frame.stanreg <- function(formula, fixed.only = FALSE, ...) { #' model.matrix.stanreg <- function(object, ...) { if (inherits(object, "gamm4")) return(object$jam$X) - if (is.mer(object)) return(object$glmod$X) + if (is.mer(object) && !is.surv(object)) return(object$glmod$X) NextMethod("model.matrix") } @@ -431,10 +434,15 @@ model.matrix.stanreg <- function(object, ...) { #' @export #' @param x A stanreg object. #' @param ... Can contain \code{fixed.only} and \code{random.only} arguments -#' that both default to \code{FALSE}. +#' that both default to \code{FALSE}. Also, for stan_surv models, can contain +#' \code{remove.tve} which defaults to FALSE, but if TRUE then any +#' 'tve(varname)' terms in the model formula are returned as 'varname'. #' formula.stanreg <- function(x, ..., m = NULL) { - if (is.mer(x) && !isTRUE(x$stan_function == "stan_gamm4")) return(formula_mer(x, ...)) + if (is.surv(x)) + return(formula_surv(x, ...)) + if (is.mer(x) && !isTRUE(x$stan_function == "stan_gamm4")) + return(formula_mer(x, ...)) x$formula } @@ -444,10 +452,11 @@ formula.stanreg <- function(x, ..., m = NULL) { #' @param x,fixed.only,random.only,... See lme4:::terms.merMod. #' terms.stanreg <- function(x, ..., fixed.only = TRUE, random.only = FALSE) { - if (!is.mer(x)) + if (!any(is.mer(x), is.stansurv(x))) return(NextMethod("terms")) - fr <- x$glmod$fr + fr <- if (is.stansurv(x)) model.frame(x) else x$glmod$fr + if (missing(fixed.only) && random.only) fixed.only <- FALSE if (fixed.only && random.only) @@ -455,11 +464,12 @@ terms.stanreg <- function(x, ..., fixed.only = TRUE, random.only = FALSE) { Terms <- attr(fr, "terms") if (fixed.only) { - Terms <- terms.formula(formula(x, fixed.only = TRUE)) + Terms <- terms.formula(formula(x, fixed.only = TRUE, remove.tve = TRUE)) attr(Terms, "predvars") <- attr(terms(fr), "predvars.fixed") } if (random.only) { - Terms <- terms.formula(lme4::subbars(formula.stanreg(x, random.only = TRUE))) + Terms <- terms.formula(lme4::subbars(formula.stanreg(x, random.only = TRUE, + remove.tve = TRUE))) attr(Terms, "predvars") <- attr(terms(fr), "predvars.random") } @@ -479,11 +489,17 @@ terms.stanreg <- function(x, ..., fixed.only = TRUE, random.only = FALSE) { .glmer_check(object) object$glmod$reTrms$cnms } +.cnms.stansurv <- function(object, ...) { + object$cnms +} .flist <- function(object, ...) UseMethod(".flist") .flist.stanreg <- function(object, ...) { .glmer_check(object) as.list(object$glmod$reTrms$flist) } +.flist.stansurv <- function(object, ...) { + as.list(object$flist) +} coef_mer <- function(object, ...) { if (length(list(...))) @@ -540,3 +556,25 @@ formula_mer <- function (x, fixed.only = FALSE, random.only = FALSE, ...) { return(form) } +formula_surv <- function(x, + fixed.only = FALSE, + random.only = FALSE, + remove.tve = FALSE, + ...) { + if (missing(fixed.only) && random.only) + fixed.only <- FALSE + if (fixed.only && random.only) + stop2("'fixed.only' and 'random.only' can't both be TRUE.") + if (remove.tve) { + form <- x$formula$tf_form + } else { + form <- x$formula$formula + } + if (is.null(form)) + stop2("Can't find formula in model object.") + if (fixed.only) + form[[length(form)]] <- lme4::nobars(form[[length(form)]]) + if (random.only) + form <- justRE(form, response = TRUE) + return(form) +} diff --git a/R/stansurv.R b/R/stansurv.R new file mode 100644 index 000000000..70c3e1377 --- /dev/null +++ b/R/stansurv.R @@ -0,0 +1,132 @@ +# Part of the rstanarm package for estimating model parameters +# Copyright (C) 2015, 2016, 2017 Trustees of Columbia University +# Copyright (C) 2016, 2017, 2018 Sam Brilleman +# +# This program is free software; you can redistribute it and/or +# modify it under the terms of the GNU General Public License +# as published by the Free Software Foundation; either version 3 +# of the License, or (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program; if not, write to the Free Software +# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. + +# Function to create a stansurv object (fitted model object) +# +# @param object A list returned by a call to stan_surv +# @return A stansurv object +# +stansurv <- function(object) { + + alg <- object$algorithm + opt <- alg == "optimizing" + mcmc <- alg == "sampling" + stanfit <- object$stanfit + basehaz <- object$basehaz + K <- NCOL(object$x) + + if (opt) + stop2("Optimisation not implemented for 'stansurv' objects.") + + stan_summary <- make_stan_summary(stanfit) + + # number of parameters + nvars <- + has_intercept(basehaz) + + ncol(object$x) + + ncol(object$s_cpts) + + basehaz$nvars + + nms_beta <- colnames(object$x_cpts) + nms_tve <- get_smooth_name(object$s_cpts, type = "smooth_coefs") + nms_smooth <- get_smooth_name(object$s_cpts, type = "smooth_sd") + nms_int <- get_int_name_basehaz(object$basehaz) + nms_aux <- get_aux_name_basehaz(object$basehaz) + nms_b <- get_b_names(object$group) + nms_vc <- get_varcov_names(object$group) + nms_coefs <- c(nms_int, + nms_beta, + nms_tve, + nms_aux, + nms_b) + + # obtain medians + coefs <- stan_summary[nms_coefs, select_median(alg)] + names(coefs) <- nms_coefs # ensure parameter names are retained + + # obtain standard errors and covariance matrix + stanmat <- as.matrix(stanfit)[, nms_coefs, drop = FALSE] + colnames(stanmat) <- nms_coefs + ses <- apply(stanmat, 2L, mad) + covmat <- cov(stanmat) + + # for mcmc only + if (mcmc) { + check_rhats(stan_summary[, "Rhat"]) # check rhats for all parameters + runtime <- get_runtime(object$stanfit) # run time (in mins) + } + + # return object of class 'stansurv' + out <- nlist( + coefficients = coefs, + ses = ses, + covmat = covmat, + formula = object$formula, + has_tve = object$has_tve, + has_quadrature= object$has_quadrature, + has_bars = object$has_bars, + terms = object$terms, + data = object$data, + model_frame = object$model_frame, + xlevs = object$xlevels, + x = object$x, + x_cpts = object$x_cpts, + s_cpts = object$s_cpts, + z_cpts = object$z_cpts, + cnms = object$cnms, + flist = object$flist, + entrytime = object$t_beg, + eventtime = object$t_end, + event = object$event, + delayed = object$delayed, + basehaz = object$basehaz, + nobs = object$nobs, + nevents = object$nevents, + nlcens = object$nlcens, + nrcens = object$nrcens, + nicens = object$nicens, + ncensor = object$ncensor, + ndelayed = object$ndelayed, + qnodes = object$qnodes, + prior.info = object$prior_info, + algorithm = object$algorithm, + stan_function = object$stan_function, + call = object$call, + runtime = if (mcmc) runtime else NULL, + rstan_version = utils::packageVersion("rstan"), + rstanarm_version = utils::packageVersion("rstanarm"), + stan_summary, + stanfit + ) + out <- rm_null(out, recursive = FALSE) + + structure(out, class = c("stansurv", "stanreg")) +} + + +#---------- internal + +# Return the model fitting time in seconds +# +# @param stanfit An object of class 'stanfit'. +# @return A matrix of runtimes, stratified by warmup/sampling and chain/overall. +get_runtime <- function(stanfit) { + tt <- rstan::get_elapsed_time(stanfit) + tt <- round(tt, digits = 0L) # time per chain + tt <- cbind(tt, total = rowSums(tt)) # time per chain & overall +} diff --git a/data/bcancer.rda b/data/bcancer.rda new file mode 100644 index 000000000..75cb67b94 Binary files /dev/null and b/data/bcancer.rda differ diff --git a/data/frail.rda b/data/frail.rda new file mode 100644 index 000000000..3563ff9bc Binary files /dev/null and b/data/frail.rda differ diff --git a/data/mice.rda b/data/mice.rda new file mode 100644 index 000000000..6da6d4ef6 Binary files /dev/null and b/data/mice.rda differ diff --git a/man-roxygen/args-stansurv-stanjm-object.R b/man-roxygen/args-stansurv-stanjm-object.R new file mode 100644 index 000000000..8ef6c5c85 --- /dev/null +++ b/man-roxygen/args-stansurv-stanjm-object.R @@ -0,0 +1,3 @@ +#' @param <%= stanregArg %> A fitted model object returned by the +#' \code{\link{stan_surv}} or \code{\link{stan_jm}} modelling function. +#' See \code{\link{stanreg-objects}}. diff --git a/src/stan_files/functions/hazard_functions.stan b/src/stan_files/functions/hazard_functions.stan new file mode 100644 index 000000000..fa29e91b5 --- /dev/null +++ b/src/stan_files/functions/hazard_functions.stan @@ -0,0 +1,138 @@ + /** + * Log hazard for exponential distribution + * + * @param eta Vector, linear predictor + * @return A vector + */ + vector exponential_log_haz(vector eta) { + return eta; + } + + /** + * Log hazard for exponential distribution; AFT parameterisation + * + * @param af Vector, acceleration factor at time t + * @return A vector + */ + vector exponentialAFT_log_haz(vector af) { + return log(af); + } + + /** + * Log hazard for Weibull distribution + * + * @param eta Vector, linear predictor + * @param t Vector, event or censoring times + * @param shape Real, Weibull shape + * @return A vector + */ + vector weibull_log_haz(vector eta, vector t, real shape) { + return log(shape) + (shape - 1) * log(t) + eta; + } + + /** + * Log hazard for Weibull distribution; AFT parameterisation + * + * @param af Vector, acceleration factor at time t + * @param caf Vector, cumulative acceleration factor at time t + * @param shape Real, Weibull shape + * @return A vector + */ + vector weibullAFT_log_haz(vector af, vector caf, real shape) { + return log(shape) + (shape - 1) * log(caf) + log(af); + } + + /** + * Log hazard for Gompertz distribution + * + * @param eta Vector, linear predictor + * @param t Vector, event or censoring times + * @param scale Real, Gompertz scale + * @return A vector + */ + vector gompertz_log_haz(vector eta, vector t, real scale) { + return scale * t + eta; + } + + /** + * Log hazard for M-spline model + * + * @param eta Vector, linear predictor + * @param t Vector, event or censoring times + * @param coefs Vector, M-spline coefficients + * @return A vector + */ + vector mspline_log_haz(vector eta, matrix basis, vector coefs) { + return log(basis * coefs) + eta; + } + + /** + * Log hazard for B-spline model + * + * @param eta Vector, linear predictor + * @param t Vector, event or censoring times + * @param coefs Vector, B-spline coefficients + * @return A vector + */ + vector bspline_log_haz(vector eta, matrix basis, vector coefs) { + return basis * coefs + eta; + } + + /** + * Evaluate log survival or log CDF from the log hazard evaluated at + * quadrature points and a corresponding vector of quadrature weights + * + * @param qwts Vector, the quadrature weights + * @param log_hazard Vector, log hazard at the quadrature points + * @param qnodes Integer, the number of quadrature points for each individual + * @param N Integer, the number of individuals (ie. rows(log_hazard) / qnodes) + * @return A vector + */ + real quadrature_log_surv(vector qwts, vector log_hazard) { + return - dot_product(qwts, exp(log_hazard)); // sum across all individuals + } + + vector quadrature_log_cdf1(vector qwts, vector log_hazard, int qnodes, int N) { + int M = rows(log_hazard); + vector[M] hazard = exp(log_hazard); + matrix[N,qnodes] qwts_mat = to_matrix(qwts, N, qnodes); + matrix[N,qnodes] haz_mat = to_matrix(hazard, N, qnodes); + vector[N] chaz = rows_dot_product(qwts_mat, haz_mat); + return log(1 - exp(- chaz)); + } + + vector quadrature_log_cdf2(vector qwts_lower, vector log_hazard_lower, + vector qwts_upper, vector log_hazard_upper, + int qnodes, int N) { + int M = rows(log_hazard_lower); + vector[M] hazard_lower = exp(log_hazard_lower); + vector[M] hazard_upper = exp(log_hazard_upper); + matrix[N,qnodes] qwts_lower_mat = to_matrix(qwts_lower, N, qnodes); + matrix[N,qnodes] qwts_upper_mat = to_matrix(qwts_upper, N, qnodes); + matrix[N,qnodes] haz_lower_mat = to_matrix(hazard_lower, N, qnodes); + matrix[N,qnodes] haz_upper_mat = to_matrix(hazard_upper, N, qnodes); + vector[N] chaz_lower = rows_dot_product(qwts_lower_mat, haz_lower_mat); + vector[N] chaz_upper = rows_dot_product(qwts_upper_mat, haz_upper_mat); + vector[N] surv_lower = exp(- chaz_lower); + vector[N] surv_upper = exp(- chaz_upper); + return log(surv_lower - surv_upper); + } + + + /** + * Evaluate cumulative acceleration factor from the linear predictor evaluated + * at quadrature points and a corresponding vector of quadrature weights + * + * @param qwts Vector, the quadrature weights + * @param eta Vector, linear predictor at the quadrature points + * @param qnodes Integer, the number of quadrature points for each individual + * @param N Integer, the number of individuals (ie. rows(eta) / qnodes) + * @return A vector + */ + vector quadrature_aft(vector qwts, vector eta, int qnodes, int N) { + int M = rows(eta); + vector[M] af = exp(-eta); // time-varying acceleration factor + matrix[N,qnodes] qwts_mat = to_matrix(qwts, N, qnodes); + matrix[N,qnodes] af_mat = to_matrix(af, N, qnodes); + return rows_dot_product(qwts_mat, af_mat); // cumulative acceleration factor + } diff --git a/src/stan_files/functions/jm_functions.stan b/src/stan_files/functions/jm_functions.stan index d10f1ff6a..d5afca303 100644 --- a/src/stan_files/functions/jm_functions.stan +++ b/src/stan_files/functions/jm_functions.stan @@ -28,7 +28,7 @@ * @param dist Integer specifying the type of prior distribution * @param scale Real specifying the scale for the prior distribution * @param df Real specifying the df for the prior distribution - * @return nothing + * @return lp__ */ real basehaz_lpdf(vector aux_unscaled, int dist, vector scale, vector df) { real lp = 0; diff --git a/src/stan_files/functions/mvmer_functions.stan b/src/stan_files/functions/mvmer_functions.stan index 4b51f532b..e2ead710c 100644 --- a/src/stan_files/functions/mvmer_functions.stan +++ b/src/stan_files/functions/mvmer_functions.stan @@ -232,13 +232,13 @@ } /** - * Increment the target with the log-likelihood for the glmer submodel + * Evaluate the log-likelihood for the glmer submodel * * @param z_beta A vector of primitive parameters * @param prior_dist Integer, the type of prior distribution * @param prior_mean,prior_scale Vectors of mean and scale parameters * for the prior distributions - * @return A vector containing the population level parameters (coefficients) + * @return lp__ */ real glm_lpdf(vector y_real, array[] int y_integer, vector eta, array[] real aux, int family, int link, real sum_log_y, vector sqrt_y, vector log_y) { @@ -286,7 +286,7 @@ * @param global Real, the global parameter * @param mix Vector of shrinkage parameters * @param one_over_lambda Real - * @return nothing + * @return lp__ */ real beta_custom_lpdf(vector z_beta, int prior_dist, vector prior_scale, vector prior_df, real global_prior_df, array[] vector local, @@ -338,7 +338,7 @@ * @param mean_ Real, mean of prior distribution * @param scale Real, scale for the prior distribution * @param df Real, df for the prior distribution - * @return nothing + * @return lp__ */ real gamma_custom_lpdf(real gamma, int dist, real mean_, real scale, real df) { real lp = 0; @@ -359,7 +359,7 @@ * @param dist Integer specifying the type of prior distribution * @param scale Real specifying the scale for the prior distribution * @param df Real specifying the df for the prior distribution - * @return nothing + * @return lp__ */ real aux_lpdf(real aux_unscaled, int dist, real scale, real df) { real lp = 0; diff --git a/src/stan_files/mvmer.stan b/src/stan_files/mvmer.stan index 12893f688..cd36961c9 100644 --- a/src/stan_files/mvmer.stan +++ b/src/stan_files/mvmer.stan @@ -40,7 +40,7 @@ transformed data { parameters { // declares: yGamma{1,2,3}, z_yBeta{1,2,3}, z_b, z_T, rho, // zeta, tau, bSd{1,2}, z_bMat{1,2}, bCholesky{1,2}, - // yAux{1,2,3}_unscaled, yGlobal{1,2,3}, yLocal{1,2,3}, + // yAux{1,2,3}_unscaled, yGlobal{1,2,3}, yLocal{1,2,3}, // yOol{1,2,3}, yMix{1,2,3} #include /parameters/parameters_mvmer.stan } diff --git a/src/stan_files/surv.stan b/src/stan_files/surv.stan new file mode 100644 index 000000000..03492c9bf --- /dev/null +++ b/src/stan_files/surv.stan @@ -0,0 +1,1200 @@ +#include /pre/Columbia_copyright.stan +#include /pre/Brilleman_copyright.stan +#include /pre/license.stan + +functions { + + #include /functions/common_functions.stan + #include /functions/hazard_functions.stan + + /** + * Return the lower bound for the baseline hazard parameters + * + * @param type An integer indicating the type of baseline hazard + * @return A real + */ + real coefs_lb(int type) { + real lb; + if (type == 2) // B-splines, on log haz scale + lb = negative_infinity(); + else if (type == 3) // piecewise constant, on log haz scale + lb = negative_infinity(); + else + lb = 0; + return lb; + } + + /** + * Return the required number of local hs parameters + * + * @param prior_dist An integer indicating the prior distribution + * @return An integer + */ + int get_nvars_for_hs(int prior_dist) { + int hs = 0; + if (prior_dist == 3) hs = 2; + else if (prior_dist == 4) hs = 4; + return hs; + } + + /** + * Scale the primitive population level parameters based on prior information + * + * @param z_beta A vector of primitive parameters + * @param prior_dist Integer, the type of prior distribution + * @param prior_mean,prior_scale Vectors of mean and scale parameters + * for the prior distributions + * @return A vector containing the population level parameters (coefficients) + */ + vector make_beta(vector z_beta, int prior_dist, vector prior_mean, + vector prior_scale, vector prior_df, real global_prior_scale, + array[] real global, array[] vector local, array[] real ool, array[] vector mix, + array[] real aux, int family, real slab_scale, array[] real caux) { + vector[rows(z_beta)] beta; + if (prior_dist == 0) beta = z_beta; + else if (prior_dist == 1) beta = z_beta .* prior_scale + prior_mean; + else if (prior_dist == 2) for (k in 1:rows(prior_mean)) { + beta[k] = CFt(z_beta[k], prior_df[k]) * prior_scale[k] + prior_mean[k]; + } + else if (prior_dist == 3) { + real c2 = square(slab_scale) * caux[1]; + if (family == 1) // don't need is_continuous since family == 1 is gaussian in mvmer + beta = hs_prior(z_beta, global, local, global_prior_scale, aux[1], c2); + else + beta = hs_prior(z_beta, global, local, global_prior_scale, 1, c2); + } + else if (prior_dist == 4) { + real c2 = square(slab_scale) * caux[1]; + if (family == 1) // don't need is_continuous since family == 1 is gaussian in mvmer + beta = hsplus_prior(z_beta, global, local, global_prior_scale, aux[1], c2); + else + beta = hsplus_prior(z_beta, global, local, global_prior_scale, 1, c2); + } + else if (prior_dist == 5) // laplace + beta = prior_mean + prior_scale .* sqrt(2 * mix[1]) .* z_beta; + else if (prior_dist == 6) // lasso + beta = prior_mean + ool[1] * prior_scale .* sqrt(2 * mix[1]) .* z_beta; + return beta; + } + + /** + * Log-prior for coefficients + * + * @param z_beta Vector of primative coefficients + * @param prior_dist Integer, the type of prior distribution + * @param prior_scale Real, scale for the prior distribution + * @param prior_df Real, df for the prior distribution + * @param global_prior_df Real, df for the prior for the global hs parameter + * @param local Vector of hs local parameters + * @param global Real, the global parameter + * @param mix Vector of shrinkage parameters + * @param one_over_lambda Real + * @return Real, the log probability. + */ + real beta_custom_lpdf(vector z_beta, int prior_dist, vector prior_scale, + vector prior_df, real global_prior_df, array[] vector local, + array[] real global, array[] vector mix, array[] real one_over_lambda, + real slab_df, array[] real caux) { + real lp = 0; + if (prior_dist == 1) lp += normal_lpdf(z_beta | 0, 1); + else if (prior_dist == 2) lp += normal_lpdf(z_beta | 0, 1); // Student t + else if (prior_dist == 3) { // hs + lp += normal_lpdf(z_beta | 0, 1); + lp += normal_lpdf(local[1] | 0, 1); + lp += inv_gamma_lpdf(local[2] | 0.5 * prior_df, 0.5 * prior_df); + lp += normal_lpdf(global[1] | 0, 1); + lp += inv_gamma_lpdf(global[2] | 0.5 * global_prior_df, 0.5 * global_prior_df); + lp += inv_gamma_lpdf(caux | 0.5 * slab_df, 0.5 * slab_df); + } + else if (prior_dist == 4) { // hs+ + lp += normal_lpdf(z_beta | 0, 1); + lp += normal_lpdf(local[1] | 0, 1); + lp += inv_gamma_lpdf(local[2] | 0.5 * prior_df, 0.5 * prior_df); + lp += normal_lpdf(local[3] | 0, 1); + // unorthodox useage of prior_scale as another df hyperparameter + lp += inv_gamma_lpdf(local[4] | 0.5 * prior_scale, 0.5 * prior_scale); + lp += normal_lpdf(global[1] | 0, 1); + lp += inv_gamma_lpdf(global[2] | 0.5 * global_prior_df, 0.5 * global_prior_df); + lp += inv_gamma_lpdf(caux | 0.5 * slab_df, 0.5 * slab_df); + } + else if (prior_dist == 5) { // laplace + lp += normal_lpdf(z_beta | 0, 1); + lp += exponential_lpdf(mix[1] | 1); + } + else if (prior_dist == 6) { // lasso + lp += normal_lpdf(z_beta | 0, 1); + lp += exponential_lpdf(mix[1] | 1); + lp += chi_square_lpdf(one_over_lambda[1] | prior_df[1]); + } + else if (prior_dist == 7) { // product_normal + lp += normal_lpdf(z_beta | 0, 1); + } + /* else prior_dist is 0 and nothing is added */ + return lp; + } + + /** + * Log-prior for intercept parameters + * + * @param gamma Real, the intercept parameter + * @param dist Integer, the type of prior distribution + * @param mean Real, mean of prior distribution + * @param scale Real, scale for the prior distribution + * @param df Real, df for the prior distribution + * @return Real, the log probability + */ + real gamma_custom_lpdf(real gamma, int dist, real mean, real scale, real df) { + real lp = 0; + if (dist == 1) // normal + lp += normal_lpdf(gamma | mean, scale); + else if (dist == 2) // student_t + lp += student_t_lpdf(gamma | df, mean, scale); + /* else dist is 0 and nothing is added */ + return lp; + } + + /** + * Log-prior for baseline hazard parameters + * + * @param aux_unscaled Vector (potentially of length 1) of unscaled + * auxiliary parameter(s) + * @param dist Integer specifying the type of prior distribution + * @param df Real specifying the df for the prior distribution, or in the case + * of the dirichlet distribution it is the concentration parameter(s) + * @return Real, the log probability + */ + real basehaz_lpdf(vector aux_unscaled, int dist, vector df) { + real lp = 0; + if (dist > 0) { + if (dist == 1) + lp += normal_lpdf(aux_unscaled | 0, 1); + else if (dist == 2) + lp += student_t_lpdf(aux_unscaled | df, 0, 1); + else if (dist == 3) + lp += exponential_lpdf(aux_unscaled | 1); + else + lp += dirichlet_lpdf(aux_unscaled | df); // df is concentration here + } + return lp; + } + + /** + * Log-prior for tve spline coefficients and their smoothing parameters + * + * @param z_beta_tve Vector of unscaled spline coefficients + * @param smooth_sd_raw Vector (potentially of length 1) of smoothing sds + * @param dist Integer specifying the type of prior distribution for the + * smoothing sds + * @param df Vector of reals specifying the df for the prior distribution + * for the smoothing sds + * @return Real, the log probability + */ + real smooth_lpdf(vector z_beta_tve, vector smooth_sd_raw, int dist, vector df) { + real lp = 0; + lp += normal_lpdf(z_beta_tve | 0, 1); + if (dist > 0) { + real log_half = -0.693147180559945286; + if (dist == 1) + lp += normal_lpdf(smooth_sd_raw | 0, 1) - log_half; + else if (dist == 2) + lp += student_t_lpdf(smooth_sd_raw | df, 0, 1) - log_half; + else if (dist == 3) + lp += exponential_lpdf(smooth_sd_raw | 1); + } + return lp; + } + + /** + * Raise each element of x to the power of y + * + * @param x Vector + * @param y Real, the power to raise to + * @return vector + */ + vector pow_vec(vector x, real y) { + int N = rows(x); + vector[N] res; + for (n in 1:N) + res[n] = pow(x[n], y); + return res; + } + + /** + * Log survival and log CDF for exponential distribution + * + * @param eta Vector, linear predictor + * @param t Vector, event or censoring times + * @return A vector + */ + vector exponential_log_surv(vector eta, vector t) { + vector[rows(eta)] res; + res = - t .* exp(eta); + return res; + } + + vector exponential_log_cdf1(vector eta, vector t) { + vector[rows(eta)] res; + res = log(1 - exp(-t .* exp(eta))); + return res; + } + + vector exponential_log_cdf2(vector eta, vector t_lower, vector t_upper) { + int N = rows(eta); + vector[N] exp_eta = exp(eta); + vector[N] surv_lower = exp(-t_lower .* exp_eta); + vector[N] surv_upper = exp(-t_upper .* exp_eta); + vector[N] res; + res = log(surv_lower - surv_upper); + return res; + } + + /** + * Log survival and log CDF for exponential distribution; AFT parameterisation + * + * @param caf Vector, cumulative acceleration factor + * @return A vector + */ + vector exponentialAFT_log_surv(vector caf) { + vector[rows(caf)] res; + res = - caf; + return res; + } + + vector exponentialAFT_log_cdf1(vector caf) { + vector[rows(caf)] res; + res = log(1 - exp(-caf)); + return res; + } + + vector exponentialAFT_log_cdf2(vector caf_lower, vector caf_upper) { + int N = rows(caf_lower); + vector[N] surv_lower = exp(-caf_lower); + vector[N] surv_upper = exp(-caf_upper); + vector[N] res; + res = log(surv_lower - surv_upper); + return res; + } + + /** + * Log survival and log CDF for Weibull distribution + * + * @param eta Vector, linear predictor + * @param t Vector, event or censoring times + * @param shape Real, Weibull shape + * @return A vector + */ + vector weibull_log_surv(vector eta, vector t, real shape) { + vector[rows(eta)] res; + res = - pow_vec(t, shape) .* exp(eta); + return res; + } + + vector weibull_log_cdf1(vector eta, vector t, real shape) { + vector[rows(eta)] res; + res = log(1 - exp(- pow_vec(t, shape) .* exp(eta))); + return res; + } + + vector weibull_log_cdf2(vector eta, vector t_lower, vector t_upper, real shape) { + int N = rows(eta); + vector[N] exp_eta = exp(eta); + vector[N] surv_lower = exp(- pow_vec(t_lower, shape) .* exp_eta); + vector[N] surv_upper = exp(- pow_vec(t_upper, shape) .* exp_eta); + vector[N] res; + res = log(surv_lower - surv_upper); + return res; + } + + /** + * Log survival and log CDF for Weibull distribution; AFT parameterisation + * + * @param caf Vector, cumulative acceleration factor + * @param shape Real, Weibull shape + * @return A vector + */ + vector weibullAFT_log_surv(vector caf, real shape) { + vector[rows(caf)] res; + res = - pow_vec(caf, shape); + return res; + } + + vector weibullAFT_log_cdf1(vector caf, real shape) { + vector[rows(caf)] res; + res = log(1 - exp(- pow_vec(caf, shape))); + return res; + } + + vector weibullAFT_log_cdf2(vector caf_lower, vector caf_upper, real shape) { + int N = rows(caf_lower); + vector[N] surv_lower = exp(- pow_vec(caf_lower, shape)); + vector[N] surv_upper = exp(- pow_vec(caf_upper, shape)); + vector[N] res; + res = log(surv_lower - surv_upper); + return res; + } + + /** + * Log survival and log CDF for Gompertz distribution + * + * @param eta Vector, linear predictor + * @param t Vector, event or censoring times + * @param scale Real, Gompertz scale + * @return A vector + */ + vector gompertz_log_surv(vector eta, vector t, real scale) { + vector[rows(eta)] res; + res = inv(scale) * -(exp(scale * t) - 1) .* exp(eta); + return res; + } + + vector gompertz_log_cdf1(vector eta, vector t, real scale) { + vector[rows(eta)] res; + res = log(1 - exp(inv(scale) * -(exp(scale * t) - 1) .* exp(eta))); + return res; + } + + vector gompertz_log_cdf2(vector eta, vector t_lower, vector t_upper, real scale) { + int N = rows(eta); + real inv_scale = inv(scale); + vector[N] exp_eta = exp(eta); + vector[N] surv_lower = exp(inv_scale * -(exp(scale * t_lower) - 1) .* exp_eta); + vector[N] surv_upper = exp(inv_scale * -(exp(scale * t_upper) - 1) .* exp_eta); + vector[N] res; + res = log(surv_lower - surv_upper); + return res; + } + + /** + * Log survival and log CDF for M-spline model + * + * @param eta Vector, linear predictor + * @param t Vector, event or censoring times + * @param coefs Vector, M-spline coefficients + * @return A vector + */ + vector mspline_log_surv(vector eta, matrix ibasis, vector coefs) { + vector[rows(eta)] res; + res = - (ibasis * coefs) .* exp(eta); + return res; + } + + vector mspline_log_cdf1(vector eta, matrix ibasis, vector coefs) { + vector[rows(eta)] res; + res = log(1 - exp(-(ibasis * coefs) .* exp(eta))); + return res; + } + + vector mspline_log_cdf2(vector eta, matrix ibasis_lower, matrix ibasis_upper, vector coefs) { + int N = rows(eta); + vector[N] exp_eta = exp(eta); + vector[N] surv_lower = exp(-(ibasis_lower * coefs) .* exp_eta); + vector[N] surv_upper = exp(-(ibasis_upper * coefs) .* exp_eta); + vector[N] res; + res = log(surv_lower - surv_upper); + return res; + } + +} + +data { + + // dimensions + int K; // num. cols in predictor matrix (time-fixed) + int S; // num. cols in predictor matrix (time-varying) + int nevent; // num. rows w/ an event (ie. not censored) + int nlcens; // num. rows w/ left censoring + int nrcens; // num. rows w/ right censoring + int nicens; // num. rows w/ interval censoring + int ndelay; // num. rows w/ delayed entry + int Nevent; // num. rows w/ an event; used only w/ quadrature + int Nlcens; // num. rows w/ left cens; used only w/ quadrature + int Nrcens; // num. rows w/ right cens; used only w/ quadrature + int Nicens; // num. rows w/ interval cens; used only w/ quadrature + int Ndelay; // num. rows w/ delayed entry; used only w/ quadrature + int qnodes; // num. nodes for GK quadrature + int qevent; // num. quadrature points for rows w/ an event + int qlcens; // num. quadrature points for rows w/ left censoring + int qrcens; // num. quadrature points for rows w/ right censoring + int qicens; // num. quadrature points for rows w/ interval censoring + int qdelay; // num. quadrature points for rows w/ delayed entry + int nvars; // num. aux parameters for baseline hazard + array[S] int smooth_map; // indexing of smooth sds for tve spline coefs + array[S > 0 ? max(smooth_map) : 0, 2] int smooth_idx; + + // dimensions for random efffects structure, see table 3 of + // https://cran.r-project.org/web/packages/lme4/vignettes/lmer.pdf + int t; // num. terms (maybe 0) with a | in the glmer formula + array[t] int p; // num. variables on the LHS of each | + array[t] int l; // num. levels for the factor(s) on the RHS of each | + int q; // conceptually equals \sum_{i=1}^t p_i \times l_i + + // log crude event rate / time (for centering linear predictor) + real log_crude_event_rate; + + // response and time variables + vector[nevent] t_event; // time of events + vector[nlcens] t_lcens; // time of left censoring + vector[nrcens] t_rcens; // time of right censoring + vector[nicens] t_icenl; // time of lower limit for interval censoring + vector[nicens] t_icenu; // time of upper limit for interval censoring + vector[ndelay] t_delay; // time of entry for delayed entry + + vector[Nevent] epts_event; // time of events + vector[qevent] qpts_event; // qpts for time of events + vector[qlcens] qpts_lcens; // qpts for time of left censoring + vector[qrcens] qpts_rcens; // qpts for time of right censoring + vector[qicens] qpts_icenl; // qpts for time of lower limit for interval censoring + vector[qicens] qpts_icenu; // qpts for time of upper limit for interval censoring + vector[qdelay] qpts_delay; // qpts for time of entry for delayed entry + + // predictor matrices (time-fixed), without quadrature + vector[K] x_bar; // predictor means + matrix[nevent,K] x_event; // for rows with events + matrix[nlcens,K] x_lcens; // for rows with left censoring + matrix[nrcens,K] x_rcens; // for rows with right censoring + matrix[nicens,K] x_icens; // for rows with interval censoring + matrix[ndelay,K] x_delay; // for rows with delayed entry + + // predictor matrices (time-fixed), with quadrature + matrix[Nevent,K] x_epts_event; // for rows with events + matrix[qevent,K] x_qpts_event; // for rows with events + matrix[qlcens,K] x_qpts_lcens; // for rows with left censoring + matrix[qrcens,K] x_qpts_rcens; // for rows with right censoring + matrix[qicens,K] x_qpts_icens; // for rows with interval censoring + matrix[qdelay,K] x_qpts_delay; // for rows with delayed entry + + // predictor matrices (time-varying) + matrix[Nevent,S] s_epts_event; // for rows with events + matrix[qevent,S] s_qpts_event; // for rows with events + matrix[qlcens,S] s_qpts_lcens; // for rows with left censoring + matrix[qrcens,S] s_qpts_rcens; // for rows with right censoring + matrix[qicens,S] s_qpts_icenl; // for rows with interval censoring + matrix[qicens,S] s_qpts_icenu; // for rows with interval censoring + matrix[qdelay,S] s_qpts_delay; // for rows with delayed entry + + // random effects structure, without quadrature + // nnz: number of non-zero elements in the Z matrix + // w: non-zero elements in the implicit Z matrix + // v: column indices for w + // u: where the non-zeros start in each row + int nnz_event; + int nnz_lcens; + int nnz_rcens; + int nnz_icens; + int nnz_delay; + + vector[nnz_event] w_event; + vector[nnz_lcens] w_lcens; + vector[nnz_rcens] w_rcens; + vector[nnz_icens] w_icens; + vector[nnz_delay] w_delay; + + array[nnz_event] int v_event; + array[nnz_lcens] int v_lcens; + array[nnz_rcens] int v_rcens; + array[nnz_icens] int v_icens; + array[nnz_delay] int v_delay; + + array[(t > 0 && nevent > 0) ? nevent + 1 : 0] int u_event; + array[(t > 0 && nlcens > 0) ? nlcens + 1 : 0] int u_lcens; + array[(t > 0 && nrcens > 0) ? nrcens + 1 : 0] int u_rcens; + array[(t > 0 && nicens > 0) ? nicens + 1 : 0] int u_icens; + array[(t > 0 && ndelay > 0) ? ndelay + 1 : 0] int u_delay; + + // random effects structure, with quadrature + // nnz: number of non-zero elements in the Z matrix + // w: non-zero elements in the implicit Z matrix + // v: column indices for w + // u: where the non-zeros start in each row + int nnz_epts_event; + int nnz_qpts_event; + int nnz_qpts_lcens; + int nnz_qpts_rcens; + int nnz_qpts_icens; + int nnz_qpts_delay; + + vector[nnz_epts_event] w_epts_event; + vector[nnz_qpts_event] w_qpts_event; + vector[nnz_qpts_lcens] w_qpts_lcens; + vector[nnz_qpts_rcens] w_qpts_rcens; + vector[nnz_qpts_icens] w_qpts_icens; + vector[nnz_qpts_delay] w_qpts_delay; + + array[nnz_epts_event] int v_epts_event; + array[nnz_qpts_event] int v_qpts_event; + array[nnz_qpts_lcens] int v_qpts_lcens; + array[nnz_qpts_rcens] int v_qpts_rcens; + array[nnz_qpts_icens] int v_qpts_icens; + array[nnz_qpts_delay] int v_qpts_delay; + + array[(t > 0 && Nevent > 0) ? Nevent + 1 : 0] int u_epts_event; + array[(t > 0 && qevent > 0) ? qevent + 1 : 0] int u_qpts_event; + array[(t > 0 && qlcens > 0) ? qlcens + 1 : 0] int u_qpts_lcens; + array[(t > 0 && qrcens > 0) ? qrcens + 1 : 0] int u_qpts_rcens; + array[(t > 0 && qicens > 0) ? qicens + 1 : 0] int u_qpts_icens; + array[(t > 0 && qdelay > 0) ? qdelay + 1 : 0] int u_qpts_delay; + + // basis matrices for M-splines / I-splines, without quadrature + matrix[nevent,nvars] basis_event; // at event time + matrix[nevent,nvars] ibasis_event; // at event time + matrix[nlcens,nvars] ibasis_lcens; // at left censoring time + matrix[nrcens,nvars] ibasis_rcens; // at right censoring time + matrix[nicens,nvars] ibasis_icenl; // at lower limit of interval censoring + matrix[nicens,nvars] ibasis_icenu; // at upper limit of interval censoring + matrix[ndelay,nvars] ibasis_delay; // at delayed entry time + + // basis matrices for M-splines, with quadrature + matrix[Nevent,nvars] basis_epts_event; // at event time + matrix[qevent,nvars] basis_qpts_event; // at qpts for event time + matrix[qlcens,nvars] basis_qpts_lcens; // at qpts for left censoring time + matrix[qrcens,nvars] basis_qpts_rcens; // at qpts for right censoring time + matrix[qicens,nvars] basis_qpts_icenl; // at qpts for lower limit of icens time + matrix[qicens,nvars] basis_qpts_icenu; // at qpts for upper limit of icens time + matrix[qdelay,nvars] basis_qpts_delay; // at qpts for delayed entry time + + // baseline hazard type: + // 1 = weibull + // 2 = B-splines + // 3 = piecewise + // 4 = M-splines + // 5 = exponential + // 6 = gompertz + // 7 = exponential AFT + // 8 = weibull AFT + int type; + + // GK quadrature weights, with (b-a)/2 scaling already incorporated + vector[qevent] qwts_event; + vector[qlcens] qwts_lcens; + vector[qrcens] qwts_rcens; + vector[qicens] qwts_icenl; + vector[qicens] qwts_icenu; + vector[qdelay] qwts_delay; + + // flags + int has_quadrature;// log surv is calculated using quadrature + int has_intercept; // basehaz requires intercept + int prior_PD; // draw only from prior predictive dist. + + // prior family: + // 0 = none + // 1 = normal + // 2 = student_t + // 3 = hs + // 4 = hs_plus + // 5 = laplace + // 6 = lasso + int prior_dist; + + // prior family: + // 0 = none + // 1 = normal + // 2 = student_t + int prior_dist_for_intercept; + + // prior family: + // 0 = none + // 1 = normal + // 2 = student_t + // 3 = exponential + // 4 = dirichlet + int prior_dist_for_aux; + + // prior family: + // 0 = none + // 1 = normal + // 2 = student_t + // 3 = exponential + int prior_dist_for_smooth; + + // hyperparameters (log hazard ratios), set to 0 if there is no prior + vector[K] prior_mean; + vector[K] prior_scale; + vector[K] prior_df; + real global_prior_scale; // for hs priors only + real global_prior_df; + real slab_scale; + real slab_df; + + // hyperparameters (intercept), set to 0 if there is no prior + real prior_mean_for_intercept; + real prior_scale_for_intercept; + real prior_df_for_intercept; + + // hyperparameters (basehaz pars), set to 0 if there is no prior + vector[nvars] prior_scale_for_aux; + vector[nvars] prior_df_for_aux; + vector[nvars] prior_conc_for_aux; // dirichlet concentration pars + + // hyperparameters (tve smooths), set to 0 if there is no prior + vector [S > 0 ? max(smooth_map) : 0] prior_mean_for_smooth; + vector[S > 0 ? max(smooth_map) : 0] prior_scale_for_smooth; + vector[S > 0 ? max(smooth_map) : 0] prior_df_for_smooth; + + // hyperparameters (random effects structure), set to 0 if there is no prior + vector[t] b_prior_shape; + vector[t] b_prior_scale; + int len_theta_L; // length of the theta_L vector + int len_concentration; + int len_regularization; + array[len_concentration] real concentration; + array[len_regularization] real regularization; + int special_case; // is the only term (1|group) + +} + +transformed data { + + int hs = get_nvars_for_hs(prior_dist); + + int sc = special_case; + + array[sc ? t : 0, nevent] int V_event = make_V(nevent, sc ? t : 0, v_event); + array[sc ? t : 0, nlcens] int V_lcens = make_V(nlcens, sc ? t : 0, v_lcens); + array[sc ? t : 0, nrcens] int V_rcens = make_V(nrcens, sc ? t : 0, v_rcens); + array[sc ? t : 0, nicens] int V_icens = make_V(nicens, sc ? t : 0, v_icens); + array[sc ? t : 0, ndelay] int V_delay = make_V(ndelay, sc ? t : 0, v_delay); + + array[sc ? t : 0, Nevent] int V_epts_event = make_V(Nevent, sc ? t : 0, v_epts_event); + array[sc ? t : 0, qevent] int V_qpts_event = make_V(qevent, sc ? t : 0, v_qpts_event); + array[sc ? t : 0, qlcens] int V_qpts_lcens = make_V(qlcens, sc ? t : 0, v_qpts_lcens); + array[sc ? t : 0, qrcens] int V_qpts_rcens = make_V(qrcens, sc ? t : 0, v_qpts_rcens); + array[sc ? t : 0, qicens] int V_qpts_icens = make_V(qicens, sc ? t : 0, v_qpts_icens); + array[sc ? t : 0, qdelay] int V_qpts_delay = make_V(qdelay, sc ? t : 0, v_qpts_delay); + + int pos = 1; + int len_z_T = 0; + int len_rho = sum(p) - t; + array[len_concentration] real delta; + + for (i in 1:t) { + if (p[i] > 1) { + for (j in 1:p[i]) { + delta[pos] = concentration[j]; + pos += 1; + } + } + for (j in 3:p[i]) len_z_T += p[i] - 1; + } + + + +} + +parameters { + + // primitive log hazard ratios + vector[K] z_beta; + + // intercept + array[has_intercept == 1] real gamma; + + // unscaled basehaz parameters + // exp model: nvars = 0, ie. no aux parameter + // weibull model: nvars = 1, ie. shape parameter + // gompertz model: nvars = 1, ie. scale parameter + // M-spline model: nvars = number of basis terms, ie. spline coefs + // B-spline model: nvars = number of basis terms, ie. spline coefs + vector[type == 4 ? 0 : nvars] z_coefs; + simplex[type == 4 ? nvars : 1] ms_coefs; // constrained coefs for M-splines + + // unscaled tve spline coefficients + vector[S] z_beta_tve; + + // hyperparameter, the prior sd for the tve spline coefs + vector[S > 0 ? max(smooth_map) : 0] smooth_sd_raw; + + // parameters for random effects + vector[q] z_b; + vector[len_z_T] z_T; + vector[len_rho] rho; + vector[len_concentration] zeta; + vector[t] tau; + + // parameters for priors + array[hs] real global; + array[hs] vector[hs > 0 ? K : 0] local; + array[hs > 0] real caux; + array[prior_dist == 5 || prior_dist == 6] vector[K] mix; + array[prior_dist == 6] real ool; +} + +transformed parameters { + + // declare log hazard ratios + vector[K] beta; + + // declare basehaz parameters + vector[type == 4 ? 0 : nvars] coefs; + + // declare tve spline coefficients and their hyperparameters + vector[S] beta_tve; + vector[S > 0 ? max(smooth_map) : 0] smooth_sd; // sd for tve splines + + // declare random effects and var-cov parameters + vector[q] b; + vector[len_theta_L] theta_L; + + // define log hazard ratios + if (K > 0) { + beta = make_beta(z_beta, + prior_dist, prior_mean, prior_scale, prior_df, + global_prior_scale, global, local, ool, mix, + rep_array(1.0, 0), 0, slab_scale, caux); + } + + // define basehaz parameters + if (type != 4 && nvars > 0) { + coefs = z_coefs .* prior_scale_for_aux; + } + + // define tve spline coefficients using random walk + if (S > 0) { + smooth_sd = smooth_sd_raw .* prior_scale_for_smooth + prior_mean_for_smooth; + for (i in 1:max(smooth_map)) { + int beg = smooth_idx[i,1]; // index of first spline coef + int end = smooth_idx[i,2]; // index of last spline coef + beta_tve[beg] = z_beta_tve[beg]; // define first spline coef + if (end > beg) { // define subsequent spline coefs + for (j in (beg+1):end) { + real tmp = beta_tve[j-1]; + beta_tve[j] = tmp + z_beta_tve[j] * smooth_sd[smooth_map[j]]; + } + } + } + } + + // define random effects and var-cov parameters + if (t > 0) { + if (special_case == 1) { + int start = 1; + theta_L = b_prior_scale .* tau * 1.0; + if (t == 1) { + b = theta_L[1] * z_b; + } + else for (i in 1:t) { + int end = start + l[i] - 1; + b[start:end] = theta_L[i] * z_b[start:end]; + start = end + 1; + } + } + else { + theta_L = make_theta_L(len_theta_L, p, 1.0, tau, b_prior_scale, zeta, rho, z_T); + b = make_b(z_b, theta_L, p, l); + } + } + +} + +model { + + if (prior_PD == 0) { + + //-------- models without quadrature + + if (has_quadrature == 0) { + + vector[nevent] eta_event; // for events + vector[nlcens] eta_lcens; // for left censored + vector[nrcens] eta_rcens; // for right censored + vector[nicens] eta_icens; // for interval censored + vector[ndelay] eta_delay; // for delayed entry + + // linear predictor + if (K > 0) { + if (nevent > 0) eta_event = x_event * beta; + if (nlcens > 0) eta_lcens = x_lcens * beta; + if (nrcens > 0) eta_rcens = x_rcens * beta; + if (nicens > 0) eta_icens = x_icens * beta; + if (ndelay > 0) eta_delay = x_delay * beta; + } + else { + if (nevent > 0) eta_event = rep_vector(0.0, nevent); + if (nlcens > 0) eta_lcens = rep_vector(0.0, nlcens); + if (nrcens > 0) eta_rcens = rep_vector(0.0, nrcens); + if (nicens > 0) eta_icens = rep_vector(0.0, nicens); + if (ndelay > 0) eta_delay = rep_vector(0.0, ndelay); + } + + // add on log crude event rate / time (helps to center intercept) + if (nevent > 0) eta_event += log_crude_event_rate; + if (nlcens > 0) eta_lcens += log_crude_event_rate; + if (nrcens > 0) eta_rcens += log_crude_event_rate; + if (nicens > 0) eta_icens += log_crude_event_rate; + if (ndelay > 0) eta_delay += log_crude_event_rate; + + // add on intercept to linear predictor + if (has_intercept == 1) { + if (nevent > 0) eta_event += gamma[1]; + if (nlcens > 0) eta_lcens += gamma[1]; + if (nrcens > 0) eta_rcens += gamma[1]; + if (nicens > 0) eta_icens += gamma[1]; + if (ndelay > 0) eta_delay += gamma[1]; + } + + // add on random effects terms to linear predictor + if (t > 0) { + if (special_case) for (i in 1:t) { + if (nevent > 0) eta_event += b[V_event[i]]; + if (nlcens > 0) eta_lcens += b[V_lcens[i]]; + if (nrcens > 0) eta_rcens += b[V_rcens[i]]; + if (nicens > 0) eta_icens += b[V_icens[i]]; + if (ndelay > 0) eta_delay += b[V_delay[i]]; + } + else { + if (nevent > 0) eta_event += + csr_matrix_times_vector(nevent, q, w_event, v_event, u_event, b); + if (nlcens > 0) eta_lcens += + csr_matrix_times_vector(nlcens, q, w_lcens, v_lcens, u_lcens, b); + if (nrcens > 0) eta_rcens += + csr_matrix_times_vector(nrcens, q, w_rcens, v_rcens, u_rcens, b); + if (nicens > 0) eta_icens += + csr_matrix_times_vector(nicens, q, w_icens, v_icens, u_icens, b); + if (ndelay > 0) eta_delay += + csr_matrix_times_vector(ndelay, q, w_delay, v_delay, u_delay, b); + } + } + + // aft models + if (type == 7 || type == 8) { + + // acceleration factor at event times + vector[nevent] af_event = exp(-eta_event); + + // cumulative acceleration factors + vector[nevent] caf_event = t_event .* exp(-eta_event); + vector[nlcens] caf_lcens = t_lcens .* exp(-eta_lcens); + vector[nrcens] caf_rcens = t_rcens .* exp(-eta_rcens); + vector[nicens] caf_icenl = t_icenl .* exp(-eta_icens); + vector[nicens] caf_icenu = t_icenu .* exp(-eta_icens); + vector[ndelay] caf_delay = t_delay .* exp(-eta_delay); + + // increment target with log-lik contributions + if (type == 7) { // exponential AFT model + if (nevent > 0) target += exponentialAFT_log_haz (af_event); + if (nevent > 0) target += exponentialAFT_log_surv(caf_event); + if (nlcens > 0) target += exponentialAFT_log_cdf1(caf_lcens); + if (nrcens > 0) target += exponentialAFT_log_surv(caf_rcens); + if (nicens > 0) target += exponentialAFT_log_cdf2(caf_icenl, caf_icenu); + if (ndelay > 0) target += -exponentialAFT_log_surv(caf_delay); + } else if (type == 8) { // weibull AFT model + real shape = coefs[1]; + if (nevent > 0) target += weibullAFT_log_haz (af_event, caf_event, shape); + if (nevent > 0) target += weibullAFT_log_surv(caf_event, shape); + if (nlcens > 0) target += weibullAFT_log_cdf1(caf_lcens, shape); + if (nrcens > 0) target += weibullAFT_log_surv(caf_rcens, shape); + if (nicens > 0) target += weibullAFT_log_cdf2(caf_icenl, caf_icenu, shape); + if (ndelay > 0) target += -weibullAFT_log_surv(caf_delay, shape); + } + + } + + // hazard models + else { + + // evaluate log hazard and log survival + if (type == 5) { // exponential model + if (nevent > 0) target += exponential_log_haz (eta_event); + if (nevent > 0) target += exponential_log_surv(eta_event, t_event); + if (nlcens > 0) target += exponential_log_cdf1(eta_lcens, t_lcens); + if (nrcens > 0) target += exponential_log_surv(eta_rcens, t_rcens); + if (nicens > 0) target += exponential_log_cdf2(eta_icens, t_icenl, t_icenu); + if (ndelay > 0) target += -exponential_log_surv(eta_delay, t_delay); + } + else if (type == 1) { // weibull model + real shape = coefs[1]; + if (nevent > 0) target += weibull_log_haz (eta_event, t_event, shape); + if (nevent > 0) target += weibull_log_surv(eta_event, t_event, shape); + if (nlcens > 0) target += weibull_log_cdf1(eta_lcens, t_lcens, shape); + if (nrcens > 0) target += weibull_log_surv(eta_rcens, t_rcens, shape); + if (nicens > 0) target += weibull_log_cdf2(eta_icens, t_icenl, t_icenu, shape); + if (ndelay > 0) target += -weibull_log_surv(eta_delay, t_delay, shape); + } + else if (type == 6) { // gompertz model + real scale = coefs[1]; + if (nevent > 0) target += gompertz_log_haz (eta_event, t_event, scale); + if (nevent > 0) target += gompertz_log_surv(eta_event, t_event, scale); + if (nlcens > 0) target += gompertz_log_cdf1(eta_lcens, t_lcens, scale); + if (nrcens > 0) target += gompertz_log_surv(eta_rcens, t_rcens, scale); + if (nicens > 0) target += gompertz_log_cdf2(eta_icens, t_icenl, t_icenu, scale); + if (ndelay > 0) target += -gompertz_log_surv(eta_delay, t_delay, scale); + } + else if (type == 4) { // M-splines, on haz scale + if (nevent > 0) target += mspline_log_haz (eta_event, basis_event, ms_coefs); + if (nevent > 0) target += mspline_log_surv(eta_event, ibasis_event, ms_coefs); + if (nlcens > 0) target += mspline_log_cdf1(eta_lcens, ibasis_lcens, ms_coefs); + if (nrcens > 0) target += mspline_log_surv(eta_rcens, ibasis_rcens, ms_coefs); + if (nicens > 0) target += mspline_log_cdf2(eta_icens, ibasis_icenl, ibasis_icenu, ms_coefs); + if (ndelay > 0) target += -mspline_log_surv(eta_delay, ibasis_delay, ms_coefs); + } + else { + reject("Bug found: invalid baseline hazard (without quadrature)."); + } + + } + } + + //-------- models with quadrature + + else { + + vector[Nevent] eta_epts_event; // for event times + vector[qevent] eta_qpts_event; // for qpts for event time + vector[qlcens] eta_qpts_lcens; // for qpts for left censoring time + vector[qrcens] eta_qpts_rcens; // for qpts for right censoring time + vector[qicens] eta_qpts_icenl; // for qpts for lower limit of icens time + vector[qicens] eta_qpts_icenu; // for qpts for upper limit of icens time + vector[qdelay] eta_qpts_delay; // for qpts for delayed entry time + + // linear predictor (time-fixed part) + if (K > 0) { + if (Nevent > 0) eta_epts_event = x_epts_event * beta; + if (qevent > 0) eta_qpts_event = x_qpts_event * beta; + if (qlcens > 0) eta_qpts_lcens = x_qpts_lcens * beta; + if (qrcens > 0) eta_qpts_rcens = x_qpts_rcens * beta; + if (qicens > 0) eta_qpts_icenl = x_qpts_icens * beta; + if (qicens > 0) eta_qpts_icenu = x_qpts_icens * beta; + if (qdelay > 0) eta_qpts_delay = x_qpts_delay * beta; + } + else { + if (Nevent > 0) eta_epts_event = rep_vector(0.0, Nevent); + if (qevent > 0) eta_qpts_event = rep_vector(0.0, qevent); + if (qlcens > 0) eta_qpts_lcens = rep_vector(0.0, qlcens); + if (qrcens > 0) eta_qpts_rcens = rep_vector(0.0, qrcens); + if (qicens > 0) eta_qpts_icenl = rep_vector(0.0, qicens); + if (qicens > 0) eta_qpts_icenu = rep_vector(0.0, qicens); + if (qdelay > 0) eta_qpts_delay = rep_vector(0.0, qdelay); + } + + // add on time-varying part to linear predictor + if (S > 0) { + if (Nevent > 0) eta_epts_event += s_epts_event * beta_tve; + if (qevent > 0) eta_qpts_event += s_qpts_event * beta_tve; + if (qlcens > 0) eta_qpts_lcens += s_qpts_lcens * beta_tve; + if (qrcens > 0) eta_qpts_rcens += s_qpts_rcens * beta_tve; + if (qicens > 0) eta_qpts_icenl += s_qpts_icenl * beta_tve; + if (qicens > 0) eta_qpts_icenu += s_qpts_icenu * beta_tve; + if (qdelay > 0) eta_qpts_delay += s_qpts_delay * beta_tve; + } + + // add on log crude event rate / time (helps to center intercept) + if (Nevent > 0) eta_epts_event += log_crude_event_rate; + if (qevent > 0) eta_qpts_event += log_crude_event_rate; + if (qlcens > 0) eta_qpts_lcens += log_crude_event_rate; + if (qrcens > 0) eta_qpts_rcens += log_crude_event_rate; + if (qicens > 0) eta_qpts_icenl += log_crude_event_rate; + if (qicens > 0) eta_qpts_icenu += log_crude_event_rate; + if (qdelay > 0) eta_qpts_delay += log_crude_event_rate; + + // add on intercept to linear predictor + if (has_intercept == 1) { + if (Nevent > 0) eta_epts_event += gamma[1]; + if (qevent > 0) eta_qpts_event += gamma[1]; + if (qlcens > 0) eta_qpts_lcens += gamma[1]; + if (qrcens > 0) eta_qpts_rcens += gamma[1]; + if (qicens > 0) eta_qpts_icenl += gamma[1]; + if (qicens > 0) eta_qpts_icenu += gamma[1]; + if (qdelay > 0) eta_qpts_delay += gamma[1]; + } + + // add on random effects terms to linear predictor + if (t > 0) { + if (special_case) for (i in 1:t) { + if (Nevent > 0) eta_epts_event += b[V_epts_event[i]]; + if (qevent > 0) eta_qpts_event += b[V_qpts_event[i]]; + if (qlcens > 0) eta_qpts_lcens += b[V_qpts_lcens[i]]; + if (qrcens > 0) eta_qpts_rcens += b[V_qpts_rcens[i]]; + if (qicens > 0) eta_qpts_icenl += b[V_qpts_icens[i]]; + if (qicens > 0) eta_qpts_icenu += b[V_qpts_icens[i]]; + if (qdelay > 0) eta_qpts_delay += b[V_qpts_delay[i]]; + } + else { + if (Nevent > 0) eta_epts_event += + csr_matrix_times_vector(Nevent, q, w_epts_event, v_epts_event, u_epts_event, b); + if (qevent > 0) eta_qpts_event += + csr_matrix_times_vector(qevent, q, w_qpts_event, v_qpts_event, u_qpts_event, b); + if (qlcens > 0) eta_qpts_lcens += + csr_matrix_times_vector(qlcens, q, w_qpts_lcens, v_qpts_lcens, u_qpts_lcens, b); + if (qrcens > 0) eta_qpts_rcens += + csr_matrix_times_vector(qrcens, q, w_qpts_rcens, v_qpts_rcens, u_qpts_rcens, b); + if (qicens > 0) eta_qpts_icenl += + csr_matrix_times_vector(qicens, q, w_qpts_icens, v_qpts_icens, u_qpts_icens, b); + if (qicens > 0) eta_qpts_icenu += + csr_matrix_times_vector(qicens, q, w_qpts_icens, v_qpts_icens, u_qpts_icens, b); + if (qdelay > 0) eta_qpts_delay += + csr_matrix_times_vector(qdelay, q, w_qpts_delay, v_qpts_delay, u_qpts_delay, b); + } + + } + + // aft models + if (type == 7 || type == 8) { + + vector[Nevent] af_event; + + vector[Nevent] caf_event; + vector[Nlcens] caf_lcens; + vector[Nrcens] caf_rcens; + vector[Nicens] caf_icenl; + vector[Nicens] caf_icenu; + vector[Ndelay] caf_delay; + + // acceleration factor at event time + if (Nevent > 0) af_event = exp(-eta_epts_event); + + // evaluate cumulative acceleration factors + if (Nevent > 0) caf_event = quadrature_aft(qwts_event, eta_qpts_event, qnodes, Nevent); + if (Nlcens > 0) caf_lcens = quadrature_aft(qwts_lcens, eta_qpts_lcens, qnodes, Nlcens); + if (Nrcens > 0) caf_rcens = quadrature_aft(qwts_rcens, eta_qpts_rcens, qnodes, Nrcens); + if (Nicens > 0) caf_icenl = quadrature_aft(qwts_icenl, eta_qpts_icenl, qnodes, Nicens); + if (Nicens > 0) caf_icenu = quadrature_aft(qwts_icenu, eta_qpts_icenu, qnodes, Nicens); + if (Ndelay > 0) caf_delay = quadrature_aft(qwts_delay, eta_qpts_delay, qnodes, Ndelay); + + // increment target with log-lik contributions + if (type == 7) { // exponential AFT model + if (Nevent > 0) target += exponentialAFT_log_haz (af_event); + if (Nevent > 0) target += exponentialAFT_log_surv(caf_event); + if (Nlcens > 0) target += exponentialAFT_log_cdf1(caf_lcens); + if (Nrcens > 0) target += exponentialAFT_log_surv(caf_rcens); + if (Nicens > 0) target += exponentialAFT_log_cdf2(caf_icenl, caf_icenu); + if (Ndelay > 0) target += -exponentialAFT_log_surv(caf_delay); + } else if (type == 8) { // weibull AFT model + real shape = coefs[1]; + if (Nevent > 0) target += weibullAFT_log_haz (af_event, caf_event, shape); + if (Nevent > 0) target += weibullAFT_log_surv(caf_event, shape); + if (Nlcens > 0) target += weibullAFT_log_cdf1(caf_lcens, shape); + if (Nrcens > 0) target += weibullAFT_log_surv(caf_rcens, shape); + if (Nicens > 0) target += weibullAFT_log_cdf2(caf_icenl, caf_icenu, shape); + if (Ndelay > 0) target += -weibullAFT_log_surv(caf_delay, shape); + } + + } + + // hazard models + else { + + vector[Nevent] lhaz_epts_event; + vector[qevent] lhaz_qpts_event; + vector[qlcens] lhaz_qpts_lcens; + vector[qrcens] lhaz_qpts_rcens; + vector[qicens] lhaz_qpts_icenl; + vector[qicens] lhaz_qpts_icenu; + vector[qdelay] lhaz_qpts_delay; + + // evaluate log hazard + if (type == 5) { // exponential model + if (Nevent > 0) lhaz_epts_event = exponential_log_haz(eta_epts_event); + if (qevent > 0) lhaz_qpts_event = exponential_log_haz(eta_qpts_event); + if (qlcens > 0) lhaz_qpts_lcens = exponential_log_haz(eta_qpts_lcens); + if (qrcens > 0) lhaz_qpts_rcens = exponential_log_haz(eta_qpts_rcens); + if (qicens > 0) lhaz_qpts_icenl = exponential_log_haz(eta_qpts_icenl); + if (qicens > 0) lhaz_qpts_icenu = exponential_log_haz(eta_qpts_icenu); + if (qdelay > 0) lhaz_qpts_delay = exponential_log_haz(eta_qpts_delay); + } + else if (type == 1) { // weibull model + real shape = coefs[1]; + if (Nevent > 0) lhaz_epts_event = weibull_log_haz(eta_epts_event, epts_event, shape); + if (qevent > 0) lhaz_qpts_event = weibull_log_haz(eta_qpts_event, qpts_event, shape); + if (qlcens > 0) lhaz_qpts_lcens = weibull_log_haz(eta_qpts_lcens, qpts_lcens, shape); + if (qrcens > 0) lhaz_qpts_rcens = weibull_log_haz(eta_qpts_rcens, qpts_rcens, shape); + if (qicens > 0) lhaz_qpts_icenl = weibull_log_haz(eta_qpts_icenl, qpts_icenl, shape); + if (qicens > 0) lhaz_qpts_icenu = weibull_log_haz(eta_qpts_icenu, qpts_icenu, shape); + if (qdelay > 0) lhaz_qpts_delay = weibull_log_haz(eta_qpts_delay, qpts_delay, shape); + } + else if (type == 6) { // gompertz model + real scale = coefs[1]; + if (Nevent > 0) lhaz_epts_event = gompertz_log_haz(eta_epts_event, epts_event, scale); + if (qevent > 0) lhaz_qpts_event = gompertz_log_haz(eta_qpts_event, qpts_event, scale); + if (qlcens > 0) lhaz_qpts_lcens = gompertz_log_haz(eta_qpts_lcens, qpts_lcens, scale); + if (qrcens > 0) lhaz_qpts_rcens = gompertz_log_haz(eta_qpts_rcens, qpts_rcens, scale); + if (qicens > 0) lhaz_qpts_icenl = gompertz_log_haz(eta_qpts_icenl, qpts_icenl, scale); + if (qicens > 0) lhaz_qpts_icenu = gompertz_log_haz(eta_qpts_icenu, qpts_icenu, scale); + if (qdelay > 0) lhaz_qpts_delay = gompertz_log_haz(eta_qpts_delay, qpts_delay, scale); + } + else if (type == 4) { // M-splines, on haz scale + if (Nevent > 0) lhaz_epts_event = mspline_log_haz(eta_epts_event, basis_epts_event, ms_coefs); + if (qevent > 0) lhaz_qpts_event = mspline_log_haz(eta_qpts_event, basis_qpts_event, ms_coefs); + if (qlcens > 0) lhaz_qpts_lcens = mspline_log_haz(eta_qpts_lcens, basis_qpts_lcens, ms_coefs); + if (qrcens > 0) lhaz_qpts_rcens = mspline_log_haz(eta_qpts_rcens, basis_qpts_rcens, ms_coefs); + if (qicens > 0) lhaz_qpts_icenl = mspline_log_haz(eta_qpts_icenl, basis_qpts_icenl, ms_coefs); + if (qicens > 0) lhaz_qpts_icenu = mspline_log_haz(eta_qpts_icenu, basis_qpts_icenu, ms_coefs); + if (qdelay > 0) lhaz_qpts_delay = mspline_log_haz(eta_qpts_delay, basis_qpts_delay, ms_coefs); + } + else if (type == 2) { // B-splines, on log haz scale + if (Nevent > 0) lhaz_epts_event = bspline_log_haz(eta_epts_event, basis_epts_event, coefs); + if (qevent > 0) lhaz_qpts_event = bspline_log_haz(eta_qpts_event, basis_qpts_event, coefs); + if (qlcens > 0) lhaz_qpts_lcens = bspline_log_haz(eta_qpts_lcens, basis_qpts_lcens, coefs); + if (qrcens > 0) lhaz_qpts_rcens = bspline_log_haz(eta_qpts_rcens, basis_qpts_rcens, coefs); + if (qicens > 0) lhaz_qpts_icenl = bspline_log_haz(eta_qpts_icenl, basis_qpts_icenl, coefs); + if (qicens > 0) lhaz_qpts_icenu = bspline_log_haz(eta_qpts_icenu, basis_qpts_icenu, coefs); + if (qdelay > 0) lhaz_qpts_delay = bspline_log_haz(eta_qpts_delay, basis_qpts_delay, coefs); + } + else { + reject("Bug found: invalid baseline hazard (with quadrature)."); + } + + // increment target with log-lik contributions for event submodel + if (Nevent > 0) target += lhaz_epts_event; + if (qevent > 0) target += quadrature_log_surv(qwts_event, lhaz_qpts_event); + if (qlcens > 0) target += quadrature_log_cdf1(qwts_lcens, lhaz_qpts_lcens, qnodes, Nlcens); + if (qrcens > 0) target += quadrature_log_surv(qwts_rcens, lhaz_qpts_rcens); + if (qicens > 0) target += quadrature_log_cdf2(qwts_icenl, lhaz_qpts_icenl, + qwts_icenu, lhaz_qpts_icenu, qnodes, Nicens); + if (qdelay > 0) target += -quadrature_log_surv(qwts_delay, lhaz_qpts_delay); + + } + + } + + } + + //-------- log priors + + // log priors for coefficients + if (K > 0) { + target += beta_custom_lpdf(z_beta | prior_dist, prior_scale, prior_df, + global_prior_df, local, global, mix, ool, + slab_df, caux); + } + + // log prior for intercept + if (has_intercept == 1) { + target += gamma_custom_lpdf(gamma[1] | prior_dist_for_intercept, + prior_mean_for_intercept, prior_scale_for_intercept, + prior_df_for_intercept); + } + + // log priors for baseline hazard parameters + if (type == 4) { + target += basehaz_lpdf(ms_coefs | prior_dist_for_aux, prior_conc_for_aux); + } + else if (nvars > 0) { + target += basehaz_lpdf(z_coefs | prior_dist_for_aux, prior_df_for_aux); + } + + // log priors for tve spline coefficients and their smoothing parameters + if (S > 0) { + target += smooth_lpdf(z_beta_tve | smooth_sd_raw, + prior_dist_for_smooth, prior_df_for_smooth); + } + + // log prior for random effects + if (t > 0) { + target += decov_lpdf(z_b | z_T, rho, zeta, tau, + regularization, delta, b_prior_shape, t, p); + } + +} + +generated quantities { + // baseline hazard parameters to return + vector[nvars] aux = (type == 4) ? ms_coefs : coefs; + + // transformed intercept + real alpha; + if (has_intercept == 1) { + alpha = log_crude_event_rate - dot_product(x_bar, beta) + gamma[1]; + } else { + alpha = log_crude_event_rate - dot_product(x_bar, beta); + } +} diff --git a/tests/testthat/helper.R b/tests/testthat/helper.R index 18d409528..7842f498f 100644 --- a/tests/testthat/helper.R +++ b/tests/testthat/helper.R @@ -109,7 +109,8 @@ expect_ppd <- function(x) { expect_stanreg <- function(x) expect_s3_class(x, "stanreg") expect_stanmvreg <- function(x) expect_s3_class(x, "stanmvreg") -expect_survfit <- function(x) expect_s3_class(x, "survfit.stanjm") +expect_survfit_jm <- function(x) expect_s3_class(x, "survfit.stanjm") +expect_survfit_surv <- function(x) expect_s3_class(x, "survfit.stansurv") # Use the standard errors from a fitted 'comparison model' to obtain # the tolerance for each parameter in the joint model @@ -128,7 +129,7 @@ expect_survfit <- function(x) expect_s3_class(x, "survfit.stanjm") # @param idvar The name of the ID variable. Used to extract the SDs for # group-specific terms that correspond to the individual/patient. # -get_tols <- function(modLong, modEvent = NULL, tolscales, idvar = "id") { +get_tols_jm <- function(modLong, modEvent = NULL, tolscales, idvar = "id") { if (is.null(modEvent)) modEvent <- modLong # if modLong is already a joint model @@ -179,6 +180,36 @@ get_tols <- function(modLong, modEvent = NULL, tolscales, idvar = "id") { return(ret) } +# Use the standard errors from a fitted 'comparison model' to obtain +# the tolerance for each parameter in the joint model +# Obtain parameter specific tolerances that can be used to assess the +# accuracy of parameter estimates in stan_jm models. The tolerances +# are calculated by taking the SE/SD for the parameter estimate in a +# "gold standard" model and multiplying this by the relevant element +# in the 'tolscales' argument. +# +# @param mod The "gold standard" longitudinal model. Likely to be +# a model estimated using coxph. +# @param toscales A named list with elements 'hr_fixef' and 'tve_fixef'. +# +get_tols_surv <- function(mod, tolscales) { + + cl <- class(mod)[1L] + + if (cl %in% c("coxph", "survreg")) { + fixef_ses <- sqrt(diag(mod$var))[1:length(mod$coefficients)] + fixef_tols <- tolscales$hr_fixef * fixef_ses + names(fixef_tols) <- names(mod$coefficients) + } + + if ("(Intercept)" %in% names(fixef_tols)) + fixef_tols[["(Intercept)"]] <- 2 * fixef_tols[["(Intercept)"]] + + ret <- Filter(function(x) !is.null(x), list(fixef = fixef_tols)) + + return(ret) +} + # Recover parameter estimates and return a list with consistent # parameter names for comparing stan_jm, stan_mvmer, stan_{g}lmer, # {g}lmer, and coxph estimates @@ -190,7 +221,7 @@ get_tols <- function(modLong, modEvent = NULL, tolscales, idvar = "id") { # @param idvar The name of the ID variable. Used to extract the estimates # for group-specific parameters that correspond to the individual/patient. # -recover_pars <- function(modLong, modEvent = NULL, idvar = "id") { +recover_pars_jm <- function(modLong, modEvent = NULL, idvar = "id") { if (is.null(modEvent)) modEvent <- modLong @@ -215,3 +246,34 @@ recover_pars <- function(modLong, modEvent = NULL, idvar = "id") { return(ret) } +# Recover parameter estimates and return a list with consistent +# parameter names for comparing stan_surv and coxph estimates +# +# @param mod The fitted survival model. Likely to be a model estimated +# using either coxph or stan_surv. +# +recover_pars_surv <- function(mod) { + + cl <- class(mod)[1L] + + fixef_pars <- switch(cl, + "coxph" = mod$coefficients, + "survreg" = mod$coefficients, + "stansurv" = fixef(mod), + NULL) + + if (cl == "stansurv") { + sel <- grep(":tve-[a-z][a-z]-coef[0-9]*$", names(fixef_pars)) + # replace stansurv tve names with coxph tt names + if (length(sel)) { + nms <- names(fixef_pars)[sel] + nms <- gsub(":tve-[a-z][a-z]-coef[0-9]*$", "", nms) + nms <- paste0("tt(", nms, ")") + names(fixef_pars)[sel] <- nms + } + } + + ret <- Filter(function(x) !is.null(x), list(fixef = fixef_pars)) + + return(ret) +} diff --git a/tests/testthat/test_stan_jm.R b/tests/testthat/test_stan_jm.R index e5d075fad..4dce97b1f 100644 --- a/tests/testthat/test_stan_jm.R +++ b/tests/testthat/test_stan_jm.R @@ -19,18 +19,20 @@ suppressPackageStartupMessages(library(rstanarm)) library(lme4) library(survival) -ITER <- 1000 -CHAINS <- 1 -SEED <- 12345 + +ITER <- 1000 +CHAINS <- 1 +SEED <- 12345 REFRESH <- 0L + set.seed(SEED) TOLSCALES <- list( - lmer_fixef = 0.25, # how many SEs can stan_jm fixefs be from lmer fixefs - lmer_ranef = 0.05, # how many SDs can stan_jm ranefs be from lmer ranefs - glmer_fixef = 0.5, # how many SEs can stan_jm fixefs be from glmer fixefs - glmer_ranef = 0.1, # how many SDs can stan_jm ranefs be from glmer ranefs - event = 0.3 # how many SEs can stan_jm fixefs be from coxph fixefs + lmer_fixef = 0.25, # how many SEs can stan_jm fixefs be from lmer fixefs + lmer_ranef = 0.05, # how many SDs can stan_jm ranefs be from lmer ranefs + glmer_fixef = 0.5, # how many SEs can stan_jm fixefs be from glmer fixefs + glmer_ranef = 0.1, # how many SDs can stan_jm ranefs be from glmer ranefs + event = 0.3 # how many SEs can stan_jm fixefs be from coxph fixefs ) context("stan_jm") @@ -348,11 +350,11 @@ compare_glmer <- function(fmLong, fam = gaussian, ...) { fmSurv <- Surv(futimeYears, death) ~ sex + trt y1 <- stan_glmer(fmLong, pbcLong, fam, iter = 1000, chains = CHAINS, seed = SEED) s1 <- coxph(fmSurv, data = pbcSurv) - j1 <- stan_jm(fmLong, pbcLong, fmSurv, pbcSurv, time_var = "year", family = fam, - assoc = NULL, iter = 1000, chains = CHAINS, seed = SEED, ...) - tols <- get_tols(y1, s1, tolscales = TOLSCALES) - pars <- recover_pars(y1, s1) - parsjm <- recover_pars(j1) + j1 <- stan_jm(fmLong, pbcLong, fmSurv, pbcSurv, time_var = "year", family = fam, + assoc = NULL, iter = 1000, chains = CHAINS, seed = SEED, ...) + tols <- get_tols_jm(y1, s1, tolscales = TOLSCALES) + pars <- recover_pars_jm(y1, s1) + parsjm <- recover_pars_jm(j1) for (i in names(tols$fixef)) expect_equal(pars$fixef[[i]], parsjm$fixef[[i]], tol = tols$fixef[[i]], info = fam) for (i in names(tols$ranef)) @@ -547,17 +549,17 @@ for (j in c(1:30)) { test_that("posterior_survfit works with estimation data", { SW(ps <- posterior_survfit(mod)) - expect_survfit(ps) + expect_survfit_jm(ps) }) test_that("posterior_survfit works with new data (one individual)", { SW(ps <- posterior_survfit(mod, newdataLong = ndL1, newdataEvent = ndE1)) - expect_survfit(ps) - }) - + expect_survfit_jm(ps) + }) + test_that("posterior_survfit works with new data (multiple individuals)", { SW(ps <- posterior_survfit(mod, newdataLong = ndL2, newdataEvent = ndE2)) - expect_survfit(ps) + expect_survfit_jm(ps) }) } } diff --git a/tests/testthat/test_stan_mvmer.R b/tests/testthat/test_stan_mvmer.R index 43cc1b5b0..931ae4138 100644 --- a/tests/testthat/test_stan_mvmer.R +++ b/tests/testthat/test_stan_mvmer.R @@ -18,17 +18,19 @@ suppressPackageStartupMessages(library(rstanarm)) library(lme4) -ITER <- 1000 -CHAINS <- 1 -SEED <- 12345 + +ITER <- 1000 +CHAINS <- 1 +SEED <- 12345 REFRESH <- 0L + set.seed(SEED) TOLSCALES <- list( - lmer_fixef = 0.25, # how many SEs can stan_jm fixefs be from lmer fixefs - lmer_ranef = 0.05, # how many SDs can stan_jm ranefs be from lmer ranefs - glmer_fixef = 0.3, # how many SEs can stan_jm fixefs be from glmer fixefs - glmer_ranef = 0.1 # how many SDs can stan_jm ranefs be from glmer ranefs + lmer_fixef = 0.25, # how many SEs can stan_jm fixefs be from lmer fixefs + lmer_ranef = 0.05, # how many SDs can stan_jm ranefs be from lmer ranefs + glmer_fixef = 0.3, # how many SEs can stan_jm fixefs be from glmer fixefs + glmer_ranef = 0.1 # how many SDs can stan_jm ranefs be from glmer ranefs ) context("stan_mvmer") @@ -163,9 +165,9 @@ if (interactive()) { compare_glmer <- function(fmLong, fam = gaussian, ...) { SW(y1 <- stan_glmer(fmLong, pbcLong, fam, iter = 1000, chains = CHAINS, seed = SEED, refresh = 0)) SW(y2 <- stan_mvmer(fmLong, pbcLong, fam, iter = 1000, chains = CHAINS, seed = SEED, ..., refresh = 0)) - tols <- get_tols(y1, tolscales = TOLSCALES) - pars <- recover_pars(y1) - pars2 <- recover_pars(y2) + tols <- get_tols_jm(y1, tolscales = TOLSCALES) + pars <- recover_pars_jm(y1) + pars2 <- recover_pars_jm(y2) for (i in names(tols$fixef)) expect_equal(pars$fixef[[i]], pars2$fixef[[i]], tol = tols$fixef[[i]]) for (i in names(tols$ranef)) @@ -176,21 +178,26 @@ if (interactive()) { expect_equal(colMeans(log_lik(y1, newdata = nd)), colMeans(log_lik(y2, newdata = nd)), tol = 0.15) } - test_that("coefs same for stan_jm and stan_lmer/coxph", { - # fails in many cases - # compare_glmer(logBili ~ year + (1 | id), gaussian) - }) + + # fails in many cases + # test_that("coefs same for stan_mvmer and stan_glmer", { + # compare_glmer(logBili ~ year + (1 | id), gaussian)}) + # fails in some cases - # test_that("coefs same for stan_jm and stan_glmer, bernoulli", { + # test_that("coefs same for stan_mvmer and stan_glmer, bernoulli", { # compare_glmer(ybern ~ year + xbern + (1 | id), binomial)}) - test_that("coefs same for stan_jm and stan_glmer, poisson", { + + test_that("coefs same for stan_mvmer and stan_glmer, poisson", { compare_glmer(ypois ~ year + xpois + (1 | id), poisson, init = 0)}) - test_that("coefs same for stan_jm and stan_glmer, negative binomial", { + + test_that("coefs same for stan_mvmer and stan_glmer, negative binomial", { compare_glmer(ynbin ~ year + xpois + (1 | id), neg_binomial_2)}) - test_that("coefs same for stan_jm and stan_glmer, Gamma", { + + test_that("coefs same for stan_mvmer and stan_glmer, Gamma", { compare_glmer(ygamm ~ year + xgamm + (1 | id), Gamma(log))}) -# test_that("coefs same for stan_jm and stan_glmer, inverse gaussian", { -# compare_glmer(ygamm ~ year + xgamm + (1 | id), inverse.gaussian)}) + + # test_that("coefs same for stan_mvmer and stan_glmer, inverse gaussian", { + # compare_glmer(ygamm ~ year + xgamm + (1 | id), inverse.gaussian)}) } #---- Check methods and post-estimation functions @@ -205,16 +212,15 @@ o<-SW(f3 <- update(m2, formula. = list(logBili ~ year + (year | id) + (1 | pract o<-SW(f4 <- update(f3, formula. = list(logBili ~ year + (year | id) + (1 | practice), albumin ~ year + (year | id) + (1 | practice)))) o<-SW(f5 <- update(f3, formula. = list(logBili ~ year + (year | id) + (1 | practice), - ybern ~ year + (year | id) + (1 | practice)), - family = list(gaussian, binomial))) + ybern ~ year + (year | id) + (1 | practice)), + family = list(gaussian, binomial))) for (j in 1:5) { mod <- get(paste0("f", j)) cat("Checking model:", paste0("f", j), "\n") expect_error(posterior_traj(mod), "stanjm") - expect_error(posterior_survfit(mod), "stanjm") - + test_that("posterior_predict works with estimation data", { pp <- posterior_predict(mod, m = 1) expect_ppd(pp) diff --git a/tests/testthat/test_stan_surv.R b/tests/testthat/test_stan_surv.R new file mode 100644 index 000000000..7c3c270ce --- /dev/null +++ b/tests/testthat/test_stan_surv.R @@ -0,0 +1,1043 @@ +# Part of the rstanarm package for estimating model parameters +# Copyright (C) 2015, 2016 Trustees of Columbia University +# Copyright (C) 2017 Sam Brilleman +# +# This program is free software; you can redistribute it and/or +# modify it under the terms of the GNU General Public License +# as published by the Free Software Foundation; either version 3 +# of the License, or (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program; if not, write to the Free Software +# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. + +# tests can be run using devtools::test() or manually by loading testthat +# package and then running the code below possibly with options(mc.cores = 4). + +library(rstanarm) +library(survival) +library(simsurv) +ITER <- 1000 +CHAINS <- 1 +REFRESH <- 0L +SEED <- 12345; set.seed(SEED) +if (interactive()) + options(mc.cores = parallel::detectCores(), + loo.cores = parallel::detectCores()) + +TOLSCALES <- list( + hr_fixef = 0.5 # how many SEs can stan_surv HRs be from coxph/stpm2 HRs +) + +context("stan_surv") + +eo <- function(...) { expect_output (...) } +ee <- function(...) { expect_error (...) } +ew <- function(...) { expect_warning(...) } +es <- function(...) { expect_stanreg(...) } +up <- function(...) { update(...) } + +run_sims <- FALSE # if TRUE then long running simulations are run + + +#----------------- Check model fitting arguments work ----------------------- + +cov1 <- data.frame(id = 1:50, + x1 = stats::rbinom(50, 1, 0.5), + x2 = stats::rnorm (50, -1, 0.5)) + +dat1 <- simsurv(lambdas = 0.1, + gammas = 1.5, + betas = c(x1 = -0.5, x2 = -0.3), + x = cov1, + maxt = 5) + +dat1$s <- Surv(dat1$eventtime, dat1$status) # abbreviated Surv object + +o<-SW(testmod <- stan_surv(formula = s ~ x1 + x2, + data = merge(dat1, cov1), + basehaz = "ms", + iter = 20, + chains = CHAINS, + refresh = REFRESH, + seed = SEED)) + +test_that("prior_PD argument works", { + es(up(testmod, prior_PD = TRUE)) +}) + +test_that("adapt_delta argument works", { + es(up(testmod, adapt_delta = NULL)) + es(up(testmod, adapt_delta = 0.8)) + es(up(testmod, control = list(adapt_delta = NULL))) + es(up(testmod, control = list(adapt_delta = 0.8))) +}) + +test_that("init argument works", { + es(up(testmod, init = "prefit")) + es(up(testmod, init = "0")) + es(up(testmod, init = 0)) + es(up(testmod, init = "random")) +}) + +test_that("qnodes argument works", { + es(up(testmod, qnodes = 7, basehaz = "bs")) + es(up(testmod, qnodes = 11, basehaz = "bs")) + es(up(testmod, qnodes = 15, basehaz = "bs")) + + ew(up(testmod, qnodes = 1), "is being ignored") + ew(up(testmod, qnodes = "wrong"), "is being ignored") + + ee(up(testmod, qnodes = 1, basehaz = "bs"), "7, 11 or 15") + ee(up(testmod, qnodes = c(1,2), basehaz = "bs"), "numeric vector of length 1") + ee(up(testmod, qnodes = "wrong", basehaz = "bs"), "numeric vector of length 1") +}) + +test_that("basehaz argument works", { + es(up(testmod, basehaz = "exp")) + es(up(testmod, basehaz = "weibull")) + es(up(testmod, basehaz = "gompertz")) + es(up(testmod, basehaz = "ms")) + es(up(testmod, basehaz = "bs")) + es(up(testmod, basehaz = "exp-aft")) + es(up(testmod, basehaz = "weibull-aft")) + + dfl <- list(df = 5) + knl <- list(knots = c(1,3,4)) + es(up(testmod, basehaz = "ms", basehaz_ops = dfl)) + es(up(testmod, basehaz = "ms", basehaz_ops = knl)) + es(up(testmod, basehaz = "bs", basehaz_ops = dfl)) + es(up(testmod, basehaz = "bs", basehaz_ops = knl)) + + ee(up(testmod, basehaz_ops = list(junk = 3)), "can only include") + + ee(up(testmod, basehaz_ops = list(df = 1)), "cannot be negative") + ee(up(testmod, basehaz_ops = list(knots = -1)), "earliest entry time") + ee(up(testmod, basehaz_ops = list(knots = c(1,2,50))), "latest event time") +}) + +test_that("prior arguments work", { + es(up(testmod, prior = normal())) + es(up(testmod, prior = student_t())) + es(up(testmod, prior = cauchy())) + es(up(testmod, prior = hs())) + es(up(testmod, prior = hs_plus())) + es(up(testmod, prior = lasso())) + es(up(testmod, prior = laplace())) + + es(up(testmod, prior_intercept = normal())) + es(up(testmod, prior_intercept = student_t())) + es(up(testmod, prior_intercept = cauchy())) + + es(up(testmod, prior_aux = dirichlet())) + es(up(testmod, prior_aux = normal(), basehaz = "weibull")) + es(up(testmod, prior_aux = student_t(), basehaz = "weibull")) + es(up(testmod, prior_aux = cauchy(), basehaz = "weibull")) + es(up(testmod, prior_aux = exponential(), basehaz = "weibull")) + + es(up(testmod, prior_smooth = exponential())) + es(up(testmod, prior_smooth = normal())) + es(up(testmod, prior_smooth = student_t())) + es(up(testmod, prior_smooth = cauchy())) + + ee(up(testmod, prior_intercept = lasso()), "prior distribution") + ee(up(testmod, prior_aux = lasso()), "prior distribution") + ee(up(testmod, prior_smooth = lasso()), "prior distribution") +}) + +test_that("tve function works", { + es(up(testmod, formula. = s ~ tve(x1) + x2)) + es(up(testmod, formula. = s ~ tve(x1) + tve(x2))) + es(up(testmod, formula. = s ~ tve(x1, knots = 1) + tve(x2, knots = 2))) +}) + +test_that("tve function works: b-spline optional arguments", { + es(up(testmod, formula. = s ~ tve(x1, knots = c(1,2)) + x2)) + es(up(testmod, formula. = s ~ tve(x1, df = 4) + x2)) + es(up(testmod, formula. = s ~ tve(x1, degree = 0) + x2)) + ee(up(testmod, formula. = s ~ tve(x1, junk = 2) + x2), "unused") +}) + + +#---------------- Check post-estimation functions work ---------------------- + +# use PBC data +pbcSurv$t0 <- 0 +pbcSurv$t0[pbcSurv$futimeYears > 2] <- 1 # fake delayed entry +pbcSurv$t1 <- pbcSurv$futimeYears - 1 # fake lower limit for interval censoring +pbcSurv$t1[pbcSurv$t1 <= 0] <- -Inf # fake left censoring +pbcSurv$site <- cut(pbcSurv$id, # fake group for frailty models + breaks = c(0,10,20,30,40), + labels = FALSE) + +# different baseline hazards +o<-SW(f1 <- stan_surv(Surv(futimeYears, death) ~ sex + trt, + data = pbcSurv, + basehaz = "ms", + chains = 1, + iter = 20, + refresh = REFRESH, + seed = SEED)) +o<-SW(f2 <- up(f1, basehaz = "bs")) +o<-SW(f3 <- up(f1, basehaz = "exp")) +o<-SW(f4 <- up(f1, basehaz = "weibull")) +o<-SW(f5 <- up(f1, basehaz = "gompertz")) +o<-SW(f6 <- up(f1, basehaz = "exp-aft")) +o<-SW(f7 <- up(f1, basehaz = "weibull-aft")) + +# time-varying effects +o<-SW(f8 <- up(f1, Surv(futimeYears, death) ~ sex + tve(trt))) +o<-SW(f9 <- up(f2, Surv(futimeYears, death) ~ sex + tve(trt))) +o<-SW(f10 <- up(f3, Surv(futimeYears, death) ~ sex + tve(trt))) +o<-SW(f11 <- up(f4, Surv(futimeYears, death) ~ sex + tve(trt))) +o<-SW(f12 <- up(f5, Surv(futimeYears, death) ~ sex + tve(trt))) +o<-SW(f13 <- up(f6, Surv(futimeYears, death) ~ sex + tve(trt))) +o<-SW(f14 <- up(f7, Surv(futimeYears, death) ~ sex + tve(trt))) + +o<-SW(f15 <- up(f1, Surv(futimeYears, death) ~ sex + tve(trt, degree = 0))) +o<-SW(f16 <- up(f2, Surv(futimeYears, death) ~ sex + tve(trt, degree = 0))) +o<-SW(f17 <- up(f3, Surv(futimeYears, death) ~ sex + tve(trt, degree = 0))) +o<-SW(f18 <- up(f4, Surv(futimeYears, death) ~ sex + tve(trt, degree = 0))) +o<-SW(f19 <- up(f5, Surv(futimeYears, death) ~ sex + tve(trt, degree = 0))) +o<-SW(f20 <- up(f6, Surv(futimeYears, death) ~ sex + tve(trt, degree = 0))) +o<-SW(f21 <- up(f7, Surv(futimeYears, death) ~ sex + tve(trt, degree = 0))) + +# start-stop notation (incl. delayed entry) +o<-SW(f22 <- up(f1, Surv(t0, futimeYears, death) ~ sex + trt)) +o<-SW(f23 <- up(f1, Surv(t0, futimeYears, death) ~ sex + tve(trt))) +o<-SW(f24 <- up(f6, Surv(t0, futimeYears, death) ~ sex + tve(trt))) +o<-SW(f25 <- up(f6, Surv(t0, futimeYears, death) ~ sex + tve(trt))) + +# left and interval censoring +o<-SW(f26 <- up(f1, Surv(t1, futimeYears, type = "interval2") ~ sex + trt)) +o<-SW(f27 <- up(f1, Surv(t1, futimeYears, type = "interval2") ~ sex + tve(trt))) +o<-SW(f28 <- up(f6, Surv(t1, futimeYears, type = "interval2") ~ sex + trt)) +o<-SW(f29 <- up(f6, Surv(t1, futimeYears, type = "interval2") ~ sex + tve(trt))) + +# frailty models +o<-SW(f30 <- up(f1, Surv(futimeYears, death) ~ trt + (trt | site))) +o<-SW(f31 <- up(f1, Surv(futimeYears, death) ~ tve(trt) + (1 | site))) +o<-SW(f32 <- up(f1, Surv(t0, futimeYears, death) ~ trt + (trt | site))) +o<-SW(f33 <- up(f1, Surv(t1, futimeYears, type = "interval2") ~ trt + (trt | site))) + +# new data for predictions +nd1 <- pbcSurv[pbcSurv$id == 2,] +nd2 <- pbcSurv[pbcSurv$id %in% c(1,2),] + +# test the models +for (j in c(1:33)) { + + mod <- try(get(paste0("f", j)), silent = TRUE) + + if (class(mod)[1L] == "try-error") { + + cat("Model not found:", paste0("f", j), "\n") + + } else { + + cat("Checking model:", paste0("f", j), "\n") + + test_that("log_lik works with estimation data", { + ll <- log_lik(mod) + expect_matrix(ll) + }) + + test_that("log_lik works with new data (one individual)", { + ll <- log_lik(mod, newdata = nd1) + expect_matrix(ll) + }) + + test_that("log_lik works with new data (multiple individuals)", { + ll <- log_lik(mod, newdata = nd2) + expect_matrix(ll) + }) + + if (mod$ndelayed == 0) # only test if no delayed entry + test_that("posterior_survfit works with estimation data", { + SW(ps <- posterior_survfit(mod)) + expect_survfit_surv(ps) + }) + + test_that("posterior_survfit works with new data (one individual)", { + SW(ps <- posterior_survfit(mod, newdata = nd1)) + expect_survfit_surv(ps) + }) + + test_that("posterior_survfit works with new data (multiple individuals)", { + SW(ps <- posterior_survfit(mod, newdata = nd2)) + expect_survfit_surv(ps) + }) + + } +} + +# test loo for a few models only (too slow to test them all) +for (j in c(1,2,8,26,30)) { + + mod <- try(get(paste0("f", j)), silent = TRUE) + + if (class(mod)[1L] == "try-error") { + + cat("Model not found:", paste0("f", j), "\n") + + } else { + + cat("Checking loo for model:", paste0("f", j), "\n") + + test_that("loo and waic work", { + loo_try <- try(expect_equivalent_loo(mod), silent = TRUE) + if (class(loo_try)[1L] == "try-error") { + # sometimes loo fails with a small number of draws so refit with more + expect_equivalent_loo(up(mod, iter = 80)) + } + }) + + } +} + + +#---- Check accuracy of log_lik and posterior_survfit: for frailty models --- + +fake_data <- + data.frame(id = c(1,2,3,4), + trt = c(1,0,1,0), + age = c(5,8,2,4), + site = c(1,1,2,2), + eventtime = c(2,4,6,8), + status = c(1,1,1,1)) + +o<-SW(stan1 <- stan_surv(formula = Surv(eventtime, status) ~ trt + age + (1 | site), + data = fake_data, + basehaz = "weibull", + chains = 1, + refresh = 0L, + iter = 100, + warmup = 95)) + +stanmat <- as.matrix(stan1) + +stanpars <- list(int = stanmat[, "(Intercept)"], + trt = stanmat[, "trt"], + age = stanmat[, "age"], + shape = stanmat[, "weibull-shape"], + site = list(stanmat[, "b[(Intercept) site:1]"], + stanmat[, "b[(Intercept) site:2]"])) + +N <- nrow(fake_data) +S <- nrow(stanmat) + +# define function to calculate log likelihood manually +llfun <- function(i, j, data, pars) { + exp_eta_ij <- exp(pars$int[j] + + pars$trt[j] * data$trt[i] + + pars$age[j] * data$age[i] + + pars$site[[data$site[i]]][[j]]) + h_ij <- pars$shape[j] * data$eventtime[i] ^ (pars$shape[j] - 1) * exp_eta_ij + H_ij <- data$eventtime[i] ^ (pars$shape[j]) * exp_eta_ij + return(data$status[i] * log(h_ij) - H_ij) +} + +# define function to calculate survival probability manually +survfun <- function(i, j, t, data, pars) { + exp_eta_ij <- exp(pars$int[j] + + pars$trt[j] * data$trt[i] + + pars$age[j] * data$age[i] + + pars$site[[data$site[i]]][[j]]) + H_ij <- t ^ (pars$shape[j]) * exp_eta_ij + return(exp(- H_ij)) +} + +# check log likelihood +L1 <- log_lik(stan1) +L2 <- log_lik(stan1, newdata = fake_data) +L3 <- matrix(NA, S, N) # manually evaluated log likelihood +for (i in 1:N) { + for (j in 1:S) { + L3[j,i] <- llfun(i, j, data = fake_data, pars = stanpars) + } +} +for (i in 1:N) { + for (j in 1:S) { + expect_equal(as.vector(L3[j,i]), as.vector(L1[j,i])) + expect_equal(as.vector(L3[j,i]), as.vector(L2[j,i])) + } +} + +# check survival probability +P1 <- posterior_survfit(stan1, times = 5, extrapolate = FALSE) +P2 <- posterior_survfit(stan1, newdata = fake_data, times = 5, extrapolate = FALSE) +P3 <- matrix(NA, S, N) # manually evaluated survival probability +for (i in 1:N) { + for (j in 1:S) { + P3[j,i] <- survfun(i, j, t = 5, data = fake_data, pars = stanpars) + } +} +for (i in 1:N) { + expect_equal(median(P3[,i]), P1[i, "median"]) + expect_equal(median(P3[,i]), P2[i, "median"]) +} + + +#---------------- Check parameter estimates: stan vs coxph ----------------- + +compare_surv <- function(data, basehaz = "weibull", ...) { + require(survival) + fm <- Surv(eventtime, status) ~ X1 + X2 + surv1 <- coxph(fm, data) + stan1 <- stan_surv(formula = fm, + data = data, + basehaz = basehaz, + iter = ITER, + refresh = REFRESH, + chains = CHAINS, + seed = SEED, ...) + tols <- get_tols_surv(surv1, tolscales = TOLSCALES) + pars_surv <- recover_pars_surv(surv1) + pars_stan <- recover_pars_surv(stan1) + for (i in names(tols$fixef)) + expect_equal(pars_surv$fixef[[i]], + pars_stan$fixef[[i]], + tol = tols$fixef[[i]], + info = basehaz) +} + +N_coxph <- 1000 + +#---- exponential data + +set.seed(543634) +covs <- data.frame(id = 1:N_coxph, + X1 = rbinom(N_coxph, 1, 0.3), + X2 = rnorm (N_coxph, 2, 2.0)) +dat <- simsurv(dist = "weibull", + lambdas = 0.1, + gammas = 1, + betas = c(X1 = 0.3, X2 = -0.5), + x = covs) +dat <- merge(dat, covs) + +compare_surv(data = dat, basehaz = "exp") + +#---- weibull data + +set.seed(543634) +covs <- data.frame(id = 1:N_coxph, + X1 = rbinom(N_coxph, 1, 0.3), + X2 = rnorm (N_coxph, 2, 2.0)) +dat <- simsurv(dist = "weibull", + lambdas = 0.1, + gammas = 1.3, + betas = c(X1 = 0.3, X2 = -0.5), + x = covs) +dat <- merge(dat, covs) + +compare_surv(data = dat, basehaz = "weibull") +compare_surv(data = dat, basehaz = "ms") +compare_surv(data = dat, basehaz = "bs") + +#---- gompertz data + +set.seed(45357) +covs <- data.frame(id = 1:N_coxph, + X1 = rbinom(N_coxph, 1, 0.3), + X2 = rnorm (N_coxph, 2, 2.0)) +dat <- simsurv(dist = "gompertz", + lambdas = 0.1, + gammas = 0.05, + betas = c(X1 = -0.6, X2 = -0.4), + x = covs) +dat <- merge(dat, covs) + +compare_surv(data = dat, basehaz = "gompertz") + + +#----------- Check parameter estimates: stan (AFT) vs survreg --------------- + +compare_surv <- function(data, basehaz = "weibull-aft", ...) { + require(survival) + fm <- Surv(eventtime, status) ~ X1 + X2 + dist <- ifelse(basehaz == "weibull-aft", "weibull", "exponential") + surv1 <- survreg(fm, data, dist = dist) + stan1 <- stan_surv(formula = fm, + data = data, + basehaz = basehaz, + iter = ITER, + refresh = REFRESH, + chains = CHAINS, + seed = SEED, + ...) + tols <- get_tols_surv(surv1, tolscales = TOLSCALES) + pars_surv <- recover_pars_surv(surv1) + pars_stan <- recover_pars_surv(stan1) + for (i in names(tols$fixef)) + expect_equal(pars_surv$fixef[[i]], + pars_stan$fixef[[i]], + tol = tols$fixef[[i]], + info = basehaz) +} + +N_survreg <- 300 + +#---- exponential data + +set.seed(543634) +covs <- data.frame(id = 1:N_survreg, + X1 = rbinom(N_survreg, 1, 0.3), + X2 = rnorm (N_survreg, 2, 2.0)) +dat <- simsurv(dist = "weibull", + lambdas = 0.1, + gammas = 1, + betas = c(X1 = 0.3, X2 = -0.5), + x = covs) +dat <- merge(dat, covs) + +compare_surv(data = dat, basehaz = "exp-aft") + +#---- weibull data + +set.seed(543634) +covs <- data.frame(id = 1:N_survreg, + X1 = rbinom(N_survreg, 1, 0.3), + X2 = rnorm (N_survreg, 2, 2.0)) +dat <- simsurv(dist = "weibull", + lambdas = 0.1, + gammas = 1.3, + betas = c(X1 = 0.3, X2 = -0.5), + x = covs) +dat <- merge(dat, covs) + +compare_surv(data = dat, basehaz = "weibull-aft") + + +#-------- Check parameter estimates: stan (tve) vs coxph (tt) --------------- + +# NB: this only checks piecewise constant hazard ratio + +set.seed(SEED) + +N <- 1000 # number of individuals to simulate + +covs <- data.frame(id = 1:N, + X1 = rbinom(N, 1, 0.3), + X2 = rnorm (N, 2, 2.0)) + +dat <- simsurv(dist = "exponential", + x = covs, + lambdas = c(0.1), + betas = c(X1 = 0.3, X2 = -0.3), + tve = c(X1 = -0.6), + tvefun = function(t) as.numeric(t > 10), + maxt = 30) + +o<-SW(surv1 <- coxph( + formula = Surv(eventtime, status) ~ X1 + tt(X1) + X2, + data = merge(dat, covs), + tt = function(x, t, ...) { x * as.numeric(t > 10) })) + +o<-SW(stan1 <- stan_surv( + formula = Surv(eventtime, status) ~ tve(X1, degree = 0, knots = c(10)) + X2, + data = merge(dat, covs), + basehaz = "exp", + chains = CHAINS, + refresh = REFRESH, + iter = ITER)) + +tols <- get_tols_surv(surv1, tolscales = TOLSCALES) + +pars_surv <- recover_pars_surv(surv1) +pars_stan <- recover_pars_surv(stan1) + +for (i in names(tols$fixef)) + expect_equal(pars_surv$fixef[[i]], + pars_stan$fixef[[i]], + tol = tols$fixef[[i]], + info = "compare_estimates_tve_pw") + + +# COMMENTED OUT TO AVOID ADDING PACKAGES TO SUGGESTS +# +# #------- Compare parameter estimates: stan (icens) vs icenReg ------------- +# +# #---- simulated interval censored weibull data +# +# library(icenReg); set.seed(321) +# +# # simulate interval censored data +# sim_data <- simIC_weib(n = 5000, +# b1 = 0.3, +# b2 = -0.3, +# model = 'ph', +# shape = 2, +# scale = 2, +# inspections = 6, +# inspectLength = 1) +# +# # lower limit = 0 is actually left censoring (stan_surv doesn't accept 0's) +# sim_data$l[sim_data$l == 0] <- -Inf +# +# # fit stan model to interval censored data +# fm <- Surv(l, u, type = 'interval2') ~ x1 + x2 +# ic_stan <- stan_surv(fm, +# data = sim_data, +# basehaz = "weibull", +# iter = ITER, +# refresh = REFRESH, +# chains = CHAINS, +# seed = SEED) +# +# # compare stan estimates to known values from data generating model +# truepars <- c('x1' = 0.3, 'x2' = -0.3, 'weibull-shape' = 2) +# stanpars <- fixef(ic_stan) +# expect_equal(stanpars[['x1']], +# truepars[['x1']], +# tol = 0.01, +# info = "compare estimates (x1) with icenReg") +# expect_equal(stanpars[['x2']], +# truepars[['x2']], +# tol = 0.01, +# info = "compare estimates (x2) with icenReg") +# expect_equal(stanpars[['weibull-shape']], +# truepars[['weibull-shape']], +# tol = 0.1, +# info = "compare estimates (weibull-shape) with icenReg") +# +# # fit model using icenReg package & compare log_lik with stan model +# ic_icen <- ic_par(fm, data = sim_data) +# ll_icen <- ic_icen$llk +# ll_stan <- mean(rowSums(log_lik(ic_stan))) +# expect_equal(ll_icen, +# ll_stan, +# tol = 5, +# info = "compare log lik with icenReg") +# +# +# #---- Compare parameter estimates: stan (tvc & delayed entry) vs phreg ---- +# +# #---- mortality data: contains a time-varying covariate +# +# library(eha); library(dplyr); set.seed(987) +# +# # add a time-fixed covariate to the mortality data +# data(mort); mort <- mort %>% group_by(id) %>% mutate(sesfixed = ses[[1]]) +# +# # fit models using the time-fixed covariate & compare HR estimates +# fm <- Surv(enter, exit, event) ~ sesfixed +# f_weib <- phreg(fm, data = mort) +# f_stan <- stan_surv(fm, +# data = mort, +# basehaz = "weibull", +# iter = ITER, +# refresh = REFRESH, +# chains = CHAINS, +# seed = SEED) +# expect_equal(coef(f_weib)['sesfixedupper'], +# coef(f_stan)['sesfixedupper'], +# tol = 0.01) +# +# # fit models using the time-varying covariate & compare HR estimates +# fm <- Surv(enter, exit, event) ~ ses +# v_weib <- phreg(fm, data = mort) +# v_stan <- stan_surv(fm, +# data = mort, +# basehaz = "weibull", +# iter = ITER, +# refresh = REFRESH, +# chains = CHAINS, +# seed = SEED) +# expect_equal(coef(v_weib)['sesupper'], +# coef(v_stan)['sesupper'], +# tol = 0.01) +# +# # stupidity check; to make sure the hazard ratios actually differed +# # between the models with the time-fixed and time-varying covariate +# expect_error(expect_equal(coef(f_weib)['sesfixedupper'][[1]], +# coef(v_weib)['sesupper'][[1]], +# tol = 0.1), "not equal") + + +#-------- Check parameter estimates: stan (frailty) vs simulated ----------- + +# define a function to simulate a survival dataset +make_data <- function(n = 10, # number of patients per site + K = 30, # number of sites + dist = "exponential", # basehaz for simulation + delay = FALSE, # induce delayed entry + icens = FALSE) { # induce interval censoring + + if (delay && icens) + stop("'delay' and 'icens' cannot both be TRUE.") + + # dimensions + N <- n * K # total num individuals + + # true sd for the random intercepts + true_sd <- 1 + + # sample random intercept for each site + bb <- rnorm(K, 0, true_sd) + + # covariate data + cov <- data.frame(id = 1:N, + site = rep(1:K, each = n), + trt = rbinom(N, 1, 0.5), + b = bb[rep(1:K, each = n)]) + + # simulate event times + dat <- simsurv(dist = dist, + lambdas = 0.1, + gammas = switch(dist, + "weibull" = 1.3, + "gompertz" = 0.05, + NULL), + x = cov, + betas = c(trt = 0.3, b = 1), + maxt = 15) + + # create delayed entry + if (delay) { + dat[["start"]] <- runif(N, 0, dat[["eventtime"]] / 2) + dat[["stop"]] <- dat[["eventtime"]] + } + + # create interval censoring + if (icens) { + + dd <- dat[["status"]] # event indicator + dat[["lower"]] <- rep(NA, nrow(dat)) + dat[["upper"]] <- rep(NA, nrow(dat)) + + # construct lower/upper interval cens times for right censored individuals + dat[dd == 0, "lower"] <- dat[dd == 0, "eventtime"] + dat[dd == 0, "upper"] <- Inf + + # construct lower/upper interval cens times for individuals with events + dat[dd == 1, "lower"] <- runif(sum(dd == 1), + dat[dd == 1, "eventtime"] / 2, + dat[dd == 1, "eventtime"]) + dat[dd == 1, "upper"] <- runif(sum(dd == 1), + dat[dd == 1, "eventtime"], + dat[dd == 1, "eventtime"] * 1.5) + + } + + merge(cov, dat) + +} + +# true parameter values used to simulate & corresponding tolerances for tests +true <- c(intercept = log(0.1), trt = 0.3, b_sd = 1) +tols <- c(0.2, 0.1, 0.2) + +# function to return the parameter estimates to test +get_ests <- function(mod) { + c(intercept = fixef(mod)[["(Intercept)"]], + trt = fixef(mod)[["trt"]], + b_sd = attr(VarCorr(mod)[[1]], "stddev")[[1]]) +} + +# fit right censored models + +# simulate datasets +set.seed(SEED) +n <- 50 +K <- 100 +dat <- make_data(n = n, K = K) +dat_delay <- make_data(n = n, K = K, delay = TRUE) +dat_icens <- make_data(n = n, K = K, icens = TRUE) + +# formulas +ff <- Surv(eventtime, status) ~ trt + (1 | site) # right cens +ffd <- Surv(start, stop, status) ~ trt + (1 | site) # delayed entry +ffi <- Surv(lower, upper, type = "interval2") ~ trt + (1 | site) # interval cens + +# fit the starting model +o<-SW(m1 <- stan_surv(formula = ff, + data = dat, + basehaz = "exp", + iter = ITER, + refresh = REFRESH, + chains = CHAINS, + seed = SEED)) + +# fit the additional models +o<-SW(m2 <- up(m1, formula. = ff, data = dat, basehaz = "weibull")) +o<-SW(m3 <- up(m1, formula. = ff, data = dat, basehaz = "gompertz")) +o<-SW(m4 <- up(m1, formula. = ff, data = dat, basehaz = "ms")) +o<-SW(m5 <- up(m1, formula. = ffd, data = dat_delay, basehaz = "exp")) +o<-SW(m6 <- up(m1, formula. = ffd, data = dat_delay, basehaz = "weibull")) +o<-SW(m7 <- up(m1, formula. = ffd, data = dat_delay, basehaz = "gompertz")) +o<-SW(m8 <- up(m1, formula. = ffd, data = dat_delay, basehaz = "ms")) +o<-SW(m9 <- up(m1, formula. = ffi, data = dat_icens, basehaz = "exp")) +o<-SW(m10 <- up(m1, formula. = ffi, data = dat_icens, basehaz = "weibull")) +o<-SW(m11 <- up(m1, formula. = ffi, data = dat_icens, basehaz = "gompertz")) +o<-SW(m12 <- up(m1, formula. = ffi, data = dat_icens, basehaz = "ms")) + +# check the estimates against the true parameters +for (j in c(1:12)) { + modfrail <- get(paste0("m", j)) + for (i in 1:3) + expect_equal(get_ests(modfrail)[[i]], true[[i]], tol = tols[[i]]) +} + + +#--- previous tests use really weak tolerances to check the +# parameter estimates; therefore the next part conducts a full +# simulation study to test each model specification and uses a +# stronger tolerance, checking that relative bias is less than 5% + +if (run_sims) { + + # number of simulations (for each model specification) + n_sims <- 200 + + # define a function to fit the model to one simulated dataset + sim_run <- function(n = 10, # number of patients per site + K = 30, # number of sites + basehaz = "exp", # basehaz for analysis + dist = "exponential", # basehaz for simulation + delay = FALSE, # induce delayed entry + icens = FALSE, # induce interval censoring + return_relb = FALSE) { + + # simulate data + dat <- make_data(n = n, K = K, dist = dist, delay = delay, icens = icens) + + # define appropriate model formula + if (delay) { + ff <- Surv(start, stop, status) ~ trt + (1 | site) + } else if (icens) { + ff <- Surv(lower, upper, type = "interval2") ~ trt + (1 | site) + } else { + ff <- Surv(eventtime, status) ~ trt + (1 | site) + } + + # fit model + mod <- stan_surv(formula = ff, + data = dat, + basehaz = basehaz, + chains = 1, + refresh = 0, + iter = 2000) + + # true parameters (hard coded here) + true <- c(intercept = log(0.1), + trt = 0.3, + b_sd = 1) + + # extract parameter estimates + ests <- c(intercept = fixef(mod)["(Intercept)"], + trt = fixef(mod)["trt"], + b_sd = attr(VarCorr(mod)[[1]], "stddev")[[1]]) + + # intercept is irrelevant for spline model + if (basehaz %in% c("ms", "bs")) { + true <- true[2:3] + ests <- ests[2:3] + } + + + # check Rhat + rhats <- summary(mod)[, "Rhat"] + rhats <- rhats[!names(rhats) %in% c("lp__", "log-posterior")] + + converged <- (all(rhats <= 1.1, na.rm = TRUE)) + + if (!converged) + ests <- rep(NA, length(ests)) # set estimates to NA if model didn't converge + + if (return_relb) + return(as.vector((ests - true) / true)) + + list(true = true, + ests = ests, + bias = ests - true, + relb = (ests - true) / true) + } + + # functions to summarise the simulations and check relative bias + summarise_sims <- function(x) { + message("Number of simulations that converged: ", + sum(!is.na(do.call(rbind, x["ests",])[,1]))) + rbind(true = colMeans(do.call(rbind, x["true",]), na.rm = TRUE), + ests = colMeans(do.call(rbind, x["ests",]), na.rm = TRUE), + bias = colMeans(do.call(rbind, x["bias",]), na.rm = TRUE), + relb = colMeans(do.call(rbind, x["relb",]), na.rm = TRUE)) + } + + validate_relbias <- function(x, tol = 0.05) { + message("Number of simulations that converged: ", + sum(!is.na(do.call(rbind, x["ests",])[,1]))) + relb <- as.vector(summarise_sims(x)["relb",]) + expect_equal(relb, rep(0, length(relb)), tol = tol) + } + +} + +# right censored models +if (run_sims) { + set.seed(5050) + sims_exp <- replicate(n_sims, sim_run(basehaz = "exp")) + validate_relbias(sims_exp) +} + +if (run_sims) { + set.seed(6060) + sims_weibull <- replicate(n_sims, sim_run(basehaz = "weibull")) + validate_relbias(sims_weibull) +} + +if (run_sims) { + set.seed(7070) + sims_gompertz <- replicate(n_sims, sim_run(basehaz = "gompertz")) + validate_relbias(sims_gompertz) +} + +if (run_sims) { + set.seed(8080) + sims_ms <- replicate(n_sims, sim_run(basehaz = "ms")) + validate_relbias(sims_ms) +} + +# delayed entry models +if (run_sims) { + set.seed(5050) + sims_exp_d <- replicate(n_sims, sim_run(basehaz = "exp", delay = TRUE)) + validate_relbias(sims_exp_d) +} + +if (run_sims) { + set.seed(6060) + sims_weibull_d <- replicate(n_sims, sim_run(basehaz = "weibull", delay = TRUE)) + validate_relbias(sims_weibull_d) +} + +if (run_sims) { + set.seed(7070) + sims_gompertz_d <- replicate(n_sims, sim_run(basehaz = "gompertz", delay = TRUE)) + validate_relbias(sims_gompertz_d) +} + +if (run_sims) { + set.seed(8080) + sims_ms_d <- replicate(n_sims, sim_run(basehaz = "ms", delay = TRUE)) + validate_relbias(sims_ms_d) +} + +# interval censored models +if (run_sims) { + set.seed(5050) + sims_exp_i <- replicate(n_sims, sim_run(basehaz = "exp", icens = TRUE)) + validate_relbias(sims_exp_i) +} + +if (run_sims) { + set.seed(6060) + sims_weibull_i <- replicate(n_sims, sim_run(basehaz = "weibull", icens = TRUE)) + validate_relbias(sims_weibull_i) +} + +if (run_sims) { + set.seed(7070) + sims_gompertz_i <- replicate(n_sims, sim_run(basehaz = "gompertz", icens = TRUE)) + validate_relbias(sims_gompertz_i) +} + +if (run_sims) { + set.seed(8080) + sims_ms_i <- replicate(n_sims, sim_run(basehaz = "ms", icens = TRUE)) + validate_relbias(sims_ms_i) +} + + +# run simulations to check piecewise constant time-varying effects +if (run_sims) { + + # number of simulations (for each model specification) + n_sims <- 250 + + # define a function to fit the model to one simulated dataset + sim_run <- function(N = 600, return_relb = FALSE) { + + # simulate data + covs <- data.frame(id = 1:N, + trt = rbinom(N, 1, 0.5)) + + dat <- simsurv(dist = "exponential", + x = covs, + lambdas = c(0.15), + betas = c(trt = -0.4), + tde = c(trt = 0.8), + tdefun = function(t) as.numeric(t > 4), + maxt = 15) + + dat <- merge(covs, dat) + + # define appropriate model formula + ff <- Surv(eventtime, status) ~ tve(trt, degree = 0, knots = 4) + + # fit model + mod <- stan_surv(formula = ff, + data = dat, + basehaz = "exp", + chains = 1, + refresh = 0, + iter = 1000) + + # true parameters (hard coded here) + true <- c(intercept = log(0.15), + trt = -0.4, + trt_tve = 0.8) + + # extract parameter estimates + ests <- c(intercept = fixef(mod)[1L], + trt = fixef(mod)[2L], + trt_tve = fixef(mod)[3L]) + + # check Rhat + rhats <- summary(mod)[, "Rhat"] + rhats <- rhats[!names(rhats) %in% c("lp__", "log-posterior")] + + converged <- (all(rhats <= 1.1, na.rm = TRUE)) + + if (!converged) + ests <- rep(NA, length(ests)) # set estimates to NA if model didn't converge + + if (return_relb) + return(as.vector((ests - true) / true)) + + list(true = true, + ests = ests, + bias = ests - true, + relb = (ests - true) / true) + } + + # functions to summarise the simulations and check relative bias + summarise_sims <- function(x) { + message("Number of simulations that converged: ", + sum(!is.na(do.call(rbind, x["ests",])[,1]))) + rbind(true = colMeans(do.call(rbind, x["true",])), + ests = colMeans(do.call(rbind, x["ests",])), + bias = colMeans(do.call(rbind, x["bias",])), + relb = colMeans(do.call(rbind, x["relb",]))) + } + + validate_relbias <- function(x, tol = 0.05) { + message("Number of simulations that converged: ", + sum(!is.na(do.call(rbind, x["ests",])[,1]))) + relb <- as.vector(summarise_sims(x)["relb",]) + expect_equal(relb, rep(0, length(relb)), tol = tol) + } + +} + +# tve models +if (run_sims) { + set.seed(5050) + sims_pw <- replicate(n_sims, sim_run()) + validate_relbias(sims_pw) +} diff --git a/vignettes/surv.Rmd b/vignettes/surv.Rmd new file mode 100644 index 000000000..222536623 --- /dev/null +++ b/vignettes/surv.Rmd @@ -0,0 +1,1563 @@ +--- +title: "Estimating Survival (Time-to-Event) Models with rstanarm" +author: "Sam Brilleman" +date: "`r Sys.Date()`" +output: + html_vignette: + toc: true + number_sections: false +params: + EVAL: !r identical(Sys.getenv("NOT_CRAN"), "true") +--- + + + + + +```{r, child="children/SETTINGS-knitr.txt"} +``` +```{r, child="children/SETTINGS-gg.txt"} +``` + +```{r setup_jm, include=FALSE, message=FALSE} +knitr::opts_chunk$set(fig.width=10, fig.height=4) +library(rstanarm) +set.seed(989898) + +CHAINS <- 1 +CORES <- 1 +SEED <- 12345 +ITER <- 1000 +``` + + +# Preamble + +This vignette provides an introduction to the `stan_surv` modelling function in the __rstanarm__ package. The `stan_surv` function allows the user to fit survival models (sometimes known as time-to-event models) under a Bayesian framework. + + +# Introduction + +Survival (or time-to-event) analysis is concerned with the analysis of an outcome variable that corresponds to the time from some defined baseline until an event of interest occurs. The methodology is used in a range of disciplines where it is known by a variety of different names. These include survival analysis (medicine), duration analysis (economics), reliability analysis (engineering), and event history analysis (sociology). Survival analyses are particularly common in health and medical research, where a classic example of survival outcome data is the time from diagnosis of a disease until the occurrence of death. + +In standard survival analysis, one event time is measured for each observational unit. In practice however that event time may be unobserved due to left, right, or interval censoring, in which case the event time is only known to have occurred within the relevant censoring interval. The combined aspects of time and censoring make survival analysis methodology distinct from many other regression modelling approaches. + +There are two common approaches to modelling survival data. The first is to model the instantaneous rate of the event (known as the hazard) as a function of time. This includes the class of models known as proportional and non-proportional hazards regression models. The second is to model the event time itself. This includes the class of models known as accelerated failure time (AFT) models. Under both of these modelling frameworks a number of extensions have been proposed. For instance the handling of recurrent events, competing events, clustered survival data, cure models, and more. More recently, methods for modelling both longitudinal (e.g. a repeatedly measured biomarker) and survival data have become increasingly popular (as described in the `stan_jm` [vignette](priors.html)). + +This vignette is structured as follows. In the next sections we describe the modelling, estimation, and prediction frameworks underpinning survival models in __rstanarm__. Following that we describe the implementation and arguments for the `stan_surv` modelling function. Following that we demonstrate usage of the package through a series of examples. + +# Modelling framework + +## Data and notation + +We assume that a true event time for individual $i$ ($i = 1,...,N$) exists and can be denoted $T_i^*$. However, in practice $T_i^*$ may not be observed due to left, right, or interval censoring. We therefore observe outcome data $\mathcal{D}_i = \{T_i, T_i^U, T_i^E, d_i\}$ for individual $i$ where: + +- $T_i$ denotes the observed event or censoring time +- $T_i^U$ denotes the observed upper limit for interval censored individuals +- $T_i^E$ denotes the observed entry time (i.e. the time at which an individual became at risk for the event); and +- $d_i \in \{0,1,2,3\}$ denotes an event indicator taking value 0 if individual $i$ was right censored (i.e. $T_i^* > T_i$), value 1 if individual $i$ was uncensored (i.e. $T_i^* = T_i$), value 2 if individual $i$ was left censored (i.e. $T_i^* < T_i$), or value 3 if individual $i$ was interval censored (i.e. $T_i < T_i^* < T_i^U$). + +### Hazard, cumulative hazard, and survival + +There are three key quantities of interest in standard survival analysis: the hazard rate, the cumulative hazard, and the survival probability. It is these quantities that are used to form the likelihood function for the survival models described in later sections. + +The hazard is the instantaneous rate of occurrence for the event at time $t$. Mathematically, it is defined as: +\ +\begin{equation} +\begin{split} +h_i(t) = \lim_{\Delta t \to 0} + \frac{P(t \leq T_i^* < t + \Delta t | T_i^* > t)}{\Delta t} +\end{split} +\end{equation} +\ +where $\Delta t$ is the width of some small time interval. + +The numerator is the conditional probability of the individual experiencing the event during the time interval $[t, t + \Delta t)$, given that they were still at risk of the event at time $t$. The denominator converts the conditional probability to a rate per unit of time. As $\Delta t$ approaches the limit, the width of the interval approaches zero and the instantaneous event rate is obtained. + +The cumulative hazard is defined as: +\ +\begin{equation} +\begin{split} +H_i(t) = \int_{s=0}^t h_i(s) ds +\end{split} +\end{equation} +\ +and the survival probability is defined as: +\ +\begin{equation} +\begin{split} +S_i(t) = \exp \left[ -H_i(t) \right] = \exp \left[ -\int_{s=0}^t h_i(s) ds \right] +\end{split} +\end{equation} + +It can be seen here that in the standard survival analysis setting -- where there is one event type of interest (i.e. no competing events) -- there is a one-to-one relationship between each of the hazard, the cumulative hazard, and the survival probability. + +### Delayed entry + +Delayed entry, also known as left truncation, occurs when an individual is not at risk of the event until some time $t > 0$. As previously described we use $T_i^E$ to denote the entry time at which the individual becomes at risk. A common situation where delayed entry occurs is when age is used as the time scale. With age as the time scale it is likely that our study will only be concerned with the observation of individuals starting from some time (i.e. age) $t > 0$. + +To allow for delayed entry we essentially want to work with a conditional survival probability: +\ +\begin{equation} +\begin{split} +S_i \left(t \mid T_i^E > 0 \right) = \frac{S_i(t)}{S_i \left( T_i^E \right)} +\end{split} +\end{equation} + +Here the survival probability is evaluated conditional on the individual having survived up to the entry time. This conditional survival probability is used to allow for delayed entry in the log likelihood of our survival model. + +## Model formulations + +Our modelling approaches are twofold. First, we define a class of models on the hazard scale. This includes both proportional and non-proportional hazard regression models. Second, we define a class of models on the scale of the survival time. These are often known as accelerated failure time (AFT) models and can include both time-fixed and time-varying acceleration factors. + +These two classes of models and their respective features are described in the following sections. + +### Hazard scale models + +Under a hazard scale formulation, we model the hazard of the event for individual $i$ at time $t$ using the regression model: +/ +\begin{equation} +\begin{split} +h_i(t) = h_0(t) \exp \left( \eta_i(t) \right) +\end{split} +\end{equation} +/ +where $h_0(t)$ is the baseline hazard (i.e. the hazard for an individual with all covariates set equal to zero) at time $t$, and $\eta_i(t)$ denotes the linear predictor evaluated for individual $i$ at time $t$. + +For full generality we allow the linear predictor to be time-varying. That is, it may be a function of time-varying covariates and/or time-varying coefficients (e.g. a time-varying hazard ratio). However, if there are no time-varying covariates or time-varying coefficients in the model, then the linear predictor reduces to a time-fixed quantity and the definition of the hazard function reduces to: +/ +\begin{equation} +\begin{split} +h_i(t) = h_0(t) \exp \left( \eta_i \right) +\end{split} +\end{equation} +/ +where the linear predictor $\eta_i$ is no longer a function of time. We describe the linear predictor in detail in later sections. + +Different distributional assumptions can be made for the baseline hazard $h_0(t)$ and affect how the baseline hazard changes as a function of time. The __rstanarm__ package currently accommodates several standard parametric distributions for the baseline hazard (exponential, Weibull, Gompertz) as well as more flexible approaches that directly model the baseline hazard as a piecewise or smooth function of time using splines. + +The following describes the baseline hazards that are currently implemented in the __rstanarm__ package. + +#### M-splines model (the default): + +Let $M_{l}(t; \boldsymbol{k}, \delta)$ denote the $l^{\text{th}}$ $(l = 1,...,L)$ basis term for a degree $\delta$ M-spline function evaluated at a vector of knot locations $\boldsymbol{k} = \{k_{1},...,k_{J}\}$, and $\gamma_{l}$ denote the $l^{\text{th}}$ M-spline coefficient. We then have: +/ +\begin{equation} +h_i(t) = \sum_{l=1}^{L} \gamma_{l} M_{l}(t; \boldsymbol{k}, \delta) \exp ( \eta_i(t) ) +\end{equation} + +The M-spline basis is evaluated using the method described in \cite{Ramsay:1988} and implemented in the __splines2__ package \citep{Wang:2018}. + +To ensure that the hazard function $h_i(t)$ is not constrained to zero at the origin (i.e. when $t$ approaches 0) the M-spline basis incorporates an intercept. To ensure identifiability of both the M-spline coefficients and the intercept in the linear predictor we constrain the M-spline coefficients to a simplex, that is, $\sum_{l=1}^L{\gamma_l} = 1$. + +The default degree in __rstanarm__ is $\delta = 3$ (i.e. cubic M-splines) such that the baseline hazard can be modelled as a flexible and smooth function of time, however this can be changed by the user. It is worthwhile noting that setting $\delta = 0$ is treated as a special case that corresponds to a piecewise constant baseline hazard. + +#### Exponential model: + +For scale parameter $\lambda_i(t) = \exp ( \eta_i(t) )$ we have: +/ +\begin{equation} +h_i(t) = \lambda_i(t) +\end{equation} + +In the case where the linear predictor is not time-varying, the exponential model leads to a hazard rate that is constant over time. + +#### Weibull model: + +For scale parameter $\lambda_i(t) = \exp ( \eta_i(t) )$ and shape parameter $\gamma > 0$ we have: +/ +\begin{equation} +h_i(t) = \gamma t^{\gamma-1} \lambda_i(t) +\end{equation} + +In the case where the linear predictor is not time-varying, the Weibull model leads to a hazard rate that is monotonically increasing or monotonically decreasing over time. In the special case where $\gamma = 1$ it reduces to the exponential model. + +#### Gompertz model: + +For shape parameter $\lambda_i(t) = \exp ( \eta_i(t) )$ and scale parameter $\gamma > 0$ we have: +/ +\begin{equation} +h_i(t) = \exp(\gamma t) \lambda_i(t) +\end{equation} + +#### B-splines model (for the log baseline hazard): + +Let $B_{l}(t; \boldsymbol{k}, \delta)$ denote the $l^{\text{th}}$ $(l = 1,...,L)$ basis term for a degree $\delta$ B-spline function evaluated at a vector of knot locations $\boldsymbol{k} = \{k_{1},...,k_{J}\}$, and $\gamma_{l}$ denote the $l^{\text{th}}$ B-spline coefficient. We then have: +/ +\begin{equation} +h_i(t) = \exp \left( \sum_{l=1}^{L} \gamma_{l} B_{l}(t; \boldsymbol{k}, \delta) + \eta_i(t) \right) +\end{equation} +/ +The B-spline basis is calculated using the method implemented in the __splines2__ package \citep{Wang:2018}. The B-spline basis does not require an intercept and therefore does not include one; any constant shift in the log hazard is fully captured via an intercept in the linear predictor. By default cubic B-splines are used (i.e. $\delta = 3$) and these allow the log baseline hazard to be modelled as a smooth function of time. + +### Accelerated failure time (AFT) models + +Under an AFT formulation we model the survival probability for individual $i$ at time $t$ using the regression model \citep{Hougaard:1999}: +/ +\begin{equation} \label{eq:aftform-surv} +\begin{split} +S_i(t) = S_0 \left( \int_{u=0}^t \exp \left( - \eta_i(u) \right) du \right) +\end{split} +\end{equation} +/ +where $S_0(t)$ is the baseline survival probability at time $t$, and $\eta_i(t)$ denotes the linear predictor evaluated for individual $i$ at time $t$. For full generality we again allow the linear predictor to be time-varying. This also leads to a corresponding general expression for the hazard function \citep{Hougaard:1999} as follows: +/ +\begin{align} \label{eq:aftform-haz} +\begin{split} +h_i(t) = \exp \left(-\eta_i(t) \right) h_0 \left( \int_{u=0}^t \exp \left( - \eta_i(u) \right) du \right) +\end{split} +\end{align} + +If there are no time-varying covariates or time-varying coefficients in the model, then the definition of the survival probability reduces to: +/ +\begin{equation} +\begin{split} +S_i(t) = S_0 \left( t \exp \left( - \eta_i \right) \right) +\end{split} +\end{equation} +/ +and for the hazard: +/ +\begin{align} +\begin{split} +h_i(t) = \exp \left( -\eta_i \right) h_0 \left( t \exp \left( - \eta_i \right) \right) +\end{split} +\end{align} + +Different distributional assumptions can be made for how the baseline survival probability $S_0(t)$ changes as a function of time. The __rstanarm__ package currently accommodates two standard parametric distributions (exponential, Weibull) although others may be added in the future. The current distributions are implemented as follows. + +#### Exponential model: + +When the linear predictor is time-varying we have: +/ +\begin{equation} +S_i(t) = \exp \left( - \int_{u=0}^t \exp ( -\eta_i(u) ) du \right) +\end{equation} +/ +and when the linear predictor is time-fixed we have: +/ +\begin{equation} +S_i(t) = \exp \left( - t \lambda_i \right) +\end{equation} +/ +for scale parameter $\lambda_i = \exp ( -\eta_i )$. + +#### Weibull model: + +When the linear predictor is time-varying we have: +/ +\begin{equation} +S_i(t) = \exp \left( - \left( \int_{u=0}^t \exp ( -\eta_i(u) ) du \right)^{\gamma} \right) +\end{equation} +/ +for shape parameter $\gamma > 0$ and when the linear predictor is time-fixed we have: +/ +\begin{equation} +S_i(t) = \exp \left( - t^{\gamma} \lambda_i \right) +\end{equation} +/ +for scale parameter $\lambda_i = \exp ( -\gamma \eta_i )$ and shape parameter $\gamma > 0$. + +## Linear predictor + +Under all of the previous model formulations our linear predictor can be defined as: +/ +\begin{equation} \label{eq:eta} +\begin{split} +\eta_i(t) = \boldsymbol{\beta}^T(t) \boldsymbol{X}_i(t) +\end{split} +\end{equation} +/ +where $\boldsymbol{X}_i(t) = [1, x_{i1}(t), ..., x_{iP}(t) ]$ denotes a vector of covariates with $x_{ip}(t)$ denoting the observed value of $p^{th}$ $(p = 1,...,P)$ covariate for the $i^{th}$ $(i=1,...,N)$ individual at time $t$, and $\boldsymbol{\beta}(t) = [ \beta_0, \beta_1(t), ... , \beta_P(t) ]$ denotes a vector of parameters with $\beta_0$ denoting an intercept parameter and $\beta_p(t)$ denoting the possibly time-varying coefficient for the $p^{th}$ covariate. + +### Hazard ratios + +Under a hazard scale formulation the quantity $\exp \left( \beta_p(t) \right)$ is referred to as a \textif{hazard ratio}. + +The hazard ratio quantifies the relative increase in the hazard that is associated with a unit-increase in the relevant covariate, $x_{ip}$, assuming that all other covariates in the model are held constant. For instance, a hazard ratio of 2 means that a unit-increase in the covariate leads to a doubling in the hazard (i.e. the instantaneous rate) of the event. + +### Acceleration factors and survival time ratios + +Under an AFT formulation the quantity $\exp \left( - \beta_p(t) \right)$ is referred to as an \textif{acceleration factor} and the quantity $\exp \left( \beta_p(t) \right)$ is referred to as a \textif{survival time ratio}. + +The acceleration factor quantifies the acceleration (or deceleration) of the event process that is associated with a unit-increase in the relevant covariate, $x_{ip}$. For instance, an acceleration factor of 0.5 means that a unit-increase in the covariate corresponds to approaching the event at half the speed. + +The survival time ratio is interpreted as the increase (or decrease) in the expected survival time that is associated with a unit-increase in the relevant covariate, $x_{ip}$. For instance, a survival time ratio of 2 (which is equivalent to an acceleration factor of 0.5) means that a unit-increase in the covariate leads to an doubling in the expected survival time. + +Note that the survival time ratio is a simple reparameterisation of the acceleration factor. Specifically, the survival time ratio is equal to the reciprocal of the acceleration factor. The survival time ratio and the acceleration factor therefore provide alternative interpretations for the same effect of the same covariate. + +### Time-fixed vs time-varying effects + +Under either a hazard scale or AFT formulation the coefficient $\beta_p(t)$ can be treated as a time-fixed or time-varying quantity. + +When $\beta_p(t)$ is treated as a time-fixed quantity we have: +/ +\begin{equation} +\begin{split} +\beta_p(t) = \theta_{p0} +\end{split} +\end{equation} +/ +such that $\theta_{p0}$ is a time-fixed log hazard ratio (or log survival time ratio). On the hazard scale this is equivalent to assuming proportional hazards, whilst on the AFT scale it is equivalent to assuming a time-fixed acceleration factor. + +When $\beta_p(t)$ is treated as a time-varying quantity we refer to it as a time-varying effect because the effect of the covariate is allowed to change as a function of time. On the hazard scale this leads to non-proportional hazards, whilst on the AFT scale it leads to time-varying acceleration factors. + +When $\beta_p(t)$ is time-varying we must determine how we wish to model it. In __rstanarm__ the default is to use B-splines such that: +/ +\begin{equation} +\begin{split} +\beta_p(t) = \theta_{p0} + \sum_{l=1}^{L} \theta_{pl} B_{l}(t; \boldsymbol{k}, \delta) +\end{split} +\end{equation} +/ +where $\theta_{p0}$ is a constant, $B_{l}(t; \boldsymbol{k}, \delta)$ is the $l^{\text{th}}$ $(l = 1,...,L)$ basis term for a degree $\delta$ B-spline function evaluated at a vector of knot locations $\boldsymbol{k} = \{k_{1},...,k_{J}\}$, and $\theta_{pl}$ is the $l^{\text{th}}$ B-spline coefficient. By default cubic B-splines are used (i.e. $\delta = 3$). These allow the log hazard ratio (or log survival time ratio) to be modelled as a smooth function of time. + +However an alternative is to model $\beta_p(t)$ using a piecewise constant function: +/ +\begin{equation} +\begin{split} +\beta_p(t) = \theta_{p0} + \sum_{l=1}^{L} \theta_{pl} I(k_{l+1} < t \leq k_{l+2}) +\end{split} +\end{equation} +/ +where $I(x)$ is an indicator function taking value 1 if $x$ is true and 0 otherwise, $\theta_{p0}$ is a constant corresponding to the log hazard ratio (or log survival time ratio for AFT models) in the first time interval, $\theta_{pl}$ is the deviation in the log hazard ratio (or log survival time ratio) between the first and $(l+1)^\text{th}$ $(l = 1,...,L)$ time interval, and $\boldsymbol{k} = \{k_{1},...,k_{J}\}$ is a sequence of knot locations (i.e. break points) that includes the lower and upper boundary knots. This allows the log hazard ratio (or log survival time ratio) to be modelled as a piecewise constant function of time. + +Note that we have dropped the subscript $p$ from the knot locations $\boldsymbol{k}$ and degree $\delta$ discussed above. This is just for simplicity of the notation. In fact, if a model has a time-varying effect estimated for more than one covariate, then each of these can be modelled using different knot locations and/or degree if the user desires. These knot locations and/or degree can also differ from those used for modelling the baseline or log baseline hazard described previously in Section \ref{sec:modelformulations}. + +### Relationship between proportional hazards and AFT models + +As shown in Section \ref{sec:modelformulations} some baseline distributions can be parameterised as either a proportional hazards or an AFT model. In __rstanarm__ this currently includes the exponential and Weibull models. One can therefore transform the estimates from an exponential or Weibull proportional hazards model to get the estimates that would be obtained under an exponential or Weibull AFT parameterisation. + +Specifically, the following relationship applies for the exponential model: +/ +\begin{equation} +\begin{split} +\beta_0 & = - \beta_0^* \\ +\beta_p & = - \beta_p^* +\end{split} +\end{equation} +/ +and for the Weibull model: +/ +\begin{equation} +\begin{split} +\beta_0 & = -\gamma \beta_0^* \\ +\beta_p & = -\gamma \beta_p^* +\end{split} +\end{equation} +/ +where the unstarred parameters are from the proportional hazards model and the starred ($*$) parameters are from the AFT model. Note however that these relationships only hold in the absence of time-varying effects. This is demonstrated using a real dataset in the example in Section \ref{sec:aftmodel}. + +## Multilevel survival models + +The definition of the linear predictor in Equation \ref{eq:eta} can be extended to allow for shared frailty or other clustering effects. + +Suppose that the individuals in our sample belong to a series of clusters. The clusters may represent for instance hospitals, families, or GP clinics. We denote the $i^{th}$ individual ($i = 1,...,N_j$) as a member of the $j^{th}$ cluster ($j = 1,...,J$). Moreover, to indicate the fact that individual $i$ is now a member of cluster $j$ we index the observed data (i.e. event times, event indicator, and covariates) with a subscript $j$, that is $T_{ij}^*$, $\mathcal{D}_{ij} = \{T_{ij}, T_{ij}^U, T_{ij}^E, d_{ij}\}$ and $X_{ij}(t)$, as well as estimated quantities such as the hazard rate, cumulative hazard, survival probability, and linear predictor, that is $h_{ij}(t)$, $H_{ij}(t)$, $S_{ij}(t)$, and $\eta_{ij}(t)$. + +To allow for intra-cluster correlation in the event times we include cluster-specific random effects in the linear predictor as follows: +/ +\begin{equation} \label{eq:multileveleta} +\begin{split} +\eta_{ij}(t) = \boldsymbol{\beta}^T \boldsymbol{X}_{ij}(t) + \boldsymbol{b}_{j}^T \boldsymbol{Z}_{ij} +\end{split} +\end{equation} +/ +where $\boldsymbol{Z}_{ij}$ denotes a vector of covariates for the $i^{th}$ individual in the $j^{th}$ cluster, with an associated vector of cluster-specific parameters $\boldsymbol{b}_{j}$. We assume that the cluster-specific parameters are normally distributed such that $\boldsymbol{b}_{j} \sim N(0, \boldsymbol{\Sigma}}_{b})$ for some variance-covariance matrix $\boldsymbol{\Sigma}}_{b}$. We assume that $\boldsymbol{\Sigma}}_{b}$ is unstructured, that is each variance and covariance term is allowed to be different. + +In most cases $\boldsymbol{b}_{j}$ will correspond to just a cluster-specific random intercept (often known as a "shared frailty" term) but more complex random effects structures are possible. + +For simplicitly of notation Equation \ref{eq:multileveleta} also assumes just one clustering factor in the model (indexed by $j = 1,...,J$). However, it is possible to extend the model to multiple clustering factors (in __rstanarm__ there is no limit to the number of clustering factors that can be included). For example, suppose that the $i^{th}$ individual was clustered within the $j^{th}$ hospital that was clustered within the $k^{th}$ geographical region. Then we would have hospital-specific random effects $\boldsymbol{b}_j \sim N(0, \boldsymbol{\Sigma}}_{b})$ and region-specific random effects $\boldsymbol{u}_k \sim N(0, \boldsymbol{\Sigma}}_{u})$ and assume $\boldsymbol{b}_j$ and $\boldsymbol{u}_k$ are independent for all $(j,k)$. + +# Estimation framework + +## Log posterior + +The log posterior for the $i^{th}$ individual in the $j^{th}$ cluster can be specified as: +/ +\begin{equation} +\begin{split} +\log p(\boldsymbol{\theta}, \boldsymbol{b}_{j} \mid \mathcal{D}_{ij}) + \propto + \log p(\mathcal{D}_{ij} \mid \boldsymbol{\theta}, \boldsymbol{b}_{j}) + + \log p(\boldsymbol{b}_{j} \mid \boldsymbol{\theta}) + + \log p(\boldsymbol{\theta}) +\end{split} +\end{equation} +/ +where $\log p(\mathcal{D}_{ij} \mid \boldsymbol{\theta}, \boldsymbol{b}_{j})$ is the log likelihood for the outcome data, $\log p(\boldsymbol{b}_{j} \mid \boldsymbol{\theta})$ is the log likelihood for the distribution of any cluster-specific parameters (i.e. random effects) when relevant, and $\log p(\boldsymbol{\theta})$ represents the log likelihood for the joint prior distribution across all remaining unknown parameters. + +## Log likelihood + +Allowing for the three forms of censoring (left, right, and interval censoring) and potential delayed entry (i.e. left truncation) the log likelihood for the survival model takes the form: +/ +\begin{equation} \label{eq:loglik} +\begin{split} +\log p(\mathcal{D}_{ij} \mid \boldsymbol{\theta}, \boldsymbol{b}_{j}) + & = {I(d_{ij} = 0)} \times \log \left[ S_{ij}(T_{ij}) \right] \\ + & \quad + {I(d_{ij} = 1)} \times \log \left[ h_{ij}(T_{ij}) \right] \\ + & \quad + {I(d_{ij} = 1)} \times \log \left[ S_{ij}(T_{ij}) \right] \\ + & \quad + {I(d_{ij} = 2)} \times \log \left[ 1 - S_{ij}(T_{ij}) \right] \\ + & \quad + {I(d_{ij} = 3)} \times \log \left[ S_{ij}(T_{ij}) - S_{ij}(T_{ij}^U) \right] \\ + & \quad - \log \left[ S_{ij} ( T_{ij}^E ) \right] +\end{split} +\end{equation} +/ +where $I(x)$ is an indicator function taking value 1 if $x$ is true and 0 otherwise. That is, each individual's contribution to the likelihood depends on the type of censoring for their event time. + +The last term on the right hand side of Equation \ref{eq:loglik} accounts for delayed entry. When an individual is at risk from time zero (i.e. no delayed entry) then $T_{ij}^E = 0$ and $S_{ij}(0) = 1$ meaning that the last term disappears from the likelihood. + +### Evaluating integrals in the log likelihood + +When the linear predictor is time-fixed there is a closed form expression for both the hazard rate and survival probability in almost all cases (the single exception is when B-splines are used to model the log baseline hazard). When there is a closed form expression for both the hazard rate and survival probability then there is also a closed form expression for the (log) likelihood function. The details of these expressions are given in Appendix \ref{app:haz-parameterisations} (for hazard models) and Appendix \ref{app:aft-parameterisations} (for AFT models). + +However, when the linear predictor is time-varying there isn't a closed form expression for the survival probability. Instead, Gauss-Kronrod quadrature with $Q$ nodes is used to approximate the necessary integrals. + +For hazard scale models Gauss-Kronrod quadrature is used to evaluate the cumulative hazard, which in turn is used to evaluate the survival probability. Expanding on Equation \ref{eq:survdef} we have: +/ +\begin{equation} +\begin{split} +\int_{u=0}^{T_{ij}} h_{ij}(u) du + \approx \frac{T_{ij}}{2} \sum_{q=1}^{Q} w_q h_{ij} \left( \frac{T_{ij}(1 + v_q)}{2} \right) +\end{split} +\end{equation} +/ +where $w_q$ and $v_q$, respectively, are the standardised weights and locations ("abscissa") for quadrature node $q$ $(q = 1,...,Q)$ \citep{Laurie:1997}. + +For AFT models Gauss-Kronrod quadrature is used to evaluate the cumulative acceleration factor, which in turn is used to evaluate both the survival probability and the hazard rate. Expanding on Equations \ref{eq:aftform-surv} and \ref{eq:aftform-haz} we have: +/ +\begin{equation} +\begin{split} +\int_{u=0}^{T_{ij}} \exp \left( - \eta_{ij}(u) \right) du + \approx \frac{T_{ij}}{2} \sum_{q=1}^{Q} w_q \exp \left( - \eta_{ij} \left( \frac{T_{ij}(1 + v_q)}{2} \right) \right) +\end{split} +\end{equation} + +When quadrature is necessary, the default in __rstanarm__ is to use $Q = 15$ nodes. But the number of nodes can be changed by the user. + +## Prior distributions + +For each of the parameters a number of prior distributions are available. Default choices exist, but the user can explicitly specify the priors if they wish. + +### Intercept + +All models include an intercept parameter in the linear predictor ($\beta_0$) which effectively forms part of the baseline hazard. Choices of prior distribution for $\beta_0$ include the normal, t, or Cauchy distributions. The default is a normal distribution with mean 0 and standard deviation of 20. + +However it is worth noting that -- internally (but not in the reported parameter estimates) -- the prior is placed on the intercept after centering the predictors at their sample means and after applying a constant shift of $\log \left( \frac{E}{T} \right)$ where $E$ is the total number of events and $T$ is the total follow up time. For instance, the default prior is not centered on an intercept of zero when all predictors are at their sample means, but rather, it is centered on the log crude event rate when all predictors are at their sample means. This is intended to help with numerical stability and sampling, but does not impact on the reported estimates (i.e. the intercept is back-transformed before being returned to the user). + +\subsubsection{Regression coefficients} + +Choices of prior distribution for the time-fixed regression coefficients $\theta_{p0}$ ($p = 1,...,P$) include normal, t, and Cauchy distributions as well as several shrinkage prior distributions. + +Where relevant, the additional coefficients required for estimating a time-varying effect (i.e. the B-spline coefficients or the interval-specific deviations in the piecewise constant function) are given a random walk prior of the form $\theta_{p,1} \sim N(0,1)$ and $\theta_{p,m} \sim N(\theta_{p,m-1},\tau_p)$ for $m = 2,...,M$, where $M$ is the total number of cubic B-spline basis terms. The prior distribution for the hyperparameter $\tau_p$ can be specified by the user and choices include an exponential, half-normal, half-t, or half-Cauchy distribution. Note that lower values of $\tau_p$ lead to a less flexible (i.e. smoother) function for modelling the time-varying effect. + +\subsubsection{Auxiliary parameters} + +There are several choices of prior distribution for the so-called "auxiliary" parameters related to the baseline hazard (i.e. scalar $\gamma$ for the Weibull and Gompertz models or vector $\boldsymbol{\gamma}$ for the M-spline and B-spline models). These include: + +\begin{itemize} + +\item a Dirichlet prior distribution for the baseline hazard M-spline coefficients $\boldsymbol{\gamma}$; + +\item a half-normal, half-t, half-Cauchy or exponential prior distribution for the Weibull shape parameter $\gamma$; + +\item a half-normal, half-t, half-Cauchy or exponential prior distribution for the Gompertz scale parameter $\gamma$; and + +\item a normal, t, or Cauchy prior distribution for the log baseline hazard B-spline coefficients $\boldsymbol{\gamma}$. + +\end{itemize} + +\subsubsection{Covariance matrices} + +When a multilevel survival model is estimated there is an unstructured covariance matrix estimated for the random effects. Of course, in the situation where there is just one random effect in the model `formula` (e.g. a random intercept or "shared frailty" term) the covariance matrix will reduce to just a single element; i.e. it will be a scalar equal to the variance of the single random effect in the model. + +The prior distribution is based on a decomposition of the covariance matrix. The decomposition takes place as follows. The covariance matrix $\boldsymbol{\Sigma}_b$ is decomposed into a correlation matrix $\boldsymbol{\Omega}$ and vector of variances. The vector of variances is then further decomposed into a simplex $\pi$ (i.e. a probability vector summing to 1) and a scalar equal to the sum of the variances. Lastly, the sum of the variances is set equal to the order of the covariance matrix multiplied by the square of a scale parameter (here we denote that scale parameter $\tau$). + +The prior distribution for the correlation matrix $\boldsymbol{\Omega}$ is the LKJ distribution \citep{Lewandowski:2009}. It is parameterised through a regularisation parameter $\zeta > 0$. The default is $\zeta = 1$ such that the LKJ prior distribution is jointly uniform over all possible correlation matrices. When $\zeta > 1$ the mode of the LKJ distribution is the identity matrix and as $\zeta$ increases the distribution becomes more sharply peaked at the mode. When $0 < \zeta < 1$ the prior has a trough at the identity matrix. + +The prior distribution for the simplex $\boldsymbol{\pi}$ is a symmetric Dirichlet distribution with a single concentration parameter $\phi > 0$. The default is $\phi = 1$ such that the prior is jointly uniform over all possible simplexes. If $\phi > 1$ then the prior mode corresponds to all entries of the simplex being equal (i.e. equal variances for the random effects) and the larger the value of $\phi$ then the more pronounced the mode of the prior. If $0 < \phi < 1$ then the variances are polarised. + +The prior distribution for the scale parameter $\tau$ is a Gamma distribution. The shape and scale parameter for the Gamma distribution are both set equal to 1 by default, however the user can change the value of the shape parameter. The behaviour is such that increasing the shape parameter will help enforce that the trace of $\boldsymbol{\Sigma}_b$ (i.e. sum of the variances of the random effects) be non-zero. + +Further details on this implied prior for covariance matrices can be found in the __rstanarm__ documentation and vignettes. + +## Estimation + +Estimation in __rstanarm__ is based on either full Bayesian inference (Hamiltonian Monte Carlo) or approximate Bayesian inference (either mean-field or full-rank variational inference). The default is full Bayesian inference, but the user can change this if they wish. The approximate Bayesian inference algorithms are much faster, but they only provide approximations for the joint posterior distribution and are therefore not recommended for final inference. + +Hamiltonian Monte Carlo is a form of Markov chain Monte Carlo (MCMC) in which information about the gradient of the log posterior is used to more efficiently sample from the posterior space. Stan uses a specific implementation of Hamiltonian Monte Carlo known as the No-U-Turn Sampler (NUTS) \citep{Hoffman:2014}. A benefit of NUTS is that the tuning parameters are handled automatically during a "warm-up" phase of the estimation. However the __rstanarm__ modelling functions provide arguments that allow the user to retain control over aspects such as the number of MCMC chains, number of warm-up and sampling iterations, and number of computing cores used. + +# Prediction framework + +## Survival predictions without clustering + +If our survival model does not contain any clustering effects (i.e. it is not a multilevel survival model) then our prediction framework is more straightforward. Let $\mathcal{D} = \{ \mathcal{D}_{i}; i = 1,...,N \}$ denote the entire collection of outcome data in our sample and let $T_{\max} = \max \{ T_{i}, T_{i}^U, T_{i}^E; i = 1,...,N \}$ denote the maximum event or censoring time across all individuals in our sample. + +Suppose that for some individual $i^*$ (who may or may not have been in our sample) we have covariate vector $\boldsymbol{x}_{i^*}$. Note that the covariate data must be time-fixed. The predicted probability of being event-free at time $0 < t \leq T_{\max}$, denoted $\hat{S}_{i^*}(t)$, can be generated from the posterior predictive distribution: +/ +\begin{equation} +\begin{split} +p \Big( \hat{S}_{i^*}(t) \mid \boldsymbol{x}_{i^*}, \mathcal{D} \Big) = + \int + p \Big( \hat{S}_{i^*}(t) \mid \boldsymbol{x}_{i^*}, \boldsymbol{\theta} \Big) + p \Big( \boldsymbol{\theta} \mid \mathcal{D} \Big) + d \boldsymbol{\theta} +\end{split} +\end{equation} + +We approximate this posterior predictive distribution by drawing from $p(\hat{S}_{i^*}(t) \mid \boldsymbol{x}_{i^*}, \boldsymbol{\theta}^{(l)})$ where $\boldsymbol{\theta}^{(l)}$ is the $l^{th}$ $(l = 1,...,L)$ MCMC draw from the posterior distribution $p(\boldsymbol{\theta} \mid \mathcal{D})$. + +## Survival predictions with clustering + +When there are clustering effects in the model (i.e. multilevel survival models) then our prediction framework requires conditioning on the cluster-specific parameters. Let $\mathcal{D} = \{ \mathcal{D}_{ij}; i = 1,...,N_j, j = 1,...,J \}$ denote the entire collection of outcome data in our sample and let $T_{\max} = \max \{ T_{ij}, T_{ij}^U, T_{ij}^E; i = 1,...,N_j, j = 1,...,J \}$ denote the maximum event or censoring time across all individuals in our sample. + +Suppose that for some individual $i^*$ (who may or may not have been in our sample) and who is known to come from cluster $j^*$ (which may or may not have been in our sample) we have covariate vectors $\boldsymbol{x}_{i^*j^*}$ and $\boldsymbol{z}_{i^*j^*}$. Note again that the covariate data is assumed to be time-fixed. + +If individual $i^*$ does in fact come from a cluster $j^* = j$ (for some $j \in \{1,...,J\}$) in our sample then the predicted probability of being event-free at time $0 < t \leq T_{\max}$, denoted $S_{i^*j}(t)$, can be generated from the posterior predictive distribution: +/ +\begin{equation} +\begin{split} +p \Big( \hat{S}_{i^*j}(t) \mid \boldsymbol{x}_{i^*j}, \boldsymbol{z}_{i^*j}, \mathcal{D} \Big) = + \int + \int + p \Big( \hat{S}_{i^*j}(t) \mid \boldsymbol{x}_{i^*j}, \boldsymbol{z}_{i^*j}, \boldsymbol{\theta}, \boldsymbol{b}_j \Big) + p \Big( \boldsymbol{\theta}, \boldsymbol{b}_j \mid \mathcal{D} \Big) + d \boldsymbol{b}_j \space d \boldsymbol{\theta} +\end{split} +\end{equation} + +Since cluster $j$ was included in our sample data it is easy for us to approximate this posterior predictive distribution by drawing from $p(\hat{S}_{i^*j}(t) \mid \boldsymbol{x}_{i^*j}, \boldsymbol{z}_{i^*j}, \boldsymbol{\theta}^{(l)}, \boldsymbol{b}_j^{(l)})$ where $\boldsymbol{\theta}^{(l)}$ and $\boldsymbol{b}_j^{(l)}$ are the $l^{th}$ $(l = 1,...,L)$ MCMC draws from the joint posterior distribution $p(\boldsymbol{\theta}, \boldsymbol{b}_j \mid \mathcal{D})$. + +Alternatively, individual $i^*$ may come from a new cluster $j^* \neq j$ (for all $j \in \{1,...,J\}$) that was not in our sample. The predicted probability of being event-free at time $0 < t \leq T_{\max}$ is therefore denoted $\hat{S}_{i^*j^*}(t)$ and can be generated from the posterior predictive distribution: +/ +\begin{equation} +\begin{aligned} +p \Big( \hat{S}_{i^*j^*}(t) \mid \boldsymbol{x}_{i^*j^*}, \boldsymbol{z}_{i^*j^*}, \mathcal{D} \Big) +& = + \int + \int + p \Big( \hat{S}_{i^*j^*}(t) \mid \boldsymbol{x}_{i^*j^*}, \boldsymbol{z}_{i^*j^*}, \boldsymbol{\theta}, \boldsymbol{\tilde{b}}_{j^*} \Big) + p \Big( \boldsymbol{\theta}, \boldsymbol{\tilde{b}}_{j^*} \mid \mathcal{D} \Big) + d \boldsymbol{\tilde{b}}_{j^*} \space d \boldsymbol{\theta} \\ +& = + \int + \int + p \Big( \hat{S}_{i^*j^*}(t) \mid \boldsymbol{x}_{i^*j^*}, \boldsymbol{z}_{i^*j^*}, \boldsymbol{\theta}, \boldsymbol{\tilde{b}}_{j^*} \Big) + p \Big( \boldsymbol{\tilde{b}}_{j^*} \mid \boldsymbol{\theta} \Big) + p \Big( \boldsymbol{\theta} \mid \mathcal{D} \Big) + d \boldsymbol{\tilde{b}}_{j^*} \space d \boldsymbol{\theta} \\ +\end{aligned} +\end{equation} +/ +where $\boldsymbol{\tilde{b}}_{j^*}$ denotes the cluster-specific parameters for the new cluster. We can obtain draws for $\boldsymbol{\tilde{b}}_{j^*}$ during estimation of the model (in a similar manner as for $\boldsymbol{b}_j$). At the $l^{th}$ iteration of the MCMC sampler we obtain $\boldsymbol{\tilde{b}}_{j^*}^{(l)}$ as a random draw from the posterior distribution of the cluster-specific parameters and store it for later use in predictions. The set of random draws $\boldsymbol{\tilde{b}}_{j^*}^{(l)}$ for $l = 1,...,L$ then allow us to essentially marginalise over the distribution of the cluster-specific parameters. This is the method used in __rstanarm__ when generating survival predictions for individuals in new clusters that were not part of the original sample. + +## Conditional survival probabilities + +In some instances we want to evaluate the predicted survival probability conditional on a last known survival time. This is known as a conditional survival probability. + +Suppose that individual $i^*$ is known to be event-free up until $C_{i^*}$ and we wish to predict the survival probability at some time $t > C_{i^*}$. To do this we draw from the conditional posterior predictive distribution: +/ +\begin{equation} +\begin{split} +p \Big( \hat{S}_{i^*}(t) \mid \boldsymbol{x}_{i^*}, \mathcal{D}, t > C_{i^*} \Big) = + \frac + {p \Big( \hat{S}_{i^*}(t) \mid \boldsymbol{x}_{i^*}, \mathcal{D} \Big)} + {p \Big( \hat{S}_{i^*}(C_{i^*}) \mid \boldsymbol{x}_{i^*}, \mathcal{D} \Big)} +\end{split} +\end{equation} +/ +or -- equivalently -- for multilevel survival models we have individual $i^*$ in cluster $j^*$ who is known to be event-free up until $C_{i^*j^*}$: +/ +\begin{equation} +\begin{split} +p \Big( \hat{S}_{i^*j^*}(t) \mid \boldsymbol{x}_{i^*j^*}, \boldsymbol{z}_{i^*j^*}, \mathcal{D}, t > C_{i^*j^*} \Big) = + \frac + {p \Big( \hat{S}_{i^*j^*}(t) \mid \boldsymbol{x}_{i^*j^*}, \boldsymbol{z}_{i^*j^*}, \mathcal{D} \Big)} + {p \Big( \hat{S}_{i^*j^*}(C_{i^*j^*}) \mid \boldsymbol{x}_{i^*j^*}, \boldsymbol{z}_{i^*j^*}, \mathcal{D} \Big)} +\end{split} +\end{equation} + +## Standardised survival probabilities + +All of the previously discussed predictions require conditioning on some covariate values $\boldsymbol{x}_{ij}$ and $\boldsymbol{z}_{ij}$. Even if we have a multilevel survival model and choose to marginalise over the distribution of the cluster-specific parameters, we are still obtaining predictions at some known unique values of the covariates. + +However sometimes we wish to generate an "average" survival probability. One possible approach is to predict at the mean value of all covariates \citep{Cupples:1995}. However this doesn't always make sense, especially not in the presence of categorical covariates. For instance, suppose our covariates are gender and a treatment indicator. Then predicting for an individual at the mean of all covariates might correspond to a 50% male who was 50% treated. That does not make sense and is not what we wish to do. + +A better alternative is to average over the individual survival probabilties. This essentially provides an approximation to marginalising over the joint distribution of the covariates. At any time $t$ it is possible to obtain a so-called standardised survival probability, denoted $\hat{S}^{*}(t)$, by averaging the individual-specific survival probabilities: +/ +\begin{equation} +\begin{split} +p ( \hat{S}^{*}(t) \mid \mathcal{D} ) = + \frac{1}{N^{P}} + \sum_{i=1}^{N^{P}} p ( \hat{S}_i(t) \mid \boldsymbol{x}_{i^*}, \mathcal{D} ) +\end{split} +\end{equation} +/ +where $\hat{S}_i(t)$ is the predicted survival probability for individual $i$ ($i = 1,...,N^{P}$) at time $t$, and $N^{P}$ is the number of individuals included in the predictions. For multilevel survival models the calculation is similar and follows quite naturally (details not shown). + +Note however that if $N^{P}$ is not sufficiently large (for example we predict individual survival probabilities using covariate data for just $N^{P} = 2$ individuals) then averaging over their covariate distribution may not be meaningful. Similarly, if we estimated a multilevel survival model and then predicted standardised survival probabilities based on just $N^{P} = 2$ individuals from our sample, the joint distribution of their cluster-specific parameters would likely be a poor representation of the distribution of cluster-specific parameters for the entire sample and population. + +It is therefore better to calculate standardised survival probabilities by setting $N^{P}$ equal to the total number of individuals in the original sample (i.e. $N^{P} = N$. This approach can then also be used for assessing the fit of the survival model in __rstanarm__ (see the \fct{ps_check} function described in Section \ref{sec:implementation}). Posterior predictive draws of the standardised survival probability are evaluated at a series of time points between 0 and $T_{\max}$ using all individuals in the estimation sample and the predicted standardised survival curve is overlaid with the observed Kaplan-Meier survival curve. + +# Implementation + +## Overview + +The __rstanarm__ package is built on top of the __rstan__ R package \citep{Stan:2019}, which is the R interface for Stan. +Models in __rstanarm__ are written in the Stan programming language, translated into C++ code, and then compiled at the time the package is built. This means that for most users -- who install a binary version of __rstanarm__ from the Comprehensive R Archive Network (CRAN) -- the models in __rstanarm__ will be pre-compiled. This is beneficial for users because there is no compilation time either during installation or when they estimate a model. + +## Main modelling function + +Survival models in __rstanarm__ are implemented around the \fct{stan_surv} modelling function. + +The function signature for \fct{stan_surv} is: +/ +<>= +stan_surv(formula, data, basehaz = "ms", basehaz_ops, qnodes = 15, + prior = normal(), prior_intercept = normal(), prior_aux, + prior_smooth = exponential(autoscale = FALSE), + prior_covariance = decov(), prior_PD = FALSE, + algorithm = c("sampling", "meanfield", "fullrank"), + adapt_delta = 0.95, ...) +@ + +The following provides a brief description of the main features of each of these arguments: + +\begin{itemize} + +\item The `formula` argument accepts objects built around the standard R formula syntax (see \fct{stats::formula}). The left hand side of the formula should be an object returned by the \fct{Surv} function in the __survival__ package \citep{Therneau:2019}. Any random effects structure (for multilevel survival models) can be specified on the right hand side of the formula using the same syntax as the __lme4__ package \citep{Bates:2015} as shown in the example in Section \ref{sec:multilevelmodel}. + +By default, any covariate effects specified in the fixed-effect part of the model `formula` are included under a proportional hazards assumption (for models estimated using a hazard scale formulation) or under the assumption of time-fixed acceleration factors (for models estimated using an AFT formulation). Time-varying effects are specified in the model `formula` by wrapping the covariate name in the \fct{tve} function. For example, if we wanted to estimate a time-varying effect for the covariate \code{sex} then we could specify \code{tve(sex)} in the model formula, e.g. \code{formula = Surv(time, status) ~ tve(sex) + age}. The \fct{tve} function is a special function that only has meaning when used in the `formula` of a model estimated using \fct{stan_surv}. Its functionality is demonstrated in the worked examples in Sections \ref{sec:tvebs} and \ref{sec:tvepw}. + +\item The \code{data} argument accepts an object inheriting the class \class{data.frame}, in other words the usual R data frame. + +\item The choice of parametric baseline hazard (or baseline survival distribution for AFT models) is specified via the \code{basehaz} argument. For the M-spline (\code{"ms"}) and B-spline (\code{"bs"}) models additional options related to the spline degree $\delta$, knot locations $\boldsymbol{k}$, or degrees of freedom $L$ can be specified as a list and passed to the \code{basehaz_ops} argument. For example, specifying \code{basehaz = "ms"} and \code{basehaz_ops = list(degree = 2, knots = c(10,20))} would request a baseline hazard modelled using quadratic M-splines with two internal knots located at $t = 10$ and $t = 20$. + +\item The argument \code{qnodes} is a control argument that allows the user to specify the number of quadrature nodes when quadrature is required (as described in Section \ref{sec:loglikelihood}). + +\item The \code{prior} family of arguments allow the user to specify the prior distributions for each of the parameters, as follows: + + \begin{itemize} + + \item \code{prior} relates to the time-fixed regression coefficients; + + \item \code{prior_intercept} relates to the intercept in the linear predictor; + + \item \code{prior_aux} relates to the so-called "auxiliary" parameters in the baseline hazard ($\gamma$ for the Weibull and Gompertz models or $\boldsymbol{\gamma}$ for the M-spline and B-spline models); + + \item \code{prior_smooth} relates to the hyperparameter $\tau_p$ when the $p^{th}$ covariate has a time-varying effect; and + + \item \code{prior_covariance} relates to the covariance matrix for the random effects when a multilevel survival model is being estimated. + + \end{itemize} + +\item The remaining arguments (\code{prior_PD}, \code{algorithm}, and \code{adapt_delta}) are optional control arguments related to estimation in Stan: + + \begin{itemize} + + \item Setting \code{prior_PD = TRUE} states that the user only wants to draw from the prior predictive distribution and not condition on the data. + + \item The \code{algorithm} argument specifies the estimation routine to use. This includes either Hamiltonian Monte Carlo (\code{"sampling"}) or one of the variational Bayes algorithms (\code{"meanfield"} or \code{"fullrank"}). The model specification is agnostic to the chosen \code{algorithm}. That is, the user can choose from any of the available algorithms regardless of the specified model. + + \item The \code{adapt_delta} argument controls the target average acceptance probability. It is only relevant when \code{algorithm = "sampling"} in which case \code{adapt_delta} should be between 0 and 1, with higher values leading to smaller step sizes and therefore a more robust sampler but longer estimation times. + + \end{itemize} + +\end{itemize} + +The model returned by \fct{stan_surv} is an object of class \class{stansurv} and inheriting the \class{stanreg} class. It is effectively a list with a number of important attributes. There are a range of post-estimation functions that can be called on \class{stansurv} (and \class{stanreg}) objects -- some of the most important ones are described in Section \ref{sec:postest-functions}. + +### Default knot locations + +Default knot locations for the M-spline, B-spline, or piecewise constant functions are the same regardless of whether they are used for modelling the baseline hazard or time-varying effects. By default the vector of knot locations $\boldsymbol{k} = \{k_{1},...,k_{J}\}$ includes a lower boundary knot $k_{1}$ at the earliest entry time (equal to zero if there isn't delayed entry) and an upper boundary knot $k_{J}$ at the latest event or censoring time. The location of the boundary knots cannot be changed by the user. + +Internal knot locations -- that is $k_{2},...,k_{(J-1)}$ when $J \geq 3$ -- can be explicitly specified by the user or are determined by default. The number of internal knots and/or their locations can be controlled via the \code{basehaz_ops} argument to \fct{stan_surv} (for modelling the baseline hazard) or via the arguments to the \fct{tve} function (for modelling a time-varying effect). If knot locations are not explicitly specified by the user, then the default is to place the internal knots at equally spaced percentiles of the distribution of uncensored event times. For instance, if there are three internal knots they would be placed at the $25^{\text{th}}$, $50^{\text{th}}$, and $75^{\text{th}}$ percentiles of the distribution of the uncensored event times. + +## Post-estimation functions + +The __rstanarm__ package provides a range of post-estimation functions that can be used after fitting the survival model. This includes functions for inference (e.g. reporting parameter estimates), diagnostics (e.g. assessing model fit), and generating predictions. We highlight the most important ones here: + +\begin{itemize} + +\item The \fct{print} and \fct{summary} functions provide reports of parameter estimates and some summary information on the data (e.g. number of observations, number of events, etc). They each provide varying levels of detail. For example, the \fct{summary} method provides diagnostic measures such as Gelman and Rubin's Rhat statistic \citep{Gelman:1992} for assessing convergence of the MCMC chains and the number of effective MCMC samples. On the other hand, the \fct{print} method is more concise and does not provide this level of additional detail. + +\item The \fct{fixef} and \fct{ranef} functions report the fixed effect and random effect parameter estimates, respectively. + +\item The \fct{posterior_survfit} function is the primary function for generating survival predictions. The type of prediction is specified via the \code{type} arguments and can currently be any of the following: + + \begin{itemize} + + \item \code{"surv"}: the estimated survival probability; + \item \code{"cumhaz"}: the estimated cumulative hazard; + \item \code{"haz"}: the estimated hazard rate; + \item \code{"cdf"}: the estimated failure probability; + \item \code{"logsurv"}: the estimated log survival probability; + \item \code{"logcumhaz"}: the estimated log cumulative hazard; + \item \code{"loghaz"}: the estimated log hazard rate; or + \item \code{"logcdf"}: the estimated log failure probability. + + \end{itemize} + +There are additional arguments to \fct{posterior_survfit} that control the time at which the predictions are generated (\code{times}), whether they are generated across a time range (referred to as extrapolation, see \code{extrapolate}), whether they are conditional on a last known survival time (\code{condition}), and whether they are averaged across individuals (referred to as standardised predictions, see \code{standardise}). The returned predictions are a data frame with a special class called \class{survfit.stansurv}. The \class{survfit.stansurv} class has both \fct{print} and \fct{plot} methods that can be called on it. These will be demonstrated as part of the examples in Section \ref{sec:usage}. + +\item The \fct{loo} and \fct{waic} functions report model fit statistics. The former is based on approximate leave-one-out cross validation \citep{Vehtari:2017} and is recommended. The latter is a less preferable alternative that reports the Widely Applicable Information Criterion (WAIC) criterion \citep{Watanabe:2010}. Both of these functions are built on top of the __loo__ R package \citep{Vehtari:2019}. The values (objects) returned by either \fct{loo} or \fct{waic} can also be passed to the \fct{loo_compare} function to compare different models estimated on the same dataset. This will be demonstrated as part of the examples in Section \ref{sec:usage}. + +\item The \fct{log_lik} function generates a pointwise log likelihood matrix. That is, it calculates the log likelihood for each observation (either in the original dataset or some new dataset) using each MCMC draw of the model parameters. + +\item The \fct{plot} function allows for a variety of plots depending on the input to the \code{plotfun} argument. The default is to plot the estimated baseline hazard (\code{plotfun = "basehaz"}), but alternatives include a plot of the estimated time-varying hazard ratio for models with time-varying effects (\code{plotfun = "tve"}), plots summarising the parameter estimates (e.g. posterior densities or posterior intervals), and plots providing diagnostics (e.g. MCMC trace plots). + +\item The \fct{ps_check} function provides a quick diagnostic check for the fitted survival function. It is based on the estimation sample and compares the predicted standardised survival curve to the observed Kaplan-Meier survival curve. + +\end{itemize} + +# Usage examples + +## Example: A flexible parametric proportional hazards model + +We will use the German Breast Cancer Study Group dataset (see `?rstanarm-datasets` for details and references). In brief, the data consist of +$N = 686$ patients with primary node positive breast cancer recruited between 1984-1989. The primary response is time to recurrence or death. Median follow-up time was 1084 days. Overall, there were 299 (44%) events and the remaining 387 (56%) individuals were right censored. We concern our analysis here with a 3-category baseline covariate for cancer prognosis (good/medium/poor). + +First, let us load the data and fit the proportional hazards model + +```{r, warning = FALSE, message = FALSE, results='hide'} +mod1 <- stan_surv(formula = Surv(recyrs, status) ~ group, + data = bcancer, + chains = CHAINS, + cores = CORES, + seed = SEED, + iter = ITER) +``` + +The model here is estimated using the default cubic M-splines (with 5 degrees of freedom) for modelling the baseline hazard. Since there are no time-varying effects in the model (i.e. we did not wrap any covariates in the `tve()` function) there is a closed form expression for the cumulative hazard and survival function and so the model is relatively fast to fit. Specifically, the model takes ~3.5 sec for each MCMC chain based on the default 2000 (1000 warm up, 1000 sampling) MCMC iterations. + +We can easily obtain the estimated hazard ratios for the 3-catgeory group covariate using the generic `print` method for `stansurv` objects, as follows + +```{r} +print(mod1, digits = 3) +``` + +We see from this output we see that individuals in the groups with `Poor` or `Medium` prognosis have much higher rates of death relative to the group with `Good` prognosis (as we might expect!). The hazard of death in the `Poor` prognosis group is approximately 5.0-fold higher than the hazard of death in the `Good` prognosis group. Similarly, the hazard of death in the `Medium` prognosis group is approximately 2.3-fold higher than the hazard of death in the `Good` prognosis group. + +It may also be of interest to compare the different types of the baseline hazard we could potentially use. Here, we will fit a series of models, each with a different baseline hazard specification + +```{r, warning = FALSE, message = FALSE, results='hide'} +mod1_exp <- update(mod1, basehaz = "exp") +mod1_weibull <- update(mod1, basehaz = "weibull") +mod1_gompertz <- update(mod1, basehaz = "gompertz") +mod1_bspline <- update(mod1, basehaz = "bs") +mod1_mspline1 <- update(mod1, basehaz = "ms") +mod1_mspline2 <- update(mod1, basehaz = "ms", basehaz_ops = list(df = 10)) +``` + +and then plot the baseline hazards with 95% posterior uncertainty limits using the generic `plot` method for `stansurv` objects (note that the default `plot` for `stansurv` objects is the estimated baseline hazard). We will write a little helper function to adjust the y-axis limits, add a title, and centre the title, on each plot, as follows + +```{r, fig.height=5} +library(ggplot2) + +plotfun <- function(model, title) { + plot(model, plotfun = "basehaz") + # plot baseline hazard + coord_cartesian(ylim = c(0,0.4)) + # adjust y-axis limits + labs(title = title) + # add plot title + theme(plot.title = element_text(hjust = 0.5)) # centre plot title +} + +p_exp <- plotfun(mod1_exp, title = "Exponential") +p_weibull <- plotfun(mod1_weibull, title = "Weibull") +p_gompertz <- plotfun(mod1_gompertz, title = "Gompertz") +p_bspline <- plotfun(mod1_bspline, title = "B-splines with df = 5") +p_mspline1 <- plotfun(mod1_mspline1, title = "M-splines with df = 5") +p_mspline2 <- plotfun(mod1_mspline2, title = "M-splines with df = 10") + +bayesplot::bayesplot_grid(p_exp, + p_weibull, + p_gompertz, + p_bspline, + p_mspline1, + p_mspline2, + grid_args = list(ncol = 3)) +``` + +We can also compare the fit of these models using the `loo` method for `stansurv` objects + +```{r, message=FALSE} +loo_compare(loo(mod1_exp), + loo(mod1_weibull), + loo(mod1_gompertz), + loo(mod1_bspline), + loo(mod1_mspline1), + loo(mod1_mspline2)) +``` + +where we see that models with a flexible parametric (spline-based) baseline hazard fit the data best followed by the standard parametric (Weibull, Gompertz, exponential) models. Roughly speaking, the B-spline and M-spline models seem to fit the data equally well since the differences in `elpd` or `looic` between the models are very small relative to their standard errors. Moreover, increasing the degrees of freedom for the M-splines from 5 to 10 doesn't seem to improve the fit (that is, the default degrees of freedom `df = 5` seems to provide sufficient flexibility to model the baseline hazard). + +After fitting the survival model, we often want to estimate the predicted survival function for individual's with different covariate patterns. Here, let us estimate the predicted survival function between 0 and 5 years for an individual in each of the prognostic groups. To do this, we can use the `posterior_survfit` method for `stansurv` objects, and it's associated `plot` method. First let us construct the prediction (covariate) data + +```{r preddata} +nd <- data.frame(group = c("Good", "Medium", "Poor")) +head(nd) +``` + +and then we will generate the posterior predictions + +```{r predresults} +ps <- posterior_survfit(mod1, newdata = nd, times = 0, extrapolate = TRUE, + control = list(edist = 5)) +head(ps) +``` + +Here we note that the `id` variable in the data frame of posterior predictions identifies which row of `newdata` the predictions correspond to. For demonstration purposes we have also shown a couple of other arguments in the `posterior_survfit` call, namely + +- the `times = 0` argument says that we want to predict at time = 0 (i.e. baseline) for each individual in the `newdata` (this is the default anyway) +- the `extrapolate = TRUE` argument says that we want to extrapolate forward from time 0 (this is also the default) +- the `control = list(edist = 5)` identifies the control of the extrapolation; this is saying extrapolate the survival function forward from time 0 for a distance of 5 time units (the default would have been to extrapolate as far as the largest event or censoring time in the estimation dataset, which is 7.28 years in the `brcancer` data). + +Let us now plot the survival predictions. We will relabel the `id` variable with meaningful labels identifying the covariate profile of each new individual in our prediction data + +```{r predplot} +panel_labels <- c('1' = "Good", '2' = "Medium", '3' = "Poor") +plot(ps) + + ggplot2::facet_wrap(~ id, labeller = ggplot2::labeller(id = panel_labels)) +``` + +We can see from the plot that predicted survival is worst for patients with a `Poor` diagnosis, and best for patients with a `Good` diagnosis, as we would expect based on our previous model estimates. + +Alternatively, if we wanted to obtain and plot the predicted *hazard* function for each individual in our new data (instead of their *survival* function), then we just need to specify `type = "haz"` in our `posterior_survfit` call (the default is `type = "surv"`), as follows + +```{r predhaz} +ph <- posterior_survfit(mod1, newdata = nd, type = "haz") +plot(ph) + + ggplot2::facet_wrap(~ id, labeller = ggplot2::labeller(id = panel_labels)) +``` + +We can quite clearly see in the plot the assumption of proportional hazards. We can also see that the hazard is highest in the `Poor` prognosis group (i.e. worst survival) and the hazard is lowest in the `Good` prognosis group (i.e. best survival). This corresponds to what we saw in the plot of the survival functions previously. + +## Example: Non-proportional hazards modelled using B-splines + +To demonstrate the implementation of time-varying effects in `stan_surv` we will use a simulated dataset, generated using the **simsurv** package (Brilleman, 2018). + +We will simulate a dataset with $N = 200$ individuals with event times generated under the following Weibull hazard function +\ +\begin{align} +h_i(t) = \gamma t^{\gamma-1} \lambda \exp( \beta(t) x_i ) +\end{align} +\ +with scale parameter $\lambda = 0.1$, shape parameter $\gamma = 1.5$, binary baseline covariate $X_i \sim \text{Bern}(0.5)$, and time-varying hazard ratio $\beta(t) = -0.5 + 0.2 t$. We will enforce administrative censoring at 5 years if an individual's simulated event time is >5 years. + +```{r simsurv-simdata} +# load package +library(simsurv) + +# set seed for reproducibility +set.seed(999111) + +# simulate covariate data +covs <- data.frame(id = 1:200, + trt = rbinom(200, 1L, 0.5)) + +# simulate event times +dat <- simsurv(lambdas = 0.1, + gammas = 1.5, + betas = c(trt = -0.5), + tde = c(trt = 0.2), + x = covs, + maxt = 5) + +# merge covariate data and event times +dat <- merge(dat, covs) + +# examine first few rows of data +head(dat) +``` + +Now that we have our simulated dataset, let us fit a model with time-varying hazard ratio for `trt` + +```{r tve_fit1, warning = FALSE, message = FALSE, results='hide'} +mod2 <- stan_surv(formula = Surv(eventtime, status) ~ tve(trt), + data = dat, + chains = CHAINS, + cores = CORES, + seed = SEED, + iter = ITER) +``` + +The `tve` function is used in the model formula to state that we want a time-varying effect (i.e. a time-varying coefficient) to be estimated for the variable `trt`. By default, a cubic B-spline basis with 3 degrees of freedom (i.e. two boundary knots placed at the limits of the range of event times, but no internal knots) is used for modelling the time-varying log hazard ratio. If we wanted to change the degree, knot locations, or degrees of freedom for the B-spline function we can specify additional arguments to the `tve` function. + +For example, to model the time-varying log hazard ratio using quadratic B-splines with 4 degrees of freedom (i.e. two boundary knots placed at the limits of the range of event times, as well as two internal knots placed -- by default -- at the 33.3rd and 66.6th percentiles of the distribution of uncensored event times) we could specify the model formula as + +```{r, warning = FALSE, message = FALSE, results='hide', eval=FALSE} +Surv(eventtime, status) ~ tve(trt, df = 4, degree = 2) +``` + +Let us now plot the estimated time-varying hazard ratio from the fitted model. We can do this using the generic `plot` method for `stansurv` objects, for which we can specify the `plotfun = "tve"` argument. (Note that in this case, there is only one covariate in the model with a time-varying effect, but if there were others, we could specify which covariate(s) we want to plot the time-varying effect for by specifying the `pars` argument to the `plot` call). + +```{r, fig.height=5} +plot(mod2, plotfun = "tve") +``` + +From the plot, we can see how the hazard ratio (i.e. the effect of treatment on the hazard of the event) changes as a function of time. The treatment appears to be protective during the first few years following baseline (i.e. HR < 1), and then the treatment appears to become harmful after about 4 years post-baseline. Thankfully, this is a reflection of the model we simulated under! + +The plot shows a large amount of uncertainty around the estimated time-varying hazard ratio. This is to be expected, since we only simulated a dataset of 200 individuals of which only around 70% experienced the event before being censored at 5 years. So, there is very little data (i.e. very few events) with which to reliably estimate the time-varying hazard ratio. We can also see this reflected in the differences between our data generating model and the estimates from our fitted model. In our data generating model, the time-varying hazard ratio equals 1 (i.e. the log hazard ratio equals 0) at 2.5 years, but in our fitted model the median estimate for our time-varying hazard ratio equals 1 at around ~3 years. This is a reflection of the large amount of sampling error, due to our simulated dataset being so small. + +## Example: Non-proportional hazards modelled using a piecewise constant function + +In the previous example we showed how non-proportional hazards can be modelled by using a smooth B-spline function for the time-varying log hazard ratio. This is the default approach when the `tve` function is used to estimate a time-varying effect for a covariate in the model formula. However, another approach for modelling a time-varying log hazard ratio is to use a piecewise constant function. If we want to use a piecewise constant for the time-varying log hazard ratio (instead of the smooth B-spline function) then we just have to specify the `type` argument to the `tve` function. + +We will again simulate some survival data using the **simsurv** package to show how a piecewise constant hazard ratio can be estimated using `stan_surv`. + +Similar to the previous example, we will simulate a dataset with $N = 500$ individuals with event times generated under a Weibull hazard function with scale parameter $\lambda = 0.1$, shape parameter $\gamma = 1.5$, and binary baseline covariate $X_i \sim \text{Bern}(0.5)$. However, in this example our time-varying hazard ratio will be defined as $\beta(t) = -0.5 + 0.7 \times I(t > 2.5)$ where $I(X)$ is the indicator function taking the value 1 if $X$ is true and 0 otherwise. This corresponds to a piecewise constant log hazard ratio with just two "pieces" or time intervals. The first time interval is $[0,2.5]$ during which the true hazard ratio is $\exp(-0.5) = 0.61$. The second time interval is $(2.5,\infty]$ during which the true log hazard ratio is $\exp(-0.5 + 0.7) = 1.22$. Our example uses only two time intervals for simplicity, but in general we could easily have considered more (although it would have required couple of additional lines of code to simulate the data). We will again enforce administrative censoring at 5 years if an individual's simulated event time is >5 years. + +```{r simsurv-simdata2} +# load package +library(simsurv) + +# set seed for reproducibility +set.seed(888222) + +# simulate covariate data +covs <- data.frame(id = 1:500, + trt = rbinom(500, 1L, 0.5)) + +# simulate event times +dat <- simsurv(lambdas = 0.1, + gammas = 1.5, + betas = c(trt = -0.5), + tde = c(trt = 0.7), + tdefun = function(t) (t > 2.5), + x = covs, + maxt = 5) + +# merge covariate data and event times +dat <- merge(dat, covs) + +# examine first few rows of data +head(dat) +``` + +We now estimate a model with a piecewise constant time-varying effect for the covariate `trt` as + +```{r tve-fit2, warning = FALSE, message = FALSE, results='hide'} +mod3 <- stan_surv(formula = Surv(eventtime, status) ~ + tve(trt, degree = 0, knots = 2.5), + data = dat, + chains = CHAINS, + cores = CORES, + seed = SEED, + iter = ITER) +``` + +This time we specify some additional arguments to the `tve` function, so that our time-varying effect corresponds to the true data generating model used to simulate our event times. Specifically, we specify `degree = 0` to say that we want the time-varying effect (i.e. the time-varying log hazard ratio) to be estimated using a piecewise constant function and `knots = 2.5` says that we only want one internal knot placed at the time $t = 2.5$. + +We can again use the generic `plot` function with argument `plotfun = "tve"` to examine our estimated hazard ratio for treatment + +```{r, fig.height=5} +plot(mod3, plotfun = "tve") +``` + +Here we see that the estimated hazard ratio reasonably reflects our true data generating model (i.e. a hazard ratio of $\approx 0.6$ during the first time interval and a hazard ratio of $\approx 1.2$ during the second time interval) although there is a slight discrepancy due to the sampling variation in the simulated event times. + +## Example: Hierarchical survival models + +To demonstrate the estimation of a hierarchical model for survival data in `stan_surv` we will use the `frail` dataset (see `help("rstanarm-datasets")` for a description). The `frail` datasets contains simulated event times for 200 patients clustered within 20 hospital sites (10 patients per hospital site). The event times are simulated from a parametric proportional hazards model under the following assumptions: (i) a constant (i.e. exponential) baseline hazard rate of 0.1; (ii) a fixed treatment effect with log hazard ratio of 0.3; and (iii) a site-specific random intercept (specified on the log hazard scale) drawn from a N(0,1) distribution. + +Let's look at the first few rows of the data: + +```{r frail-data-head} +head(frail) +``` + +To fit a hierarchical model for clustered survival data we use a formula syntax similar to what is used in the **lme4** R package (Bates et al. (2015)). Let's consider the following model (which aligns with the model used to generate the simulated data): + +```{r frail-fit-model, warning = FALSE, message = FALSE} +mod_randint <- stan_surv( + formula = Surv(eventtime, status) ~ trt + (1 | site), + data = frail, + basehaz = "exp", + chains = CHAINS, + cores = CORES, + seed = SEED, + iter = ITER) +``` + +The model contains a baseline covariate for treatment (0 or 1) as well as a site-specific intercept to allow for correlation in the event times for patients from the same site. We've call the model object `mod_randint` to denote the fact that it includes a site-specific (random) intercept. Let's examine the parameter estimates from the model: + +```{r frail-estimates} +print(mod_randint, digits = 2) +``` + +We see that the estimated log hazard ratio for treatment ($\hat{\beta}_{\text{(trt)}} = 0.46$) is a bit larger than the "true" log hazard ratio used in the data generating model ($\beta_{\text{(trt)}} = 0.3$). The estimated baseline hazard rate is $\exp(-2.3716) = 0.093$, which is pretty close to the baseline hazard rate used in the data generating model ($0.1$). Of course, the differences between the estimated parameters and the true parameters from the data generating model are attributable to sampling noise. + +If this were a real analysis, we might wonder whether the site-specific estimates are necessary! Well, we can assess that by fitting an alternative model that does **not** include the site-specific intercepts and compare it to the model we just estimated. We will compare it using the `loo` function. We first need to fit the model without the site-specific intercept. To do this, we will just use the generic `update` method for `stansurv` objects, since all we are changing is the model formula: + +```{r frail-fixed-model, warning = FALSE, message = FALSE} +mod_fixed <- update(mod_randint, formula. = Surv(eventtime, status) ~ trt) +``` + +Let's calculate the `loo` for both these models and compare them: + +```{r frail-compare-1, warning = FALSE, message = FALSE} +loo_fixed <- loo(mod_fixed) +loo_randint <- loo(mod_randint) +loo_compare(loo_fixed, loo_randint) +``` + +We see strong evidence in favour of the model with the site-specific intercepts! + +But let's not quite finish there. What about if we want to generalise the random effects structure further. For instance, is the site-specific intercept enough? Perhaps we should consider estimating both a site-specific intercept and a site-specific treatment effect. We have minimal data to estimate such a model (recall that there is only 20 sites and 10 patients per site) but for the sake of demonstration we will forge on nonetheless. Let's fit a model with both a site-specific intercept and a site-specific coefficient for the covariate `trt` (i.e. treatment): + +```{r frail-random-trt, warning = FALSE, message = FALSE} +mod_randtrt <- update(mod_randint, formula. = + Surv(eventtime, status) ~ trt + (trt | site)) +print(mod_randtrt, digits = 2) +``` + +We see that we have an estimated standard deviation for the site-specific intercepts and the site-specific coefficients for `trt`, as well as the estimated correlation between those site-specific parameters. + +Let's now compare all three of these models based on `loo`: + +```{r frail-compare-2, warning = FALSE, message = FALSE} +loo_randtrt <- loo(mod_randtrt) +loo_compare(loo_fixed, loo_randint, loo_randtrt) +``` + +It appears that the model with just a site-specific intercept is the best fitting model. It is much better than the model without a site-specific intercept, and slightly better than the model with both a site-specific intercept and a site-specific treatment effect. In other words, including a site-specific intercept appears important, but including a site-specific treatment effect is not. This conclusion is reassuring, because it aligns with the data generating model we used to simulate the data! + + +# References + +Bates D, Maechler M, Bolker B, Walker S. Fitting Linear Mixed-Effects Models Using lme4. *Journal of Statistical Software* 2015;67(1):1--48. \url{https://doi.org/10.18637/jss.v067.i01} + +Brilleman S. (2018) *simsurv: Simulate Survival Data.* R package version 0.2.2. \url{https://CRAN.R-project.org/package=simsurv} + +Hougaard P. Fundamentals of Survival Data. *Biometrics* 1999;55:13--22. + +Ramsay JO. Monotone Regression Splines in Action. *Statistical Science* 1988;3(4):425--461. \url{https://doi.org/10.1214/ss/1177012761} + +Wang W, Yan J. (2018) *splines2: Regression Spline Functions and Classes.* R package version 0.2.8. \url{https://CRAN.R-project.org/package=splines2} + + +# Appendix A: Parameterisations on the hazard scale + +When `basehaz` is set equal to `"exp"`, `"weibull"`, `"gompertz"`, `"ms"` (the default), or `"bs"` then the model is defined on the hazard scale using the following parameterisations. + + +### Exponential model + +The exponential model is parameterised with scale parameter $\lambda_i = \exp(\eta_i)$ where $\eta_i = \beta_0 + \sum_{p=1}^P \beta_p x_{ip}$ denotes our linear predictor. + +For individual $i$ we have: + +\begin{align} +\begin{split} + h_i(T_i) + & = \lambda_i \\ + & = \exp(\eta_i) \\ + H_i(T_i) + & = T_i \lambda_i \\ + & = T_i \exp(\eta_i) \\ + S_i(T_i) + & = \exp \left( - T_i \lambda_i \right) \\ + & = \exp \left( - T_i \exp(\eta_i) \right) \\ + F_i(T_i) + & = 1 - \exp \left( - T_i \lambda_i \right) \\ + & = 1 - \exp \left( - T_i \exp(\eta_i) \right) \\ + S_i(T_i) - S_i(T_i^U) + & = \exp \left( - T_i \lambda_i \right) - \exp \left( - T_i^U \lambda_i \right) \\ + & = \exp \left( - T_i \exp(\eta_i) \right) - \exp \left( - T_i^U \exp(\eta_i) \right) +\end{split} +\end{align} + +or on the log scale: + +\begin{align} +\begin{split} + \log h_i(T_i) + & = \log \lambda_i \\ + & = \eta_i \\ + \log H_i(T_i) + & = \log(T_i) + \log \lambda_i \\ + & = \log(T_i) + \eta_i \\ + \log S_i(T_i) + & = - T_i \lambda_i \\ + & = - T_i \exp(\eta_i) \\ + \log F_i(T_i) + & = \log \left( 1 - \exp \left( - T_i \lambda_i \right) \right) \\ + & = \log \left( 1 - \exp \left( - T_i \exp(\eta_i) \right) \right) \\ + \log (S_i(T_i) - S_i(T_i^U)) + & = \log \left[ \exp \left( - T_i \lambda_i \right) - \exp \left( - T_i^U \lambda_i \right) \right] \\ + & = \log \left[ \exp \left( - T_i \exp(\eta_i) \right) - \exp \left( - T_i^U \exp(\eta_i) \right) \right] +\end{split} +\end{align} + +The definition of $\lambda$ for the baseline is: + +\begin{align} +\begin{split} + \lambda_0 = \exp(\beta_0) \Longleftrightarrow \beta_0 = \log(\lambda_0) +\end{split} +\end{align} + + +### Weibull model + +The Weibull model is parameterised with scale parameter $\lambda_i = \exp(\eta_i)$ and shape parameter $\gamma > 0$ where $\eta_i = \beta_0 + \sum_{p=1}^P \beta_p x_{ip}$ denotes our linear predictor. + +For individual $i$ we have: + +\begin{align} +\begin{split} + h_i(T_i) + & = \gamma t^{\gamma-1} \lambda_i \\ + & = \gamma t^{\gamma-1} \exp(\eta_i) \\ + H_i(T_i) + & = T_i^{\gamma} \lambda_i \\ + & = T_i^{\gamma} \exp(\eta_i) \\ + S_i(T_i) + & = \exp \left( - T_i^{\gamma} \lambda_i \right) \\ + & = \exp \left( - T_i^{\gamma} \exp(\eta_i) \right) \\ + F_i(T_i) + & = 1 - \exp \left( - {(T_i)}^{\gamma} \lambda_i \right) \\ + & = 1 - \exp \left( - {(T_i)}^{\gamma} \exp(\eta_i) \right) \\ + S_i(T_i) - S_i(T_i^U) + & = \exp \left( - {(T_i)}^{\gamma} \lambda_i \right) - \exp \left( - {(T_i^U)}^{\gamma} \lambda_i \right) \\ + & = \exp \left( - {(T_i)}^{\gamma} \exp(\eta_i) \right) - \exp \left( - {(T_i^U)}^{\gamma} \exp(\eta_i) \right) +\end{split} +\end{align} + +or on the log scale: + +\begin{align} +\begin{split} + \log h_i(T_i) + & = \log(\gamma) + (\gamma-1) \log(t) + \log \lambda_i \\ + & = \log(\gamma) + (\gamma-1) \log(t) + \eta_i \\ + \log H_i(T_i) + & = \gamma \log(T_i) + \log \lambda_i \\ + & = \gamma \log(T_i) + \eta_i \\ + \log S_i(T_i) + & = - T_i^{\gamma} \lambda_i \\ + & = - T_i^{\gamma} \exp(\eta_i) \\ + \log F_i(T_i) + & = \log \left( 1 - \exp \left( - {(T_i)}^{\gamma} \lambda_i \right) \right) \\ + & = \log \left( 1 - \exp \left( - {(T_i)}^{\gamma} \exp(\eta_i) \right) \right) \\ + \log (S_i(T_i) - S_i(T_i^U)) + & = \log \left[ \exp \left( - {(T_i)}^{\gamma} \lambda_i \right) - \exp \left( - {(T_i^U)}^{\gamma} \lambda_i \right) \right] \\ + & = \log \left[ \exp \left( - {(T_i)}^{\gamma} \exp(\eta_i) \right) - \exp \left( - {(T_i^U)}^{\gamma} \exp(\eta_i) \right) \right] +\end{split} +\end{align} + +The definition of $\lambda$ for the baseline is: + +\begin{align} +\begin{split} + \lambda_0 = \exp(\beta_0) \Longleftrightarrow \beta_0 = \log(\lambda_0) +\end{split} +\end{align} + + +### Gompertz model + +The Gompertz model is parameterised with shape parameter $\lambda_i = \exp(\eta_i)$ and scale parameter $\gamma > 0$ where $\eta_i = \beta_0 + \sum_{p=1}^P \beta_p x_{ip}$ denotes our linear predictor. + +For individual $i$ we have: + +\begin{align} +\begin{split} + h_i(T_i) + & = \exp(\gamma T_i) \lambda_i \\ + & = \exp(\gamma T_i) \exp(\eta_i) \\ + H_i(T_i) + & = \frac{\exp(\gamma T_i) - 1}{\gamma} \lambda_i \\ + & = \frac{\exp(\gamma T_i) - 1}{\gamma} \exp(\eta_i) \\ + S_i(T_i) + & = \exp \left( \frac{-(\exp(\gamma T_i) - 1)}{\gamma} \lambda_i \right) \\ + & = \exp \left( \frac{-(\exp(\gamma T_i) - 1)}{\gamma} \exp(\eta_i) \right) \\ + F_i(T_i) + & = 1 - \exp \left( \frac{-(\exp(\gamma T_i) - 1)}{\gamma} \lambda_i \right) \\ + & = 1 - \exp \left( \frac{-(\exp(\gamma T_i) - 1)}{\gamma} \exp(\eta_i) \right) \\ + S_i(T_i) - S_i(T_i^U) + & = \exp \left( \frac{-(\exp(\gamma T_i) - 1)}{\gamma} \lambda_i \right) - \exp \left( \frac{-(\exp(\gamma T_i^U) - 1)}{\gamma} \lambda_i \right) \\ + & = \exp \left( \frac{-(\exp(\gamma T_i) - 1)}{\gamma} \exp(\eta_i) \right) - \exp \left( \frac{-(\exp(\gamma T_i^U) - 1)}{\gamma} \exp(\eta_i) \right) +\end{split} +\end{align} + +or on the log scale: + +\begin{align} +\begin{split} + \log h_i(T_i) + & = \gamma T_i + \log \lambda_i \\ + & = \gamma T_i + \eta_i \\ + \log H_i(T_i) + & = \log(\exp(\gamma T_i) - 1) - \log(\gamma) + \log \lambda_i \\ + & = \log(\exp(\gamma T_i) - 1) - \log(\gamma) + \eta_i \\ + \log S_i(T_i) + & = \frac{-(\exp(\gamma T_i) - 1)}{\gamma} \lambda_i \\ + & = \frac{-(\exp(\gamma T_i) - 1)}{\gamma} \exp(\eta_i) \\ + \log F_i(T_i) + & = \log \left( 1 - \exp \left( \frac{-(\exp(\gamma T_i) - 1)}{\gamma} \lambda_i \right) \right) \\ + & = \log \left( 1 - \exp \left( \frac{-(\exp(\gamma T_i) - 1)}{\gamma} \exp(\eta_i) \right) \right) \\ + \log (S_i(T_i) - S_i(T_i^U)) + & = \log \left[ \exp \left( \frac{-(\exp(\gamma T_i) - 1)}{\gamma} \lambda_i \right) - \exp \left( \frac{-(\exp(\gamma T_i^U) - 1)}{\gamma} \lambda_i \right) \right] \\ + & = \log \left[ \exp \left( \frac{-(\exp(\gamma T_i) - 1)}{\gamma} \exp(\eta_i) \right) - \exp \left( \frac{-(\exp(\gamma T_i^U) - 1)}{\gamma} \exp(\eta_i) \right) \right] +\end{split} +\end{align} + +The definition of $\lambda$ for the baseline is: + +\begin{align} +\begin{split} + \lambda_0 = \exp(\beta_0) \Longleftrightarrow \beta_0 = \log(\lambda_0) +\end{split} +\end{align} + + +### M-spline model + +The M-spline model is parameterised with vector of regression coefficients $\boldsymbol{\theta} > 0$ for the baseline hazard and with covariate effects introduced through a linear predictor $\eta_i = \sum_{p=1}^P \beta_p x_{ip}$. Note that there is no intercept in the linear predictor since it is absorbed into the baseline hazard spline function. + +For individual $i$ we have: + +\begin{align} +\begin{split} + h_i(T_i) + & = M(T_i; \boldsymbol{\theta}, \boldsymbol{k_0}) \exp(\eta_i) \\ + H_i(T_i) + & = I(T_i; \boldsymbol{\theta}, \boldsymbol{k_0}) \exp(\eta_i) \\ + S_i(T_i) + & = \exp \left( - I(T_i; \boldsymbol{\theta}, \boldsymbol{k_0}) \exp(\eta_i) \right) \\ + F_i(T_i) + & = 1 - \exp \left( - I(T_i; \boldsymbol{\theta}, \boldsymbol{k_0}) \exp(\eta_i) \right) \\ + S_i(T_i) - S_i(T_i^U) + & = \exp \left( - I(T_i; \boldsymbol{\theta}, \boldsymbol{k_0}) \exp(\eta_i) \right) - \exp \left( - I(T_i^U; \boldsymbol{\theta}, \boldsymbol{k_0}) \exp(\eta_i) \right) +\end{split} +\end{align} + +or on the log scale: + +\begin{align} +\begin{split} + \log h_i(T_i) + & = \log(M(T_i; \boldsymbol{\theta}, \boldsymbol{k_0})) + \eta_i \\ + \log H_i(T_i) + & = \log(I(T_i; \boldsymbol{\theta}, \boldsymbol{k_0})) + \eta_i \\ + \log S_i(T_i) + & = - I(T_i; \boldsymbol{\theta}, \boldsymbol{k_0}) \exp(\eta_i) \\ + \log F_i(T_i) + & = \log \left[ 1 - \exp \left( - I(T_i; \boldsymbol{\theta}, \boldsymbol{k_0}) \exp(\eta_i) \right) \right] \\ + \log (S_i(T_i) - S_i(T_i^U)) + & = \log \left[ \exp \left( - I(T_i; \boldsymbol{\theta}, \boldsymbol{k_0}) \exp(\eta_i) \right) - \exp \left( - I(T_i^U; \boldsymbol{\theta}, \boldsymbol{k_0}) \exp(\eta_i) \right) \right] +\end{split} +\end{align} + +where $M(t; \boldsymbol{\theta}, \boldsymbol{k_0})$ denotes a cubic M-spline function evaluated at time $t$ with regression coefficients $\boldsymbol{\theta}$ and basis evaluated using the vector of knot locations $\boldsymbol{k_0})$. Similarly, $I(t; \boldsymbol{\theta}, \boldsymbol{k_0})$ denotes a cubic I-spline function (i.e. integral of an M-spline) evaluated at time $t$ with regression coefficients $\boldsymbol{\theta}$ and basis evaluated using the vector of knot locations $\boldsymbol{k_0}$. + + +### B-spline model + +The B-spline model is parameterised with vector of regression coefficients $\boldsymbol{\theta}$ and linear predictor where $\eta_i = \sum_{p=1}^P \beta_p x_{ip}$ denotes our linear predictor. Note that there is no intercept in the linear predictor since it is absorbed into the spline function. + + +For individual $i$ we have: + +\begin{align} +\begin{split} + h_i(T_i) + & = \exp \left( B(T_i; \boldsymbol{\theta}, \boldsymbol{k_0}) + \eta_i \right) +\end{split} +\end{align} + +or on the log scale: + +\begin{align} +\begin{split} + \log h_i(T_i) + & = B(T_i; \boldsymbol{\theta}, \boldsymbol{k_0}) + \eta_i +\end{split} +\end{align} + +The cumulative hazard, survival function, and CDF for the B-spline model cannot be calculated analytically. Instead, the model is only defined analytically on the hazard scale and quadrature is used to evaluate the following: + +\begin{align} +\begin{split} + H_i(T_i) + & = \int_0^{T_i} h_i(u) du \\ + S_i(T_i) + & = \exp \left( - \int_0^{T_i} h_i(u) du \right) \\ + F_i(T_i) + & = 1 - \exp \left( - \int_0^{T_i} h_i(u) du \right) \\ + S_i(T_i) - S_i(T_i^U) + & = \exp \left( -\int_0^{T_i} h_i(u) du \right) - \exp \left( - \int_0^{T_i^U} h_i(u) du \right) +\end{split} +\end{align} + + +### Extension to time-varying coefficients (i.e. non-proportional hazards) + +We can extend the previous model formulations to allow for time-varying coefficients (i.e. non-proportional hazards). The time-varying linear predictor is introduced on the hazard scale. That is, $\eta_i$ in our previous model definitions is instead replaced by $\eta_i(t)$. This leads to an analytical form for the hazard and log hazard. However, in general, there is no longer a closed form expression for the cumulative hazard, survival function, or CDF. Therefore, when the linear predictor includes time-varying coefficients, quadrature is used to evaluate the following: + +\begin{align} +\begin{split} + H_i(T_i) + & = \int_0^{T_i} h_i(u) du \\ + S_i(T_i) + & = \exp \left( - \int_0^{T_i} h_i(u) du \right) \\ + F_i(T_i) + & = 1 - \exp \left( - \int_0^{T_i} h_i(u) du \right) \\ + S_i(T_i) - S_i(T_i^U) + & = \exp \left( -\int_0^{T_i} h_i(u) du \right) - \exp \left( - \int_0^{T_i^U} h_i(u) du \right) +\end{split} +\end{align} + + +# Appendix B: Parameterisations under accelerated failure times + +When `basehaz` is set equal to `"exp-aft"`, or `"weibull-aft"` then the model is defined on the accelerated failure time scale using the following parameterisations. + + +### Exponential model + +The exponential model is parameterised with scale parameter $\lambda_i = \exp(-\eta_i)$ where $\eta_i = \beta_0^* + \sum_{p=1}^P \beta_p^* x_{ip}$ denotes our linear predictor. + +For individual $i$ we have: + +\begin{align} +\begin{split} + h_i(T_i) + & = \lambda_i \\ + & = \exp(-\eta_i) \\ + H_i(T_i) + & = T_i \lambda_i \\ + & = T_i \exp(-\eta_i) \\ + S_i(T_i) + & = \exp \left( - T_i \lambda_i \right) \\ + & = \exp \left( - T_i \exp(-\eta_i) \right) \\ + F_i(T_i) + & = 1 - \exp \left( - T_i \lambda_i \right) \\ + & = 1 - \exp \left( - T_i \exp(-\eta_i) \right) \\ + S_i(T_i) - S_i(T_i^U) + & = \exp \left( - T_i \lambda_i \right) - \exp \left( - T_i^U \lambda_i \right) \\ + & = \exp \left( - T_i \exp(-\eta_i) \right) - \exp \left( - T_i^U \exp(-\eta_i) \right) +\end{split} +\end{align} + +or on the log scale: + +\begin{align} +\begin{split} + \log h_i(T_i) + & = \log \lambda_i \\ + & = -\eta_i \\ + \log H_i(T_i) + & = \log(T_i) + \log \lambda_i \\ + & = \log(T_i) - \eta_i \\ + \log S_i(T_i) + & = - T_i \lambda_i \\ + & = - T_i \exp(-\eta_i) \\ + \log F_i(T_i) + & = \log \left( 1 - \exp \left( - T_i \lambda_i \right) \right) \\ + & = \log \left( 1 - \exp \left( - T_i \exp(-\eta_i) \right) \right) \\ + \log (S_i(T_i) - S_i(T_i^U)) + & = \log \left[ \exp \left( - T_i \lambda_i) \right) - \exp \left( - T_i^U \lambda_i \right) \right] \\ + & = \log \left[ \exp \left( - T_i \exp(-\eta_i) \right) - \exp \left( - T_i^U \exp(-\eta_i) \right) \right] +\end{split} +\end{align} + +The definition of $\lambda$ for the baseline is: + +\begin{align} +\begin{split} + \lambda_0 = \exp(-\beta_0^*) \Longleftrightarrow \beta_0^* = -\log(\lambda_0) +\end{split} +\end{align} + +The relationship between coefficients under the PH (unstarred) and AFT (starred) parameterisations are as follows: + +\begin{align} +\begin{split} + \beta_0 & = -\beta_0^* \\ + \beta_p & = -\beta_p^* +\end{split} +\end{align} + +Lastly, the general form for the hazard function and survival function under an AFT model with acceleration factor $\exp(-\eta_i)$ can be used to derive the exponential AFT model defined here by setting $h_0(t) = 1$, $S_0(t) = \exp(-T_i)$, and $\lambda_i = \exp(-\eta_i)$: + +\begin{align} +\begin{split} + h_i(T_i) + & = \exp(-\eta_i) h_0(t \exp(-\eta_i)) \\ + & = \exp(-\eta_i) \\ + & = \lambda_i +\end{split} +\end{align} + +\begin{align} +\begin{split} + S_i(T_i) + & = S_0(t \exp(-\eta_i)) \\ + & = \exp(-T_i \exp(-\eta_i)) \\ + & = \exp(-T_i \lambda_i) +\end{split} +\end{align} + + +### Weibull model + +The Weibull model is parameterised with scale parameter $\lambda_i = \exp(-\gamma \eta_i)$ and shape parameter $\gamma > 0$ where $\eta_i = \beta_0^* + \sum_{p=1}^P \beta_p^* x_{ip}$ denotes our linear predictor. + +For individual $i$ we have: + +\begin{align} +\begin{split} + h_i(T_i) + & = \gamma t^{\gamma-1} \lambda_i \\ + & = \gamma t^{\gamma-1} \exp(-\gamma \eta_i) \\ + H_i(T_i) + & = T_i^{\gamma} \lambda_i \\ + & = T_i^{\gamma} \exp(-\gamma \eta_i) \\ + S_i(T_i) + & = \exp \left( - T_i^{\gamma} \lambda_i \right) \\ + & = \exp \left( - T_i^{\gamma} \exp(-\gamma \eta_i) \right) \\ + F_i(T_i) + & = 1 - \exp \left( - {(T_i)}^{\gamma} \lambda_i \right) \\ + & = 1 - \exp \left( - {(T_i)}^{\gamma} \exp(-\gamma \eta_i) \right) \\ + S_i(T_i) - S_i(T_i^U) + & = \exp \left( - {(T_i)}^{\gamma} \lambda_i \right) - \exp \left( - {(T_i^U)}^{\gamma} \lambda_i \right) \\ + & = \exp \left( - {(T_i)}^{\gamma} \exp(-\gamma \eta_i) \right) - \exp \left( - {(T_i^U)}^{\gamma} \exp(-\gamma \eta_i) \right) +\end{split} +\end{align} + +or on the log scale: + +\begin{align} +\begin{split} + \log h_i(T_i) + & = \log(\gamma) + (\gamma-1) \log(t) + \log \lambda_i \\ + & = \log(\gamma) + (\gamma-1) \log(t) - \gamma \eta_i \\ + \log H_i(T_i) + & = \gamma \log(T_i) + \log \lambda_i \\ + & = \gamma \log(T_i) - \gamma \eta_i \\ + \log S_i(T_i) + & = - T_i^{\gamma} \lambda_i \\ + & = - T_i^{\gamma} \exp(-\gamma \eta_i) \\ + \log F_i(T_i) + & = \log \left( 1 - \exp \left( - {(T_i)}^{\gamma} \lambda_i \right) \right) \\ + & = \log \left( 1 - \exp \left( - {(T_i)}^{\gamma} \exp(-\gamma \eta_i) \right) \right) \\ + \log (S_i(T_i) - S_i(T_i^U)) + & = \log \left[ \exp \left( - {(T_i)}^{\gamma} \lambda_i \right) - \exp \left( - {(T_i^U)}^{\gamma} \lambda_i \right) \right] \\ + & = \log \left[ \exp \left( - {(T_i)}^{\gamma} \exp(-\gamma \eta_i) \right) - \exp \left( - {(T_i^U)}^{\gamma} \exp(-\gamma \eta_i) \right) \right] +\end{split} +\end{align} + +The definition of $\lambda$ for the baseline is: + +\begin{align} +\begin{split} + \lambda_0 = \exp(-\gamma \beta_0^*) \Longleftrightarrow \beta_0^* = \frac{-\log(\lambda_0)}{\gamma} +\end{split} +\end{align} + +The relationship between coefficients under the PH (unstarred) and AFT (starred) parameterisations are as follows: + +\begin{align} +\begin{split} + \beta_0 & = -\gamma \beta_0^* \\ + \beta_p & = -\gamma \beta_p^* +\end{split} +\end{align} + +Lastly, the general form for the hazard function and survival function under an AFT model with acceleration factor $\exp(-\eta_i)$ can be used to derive the Weibull AFT model defined here by setting $h_0(t) = \gamma t^{\gamma - 1}$, $S_0(t) = \exp(-T_i^{\gamma})$, and $\lambda_i = \exp(-\gamma \eta_i)$: + +\begin{align} +\begin{split} + h_i(T_i) + & = \exp(-\eta_i) h_0(t \exp(-\eta_i)) \\ + & = \exp(-\eta_i) \gamma {(t \exp(-\eta_i))}^{\gamma - 1} \\ + & = \exp(-\gamma \eta_i) \gamma t^{\gamma - 1} \\ + & = \lambda_i \gamma t^{\gamma - 1} +\end{split} +\end{align} + +\begin{align} +\begin{split} + S_i(T_i) + & = S_0(t \exp(-\eta_i)) \\ + & = \exp(-(T_i \exp(-\eta_i))^{\gamma}) \\ + & = \exp(-T_i^{\gamma} [\exp(-\eta_i)]^{\gamma}) \\ + & = \exp(-T_i^{\gamma} \exp(-\gamma \eta_i)) \\ + & = \exp(-T_i \lambda_i) +\end{split} +\end{align} + + +### Extension to time-varying coefficients (i.e. time-varying acceleration factors) + +We can extend the previous model formulations to allow for time-varying coefficients (i.e. time-varying acceleration factors). + +The so-called "unmoderated" survival probability for an individual at time $t$ is defined as the baseline survival probability at time $t$, i.e. $S_i(t) = S_0(t)$. With a time-fixed acceleration factor, the survival probability for a so-called "moderated" individual is defined as the baseline survival probability but evaluated at "time $t$ multiplied by the acceleration factor $\exp(-\eta_i)$". That is, the survival probability for the moderated individual is $S_i(t) = S_0(t \exp(-\eta_i))$. + +However, with time-varying acceleration we cannot simply multiply time by a fixed (acceleration) constant. Instead, we must integrate the function for the time-varying acceleration factor over the interval $0$ to $t$. In other words, we must evaluate: +\ +\begin{align} +\begin{split} + S_i(t) = S_0 \left( \int_0^t \exp(-\eta_i(u)) du \right) +\end{split} +\end{align} +\ +as described by Hougaard (1999). + +Hougaard also gives a general expression for the hazard function under time-varying acceleration, as follows: +\ +\begin{align} +\begin{split} + h_i(t) = \exp \left(-\eta_i(t) \right) h_0 \left( \int_0^t \exp(-\eta_i(u)) du \right) +\end{split} +\end{align} + +**Note:** It is interesting to note here that the *hazard* at time $t$ is in fact a function of the full history of covariates and parameters (i.e. the linear predictor) from time $0$ up until time $t$. This is different to the hazard scale formulation of time-varying effects (i.e. non-proportional hazards). Under the hazard scale formulation with time-varying effects, the *survival* probability is a function of the full history between times $0$ and $t$, but the *hazard* is **not**; instead, the hazard is only a function of covariates and parameters as defined at the current time. This is particularly important to consider when fitting accelerated failure time models with time-varying effects in the presence of delayed entry (i.e. left truncation). + +For the exponential distribution, this leads to: + +\begin{align} +\begin{split} + S_i(T_i) + & = S_0 \left( \int_0^{T_i} \exp(-\eta_i(u)) du \right) \\ + & = \exp \left(- \int_0^{T_i} \exp(-\eta_i(u)) du \right) +\end{split} +\end{align} + +\begin{align} +\begin{split} + h_i(T_i) + & = \exp \left(-\eta_i(T_i) \right) h_0 \left( \int_0^{T_i} \exp(-\eta_i(u)) du \right) \\ + & = \exp \left(-\eta_i(T_i) \right) \exp \left(- \int_0^{T_i} \exp(-\eta_i(u)) du \right) +\end{split} +\end{align} + +and for the Weibull distribution, this leads to: + +\begin{align} +\begin{split} + S_i(T_i) + & = S_0 \left( \int_0^{T_i} \exp(-\eta_i(u)) du \right) \\ + & = \exp \left(- \left[\int_0^{T_i} \exp (-\eta_i(u)) du \right]^{\gamma} \right) +\end{split} +\end{align} + +\begin{align} +\begin{split} + h_i(T_i) + & = \exp \left(-\eta_i(T_i) \right) h_0 \left( \int_0^{T_i} \exp(-\eta_i(u)) du \right) \\ + & = \exp \left(-\eta_i(T_i) \right) \exp \left(- \left[\int_0^{T_i} \exp (-\eta_i(u)) du \right]^{\gamma} \right) +\end{split} +\end{align} + +The general expressions for the hazard and survival function under an AFT model with a time-varying linear predictor are used to evaluate the likelihood for the accelerated failure time model in `stan_surv` when time-varying effects are specified in the model formula. Specifically, quadrature is used to evaluate the cumulative acceleration factor $\int_0^t \exp(-\eta_i(u)) du$ and this is then substituted into the relevant expressions for the hazard and survival.