Skip to content

Commit

Permalink
Fix issues with forthcoming {marginaleffects} (#363)
Browse files Browse the repository at this point in the history
* Just for testing purposes, if latest changes in marginaleffects work

* version

* fix

* fix

* fix

* fix?

* comments

* fix test

* skip for now

* clean-up todo's

* test

* validate against predict and me

* Copy data sets from ggeffects, to remove dependency

* wordlist

* update data

* fix?

* fix

* wordlist

* pkgdown

* back to rdata files

* fix

* dont ignore data folder

* fix

* fix

* fix

* fix

* Update get_marginalcontrasts.R

* Update test-estimate_contrasts.R

* Update get_marginalcontrasts.R

* Update get_marginalcontrasts.R

* Update get_marginalcontrasts.R

* Update format.R

* reverse order

* Update format.R

* Update get_marginalcontrasts.R

* update

* fix sorting

* fix

* fix

* Update test-estimate_contrasts.R

* Update get_marginalmeans.R

* Update format.R

* Update test-estimate_contrasts.R

* fix

* Update test-estimate_contrasts.R

* fix

* Update format.R

* Update format.R

* Update test-estimate_contrasts.R

* Update format.R

* Update format.R

* fixes

* update

* Update test-glmmTMB.R

* update snaps

* remove redundant code, fix tests

* update readme

* update tests

* some snaps on windows only
  • Loading branch information
strengejacke authored Jan 25, 2025
1 parent 305fa52 commit 2795a46
Show file tree
Hide file tree
Showing 36 changed files with 905 additions and 669 deletions.
1 change: 0 additions & 1 deletion .Rbuildignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
^\.travis.yml
^\_pkgdown.yml
^LICENSE
^data/.
^docs/.
^paper/.
^vignettes/d.
Expand Down
5 changes: 3 additions & 2 deletions DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
Type: Package
Package: modelbased
Title: Estimation of Model-Based Predictions, Contrasts and Means
Version: 0.8.9.101
Version: 0.8.9.102
Authors@R:
c(person(given = "Dominique",
family = "Makowski",
Expand Down Expand Up @@ -55,7 +55,6 @@ Suggests:
Formula,
gamm4,
gganimate,
ggeffects,
ggplot2,
glmmTMB,
httr2,
Expand Down Expand Up @@ -89,3 +88,5 @@ Config/testthat/parallel: true
Roxygen: list(markdown = TRUE)
Config/Needs/check: stan-dev/cmdstanr
Config/Needs/website: easystats/easystatstemplate
Remotes: vincentarelbundock/marginaleffects
LazyData: true
40 changes: 40 additions & 0 deletions R/data.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
#' @docType data
#' @title Sample data set
#' @name fish
#' @keywords data
#'
#' @description A sample data set, used in tests and some examples. Useful for
#' demonstrating count models (with or without zero-inflation component). It
#' consists of nine variables from 250 observations.
NULL


#' @docType data
#' @title Sample dataset from the EFC Survey
#' @name efc
#' @keywords data
#'
#' @description Selected variables from the EUROFAMCARE survey. Useful when
#' testing on "real-life" data sets, including random missing values. This
#' data set also has value and variable label attributes.
NULL


#' @docType data
#' @title Sample dataset from a course about analysis of factorial designs
#' @name coffee_data
#' @keywords data
#'
#' @description A sample data set from a course about the analysis of factorial
#' designs, by Mattan S. Ben-Shachar. See following link for more information:
#' https://github.com/mattansb/Analysis-of-Factorial-Designs-foR-Psychologists
#'
#' The data consists of five variables from 120 observations:
#'
#' - `ID`: A unique identifier for each participant
#' - `sex`: The participant's sex
#' - `time`: The time of day the participant was tested (morning, noon, or afternoon)
#' - `coffee`: Group indicator, whether participant drank coffee or not
#' ("`coffee"` or `"control"`).
#' - `alertness`: The participant's alertness score.
NULL
6 changes: 6 additions & 0 deletions R/estimate_contrasts.R
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
#'
#' @examplesIf all(insight::check_if_installed(c("lme4", "marginaleffects", "rstanarm"), quietly = TRUE))
#' \dontrun{
#' options(marginaleffects_safe = FALSE)
#' # Basic usage
#' model <- lm(Sepal.Width ~ Species, data = iris)
#' estimate_contrasts(model)
Expand Down Expand Up @@ -115,6 +116,11 @@ estimate_contrasts <- function(model,
predict <- transform
}

# update comparison argument - if user provides a formula for the new
# marginaleffects version, we still want the string-option for internal
# processing...
comparison <- .get_marginaleffects_hypothesis_argument(comparison)$comparison

if (backend == "emmeans") {
# Emmeans ------------------------------------------------------------------
estimated <- get_emcontrasts(model,
Expand Down
86 changes: 75 additions & 11 deletions R/format.R
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,16 @@ format.estimate_contrasts <- function(x, format = NULL, ...) {

# arrange columns (not for contrast now)
by <- rev(attr(x, "focal_terms", exact = TRUE))

Check warning on line 13 in R/format.R

View workflow job for this annotation

GitHub Actions / lint / lint

file=R/format.R,line=13,col=3,[object_overwrite_linter] 'by' is an exported object from package 'base'. Avoid re-using such symbols.
if (!is.null(by) && all(by %in% colnames(x))) {
# add "Level" columns from contrasts
if (all(c("Level1", "Level2") %in% colnames(x))) {
by <- unique(c("Level1", "Level2", by))

Check warning on line 16 in R/format.R

View workflow job for this annotation

GitHub Actions / lint / lint

file=R/format.R,line=16,col=5,[object_overwrite_linter] 'by' is an exported object from package 'base'. Avoid re-using such symbols.
}
# check which columns actually exist
if (!is.null(by)) {
by <- intersect(by, colnames(x))

Check warning on line 20 in R/format.R

View workflow job for this annotation

GitHub Actions / lint / lint

file=R/format.R,line=20,col=5,[object_overwrite_linter] 'by' is an exported object from package 'base'. Avoid re-using such symbols.
}
# sort
if (length(by)) {
# arrange predictions
x <- datawizard::data_arrange(x, select = by)
# protect integers, only for focal terms
Expand Down Expand Up @@ -138,13 +147,20 @@ format.marginaleffects_slopes <- function(x, model, ci = 0.95, ...) {


#' @export
format.marginaleffects_contrasts <- function(x, model, p_adjust, comparison, ...) {
format.marginaleffects_contrasts <- function(x, model = NULL, p_adjust = NULL, comparison = NULL, ...) {

Check warning on line 150 in R/format.R

View workflow job for this annotation

GitHub Actions / lint / lint

file=R/format.R,line=150,col=1,[cyclocomp_linter] Reduce the cyclomatic complexity of this function from 65 to at most 40.
predict <- attributes(x)$predict

Check warning on line 151 in R/format.R

View workflow job for this annotation

GitHub Actions / lint / lint

file=R/format.R,line=151,col=3,[object_overwrite_linter] 'predict' is an exported object from package 'stats'. Avoid re-using such symbols.
by <- attributes(x)$by

Check warning on line 152 in R/format.R

View workflow job for this annotation

GitHub Actions / lint / lint

file=R/format.R,line=152,col=3,[object_overwrite_linter] 'by' is an exported object from package 'base'. Avoid re-using such symbols.
contrast <- attributes(x)$contrast
focal_terms <- attributes(x)$focal_terms
dgrid <- attributes(x)$datagrid

# sanity check - method "get_marginalmeans()" calls "format.estimate_means()"
# for printing, and that method doesn't pass "comparison" - thus, we have to
# extract it from the attributes
if (is.null(comparison)) {
comparison <- attributes(x)$comparison
}

# clean "by" and contrast variable names, for the special cases. for example,
# if we have `by = "name [fivenum]"`, we just want "name"
for (i in focal_terms) {
Expand All @@ -158,14 +174,15 @@ format.marginaleffects_contrasts <- function(x, model, p_adjust, comparison, ...

# only when we have a comparison based on these options from marginaleffects,
# we want to "clean" the parameter names
valid_options <- c(
"pairwise", "reference", "sequential", "meandev", "meanotherdev",
"revpairwise", "revreference", "revsequential"
)
valid_options <- .valid_hypothesis_strings()

# Column name for coefficient - fix needed for contrasting slopes
colnames(x)[colnames(x) == "Slope"] <- "Difference"

## TODO: we should be able to process more ways of comparisons here,
## e.g. also prettify labels and prepare levels for certain formula-written
## comparisons. Need to find out which ones.

# for contrasting slopes, we do nothing more here. for other contrasts,
# we prettify labels now
if (!is.null(comparison) && is.character(comparison) && comparison %in% valid_options) {
Expand Down Expand Up @@ -218,18 +235,51 @@ format.marginaleffects_contrasts <- function(x, model, p_adjust, comparison, ...

# for more than one term, we have comma-separated levels.
if (length(focal_terms) > 1) {
# levels may contain the separator char. to be 100% certain we extract
# levels correctly, we now replace levels with a special "token", and later
# replace those tokens with the original levels again

# extract all comparison levels
all_levels <- unlist(lapply(dgrid[focal_terms], function(i) as.character(unique(i))), use.names = FALSE)
# create replacement vector
replace_levels <- NULL
# this looks strange, but we need to make sure we have unique tokens that
# do not contain any letters or numbers, or similar characters that may
# appear as a single level in the data. thus, we use a sequence of "~"
# characters, which are unlikely to appear in the data
for (i in seq_along(all_levels)) {
replace_levels <- c(replace_levels, paste0("#", paste(rep_len("~", i), collapse = ""), "#"))
}

# replace all comparison levels with tokens
params[] <- lapply(params, function(comparison_pair) {
for (j in seq_along(all_levels)) {
comparison_pair <- sub(paste0("\\<", all_levels[j], "\\>"), replace_levels[j], comparison_pair)
}
comparison_pair
})

# we now have a data frame with each comparison-pairs as single column.
# next, we need to separate the levels from the different variables at the ","
# next, we need to separate the levels from the different variables at the
# separator char, "," for old marginaleffects, "_" for new marginaleffects
params <- datawizard::data_separate(
params,
separator = ",",
separator = "[_, ]",
guess_columns = "max",
verbose = FALSE
)
new_colnames <- paste0(
rep.int(focal_terms, 2),
rep(1:2, each = length(focal_terms))
)

# finally, replace all tokens with original comparison levels again
params[] <- lapply(params, function(comparison_pair) {
for (j in seq_along(all_levels)) {
comparison_pair <- sub(replace_levels[j], all_levels[j], comparison_pair, fixed = TRUE)
}
comparison_pair
})
} else {
new_colnames <- paste0(focal_terms, 1:2)
}
Expand Down Expand Up @@ -332,9 +382,22 @@ format.marginaleffects_contrasts <- function(x, model, p_adjust, comparison, ...
for (i in focal_terms) {
x[[i]] <- factor(x[[i]], levels = unique(x[[i]]))
}
# make sure filtering terms in `by` are factors, for data_arrange later
if (!is.null(by) && length(by)) {
for (i in by) {
if (i %in% colnames(dgrid) && i %in% colnames(x) && is.factor(dgrid[[i]]) && !is.factor(x[[i]])) { # nolint
x[[i]] <- factor(x[[i]], levels = unique(x[[i]]))
}
}
}
}
}

# remove () for single columns
if ("Parameter" %in% colnames(x)) {
x$Parameter <- gsub("(", "", gsub(")", "", x$Parameter, fixed = TRUE), fixed = TRUE)
}

x
}

Expand Down Expand Up @@ -411,9 +474,10 @@ format.marginaleffects_contrasts <- function(x, model, p_adjust, comparison, ...
if (!is.null(estimate_name) && !tolower(estimate_name) %in% .brms_aux_elements()) {
estimate_name <- coefficient_name
}
# and rename the "term" column (which we get from contrasts)
colnames(params)[colnames(params) == "term"] <- "Parameter"
}
# rename the "term" and "hypothesis" column (which we get from contrasts)
colnames(params)[colnames(params) == "term"] <- "Parameter"
colnames(params)[colnames(params) == "hypothesis"] <- "Parameter"

# add back ci? these are missing when contrasts are computed
params <- .add_contrasts_ci(is_contrast_analysis, params)
Expand Down Expand Up @@ -520,7 +584,7 @@ format.marginaleffects_contrasts <- function(x, model, p_adjust, comparison, ...
} else if (!predict_type %in% c("none", "link") && (info$is_binomial || info$is_bernoulli)) {
estimate_name <- "Probability"
} else if (predict_type %in% c("zprob", "zero")) {
estimate_name <- "Probability" ## TODO: could be renamed into ZI-Probability?
estimate_name <- "Probability"
} else {
estimate_name <- "Mean"
}
Expand Down
56 changes: 41 additions & 15 deletions R/get_marginalcontrasts.R
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ get_marginalcontrasts <- function(model,
# marginaleffects versions - newer versions don't accept a string argument,
# only formulas (older versions don't accept formulas)
hypothesis_arg <- .get_marginaleffects_hypothesis_argument(comparison, ...)
# update / reset argument
comparison <- hypothesis_arg$comparison

# check whether contrasts should be made for numerics or categorical
model_data <- insight::get_data(model, source = "mf", verbose = FALSE)
Expand Down Expand Up @@ -58,7 +60,7 @@ get_marginalcontrasts <- function(model,
trend = my_args$contrast,
by = my_args$by,
ci = ci,
hypothesis = hypothesis_arg,
hypothesis = hypothesis_arg$hypothesis,
backend = "marginaleffects",
verbose = verbose,
...
Expand All @@ -69,7 +71,7 @@ get_marginalcontrasts <- function(model,
model = model,
by = unique(c(my_args$contrast, my_args$by)),
ci = ci,
hypothesis = hypothesis_arg,
hypothesis = hypothesis_arg$hypothesis,
predict = predict,
backend = "marginaleffects",
marginalize = marginalize,
Expand All @@ -83,6 +85,7 @@ get_marginalcontrasts <- function(model,
out <- .p_adjust(model, out, p_adjust, verbose, ...)
}


# Last step: Save information in attributes --------------------------------
# ---------------------------------------------------------------------------

Expand All @@ -98,30 +101,53 @@ get_marginalcontrasts <- function(model,
)
)

class(out) <- unique(c("marginaleffects_contrasts", class(out)))
# remove "estimate_means" class attribute
class(out) <- setdiff(
unique(c("marginaleffects_contrasts", class(out))),
"estimate_means"
)
out
}


# make "comparison" argument compatible -----------------------------------

.get_marginaleffects_hypothesis_argument <- function(comparison, ...) {
# these are the string values that need to be converted to formulas
hypothesis_strings <- c(
# save original argument
hypothesis <- comparison
# check if we have such a string
if (!is.null(comparison)) {
if (is.character(comparison) &&
comparison %in% .valid_hypothesis_strings() &&
isTRUE(insight::check_if_installed("marginaleffects", quietly = TRUE)) &&
utils::packageVersion("marginaleffects") > "0.24.0") {
# convert to formula
hypothesis <- stats::as.formula(paste("~", comparison))
} else if (inherits(comparison, "formula")) {
# convert to character
comparison_string <- all.vars(comparison)
# update comparison
if (length(comparison_string) == 1 && comparison_string %in% .valid_hypothesis_strings()) {
comparison <- comparison_string
}
}
}
# we want: "hypothesis" is the original argument provided by the user,
# can be a formula like ~pairwise, or a string like "pairwise". This is
# converted into the appropriate type depending on the marginaleffects
# version. "comparison" should always be a character string, for internal
# processing.
list(hypothesis = hypothesis, comparison = comparison)
}


# these are the string values that need to be converted to formulas
.valid_hypothesis_strings <- function() {
c(
"pairwise", "reference", "sequential", "meandev", "meanotherdev",
"revpairwise", "revreference", "revsequential", "poly", "helmert",
"trt_vs_ctrl"
)
# check if we have such a string
if (!is.null(comparison) &&
is.character(comparison) &&
comparison %in% hypothesis_strings &&
isTRUE(insight::check_if_installed("marginaleffects", quietly = TRUE)) &&
utils::packageVersion("marginaleffects") > "0.24.0") {
# convert to formula
comparison <- stats::as.formula(paste("~", comparison))
}
comparison
}


Expand Down
3 changes: 3 additions & 0 deletions R/get_marginalmeans.R
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,10 @@ get_marginalmeans <- function(model,
# =========================================================================
# fix term label for custom hypothesis
if (.is_custom_comparison(comparison)) {
## TODO: check which column name is used in marginaleffects update, and
## keep only the new one later
means$term <- gsub(" ", "", comparison, fixed = TRUE)
means$hypothesis <- gsub(" ", "", comparison, fixed = TRUE)
}

# Last step: Save information in attributes --------------------------------
Expand Down
2 changes: 0 additions & 2 deletions R/visualisation_recipe_internal.R
Original file line number Diff line number Diff line change
Expand Up @@ -285,8 +285,6 @@

#' @keywords internal
.visualization_recipe_rawdata <- function(x, aes) {
# TODO: In the main function, don't forget to NOT add raw data when `predict` is not "response"

model <- attributes(x)$model
rawdata <- insight::get_data(model, verbose = FALSE)

Expand Down
1 change: 1 addition & 0 deletions README.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ knitr::opts_chunk$set(
message = FALSE,
out.width = "100%"
)
options(marginaleffects_safe = FALSE)
```


Expand Down
Loading

0 comments on commit 2795a46

Please sign in to comment.