diff --git a/R/data.table.R b/R/data.table.R index 06a7f0437..e3191583e 100644 --- a/R/data.table.R +++ b/R/data.table.R @@ -521,6 +521,50 @@ replace_dot_alias = function(e) { list(GForce=GForce, jsub=jsub, jvnames=jvnames) } +# Helper function to process SDcols +.processSDcols = function(SDcols_sub, SDcols_missing, x, jsub, by, enclos = parent.frame()) { + names_x = names(x) + bysub = substitute(by) + allbyvars = intersect(all.vars(bysub), names_x) + usesSD = ".SD" %chin% all.vars(jsub) + if (!usesSD) { + return(NULL) + } + if (SDcols_missing) { + ansvars = sdvars = setdiff(unique(names_x), union(by, allbyvars)) + ansvals = match(ansvars, names_x) + return(list(ansvars = ansvars, sdvars = sdvars, ansvals = ansvals)) + } + sub.result = SDcols_sub + if (sub.result %iscall% "patterns") { + .SDcols = eval_with_cols(sub.result, names_x) + } else { + .SDcols = eval(sub.result, enclos) + } + if (anyNA(.SDcols)) + stopf(".SDcols missing at the following indices: %s", brackify(which(is.na(.SDcols)))) + if (is.character(.SDcols)) { + idx = .SDcols %chin% names_x + if (!all(idx)) + stopf("Some items of .SDcols are not column names: %s", toString(.SDcols[!idx])) + ansvars = sdvars = .SDcols + ansvals = match(ansvars, names_x) + } else if (is.numeric(.SDcols)) { + ansvals = as.integer(.SDcols) + if (any(ansvals < 1L | ansvals > length(names_x))) + stopf(".SDcols contains indices out of bounds") + ansvars = sdvars = names_x[ansvals] + } else if (is.logical(.SDcols)) { + if (length(.SDcols) != length(names_x)) + stopf(".SDcols is a logical vector of length %d but there are %d columns", length(.SDcols), length(names_x)) + ansvals = which(.SDcols) + ansvars = sdvars = names_x[ansvals] + } else { + stopf(".SDcols must be character, numeric, or logical") + } + list(ansvars = ansvars, sdvars = sdvars, ansvals = ansvals) +} + "[.data.table" = function(x, i, j, by, keyby, with=TRUE, nomatch=NA, mult="all", roll=FALSE, rollends=if (roll=="nearest") c(TRUE,TRUE) else if (roll>=0.0) c(FALSE,TRUE) else c(TRUE,FALSE), which=FALSE, .SDcols, verbose=getOption("datatable.verbose"), allow.cartesian=getOption("datatable.allow.cartesian"), drop=NULL, on=NULL, env=NULL, showProgress=getOption("datatable.showProgress", interactive())) { # ..selfcount <<- ..selfcount+1 # in dev, we check no self calls, each of which doubles overhead, or could diff --git a/R/groupingsets.R b/R/groupingsets.R index 885a64830..7cf6de86c 100644 --- a/R/groupingsets.R +++ b/R/groupingsets.R @@ -20,22 +20,33 @@ cube = function(x, ...) { UseMethod("cube") } cube.data.table = function(x, j, by, .SDcols, id = FALSE, label = NULL, ...) { - # input data type basic validation - if (!is.data.table(x)) - stopf("'%s' must be a data.table", "x", class="dt_invalid_input_error") - if (!is.character(by)) - stopf("Argument 'by' must be a character vector of column names used in grouping.") - if (!is.logical(id)) - stopf("Argument 'id' must be a logical scalar.") - if (missing(j)) - stopf("Argument 'j' is required") - # generate grouping sets for cube - power set: http://stackoverflow.com/a/32187892/2490497 - n = length(by) - keepBool = sapply(2L^(seq_len(n)-1L), function(k) rep(c(FALSE, TRUE), times=k, each=((2L^n)/(2L*k)))) - sets = lapply((2L^n):1L, function(jj) by[keepBool[jj, ]]) - # redirect to workhorse function - jj = substitute(j) - groupingsets.data.table(x, by=by, sets=sets, .SDcols=.SDcols, id=id, jj=jj, label=label, enclos = parent.frame()) + # input data type basic validation + if (!is.data.table(x)) + stopf("Argument 'x' must be a data.table object", class="dt_invalid_input_error") + if (!is.character(by)) + stopf("Argument 'by' must be a character vector of column names used in grouping.") + if (!is.logical(id)) + stopf("Argument 'id' must be a logical scalar.") + if (missing(j)) + stopf("Argument 'j' is required") + # Implementing NSE in cube using the helper, .processSDcols + jj = substitute(j) + sdcols_result = .processSDcols(SDcols_sub = substitute(.SDcols), SDcols_missing = missing(.SDcols), x = x, jsub = jj, by = by, enclos = parent.frame()) + if (is.null(sdcols_result)) { + .SDcols = NULL + } else { + ansvars = sdcols_result$ansvars + sdvars = sdcols_result$sdvars + ansvals = sdcols_result$ansvals + .SDcols = sdvars + } + # generate grouping sets for cube - power set: http://stackoverflow.com/a/32187892/2490497 + n = length(by) + keepBool = sapply(2L^(seq_len(n)-1L), function(k) rep(c(FALSE, TRUE), times=k, each=((2L^n)/(2L*k)))) + sets = lapply((2L^n):1L, function(jj) by[keepBool[jj, ]]) + # redirect to workhorse function + jj = substitute(j) + groupingsets.data.table(x, by=by, sets=sets, .SDcols=.SDcols, id=id, jj=jj, label=label, enclos = parent.frame()) } groupingsets = function(x, ...) { diff --git a/inst/tests/tests.Rraw b/inst/tests/tests.Rraw index 10fd2fc7f..d7f0245d3 100644 --- a/inst/tests/tests.Rraw +++ b/inst/tests/tests.Rraw @@ -11097,6 +11097,43 @@ test(1750.34, character(0)), id = TRUE) ) +test(1750.35, + cube(dt, j = lapply(.SD, sum), by = c("color","year","status"), id=TRUE, .SDcols=patterns("value")), + groupingsets(dt, j = lapply(.SD, sum), by = c("color","year","status"), .SDcols = "value", + sets = list(c("color","year","status"), + c("color","year"), + c("color","status"), + "color", + c("year","status"), + "year", + "status", + character(0)), + id = TRUE) +) +test(1750.36, + cube(dt, j = lapply(.SD, sum), by = "year", .SDcols = c("value", "BADCOL")), + error = "Some items of \\.SDcols are not column names" +) + +test(1750.37, + cube(dt, j = lapply(.SD, sum), by = "year", .SDcols = c(TRUE, FALSE)), + error = "\\.SDcols is a logical vector of length" +) + +test(1750.38, +cube(dt, j = lapply(.SD, mean), by = "color", .SDcols = c(FALSE, FALSE, FALSE, TRUE, FALSE), id=TRUE), + groupingsets(dt, j = lapply(.SD, mean), by = "color", .SDcols = "amount", + sets = list("color", character(0)), + id = TRUE) +) +test(1750.39, + cube(dt, j = lapply(.SD, sum), by = "color", .SDcols = list("amount")), + error = ".SDcols must be character, numeric, or logical" +) +test(1750.40, + cube(dt, j = lapply(.SD, sum), by = "color", .SDcols = c(1, 99)), + error = "out of bounds" +) # grouping sets with integer64 if (test_bit64) { set.seed(26)