diff --git a/DESCRIPTION b/DESCRIPTION index 8a06e6f56..6fecc73c4 100755 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -21,7 +21,8 @@ Description: This package introduces a common data structure for epidemiological work with revisions to these data sets over time, and offers associated utilities to perform basic signal processing tasks. License: MIT + file LICENSE -Imports: +Imports: + cli, data.table, dplyr (>= 1.0.0), fabletools, @@ -48,7 +49,7 @@ Suggests: knitr, outbreaks, rmarkdown, - testthat (>= 3.0.0), + testthat (>= 3.1.5), waldo (>= 0.3.1), withr VignetteBuilder: diff --git a/NAMESPACE b/NAMESPACE index 10847e6c9..065302d75 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -84,6 +84,7 @@ importFrom(dplyr,ungroup) importFrom(lubridate,days) importFrom(lubridate,weeks) importFrom(magrittr,"%>%") +importFrom(purrr,map_lgl) importFrom(rlang,"!!!") importFrom(rlang,"!!") importFrom(rlang,.data) @@ -91,6 +92,7 @@ importFrom(rlang,.env) importFrom(rlang,arg_match) importFrom(rlang,enquo) importFrom(rlang,enquos) +importFrom(rlang,is_missing) importFrom(rlang,is_quosure) importFrom(rlang,quo_is_missing) importFrom(rlang,sym) @@ -101,3 +103,4 @@ importFrom(tidyr,unnest) importFrom(tidyselect,eval_select) importFrom(tidyselect,starts_with) importFrom(tsibble,as_tsibble) +importFrom(utils,tail) diff --git a/R/grouped_epi_archive.R b/R/grouped_epi_archive.R index 1c6bd3110..76b079a4a 100644 --- a/R/grouped_epi_archive.R +++ b/R/grouped_epi_archive.R @@ -220,6 +220,11 @@ grouped_epi_archive = ref_time_values = sort(ref_time_values) } + # Check that `f` takes enough args + if (!missing(f) && is.function(f)) { + assert_sufficient_f_args(f, ...) + } + # Validate and pre-process `before`: if (missing(before)) { Abort("`before` is required (and must be passed by name); diff --git a/R/slide.R b/R/slide.R index ea2b93cba..d8d6becba 100644 --- a/R/slide.R +++ b/R/slide.R @@ -155,7 +155,12 @@ epi_slide = function(x, f, ..., before, after, ref_time_values, # Check we have an `epi_df` object if (!inherits(x, "epi_df")) Abort("`x` must be of class `epi_df`.") - + + # Check that `f` takes enough args + if (!missing(f) && is.function(f)) { + assert_sufficient_f_args(f, ...) + } + # Arrange by increasing time_value x = arrange(x, time_value) diff --git a/R/utils.R b/R/utils.R index b398ff5b6..d17f05d4e 100644 --- a/R/utils.R +++ b/R/utils.R @@ -100,6 +100,87 @@ paste_lines = function(lines) { Abort = function(msg, ...) rlang::abort(break_str(msg, init = "Error: "), ...) Warn = function(msg, ...) rlang::warn(break_str(msg, init = "Warning: "), ...) +#' Assert that a sliding computation function takes enough args +#' +#' @param f Function; specifies a computation to slide over an `epi_df` or +#' `epi_archive` in `epi_slide` or `epix_slide`. +#' @param ... Dots that will be forwarded to `f` from the dots of `epi_slide` or +#' `epix_slide`. +#' +#' @importFrom rlang is_missing +#' @importFrom purrr map_lgl +#' @importFrom utils tail +#' +#' @noRd +assert_sufficient_f_args <- function(f, ...) { + mandatory_f_args_labels <- c("window data", "group key") + n_mandatory_f_args <- length(mandatory_f_args_labels) + args = formals(args(f)) + args_names = names(args) + # Remove named arguments forwarded from `epi[x]_slide`'s `...`: + forwarded_dots_names = names(rlang::call_match(dots_expand = FALSE)[["..."]]) + args_matched_in_dots = + # positional calling args will skip over args matched by named calling args + args_names %in% forwarded_dots_names & + # extreme edge case: `epi[x]_slide(, dot = 1, `...` = 2)` + args_names != "..." + remaining_args = args[!args_matched_in_dots] + remaining_args_names = names(remaining_args) + # note that this doesn't include unnamed args forwarded through `...`. + dots_i <- which(remaining_args_names == "...") # integer(0) if no match + n_f_args_before_dots <- dots_i - 1L + if (length(dots_i) != 0L) { # `f` has a dots "arg" + # Keep all arg names before `...` + mandatory_args_mapped_names <- remaining_args_names[seq_len(n_f_args_before_dots)] + + if (n_f_args_before_dots < n_mandatory_f_args) { + mandatory_f_args_in_f_dots = + tail(mandatory_f_args_labels, n_mandatory_f_args - n_f_args_before_dots) + cli::cli_warn( + "`f` might not have enough positional arguments before its `...`; in the current `epi[x]_slide` call, the {mandatory_f_args_in_f_dots} will be included in `f`'s `...`; if `f` doesn't expect those arguments, it may produce confusing error messages", + class = "epiprocess__assert_sufficient_f_args__mandatory_f_args_passed_to_f_dots", + epiprocess__f = f, + epiprocess__mandatory_f_args_in_f_dots = mandatory_f_args_in_f_dots + ) + } + } else { # `f` doesn't have a dots "arg" + if (length(args_names) < n_mandatory_f_args + rlang::dots_n(...)) { + # `f` doesn't take enough args. + if (rlang::dots_n(...) == 0L) { + # common case; try for friendlier error message + Abort(sprintf("`f` must take at least %s arguments", n_mandatory_f_args), + class = "epiprocess__assert_sufficient_f_args__f_needs_min_args", + epiprocess__f = f) + } else { + # less common; highlight that they are (accidentally?) using dots forwarding + Abort(sprintf("`f` must take at least %s arguments plus the %s arguments forwarded through `epi[x]_slide`'s `...`, or a named argument to `epi[x]_slide` was misspelled", n_mandatory_f_args, rlang::dots_n(...)), + class = "epiprocess__assert_sufficient_f_args__f_needs_min_args_plus_forwarded", + epiprocess__f = f) + } + } + } + # Check for args with defaults that are filled with mandatory positional + # calling args. If `f` has fewer than n_mandatory_f_args before `...`, then we + # only need to check those args for defaults. Note that `n_f_args_before_dots` is + # length 0 if `f` doesn't accept `...`. + n_remaining_args_for_default_check = min(c(n_f_args_before_dots, n_mandatory_f_args)) + default_check_args = remaining_args[seq_len(n_remaining_args_for_default_check)] + default_check_args_names = names(default_check_args) + has_default_replaced_by_mandatory = map_lgl(default_check_args, ~!is_missing(.x)) + if (any(has_default_replaced_by_mandatory)) { + default_check_mandatory_args_labels = + mandatory_f_args_labels[seq_len(n_remaining_args_for_default_check)] + # ^ excludes any mandatory args absorbed by f's `...`'s: + mandatory_args_replacing_defaults = + default_check_mandatory_args_labels[has_default_replaced_by_mandatory] + args_with_default_replaced_by_mandatory = + rlang::syms(default_check_args_names[has_default_replaced_by_mandatory]) + cli::cli_abort("`epi[x]_slide` would pass the {mandatory_args_replacing_defaults} to `f`'s {args_with_default_replaced_by_mandatory} argument{?s}, which {?has a/have} default value{?s}; we suspect that `f` doesn't expect {?this arg/these args} at all and may produce confusing error messages. Please add additional arguments to `f` or remove defaults as appropriate.", + class = "epiprocess__assert_sufficient_f_args__required_args_contain_defaults", + epiprocess__f = f) + } +} + ########## in_range = function(x, rng) pmin(pmax(x, rng[1]), rng[2]) diff --git a/tests/testthat/test-epi_slide.R b/tests/testthat/test-epi_slide.R index eebcc55b9..84192f940 100644 --- a/tests/testthat/test-epi_slide.R +++ b/tests/testthat/test-epi_slide.R @@ -86,3 +86,14 @@ test_that("these doesn't produce an error; the error appears only if the ref tim dplyr::select("geo_value","slide_value_value"), dplyr::tibble(geo_value = c("ak", "al"), slide_value_value = c(2, -2))) # not out of range for either group }) + +test_that("epi_slide alerts if the provided f doesn't take enough args", { + f_xg = function(x, g) dplyr::tibble(value=mean(x$value), count=length(x$value)) + # If `regexp` is NA, asserts that there should be no errors/messages. + expect_error(epi_slide(grouped, f_xg, before = 1L, ref_time_values = d+1), regexp = NA) + expect_warning(epi_slide(grouped, f_xg, before = 1L, ref_time_values = d+1), regexp = NA) + + f_x_dots = function(x, ...) dplyr::tibble(value=mean(x$value), count=length(x$value)) + expect_warning(epi_slide(grouped, f_x_dots, before = 1L, ref_time_values = d+1), + class = "epiprocess__assert_sufficient_f_args__mandatory_f_args_passed_to_f_dots") +}) diff --git a/tests/testthat/test-epix_slide.R b/tests/testthat/test-epix_slide.R index 9ef2f9afd..5eeb5c2c1 100644 --- a/tests/testthat/test-epix_slide.R +++ b/tests/testthat/test-epix_slide.R @@ -348,3 +348,14 @@ test_that("epix_slide with all_versions option works as intended",{ expect_identical(xx1,xx3) # This and * Imply xx2 and xx3 are identical }) + +test_that("epix_slide alerts if the provided f doesn't take enough args", { + f_xg = function(x, g) dplyr::tibble(value=mean(x$binary), count=length(x$binary)) + # If `regexp` is NA, asserts that there should be no errors/messages. + expect_error(epix_slide(xx, f = f_xg, before = 2L), regexp = NA) + expect_warning(epix_slide(xx, f = f_xg, before = 2L), regexp = NA) + + f_x_dots = function(x, ...) dplyr::tibble(value=mean(x$binary), count=length(x$binary)) + expect_warning(epix_slide(xx, f_x_dots, before = 2L), + class = "epiprocess__assert_sufficient_f_args__mandatory_f_args_passed_to_f_dots") +}) diff --git a/tests/testthat/test-utils.R b/tests/testthat/test-utils.R index 08b28c97e..6648ce3ce 100644 --- a/tests/testthat/test-utils.R +++ b/tests/testthat/test-utils.R @@ -107,4 +107,80 @@ test_that("enlist works",{ my_list <- enlist(x=1,y=2,z=3) expect_equal(my_list$x,1) expect_true(inherits(my_list,"list")) -}) \ No newline at end of file +}) + +test_that("assert_sufficient_f_args alerts if the provided f doesn't take enough args", { + f_xg = function(x, g) dplyr::tibble(value=mean(x$binary), count=length(x$binary)) + f_xg_dots = function(x, g, ...) dplyr::tibble(value=mean(x$binary), count=length(x$binary)) + + # If `regexp` is NA, asserts that there should be no errors/messages. + expect_error(assert_sufficient_f_args(f_xg), regexp = NA) + expect_warning(assert_sufficient_f_args(f_xg), regexp = NA) + expect_error(assert_sufficient_f_args(f_xg_dots), regexp = NA) + expect_warning(assert_sufficient_f_args(f_xg_dots), regexp = NA) + + f_x_dots = function(x, ...) dplyr::tibble(value=mean(x$binary), count=length(x$binary)) + f_dots = function(...) dplyr::tibble(value=c(5), count=c(2)) + f_x = function(x) dplyr::tibble(value=mean(x$binary), count=length(x$binary)) + f = function() dplyr::tibble(value=c(5), count=c(2)) + + expect_warning(assert_sufficient_f_args(f_x_dots), + regexp = ", the group key will be included", + class = "epiprocess__assert_sufficient_f_args__mandatory_f_args_passed_to_f_dots") + expect_warning(assert_sufficient_f_args(f_dots), + regexp = ", the window data and group key will be included", + class = "epiprocess__assert_sufficient_f_args__mandatory_f_args_passed_to_f_dots") + expect_error(assert_sufficient_f_args(f_x), + class = "epiprocess__assert_sufficient_f_args__f_needs_min_args") + expect_error(assert_sufficient_f_args(f), + class = "epiprocess__assert_sufficient_f_args__f_needs_min_args") + + f_xs_dots = function(x, setting="a", ...) dplyr::tibble(value=mean(x$binary), count=length(x$binary)) + f_xs = function(x, setting="a") dplyr::tibble(value=mean(x$binary), count=length(x$binary)) + expect_warning(assert_sufficient_f_args(f_xs_dots, setting="b"), + class = "epiprocess__assert_sufficient_f_args__mandatory_f_args_passed_to_f_dots") + expect_error(assert_sufficient_f_args(f_xs, setting="b"), + class = "epiprocess__assert_sufficient_f_args__f_needs_min_args_plus_forwarded") + + expect_error(assert_sufficient_f_args(f_xg, "b"), + class = "epiprocess__assert_sufficient_f_args__f_needs_min_args_plus_forwarded") +}) + +test_that("assert_sufficient_f_args alerts if the provided f has defaults for the required args", { + f_xg = function(x, g=1) dplyr::tibble(value=mean(x$binary), count=length(x$binary)) + f_xg_dots = function(x=1, g, ...) dplyr::tibble(value=mean(x$binary), count=length(x$binary)) + f_x_dots = function(x=1, ...) dplyr::tibble(value=mean(x$binary), count=length(x$binary)) + + expect_error(assert_sufficient_f_args(f_xg), + regexp = "pass the group key to `f`'s g argument,", + class = "epiprocess__assert_sufficient_f_args__required_args_contain_defaults") + expect_error(assert_sufficient_f_args(f_xg_dots), + regexp = "pass the window data to `f`'s x argument,", + class = "epiprocess__assert_sufficient_f_args__required_args_contain_defaults") + expect_error(suppressWarnings(assert_sufficient_f_args(f_x_dots)), + class = "epiprocess__assert_sufficient_f_args__required_args_contain_defaults") + + f_xsg = function(x, setting="a", g) dplyr::tibble(value=mean(x$binary), count=length(x$binary)) + f_xsg_dots = function(x, setting="a", g, ...) dplyr::tibble(value=mean(x$binary), count=length(x$binary)) + f_xs_dots = function(x=1, setting="a", ...) dplyr::tibble(value=mean(x$binary), count=length(x$binary)) + + # forwarding named dots should prevent some complaints: + expect_no_error(assert_sufficient_f_args(f_xsg, setting = "b")) + expect_no_error(assert_sufficient_f_args(f_xsg_dots, setting = "b")) + expect_error(suppressWarnings(assert_sufficient_f_args(f_xs_dots, setting = "b")), + regexp = "window data to `f`'s x argument", + class = "epiprocess__assert_sufficient_f_args__required_args_contain_defaults") + + # forwarding unnamed dots should not: + expect_error(assert_sufficient_f_args(f_xsg, "b"), + class = "epiprocess__assert_sufficient_f_args__required_args_contain_defaults") + expect_error(assert_sufficient_f_args(f_xsg_dots, "b"), + class = "epiprocess__assert_sufficient_f_args__required_args_contain_defaults") + expect_error(assert_sufficient_f_args(f_xs_dots, "b"), + class = "epiprocess__assert_sufficient_f_args__required_args_contain_defaults") + + # forwarding no dots should produce a different error message in some cases: + expect_error(assert_sufficient_f_args(f_xs_dots), + regexp = "window data and group key to `f`'s x and setting argument", + class = "epiprocess__assert_sufficient_f_args__required_args_contain_defaults") +})