aboutsummaryrefslogtreecommitdiff
path: root/R
diff options
context:
space:
mode:
authorqhuang <qian.huang@intel.com>2015-05-15 14:06:16 -0700
committerShivaram Venkataraman <shivaram@cs.berkeley.edu>2015-05-15 14:06:16 -0700
commit50da9e89161faa0ecdc1feb3ffee6c822a742034 (patch)
tree269c29d8782a270f5522f4b77e00bfb27745583e /R
parent9b6cf285d0b60848b01b6c7e3421e8ac850a88ab (diff)
downloadspark-50da9e89161faa0ecdc1feb3ffee6c822a742034.tar.gz
spark-50da9e89161faa0ecdc1feb3ffee6c822a742034.tar.bz2
spark-50da9e89161faa0ecdc1feb3ffee6c822a742034.zip
[SPARK-7226] [SPARKR] Support math functions in R DataFrame
Author: qhuang <qian.huang@intel.com> Closes #6170 from hqzizania/master and squashes the following commits: f20c39f [qhuang] add tests units and fixes 2a7d121 [qhuang] use a function name more familiar to R users 07aa72e [qhuang] Support math functions in R DataFrame
Diffstat (limited to 'R')
-rw-r--r--R/pkg/NAMESPACE23
-rw-r--r--R/pkg/R/column.R36
-rw-r--r--R/pkg/R/generics.R20
-rw-r--r--R/pkg/inst/tests/test_sparkSQL.R24
4 files changed, 100 insertions, 3 deletions
diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE
index ba29614e7b..64ffdcffc9 100644
--- a/R/pkg/NAMESPACE
+++ b/R/pkg/NAMESPACE
@@ -59,33 +59,56 @@ exportMethods("arrange",
exportClasses("Column")
exportMethods("abs",
+ "acos",
"alias",
"approxCountDistinct",
"asc",
+ "asin",
+ "atan",
+ "atan2",
"avg",
"cast",
+ "cbrt",
+ "ceiling",
"contains",
+ "cos",
+ "cosh",
"countDistinct",
"desc",
"endsWith",
+ "exp",
+ "expm1",
+ "floor",
"getField",
"getItem",
+ "hypot",
"isNotNull",
"isNull",
"last",
"like",
+ "log",
+ "log10",
+ "log1p",
"lower",
"max",
"mean",
"min",
"n",
"n_distinct",
+ "rint",
"rlike",
+ "sign",
+ "sin",
+ "sinh",
"sqrt",
"startsWith",
"substr",
"sum",
"sumDistinct",
+ "tan",
+ "tanh",
+ "toDegrees",
+ "toRadians",
"upper")
exportClasses("GroupedData")
diff --git a/R/pkg/R/column.R b/R/pkg/R/column.R
index 9a68445ab4..80e92d3105 100644
--- a/R/pkg/R/column.R
+++ b/R/pkg/R/column.R
@@ -55,12 +55,17 @@ operators <- list(
"+" = "plus", "-" = "minus", "*" = "multiply", "/" = "divide", "%%" = "mod",
"==" = "equalTo", ">" = "gt", "<" = "lt", "!=" = "notEqual", "<=" = "leq", ">=" = "geq",
# we can not override `&&` and `||`, so use `&` and `|` instead
- "&" = "and", "|" = "or" #, "!" = "unary_$bang"
+ "&" = "and", "|" = "or", #, "!" = "unary_$bang"
+ "^" = "pow"
)
column_functions1 <- c("asc", "desc", "isNull", "isNotNull")
column_functions2 <- c("like", "rlike", "startsWith", "endsWith", "getField", "getItem", "contains")
functions <- c("min", "max", "sum", "avg", "mean", "count", "abs", "sqrt",
- "first", "last", "lower", "upper", "sumDistinct")
+ "first", "last", "lower", "upper", "sumDistinct",
+ "acos", "asin", "atan", "cbrt", "ceiling", "cos", "cosh", "exp",
+ "expm1", "floor", "log", "log10", "log1p", "rint", "sign",
+ "sin", "sinh", "tan", "tanh", "toDegrees", "toRadians")
+binary_mathfunctions<- c("atan2", "hypot")
createOperator <- function(op) {
setMethod(op,
@@ -76,7 +81,11 @@ createOperator <- function(op) {
if (class(e2) == "Column") {
e2 <- e2@jc
}
- callJMethod(e1@jc, operators[[op]], e2)
+ if (op == "^") {
+ jc <- callJStatic("org.apache.spark.sql.functions", operators[[op]], e1@jc, e2)
+ } else {
+ callJMethod(e1@jc, operators[[op]], e2)
+ }
}
column(jc)
})
@@ -106,11 +115,29 @@ createStaticFunction <- function(name) {
setMethod(name,
signature(x = "Column"),
function(x) {
+ if (name == "ceiling") {
+ name <- "ceil"
+ }
+ if (name == "sign") {
+ name <- "signum"
+ }
jc <- callJStatic("org.apache.spark.sql.functions", name, x@jc)
column(jc)
})
}
+createBinaryMathfunctions <- function(name) {
+ setMethod(name,
+ signature(y = "Column"),
+ function(y, x) {
+ if (class(x) == "Column") {
+ x <- x@jc
+ }
+ jc <- callJStatic("org.apache.spark.sql.functions", name, y@jc, x)
+ column(jc)
+ })
+}
+
createMethods <- function() {
for (op in names(operators)) {
createOperator(op)
@@ -124,6 +151,9 @@ createMethods <- function() {
for (x in functions) {
createStaticFunction(x)
}
+ for (name in binary_mathfunctions) {
+ createBinaryMathfunctions(name)
+ }
}
createMethods()
diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R
index 6d2bfb1181..a23d3b217b 100644
--- a/R/pkg/R/generics.R
+++ b/R/pkg/R/generics.R
@@ -554,6 +554,10 @@ setGeneric("cast", function(x, dataType) { standardGeneric("cast") })
#' @rdname column
#' @export
+setGeneric("cbrt", function(x) { standardGeneric("cbrt") })
+
+#' @rdname column
+#' @export
setGeneric("contains", function(x, ...) { standardGeneric("contains") })
#' @rdname column
#' @export
@@ -577,6 +581,10 @@ setGeneric("getItem", function(x, ...) { standardGeneric("getItem") })
#' @rdname column
#' @export
+setGeneric("hypot", function(y, x) { standardGeneric("hypot") })
+
+#' @rdname column
+#' @export
setGeneric("isNull", function(x) { standardGeneric("isNull") })
#' @rdname column
@@ -605,6 +613,10 @@ setGeneric("n_distinct", function(x, ...) { standardGeneric("n_distinct") })
#' @rdname column
#' @export
+setGeneric("rint", function(x, ...) { standardGeneric("rint") })
+
+#' @rdname column
+#' @export
setGeneric("rlike", function(x, ...) { standardGeneric("rlike") })
#' @rdname column
@@ -617,5 +629,13 @@ setGeneric("sumDistinct", function(x) { standardGeneric("sumDistinct") })
#' @rdname column
#' @export
+setGeneric("toDegrees", function(x) { standardGeneric("toDegrees") })
+
+#' @rdname column
+#' @export
+setGeneric("toRadians", function(x) { standardGeneric("toRadians") })
+
+#' @rdname column
+#' @export
setGeneric("upper", function(x) { standardGeneric("upper") })
diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R
index 1109e8fdba..3e5658eb5b 100644
--- a/R/pkg/inst/tests/test_sparkSQL.R
+++ b/R/pkg/inst/tests/test_sparkSQL.R
@@ -530,6 +530,7 @@ test_that("column operators", {
c2 <- (- c + 1 - 2) * 3 / 4.0
c3 <- (c + c2 - c2) * c2 %% c2
c4 <- (c > c2) & (c2 <= c3) | (c == c2) & (c2 != c3)
+ c5 <- c2 ^ c3 ^ c4
})
test_that("column functions", {
@@ -538,6 +539,29 @@ test_that("column functions", {
c3 <- lower(c) + upper(c) + first(c) + last(c)
c4 <- approxCountDistinct(c) + countDistinct(c) + cast(c, "string")
c5 <- n(c) + n_distinct(c)
+ c5 <- acos(c) + asin(c) + atan(c) + cbrt(c)
+ c6 <- ceiling(c) + cos(c) + cosh(c) + exp(c) + expm1(c)
+ c7 <- floor(c) + log(c) + log10(c) + log1p(c) + rint(c)
+ c8 <- sign(c) + sin(c) + sinh(c) + tan(c) + tanh(c)
+ c9 <- toDegrees(c) + toRadians(c)
+})
+
+test_that("column binary mathfunctions", {
+ lines <- c("{\"a\":1, \"b\":5}",
+ "{\"a\":2, \"b\":6}",
+ "{\"a\":3, \"b\":7}",
+ "{\"a\":4, \"b\":8}")
+ jsonPathWithDup <- tempfile(pattern="sparkr-test", fileext=".tmp")
+ writeLines(lines, jsonPathWithDup)
+ df <- jsonFile(sqlCtx, jsonPathWithDup)
+ expect_equal(collect(select(df, atan2(df$a, df$b)))[1, "ATAN2(a, b)"], atan2(1, 5))
+ expect_equal(collect(select(df, atan2(df$a, df$b)))[2, "ATAN2(a, b)"], atan2(2, 6))
+ expect_equal(collect(select(df, atan2(df$a, df$b)))[3, "ATAN2(a, b)"], atan2(3, 7))
+ expect_equal(collect(select(df, atan2(df$a, df$b)))[4, "ATAN2(a, b)"], atan2(4, 8))
+ expect_equal(collect(select(df, hypot(df$a, df$b)))[1, "HYPOT(a, b)"], sqrt(1^2 + 5^2))
+ expect_equal(collect(select(df, hypot(df$a, df$b)))[2, "HYPOT(a, b)"], sqrt(2^2 + 6^2))
+ expect_equal(collect(select(df, hypot(df$a, df$b)))[3, "HYPOT(a, b)"], sqrt(3^2 + 7^2))
+ expect_equal(collect(select(df, hypot(df$a, df$b)))[4, "HYPOT(a, b)"], sqrt(4^2 + 8^2))
})
test_that("string operators", {