Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/lm list #108

Open
wants to merge 12 commits into
base: master
Choose a base branch
from
1 change: 1 addition & 0 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ cache: packages

r_github_packages:
- jimhester/covr
- HenrikBengtsson/matrixStats@develop

env:
matrix:
Expand Down
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ export(stan_glmer.nb)
export(stan_lm)
export(stan_lm.fit)
export(stan_lm.wfit)
export(stan_lmList)
export(stan_lmer)
export(stan_polr)
export(stan_polr.fit)
Expand Down
11 changes: 8 additions & 3 deletions R/loo.R
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,7 @@ ll_fun <- function(x) {
f <- family(x)
if (!is(f, "family") || is_scobit(x))
return(.ll_polr_i)

if (is.lmList(x)) return(.ll_gaussian_lmList_i)
get(paste0(".ll_", f$family, "_i"))
}

Expand Down Expand Up @@ -390,7 +390,7 @@ ll_args <- function(object, newdata = NULL) {
if (is(f, "family") && !is_scobit(object)) {
fname <- f$family
if (!is.binomial(fname)) {
data <- data.frame(y, x)
data <- data.frame(y, as.matrix(x))
} else {
if (NCOL(y) == 2L) {
trials <- rowSums(y)
Expand All @@ -405,7 +405,9 @@ ll_args <- function(object, newdata = NULL) {
}
draws$beta <- stanmat[, seq_len(ncol(x)), drop = FALSE]
if (is.gaussian(fname))
draws$sigma <- stanmat[, "sigma"]
draws$sigma <- stanmat[, grep("^sigma", colnames(stanmat))]
if (is.lmList(object))
draws$groups <- as.integer(object$groups)
if (is.gamma(fname))
draws$shape <- stanmat[, "shape"]
if (is.ig(fname))
Expand Down Expand Up @@ -492,6 +494,9 @@ ll_args <- function(object, newdata = NULL) {
val <- dnorm(data$y, mean = .mu(data,draws), sd = draws$sigma, log = TRUE)
.weighted(val, data$weights)
}
.ll_gaussian_lmList_i <- function(i, data, draws) {
val <- dnorm(data$y, mean = .mu(data,draws), sd = draws$sigma[,draws$group[i]], log = TRUE)
}
.ll_binomial_i <- function(i, data, draws) {
val <- dbinom(data$y, size = data$trials, prob = .mu(data,draws), log = TRUE)
.weighted(val, data$weights)
Expand Down
7 changes: 7 additions & 0 deletions R/misc.R
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,13 @@ is.mer <- function(x) {
isTRUE(check1 && check2)
}

# Test if stanreg object used stan_lmList
#
# @param x A stanreg object.
is.lmList <- function(x) {
inherits(x, "lmList")
}

# Consistent error message to use when something is only available for
# models fit using MCMC
#
Expand Down
9 changes: 8 additions & 1 deletion R/posterior_predict.R
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ posterior_predict <- function(object, newdata = NULL, draws = NULL,
else ppargs <- pp_args(object, data = pp_eta(object, dat, draws))
if (!is(object, "polr") && is.binomial(family(object)$family))
ppargs$trials <- pp_binomial_trials(object, newdata)
if (inherits(object, "lmList")) ppargs$group <- object$group

ppfun <- pp_fun(object)
ytilde <- do.call(ppfun, ppargs)
Expand All @@ -175,6 +176,7 @@ posterior_predict <- function(object, newdata = NULL, draws = NULL,
# functions to draw from the various posterior predictive distributions
pp_fun <- function(object) {
suffix <- if (is(object, "polr")) "polr" else family(object)$family
if (inherits(object, "lmList")) suffix <- paste0(suffix, "_grouped")
get(paste0(".pp_", suffix), mode = "function")
}

Expand All @@ -183,6 +185,11 @@ pp_fun <- function(object) {
rnorm(ncol(mu), mu[s,], sigma[s])
}))
}
.pp_gaussian_grouped <- function(mu, sigma, group) {
t(sapply(1:nrow(mu), function(s) {
rnorm(ncol(mu), mu[s,], sigma[s,group])
}))
}
.pp_binomial <- function(mu, trials) {
t(sapply(1:nrow(mu), function(s) {
rbinom(ncol(mu), size = trials, prob = mu[s, ])
Expand Down Expand Up @@ -260,7 +267,7 @@ pp_args <- function(object, data) {
args <- list(mu = inverse_link(eta))
famname <- family(object)$family
if (is.gaussian(famname)) {
args$sigma <- stanmat[, "sigma"]
args$sigma <- stanmat[, grep("^sigma", colnames(stanmat))]
} else if (is.gamma(famname)) {
args$shape <- stanmat[, "shape"]
} else if (is.ig(famname)) {
Expand Down
27 changes: 20 additions & 7 deletions R/pp_data.R
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,26 @@ pp_data <- function(object, newdata = NULL, re.form = NULL, ...) {
return(nlist(x, offset))
}
tt <- terms(object)
Terms <- delete.response(tt)
m <- model.frame(Terms, newdata, xlev = object$xlevels)
if (!is.null(cl <- attr(Terms, "dataClasses")))
.checkMFClasses(cl, m)
x <- model.matrix(Terms, m, contrasts.arg = object$contrasts)
if (is(object, "polr") && !is_scobit(object))
x <- x[,colnames(x) != "(Intercept)", drop = FALSE]
if (inherits(object, "lmList")) {
f <- as.character(object$formula)
f <- as.formula(paste(f[2], "~", "-1 + (", f[3], ")"))
g <- glFormula(f, newdata)
modelframe <- g$fr
x <- t(g$reTrms$Zt)
group_names <- levels(object$groups)
K <- NCOL(x) / length(group_names)
colnames(x) <- c(sapply(group_names, FUN = function(g) paste0(1:K, ":", g)))
x <- x[,sort(colnames(x))]
}
else {
Terms <- delete.response(tt)
m <- model.frame(Terms, newdata, xlev = object$xlevels)
if (!is.null(cl <- attr(Terms, "dataClasses")))
.checkMFClasses(cl, m)
x <- model.matrix(Terms, m, contrasts.arg = object$contrasts)
if (is(object, "polr") && !is_scobit(object))
x <- x[,colnames(x) != "(Intercept)", drop = FALSE]
}
offset <- rep(0, nrow(x))
if (!is.null(off.num <- attr(tt, "offset"))) {
for (i in off.num) {
Expand Down
6 changes: 4 additions & 2 deletions R/priors.R
Original file line number Diff line number Diff line change
Expand Up @@ -400,12 +400,14 @@ make_eta <- function(location, what = c("mode", "mean", "median", "log"), K) {
half_K <- K / 2
if (what == "mode") {
stopifnot(location > 0, location <= 1)
if (K <= 2)
stop(paste("R2 prior error.",
if (K <= 2) {
if (location == 0.5 && K == 2) what <- "mean"
else stop(paste("R2 prior error.",
"The mode of the beta distribution does not exist",
"with fewer than three predictors.",
"Specify 'what' as 'mean', 'median', or 'log' instead."),
call. = FALSE)
}
eta <- (half_K - 1 - location * half_K + location * 2) / location
} else if (what == "mean") {
stopifnot(location > 0, location < 1)
Expand Down
17 changes: 11 additions & 6 deletions R/rstanarm-package.R
Original file line number Diff line number Diff line change
Expand Up @@ -116,11 +116,14 @@
#' overview:
#'
#' \describe{
#' \item{\code{\link{stan_lm}}, \code{stan_aov}, \code{stan_biglm}}{
#' Similar to \code{\link[stats]{lm}} or \code{\link[stats]{aov}} but with
#' novel regularizing priors on the model parameters that are driven by prior
#' beliefs about \eqn{R^2}, the proportion of variance in the outcome
#' attributable to the predictors in a linear model.
#' \item{\code{\link{stan_lm}}, \code{stan_aov}, \code{stan_biglm}, \code{stan_lmList}}{
#' Similar to \code{\link[stats]{lm}} but with novel regularizing priors on the model
#' parameters that are driven by prior beliefs about \eqn{R^2}, the proportion of
#' variance in the outcome attributable to the predictors in a linear model. The
#' \code{stan_aov} variant estimates ANOVA models, the \code{stan_biglm} variant
#' estimates linear models when the design matrix is too large to fit into memory, and
#' the \code{stan_lmList} variant allows all the parameters to vary by group but utilizes
#' hyperpriors to shrink them toward common parameters, much like \code{stan_lmer}.
#' }
#' \item{\code{\link{stan_glm}}, \code{stan_glm.nb}}{
#' Similar to \code{\link[stats]{glm}} but with Gaussian, Student t, Cauchy
Expand All @@ -139,7 +142,9 @@
#' a highly-structured but unknown covariance matrix (for which \pkg{rstanarm}
#' introduces an innovative prior distribution). MCMC provides more
#' appropriate estimates of uncertainty for models that consist of a mix of
#' common and group-specific parameters.
#' common and group-specific parameters. The same line or argument applies even
#' more in the case of \code{stan_lmList} mentioned in the first stanza of this
#' section.
#' }
#' \item{\code{\link{stan_gamm4}}}{
#' Similar to \code{\link[gamm4]{gamm4}} in the \pkg{gamm4} package, which
Expand Down
22 changes: 18 additions & 4 deletions R/stan_biglm.R
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,17 @@
#' intercept and must utilize \emph{centered} but not \emph{standardized}
#' predictors. See the Details section or the Example.
#' @param xbar A numeric vector of means in the implicit design matrix for
#' the observations included in the model
#' the observations included in the model or --- in the case of
#' \code{stan_biglm.fit} only --- a list of such vectors with one list element
#' for each group
#' @param ybar A numeric scalar indicating the same mean of the outcome for
#' the observations included in the model
#' the observations included in the model or --- in the case of
#' \code{stan_biglm.fit} only --- a numeric vector of such means with one
#' element for each group
#' @param s_y A numeric scalar indicating the unbiased sample standard deviation
#' of the outcome for the observations included in the model
#' of the outcome for the observations included in the model or --- in the case
#' of \code{stan_biglm.fit} only --- a numeric vector of such standard deviations
#' with one element for each group
#' @param has_intercept A logical scalar indicating whether to add an intercept
#' to the model when estimating it
#' @template args-dots
Expand Down Expand Up @@ -69,6 +75,14 @@
#' The sample mean and sample standard deviation of the outcome must also
#' be passed.
#'
#' The \code{stan_biglm} function calls \code{stan_biglm.fit}, although the
#' latter can be called directly. The first seven arguments to the
#' \code{stan_biglm.fit} may be lists of the same length where the list length
#' is equal to the number of mutually exclusive and exhaustive groups that
#' the data have been stratified by. The \code{\link{stan_lmList}} function
#' provides a more conventional interface for a stratified linear regression
#' model and calls \code{stan_biglm.fit} internally.
#'
#' @return The output of both \code{stan_biglm} and \code{stan_biglm.fit} is an object of
#' \code{\link[rstan]{stanfit-class}} rather than \code{\link{stanreg-objects}},
#' which is more limited and less convenient but necessitated by the fact that
Expand Down Expand Up @@ -98,7 +112,7 @@ stan_biglm <- function(biglm, xbar, ybar, s_y, has_intercept = TRUE, ...,
R <- sqrt(biglm$qr$D) * R
return(stan_biglm.fit(b, R, SSR = biglm$qr$ss, N = biglm$n, xbar, ybar, s_y, has_intercept,
...,
prior = prior, prior_intercept = prior_intercept,
prior = prior, prior_intercept = prior_intercept, kappa_mean = 0,
prior_PD = prior_PD, algorithm = algorithm,
adapt_delta = adapt_delta))
}
Expand Down
89 changes: 64 additions & 25 deletions R/stan_biglm.fit.R
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,17 @@

#' @rdname stan_biglm
#' @export
#' @param b A numeric vector of OLS coefficients, excluding the intercept
#' @param R A square upper-triangular matrix from the QR decomposition of the design matrix
#' @param b A numeric vector of OLS coefficients --- excluding the intercept ---
#' or a list of such vectors with one list element for each group
#' @param R A square upper-triangular matrix from the QR decomposition of the
#' centered design matrix or a list of such matrices with one list element for
#' each group
#' @param SSR A numeric scalar indicating the sum-of-squared residuals for OLS
#' @param N A integer scalar indicating the number of included observations
#' or a numeric vector of sums-of-squared residuals with one element for each
#' group
#' @param N A integer scalar indicating the number of included observations or
#' an integer array of sample sizes with one element for each group
#' @template args-kappa_mean
#' @examples
#' # create inputs
#' ols <- lm(mpg ~ wt + qsec + am - 1, # next line is critical for centering
Expand All @@ -39,23 +46,41 @@
#' cbind(lm = b, stan_lm = rstan::get_posterior_mean(post)[14:16]) # shrunk
stan_biglm.fit <- function(b, R, SSR, N, xbar, ybar, s_y, has_intercept = TRUE, ...,
prior = R2(stop("'location' must be specified")),
prior_intercept = NULL, prior_PD = FALSE,
prior_intercept = NULL, kappa_mean = 1, prior_PD = FALSE,
algorithm = c("sampling", "meanfield", "fullrank"),
adapt_delta = NULL) {

J <- 1L
N <- array(N, c(J))
K <- ncol(R)
cn <- names(xbar)
if (is.null(cn)) cn <- names(b)
R_inv <- backsolve(R, diag(K))
JK <- c(J, K)
xbarR_inv <- array(c(xbar %*% R_inv), JK)
Rb <- array(R %*% b, JK)
SSR <- array(SSR, J)
s_Y <- array(s_y, J)
center_y <- if (isTRUE(all.equal(matrix(0, J, K), xbar))) ybar else 0.0
ybar <- array(ybar, J)
if (is.list(b)) {
J <- length(b)
K <- ncol(R[[1]])
cn <- names(xbar[[1]])
if (is.null(cn)) cn <- names(b[[1]])
I <- diag(K)
R_inv <- sapply(R, simplify = FALSE, FUN = backsolve, x = I)
center_y <- 0.0
xbarR_inv <- Rb <- vector("list", J)
names(xbarR_inv) <- names(Rb) <- names(b)
for (j in 1:J) {
xbarR_inv[[j]] <- c(xbar[[j]] %*% R_inv[[j]])
Rb[[j]] <- c(R[[j]] %*% b[[j]])
}
}
else {
J <- 1L
N <- array(N, c(J))
K <- ncol(R)
cn <- names(xbar)
if (is.null(cn)) cn <- names(b)
I <- diag(K)
R_inv <- backsolve(R, I)
JK <- c(J, K)
xbarR_inv <- array(c(xbar %*% R_inv), JK)
Rb <- array(R %*% b, JK)
SSR <- array(SSR, J)
s_y <- array(s_y, J)
center_y <- if (isTRUE(all.equal(matrix(0, J, K), xbar))) ybar else 0.0
ybar <- array(ybar, J)
dim(R_inv) <- c(J, dim(R_inv))
}

if (!length(prior)) {
prior_dist <- 0L
Expand All @@ -77,23 +102,28 @@ stan_biglm.fit <- function(b, R, SSR, N, xbar, ybar, s_y, has_intercept = TRUE,
if (is.null(prior_scale_for_intercept))
prior_scale_for_intercept <- 0
}
dim(R_inv) <- c(J, dim(R_inv))

# initial values
R2 <- array(1 - SSR[1] / ((N - 1) * s_Y^2), J)
if (J == 1) R2 <- array(1 - SSR[1] / ((N - 1) * s_y^2), J)
else R2 <- 1 - SSR / ((N - 1) * s_y^2)
log_omega <- array(0, ifelse(prior_PD == 0, J, 0))
init_fun <- function(chain_id) {
out <- list(R2 = R2, log_omega = log_omega)
if (has_intercept == 0L) out$z_alpha <- double()
if (J == 1) {
out$SMC <- double()
out$mu <- array(0, c(0, K))
out$kappa <- double()
}
return(out)
}
stanfit <- stanmodels$lm
standata <- nlist(K, has_intercept, prior_dist,
prior_dist_for_intercept,
prior_mean_for_intercept,
prior_scale_for_intercept,
prior_scale_for_intercept, kappa_mean,
prior_PD, eta, J, N, xbarR_inv,
ybar, center_y, s_Y, Rb, SSR, R_inv)
ybar, center_y, s_Y = s_y, Rb, SSR, R_inv)
pars <- c(if (has_intercept) "alpha", "beta", "sigma",
if (prior_PD == 0) "log_omega", "R2", "mean_PPD")
algorithm <- match.arg(algorithm)
Expand All @@ -109,9 +139,18 @@ stan_biglm.fit <- function(b, R, SSR, N, xbar, ybar, s_y, has_intercept = TRUE,
init = init_fun, data = standata, pars = pars, show_messages = FALSE)
stanfit <- do.call(sampling, sampling_args)
}
new_names <- c(if (has_intercept) "(Intercept)", cn, "sigma",
if (prior_PD == 0) "log-fit_ratio",
"R2", "mean_PPD", "log-posterior")
if (J == 1) new_names <- c(if (has_intercept) "(Intercept)", cn, "sigma",
if (prior_PD == 0) "log-fit_ratio",
"R2", "mean_PPD", "log-posterior")
else {
group_names <- names(b)
new_names <- c(if (has_intercept) paste0("(Intercept):", group_names),
t(sapply(group_names, FUN = function(g) paste0(cn, ":", g))),
paste0("sigma:", group_names),
if (prior_PD == 0) paste0("log-fit_ratio:", group_names),
paste0("R2:", group_names), paste0("mean_PPD:", group_names),
"log-posterior")
}
stanfit@sim$fnames_oi <- new_names
return(stanfit)
}
1 change: 1 addition & 0 deletions R/stan_lm.fit.R
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ stan_lm.wfit <- function(x, y, w, offset = NULL, singular.ok = TRUE, ...,
N = nrow(x), xbar = xbar, ybar = ybar, s_y = sd(y),
has_intercept = has_intercept, ...,
prior = prior, prior_intercept = prior_intercept,
kappa_mean = 0,
prior_PD = prior_PD, algorithm = algorithm,
adapt_delta = adapt_delta))

Expand Down
Loading