diff options
-rw-r--r-- | R/pkg/NAMESPACE | 28 | ||||
-rw-r--r-- | R/pkg/R/functions.R | 415 | ||||
-rw-r--r-- | R/pkg/R/generics.R | 113 | ||||
-rw-r--r-- | R/pkg/inst/tests/test_sparkSQL.R | 98 | ||||
-rw-r--r-- | core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala | 1 |
5 files changed, 649 insertions, 6 deletions
diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 8fa12d5ade..111a2dc30d 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -84,6 +84,7 @@ exportClasses("Column") exportMethods("abs", "acos", + "add_months", "alias", "approxCountDistinct", "asc", @@ -101,12 +102,17 @@ exportMethods("abs", "ceil", "ceiling", "concat", + "concat_ws", "contains", + "conv", "cos", "cosh", "count", "countDistinct", "crc32", + "date_add", + "date_format", + "date_sub", "datediff", "dayofmonth", "dayofyear", @@ -115,9 +121,14 @@ exportMethods("abs", "exp", "explode", "expm1", + "expr", "factorial", "first", "floor", + "format_number", + "format_string", + "from_unixtime", + "from_utc_timestamp", "getField", "getItem", "greatest", @@ -125,6 +136,7 @@ exportMethods("abs", "hour", "hypot", "initcap", + "instr", "isNaN", "isNotNull", "isNull", @@ -135,11 +147,13 @@ exportMethods("abs", "levenshtein", "like", "lit", + "locate", "log", "log10", "log1p", "log2", "lower", + "lpad", "ltrim", "max", "md5", @@ -152,16 +166,26 @@ exportMethods("abs", "n_distinct", "nanvl", "negate", + "next_day", "otherwise", "pmod", "quarter", + "rand", + "randn", + "regexp_extract", + "regexp_replace", "reverse", "rint", "rlike", "round", + "rpad", "rtrim", "second", "sha1", + "sha2", + "shiftLeft", + "shiftRight", + "shiftRightUnsigned", "sign", "signum", "sin", @@ -171,6 +195,7 @@ exportMethods("abs", "sqrt", "startsWith", "substr", + "substring_index", "sum", "sumDistinct", "tan", @@ -178,9 +203,12 @@ exportMethods("abs", "toDegrees", "toRadians", "to_date", + "to_utc_timestamp", + "translate", "trim", "unbase64", "unhex", + "unix_timestamp", "upper", "weekofyear", "when", diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index 366c230e1e..5dba0887d1 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -166,6 +166,421 @@ setMethod("n", signature(x = "Column"), count(x) }) +#' date_format +#' +#' Converts a date/timestamp/string to a value of string in the format specified by the date +#' format given by the second argument. +#' +#' A pattern could be for instance `dd.MM.yyyy` and could return a string like '18.03.1993'. All +#' pattern letters of `java.text.SimpleDateFormat` can be used. +#' +#' NOTE: Use when ever possible specialized functions like `year`. These benefit from a +#' specialized implementation. +#' +#' @rdname functions +setMethod("date_format", signature(y = "Column", x = "character"), + function(y, x) { + jc <- callJStatic("org.apache.spark.sql.functions", "date_format", y@jc, x) + column(jc) + }) + +#' from_utc_timestamp +#' +#' Assumes given timestamp is UTC and converts to given timezone. +#' +#' @rdname functions +setMethod("from_utc_timestamp", signature(y = "Column", x = "character"), + function(y, x) { + jc <- callJStatic("org.apache.spark.sql.functions", "from_utc_timestamp", y@jc, x) + column(jc) + }) + +#' instr +#' +#' Locate the position of the first occurrence of substr column in the given string. +#' Returns null if either of the arguments are null. +#' +#' NOTE: The position is not zero based, but 1 based index, returns 0 if substr +#' could not be found in str. +#' +#' @rdname functions +setMethod("instr", signature(y = "Column", x = "character"), + function(y, x) { + jc <- callJStatic("org.apache.spark.sql.functions", "instr", y@jc, x) + column(jc) + }) + +#' next_day +#' +#' Given a date column, returns the first date which is later than the value of the date column +#' that is on the specified day of the week. +#' +#' For example, `next <- day('2015-07-27', "Sunday")` returns 2015-08-02 because that is the first +#' Sunday after 2015-07-27. +#' +#' Day of the week parameter is case insensitive, and accepts: +#' "Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun". +#' +#' @rdname functions +setMethod("next_day", signature(y = "Column", x = "character"), + function(y, x) { + jc <- callJStatic("org.apache.spark.sql.functions", "next_day", y@jc, x) + column(jc) + }) + +#' to_utc_timestamp +#' +#' Assumes given timestamp is in given timezone and converts to UTC. +#' +#' @rdname functions +setMethod("to_utc_timestamp", signature(y = "Column", x = "character"), + function(y, x) { + jc <- callJStatic("org.apache.spark.sql.functions", "to_utc_timestamp", y@jc, x) + column(jc) + }) + +#' add_months +#' +#' Returns the date that is numMonths after startDate. +#' +#' @rdname functions +setMethod("add_months", signature(y = "Column", x = "numeric"), + function(y, x) { + jc <- callJStatic("org.apache.spark.sql.functions", "add_months", y@jc, as.integer(x)) + column(jc) + }) + +#' date_add +#' +#' Returns the date that is `days` days after `start` +#' +#' @rdname functions +setMethod("date_add", signature(y = "Column", x = "numeric"), + function(y, x) { + jc <- callJStatic("org.apache.spark.sql.functions", "date_add", y@jc, as.integer(x)) + column(jc) + }) + +#' date_sub +#' +#' Returns the date that is `days` days before `start` +#' +#' @rdname functions +setMethod("date_sub", signature(y = "Column", x = "numeric"), + function(y, x) { + jc <- callJStatic("org.apache.spark.sql.functions", "date_sub", y@jc, as.integer(x)) + column(jc) + }) + +#' format_number +#' +#' Formats numeric column x to a format like '#,###,###.##', rounded to d decimal places, +#' and returns the result as a string column. +#' +#' If d is 0, the result has no decimal point or fractional part. +#' If d < 0, the result will be null.' +#' +#' @rdname functions +setMethod("format_number", signature(y = "Column", x = "numeric"), + function(y, x) { + jc <- callJStatic("org.apache.spark.sql.functions", + "format_number", + y@jc, as.integer(x)) + column(jc) + }) + +#' sha2 +#' +#' Calculates the SHA-2 family of hash functions of a binary column and +#' returns the value as a hex string. +#' +#' @rdname functions +#' @param y column to compute SHA-2 on. +#' @param x one of 224, 256, 384, or 512. +setMethod("sha2", signature(y = "Column", x = "numeric"), + function(y, x) { + jc <- callJStatic("org.apache.spark.sql.functions", "sha2", y@jc, as.integer(x)) + column(jc) + }) + +#' shiftLeft +#' +#' Shift the the given value numBits left. If the given value is a long value, this function +#' will return a long value else it will return an integer value. +#' +#' @rdname functions +setMethod("shiftLeft", signature(y = "Column", x = "numeric"), + function(y, x) { + jc <- callJStatic("org.apache.spark.sql.functions", + "shiftLeft", + y@jc, as.integer(x)) + column(jc) + }) + +#' shiftRight +#' +#' Shift the the given value numBits right. If the given value is a long value, it will return +#' a long value else it will return an integer value. +#' +#' @rdname functions +setMethod("shiftRight", signature(y = "Column", x = "numeric"), + function(y, x) { + jc <- callJStatic("org.apache.spark.sql.functions", + "shiftRight", + y@jc, as.integer(x)) + column(jc) + }) + +#' shiftRightUnsigned +#' +#' Unsigned shift the the given value numBits right. If the given value is a long value, +#' it will return a long value else it will return an integer value. +#' +#' @rdname functions +setMethod("shiftRightUnsigned", signature(y = "Column", x = "numeric"), + function(y, x) { + jc <- callJStatic("org.apache.spark.sql.functions", + "shiftRightUnsigned", + y@jc, as.integer(x)) + column(jc) + }) + +#' concat_ws +#' +#' Concatenates multiple input string columns together into a single string column, +#' using the given separator. +#' +#' @rdname functions +setMethod("concat_ws", signature(sep = "character", x = "Column"), + function(sep, x, ...) { + jcols <- listToSeq(lapply(list(x, ...), function(x) { x@jc })) + jc <- callJStatic("org.apache.spark.sql.functions", "concat_ws", sep, jcols) + column(jc) + }) + +#' conv +#' +#' Convert a number in a string column from one base to another. +#' +#' @rdname functions +setMethod("conv", signature(x = "Column", fromBase = "numeric", toBase = "numeric"), + function(x, fromBase, toBase) { + fromBase <- as.integer(fromBase) + toBase <- as.integer(toBase) + jc <- callJStatic("org.apache.spark.sql.functions", + "conv", + x@jc, fromBase, toBase) + column(jc) + }) + +#' expr +#' +#' Parses the expression string into the column that it represents, similar to +#' DataFrame.selectExpr +#' +#' @rdname functions +setMethod("expr", signature(x = "character"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "expr", x) + column(jc) + }) + +#' format_string +#' +#' Formats the arguments in printf-style and returns the result as a string column. +#' +#' @rdname functions +setMethod("format_string", signature(format = "character", x = "Column"), + function(format, x, ...) { + jcols <- listToSeq(lapply(list(x, ...), function(arg) { arg@jc })) + jc <- callJStatic("org.apache.spark.sql.functions", + "format_string", + format, jcols) + column(jc) + }) + +#' from_unixtime +#' +#' Converts the number of seconds from unix epoch (1970-01-01 00:00:00 UTC) to a string +#' representing the timestamp of that moment in the current system time zone in the given +#' format. +#' +#' @rdname functions +setMethod("from_unixtime", signature(x = "Column"), + function(x, format = "yyyy-MM-dd HH:mm:ss") { + jc <- callJStatic("org.apache.spark.sql.functions", + "from_unixtime", + x@jc, format) + column(jc) + }) + +#' locate +#' +#' Locate the position of the first occurrence of substr. +#' NOTE: The position is not zero based, but 1 based index, returns 0 if substr +#' could not be found in str. +#' +#' @rdname functions +setMethod("locate", signature(substr = "character", str = "Column"), + function(substr, str, pos = 0) { + jc <- callJStatic("org.apache.spark.sql.functions", + "locate", + substr, str@jc, as.integer(pos)) + column(jc) + }) + +#' lpad +#' +#' Left-pad the string column with +#' +#' @rdname functions +setMethod("lpad", signature(x = "Column", len = "numeric", pad = "character"), + function(x, len, pad) { + jc <- callJStatic("org.apache.spark.sql.functions", + "lpad", + x@jc, as.integer(len), pad) + column(jc) + }) + +#' rand +#' +#' Generate a random column with i.i.d. samples from U[0.0, 1.0]. +#' +#' @rdname functions +setMethod("rand", signature(seed = "missing"), + function(seed) { + jc <- callJStatic("org.apache.spark.sql.functions", "rand") + column(jc) + }) +setMethod("rand", signature(seed = "numeric"), + function(seed) { + jc <- callJStatic("org.apache.spark.sql.functions", "rand", as.integer(seed)) + column(jc) + }) + +#' randn +#' +#' Generate a column with i.i.d. samples from the standard normal distribution. +#' +#' @rdname functions +setMethod("randn", signature(seed = "missing"), + function(seed) { + jc <- callJStatic("org.apache.spark.sql.functions", "randn") + column(jc) + }) +setMethod("randn", signature(seed = "numeric"), + function(seed) { + jc <- callJStatic("org.apache.spark.sql.functions", "randn", as.integer(seed)) + column(jc) + }) + +#' regexp_extract +#' +#' Extract a specific(idx) group identified by a java regex, from the specified string column. +#' +#' @rdname functions +setMethod("regexp_extract", + signature(x = "Column", pattern = "character", idx = "numeric"), + function(x, pattern, idx) { + jc <- callJStatic("org.apache.spark.sql.functions", + "regexp_extract", + x@jc, pattern, as.integer(idx)) + column(jc) + }) + +#' regexp_replace +#' +#' Replace all substrings of the specified string value that match regexp with rep. +#' +#' @rdname functions +setMethod("regexp_replace", + signature(x = "Column", pattern = "character", replacement = "character"), + function(x, pattern, replacement) { + jc <- callJStatic("org.apache.spark.sql.functions", + "regexp_replace", + x@jc, pattern, replacement) + column(jc) + }) + +#' rpad +#' +#' Right-padded with pad to a length of len. +#' +#' @rdname functions +setMethod("rpad", signature(x = "Column", len = "numeric", pad = "character"), + function(x, len, pad) { + jc <- callJStatic("org.apache.spark.sql.functions", + "rpad", + x@jc, as.integer(len), pad) + column(jc) + }) + +#' substring_index +#' +#' Returns the substring from string str before count occurrences of the delimiter delim. +#' If count is positive, everything the left of the final delimiter (counting from left) is +#' returned. If count is negative, every to the right of the final delimiter (counting from the +#' right) is returned. substring <- index performs a case-sensitive match when searching for delim. +#' +#' @rdname functions +setMethod("substring_index", + signature(x = "Column", delim = "character", count = "numeric"), + function(x, delim, count) { + jc <- callJStatic("org.apache.spark.sql.functions", + "substring_index", + x@jc, delim, as.integer(count)) + column(jc) + }) + +#' translate +#' +#' Translate any character in the src by a character in replaceString. +#' The characters in replaceString is corresponding to the characters in matchingString. +#' The translate will happen when any character in the string matching with the character +#' in the matchingString. +#' +#' @rdname functions +setMethod("translate", + signature(x = "Column", matchingString = "character", replaceString = "character"), + function(x, matchingString, replaceString) { + jc <- callJStatic("org.apache.spark.sql.functions", + "translate", x@jc, matchingString, replaceString) + column(jc) + }) + +#' unix_timestamp +#' +#' Gets current Unix timestamp in seconds. +#' +#' @rdname functions +setMethod("unix_timestamp", signature(x = "missing", format = "missing"), + function(x, format) { + jc <- callJStatic("org.apache.spark.sql.functions", "unix_timestamp") + column(jc) + }) +#' unix_timestamp +#' +#' Converts time string in format yyyy-MM-dd HH:mm:ss to Unix timestamp (in seconds), +#' using the default timezone and the default locale, return null if fail. +#' +#' @rdname functions +setMethod("unix_timestamp", signature(x = "Column", format = "missing"), + function(x, format) { + jc <- callJStatic("org.apache.spark.sql.functions", "unix_timestamp", x@jc) + column(jc) + }) +#' unix_timestamp +#' +#' Convert time string with given pattern +#' (see [http://docs.oracle.com/javase/tutorial/i18n/format/simpleDateFormat.html]) +#' to Unix time stamp (in seconds), return null if fail. +#' +#' @rdname functions +setMethod("unix_timestamp", signature(x = "Column", format = "character"), + function(x, format = "yyyy-MM-dd HH:mm:ss") { + jc <- callJStatic("org.apache.spark.sql.functions", "unix_timestamp", x@jc, format) + column(jc) + }) #' when #' #' Evaluates a list of conditions and returns one of multiple possible result expressions. diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 338b32e648..84cb8dfdaa 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -664,6 +664,10 @@ setGeneric("otherwise", function(x, value) { standardGeneric("otherwise") }) #' @rdname functions #' @export +setGeneric("add_months", function(y, x) { standardGeneric("add_months") }) + +#' @rdname functions +#' @export setGeneric("ascii", function(x) { standardGeneric("ascii") }) #' @rdname functions @@ -696,6 +700,14 @@ setGeneric("concat", function(x, ...) { standardGeneric("concat") }) #' @rdname functions #' @export +setGeneric("concat_ws", function(sep, x, ...) { standardGeneric("concat_ws") }) + +#' @rdname functions +#' @export +setGeneric("conv", function(x, fromBase, toBase) { standardGeneric("conv") }) + +#' @rdname functions +#' @export setGeneric("crc32", function(x) { standardGeneric("crc32") }) #' @rdname functions @@ -704,6 +716,18 @@ setGeneric("datediff", function(y, x) { standardGeneric("datediff") }) #' @rdname functions #' @export +setGeneric("date_add", function(y, x) { standardGeneric("date_add") }) + +#' @rdname functions +#' @export +setGeneric("date_format", function(y, x) { standardGeneric("date_format") }) + +#' @rdname functions +#' @export +setGeneric("date_sub", function(y, x) { standardGeneric("date_sub") }) + +#' @rdname functions +#' @export setGeneric("dayofmonth", function(x) { standardGeneric("dayofmonth") }) #' @rdname functions @@ -716,6 +740,26 @@ setGeneric("explode", function(x) { standardGeneric("explode") }) #' @rdname functions #' @export +setGeneric("expr", function(x) { standardGeneric("expr") }) + +#' @rdname functions +#' @export +setGeneric("from_utc_timestamp", function(y, x) { standardGeneric("from_utc_timestamp") }) + +#' @rdname functions +#' @export +setGeneric("format_number", function(y, x) { standardGeneric("format_number") }) + +#' @rdname functions +#' @export +setGeneric("format_string", function(format, x, ...) { standardGeneric("format_string") }) + +#' @rdname functions +#' @export +setGeneric("from_unixtime", function(x, ...) { standardGeneric("from_unixtime") }) + +#' @rdname functions +#' @export setGeneric("greatest", function(x, ...) { standardGeneric("greatest") }) #' @rdname functions @@ -732,6 +776,10 @@ setGeneric("initcap", function(x) { standardGeneric("initcap") }) #' @rdname functions #' @export +setGeneric("instr", function(y, x) { standardGeneric("instr") }) + +#' @rdname functions +#' @export setGeneric("isNaN", function(x) { standardGeneric("isNaN") }) #' @rdname functions @@ -752,10 +800,18 @@ setGeneric("lit", function(x) { standardGeneric("lit") }) #' @rdname functions #' @export +setGeneric("locate", function(substr, str, ...) { standardGeneric("locate") }) + +#' @rdname functions +#' @export setGeneric("lower", function(x) { standardGeneric("lower") }) #' @rdname functions #' @export +setGeneric("lpad", function(x, len, pad) { standardGeneric("lpad") }) + +#' @rdname functions +#' @export setGeneric("ltrim", function(x) { standardGeneric("ltrim") }) #' @rdname functions @@ -784,6 +840,10 @@ setGeneric("negate", function(x) { standardGeneric("negate") }) #' @rdname functions #' @export +setGeneric("next_day", function(y, x) { standardGeneric("next_day") }) + +#' @rdname functions +#' @export setGeneric("pmod", function(y, x) { standardGeneric("pmod") }) #' @rdname functions @@ -792,10 +852,31 @@ setGeneric("quarter", function(x) { standardGeneric("quarter") }) #' @rdname functions #' @export +setGeneric("rand", function(seed) { standardGeneric("rand") }) + +#' @rdname functions +#' @export +setGeneric("randn", function(seed) { standardGeneric("randn") }) + +#' @rdname functions +#' @export +setGeneric("regexp_extract", function(x, pattern, idx) { standardGeneric("regexp_extract") }) + +#' @rdname functions +#' @export +setGeneric("regexp_replace", + function(x, pattern, replacement) { standardGeneric("regexp_replace") }) + +#' @rdname functions +#' @export setGeneric("reverse", function(x) { standardGeneric("reverse") }) #' @rdname functions #' @export +setGeneric("rpad", function(x, len, pad) { standardGeneric("rpad") }) + +#' @rdname functions +#' @export setGeneric("rtrim", function(x) { standardGeneric("rtrim") }) #' @rdname functions @@ -808,6 +889,22 @@ setGeneric("sha1", function(x) { standardGeneric("sha1") }) #' @rdname functions #' @export +setGeneric("sha2", function(y, x) { standardGeneric("sha2") }) + +#' @rdname functions +#' @export +setGeneric("shiftLeft", function(y, x) { standardGeneric("shiftLeft") }) + +#' @rdname functions +#' @export +setGeneric("shiftRight", function(y, x) { standardGeneric("shiftRight") }) + +#' @rdname functions +#' @export +setGeneric("shiftRightUnsigned", function(y, x) { standardGeneric("shiftRightUnsigned") }) + +#' @rdname functions +#' @export setGeneric("signum", function(x) { standardGeneric("signum") }) #' @rdname functions @@ -820,6 +917,10 @@ setGeneric("soundex", function(x) { standardGeneric("soundex") }) #' @rdname functions #' @export +setGeneric("substring_index", function(x, delim, count) { standardGeneric("substring_index") }) + +#' @rdname functions +#' @export setGeneric("sumDistinct", function(x) { standardGeneric("sumDistinct") }) #' @rdname functions @@ -836,6 +937,14 @@ setGeneric("to_date", function(x) { standardGeneric("to_date") }) #' @rdname functions #' @export +setGeneric("to_utc_timestamp", function(y, x) { standardGeneric("to_utc_timestamp") }) + +#' @rdname functions +#' @export +setGeneric("translate", function(x, matchingString, replaceString) { standardGeneric("translate") }) + +#' @rdname functions +#' @export setGeneric("trim", function(x) { standardGeneric("trim") }) #' @rdname functions @@ -848,6 +957,10 @@ setGeneric("unhex", function(x) { standardGeneric("unhex") }) #' @rdname functions #' @export +setGeneric("unix_timestamp", function(x, format) { standardGeneric("unix_timestamp") }) + +#' @rdname functions +#' @export setGeneric("upper", function(x) { standardGeneric("upper") }) #' @rdname functions diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index 841de657df..670017ed34 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -598,6 +598,11 @@ test_that("selectExpr() on a DataFrame", { expect_equal(count(selected2), 3) }) +test_that("expr() on a DataFrame", { + df <- jsonFile(sqlContext, jsonPath) + expect_equal(collect(select(df, expr("abs(-123)")))[1, 1], 123) +}) + test_that("column calculation", { df <- jsonFile(sqlContext, jsonPath) d <- collect(select(df, alias(df$age + 1, "age2"))) @@ -667,16 +672,15 @@ test_that("column functions", { c <- SparkR:::col("a") c1 <- abs(c) + acos(c) + approxCountDistinct(c) + ascii(c) + asin(c) + atan(c) c2 <- avg(c) + base64(c) + bin(c) + bitwiseNOT(c) + cbrt(c) + ceil(c) + cos(c) - c3 <- cosh(c) + count(c) + crc32(c) + dayofmonth(c) + dayofyear(c) + exp(c) + c3 <- cosh(c) + count(c) + crc32(c) + exp(c) c4 <- explode(c) + expm1(c) + factorial(c) + first(c) + floor(c) + hex(c) c5 <- hour(c) + initcap(c) + isNaN(c) + last(c) + last_day(c) + length(c) c6 <- log(c) + (c) + log1p(c) + log2(c) + lower(c) + ltrim(c) + max(c) + md5(c) - c7 <- mean(c) + min(c) + minute(c) + month(c) + negate(c) + quarter(c) - c8 <- reverse(c) + rint(c) + round(c) + rtrim(c) + second(c) + sha1(c) + c7 <- mean(c) + min(c) + month(c) + negate(c) + quarter(c) + c8 <- reverse(c) + rint(c) + round(c) + rtrim(c) + sha1(c) c9 <- signum(c) + sin(c) + sinh(c) + size(c) + soundex(c) + sqrt(c) + sum(c) c10 <- sumDistinct(c) + tan(c) + tanh(c) + toDegrees(c) + toRadians(c) - c11 <- to_date(c) + trim(c) + unbase64(c) + unhex(c) + upper(c) + weekofyear(c) - c12 <- year(c) + c11 <- to_date(c) + trim(c) + unbase64(c) + unhex(c) + upper(c) df <- jsonFile(sqlContext, jsonPath) df2 <- select(df, between(df$age, c(20, 30)), between(df$age, c(10, 20))) @@ -689,8 +693,11 @@ test_that("column functions", { expect_equal(collect(df3)[[1, 1]], TRUE) expect_equal(collect(df3)[[2, 1]], FALSE) expect_equal(collect(df3)[[3, 1]], TRUE) -}) + df4 <- createDataFrame(sqlContext, list(list(a = "010101"))) + expect_equal(collect(select(df4, conv(df4$a, 2, 16)))[1, 1], "15") +}) +# test_that("column binary mathfunctions", { lines <- c("{\"a\":1, \"b\":5}", "{\"a\":2, \"b\":6}", @@ -709,6 +716,13 @@ test_that("column binary mathfunctions", { expect_equal(collect(select(df, hypot(df$a, df$b)))[3, "HYPOT(a, b)"], sqrt(3^2 + 7^2)) expect_equal(collect(select(df, hypot(df$a, df$b)))[4, "HYPOT(a, b)"], sqrt(4^2 + 8^2)) ## nolint end + expect_equal(collect(select(df, shiftLeft(df$b, 1)))[4, 1], 16) + expect_equal(collect(select(df, shiftRight(df$b, 1)))[4, 1], 4) + expect_equal(collect(select(df, shiftRightUnsigned(df$b, 1)))[4, 1], 4) + expect_equal(class(collect(select(df, rand()))[2, 1]), "numeric") + expect_equal(collect(select(df, rand(1)))[1, 1], 0.45, tolerance = 0.01) + expect_equal(class(collect(select(df, randn()))[2, 1]), "numeric") + expect_equal(collect(select(df, randn(1)))[1, 1], -0.0111, tolerance = 0.01) }) test_that("string operators", { @@ -718,6 +732,78 @@ test_that("string operators", { expect_equal(first(select(df, substr(df$name, 1, 2)))[[1]], "Mi") expect_equal(collect(select(df, cast(df$age, "string")))[[2, 1]], "30") expect_equal(collect(select(df, concat(df$name, lit(":"), df$age)))[[2, 1]], "Andy:30") + expect_equal(collect(select(df, concat_ws(":", df$name)))[[2, 1]], "Andy") + expect_equal(collect(select(df, concat_ws(":", df$name, df$age)))[[2, 1]], "Andy:30") + expect_equal(collect(select(df, instr(df$name, "i")))[, 1], c(2, 0, 5)) + expect_equal(collect(select(df, format_number(df$age, 2)))[2, 1], "30.00") + expect_equal(collect(select(df, sha1(df$name)))[2, 1], + "ab5a000e88b5d9d0fa2575f5c6263eb93452405d") + expect_equal(collect(select(df, sha2(df$name, 256)))[2, 1], + "80f2aed3c618c423ddf05a2891229fba44942d907173152442cf6591441ed6dc") + expect_equal(collect(select(df, format_string("Name:%s", df$name)))[2, 1], "Name:Andy") + expect_equal(collect(select(df, format_string("%s, %d", df$name, df$age)))[2, 1], "Andy, 30") + expect_equal(collect(select(df, regexp_extract(df$name, "(n.y)", 1)))[2, 1], "ndy") + expect_equal(collect(select(df, regexp_replace(df$name, "(n.y)", "ydn")))[2, 1], "Aydn") + + l2 <- list(list(a = "aaads")) + df2 <- createDataFrame(sqlContext, l2) + expect_equal(collect(select(df2, locate("aa", df2$a)))[1, 1], 1) + expect_equal(collect(select(df2, locate("aa", df2$a, 1)))[1, 1], 2) + expect_equal(collect(select(df2, lpad(df2$a, 8, "#")))[1, 1], "###aaads") + expect_equal(collect(select(df2, rpad(df2$a, 8, "#")))[1, 1], "aaads###") + + l3 <- list(list(a = "a.b.c.d")) + df3 <- createDataFrame(sqlContext, l3) + expect_equal(collect(select(df3, substring_index(df3$a, ".", 2)))[1, 1], "a.b") + expect_equal(collect(select(df3, substring_index(df3$a, ".", -3)))[1, 1], "b.c.d") + expect_equal(collect(select(df3, translate(df3$a, "bc", "12")))[1, 1], "a.1.2.d") +}) + +test_that("date functions on a DataFrame", { + .originalTimeZone <- Sys.getenv("TZ") + Sys.setenv(TZ = "UTC") + l <- list(list(a = 1L, b = as.Date("2012-12-13")), + list(a = 2L, b = as.Date("2013-12-14")), + list(a = 3L, b = as.Date("2014-12-15"))) + df <- createDataFrame(sqlContext, l) + expect_equal(collect(select(df, dayofmonth(df$b)))[, 1], c(13, 14, 15)) + expect_equal(collect(select(df, dayofyear(df$b)))[, 1], c(348, 348, 349)) + expect_equal(collect(select(df, weekofyear(df$b)))[, 1], c(50, 50, 51)) + expect_equal(collect(select(df, year(df$b)))[, 1], c(2012, 2013, 2014)) + expect_equal(collect(select(df, month(df$b)))[, 1], c(12, 12, 12)) + expect_equal(collect(select(df, last_day(df$b)))[, 1], + c(as.Date("2012-12-31"), as.Date("2013-12-31"), as.Date("2014-12-31"))) + expect_equal(collect(select(df, next_day(df$b, "MONDAY")))[, 1], + c(as.Date("2012-12-17"), as.Date("2013-12-16"), as.Date("2014-12-22"))) + expect_equal(collect(select(df, date_format(df$b, "y")))[, 1], c("2012", "2013", "2014")) + expect_equal(collect(select(df, add_months(df$b, 3)))[, 1], + c(as.Date("2013-03-13"), as.Date("2014-03-14"), as.Date("2015-03-15"))) + expect_equal(collect(select(df, date_add(df$b, 1)))[, 1], + c(as.Date("2012-12-14"), as.Date("2013-12-15"), as.Date("2014-12-16"))) + expect_equal(collect(select(df, date_sub(df$b, 1)))[, 1], + c(as.Date("2012-12-12"), as.Date("2013-12-13"), as.Date("2014-12-14"))) + + l2 <- list(list(a = 1L, b = as.POSIXlt("2012-12-13 12:34:00", tz = "UTC")), + list(a = 2L, b = as.POSIXlt("2014-12-15 01:24:34", tz = "UTC"))) + df2 <- createDataFrame(sqlContext, l2) + expect_equal(collect(select(df2, minute(df2$b)))[, 1], c(34, 24)) + expect_equal(collect(select(df2, second(df2$b)))[, 1], c(0, 34)) + expect_equal(collect(select(df2, from_utc_timestamp(df2$b, "JST")))[, 1], + c(as.POSIXlt("2012-12-13 21:34:00 UTC"), as.POSIXlt("2014-12-15 10:24:34 UTC"))) + expect_equal(collect(select(df2, to_utc_timestamp(df2$b, "JST")))[, 1], + c(as.POSIXlt("2012-12-13 03:34:00 UTC"), as.POSIXlt("2014-12-14 16:24:34 UTC"))) + expect_more_than(collect(select(df2, unix_timestamp()))[1, 1], 0) + expect_more_than(collect(select(df2, unix_timestamp(df2$b)))[1, 1], 0) + expect_more_than(collect(select(df2, unix_timestamp(lit("2015-01-01"), "yyyy-MM-dd")))[1, 1], 0) + + l3 <- list(list(a = 1000), list(a = -1000)) + df3 <- createDataFrame(sqlContext, l3) + result31 <- collect(select(df3, from_unixtime(df3$a))) + expect_equal(grep("\\d{4}-\\d{2}-\\d{2} \\d{2}:\\d{2}:\\d{2}", result31[, 1], perl = TRUE), + c(1, 2)) + result32 <- collect(select(df3, from_unixtime(df3$a, "yyyy"))) + expect_equal(grep("\\d{4}", result32[, 1]), c(1, 2)) + Sys.setenv(TZ = .originalTimeZone) }) test_that("greatest() and least() on a DataFrame", { diff --git a/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala b/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala index 14dac4ed28..6ce02e2ea3 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala @@ -182,6 +182,7 @@ private[r] class RBackendHandler(server: RBackend) if (parameterType.isPrimitive) { parameterWrapperType = parameterType match { case java.lang.Integer.TYPE => classOf[java.lang.Integer] + case java.lang.Long.TYPE => classOf[java.lang.Integer] case java.lang.Double.TYPE => classOf[java.lang.Double] case java.lang.Boolean.TYPE => classOf[java.lang.Boolean] case _ => parameterType |