@@ -186,15 +186,16 @@ epi_slide <- function(
186186
187187 # Validate arguments
188188 assert_class(.x , " epi_df" )
189- if (checkmate :: test_class(.x , " grouped_df" )) {
189+ .x_orig_groups <- groups(.x )
190+ if (inherits(.x , " grouped_df" )) {
190191 expected_group_keys <- .x %> %
191192 key_colnames(exclude = " time_value" ) %> %
192193 sort()
193194 if (! identical(.x %> % group_vars() %> % sort(), expected_group_keys )) {
194195 cli_abort(
195- " epi_slide: `.x` must be either grouped by {expected_group_keys}. (Or you can just ungroup
196- `.x` and we'll do this grouping automatically.) You may need to aggregate your data first,
197- see aggregate_epi_df()." ,
196+ " `.x` must be either grouped by {expected_group_keys} or ungrouped; if the latter,
197+ we'll temporarily group by {expected_group_keys} for this operation. You may need
198+ to aggregate your data first, see aggregate_epi_df()." ,
198199 class = " epiprocess__epi_slide__invalid_grouping"
199200 )
200201 }
@@ -300,7 +301,6 @@ epi_slide <- function(
300301 # `epi_slide_one_group`.
301302 # - `...` from top of `epi_slide` are forwarded to `.f` here through
302303 # group_modify and through the lambda.
303- .x_groups <- groups(.x )
304304 result <- group_map(
305305 .x ,
306306 .f = function (.data_group , .group_key , ... ) {
@@ -324,7 +324,7 @@ epi_slide <- function(
324324 filter(.real ) %> %
325325 select(- .real ) %> %
326326 arrange_col_canonical() %> %
327- group_by(!!! .x_groups )
327+ group_by(!!! .x_orig_groups )
328328
329329 # If every group in epi_slide_one_group takes the
330330 # length(available_ref_time_values) == 0 branch then we end up here.
@@ -691,12 +691,30 @@ epi_slide_opt <- function(
691691 )
692692 }
693693
694+ assert_class(.x , " epi_df" )
695+ .x_orig_groups <- groups(.x )
696+ if (inherits(.x , " grouped_df" )) {
697+ expected_group_keys <- .x %> %
698+ key_colnames(exclude = " time_value" ) %> %
699+ sort()
700+ if (! identical(.x %> % group_vars() %> % sort(), expected_group_keys )) {
701+ cli_abort(
702+ " `.x` must be either grouped by {expected_group_keys} or ungrouped; if the latter,
703+ we'll temporarily group by {expected_group_keys} for this operation. You may need
704+ to aggregate your data first, see aggregate_epi_df()." ,
705+ class = " epiprocess__epi_slide__invalid_grouping"
706+ )
707+ }
708+ } else {
709+ .x <- group_epi_df(.x , exclude = " time_value" )
710+ }
694711 if (nrow(.x ) == 0L ) {
695712 cli_abort(
696713 c(
697714 " input data `.x` unexpectedly has 0 rows" ,
698715 " i" = " If this computation is occuring within an `epix_slide` call,
699- check that `epix_slide` `.versions` argument was set appropriately"
716+ check that `epix_slide` `.versions` argument was set appropriately
717+ so that you don't get any completely-empty snapshots"
700718 ),
701719 class = " epiprocess__epi_slide_opt__0_row_input" ,
702720 epiprocess__x = .x
@@ -857,27 +875,9 @@ epi_slide_opt <- function(
857875 arrange(.data $ time_value )
858876
859877 if (f_from_package == " data.table" ) {
860- # If a group contains duplicate time values, `frollmean` will still only
861- # use the last `k` obs. It isn't looking at dates, it just goes in row
862- # order. So if the computation is aggregating across multiple obs for the
863- # same date, `epi_slide_opt` and derivates will produce incorrect results;
864- # `epi_slide` should be used instead.
865- if (anyDuplicated(.data_group $ time_value ) != 0L ) {
866- cli_abort(
867- c(
868- " group contains duplicate time values. Using `epi_slide_[opt/mean/sum]` on this
869- group will result in incorrect results" ,
870- " i" = " Please change the grouping structure of the input data so that
871- each group has non-duplicate time values (e.g. `x %>% group_by(geo_value)
872- %>% epi_slide_opt(.f = frollmean)`)" ,
873- " i" = " Use `epi_slide` to aggregate across groups"
874- ),
875- class = " epiprocess__epi_slide_opt__duplicate_time_values" ,
876- epiprocess__data_group = .data_group ,
877- epiprocess__group_key = .group_key
878- )
879- }
880-
878+ # Grouping should ensure that we don't have duplicate time values.
879+ # Completion above should ensure we have at least .window_size rows. Check
880+ # that we don't have more than .window_size rows (or fewer somehow):
881881 if (nrow(.data_group ) != length(c(all_dates , pad_early_dates , pad_late_dates ))) {
882882 cli_abort(
883883 c(
@@ -928,7 +928,8 @@ epi_slide_opt <- function(
928928 group_modify(slide_one_grp , ... , .keep = FALSE ) %> %
929929 filter(.data $ .real ) %> %
930930 select(- .real ) %> %
931- arrange_col_canonical()
931+ arrange_col_canonical() %> %
932+ group_by(!!! .x_orig_groups )
932933
933934 if (.all_rows ) {
934935 result [! (result $ time_value %in% ref_time_values ), result_col_names ] <- NA
0 commit comments