aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorHossein <hossein@databricks.com>2016-10-12 10:32:38 -0700
committerFelix Cheung <felixcheung@apache.org>2016-10-12 10:32:38 -0700
commit5cc503f4fe9737a4c7947a80eecac053780606df (patch)
tree02cfea5ff7007d7375b17786880d55a6867eedb7
parentd5580ebaa086b9feb72d5428f24c5b60cd7da745 (diff)
downloadspark-5cc503f4fe9737a4c7947a80eecac053780606df.tar.gz
spark-5cc503f4fe9737a4c7947a80eecac053780606df.tar.bz2
spark-5cc503f4fe9737a4c7947a80eecac053780606df.zip
[SPARK-17790][SPARKR] Support for parallelizing R data.frame larger than 2GB
## What changes were proposed in this pull request? If the R data structure that is being parallelized is larger than `INT_MAX` we use files to transfer data to JVM. The serialization protocol mimics Python pickling. This allows us to simply call `PythonRDD.readRDDFromFile` to create the RDD. I tested this on my MacBook. Following code works with this patch: ```R intMax <- .Machine$integer.max largeVec <- 1:intMax rdd <- SparkR:::parallelize(sc, largeVec, 2) ``` ## How was this patch tested? * [x] Unit tests Author: Hossein <hossein@databricks.com> Closes #15375 from falaki/SPARK-17790.
-rw-r--r--R/pkg/R/context.R45
-rw-r--r--R/pkg/inst/tests/testthat/test_sparkSQL.R11
-rw-r--r--core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/api/r/RRDD.scala13
4 files changed, 68 insertions, 3 deletions
diff --git a/R/pkg/R/context.R b/R/pkg/R/context.R
index fe2f3e3d10..438d77a388 100644
--- a/R/pkg/R/context.R
+++ b/R/pkg/R/context.R
@@ -87,6 +87,10 @@ objectFile <- function(sc, path, minPartitions = NULL) {
#' in the list are split into \code{numSlices} slices and distributed to nodes
#' in the cluster.
#'
+#' If size of serialized slices is larger than spark.r.maxAllocationLimit or (200MB), the function
+#' will write it to disk and send the file name to JVM. Also to make sure each slice is not
+#' larger than that limit, number of slices may be increased.
+#'
#' @param sc SparkContext to use
#' @param coll collection to parallelize
#' @param numSlices number of partitions to create in the RDD
@@ -120,6 +124,11 @@ parallelize <- function(sc, coll, numSlices = 1) {
coll <- as.list(coll)
}
+ sizeLimit <- getMaxAllocationLimit(sc)
+ objectSize <- object.size(coll)
+
+ # For large objects we make sure the size of each slice is also smaller than sizeLimit
+ numSlices <- max(numSlices, ceiling(objectSize / sizeLimit))
if (numSlices > length(coll))
numSlices <- length(coll)
@@ -130,12 +139,44 @@ parallelize <- function(sc, coll, numSlices = 1) {
# 2-tuples of raws
serializedSlices <- lapply(slices, serialize, connection = NULL)
- jrdd <- callJStatic("org.apache.spark.api.r.RRDD",
- "createRDDFromArray", sc, serializedSlices)
+ # The PRC backend cannot handle arguments larger than 2GB (INT_MAX)
+ # If serialized data is safely less than that threshold we send it over the PRC channel.
+ # Otherwise, we write it to a file and send the file name
+ if (objectSize < sizeLimit) {
+ jrdd <- callJStatic("org.apache.spark.api.r.RRDD", "createRDDFromArray", sc, serializedSlices)
+ } else {
+ fileName <- writeToTempFile(serializedSlices)
+ jrdd <- tryCatch(callJStatic(
+ "org.apache.spark.api.r.RRDD", "createRDDFromFile", sc, fileName, as.integer(numSlices)),
+ finally = {
+ file.remove(fileName)
+ })
+ }
RDD(jrdd, "byte")
}
+getMaxAllocationLimit <- function(sc) {
+ conf <- callJMethod(sc, "getConf")
+ as.numeric(
+ callJMethod(conf,
+ "get",
+ "spark.r.maxAllocationLimit",
+ toString(.Machine$integer.max / 10) # Default to a safe value: 200MB
+ ))
+}
+
+writeToTempFile <- function(serializedSlices) {
+ fileName <- tempfile()
+ conn <- file(fileName, "wb")
+ for (slice in serializedSlices) {
+ writeBin(as.integer(length(slice)), conn, endian = "big")
+ writeBin(slice, conn, endian = "big")
+ }
+ close(conn)
+ fileName
+}
+
#' Include this specified package on all workers
#'
#' This function can be used to include a package on all workers before the
diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R
index 61554248ee..af81d0586e 100644
--- a/R/pkg/inst/tests/testthat/test_sparkSQL.R
+++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R
@@ -208,6 +208,17 @@ test_that("create DataFrame from RDD", {
unsetHiveContext()
})
+test_that("createDataFrame uses files for large objects", {
+ # To simulate a large file scenario, we set spark.r.maxAllocationLimit to a smaller value
+ conf <- callJMethod(sparkSession, "conf")
+ callJMethod(conf, "set", "spark.r.maxAllocationLimit", "100")
+ df <- createDataFrame(iris)
+
+ # Resetting the conf back to default value
+ callJMethod(conf, "set", "spark.r.maxAllocationLimit", toString(.Machine$integer.max / 10))
+ expect_equal(dim(df), dim(iris))
+})
+
test_that("read/write csv as DataFrame", {
csvPath <- tempfile(pattern = "sparkr-test", fileext = ".csv")
mockLinesCsv <- c("year,make,model,comment,blank",
diff --git a/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala b/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala
index 7d5348266b..1422ef888f 100644
--- a/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala
+++ b/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala
@@ -168,7 +168,7 @@ private[r] class RBackendHandler(server: RBackend)
}
} catch {
case e: Exception =>
- logError(s"$methodName on $objId failed")
+ logError(s"$methodName on $objId failed", e)
writeInt(dos, -1)
// Writing the error message of the cause for the exception. This will be returned
// to user in the R process.
diff --git a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala
index 59c8429c80..a1a5eb8cf5 100644
--- a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala
@@ -24,6 +24,7 @@ import scala.reflect.ClassTag
import org.apache.spark._
import org.apache.spark.api.java.{JavaPairRDD, JavaRDD, JavaSparkContext}
+import org.apache.spark.api.python.PythonRDD
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
@@ -140,4 +141,16 @@ private[r] object RRDD {
def createRDDFromArray(jsc: JavaSparkContext, arr: Array[Array[Byte]]): JavaRDD[Array[Byte]] = {
JavaRDD.fromRDD(jsc.sc.parallelize(arr, arr.length))
}
+
+ /**
+ * Create an RRDD given a temporary file name. This is used to create RRDD when parallelize is
+ * called on large R objects.
+ *
+ * @param fileName name of temporary file on driver machine
+ * @param parallelism number of slices defaults to 4
+ */
+ def createRDDFromFile(jsc: JavaSparkContext, fileName: String, parallelism: Int):
+ JavaRDD[Array[Byte]] = {
+ PythonRDD.readRDDFromFile(jsc, fileName, parallelism)
+ }
}