Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions R/data.table.R
Original file line number Diff line number Diff line change
Expand Up @@ -521,6 +521,50 @@ replace_dot_alias = function(e) {
list(GForce=GForce, jsub=jsub, jvnames=jvnames)
}

# Helper function to process SDcols
.processSDcols = function(SDcols_sub, SDcols_missing, x, jsub, by, enclos = parent.frame()) {
names_x = names(x)
bysub = substitute(by)
allbyvars = intersect(all.vars(bysub), names_x)
usesSD = ".SD" %chin% all.vars(jsub)
if (!usesSD) {
return(NULL)
}
if (SDcols_missing) {
ansvars = sdvars = setdiff(unique(names_x), union(by, allbyvars))
ansvals = match(ansvars, names_x)
return(list(ansvars = ansvars, sdvars = sdvars, ansvals = ansvals))
}
sub.result = SDcols_sub
if (sub.result %iscall% "patterns") {
.SDcols = eval_with_cols(sub.result, names_x)
} else {
.SDcols = eval(sub.result, enclos)
}
if (anyNA(.SDcols))
stopf(".SDcols missing at the following indices: %s", brackify(which(is.na(.SDcols))))
if (is.character(.SDcols)) {
idx = .SDcols %chin% names_x
if (!all(idx))
stopf("Some items of .SDcols are not column names: %s", toString(.SDcols[!idx]))
ansvars = sdvars = .SDcols
ansvals = match(ansvars, names_x)
} else if (is.numeric(.SDcols)) {
ansvals = as.integer(.SDcols)
if (any(ansvals < 1L | ansvals > length(names_x)))
stopf(".SDcols contains indices out of bounds")
ansvars = sdvars = names_x[ansvals]
} else if (is.logical(.SDcols)) {
if (length(.SDcols) != length(names_x))
stopf(".SDcols is a logical vector of length %d but there are %d columns", length(.SDcols), length(names_x))
ansvals = which(.SDcols)
ansvars = sdvars = names_x[ansvals]
} else {
stopf(".SDcols must be character, numeric, or logical")
}
list(ansvars = ansvars, sdvars = sdvars, ansvals = ansvals)
}

"[.data.table" = function(x, i, j, by, keyby, with=TRUE, nomatch=NA, mult="all", roll=FALSE, rollends=if (roll=="nearest") c(TRUE,TRUE) else if (roll>=0.0) c(FALSE,TRUE) else c(TRUE,FALSE), which=FALSE, .SDcols, verbose=getOption("datatable.verbose"), allow.cartesian=getOption("datatable.allow.cartesian"), drop=NULL, on=NULL, env=NULL, showProgress=getOption("datatable.showProgress", interactive()))
{
# ..selfcount <<- ..selfcount+1 # in dev, we check no self calls, each of which doubles overhead, or could
Expand Down
43 changes: 27 additions & 16 deletions R/groupingsets.R
Original file line number Diff line number Diff line change
Expand Up @@ -20,22 +20,33 @@ cube = function(x, ...) {
UseMethod("cube")
}
cube.data.table = function(x, j, by, .SDcols, id = FALSE, label = NULL, ...) {
# input data type basic validation
if (!is.data.table(x))
stopf("'%s' must be a data.table", "x", class="dt_invalid_input_error")
if (!is.character(by))
stopf("Argument 'by' must be a character vector of column names used in grouping.")
if (!is.logical(id))
stopf("Argument 'id' must be a logical scalar.")
if (missing(j))
stopf("Argument 'j' is required")
# generate grouping sets for cube - power set: http://stackoverflow.com/a/32187892/2490497
n = length(by)
keepBool = sapply(2L^(seq_len(n)-1L), function(k) rep(c(FALSE, TRUE), times=k, each=((2L^n)/(2L*k))))
sets = lapply((2L^n):1L, function(jj) by[keepBool[jj, ]])
# redirect to workhorse function
jj = substitute(j)
groupingsets.data.table(x, by=by, sets=sets, .SDcols=.SDcols, id=id, jj=jj, label=label, enclos = parent.frame())
# input data type basic validation
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changes in those lines looks to be unrelated to this PR

if (!is.data.table(x))
stopf("Argument 'x' must be a data.table object", class="dt_invalid_input_error")
if (!is.character(by))
stopf("Argument 'by' must be a character vector of column names used in grouping.")
if (!is.logical(id))
stopf("Argument 'id' must be a logical scalar.")
if (missing(j))
stopf("Argument 'j' is required")
# Implementing NSE in cube using the helper, .processSDcols
jj = substitute(j)
sdcols_result = .processSDcols(SDcols_sub = substitute(.SDcols), SDcols_missing = missing(.SDcols), x = x, jsub = jj, by = by, enclos = parent.frame())
if (is.null(sdcols_result)) {
.SDcols = NULL
} else {
ansvars = sdcols_result$ansvars
sdvars = sdcols_result$sdvars
ansvals = sdcols_result$ansvals
.SDcols = sdvars
}
# generate grouping sets for cube - power set: http://stackoverflow.com/a/32187892/2490497
n = length(by)
keepBool = sapply(2L^(seq_len(n)-1L), function(k) rep(c(FALSE, TRUE), times=k, each=((2L^n)/(2L*k))))
sets = lapply((2L^n):1L, function(jj) by[keepBool[jj, ]])
# redirect to workhorse function
jj = substitute(j)
groupingsets.data.table(x, by=by, sets=sets, .SDcols=.SDcols, id=id, jj=jj, label=label, enclos = parent.frame())
}

groupingsets = function(x, ...) {
Expand Down
37 changes: 37 additions & 0 deletions inst/tests/tests.Rraw
Original file line number Diff line number Diff line change
Expand Up @@ -11097,6 +11097,43 @@ test(1750.34,
character(0)),
id = TRUE)
)
test(1750.35,
cube(dt, j = lapply(.SD, sum), by = c("color","year","status"), id=TRUE, .SDcols=patterns("value")),
groupingsets(dt, j = lapply(.SD, sum), by = c("color","year","status"), .SDcols = "value",
sets = list(c("color","year","status"),
c("color","year"),
c("color","status"),
"color",
c("year","status"),
"year",
"status",
character(0)),
id = TRUE)
)
test(1750.36,
cube(dt, j = lapply(.SD, sum), by = "year", .SDcols = c("value", "BADCOL")),
error = "Some items of \\.SDcols are not column names"
)

test(1750.37,
cube(dt, j = lapply(.SD, sum), by = "year", .SDcols = c(TRUE, FALSE)),
error = "\\.SDcols is a logical vector of length"
)

test(1750.38,
cube(dt, j = lapply(.SD, mean), by = "color", .SDcols = c(FALSE, FALSE, FALSE, TRUE, FALSE), id=TRUE),
groupingsets(dt, j = lapply(.SD, mean), by = "color", .SDcols = "amount",
sets = list("color", character(0)),
id = TRUE)
)
test(1750.39,
cube(dt, j = lapply(.SD, sum), by = "color", .SDcols = list("amount")),
error = ".SDcols must be character, numeric, or logical"
)
test(1750.40,
cube(dt, j = lapply(.SD, sum), by = "color", .SDcols = c(1, 99)),
error = "out of bounds"
)
# grouping sets with integer64
if (test_bit64) {
set.seed(26)
Expand Down
Loading