aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorcafreeman <cfreeman@alteryx.com>2015-04-17 13:42:19 -0700
committerShivaram Venkataraman <shivaram@cs.berkeley.edu>2015-04-17 13:42:19 -0700
commit59e206deb7346148412bbf5ba4ab626718fadf18 (patch)
treecf4435a81197e76957c4afdcc48686a6e46dc5dc
parenta83571acc938582865efb41645aa1e414f339e46 (diff)
downloadspark-59e206deb7346148412bbf5ba4ab626718fadf18.tar.gz
spark-59e206deb7346148412bbf5ba4ab626718fadf18.tar.bz2
spark-59e206deb7346148412bbf5ba4ab626718fadf18.zip
[SPARK-6807] [SparkR] Merge recent SparkR-pkg changes
This PR pulls in recent changes in SparkR-pkg, including cartesian, intersection, sampleByKey, subtract, subtractByKey, except, and some API for StructType and StructField. Author: cafreeman <cfreeman@alteryx.com> Author: Davies Liu <davies@databricks.com> Author: Zongheng Yang <zongheng.y@gmail.com> Author: Shivaram Venkataraman <shivaram.venkataraman@gmail.com> Author: Shivaram Venkataraman <shivaram@cs.berkeley.edu> Author: Sun Rui <rui.sun@intel.com> Closes #5436 from davies/R3 and squashes the following commits: c2b09be [Davies Liu] SQLTypes -> schema a5a02f2 [Davies Liu] Merge branch 'master' of github.com:apache/spark into R3 168b7fe [Davies Liu] sort generics b1fe460 [Davies Liu] fix conflict in README.md e74c04e [Davies Liu] fix schema.R 4f5ac09 [Davies Liu] Merge branch 'master' of github.com:apache/spark into R5 41f8184 [Davies Liu] rm man ae78312 [Davies Liu] Merge pull request #237 from sun-rui/SPARKR-154_3 1bdcb63 [Zongheng Yang] Updates to README.md. 5a553e7 [cafreeman] Use object attribute instead of argument 71372d9 [cafreeman] Update docs and examples 8526d2e71 [cafreeman] Remove `tojson` functions 6ef5f2d [cafreeman] Fix spacing 7741d66 [cafreeman] Rename the SQL DataType function 141efd8 [Shivaram Venkataraman] Merge pull request #245 from hqzizania/upstream 9387402 [Davies Liu] fix style 40199eb [Shivaram Venkataraman] Move except into sorted position 07d0dbc [Sun Rui] [SPARKR-244] Fix test failure after integration of subtract() and subtractByKey() for RDD. 7e8caa3 [Shivaram Venkataraman] Merge pull request #246 from hlin09/fixCombineByKey ed66c81 [cafreeman] Update `subtract` to work with `generics.R` f3ba785 [cafreeman] Fixed duplicate export 275deb4 [cafreeman] Update `NAMESPACE` and tests 1a3b63d [cafreeman] new version of `CreateDF` 836c4bf [cafreeman] Update `createDataFrame` and `toDF` be5d5c1 [cafreeman] refactor schema functions 40338a4 [Zongheng Yang] Merge pull request #244 from sun-rui/SPARKR-154_5 20b97a6 [Zongheng Yang] Merge pull request #234 from hqzizania/assist ba54e34 [Shivaram Venkataraman] Merge pull request #238 from sun-rui/SPARKR-154_4 c9497a3 [Shivaram Venkataraman] Merge pull request #208 from lythesia/master b317aa7 [Zongheng Yang] Merge pull request #243 from hqzizania/master 136a07e [Zongheng Yang] Merge pull request #242 from hqzizania/stats cd66603 [cafreeman] new line at EOF 8b76e81 [Shivaram Venkataraman] Merge pull request #233 from redbaron/fail-early-on-missing-dep 7dd81b7 [cafreeman] Documentation 0e2a94f [cafreeman] Define functions for schema and fields
-rw-r--r--R/pkg/DESCRIPTION2
-rw-r--r--R/pkg/NAMESPACE20
-rw-r--r--R/pkg/R/DataFrame.R18
-rw-r--r--R/pkg/R/RDD.R205
-rw-r--r--R/pkg/R/SQLContext.R44
-rw-r--r--R/pkg/R/SQLTypes.R64
-rw-r--r--R/pkg/R/column.R2
-rw-r--r--R/pkg/R/generics.R46
-rw-r--r--R/pkg/R/group.R2
-rw-r--r--R/pkg/R/pairRDD.R192
-rw-r--r--R/pkg/R/schema.R162
-rw-r--r--R/pkg/R/serialize.R9
-rw-r--r--R/pkg/R/utils.R80
-rw-r--r--R/pkg/inst/tests/test_rdd.R193
-rw-r--r--R/pkg/inst/tests/test_shuffle.R12
-rw-r--r--R/pkg/inst/tests/test_sparkSQL.R35
-rw-r--r--R/pkg/inst/worker/worker.R59
-rw-r--r--core/src/main/scala/org/apache/spark/api/r/RRDD.scala131
-rw-r--r--core/src/main/scala/org/apache/spark/api/r/SerDe.scala14
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala32
20 files changed, 971 insertions, 351 deletions
diff --git a/R/pkg/DESCRIPTION b/R/pkg/DESCRIPTION
index 052f68c6c2..1c1779a763 100644
--- a/R/pkg/DESCRIPTION
+++ b/R/pkg/DESCRIPTION
@@ -19,7 +19,7 @@ Collate:
'jobj.R'
'RDD.R'
'pairRDD.R'
- 'SQLTypes.R'
+ 'schema.R'
'column.R'
'group.R'
'DataFrame.R'
diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE
index a354cdce74..8028364386 100644
--- a/R/pkg/NAMESPACE
+++ b/R/pkg/NAMESPACE
@@ -5,6 +5,7 @@ exportMethods(
"aggregateByKey",
"aggregateRDD",
"cache",
+ "cartesian",
"checkpoint",
"coalesce",
"cogroup",
@@ -28,6 +29,7 @@ exportMethods(
"fullOuterJoin",
"glom",
"groupByKey",
+ "intersection",
"join",
"keyBy",
"keys",
@@ -52,11 +54,14 @@ exportMethods(
"reduceByKeyLocally",
"repartition",
"rightOuterJoin",
+ "sampleByKey",
"sampleRDD",
"saveAsTextFile",
"saveAsObjectFile",
"sortBy",
"sortByKey",
+ "subtract",
+ "subtractByKey",
"sumRDD",
"take",
"takeOrdered",
@@ -95,6 +100,7 @@ exportClasses("DataFrame")
exportMethods("columns",
"distinct",
"dtypes",
+ "except",
"explain",
"filter",
"groupBy",
@@ -118,7 +124,6 @@ exportMethods("columns",
"show",
"showDF",
"sortDF",
- "subtract",
"toJSON",
"toRDD",
"unionAll",
@@ -178,5 +183,14 @@ export("cacheTable",
"toDF",
"uncacheTable")
-export("print.structType",
- "print.structField")
+export("sparkRSQL.init",
+ "sparkRHive.init")
+
+export("structField",
+ "structField.jobj",
+ "structField.character",
+ "print.structField",
+ "structType",
+ "structType.jobj",
+ "structType.structField",
+ "print.structType")
diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R
index 044fdb4d01..861fe1c78b 100644
--- a/R/pkg/R/DataFrame.R
+++ b/R/pkg/R/DataFrame.R
@@ -17,7 +17,7 @@
# DataFrame.R - DataFrame class and methods implemented in S4 OO classes
-#' @include generics.R jobj.R SQLTypes.R RDD.R pairRDD.R column.R group.R
+#' @include generics.R jobj.R schema.R RDD.R pairRDD.R column.R group.R
NULL
setOldClass("jobj")
@@ -1141,15 +1141,15 @@ setMethod("intersect",
dataFrame(intersected)
})
-#' Subtract
+#' except
#'
#' Return a new DataFrame containing rows in this DataFrame
#' but not in another DataFrame. This is equivalent to `EXCEPT` in SQL.
#'
#' @param x A Spark DataFrame
#' @param y A Spark DataFrame
-#' @return A DataFrame containing the result of the subtract operation.
-#' @rdname subtract
+#' @return A DataFrame containing the result of the except operation.
+#' @rdname except
#' @export
#' @examples
#'\dontrun{
@@ -1157,13 +1157,15 @@ setMethod("intersect",
#' sqlCtx <- sparkRSQL.init(sc)
#' df1 <- jsonFile(sqlCtx, path)
#' df2 <- jsonFile(sqlCtx, path2)
-#' subtractDF <- subtract(df, df2)
+#' exceptDF <- except(df, df2)
#' }
-setMethod("subtract",
+#' @rdname except
+#' @export
+setMethod("except",
signature(x = "DataFrame", y = "DataFrame"),
function(x, y) {
- subtracted <- callJMethod(x@sdf, "except", y@sdf)
- dataFrame(subtracted)
+ excepted <- callJMethod(x@sdf, "except", y@sdf)
+ dataFrame(excepted)
})
#' Save the contents of the DataFrame to a data source
diff --git a/R/pkg/R/RDD.R b/R/pkg/R/RDD.R
index 820027ef67..128431334c 100644
--- a/R/pkg/R/RDD.R
+++ b/R/pkg/R/RDD.R
@@ -730,6 +730,7 @@ setMethod("take",
index <- -1
jrdd <- getJRDD(x)
numPartitions <- numPartitions(x)
+ serializedModeRDD <- getSerializedMode(x)
# TODO(shivaram): Collect more than one partition based on size
# estimates similar to the scala version of `take`.
@@ -748,13 +749,14 @@ setMethod("take",
elems <- convertJListToRList(partition,
flatten = TRUE,
logicalUpperBound = size,
- serializedMode = getSerializedMode(x))
- # TODO: Check if this append is O(n^2)?
+ serializedMode = serializedModeRDD)
+
resList <- append(resList, elems)
}
resList
})
+
#' First
#'
#' Return the first element of an RDD
@@ -1092,21 +1094,42 @@ takeOrderedElem <- function(x, num, ascending = TRUE) {
if (num < length(part)) {
# R limitation: order works only on primitive types!
ord <- order(unlist(part, recursive = FALSE), decreasing = !ascending)
- list(part[ord[1:num]])
+ part[ord[1:num]]
} else {
- list(part)
+ part
}
}
- reduceFunc <- function(elems, part) {
- newElems <- append(elems, part)
- # R limitation: order works only on primitive types!
- ord <- order(unlist(newElems, recursive = FALSE), decreasing = !ascending)
- newElems[ord[1:num]]
- }
-
newRdd <- mapPartitions(x, partitionFunc)
- reduce(newRdd, reduceFunc)
+
+ resList <- list()
+ index <- -1
+ jrdd <- getJRDD(newRdd)
+ numPartitions <- numPartitions(newRdd)
+ serializedModeRDD <- getSerializedMode(newRdd)
+
+ while (TRUE) {
+ index <- index + 1
+
+ if (index >= numPartitions) {
+ ord <- order(unlist(resList, recursive = FALSE), decreasing = !ascending)
+ resList <- resList[ord[1:num]]
+ break
+ }
+
+ # a JList of byte arrays
+ partitionArr <- callJMethod(jrdd, "collectPartitions", as.list(as.integer(index)))
+ partition <- partitionArr[[1]]
+
+ # elems is capped to have at most `num` elements
+ elems <- convertJListToRList(partition,
+ flatten = TRUE,
+ logicalUpperBound = num,
+ serializedMode = serializedModeRDD)
+
+ resList <- append(resList, elems)
+ }
+ resList
}
#' Returns the first N elements from an RDD in ascending order.
@@ -1465,67 +1488,105 @@ setMethod("zipRDD",
stop("Can only zip RDDs which have the same number of partitions.")
}
- if (getSerializedMode(x) != getSerializedMode(other) ||
- getSerializedMode(x) == "byte") {
- # Append the number of elements in each partition to that partition so that we can later
- # check if corresponding partitions of both RDDs have the same number of elements.
- #
- # Note that this appending also serves the purpose of reserialization, because even if
- # any RDD is serialized, we need to reserialize it to make sure its partitions are encoded
- # as a single byte array. For example, partitions of an RDD generated from partitionBy()
- # may be encoded as multiple byte arrays.
- appendLength <- function(part) {
- part[[length(part) + 1]] <- length(part) + 1
- part
- }
- x <- lapplyPartition(x, appendLength)
- other <- lapplyPartition(other, appendLength)
- }
+ rdds <- appendPartitionLengths(x, other)
+ jrdd <- callJMethod(getJRDD(rdds[[1]]), "zip", getJRDD(rdds[[2]]))
+ # The jrdd's elements are of scala Tuple2 type. The serialized
+ # flag here is used for the elements inside the tuples.
+ rdd <- RDD(jrdd, getSerializedMode(rdds[[1]]))
- zippedJRDD <- callJMethod(getJRDD(x), "zip", getJRDD(other))
- # The zippedRDD's elements are of scala Tuple2 type. The serialized
- # flag Here is used for the elements inside the tuples.
- serializerMode <- getSerializedMode(x)
- zippedRDD <- RDD(zippedJRDD, serializerMode)
+ mergePartitions(rdd, TRUE)
+ })
+
+#' Cartesian product of this RDD and another one.
+#'
+#' Return the Cartesian product of this RDD and another one,
+#' that is, the RDD of all pairs of elements (a, b) where a
+#' is in this and b is in other.
+#'
+#' @param x An RDD.
+#' @param other An RDD.
+#' @return A new RDD which is the Cartesian product of these two RDDs.
+#' @examples
+#'\dontrun{
+#' sc <- sparkR.init()
+#' rdd <- parallelize(sc, 1:2)
+#' sortByKey(cartesian(rdd, rdd))
+#' # list(list(1, 1), list(1, 2), list(2, 1), list(2, 2))
+#'}
+#' @rdname cartesian
+#' @aliases cartesian,RDD,RDD-method
+setMethod("cartesian",
+ signature(x = "RDD", other = "RDD"),
+ function(x, other) {
+ rdds <- appendPartitionLengths(x, other)
+ jrdd <- callJMethod(getJRDD(rdds[[1]]), "cartesian", getJRDD(rdds[[2]]))
+ # The jrdd's elements are of scala Tuple2 type. The serialized
+ # flag here is used for the elements inside the tuples.
+ rdd <- RDD(jrdd, getSerializedMode(rdds[[1]]))
- partitionFunc <- function(split, part) {
- len <- length(part)
- if (len > 0) {
- if (serializerMode == "byte") {
- lengthOfValues <- part[[len]]
- lengthOfKeys <- part[[len - lengthOfValues]]
- stopifnot(len == lengthOfKeys + lengthOfValues)
-
- # check if corresponding partitions of both RDDs have the same number of elements.
- if (lengthOfKeys != lengthOfValues) {
- stop("Can only zip RDDs with same number of elements in each pair of corresponding partitions.")
- }
-
- if (lengthOfKeys > 1) {
- keys <- part[1 : (lengthOfKeys - 1)]
- values <- part[(lengthOfKeys + 1) : (len - 1)]
- } else {
- keys <- list()
- values <- list()
- }
- } else {
- # Keys, values must have same length here, because this has
- # been validated inside the JavaRDD.zip() function.
- keys <- part[c(TRUE, FALSE)]
- values <- part[c(FALSE, TRUE)]
- }
- mapply(
- function(k, v) {
- list(k, v)
- },
- keys,
- values,
- SIMPLIFY = FALSE,
- USE.NAMES = FALSE)
- } else {
- part
- }
+ mergePartitions(rdd, FALSE)
+ })
+
+#' Subtract an RDD with another RDD.
+#'
+#' Return an RDD with the elements from this that are not in other.
+#'
+#' @param x An RDD.
+#' @param other An RDD.
+#' @param numPartitions Number of the partitions in the result RDD.
+#' @return An RDD with the elements from this that are not in other.
+#' @examples
+#'\dontrun{
+#' sc <- sparkR.init()
+#' rdd1 <- parallelize(sc, list(1, 1, 2, 2, 3, 4))
+#' rdd2 <- parallelize(sc, list(2, 4))
+#' collect(subtract(rdd1, rdd2))
+#' # list(1, 1, 3)
+#'}
+#' @rdname subtract
+#' @aliases subtract,RDD
+setMethod("subtract",
+ signature(x = "RDD", other = "RDD"),
+ function(x, other, numPartitions = SparkR::numPartitions(x)) {
+ mapFunction <- function(e) { list(e, NA) }
+ rdd1 <- map(x, mapFunction)
+ rdd2 <- map(other, mapFunction)
+ keys(subtractByKey(rdd1, rdd2, numPartitions))
+ })
+
+#' Intersection of this RDD and another one.
+#'
+#' Return the intersection of this RDD and another one.
+#' The output will not contain any duplicate elements,
+#' even if the input RDDs did. Performs a hash partition
+#' across the cluster.
+#' Note that this method performs a shuffle internally.
+#'
+#' @param x An RDD.
+#' @param other An RDD.
+#' @param numPartitions The number of partitions in the result RDD.
+#' @return An RDD which is the intersection of these two RDDs.
+#' @examples
+#'\dontrun{
+#' sc <- sparkR.init()
+#' rdd1 <- parallelize(sc, list(1, 10, 2, 3, 4, 5))
+#' rdd2 <- parallelize(sc, list(1, 6, 2, 3, 7, 8))
+#' collect(sortBy(intersection(rdd1, rdd2), function(x) { x }))
+#' # list(1, 2, 3)
+#'}
+#' @rdname intersection
+#' @aliases intersection,RDD
+setMethod("intersection",
+ signature(x = "RDD", other = "RDD"),
+ function(x, other, numPartitions = SparkR::numPartitions(x)) {
+ rdd1 <- map(x, function(v) { list(v, NA) })
+ rdd2 <- map(other, function(v) { list(v, NA) })
+
+ filterFunction <- function(elem) {
+ iters <- elem[[2]]
+ all(as.vector(
+ lapply(iters, function(iter) { length(iter) > 0 }), mode = "logical"))
}
-
- PipelinedRDD(zippedRDD, partitionFunc)
+
+ keys(filterRDD(cogroup(rdd1, rdd2, numPartitions = numPartitions), filterFunction))
})
diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R
index 930ada22f4..4f05ba524a 100644
--- a/R/pkg/R/SQLContext.R
+++ b/R/pkg/R/SQLContext.R
@@ -54,9 +54,9 @@ infer_type <- function(x) {
# StructType
types <- lapply(x, infer_type)
fields <- lapply(1:length(x), function(i) {
- list(name = names[[i]], type = types[[i]], nullable = TRUE)
+ structField(names[[i]], types[[i]], TRUE)
})
- list(type = "struct", fields = fields)
+ do.call(structType, fields)
}
} else if (length(x) > 1) {
list(type = "array", elementType = type, containsNull = TRUE)
@@ -65,30 +65,6 @@ infer_type <- function(x) {
}
}
-#' dump the schema into JSON string
-tojson <- function(x) {
- if (is.list(x)) {
- names <- names(x)
- if (!is.null(names)) {
- items <- lapply(names, function(n) {
- safe_n <- gsub('"', '\\"', n)
- paste(tojson(safe_n), ':', tojson(x[[n]]), sep = '')
- })
- d <- paste(items, collapse = ', ')
- paste('{', d, '}', sep = '')
- } else {
- l <- paste(lapply(x, tojson), collapse = ', ')
- paste('[', l, ']', sep = '')
- }
- } else if (is.character(x)) {
- paste('"', x, '"', sep = '')
- } else if (is.logical(x)) {
- if (x) "true" else "false"
- } else {
- stop(paste("unexpected type:", class(x)))
- }
-}
-
#' Create a DataFrame from an RDD
#'
#' Converts an RDD to a DataFrame by infer the types.
@@ -134,7 +110,7 @@ createDataFrame <- function(sqlCtx, data, schema = NULL, samplingRatio = 1.0) {
stop(paste("unexpected type:", class(data)))
}
- if (is.null(schema) || is.null(names(schema))) {
+ if (is.null(schema) || (!inherits(schema, "structType") && is.null(names(schema)))) {
row <- first(rdd)
names <- if (is.null(schema)) {
names(row)
@@ -143,7 +119,7 @@ createDataFrame <- function(sqlCtx, data, schema = NULL, samplingRatio = 1.0) {
}
if (is.null(names)) {
names <- lapply(1:length(row), function(x) {
- paste("_", as.character(x), sep = "")
+ paste("_", as.character(x), sep = "")
})
}
@@ -159,20 +135,18 @@ createDataFrame <- function(sqlCtx, data, schema = NULL, samplingRatio = 1.0) {
types <- lapply(row, infer_type)
fields <- lapply(1:length(row), function(i) {
- list(name = names[[i]], type = types[[i]], nullable = TRUE)
+ structField(names[[i]], types[[i]], TRUE)
})
- schema <- list(type = "struct", fields = fields)
+ schema <- do.call(structType, fields)
}
- stopifnot(class(schema) == "list")
- stopifnot(schema$type == "struct")
- stopifnot(class(schema$fields) == "list")
- schemaString <- tojson(schema)
+ stopifnot(class(schema) == "structType")
+ # schemaString <- tojson(schema)
jrdd <- getJRDD(lapply(rdd, function(x) x), "row")
srdd <- callJMethod(jrdd, "rdd")
sdf <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "createDF",
- srdd, schemaString, sqlCtx)
+ srdd, schema$jobj, sqlCtx)
dataFrame(sdf)
}
diff --git a/R/pkg/R/SQLTypes.R b/R/pkg/R/SQLTypes.R
deleted file mode 100644
index 962fba5b3c..0000000000
--- a/R/pkg/R/SQLTypes.R
+++ /dev/null
@@ -1,64 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one or more
-# contributor license agreements. See the NOTICE file distributed with
-# this work for additional information regarding copyright ownership.
-# The ASF licenses this file to You under the Apache License, Version 2.0
-# (the "License"); you may not use this file except in compliance with
-# the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-#
-
-# Utility functions for handling SparkSQL DataTypes.
-
-# Handler for StructType
-structType <- function(st) {
- obj <- structure(new.env(parent = emptyenv()), class = "structType")
- obj$jobj <- st
- obj$fields <- function() { lapply(callJMethod(st, "fields"), structField) }
- obj
-}
-
-#' Print a Spark StructType.
-#'
-#' This function prints the contents of a StructType returned from the
-#' SparkR JVM backend.
-#'
-#' @param x A StructType object
-#' @param ... further arguments passed to or from other methods
-print.structType <- function(x, ...) {
- fieldsList <- lapply(x$fields(), function(i) { i$print() })
- print(fieldsList)
-}
-
-# Handler for StructField
-structField <- function(sf) {
- obj <- structure(new.env(parent = emptyenv()), class = "structField")
- obj$jobj <- sf
- obj$name <- function() { callJMethod(sf, "name") }
- obj$dataType <- function() { callJMethod(sf, "dataType") }
- obj$dataType.toString <- function() { callJMethod(obj$dataType(), "toString") }
- obj$dataType.simpleString <- function() { callJMethod(obj$dataType(), "simpleString") }
- obj$nullable <- function() { callJMethod(sf, "nullable") }
- obj$print <- function() { paste("StructField(",
- paste(obj$name(), obj$dataType.toString(), obj$nullable(), sep = ", "),
- ")", sep = "") }
- obj
-}
-
-#' Print a Spark StructField.
-#'
-#' This function prints the contents of a StructField returned from the
-#' SparkR JVM backend.
-#'
-#' @param x A StructField object
-#' @param ... further arguments passed to or from other methods
-print.structField <- function(x, ...) {
- cat(x$print())
-}
diff --git a/R/pkg/R/column.R b/R/pkg/R/column.R
index b282001d8b..95fb9ff088 100644
--- a/R/pkg/R/column.R
+++ b/R/pkg/R/column.R
@@ -17,7 +17,7 @@
# Column Class
-#' @include generics.R jobj.R SQLTypes.R
+#' @include generics.R jobj.R schema.R
NULL
setOldClass("jobj")
diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R
index 5fb1ccaa84..6c62333901 100644
--- a/R/pkg/R/generics.R
+++ b/R/pkg/R/generics.R
@@ -230,6 +230,10 @@ setGeneric("zipWithUniqueId", function(x) { standardGeneric("zipWithUniqueId") }
############ Binary Functions #############
+#' @rdname cartesian
+#' @export
+setGeneric("cartesian", function(x, other) { standardGeneric("cartesian") })
+
#' @rdname countByKey
#' @export
setGeneric("countByKey", function(x) { standardGeneric("countByKey") })
@@ -238,6 +242,11 @@ setGeneric("countByKey", function(x) { standardGeneric("countByKey") })
#' @export
setGeneric("flatMapValues", function(X, FUN) { standardGeneric("flatMapValues") })
+#' @rdname intersection
+#' @export
+setGeneric("intersection", function(x, other, numPartitions = 1L) {
+ standardGeneric("intersection") })
+
#' @rdname keys
#' @export
setGeneric("keys", function(x) { standardGeneric("keys") })
@@ -250,12 +259,18 @@ setGeneric("lookup", function(x, key) { standardGeneric("lookup") })
#' @export
setGeneric("mapValues", function(X, FUN) { standardGeneric("mapValues") })
+#' @rdname sampleByKey
+#' @export
+setGeneric("sampleByKey",
+ function(x, withReplacement, fractions, seed) {
+ standardGeneric("sampleByKey")
+ })
+
#' @rdname values
#' @export
setGeneric("values", function(x) { standardGeneric("values") })
-
############ Shuffle Functions ############
#' @rdname aggregateByKey
@@ -330,9 +345,24 @@ setGeneric("rightOuterJoin", function(x, y, numPartitions) { standardGeneric("ri
#' @rdname sortByKey
#' @export
-setGeneric("sortByKey", function(x, ascending = TRUE, numPartitions = 1L) {
- standardGeneric("sortByKey")
-})
+setGeneric("sortByKey",
+ function(x, ascending = TRUE, numPartitions = 1L) {
+ standardGeneric("sortByKey")
+ })
+
+#' @rdname subtract
+#' @export
+setGeneric("subtract",
+ function(x, other, numPartitions = 1L) {
+ standardGeneric("subtract")
+ })
+
+#' @rdname subtractByKey
+#' @export
+setGeneric("subtractByKey",
+ function(x, other, numPartitions = 1L) {
+ standardGeneric("subtractByKey")
+ })
################### Broadcast Variable Methods #################
@@ -357,6 +387,10 @@ setGeneric("dtypes", function(x) { standardGeneric("dtypes") })
#' @export
setGeneric("explain", function(x, ...) { standardGeneric("explain") })
+#' @rdname except
+#' @export
+setGeneric("except", function(x, y) { standardGeneric("except") })
+
#' @rdname filter
#' @export
setGeneric("filter", function(x, condition) { standardGeneric("filter") })
@@ -434,10 +468,6 @@ setGeneric("showDF", function(x,...) { standardGeneric("showDF") })
#' @export
setGeneric("sortDF", function(x, col, ...) { standardGeneric("sortDF") })
-#' @rdname subtract
-#' @export
-setGeneric("subtract", function(x, y) { standardGeneric("subtract") })
-
#' @rdname tojson
#' @export
setGeneric("toJSON", function(x) { standardGeneric("toJSON") })
diff --git a/R/pkg/R/group.R b/R/pkg/R/group.R
index 855fbdfc7c..02237b3672 100644
--- a/R/pkg/R/group.R
+++ b/R/pkg/R/group.R
@@ -17,7 +17,7 @@
# group.R - GroupedData class and methods implemented in S4 OO classes
-#' @include generics.R jobj.R SQLTypes.R column.R
+#' @include generics.R jobj.R schema.R column.R
NULL
setOldClass("jobj")
diff --git a/R/pkg/R/pairRDD.R b/R/pkg/R/pairRDD.R
index 5d64822859..13efebc11c 100644
--- a/R/pkg/R/pairRDD.R
+++ b/R/pkg/R/pairRDD.R
@@ -430,7 +430,7 @@ setMethod("combineByKey",
pred <- function(item) exists(item$hash, keys)
lapply(part,
function(item) {
- item$hash <- as.character(item[[1]])
+ item$hash <- as.character(hashCode(item[[1]]))
updateOrCreatePair(item, keys, combiners, pred, mergeValue, createCombiner)
})
convertEnvsToList(keys, combiners)
@@ -443,7 +443,7 @@ setMethod("combineByKey",
pred <- function(item) exists(item$hash, keys)
lapply(part,
function(item) {
- item$hash <- as.character(item[[1]])
+ item$hash <- as.character(hashCode(item[[1]]))
updateOrCreatePair(item, keys, combiners, pred, mergeCombiners, identity)
})
convertEnvsToList(keys, combiners)
@@ -452,19 +452,19 @@ setMethod("combineByKey",
})
#' Aggregate a pair RDD by each key.
-#'
+#'
#' Aggregate the values of each key in an RDD, using given combine functions
#' and a neutral "zero value". This function can return a different result type,
#' U, than the type of the values in this RDD, V. Thus, we need one operation
-#' for merging a V into a U and one operation for merging two U's, The former
-#' operation is used for merging values within a partition, and the latter is
-#' used for merging values between partitions. To avoid memory allocation, both
-#' of these functions are allowed to modify and return their first argument
+#' for merging a V into a U and one operation for merging two U's, The former
+#' operation is used for merging values within a partition, and the latter is
+#' used for merging values between partitions. To avoid memory allocation, both
+#' of these functions are allowed to modify and return their first argument
#' instead of creating a new U.
-#'
+#'
#' @param x An RDD.
#' @param zeroValue A neutral "zero value".
-#' @param seqOp A function to aggregate the values of each key. It may return
+#' @param seqOp A function to aggregate the values of each key. It may return
#' a different result type from the type of the values.
#' @param combOp A function to aggregate results of seqOp.
#' @return An RDD containing the aggregation result.
@@ -476,7 +476,7 @@ setMethod("combineByKey",
#' zeroValue <- list(0, 0)
#' seqOp <- function(x, y) { list(x[[1]] + y, x[[2]] + 1) }
#' combOp <- function(x, y) { list(x[[1]] + y[[1]], x[[2]] + y[[2]]) }
-#' aggregateByKey(rdd, zeroValue, seqOp, combOp, 2L)
+#' aggregateByKey(rdd, zeroValue, seqOp, combOp, 2L)
#' # list(list(1, list(3, 2)), list(2, list(7, 2)))
#'}
#' @rdname aggregateByKey
@@ -493,12 +493,12 @@ setMethod("aggregateByKey",
})
#' Fold a pair RDD by each key.
-#'
+#'
#' Aggregate the values of each key in an RDD, using an associative function "func"
-#' and a neutral "zero value" which may be added to the result an arbitrary
-#' number of times, and must not change the result (e.g., 0 for addition, or
+#' and a neutral "zero value" which may be added to the result an arbitrary
+#' number of times, and must not change the result (e.g., 0 for addition, or
#' 1 for multiplication.).
-#'
+#'
#' @param x An RDD.
#' @param zeroValue A neutral "zero value".
#' @param func An associative function for folding values of each key.
@@ -548,11 +548,11 @@ setMethod("join",
function(x, y, numPartitions) {
xTagged <- lapply(x, function(i) { list(i[[1]], list(1L, i[[2]])) })
yTagged <- lapply(y, function(i) { list(i[[1]], list(2L, i[[2]])) })
-
+
doJoin <- function(v) {
joinTaggedList(v, list(FALSE, FALSE))
}
-
+
joined <- flatMapValues(groupByKey(unionRDD(xTagged, yTagged), numToInt(numPartitions)),
doJoin)
})
@@ -568,8 +568,8 @@ setMethod("join",
#' @param y An RDD to be joined. Should be an RDD where each element is
#' list(K, V).
#' @param numPartitions Number of partitions to create.
-#' @return For each element (k, v) in x, the resulting RDD will either contain
-#' all pairs (k, (v, w)) for (k, w) in rdd2, or the pair (k, (v, NULL))
+#' @return For each element (k, v) in x, the resulting RDD will either contain
+#' all pairs (k, (v, w)) for (k, w) in rdd2, or the pair (k, (v, NULL))
#' if no elements in rdd2 have key k.
#' @examples
#'\dontrun{
@@ -586,11 +586,11 @@ setMethod("leftOuterJoin",
function(x, y, numPartitions) {
xTagged <- lapply(x, function(i) { list(i[[1]], list(1L, i[[2]])) })
yTagged <- lapply(y, function(i) { list(i[[1]], list(2L, i[[2]])) })
-
+
doJoin <- function(v) {
joinTaggedList(v, list(FALSE, TRUE))
}
-
+
joined <- flatMapValues(groupByKey(unionRDD(xTagged, yTagged), numPartitions), doJoin)
})
@@ -623,18 +623,18 @@ setMethod("rightOuterJoin",
function(x, y, numPartitions) {
xTagged <- lapply(x, function(i) { list(i[[1]], list(1L, i[[2]])) })
yTagged <- lapply(y, function(i) { list(i[[1]], list(2L, i[[2]])) })
-
+
doJoin <- function(v) {
joinTaggedList(v, list(TRUE, FALSE))
}
-
+
joined <- flatMapValues(groupByKey(unionRDD(xTagged, yTagged), numPartitions), doJoin)
})
#' Full outer join two RDDs
#'
#' @description
-#' \code{fullouterjoin} This function full-outer-joins two RDDs where every element is of the form list(K, V).
+#' \code{fullouterjoin} This function full-outer-joins two RDDs where every element is of the form list(K, V).
#' The key types of the two RDDs should be the same.
#'
#' @param x An RDD to be joined. Should be an RDD where each element is
@@ -644,7 +644,7 @@ setMethod("rightOuterJoin",
#' @param numPartitions Number of partitions to create.
#' @return For each element (k, v) in x and (k, w) in y, the resulting RDD
#' will contain all pairs (k, (v, w)) for both (k, v) in x and
-#' (k, w) in y, or the pair (k, (NULL, w))/(k, (v, NULL)) if no elements
+#' (k, w) in y, or the pair (k, (NULL, w))/(k, (v, NULL)) if no elements
#' in x/y have key k.
#' @examples
#'\dontrun{
@@ -683,7 +683,7 @@ setMethod("fullOuterJoin",
#' sc <- sparkR.init()
#' rdd1 <- parallelize(sc, list(list(1, 1), list(2, 4)))
#' rdd2 <- parallelize(sc, list(list(1, 2), list(1, 3)))
-#' cogroup(rdd1, rdd2, numPartitions = 2L)
+#' cogroup(rdd1, rdd2, numPartitions = 2L)
#' # list(list(1, list(1, list(2, 3))), list(2, list(list(4), list()))
#'}
#' @rdname cogroup
@@ -694,7 +694,7 @@ setMethod("cogroup",
rdds <- list(...)
rddsLen <- length(rdds)
for (i in 1:rddsLen) {
- rdds[[i]] <- lapply(rdds[[i]],
+ rdds[[i]] <- lapply(rdds[[i]],
function(x) { list(x[[1]], list(i, x[[2]])) })
}
union.rdd <- Reduce(unionRDD, rdds)
@@ -719,7 +719,7 @@ setMethod("cogroup",
}
})
}
- cogroup.rdd <- mapValues(groupByKey(union.rdd, numPartitions),
+ cogroup.rdd <- mapValues(groupByKey(union.rdd, numPartitions),
group.func)
})
@@ -741,18 +741,18 @@ setMethod("sortByKey",
signature(x = "RDD"),
function(x, ascending = TRUE, numPartitions = SparkR::numPartitions(x)) {
rangeBounds <- list()
-
+
if (numPartitions > 1) {
rddSize <- count(x)
# constant from Spark's RangePartitioner
maxSampleSize <- numPartitions * 20
fraction <- min(maxSampleSize / max(rddSize, 1), 1.0)
-
+
samples <- collect(keys(sampleRDD(x, FALSE, fraction, 1L)))
-
+
# Note: the built-in R sort() function only works on atomic vectors
samples <- sort(unlist(samples, recursive = FALSE), decreasing = !ascending)
-
+
if (length(samples) > 0) {
rangeBounds <- lapply(seq_len(numPartitions - 1),
function(i) {
@@ -764,24 +764,146 @@ setMethod("sortByKey",
rangePartitionFunc <- function(key) {
partition <- 0
-
+
# TODO: Use binary search instead of linear search, similar with Spark
while (partition < length(rangeBounds) && key > rangeBounds[[partition + 1]]) {
partition <- partition + 1
}
-
+
if (ascending) {
partition
} else {
numPartitions - partition - 1
}
}
-
+
partitionFunc <- function(part) {
sortKeyValueList(part, decreasing = !ascending)
}
-
+
newRDD <- partitionBy(x, numPartitions, rangePartitionFunc)
lapplyPartition(newRDD, partitionFunc)
})
+#' Subtract a pair RDD with another pair RDD.
+#'
+#' Return an RDD with the pairs from x whose keys are not in other.
+#'
+#' @param x An RDD.
+#' @param other An RDD.
+#' @param numPartitions Number of the partitions in the result RDD.
+#' @return An RDD with the pairs from x whose keys are not in other.
+#' @examples
+#'\dontrun{
+#' sc <- sparkR.init()
+#' rdd1 <- parallelize(sc, list(list("a", 1), list("b", 4),
+#' list("b", 5), list("a", 2)))
+#' rdd2 <- parallelize(sc, list(list("a", 3), list("c", 1)))
+#' collect(subtractByKey(rdd1, rdd2))
+#' # list(list("b", 4), list("b", 5))
+#'}
+#' @rdname subtractByKey
+#' @aliases subtractByKey,RDD
+setMethod("subtractByKey",
+ signature(x = "RDD", other = "RDD"),
+ function(x, other, numPartitions = SparkR::numPartitions(x)) {
+ filterFunction <- function(elem) {
+ iters <- elem[[2]]
+ (length(iters[[1]]) > 0) && (length(iters[[2]]) == 0)
+ }
+
+ flatMapValues(filterRDD(cogroup(x,
+ other,
+ numPartitions = numPartitions),
+ filterFunction),
+ function (v) { v[[1]] })
+ })
+
+#' Return a subset of this RDD sampled by key.
+#'
+#' @description
+#' \code{sampleByKey} Create a sample of this RDD using variable sampling rates
+#' for different keys as specified by fractions, a key to sampling rate map.
+#'
+#' @param x The RDD to sample elements by key, where each element is
+#' list(K, V) or c(K, V).
+#' @param withReplacement Sampling with replacement or not
+#' @param fraction The (rough) sample target fraction
+#' @param seed Randomness seed value
+#' @examples
+#'\dontrun{
+#' sc <- sparkR.init()
+#' rdd <- parallelize(sc, 1:3000)
+#' pairs <- lapply(rdd, function(x) { if (x %% 3 == 0) list("a", x)
+#' else { if (x %% 3 == 1) list("b", x) else list("c", x) }})
+#' fractions <- list(a = 0.2, b = 0.1, c = 0.3)
+#' sample <- sampleByKey(pairs, FALSE, fractions, 1618L)
+#' 100 < length(lookup(sample, "a")) && 300 > length(lookup(sample, "a")) # TRUE
+#' 50 < length(lookup(sample, "b")) && 150 > length(lookup(sample, "b")) # TRUE
+#' 200 < length(lookup(sample, "c")) && 400 > length(lookup(sample, "c")) # TRUE
+#' lookup(sample, "a")[which.min(lookup(sample, "a"))] >= 0 # TRUE
+#' lookup(sample, "a")[which.max(lookup(sample, "a"))] <= 2000 # TRUE
+#' lookup(sample, "b")[which.min(lookup(sample, "b"))] >= 0 # TRUE
+#' lookup(sample, "b")[which.max(lookup(sample, "b"))] <= 2000 # TRUE
+#' lookup(sample, "c")[which.min(lookup(sample, "c"))] >= 0 # TRUE
+#' lookup(sample, "c")[which.max(lookup(sample, "c"))] <= 2000 # TRUE
+#' fractions <- list(a = 0.2, b = 0.1, c = 0.3, d = 0.4)
+#' sample <- sampleByKey(pairs, FALSE, fractions, 1618L) # Key "d" will be ignored
+#' fractions <- list(a = 0.2, b = 0.1)
+#' sample <- sampleByKey(pairs, FALSE, fractions, 1618L) # KeyError: "c"
+#'}
+#' @rdname sampleByKey
+#' @aliases sampleByKey,RDD-method
+setMethod("sampleByKey",
+ signature(x = "RDD", withReplacement = "logical",
+ fractions = "vector", seed = "integer"),
+ function(x, withReplacement, fractions, seed) {
+
+ for (elem in fractions) {
+ if (elem < 0.0) {
+ stop(paste("Negative fraction value ", fractions[which(fractions == elem)]))
+ }
+ }
+
+ # The sampler: takes a partition and returns its sampled version.
+ samplingFunc <- function(split, part) {
+ set.seed(bitwXor(seed, split))
+ res <- vector("list", length(part))
+ len <- 0
+
+ # mixing because the initial seeds are close to each other
+ runif(10)
+
+ for (elem in part) {
+ if (elem[[1]] %in% names(fractions)) {
+ frac <- as.numeric(fractions[which(elem[[1]] == names(fractions))])
+ if (withReplacement) {
+ count <- rpois(1, frac)
+ if (count > 0) {
+ res[(len + 1):(len + count)] <- rep(list(elem), count)
+ len <- len + count
+ }
+ } else {
+ if (runif(1) < frac) {
+ len <- len + 1
+ res[[len]] <- elem
+ }
+ }
+ } else {
+ stop("KeyError: \"", elem[[1]], "\"")
+ }
+ }
+
+ # TODO(zongheng): look into the performance of the current
+ # implementation. Look into some iterator package? Note that
+ # Scala avoids many calls to creating an empty list and PySpark
+ # similarly achieves this using `yield'. (duplicated from sampleRDD)
+ if (len > 0) {
+ res[1:len]
+ } else {
+ list()
+ }
+ }
+
+ lapplyPartitionsWithIndex(x, samplingFunc)
+ })
diff --git a/R/pkg/R/schema.R b/R/pkg/R/schema.R
new file mode 100644
index 0000000000..e442119086
--- /dev/null
+++ b/R/pkg/R/schema.R
@@ -0,0 +1,162 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# A set of S3 classes and methods that support the SparkSQL `StructType` and `StructField
+# datatypes. These are used to create and interact with DataFrame schemas.
+
+#' structType
+#'
+#' Create a structType object that contains the metadata for a DataFrame. Intended for
+#' use with createDataFrame and toDF.
+#'
+#' @param x a structField object (created with the field() function)
+#' @param ... additional structField objects
+#' @return a structType object
+#' @export
+#' @examples
+#'\dontrun{
+#' sc <- sparkR.init()
+#' sqlCtx <- sparkRSQL.init(sc)
+#' rdd <- lapply(parallelize(sc, 1:10), function(x) { list(x, as.character(x)) })
+#' schema <- structType(structField("a", "integer"), structField("b", "string"))
+#' df <- createDataFrame(sqlCtx, rdd, schema)
+#' }
+structType <- function(x, ...) {
+ UseMethod("structType", x)
+}
+
+structType.jobj <- function(x) {
+ obj <- structure(list(), class = "structType")
+ obj$jobj <- x
+ obj$fields <- function() { lapply(callJMethod(obj$jobj, "fields"), structField) }
+ obj
+}
+
+structType.structField <- function(x, ...) {
+ fields <- list(x, ...)
+ if (!all(sapply(fields, inherits, "structField"))) {
+ stop("All arguments must be structField objects.")
+ }
+ sfObjList <- lapply(fields, function(field) {
+ field$jobj
+ })
+ stObj <- callJStatic("org.apache.spark.sql.api.r.SQLUtils",
+ "createStructType",
+ listToSeq(sfObjList))
+ structType(stObj)
+}
+
+#' Print a Spark StructType.
+#'
+#' This function prints the contents of a StructType returned from the
+#' SparkR JVM backend.
+#'
+#' @param x A StructType object
+#' @param ... further arguments passed to or from other methods
+print.structType <- function(x, ...) {
+ cat("StructType\n",
+ sapply(x$fields(), function(field) { paste("|-", "name = \"", field$name(),
+ "\", type = \"", field$dataType.toString(),
+ "\", nullable = ", field$nullable(), "\n",
+ sep = "") })
+ , sep = "")
+}
+
+#' structField
+#'
+#' Create a structField object that contains the metadata for a single field in a schema.
+#'
+#' @param x The name of the field
+#' @param type The data type of the field
+#' @param nullable A logical vector indicating whether or not the field is nullable
+#' @return a structField object
+#' @export
+#' @examples
+#'\dontrun{
+#' sc <- sparkR.init()
+#' sqlCtx <- sparkRSQL.init(sc)
+#' rdd <- lapply(parallelize(sc, 1:10), function(x) { list(x, as.character(x)) })
+#' field1 <- structField("a", "integer", TRUE)
+#' field2 <- structField("b", "string", TRUE)
+#' schema <- structType(field1, field2)
+#' df <- createDataFrame(sqlCtx, rdd, schema)
+#' }
+
+structField <- function(x, ...) {
+ UseMethod("structField", x)
+}
+
+structField.jobj <- function(x) {
+ obj <- structure(list(), class = "structField")
+ obj$jobj <- x
+ obj$name <- function() { callJMethod(x, "name") }
+ obj$dataType <- function() { callJMethod(x, "dataType") }
+ obj$dataType.toString <- function() { callJMethod(obj$dataType(), "toString") }
+ obj$dataType.simpleString <- function() { callJMethod(obj$dataType(), "simpleString") }
+ obj$nullable <- function() { callJMethod(x, "nullable") }
+ obj
+}
+
+structField.character <- function(x, type, nullable = TRUE) {
+ if (class(x) != "character") {
+ stop("Field name must be a string.")
+ }
+ if (class(type) != "character") {
+ stop("Field type must be a string.")
+ }
+ if (class(nullable) != "logical") {
+ stop("nullable must be either TRUE or FALSE")
+ }
+ options <- c("byte",
+ "integer",
+ "double",
+ "numeric",
+ "character",
+ "string",
+ "binary",
+ "raw",
+ "logical",
+ "boolean",
+ "timestamp",
+ "date")
+ dataType <- if (type %in% options) {
+ type
+ } else {
+ stop(paste("Unsupported type for Dataframe:", type))
+ }
+ sfObj <- callJStatic("org.apache.spark.sql.api.r.SQLUtils",
+ "createStructField",
+ x,
+ dataType,
+ nullable)
+ structField(sfObj)
+}
+
+#' Print a Spark StructField.
+#'
+#' This function prints the contents of a StructField returned from the
+#' SparkR JVM backend.
+#'
+#' @param x A StructField object
+#' @param ... further arguments passed to or from other methods
+print.structField <- function(x, ...) {
+ cat("StructField(name = \"", x$name(),
+ "\", type = \"", x$dataType.toString(),
+ "\", nullable = ", x$nullable(),
+ ")",
+ sep = "")
+}
diff --git a/R/pkg/R/serialize.R b/R/pkg/R/serialize.R
index 8a9c0c652c..c53d0a9610 100644
--- a/R/pkg/R/serialize.R
+++ b/R/pkg/R/serialize.R
@@ -69,8 +69,9 @@ writeJobj <- function(con, value) {
}
writeString <- function(con, value) {
- writeInt(con, as.integer(nchar(value) + 1))
- writeBin(value, con, endian = "big")
+ utfVal <- enc2utf8(value)
+ writeInt(con, as.integer(nchar(utfVal, type = "bytes") + 1))
+ writeBin(utfVal, con, endian = "big")
}
writeInt <- function(con, value) {
@@ -189,7 +190,3 @@ writeArgs <- function(con, args) {
}
}
}
-
-writeStrings <- function(con, stringList) {
- writeLines(unlist(stringList), con)
-}
diff --git a/R/pkg/R/utils.R b/R/pkg/R/utils.R
index c337fb0751..23305d3c67 100644
--- a/R/pkg/R/utils.R
+++ b/R/pkg/R/utils.R
@@ -465,3 +465,83 @@ cleanClosure <- function(func, checkedFuncs = new.env()) {
}
func
}
+
+# Append partition lengths to each partition in two input RDDs if needed.
+# param
+# x An RDD.
+# Other An RDD.
+# return value
+# A list of two result RDDs.
+appendPartitionLengths <- function(x, other) {
+ if (getSerializedMode(x) != getSerializedMode(other) ||
+ getSerializedMode(x) == "byte") {
+ # Append the number of elements in each partition to that partition so that we can later
+ # know the boundary of elements from x and other.
+ #
+ # Note that this appending also serves the purpose of reserialization, because even if
+ # any RDD is serialized, we need to reserialize it to make sure its partitions are encoded
+ # as a single byte array. For example, partitions of an RDD generated from partitionBy()
+ # may be encoded as multiple byte arrays.
+ appendLength <- function(part) {
+ len <- length(part)
+ part[[len + 1]] <- len + 1
+ part
+ }
+ x <- lapplyPartition(x, appendLength)
+ other <- lapplyPartition(other, appendLength)
+ }
+ list (x, other)
+}
+
+# Perform zip or cartesian between elements from two RDDs in each partition
+# param
+# rdd An RDD.
+# zip A boolean flag indicating this call is for zip operation or not.
+# return value
+# A result RDD.
+mergePartitions <- function(rdd, zip) {
+ serializerMode <- getSerializedMode(rdd)
+ partitionFunc <- function(split, part) {
+ len <- length(part)
+ if (len > 0) {
+ if (serializerMode == "byte") {
+ lengthOfValues <- part[[len]]
+ lengthOfKeys <- part[[len - lengthOfValues]]
+ stopifnot(len == lengthOfKeys + lengthOfValues)
+
+ # For zip operation, check if corresponding partitions of both RDDs have the same number of elements.
+ if (zip && lengthOfKeys != lengthOfValues) {
+ stop("Can only zip RDDs with same number of elements in each pair of corresponding partitions.")
+ }
+
+ if (lengthOfKeys > 1) {
+ keys <- part[1 : (lengthOfKeys - 1)]
+ } else {
+ keys <- list()
+ }
+ if (lengthOfValues > 1) {
+ values <- part[(lengthOfKeys + 1) : (len - 1)]
+ } else {
+ values <- list()
+ }
+
+ if (!zip) {
+ return(mergeCompactLists(keys, values))
+ }
+ } else {
+ keys <- part[c(TRUE, FALSE)]
+ values <- part[c(FALSE, TRUE)]
+ }
+ mapply(
+ function(k, v) { list(k, v) },
+ keys,
+ values,
+ SIMPLIFY = FALSE,
+ USE.NAMES = FALSE)
+ } else {
+ part
+ }
+ }
+
+ PipelinedRDD(rdd, partitionFunc)
+}
diff --git a/R/pkg/inst/tests/test_rdd.R b/R/pkg/inst/tests/test_rdd.R
index b76e4db03e..3ba7d17163 100644
--- a/R/pkg/inst/tests/test_rdd.R
+++ b/R/pkg/inst/tests/test_rdd.R
@@ -35,7 +35,7 @@ test_that("get number of partitions in RDD", {
test_that("first on RDD", {
expect_true(first(rdd) == 1)
newrdd <- lapply(rdd, function(x) x + 1)
- expect_true(first(newrdd) == 2)
+ expect_true(first(newrdd) == 2)
})
test_that("count and length on RDD", {
@@ -48,7 +48,7 @@ test_that("count by values and keys", {
actual <- countByValue(mods)
expected <- list(list(0, 3L), list(1, 4L), list(2, 3L))
expect_equal(sortKeyValueList(actual), sortKeyValueList(expected))
-
+
actual <- countByKey(intRdd)
expected <- list(list(2L, 2L), list(1L, 2L))
expect_equal(sortKeyValueList(actual), sortKeyValueList(expected))
@@ -82,11 +82,11 @@ test_that("filterRDD on RDD", {
filtered.rdd <- filterRDD(rdd, function(x) { x %% 2 == 0 })
actual <- collect(filtered.rdd)
expect_equal(actual, list(2, 4, 6, 8, 10))
-
+
filtered.rdd <- Filter(function(x) { x[[2]] < 0 }, intRdd)
actual <- collect(filtered.rdd)
expect_equal(actual, list(list(1L, -1)))
-
+
# Filter out all elements.
filtered.rdd <- filterRDD(rdd, function(x) { x > 10 })
actual <- collect(filtered.rdd)
@@ -96,7 +96,7 @@ test_that("filterRDD on RDD", {
test_that("lookup on RDD", {
vals <- lookup(intRdd, 1L)
expect_equal(vals, list(-1, 200))
-
+
vals <- lookup(intRdd, 3L)
expect_equal(vals, list())
})
@@ -110,7 +110,7 @@ test_that("several transformations on RDD (a benchmark on PipelinedRDD)", {
})
rdd2 <- lapply(rdd2, function(x) x + x)
actual <- collect(rdd2)
- expected <- list(24, 24, 24, 24, 24,
+ expected <- list(24, 24, 24, 24, 24,
168, 170, 172, 174, 176)
expect_equal(actual, expected)
})
@@ -248,10 +248,10 @@ test_that("flatMapValues() on pairwise RDDs", {
l <- parallelize(sc, list(list(1, c(1,2)), list(2, c(3,4))))
actual <- collect(flatMapValues(l, function(x) { x }))
expect_equal(actual, list(list(1,1), list(1,2), list(2,3), list(2,4)))
-
+
# Generate x to x+1 for every value
actual <- collect(flatMapValues(intRdd, function(x) { x:(x + 1) }))
- expect_equal(actual,
+ expect_equal(actual,
list(list(1L, -1), list(1L, 0), list(2L, 100), list(2L, 101),
list(2L, 1), list(2L, 2), list(1L, 200), list(1L, 201)))
})
@@ -348,7 +348,7 @@ test_that("top() on RDDs", {
rdd <- parallelize(sc, l)
actual <- top(rdd, 6L)
expect_equal(actual, as.list(sort(unlist(l), decreasing = TRUE))[1:6])
-
+
l <- list("e", "d", "c", "d", "a")
rdd <- parallelize(sc, l)
actual <- top(rdd, 3L)
@@ -358,7 +358,7 @@ test_that("top() on RDDs", {
test_that("fold() on RDDs", {
actual <- fold(rdd, 0, "+")
expect_equal(actual, Reduce("+", nums, 0))
-
+
rdd <- parallelize(sc, list())
actual <- fold(rdd, 0, "+")
expect_equal(actual, 0)
@@ -371,7 +371,7 @@ test_that("aggregateRDD() on RDDs", {
combOp <- function(x, y) { list(x[[1]] + y[[1]], x[[2]] + y[[2]]) }
actual <- aggregateRDD(rdd, zeroValue, seqOp, combOp)
expect_equal(actual, list(10, 4))
-
+
rdd <- parallelize(sc, list())
actual <- aggregateRDD(rdd, zeroValue, seqOp, combOp)
expect_equal(actual, list(0, 0))
@@ -380,13 +380,13 @@ test_that("aggregateRDD() on RDDs", {
test_that("zipWithUniqueId() on RDDs", {
rdd <- parallelize(sc, list("a", "b", "c", "d", "e"), 3L)
actual <- collect(zipWithUniqueId(rdd))
- expected <- list(list("a", 0), list("b", 3), list("c", 1),
+ expected <- list(list("a", 0), list("b", 3), list("c", 1),
list("d", 4), list("e", 2))
expect_equal(actual, expected)
-
+
rdd <- parallelize(sc, list("a", "b", "c", "d", "e"), 1L)
actual <- collect(zipWithUniqueId(rdd))
- expected <- list(list("a", 0), list("b", 1), list("c", 2),
+ expected <- list(list("a", 0), list("b", 1), list("c", 2),
list("d", 3), list("e", 4))
expect_equal(actual, expected)
})
@@ -394,13 +394,13 @@ test_that("zipWithUniqueId() on RDDs", {
test_that("zipWithIndex() on RDDs", {
rdd <- parallelize(sc, list("a", "b", "c", "d", "e"), 3L)
actual <- collect(zipWithIndex(rdd))
- expected <- list(list("a", 0), list("b", 1), list("c", 2),
+ expected <- list(list("a", 0), list("b", 1), list("c", 2),
list("d", 3), list("e", 4))
expect_equal(actual, expected)
-
+
rdd <- parallelize(sc, list("a", "b", "c", "d", "e"), 1L)
actual <- collect(zipWithIndex(rdd))
- expected <- list(list("a", 0), list("b", 1), list("c", 2),
+ expected <- list(list("a", 0), list("b", 1), list("c", 2),
list("d", 3), list("e", 4))
expect_equal(actual, expected)
})
@@ -427,12 +427,12 @@ test_that("pipeRDD() on RDDs", {
actual <- collect(pipeRDD(rdd, "more"))
expected <- as.list(as.character(1:10))
expect_equal(actual, expected)
-
+
trailed.rdd <- parallelize(sc, c("1", "", "2\n", "3\n\r\n"))
actual <- collect(pipeRDD(trailed.rdd, "sort"))
expected <- list("", "1", "2", "3")
expect_equal(actual, expected)
-
+
rev.nums <- 9:0
rev.rdd <- parallelize(sc, rev.nums, 2L)
actual <- collect(pipeRDD(rev.rdd, "sort"))
@@ -446,11 +446,11 @@ test_that("zipRDD() on RDDs", {
actual <- collect(zipRDD(rdd1, rdd2))
expect_equal(actual,
list(list(0, 1000), list(1, 1001), list(2, 1002), list(3, 1003), list(4, 1004)))
-
+
mockFile = c("Spark is pretty.", "Spark is awesome.")
fileName <- tempfile(pattern="spark-test", fileext=".tmp")
writeLines(mockFile, fileName)
-
+
rdd <- textFile(sc, fileName, 1)
actual <- collect(zipRDD(rdd, rdd))
expected <- lapply(mockFile, function(x) { list(x ,x) })
@@ -465,10 +465,125 @@ test_that("zipRDD() on RDDs", {
actual <- collect(zipRDD(rdd, rdd1))
expected <- lapply(mockFile, function(x) { list(x, x) })
expect_equal(actual, expected)
-
+
+ unlink(fileName)
+})
+
+test_that("cartesian() on RDDs", {
+ rdd <- parallelize(sc, 1:3)
+ actual <- collect(cartesian(rdd, rdd))
+ expect_equal(sortKeyValueList(actual),
+ list(
+ list(1, 1), list(1, 2), list(1, 3),
+ list(2, 1), list(2, 2), list(2, 3),
+ list(3, 1), list(3, 2), list(3, 3)))
+
+ # test case where one RDD is empty
+ emptyRdd <- parallelize(sc, list())
+ actual <- collect(cartesian(rdd, emptyRdd))
+ expect_equal(actual, list())
+
+ mockFile = c("Spark is pretty.", "Spark is awesome.")
+ fileName <- tempfile(pattern="spark-test", fileext=".tmp")
+ writeLines(mockFile, fileName)
+
+ rdd <- textFile(sc, fileName)
+ actual <- collect(cartesian(rdd, rdd))
+ expected <- list(
+ list("Spark is awesome.", "Spark is pretty."),
+ list("Spark is awesome.", "Spark is awesome."),
+ list("Spark is pretty.", "Spark is pretty."),
+ list("Spark is pretty.", "Spark is awesome."))
+ expect_equal(sortKeyValueList(actual), expected)
+
+ rdd1 <- parallelize(sc, 0:1)
+ actual <- collect(cartesian(rdd1, rdd))
+ expect_equal(sortKeyValueList(actual),
+ list(
+ list(0, "Spark is pretty."),
+ list(0, "Spark is awesome."),
+ list(1, "Spark is pretty."),
+ list(1, "Spark is awesome.")))
+
+ rdd1 <- map(rdd, function(x) { x })
+ actual <- collect(cartesian(rdd, rdd1))
+ expect_equal(sortKeyValueList(actual), expected)
+
unlink(fileName)
})
+test_that("subtract() on RDDs", {
+ l <- list(1, 1, 2, 2, 3, 4)
+ rdd1 <- parallelize(sc, l)
+
+ # subtract by itself
+ actual <- collect(subtract(rdd1, rdd1))
+ expect_equal(actual, list())
+
+ # subtract by an empty RDD
+ rdd2 <- parallelize(sc, list())
+ actual <- collect(subtract(rdd1, rdd2))
+ expect_equal(as.list(sort(as.vector(actual, mode="integer"))),
+ l)
+
+ rdd2 <- parallelize(sc, list(2, 4))
+ actual <- collect(subtract(rdd1, rdd2))
+ expect_equal(as.list(sort(as.vector(actual, mode="integer"))),
+ list(1, 1, 3))
+
+ l <- list("a", "a", "b", "b", "c", "d")
+ rdd1 <- parallelize(sc, l)
+ rdd2 <- parallelize(sc, list("b", "d"))
+ actual <- collect(subtract(rdd1, rdd2))
+ expect_equal(as.list(sort(as.vector(actual, mode="character"))),
+ list("a", "a", "c"))
+})
+
+test_that("subtractByKey() on pairwise RDDs", {
+ l <- list(list("a", 1), list("b", 4),
+ list("b", 5), list("a", 2))
+ rdd1 <- parallelize(sc, l)
+
+ # subtractByKey by itself
+ actual <- collect(subtractByKey(rdd1, rdd1))
+ expect_equal(actual, list())
+
+ # subtractByKey by an empty RDD
+ rdd2 <- parallelize(sc, list())
+ actual <- collect(subtractByKey(rdd1, rdd2))
+ expect_equal(sortKeyValueList(actual),
+ sortKeyValueList(l))
+
+ rdd2 <- parallelize(sc, list(list("a", 3), list("c", 1)))
+ actual <- collect(subtractByKey(rdd1, rdd2))
+ expect_equal(actual,
+ list(list("b", 4), list("b", 5)))
+
+ l <- list(list(1, 1), list(2, 4),
+ list(2, 5), list(1, 2))
+ rdd1 <- parallelize(sc, l)
+ rdd2 <- parallelize(sc, list(list(1, 3), list(3, 1)))
+ actual <- collect(subtractByKey(rdd1, rdd2))
+ expect_equal(actual,
+ list(list(2, 4), list(2, 5)))
+})
+
+test_that("intersection() on RDDs", {
+ # intersection with self
+ actual <- collect(intersection(rdd, rdd))
+ expect_equal(sort(as.integer(actual)), nums)
+
+ # intersection with an empty RDD
+ emptyRdd <- parallelize(sc, list())
+ actual <- collect(intersection(rdd, emptyRdd))
+ expect_equal(actual, list())
+
+ rdd1 <- parallelize(sc, list(1, 10, 2, 3, 4, 5))
+ rdd2 <- parallelize(sc, list(1, 6, 2, 3, 7, 8))
+ actual <- collect(intersection(rdd1, rdd2))
+ expect_equal(sort(as.integer(actual)), 1:3)
+})
+
test_that("join() on pairwise RDDs", {
rdd1 <- parallelize(sc, list(list(1,1), list(2,4)))
rdd2 <- parallelize(sc, list(list(1,2), list(1,3)))
@@ -596,9 +711,9 @@ test_that("sortByKey() on pairwise RDDs", {
sortedRdd3 <- sortByKey(rdd3)
actual <- collect(sortedRdd3)
expect_equal(actual, list(list("1", 3), list("2", 5), list("a", 1), list("b", 2), list("d", 4)))
-
+
# test on the boundary cases
-
+
# boundary case 1: the RDD to be sorted has only 1 partition
rdd4 <- parallelize(sc, l, 1L)
sortedRdd4 <- sortByKey(rdd4)
@@ -623,7 +738,7 @@ test_that("sortByKey() on pairwise RDDs", {
rdd7 <- parallelize(sc, l3, 2L)
sortedRdd7 <- sortByKey(rdd7)
actual <- collect(sortedRdd7)
- expect_equal(actual, l3)
+ expect_equal(actual, l3)
})
test_that("collectAsMap() on a pairwise RDD", {
@@ -634,12 +749,36 @@ test_that("collectAsMap() on a pairwise RDD", {
rdd <- parallelize(sc, list(list("a", 1), list("b", 2)))
vals <- collectAsMap(rdd)
expect_equal(vals, list(a = 1, b = 2))
-
+
rdd <- parallelize(sc, list(list(1.1, 2.2), list(1.2, 2.4)))
vals <- collectAsMap(rdd)
expect_equal(vals, list(`1.1` = 2.2, `1.2` = 2.4))
-
+
rdd <- parallelize(sc, list(list(1, "a"), list(2, "b")))
vals <- collectAsMap(rdd)
expect_equal(vals, list(`1` = "a", `2` = "b"))
})
+
+test_that("sampleByKey() on pairwise RDDs", {
+ rdd <- parallelize(sc, 1:2000)
+ pairsRDD <- lapply(rdd, function(x) { if (x %% 2 == 0) list("a", x) else list("b", x) })
+ fractions <- list(a = 0.2, b = 0.1)
+ sample <- sampleByKey(pairsRDD, FALSE, fractions, 1618L)
+ expect_equal(100 < length(lookup(sample, "a")) && 300 > length(lookup(sample, "a")), TRUE)
+ expect_equal(50 < length(lookup(sample, "b")) && 150 > length(lookup(sample, "b")), TRUE)
+ expect_equal(lookup(sample, "a")[which.min(lookup(sample, "a"))] >= 0, TRUE)
+ expect_equal(lookup(sample, "a")[which.max(lookup(sample, "a"))] <= 2000, TRUE)
+ expect_equal(lookup(sample, "b")[which.min(lookup(sample, "b"))] >= 0, TRUE)
+ expect_equal(lookup(sample, "b")[which.max(lookup(sample, "b"))] <= 2000, TRUE)
+
+ rdd <- parallelize(sc, 1:2000)
+ pairsRDD <- lapply(rdd, function(x) { if (x %% 2 == 0) list(2, x) else list(3, x) })
+ fractions <- list(`2` = 0.2, `3` = 0.1)
+ sample <- sampleByKey(pairsRDD, TRUE, fractions, 1618L)
+ expect_equal(100 < length(lookup(sample, 2)) && 300 > length(lookup(sample, 2)), TRUE)
+ expect_equal(50 < length(lookup(sample, 3)) && 150 > length(lookup(sample, 3)), TRUE)
+ expect_equal(lookup(sample, 2)[which.min(lookup(sample, 2))] >= 0, TRUE)
+ expect_equal(lookup(sample, 2)[which.max(lookup(sample, 2))] <= 2000, TRUE)
+ expect_equal(lookup(sample, 3)[which.min(lookup(sample, 3))] >= 0, TRUE)
+ expect_equal(lookup(sample, 3)[which.max(lookup(sample, 3))] <= 2000, TRUE)
+})
diff --git a/R/pkg/inst/tests/test_shuffle.R b/R/pkg/inst/tests/test_shuffle.R
index d1da8232ae..d7dedda553 100644
--- a/R/pkg/inst/tests/test_shuffle.R
+++ b/R/pkg/inst/tests/test_shuffle.R
@@ -87,6 +87,18 @@ test_that("combineByKey for doubles", {
expect_equal(sortKeyValueList(actual), sortKeyValueList(expected))
})
+test_that("combineByKey for characters", {
+ stringKeyRDD <- parallelize(sc,
+ list(list("max", 1L), list("min", 2L),
+ list("other", 3L), list("max", 4L)), 2L)
+ reduced <- combineByKey(stringKeyRDD,
+ function(x) { x }, "+", "+", 2L)
+ actual <- collect(reduced)
+
+ expected <- list(list("max", 5L), list("min", 2L), list("other", 3L))
+ expect_equal(sortKeyValueList(actual), sortKeyValueList(expected))
+})
+
test_that("aggregateByKey", {
# test aggregateByKey for int keys
rdd <- parallelize(sc, list(list(1, 1), list(1, 2), list(2, 3), list(2, 4)))
diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R
index cf5cf6d169..25831ae2d9 100644
--- a/R/pkg/inst/tests/test_sparkSQL.R
+++ b/R/pkg/inst/tests/test_sparkSQL.R
@@ -44,9 +44,8 @@ test_that("infer types", {
expect_equal(infer_type(list(1L, 2L)),
list(type = 'array', elementType = "integer", containsNull = TRUE))
expect_equal(infer_type(list(a = 1L, b = "2")),
- list(type = "struct",
- fields = list(list(name = "a", type = "integer", nullable = TRUE),
- list(name = "b", type = "string", nullable = TRUE))))
+ structType(structField(x = "a", type = "integer", nullable = TRUE),
+ structField(x = "b", type = "string", nullable = TRUE)))
e <- new.env()
assign("a", 1L, envir = e)
expect_equal(infer_type(e),
@@ -54,6 +53,18 @@ test_that("infer types", {
valueContainsNull = TRUE))
})
+test_that("structType and structField", {
+ testField <- structField("a", "string")
+ expect_true(inherits(testField, "structField"))
+ expect_true(testField$name() == "a")
+ expect_true(testField$nullable())
+
+ testSchema <- structType(testField, structField("b", "integer"))
+ expect_true(inherits(testSchema, "structType"))
+ expect_true(inherits(testSchema$fields()[[2]], "structField"))
+ expect_true(testSchema$fields()[[1]]$dataType.toString() == "StringType")
+})
+
test_that("create DataFrame from RDD", {
rdd <- lapply(parallelize(sc, 1:10), function(x) { list(x, as.character(x)) })
df <- createDataFrame(sqlCtx, rdd, list("a", "b"))
@@ -66,9 +77,8 @@ test_that("create DataFrame from RDD", {
expect_true(inherits(df, "DataFrame"))
expect_equal(columns(df), c("_1", "_2"))
- fields <- list(list(name = "a", type = "integer", nullable = TRUE),
- list(name = "b", type = "string", nullable = TRUE))
- schema <- list(type = "struct", fields = fields)
+ schema <- structType(structField(x = "a", type = "integer", nullable = TRUE),
+ structField(x = "b", type = "string", nullable = TRUE))
df <- createDataFrame(sqlCtx, rdd, schema)
expect_true(inherits(df, "DataFrame"))
expect_equal(columns(df), c("a", "b"))
@@ -94,9 +104,8 @@ test_that("toDF", {
expect_true(inherits(df, "DataFrame"))
expect_equal(columns(df), c("_1", "_2"))
- fields <- list(list(name = "a", type = "integer", nullable = TRUE),
- list(name = "b", type = "string", nullable = TRUE))
- schema <- list(type = "struct", fields = fields)
+ schema <- structType(structField(x = "a", type = "integer", nullable = TRUE),
+ structField(x = "b", type = "string", nullable = TRUE))
df <- toDF(rdd, schema)
expect_true(inherits(df, "DataFrame"))
expect_equal(columns(df), c("a", "b"))
@@ -635,7 +644,7 @@ test_that("isLocal()", {
expect_false(isLocal(df))
})
-test_that("unionAll(), subtract(), and intersect() on a DataFrame", {
+test_that("unionAll(), except(), and intersect() on a DataFrame", {
df <- jsonFile(sqlCtx, jsonPath)
lines <- c("{\"name\":\"Bob\", \"age\":24}",
@@ -650,10 +659,10 @@ test_that("unionAll(), subtract(), and intersect() on a DataFrame", {
expect_true(count(unioned) == 6)
expect_true(first(unioned)$name == "Michael")
- subtracted <- sortDF(subtract(df, df2), desc(df$age))
+ excepted <- sortDF(except(df, df2), desc(df$age))
expect_true(inherits(unioned, "DataFrame"))
- expect_true(count(subtracted) == 2)
- expect_true(first(subtracted)$name == "Justin")
+ expect_true(count(excepted) == 2)
+ expect_true(first(excepted)$name == "Justin")
intersected <- sortDF(intersect(df, df2), df$age)
expect_true(inherits(unioned, "DataFrame"))
diff --git a/R/pkg/inst/worker/worker.R b/R/pkg/inst/worker/worker.R
index c6542928e8..014bf7bd7b 100644
--- a/R/pkg/inst/worker/worker.R
+++ b/R/pkg/inst/worker/worker.R
@@ -17,6 +17,23 @@
# Worker class
+# Get current system time
+currentTimeSecs <- function() {
+ as.numeric(Sys.time())
+}
+
+# Get elapsed time
+elapsedSecs <- function() {
+ proc.time()[3]
+}
+
+# Constants
+specialLengths <- list(END_OF_STERAM = 0L, TIMING_DATA = -1L)
+
+# Timing R process boot
+bootTime <- currentTimeSecs()
+bootElap <- elapsedSecs()
+
rLibDir <- Sys.getenv("SPARKR_RLIBDIR")
# Set libPaths to include SparkR package as loadNamespace needs this
# TODO: Figure out if we can avoid this by not loading any objects that require
@@ -37,7 +54,7 @@ serializer <- SparkR:::readString(inputCon)
# Include packages as required
packageNames <- unserialize(SparkR:::readRaw(inputCon))
for (pkg in packageNames) {
- suppressPackageStartupMessages(require(as.character(pkg), character.only=TRUE))
+ suppressPackageStartupMessages(library(as.character(pkg), character.only=TRUE))
}
# read function dependencies
@@ -46,6 +63,9 @@ computeFunc <- unserialize(SparkR:::readRawLen(inputCon, funcLen))
env <- environment(computeFunc)
parent.env(env) <- .GlobalEnv # Attach under global environment.
+# Timing init envs for computing
+initElap <- elapsedSecs()
+
# Read and set broadcast variables
numBroadcastVars <- SparkR:::readInt(inputCon)
if (numBroadcastVars > 0) {
@@ -56,6 +76,9 @@ if (numBroadcastVars > 0) {
}
}
+# Timing broadcast
+broadcastElap <- elapsedSecs()
+
# If -1: read as normal RDD; if >= 0, treat as pairwise RDD and treat the int
# as number of partitions to create.
numPartitions <- SparkR:::readInt(inputCon)
@@ -73,14 +96,23 @@ if (isEmpty != 0) {
} else if (deserializer == "row") {
data <- SparkR:::readDeserializeRows(inputCon)
}
+ # Timing reading input data for execution
+ inputElap <- elapsedSecs()
+
output <- computeFunc(partition, data)
+ # Timing computing
+ computeElap <- elapsedSecs()
+
if (serializer == "byte") {
SparkR:::writeRawSerialize(outputCon, output)
} else if (serializer == "row") {
SparkR:::writeRowSerialize(outputCon, output)
} else {
- SparkR:::writeStrings(outputCon, output)
+ # write lines one-by-one with flag
+ lapply(output, function(line) SparkR:::writeString(outputCon, line))
}
+ # Timing output
+ outputElap <- elapsedSecs()
} else {
if (deserializer == "byte") {
# Now read as many characters as described in funcLen
@@ -90,6 +122,8 @@ if (isEmpty != 0) {
} else if (deserializer == "row") {
data <- SparkR:::readDeserializeRows(inputCon)
}
+ # Timing reading input data for execution
+ inputElap <- elapsedSecs()
res <- new.env()
@@ -107,6 +141,8 @@ if (isEmpty != 0) {
res[[bucket]] <- acc
}
invisible(lapply(data, hashTupleToEnvir))
+ # Timing computing
+ computeElap <- elapsedSecs()
# Step 2: write out all of the environment as key-value pairs.
for (name in ls(res)) {
@@ -116,13 +152,26 @@ if (isEmpty != 0) {
length(res[[name]]$data) <- res[[name]]$counter
SparkR:::writeRawSerialize(outputCon, res[[name]]$data)
}
+ # Timing output
+ outputElap <- elapsedSecs()
}
+} else {
+ inputElap <- broadcastElap
+ computeElap <- broadcastElap
+ outputElap <- broadcastElap
}
+# Report timing
+SparkR:::writeInt(outputCon, specialLengths$TIMING_DATA)
+SparkR:::writeDouble(outputCon, bootTime)
+SparkR:::writeDouble(outputCon, initElap - bootElap) # init
+SparkR:::writeDouble(outputCon, broadcastElap - initElap) # broadcast
+SparkR:::writeDouble(outputCon, inputElap - broadcastElap) # input
+SparkR:::writeDouble(outputCon, computeElap - inputElap) # compute
+SparkR:::writeDouble(outputCon, outputElap - computeElap) # output
+
# End of output
-if (serializer %in% c("byte", "row")) {
- SparkR:::writeInt(outputCon, 0L)
-}
+SparkR:::writeInt(outputCon, specialLengths$END_OF_STERAM)
close(outputCon)
close(inputCon)
diff --git a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala
index 5fa4d483b8..6fea5e1144 100644
--- a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala
@@ -42,10 +42,15 @@ private abstract class BaseRRDD[T: ClassTag, U: ClassTag](
rLibDir: String,
broadcastVars: Array[Broadcast[Object]])
extends RDD[U](parent) with Logging {
+ protected var dataStream: DataInputStream = _
+ private var bootTime: Double = _
override def getPartitions: Array[Partition] = parent.partitions
override def compute(partition: Partition, context: TaskContext): Iterator[U] = {
+ // Timing start
+ bootTime = System.currentTimeMillis / 1000.0
+
// The parent may be also an RRDD, so we should launch it first.
val parentIterator = firstParent[T].iterator(partition, context)
@@ -69,7 +74,7 @@ private abstract class BaseRRDD[T: ClassTag, U: ClassTag](
// the socket used to receive the output of task
val outSocket = serverSocket.accept()
val inputStream = new BufferedInputStream(outSocket.getInputStream)
- val dataStream = openDataStream(inputStream)
+ dataStream = new DataInputStream(inputStream)
serverSocket.close()
try {
@@ -155,6 +160,7 @@ private abstract class BaseRRDD[T: ClassTag, U: ClassTag](
} else if (deserializer == SerializationFormats.ROW) {
dataOut.write(elem.asInstanceOf[Array[Byte]])
} else if (deserializer == SerializationFormats.STRING) {
+ // write string(for StringRRDD)
printOut.println(elem)
}
}
@@ -180,9 +186,41 @@ private abstract class BaseRRDD[T: ClassTag, U: ClassTag](
}.start()
}
- protected def openDataStream(input: InputStream): Closeable
+ protected def readData(length: Int): U
- protected def read(): U
+ protected def read(): U = {
+ try {
+ val length = dataStream.readInt()
+
+ length match {
+ case SpecialLengths.TIMING_DATA =>
+ // Timing data from R worker
+ val boot = dataStream.readDouble - bootTime
+ val init = dataStream.readDouble
+ val broadcast = dataStream.readDouble
+ val input = dataStream.readDouble
+ val compute = dataStream.readDouble
+ val output = dataStream.readDouble
+ logInfo(
+ ("Times: boot = %.3f s, init = %.3f s, broadcast = %.3f s, " +
+ "read-input = %.3f s, compute = %.3f s, write-output = %.3f s, " +
+ "total = %.3f s").format(
+ boot,
+ init,
+ broadcast,
+ input,
+ compute,
+ output,
+ boot + init + broadcast + input + compute + output))
+ read()
+ case length if length >= 0 =>
+ readData(length)
+ }
+ } catch {
+ case eof: EOFException =>
+ throw new SparkException("R worker exited unexpectedly (cranshed)", eof)
+ }
+ }
}
/**
@@ -202,31 +240,16 @@ private class PairwiseRRDD[T: ClassTag](
SerializationFormats.BYTE, packageNames, rLibDir,
broadcastVars.map(x => x.asInstanceOf[Broadcast[Object]])) {
- private var dataStream: DataInputStream = _
-
- override protected def openDataStream(input: InputStream): Closeable = {
- dataStream = new DataInputStream(input)
- dataStream
- }
-
- override protected def read(): (Int, Array[Byte]) = {
- try {
- val length = dataStream.readInt()
-
- length match {
- case length if length == 2 =>
- val hashedKey = dataStream.readInt()
- val contentPairsLength = dataStream.readInt()
- val contentPairs = new Array[Byte](contentPairsLength)
- dataStream.readFully(contentPairs)
- (hashedKey, contentPairs)
- case _ => null // End of input
- }
- } catch {
- case eof: EOFException => {
- throw new SparkException("R worker exited unexpectedly (crashed)", eof)
- }
- }
+ override protected def readData(length: Int): (Int, Array[Byte]) = {
+ length match {
+ case length if length == 2 =>
+ val hashedKey = dataStream.readInt()
+ val contentPairsLength = dataStream.readInt()
+ val contentPairs = new Array[Byte](contentPairsLength)
+ dataStream.readFully(contentPairs)
+ (hashedKey, contentPairs)
+ case _ => null
+ }
}
lazy val asJavaPairRDD : JavaPairRDD[Int, Array[Byte]] = JavaPairRDD.fromRDD(this)
@@ -247,28 +270,13 @@ private class RRDD[T: ClassTag](
parent, -1, func, deserializer, serializer, packageNames, rLibDir,
broadcastVars.map(x => x.asInstanceOf[Broadcast[Object]])) {
- private var dataStream: DataInputStream = _
-
- override protected def openDataStream(input: InputStream): Closeable = {
- dataStream = new DataInputStream(input)
- dataStream
- }
-
- override protected def read(): Array[Byte] = {
- try {
- val length = dataStream.readInt()
-
- length match {
- case length if length > 0 =>
- val obj = new Array[Byte](length)
- dataStream.readFully(obj, 0, length)
- obj
- case _ => null
- }
- } catch {
- case eof: EOFException => {
- throw new SparkException("R worker exited unexpectedly (crashed)", eof)
- }
+ override protected def readData(length: Int): Array[Byte] = {
+ length match {
+ case length if length > 0 =>
+ val obj = new Array[Byte](length)
+ dataStream.readFully(obj)
+ obj
+ case _ => null
}
}
@@ -289,26 +297,21 @@ private class StringRRDD[T: ClassTag](
parent, -1, func, deserializer, SerializationFormats.STRING, packageNames, rLibDir,
broadcastVars.map(x => x.asInstanceOf[Broadcast[Object]])) {
- private var dataStream: BufferedReader = _
-
- override protected def openDataStream(input: InputStream): Closeable = {
- dataStream = new BufferedReader(new InputStreamReader(input))
- dataStream
- }
-
- override protected def read(): String = {
- try {
- dataStream.readLine()
- } catch {
- case e: IOException => {
- throw new SparkException("R worker exited unexpectedly (crashed)", e)
- }
+ override protected def readData(length: Int): String = {
+ length match {
+ case length if length > 0 =>
+ SerDe.readStringBytes(dataStream, length)
+ case _ => null
}
}
lazy val asJavaRDD : JavaRDD[String] = JavaRDD.fromRDD(this)
}
+private object SpecialLengths {
+ val TIMING_DATA = -1
+}
+
private[r] class BufferedStreamThread(
in: InputStream,
name: String,
diff --git a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala
index ccb2a371f4..371dfe454d 100644
--- a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala
+++ b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala
@@ -85,13 +85,17 @@ private[spark] object SerDe {
in.readDouble()
}
+ def readStringBytes(in: DataInputStream, len: Int): String = {
+ val bytes = new Array[Byte](len)
+ in.readFully(bytes)
+ assert(bytes(len - 1) == 0)
+ val str = new String(bytes.dropRight(1), "UTF-8")
+ str
+ }
+
def readString(in: DataInputStream): String = {
val len = in.readInt()
- val asciiBytes = new Array[Byte](len)
- in.readFully(asciiBytes)
- assert(asciiBytes(len - 1) == 0)
- val str = new String(asciiBytes.dropRight(1).map(_.toChar))
- str
+ readStringBytes(in, len)
}
def readBoolean(in: DataInputStream): Boolean = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala
index d1ea7cc3e9..ae77f72998 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala
@@ -23,7 +23,7 @@ import org.apache.spark.api.java.{JavaRDD, JavaSparkContext}
import org.apache.spark.api.r.SerDe
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.expressions.{Alias, Expression, NamedExpression}
-import org.apache.spark.sql.types.{DataType, StructType}
+import org.apache.spark.sql.types._
import org.apache.spark.sql.{Column, DataFrame, GroupedData, Row, SQLContext, SaveMode}
private[r] object SQLUtils {
@@ -39,8 +39,34 @@ private[r] object SQLUtils {
arr.toSeq
}
- def createDF(rdd: RDD[Array[Byte]], schemaString: String, sqlContext: SQLContext): DataFrame = {
- val schema = DataType.fromJson(schemaString).asInstanceOf[StructType]
+ def createStructType(fields : Seq[StructField]): StructType = {
+ StructType(fields)
+ }
+
+ def getSQLDataType(dataType: String): DataType = {
+ dataType match {
+ case "byte" => org.apache.spark.sql.types.ByteType
+ case "integer" => org.apache.spark.sql.types.IntegerType
+ case "double" => org.apache.spark.sql.types.DoubleType
+ case "numeric" => org.apache.spark.sql.types.DoubleType
+ case "character" => org.apache.spark.sql.types.StringType
+ case "string" => org.apache.spark.sql.types.StringType
+ case "binary" => org.apache.spark.sql.types.BinaryType
+ case "raw" => org.apache.spark.sql.types.BinaryType
+ case "logical" => org.apache.spark.sql.types.BooleanType
+ case "boolean" => org.apache.spark.sql.types.BooleanType
+ case "timestamp" => org.apache.spark.sql.types.TimestampType
+ case "date" => org.apache.spark.sql.types.DateType
+ case _ => throw new IllegalArgumentException(s"Invaid type $dataType")
+ }
+ }
+
+ def createStructField(name: String, dataType: String, nullable: Boolean): StructField = {
+ val dtObj = getSQLDataType(dataType)
+ StructField(name, dtObj, nullable)
+ }
+
+ def createDF(rdd: RDD[Array[Byte]], schema: StructType, sqlContext: SQLContext): DataFrame = {
val num = schema.fields.size
val rowRDD = rdd.map(bytesToRow)
sqlContext.createDataFrame(rowRDD, schema)