Skip to content

Commit

Permalink
use type checkers in backend-*.R files (#1556)
Browse files Browse the repository at this point in the history
  • Loading branch information
simonpcouch authored Nov 12, 2024
1 parent b1e8bab commit d854986
Show file tree
Hide file tree
Showing 13 changed files with 80 additions and 96 deletions.
8 changes: 4 additions & 4 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
# dbplyr (development version)

* Tightened argument checks for Snowflake SQL translations. These changes should
result in more informative errors in cases where code already failed; if you
see errors with code that used to run without issue, please report them to
the package authors (@simonpcouch, #1554).
* Tightened argument checks for SQL translations. These changes should
result in more informative errors in cases where code already failed, possibly
silently; if you see errors with code that used to run correctly, please report
them to the package authors (@simonpcouch, #1554, #1555).

* `clock::add_years()` translates to correct SQL on Spark (@ablack3, #1510).

Expand Down
2 changes: 2 additions & 0 deletions R/backend-.R
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,7 @@ base_scalar <- sql_translator(
# base R
nchar = sql_prefix("LENGTH", 1),
nzchar = function(x, keepNA = FALSE) {
check_bool(keepNA)
if (keepNA) {
exp <- expr(!!x != "")
translate_sql(!!exp, con = sql_current_con())
Expand Down Expand Up @@ -281,6 +282,7 @@ base_scalar <- sql_translator(
str_c = sql_paste(""),
str_sub = sql_str_sub("SUBSTR"),
str_like = function(string, pattern, ignore_case = TRUE) {
check_bool(ignore_case)
if (isTRUE(ignore_case)) {
sql_expr(!!string %LIKE% !!pattern)
} else {
Expand Down
1 change: 1 addition & 0 deletions R/backend-hive.R
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ sql_table_analyze.Hive <- function(con, table, ...) {

#' @export
sql_query_set_op.Hive <- function(con, x, y, method, ..., all = FALSE, lvl = 0) {
check_bool(all)
# parentheses are not allowed
method <- paste0(method, if (all) " ALL")
glue_sql2(
Expand Down
33 changes: 15 additions & 18 deletions R/backend-mssql.R
Original file line number Diff line number Diff line change
Expand Up @@ -107,11 +107,15 @@ simulate_mssql <- function(version = "15.0") {
conflict = c("error", "ignore"),
returning_cols = NULL,
method = NULL) {
method <- method %||% "where_not_exists"
arg_match(method, "where_not_exists", error_arg = "method")
# https://stackoverflow.com/questions/25969/insert-into-values-select-from
conflict <- rows_check_conflict(conflict)

check_character(returning_cols, allow_null = TRUE)

check_string(method, allow_null = TRUE)
method <- method %||% "where_not_exists"
arg_match(method, "where_not_exists", error_arg = "method")

parts <- rows_insert_prep(con, table, from, insert_cols, by, lvl = 0)

clauses <- list2(
Expand Down Expand Up @@ -177,6 +181,7 @@ simulate_mssql <- function(version = "15.0") {
...,
returning_cols = NULL,
method = NULL) {
check_string(method, allow_null = TRUE)
method <- method %||% "merge"
arg_match(method, "merge", error_arg = "method")

Expand Down Expand Up @@ -333,6 +338,7 @@ simulate_mssql <- function(version = "15.0") {
second = function(x) sql_expr(DATEPART(SECOND, !!x)),

month = function(x, label = FALSE, abbr = TRUE) {
check_bool(label)
if (!label) {
sql_expr(DATEPART(MONTH, !!x))
} else {
Expand All @@ -342,6 +348,7 @@ simulate_mssql <- function(version = "15.0") {
},

quarter = function(x, with_year = FALSE, fiscal_start = 1) {
check_bool(with_year)
check_unsupported_arg(fiscal_start, 1, backend = "SQL Server")

if (with_year) {
Expand All @@ -361,6 +368,7 @@ simulate_mssql <- function(version = "15.0") {
sql_expr(DATEADD(YEAR, !!n, !!x))
},
date_build = function(year, month = 1L, day = 1L, ..., invalid = NULL) {
check_unsupported_arg(invalid, allow_null = TRUE)
sql_expr(DATEFROMPARTS(!!year, !!month, !!day))
},
get_year = function(x) {
Expand All @@ -373,27 +381,16 @@ simulate_mssql <- function(version = "15.0") {
sql_expr(DATEPART(DAY, !!x))
},
date_count_between = function(start, end, precision, ..., n = 1L){

check_dots_empty()
if (precision != "day") {
cli_abort("{.arg precision} must be {.val day} on SQL backends.")
}
if (n != 1) {
cli_abort("{.arg n} must be {.val 1} on SQL backends.")
}
check_unsupported_arg(precision, allowed = "day")
check_unsupported_arg(n, allowed = 1L)

sql_expr(DATEDIFF(DAY, !!start, !!end))
},

difftime = function(time1, time2, tz, units = "days") {

if (!missing(tz)) {
cli::cli_abort("The {.arg tz} argument is not supported for SQL backends.")
}

if (units[1] != "days") {
cli::cli_abort('The only supported value for {.arg units} on SQL backends is "days"')
}
check_unsupported_arg(tz)
check_unsupported_arg(units, allowed = "days")

sql_expr(DATEDIFF(DAY, !!time2, !!time1))
}
Expand Down Expand Up @@ -545,7 +542,7 @@ mssql_version <- function(con) {

#' @export
`sql_returning_cols.Microsoft SQL Server` <- function(con, cols, table, ...) {
stopifnot(table %in% c("DELETED", "INSERTED"))
arg_match(table, values = c("DELETED", "INSERTED"))
returning_cols <- sql_named_cols(con, cols, table = table)

sql_clause("OUTPUT", returning_cols)
Expand Down
28 changes: 15 additions & 13 deletions R/backend-postgres.R
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ postgres_grepl <- function(pattern,
check_unsupported_arg(perl, FALSE, backend = "PostgreSQL")
check_unsupported_arg(fixed, FALSE, backend = "PostgreSQL")
check_unsupported_arg(useBytes, FALSE, backend = "PostgreSQL")
check_bool(ignore.case)

if (ignore.case) {
sql_expr(((!!x)) %~*% ((!!pattern)))
Expand Down Expand Up @@ -123,6 +124,7 @@ sql_translation.PqConnection <- function(con) {
},
# https://www.postgresql.org/docs/current/functions-matching.html
str_like = function(string, pattern, ignore_case = TRUE) {
check_bool(ignore_case)
if (isTRUE(ignore_case)) {
sql_expr(!!string %ILIKE% !!pattern)
} else {
Expand Down Expand Up @@ -162,6 +164,9 @@ sql_translation.PqConnection <- function(con) {
sql_expr(EXTRACT(DAY %FROM% !!x))
},
wday = function(x, label = FALSE, abbr = TRUE, week_start = NULL) {
check_bool(label)
check_bool(abbr)
check_number_whole(week_start, allow_null = TRUE)
if (!label) {
week_start <- week_start %||% getOption("lubridate.week.start", 7)
offset <- as.integer(7 - week_start)
Expand All @@ -182,6 +187,8 @@ sql_translation.PqConnection <- function(con) {
sql_expr(EXTRACT(WEEK %FROM% !!x))
},
month = function(x, label = FALSE, abbr = TRUE) {
check_bool(label)
check_bool(abbr)
if (!label) {
sql_expr(EXTRACT(MONTH %FROM% !!x))
} else {
Expand All @@ -193,6 +200,7 @@ sql_translation.PqConnection <- function(con) {
}
},
quarter = function(x, with_year = FALSE, fiscal_start = 1) {
check_bool(with_year)
check_unsupported_arg(fiscal_start, 1, backend = "PostgreSQL")

if (with_year) {
Expand Down Expand Up @@ -246,17 +254,14 @@ sql_translation.PqConnection <- function(con) {
glue_sql2(sql_current_con(), "({.col x} + {.val n}*INTERVAL'1 year')")
},
date_build = function(year, month = 1L, day = 1L, ..., invalid = NULL) {
check_unsupported_arg(invalid, allow_null = TRUE)
sql_expr(make_date(!!year, !!month, !!day))
},
date_count_between = function(start, end, precision, ..., n = 1L){

check_dots_empty()
if (precision != "day") {
cli_abort("{.arg precision} must be {.val day} on SQL backends.")
}
if (n != 1) {
cli_abort("{.arg n} must be {.val 1} on SQL backends.")
}
check_unsupported_arg(precision, allowed = "day")
check_unsupported_arg(n, allowed = 1L)

sql_expr(!!end - !!start)
},
Expand All @@ -272,13 +277,8 @@ sql_translation.PqConnection <- function(con) {

difftime = function(time1, time2, tz, units = "days") {

if (!missing(tz)) {
cli::cli_abort("The {.arg tz} argument is not supported for SQL backends.")
}

if (units[1] != "days") {
cli::cli_abort('The only supported value for {.arg units} on SQL backends is "days"')
}
check_unsupported_arg(tz)
check_unsupported_arg(units, allowed = "days")

sql_expr((CAST(!!time1 %AS% DATE) - CAST(!!time2 %AS% DATE)))
},
Expand Down Expand Up @@ -344,6 +344,7 @@ sql_query_insert.PqConnection <- function(con,
...,
returning_cols = NULL,
method = NULL) {
check_string(method, allow_null = TRUE)
method <- method %||% "on_conflict"
arg_match(method, c("on_conflict", "where_not_exists"), error_arg = "method")
if (method == "where_not_exists") {
Expand Down Expand Up @@ -379,6 +380,7 @@ sql_query_upsert.PqConnection <- function(con,
...,
returning_cols = NULL,
method = NULL) {
check_string(method, allow_null = TRUE)
method <- method %||% "on_conflict"
arg_match(method, c("cte_update", "on_conflict"), error_arg = "method")

Expand Down
20 changes: 5 additions & 15 deletions R/backend-redshift.R
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ sql_translation.RedshiftConnection <- function(con) {
sql_expr(DATEADD(YEAR, !!n, !!x))
},
date_build = function(year, month = 1L, day = 1L, ..., invalid = NULL) {
check_unsupported_arg(invalid, allow_null = TRUE)
glue_sql2(sql_current_con(), "TO_DATE(CAST({.val year} AS TEXT) || '-' CAST({.val month} AS TEXT) || '-' || CAST({.val day} AS TEXT)), 'YYYY-MM-DD')")
},
get_year = function(x) {
Expand All @@ -84,27 +85,16 @@ sql_translation.RedshiftConnection <- function(con) {
sql_expr(DATE_PART('day', !!x))
},
date_count_between = function(start, end, precision, ..., n = 1L){

check_dots_empty()
if (precision != "day") {
cli_abort("{.arg precision} must be {.val day} on SQL backends.")
}
if (n != 1) {
cli_abort("{.arg n} must be {.val 1} on SQL backends.")
}
check_unsupported_arg(precision, allowed = "day")
check_unsupported_arg(n, allowed = 1L)

sql_expr(DATEDIFF(DAY, !!start, !!end))
},

difftime = function(time1, time2, tz, units = "days") {

if (!missing(tz)) {
cli::cli_abort("The {.arg tz} argument is not supported for SQL backends.")
}

if (units[1] != "days") {
cli::cli_abort('The only supported value for {.arg units} on SQL backends is "days"')
}
check_unsupported_arg(tz)
check_unsupported_arg(units, allowed = "days")

sql_expr(DATEDIFF(DAY, !!time2, !!time1))
}
Expand Down
23 changes: 7 additions & 16 deletions R/backend-spark-sql.R
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ simulate_spark_sql <- function() simulate_dbi("Spark SQL")
sql_expr(add_months(!!x, !!n*12))
},
date_build = function(year, month = 1L, day = 1L, ..., invalid = NULL) {
check_unsupported_arg(invalid, allow_null = TRUE)
sql_expr(make_date(!!year, !!month, !!day))
},
get_year = function(x) {
Expand All @@ -59,27 +60,16 @@ simulate_spark_sql <- function() simulate_dbi("Spark SQL")
sql_expr(date_part('DAY', !!x))
},
date_count_between = function(start, end, precision, ..., n = 1L){

check_dots_empty()
if (precision != "day") {
cli_abort("{.arg precision} must be {.val day} on SQL backends.")
}
if (n != 1) {
cli_abort("{.arg n} must be {.val 1} on SQL backends.")
}
check_unsupported_arg(precision, allowed = "day")
check_unsupported_arg(n, allowed = 1L)

sql_expr(datediff(!!end, !!start))
},

difftime = function(time1, time2, tz, units = "days") {

if (!missing(tz)) {
cli::cli_abort("The {.arg tz} argument is not supported for SQL backends.")
}

if (units[1] != "days") {
cli::cli_abort('The only supported value for {.arg units} on SQL backends is "days"')
}
check_unsupported_arg(tz)
check_unsupported_arg(units, allowed = "days")

sql_expr(datediff(!!time2, !!time1))
}
Expand Down Expand Up @@ -153,7 +143,8 @@ simulate_spark_sql <- function() simulate_dbi("Spark SQL")
indexes = list(),
analyze = TRUE,
in_transaction = FALSE) {

check_bool(overwrite)
check_bool(temporary)
sql <- glue_sql2(
con,
"CREATE ", if (overwrite) "OR REPLACE ",
Expand Down
2 changes: 2 additions & 0 deletions R/backend-teradata.R
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ sql_translation.Teradata <- function(con) {
row_number = win_rank("ROW_NUMBER", empty_order = TRUE),
weighted.mean = function(x, w, na.rm = T) {
# nocov start
check_unsupported_arg(na.rm, allowed = TRUE)
win_over(
sql_expr(SUM((!!x * !!w))/SUM(!!w)),
win_current_group(),
Expand Down Expand Up @@ -191,6 +192,7 @@ sql_translation.Teradata <- function(con) {
},
weighted.mean = function(x, w, na.rm = T) {
# nocov start
check_unsupported_arg(na.rm, allowed = TRUE)
win_over(
sql_expr(SUM((!!x * !!w))/SUM(!!w)),
win_current_group(),
Expand Down
1 change: 1 addition & 0 deletions R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ res_warn_incomplete <- function(res, hint = "n = -1") {
}

add_temporary_prefix <- function(con, table, temporary = TRUE) {
check_bool(temporary)
check_table_path(table)

if (!temporary) {
Expand Down
25 changes: 7 additions & 18 deletions tests/testthat/_snaps/backend-mssql.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,31 +47,34 @@
test_translate_sql(date_count_between(date_column_1, date_column_2, "year"))
Condition
Error in `date_count_between()`:
! `precision` must be "day" on SQL backends.
! `precision = "year"` isn't supported on database backends.
i It must be "day" instead.

---

Code
test_translate_sql(date_count_between(date_column_1, date_column_2, "day", n = 5))
Condition
Error in `date_count_between()`:
! `n` must be "1" on SQL backends.
! `n = 5` isn't supported on database backends.
i It must be 1 instead.

# difftime is translated correctly

Code
test_translate_sql(difftime(start_date, end_date, units = "auto"))
Condition
Error in `difftime()`:
! The only supported value for `units` on SQL backends is "days"
! `units = "auto"` isn't supported on database backends.
i It must be "days" instead.

---

Code
test_translate_sql(difftime(start_date, end_date, tz = "UTC", units = "days"))
Condition
Error in `difftime()`:
! The `tz` argument is not supported for SQL backends.
! Argument `tz` isn't supported on database backends.

# convert between bit and boolean as needed

Expand Down Expand Up @@ -494,20 +497,6 @@
FROM `df`
ORDER BY `y`

# can copy_to() and compute() with temporary tables (#438)

Code
db <- copy_to(con, df, name = unique_table_name(), temporary = TRUE)
Message
Created a temporary table named #dbplyr_{tmp}

---

Code
db2 <- db %>% mutate(y = x + 1) %>% compute()
Message
Created a temporary table named #dbplyr_{tmp}

# add prefix to temporary table

Code
Expand Down
Loading

0 comments on commit d854986

Please sign in to comment.