aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--.gitignore2
-rw-r--r--.rat-excludes2
-rw-r--r--R/.gitignore6
-rw-r--r--R/DOCUMENTATION.md12
-rw-r--r--R/README.md67
-rw-r--r--R/WINDOWS.md13
-rwxr-xr-xR/create-docs.sh46
-rw-r--r--R/install-dev.bat27
-rwxr-xr-xR/install-dev.sh36
-rw-r--r--R/log4j.properties28
-rw-r--r--R/pkg/DESCRIPTION35
-rw-r--r--R/pkg/NAMESPACE182
-rw-r--r--R/pkg/R/DataFrame.R1270
-rw-r--r--R/pkg/R/RDD.R1539
-rw-r--r--R/pkg/R/SQLContext.R520
-rw-r--r--R/pkg/R/SQLTypes.R64
-rw-r--r--R/pkg/R/backend.R115
-rw-r--r--R/pkg/R/broadcast.R86
-rw-r--r--R/pkg/R/client.R57
-rw-r--r--R/pkg/R/column.R199
-rw-r--r--R/pkg/R/context.R225
-rw-r--r--R/pkg/R/deserialize.R184
-rw-r--r--R/pkg/R/generics.R543
-rw-r--r--R/pkg/R/group.R132
-rw-r--r--R/pkg/R/jobj.R101
-rw-r--r--R/pkg/R/pairRDD.R789
-rw-r--r--R/pkg/R/serialize.R195
-rw-r--r--R/pkg/R/sparkR.R266
-rw-r--r--R/pkg/R/utils.R467
-rw-r--r--R/pkg/R/zzz.R21
-rw-r--r--R/pkg/inst/profile/general.R22
-rw-r--r--R/pkg/inst/profile/shell.R31
-rw-r--r--R/pkg/inst/tests/test_binaryFile.R90
-rw-r--r--R/pkg/inst/tests/test_binary_function.R68
-rw-r--r--R/pkg/inst/tests/test_broadcast.R48
-rw-r--r--R/pkg/inst/tests/test_context.R50
-rw-r--r--R/pkg/inst/tests/test_includePackage.R57
-rw-r--r--R/pkg/inst/tests/test_parallelize_collect.R109
-rw-r--r--R/pkg/inst/tests/test_rdd.R644
-rw-r--r--R/pkg/inst/tests/test_shuffle.R209
-rw-r--r--R/pkg/inst/tests/test_sparkSQL.R695
-rw-r--r--R/pkg/inst/tests/test_take.R67
-rw-r--r--R/pkg/inst/tests/test_textFile.R162
-rw-r--r--R/pkg/inst/tests/test_utils.R137
-rw-r--r--R/pkg/inst/worker/daemon.R52
-rw-r--r--R/pkg/inst/worker/worker.R128
-rw-r--r--R/pkg/src/Makefile27
-rw-r--r--R/pkg/src/Makefile.win27
-rw-r--r--R/pkg/src/string_hash_code.c49
-rw-r--r--R/pkg/tests/run-all.R21
-rwxr-xr-xR/run-tests.sh39
-rwxr-xr-xbin/sparkR39
-rw-r--r--bin/sparkR.cmd23
-rw-r--r--bin/sparkR2.cmd26
-rw-r--r--core/pom.xml51
-rw-r--r--core/src/main/scala/org/apache/spark/api/r/RBackend.scala145
-rw-r--r--core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala223
-rw-r--r--core/src/main/scala/org/apache/spark/api/r/RRDD.scala450
-rw-r--r--core/src/main/scala/org/apache/spark/api/r/SerDe.scala340
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/RRunner.scala92
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala73
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala8
-rwxr-xr-xdev/run-tests15
-rw-r--r--dev/run-tests-codes.sh1
-rwxr-xr-xdev/run-tests-jenkins2
-rw-r--r--docs/README.md12
-rwxr-xr-xdocs/_layouts/global.html1
-rw-r--r--docs/_plugins/copy_api_dirs.rb15
-rw-r--r--examples/src/main/r/kmeans.R93
-rw-r--r--examples/src/main/r/linear_solver_mnist.R107
-rw-r--r--examples/src/main/r/logistic_regression.R62
-rw-r--r--examples/src/main/r/pi.R46
-rw-r--r--examples/src/main/r/wordcount.R42
-rw-r--r--launcher/src/main/java/org/apache/spark/launcher/CommandBuilderUtils.java8
-rw-r--r--launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java87
-rw-r--r--launcher/src/test/java/org/apache/spark/launcher/CommandBuilderUtilsSuite.java6
-rw-r--r--pom.xml3
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala127
-rw-r--r--yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala3
-rw-r--r--yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala11
-rw-r--r--yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala13
-rw-r--r--yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala11
83 files changed, 12043 insertions, 55 deletions
diff --git a/.gitignore b/.gitignore
index d162fa9cca..d54d21b802 100644
--- a/.gitignore
+++ b/.gitignore
@@ -63,6 +63,8 @@ ec2/lib/
rat-results.txt
scalastyle.txt
scalastyle-output.xml
+R-unit-tests.log
+R/unit-tests.out
# For Hive
metastore_db/
diff --git a/.rat-excludes b/.rat-excludes
index 8c61e67a0c..8aca5a7f7a 100644
--- a/.rat-excludes
+++ b/.rat-excludes
@@ -67,3 +67,5 @@ logs
.*scalastyle-output.xml
.*dependency-reduced-pom.xml
known_translations
+DESCRIPTION
+NAMESPACE
diff --git a/R/.gitignore b/R/.gitignore
new file mode 100644
index 0000000000..9a5889ba28
--- /dev/null
+++ b/R/.gitignore
@@ -0,0 +1,6 @@
+*.o
+*.so
+*.Rd
+lib
+pkg/man
+pkg/html
diff --git a/R/DOCUMENTATION.md b/R/DOCUMENTATION.md
new file mode 100644
index 0000000000..931d01549b
--- /dev/null
+++ b/R/DOCUMENTATION.md
@@ -0,0 +1,12 @@
+# SparkR Documentation
+
+SparkR documentation is generated using in-source comments annotated using using
+`roxygen2`. After making changes to the documentation, to generate man pages,
+you can run the following from an R console in the SparkR home directory
+
+ library(devtools)
+ devtools::document(pkg="./pkg", roclets=c("rd"))
+
+You can verify if your changes are good by running
+
+ R CMD check pkg/
diff --git a/R/README.md b/R/README.md
new file mode 100644
index 0000000000..a6970e39b5
--- /dev/null
+++ b/R/README.md
@@ -0,0 +1,67 @@
+# R on Spark
+
+SparkR is an R package that provides a light-weight frontend to use Spark from R.
+
+### SparkR development
+
+#### Build Spark
+
+Build Spark with [Maven](http://spark.apache.org/docs/latest/building-spark.html#building-with-buildmvn) and include the `-PsparkR` profile to build the R package. For example to use the default Hadoop versions you can run
+```
+ build/mvn -DskipTests -Psparkr package
+```
+
+#### Running sparkR
+
+You can start using SparkR by launching the SparkR shell with
+
+ ./bin/sparkR
+
+The `sparkR` script automatically creates a SparkContext with Spark by default in
+local mode. To specify the Spark master of a cluster for the automatically created
+SparkContext, you can run
+
+ ./bin/sparkR --master "local[2]"
+
+To set other options like driver memory, executor memory etc. you can pass in the [spark-submit](http://spark.apache.org/docs/latest/submitting-applications.html) arguments to `./bin/sparkR`
+
+#### Using SparkR from RStudio
+
+If you wish to use SparkR from RStudio or other R frontends you will need to set some environment variables which point SparkR to your Spark installation. For example
+```
+# Set this to where Spark is installed
+Sys.setenv(SPARK_HOME="/Users/shivaram/spark")
+# This line loads SparkR from the installed directory
+.libPaths(c(file.path(Sys.getenv("SPARK_HOME"), "R", "lib"), .libPaths()))
+library(SparkR)
+sc <- sparkR.init(master="local")
+```
+
+#### Making changes to SparkR
+
+The [instructions](https://cwiki.apache.org/confluence/display/SPARK/Contributing+to+Spark) for making contributions to Spark also apply to SparkR.
+If you only make R file changes (i.e. no Scala changes) then you can just re-install the R package using `R/install-dev.sh` and test your changes.
+Once you have made your changes, please include unit tests for them and run existing unit tests using the `run-tests.sh` script as described below.
+
+#### Generating documentation
+
+The SparkR documentation (Rd files and HTML files) are not a part of the source repository. To generate them you can run the script `R/create-docs.sh`. This script uses `devtools` and `knitr` to generate the docs and these packages need to be installed on the machine before using the script.
+
+### Examples, Unit tests
+
+SparkR comes with several sample programs in the `examples/src/main/r` directory.
+To run one of them, use `./bin/sparkR <filename> <args>`. For example:
+
+ ./bin/sparkR examples/src/main/r/pi.R local[2]
+
+You can also run the unit-tests for SparkR by running (you need to install the [testthat](http://cran.r-project.org/web/packages/testthat/index.html) package first):
+
+ R -e 'install.packages("testthat", repos="http://cran.us.r-project.org")'
+ ./R/run-tests.sh
+
+### Running on YARN
+The `./bin/spark-submit` and `./bin/sparkR` can also be used to submit jobs to YARN clusters. You will need to set YARN conf dir before doing so. For example on CDH you can run
+```
+export YARN_CONF_DIR=/etc/hadoop/conf
+./bin/spark-submit --master yarn examples/src/main/r/pi.R 4
+```
diff --git a/R/WINDOWS.md b/R/WINDOWS.md
new file mode 100644
index 0000000000..3f889c0ca3
--- /dev/null
+++ b/R/WINDOWS.md
@@ -0,0 +1,13 @@
+## Building SparkR on Windows
+
+To build SparkR on Windows, the following steps are required
+
+1. Install R (>= 3.1) and [Rtools](http://cran.r-project.org/bin/windows/Rtools/). Make sure to
+include Rtools and R in `PATH`.
+2. Install
+[JDK7](http://www.oracle.com/technetwork/java/javase/downloads/jdk7-downloads-1880260.html) and set
+`JAVA_HOME` in the system environment variables.
+3. Download and install [Maven](http://maven.apache.org/download.html). Also include the `bin`
+directory in Maven in `PATH`.
+4. Set `MAVEN_OPTS` as described in [Building Spark](http://spark.apache.org/docs/latest/building-spark.html).
+5. Open a command shell (`cmd`) in the Spark directory and run `mvn -DskipTests -Psparkr package`
diff --git a/R/create-docs.sh b/R/create-docs.sh
new file mode 100755
index 0000000000..4194172a2e
--- /dev/null
+++ b/R/create-docs.sh
@@ -0,0 +1,46 @@
+#!/bin/bash
+
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# Script to create API docs for SparkR
+# This requires `devtools` and `knitr` to be installed on the machine.
+
+# After running this script the html docs can be found in
+# $SPARK_HOME/R/pkg/html
+
+# Figure out where the script is
+export FWDIR="$(cd "`dirname "$0"`"; pwd)"
+pushd $FWDIR
+
+# Generate Rd file
+Rscript -e 'library(devtools); devtools::document(pkg="./pkg", roclets=c("rd"))'
+
+# Install the package
+./install-dev.sh
+
+# Now create HTML files
+
+# knit_rd puts html in current working directory
+mkdir -p pkg/html
+pushd pkg/html
+
+Rscript -e 'library(SparkR, lib.loc="../../lib"); library(knitr); knit_rd("SparkR")'
+
+popd
+
+popd
diff --git a/R/install-dev.bat b/R/install-dev.bat
new file mode 100644
index 0000000000..008a5c668b
--- /dev/null
+++ b/R/install-dev.bat
@@ -0,0 +1,27 @@
+@echo off
+
+rem
+rem Licensed to the Apache Software Foundation (ASF) under one or more
+rem contributor license agreements. See the NOTICE file distributed with
+rem this work for additional information regarding copyright ownership.
+rem The ASF licenses this file to You under the Apache License, Version 2.0
+rem (the "License"); you may not use this file except in compliance with
+rem the License. You may obtain a copy of the License at
+rem
+rem http://www.apache.org/licenses/LICENSE-2.0
+rem
+rem Unless required by applicable law or agreed to in writing, software
+rem distributed under the License is distributed on an "AS IS" BASIS,
+rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+rem See the License for the specific language governing permissions and
+rem limitations under the License.
+rem
+
+rem Install development version of SparkR
+rem
+
+set SPARK_HOME=%~dp0..
+
+MKDIR %SPARK_HOME%\R\lib
+
+R.exe CMD INSTALL --library="%SPARK_HOME%\R\lib" %SPARK_HOME%\R\pkg\
diff --git a/R/install-dev.sh b/R/install-dev.sh
new file mode 100755
index 0000000000..55ed6f4be1
--- /dev/null
+++ b/R/install-dev.sh
@@ -0,0 +1,36 @@
+#!/bin/bash
+
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# This scripts packages the SparkR source files (R and C files) and
+# creates a package that can be loaded in R. The package is by default installed to
+# $FWDIR/lib and the package can be loaded by using the following command in R:
+#
+# library(SparkR, lib.loc="$FWDIR/lib")
+#
+# NOTE(shivaram): Right now we use $SPARK_HOME/R/lib to be the installation directory
+# to load the SparkR package on the worker nodes.
+
+
+FWDIR="$(cd `dirname $0`; pwd)"
+LIB_DIR="$FWDIR/lib"
+
+mkdir -p $LIB_DIR
+
+# Install R
+R CMD INSTALL --library=$LIB_DIR $FWDIR/pkg/
diff --git a/R/log4j.properties b/R/log4j.properties
new file mode 100644
index 0000000000..701adb2a3d
--- /dev/null
+++ b/R/log4j.properties
@@ -0,0 +1,28 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# Set everything to be logged to the file target/unit-tests.log
+log4j.rootCategory=INFO, file
+log4j.appender.file=org.apache.log4j.FileAppender
+log4j.appender.file.append=true
+log4j.appender.file.file=R-unit-tests.log
+log4j.appender.file.layout=org.apache.log4j.PatternLayout
+log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n
+
+# Ignore messages below warning level from Jetty, because it's a bit verbose
+log4j.logger.org.eclipse.jetty=WARN
+org.eclipse.jetty.LEVEL=WARN
diff --git a/R/pkg/DESCRIPTION b/R/pkg/DESCRIPTION
new file mode 100644
index 0000000000..1842b97d43
--- /dev/null
+++ b/R/pkg/DESCRIPTION
@@ -0,0 +1,35 @@
+Package: SparkR
+Type: Package
+Title: R frontend for Spark
+Version: 1.4.0
+Date: 2013-09-09
+Author: The Apache Software Foundation
+Maintainer: Shivaram Venkataraman <shivaram@cs.berkeley.edu>
+Imports:
+ methods
+Depends:
+ R (>= 3.0),
+ methods,
+Suggests:
+ testthat
+Description: R frontend for Spark
+License: Apache License (== 2.0)
+Collate:
+ 'generics.R'
+ 'jobj.R'
+ 'SQLTypes.R'
+ 'RDD.R'
+ 'pairRDD.R'
+ 'column.R'
+ 'group.R'
+ 'DataFrame.R'
+ 'SQLContext.R'
+ 'broadcast.R'
+ 'context.R'
+ 'deserialize.R'
+ 'serialize.R'
+ 'sparkR.R'
+ 'backend.R'
+ 'client.R'
+ 'utils.R'
+ 'zzz.R'
diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE
new file mode 100644
index 0000000000..a354cdce74
--- /dev/null
+++ b/R/pkg/NAMESPACE
@@ -0,0 +1,182 @@
+#exportPattern("^[[:alpha:]]+")
+exportClasses("RDD")
+exportClasses("Broadcast")
+exportMethods(
+ "aggregateByKey",
+ "aggregateRDD",
+ "cache",
+ "checkpoint",
+ "coalesce",
+ "cogroup",
+ "collect",
+ "collectAsMap",
+ "collectPartition",
+ "combineByKey",
+ "count",
+ "countByKey",
+ "countByValue",
+ "distinct",
+ "Filter",
+ "filterRDD",
+ "first",
+ "flatMap",
+ "flatMapValues",
+ "fold",
+ "foldByKey",
+ "foreach",
+ "foreachPartition",
+ "fullOuterJoin",
+ "glom",
+ "groupByKey",
+ "join",
+ "keyBy",
+ "keys",
+ "length",
+ "lapply",
+ "lapplyPartition",
+ "lapplyPartitionsWithIndex",
+ "leftOuterJoin",
+ "lookup",
+ "map",
+ "mapPartitions",
+ "mapPartitionsWithIndex",
+ "mapValues",
+ "maximum",
+ "minimum",
+ "numPartitions",
+ "partitionBy",
+ "persist",
+ "pipeRDD",
+ "reduce",
+ "reduceByKey",
+ "reduceByKeyLocally",
+ "repartition",
+ "rightOuterJoin",
+ "sampleRDD",
+ "saveAsTextFile",
+ "saveAsObjectFile",
+ "sortBy",
+ "sortByKey",
+ "sumRDD",
+ "take",
+ "takeOrdered",
+ "takeSample",
+ "top",
+ "unionRDD",
+ "unpersist",
+ "value",
+ "values",
+ "zipRDD",
+ "zipWithIndex",
+ "zipWithUniqueId"
+ )
+
+# S3 methods exported
+export(
+ "textFile",
+ "objectFile",
+ "parallelize",
+ "hashCode",
+ "includePackage",
+ "broadcast",
+ "setBroadcastValue",
+ "setCheckpointDir"
+ )
+export("sparkR.init")
+export("sparkR.stop")
+export("print.jobj")
+useDynLib(SparkR, stringHashCode)
+importFrom(methods, setGeneric, setMethod, setOldClass)
+
+# SparkRSQL
+
+exportClasses("DataFrame")
+
+exportMethods("columns",
+ "distinct",
+ "dtypes",
+ "explain",
+ "filter",
+ "groupBy",
+ "head",
+ "insertInto",
+ "intersect",
+ "isLocal",
+ "limit",
+ "orderBy",
+ "names",
+ "printSchema",
+ "registerTempTable",
+ "repartition",
+ "sampleDF",
+ "saveAsParquetFile",
+ "saveAsTable",
+ "saveDF",
+ "schema",
+ "select",
+ "selectExpr",
+ "show",
+ "showDF",
+ "sortDF",
+ "subtract",
+ "toJSON",
+ "toRDD",
+ "unionAll",
+ "where",
+ "withColumn",
+ "withColumnRenamed")
+
+exportClasses("Column")
+
+exportMethods("abs",
+ "alias",
+ "approxCountDistinct",
+ "asc",
+ "avg",
+ "cast",
+ "contains",
+ "countDistinct",
+ "desc",
+ "endsWith",
+ "getField",
+ "getItem",
+ "isNotNull",
+ "isNull",
+ "last",
+ "like",
+ "lower",
+ "max",
+ "mean",
+ "min",
+ "rlike",
+ "sqrt",
+ "startsWith",
+ "substr",
+ "sum",
+ "sumDistinct",
+ "upper")
+
+exportClasses("GroupedData")
+exportMethods("agg")
+
+export("sparkRSQL.init",
+ "sparkRHive.init")
+
+export("cacheTable",
+ "clearCache",
+ "createDataFrame",
+ "createExternalTable",
+ "dropTempTable",
+ "jsonFile",
+ "jsonRDD",
+ "loadDF",
+ "parquetFile",
+ "sql",
+ "table",
+ "tableNames",
+ "tables",
+ "toDF",
+ "uncacheTable")
+
+export("print.structType",
+ "print.structField")
diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R
new file mode 100644
index 0000000000..feafd56909
--- /dev/null
+++ b/R/pkg/R/DataFrame.R
@@ -0,0 +1,1270 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# DataFrame.R - DataFrame class and methods implemented in S4 OO classes
+
+#' @include jobj.R SQLTypes.R RDD.R pairRDD.R column.R group.R
+NULL
+
+setOldClass("jobj")
+
+#' @title S4 class that represents a DataFrame
+#' @description DataFrames can be created using functions like
+#' \code{jsonFile}, \code{table} etc.
+#' @rdname DataFrame
+#' @seealso jsonFile, table
+#'
+#' @param env An R environment that stores bookkeeping states of the DataFrame
+#' @param sdf A Java object reference to the backing Scala DataFrame
+#' @export
+setClass("DataFrame",
+ slots = list(env = "environment",
+ sdf = "jobj"))
+
+setMethod("initialize", "DataFrame", function(.Object, sdf, isCached) {
+ .Object@env <- new.env()
+ .Object@env$isCached <- isCached
+
+ .Object@sdf <- sdf
+ .Object
+})
+
+#' @rdname DataFrame
+#' @export
+dataFrame <- function(sdf, isCached = FALSE) {
+ new("DataFrame", sdf, isCached)
+}
+
+############################ DataFrame Methods ##############################################
+
+#' Print Schema of a DataFrame
+#'
+#' Prints out the schema in tree format
+#'
+#' @param x A SparkSQL DataFrame
+#'
+#' @rdname printSchema
+#' @export
+#' @examples
+#'\dontrun{
+#' sc <- sparkR.init()
+#' sqlCtx <- sparkRSQL.init(sc)
+#' path <- "path/to/file.json"
+#' df <- jsonFile(sqlCtx, path)
+#' printSchema(df)
+#'}
+setMethod("printSchema",
+ signature(x = "DataFrame"),
+ function(x) {
+ schemaString <- callJMethod(schema(x)$jobj, "treeString")
+ cat(schemaString)
+ })
+
+#' Get schema object
+#'
+#' Returns the schema of this DataFrame as a structType object.
+#'
+#' @param x A SparkSQL DataFrame
+#'
+#' @rdname schema
+#' @export
+#' @examples
+#'\dontrun{
+#' sc <- sparkR.init()
+#' sqlCtx <- sparkRSQL.init(sc)
+#' path <- "path/to/file.json"
+#' df <- jsonFile(sqlCtx, path)
+#' dfSchema <- schema(df)
+#'}
+setMethod("schema",
+ signature(x = "DataFrame"),
+ function(x) {
+ structType(callJMethod(x@sdf, "schema"))
+ })
+
+#' Explain
+#'
+#' Print the logical and physical Catalyst plans to the console for debugging.
+#'
+#' @param x A SparkSQL DataFrame
+#' @param extended Logical. If extended is False, explain() only prints the physical plan.
+#' @rdname explain
+#' @export
+#' @examples
+#'\dontrun{
+#' sc <- sparkR.init()
+#' sqlCtx <- sparkRSQL.init(sc)
+#' path <- "path/to/file.json"
+#' df <- jsonFile(sqlCtx, path)
+#' explain(df, TRUE)
+#'}
+setMethod("explain",
+ signature(x = "DataFrame"),
+ function(x, extended = FALSE) {
+ queryExec <- callJMethod(x@sdf, "queryExecution")
+ if (extended) {
+ cat(callJMethod(queryExec, "toString"))
+ } else {
+ execPlan <- callJMethod(queryExec, "executedPlan")
+ cat(callJMethod(execPlan, "toString"))
+ }
+ })
+
+#' isLocal
+#'
+#' Returns True if the `collect` and `take` methods can be run locally
+#' (without any Spark executors).
+#'
+#' @param x A SparkSQL DataFrame
+#'
+#' @rdname isLocal
+#' @export
+#' @examples
+#'\dontrun{
+#' sc <- sparkR.init()
+#' sqlCtx <- sparkRSQL.init(sc)
+#' path <- "path/to/file.json"
+#' df <- jsonFile(sqlCtx, path)
+#' isLocal(df)
+#'}
+setMethod("isLocal",
+ signature(x = "DataFrame"),
+ function(x) {
+ callJMethod(x@sdf, "isLocal")
+ })
+
+#' ShowDF
+#'
+#' Print the first numRows rows of a DataFrame
+#'
+#' @param x A SparkSQL DataFrame
+#' @param numRows The number of rows to print. Defaults to 20.
+#'
+#' @rdname showDF
+#' @export
+#' @examples
+#'\dontrun{
+#' sc <- sparkR.init()
+#' sqlCtx <- sparkRSQL.init(sc)
+#' path <- "path/to/file.json"
+#' df <- jsonFile(sqlCtx, path)
+#' showDF(df)
+#'}
+setMethod("showDF",
+ signature(x = "DataFrame"),
+ function(x, numRows = 20) {
+ cat(callJMethod(x@sdf, "showString", numToInt(numRows)), "\n")
+ })
+
+#' show
+#'
+#' Print the DataFrame column names and types
+#'
+#' @param x A SparkSQL DataFrame
+#'
+#' @rdname show
+#' @export
+#' @examples
+#'\dontrun{
+#' sc <- sparkR.init()
+#' sqlCtx <- sparkRSQL.init(sc)
+#' path <- "path/to/file.json"
+#' df <- jsonFile(sqlCtx, path)
+#' show(df)
+#'}
+setMethod("show", "DataFrame",
+ function(object) {
+ cols <- lapply(dtypes(object), function(l) {
+ paste(l, collapse = ":")
+ })
+ s <- paste(cols, collapse = ", ")
+ cat(paste("DataFrame[", s, "]\n", sep = ""))
+ })
+
+#' DataTypes
+#'
+#' Return all column names and their data types as a list
+#'
+#' @param x A SparkSQL DataFrame
+#'
+#' @rdname dtypes
+#' @export
+#' @examples
+#'\dontrun{
+#' sc <- sparkR.init()
+#' sqlCtx <- sparkRSQL.init(sc)
+#' path <- "path/to/file.json"
+#' df <- jsonFile(sqlCtx, path)
+#' dtypes(df)
+#'}
+setMethod("dtypes",
+ signature(x = "DataFrame"),
+ function(x) {
+ lapply(schema(x)$fields(), function(f) {
+ c(f$name(), f$dataType.simpleString())
+ })
+ })
+
+#' Column names
+#'
+#' Return all column names as a list
+#'
+#' @param x A SparkSQL DataFrame
+#'
+#' @rdname columns
+#' @export
+#' @examples
+#'\dontrun{
+#' sc <- sparkR.init()
+#' sqlCtx <- sparkRSQL.init(sc)
+#' path <- "path/to/file.json"
+#' df <- jsonFile(sqlCtx, path)
+#' columns(df)
+#'}
+setMethod("columns",
+ signature(x = "DataFrame"),
+ function(x) {
+ sapply(schema(x)$fields(), function(f) {
+ f$name()
+ })
+ })
+
+#' @rdname columns
+#' @export
+setMethod("names",
+ signature(x = "DataFrame"),
+ function(x) {
+ columns(x)
+ })
+
+#' Register Temporary Table
+#'
+#' Registers a DataFrame as a Temporary Table in the SQLContext
+#'
+#' @param x A SparkSQL DataFrame
+#' @param tableName A character vector containing the name of the table
+#'
+#' @rdname registerTempTable
+#' @export
+#' @examples
+#'\dontrun{
+#' sc <- sparkR.init()
+#' sqlCtx <- sparkRSQL.init(sc)
+#' path <- "path/to/file.json"
+#' df <- jsonFile(sqlCtx, path)
+#' registerTempTable(df, "json_df")
+#' new_df <- sql(sqlCtx, "SELECT * FROM json_df")
+#'}
+setMethod("registerTempTable",
+ signature(x = "DataFrame", tableName = "character"),
+ function(x, tableName) {
+ callJMethod(x@sdf, "registerTempTable", tableName)
+ })
+
+#' insertInto
+#'
+#' Insert the contents of a DataFrame into a table registered in the current SQL Context.
+#'
+#' @param x A SparkSQL DataFrame
+#' @param tableName A character vector containing the name of the table
+#' @param overwrite A logical argument indicating whether or not to overwrite
+#' the existing rows in the table.
+#'
+#' @rdname insertInto
+#' @export
+#' @examples
+#'\dontrun{
+#' sc <- sparkR.init()
+#' sqlCtx <- sparkRSQL.init(sc)
+#' df <- loadDF(sqlCtx, path, "parquet")
+#' df2 <- loadDF(sqlCtx, path2, "parquet")
+#' registerTempTable(df, "table1")
+#' insertInto(df2, "table1", overwrite = TRUE)
+#'}
+setMethod("insertInto",
+ signature(x = "DataFrame", tableName = "character"),
+ function(x, tableName, overwrite = FALSE) {
+ callJMethod(x@sdf, "insertInto", tableName, overwrite)
+ })
+
+#' Cache
+#'
+#' Persist with the default storage level (MEMORY_ONLY).
+#'
+#' @param x A SparkSQL DataFrame
+#'
+#' @rdname cache-methods
+#' @export
+#' @examples
+#'\dontrun{
+#' sc <- sparkR.init()
+#' sqlCtx <- sparkRSQL.init(sc)
+#' path <- "path/to/file.json"
+#' df <- jsonFile(sqlCtx, path)
+#' cache(df)
+#'}
+setMethod("cache",
+ signature(x = "DataFrame"),
+ function(x) {
+ cached <- callJMethod(x@sdf, "cache")
+ x@env$isCached <- TRUE
+ x
+ })
+
+#' Persist
+#'
+#' Persist this DataFrame with the specified storage level. For details of the
+#' supported storage levels, refer to
+#' http://spark.apache.org/docs/latest/programming-guide.html#rdd-persistence.
+#'
+#' @param x The DataFrame to persist
+#' @rdname persist
+#' @export
+#' @examples
+#'\dontrun{
+#' sc <- sparkR.init()
+#' sqlCtx <- sparkRSQL.init(sc)
+#' path <- "path/to/file.json"
+#' df <- jsonFile(sqlCtx, path)
+#' persist(df, "MEMORY_AND_DISK")
+#'}
+setMethod("persist",
+ signature(x = "DataFrame", newLevel = "character"),
+ function(x, newLevel) {
+ callJMethod(x@sdf, "persist", getStorageLevel(newLevel))
+ x@env$isCached <- TRUE
+ x
+ })
+
+#' Unpersist
+#'
+#' Mark this DataFrame as non-persistent, and remove all blocks for it from memory and
+#' disk.
+#'
+#' @param x The DataFrame to unpersist
+#' @param blocking Whether to block until all blocks are deleted
+#' @rdname unpersist-methods
+#' @export
+#' @examples
+#'\dontrun{
+#' sc <- sparkR.init()
+#' sqlCtx <- sparkRSQL.init(sc)
+#' path <- "path/to/file.json"
+#' df <- jsonFile(sqlCtx, path)
+#' persist(df, "MEMORY_AND_DISK")
+#' unpersist(df)
+#'}
+setMethod("unpersist",
+ signature(x = "DataFrame"),
+ function(x, blocking = TRUE) {
+ callJMethod(x@sdf, "unpersist", blocking)
+ x@env$isCached <- FALSE
+ x
+ })
+
+#' Repartition
+#'
+#' Return a new DataFrame that has exactly numPartitions partitions.
+#'
+#' @param x A SparkSQL DataFrame
+#' @param numPartitions The number of partitions to use.
+#' @rdname repartition
+#' @export
+#' @examples
+#'\dontrun{
+#' sc <- sparkR.init()
+#' sqlCtx <- sparkRSQL.init(sc)
+#' path <- "path/to/file.json"
+#' df <- jsonFile(sqlCtx, path)
+#' newDF <- repartition(df, 2L)
+#'}
+setMethod("repartition",
+ signature(x = "DataFrame", numPartitions = "numeric"),
+ function(x, numPartitions) {
+ sdf <- callJMethod(x@sdf, "repartition", numToInt(numPartitions))
+ dataFrame(sdf)
+ })
+
+#' toJSON
+#'
+#' Convert the rows of a DataFrame into JSON objects and return an RDD where
+#' each element contains a JSON string.
+#'
+#' @param x A SparkSQL DataFrame
+#' @return A StringRRDD of JSON objects
+#' @rdname tojson
+#' @export
+#' @examples
+#'\dontrun{
+#' sc <- sparkR.init()
+#' sqlCtx <- sparkRSQL.init(sc)
+#' path <- "path/to/file.json"
+#' df <- jsonFile(sqlCtx, path)
+#' newRDD <- toJSON(df)
+#'}
+setMethod("toJSON",
+ signature(x = "DataFrame"),
+ function(x) {
+ rdd <- callJMethod(x@sdf, "toJSON")
+ jrdd <- callJMethod(rdd, "toJavaRDD")
+ RDD(jrdd, serializedMode = "string")
+ })
+
+#' saveAsParquetFile
+#'
+#' Save the contents of a DataFrame as a Parquet file, preserving the schema. Files written out
+#' with this method can be read back in as a DataFrame using parquetFile().
+#'
+#' @param x A SparkSQL DataFrame
+#' @param path The directory where the file is saved
+#' @rdname saveAsParquetFile
+#' @export
+#' @examples
+#'\dontrun{
+#' sc <- sparkR.init()
+#' sqlCtx <- sparkRSQL.init(sc)
+#' path <- "path/to/file.json"
+#' df <- jsonFile(sqlCtx, path)
+#' saveAsParquetFile(df, "/tmp/sparkr-tmp/")
+#'}
+setMethod("saveAsParquetFile",
+ signature(x = "DataFrame", path = "character"),
+ function(x, path) {
+ invisible(callJMethod(x@sdf, "saveAsParquetFile", path))
+ })
+
+#' Distinct
+#'
+#' Return a new DataFrame containing the distinct rows in this DataFrame.
+#'
+#' @param x A SparkSQL DataFrame
+#' @rdname distinct
+#' @export
+#' @examples
+#'\dontrun{
+#' sc <- sparkR.init()
+#' sqlCtx <- sparkRSQL.init(sc)
+#' path <- "path/to/file.json"
+#' df <- jsonFile(sqlCtx, path)
+#' distinctDF <- distinct(df)
+#'}
+setMethod("distinct",
+ signature(x = "DataFrame"),
+ function(x) {
+ sdf <- callJMethod(x@sdf, "distinct")
+ dataFrame(sdf)
+ })
+
+#' SampleDF
+#'
+#' Return a sampled subset of this DataFrame using a random seed.
+#'
+#' @param x A SparkSQL DataFrame
+#' @param withReplacement Sampling with replacement or not
+#' @param fraction The (rough) sample target fraction
+#' @rdname sampleDF
+#' @export
+#' @examples
+#'\dontrun{
+#' sc <- sparkR.init()
+#' sqlCtx <- sparkRSQL.init(sc)
+#' path <- "path/to/file.json"
+#' df <- jsonFile(sqlCtx, path)
+#' collect(sampleDF(df, FALSE, 0.5))
+#' collect(sampleDF(df, TRUE, 0.5))
+#'}
+setMethod("sampleDF",
+ # TODO : Figure out how to send integer as java.lang.Long to JVM so
+ # we can send seed as an argument through callJMethod
+ signature(x = "DataFrame", withReplacement = "logical",
+ fraction = "numeric"),
+ function(x, withReplacement, fraction) {
+ if (fraction < 0.0) stop(cat("Negative fraction value:", fraction))
+ sdf <- callJMethod(x@sdf, "sample", withReplacement, fraction)
+ dataFrame(sdf)
+ })
+
+#' Count
+#'
+#' Returns the number of rows in a DataFrame
+#'
+#' @param x A SparkSQL DataFrame
+#'
+#' @rdname count
+#' @export
+#' @examples
+#'\dontrun{
+#' sc <- sparkR.init()
+#' sqlCtx <- sparkRSQL.init(sc)
+#' path <- "path/to/file.json"
+#' df <- jsonFile(sqlCtx, path)
+#' count(df)
+#' }
+setMethod("count",
+ signature(x = "DataFrame"),
+ function(x) {
+ callJMethod(x@sdf, "count")
+ })
+
+#' Collects all the elements of a Spark DataFrame and coerces them into an R data.frame.
+#'
+#' @param x A SparkSQL DataFrame
+#' @param stringsAsFactors (Optional) A logical indicating whether or not string columns
+#' should be converted to factors. FALSE by default.
+
+#' @rdname collect-methods
+#' @export
+#' @examples
+#'\dontrun{
+#' sc <- sparkR.init()
+#' sqlCtx <- sparkRSQL.init(sc)
+#' path <- "path/to/file.json"
+#' df <- jsonFile(sqlCtx, path)
+#' collected <- collect(df)
+#' firstName <- collected[[1]]$name
+#' }
+setMethod("collect",
+ signature(x = "DataFrame"),
+ function(x, stringsAsFactors = FALSE) {
+ # listCols is a list of raw vectors, one per column
+ listCols <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "dfToCols", x@sdf)
+ cols <- lapply(listCols, function(col) {
+ objRaw <- rawConnection(col)
+ numRows <- readInt(objRaw)
+ col <- readCol(objRaw, numRows)
+ close(objRaw)
+ col
+ })
+ names(cols) <- columns(x)
+ do.call(cbind.data.frame, list(cols, stringsAsFactors = stringsAsFactors))
+ })
+
+#' Limit
+#'
+#' Limit the resulting DataFrame to the number of rows specified.
+#'
+#' @param x A SparkSQL DataFrame
+#' @param num The number of rows to return
+#' @return A new DataFrame containing the number of rows specified.
+#'
+#' @rdname limit
+#' @export
+#' @examples
+#' \dontrun{
+#' sc <- sparkR.init()
+#' sqlCtx <- sparkRSQL.init(sc)
+#' path <- "path/to/file.json"
+#' df <- jsonFile(sqlCtx, path)
+#' limitedDF <- limit(df, 10)
+#' }
+setMethod("limit",
+ signature(x = "DataFrame", num = "numeric"),
+ function(x, num) {
+ res <- callJMethod(x@sdf, "limit", as.integer(num))
+ dataFrame(res)
+ })
+
+# Take the first NUM rows of a DataFrame and return a the results as a data.frame
+
+#' @rdname take
+#' @export
+#' @examples
+#'\dontrun{
+#' sc <- sparkR.init()
+#' sqlCtx <- sparkRSQL.init(sc)
+#' path <- "path/to/file.json"
+#' df <- jsonFile(sqlCtx, path)
+#' take(df, 2)
+#' }
+setMethod("take",
+ signature(x = "DataFrame", num = "numeric"),
+ function(x, num) {
+ limited <- limit(x, num)
+ collect(limited)
+ })
+
+#' Head
+#'
+#' Return the first NUM rows of a DataFrame as a data.frame. If NUM is NULL,
+#' then head() returns the first 6 rows in keeping with the current data.frame
+#' convention in R.
+#'
+#' @param x A SparkSQL DataFrame
+#' @param num The number of rows to return. Default is 6.
+#' @return A data.frame
+#'
+#' @rdname head
+#' @export
+#' @examples
+#'\dontrun{
+#' sc <- sparkR.init()
+#' sqlCtx <- sparkRSQL.init(sc)
+#' path <- "path/to/file.json"
+#' df <- jsonFile(sqlCtx, path)
+#' head(df)
+#' }
+setMethod("head",
+ signature(x = "DataFrame"),
+ function(x, num = 6L) {
+ # Default num is 6L in keeping with R's data.frame convention
+ take(x, num)
+ })
+
+#' Return the first row of a DataFrame
+#'
+#' @param x A SparkSQL DataFrame
+#'
+#' @rdname first
+#' @export
+#' @examples
+#'\dontrun{
+#' sc <- sparkR.init()
+#' sqlCtx <- sparkRSQL.init(sc)
+#' path <- "path/to/file.json"
+#' df <- jsonFile(sqlCtx, path)
+#' first(df)
+#' }
+setMethod("first",
+ signature(x = "DataFrame"),
+ function(x) {
+ take(x, 1)
+ })
+
+#' toRDD()
+#'
+#' Converts a Spark DataFrame to an RDD while preserving column names.
+#'
+#' @param x A Spark DataFrame
+#'
+#' @rdname DataFrame
+#' @export
+#' @examples
+#'\dontrun{
+#' sc <- sparkR.init()
+#' sqlCtx <- sparkRSQL.init(sc)
+#' path <- "path/to/file.json"
+#' df <- jsonFile(sqlCtx, path)
+#' rdd <- toRDD(df)
+#' }
+setMethod("toRDD",
+ signature(x = "DataFrame"),
+ function(x) {
+ jrdd <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "dfToRowRDD", x@sdf)
+ colNames <- callJMethod(x@sdf, "columns")
+ rdd <- RDD(jrdd, serializedMode = "row")
+ lapply(rdd, function(row) {
+ names(row) <- colNames
+ row
+ })
+ })
+
+#' GroupBy
+#'
+#' Groups the DataFrame using the specified columns, so we can run aggregation on them.
+#'
+#' @param x a DataFrame
+#' @return a GroupedData
+#' @seealso GroupedData
+#' @rdname DataFrame
+#' @export
+#' @examples
+#' \dontrun{
+#' # Compute the average for all numeric columns grouped by department.
+#' avg(groupBy(df, "department"))
+#'
+#' # Compute the max age and average salary, grouped by department and gender.
+#' agg(groupBy(df, "department", "gender"), salary="avg", "age" -> "max")
+#' }
+setMethod("groupBy",
+ signature(x = "DataFrame"),
+ function(x, ...) {
+ cols <- list(...)
+ if (length(cols) >= 1 && class(cols[[1]]) == "character") {
+ sgd <- callJMethod(x@sdf, "groupBy", cols[[1]], listToSeq(cols[-1]))
+ } else {
+ jcol <- lapply(cols, function(c) { c@jc })
+ sgd <- callJMethod(x@sdf, "groupBy", listToSeq(jcol))
+ }
+ groupedData(sgd)
+ })
+
+#' Agg
+#'
+#' Compute aggregates by specifying a list of columns
+#'
+#' @rdname DataFrame
+#' @export
+setMethod("agg",
+ signature(x = "DataFrame"),
+ function(x, ...) {
+ agg(groupBy(x), ...)
+ })
+
+
+############################## RDD Map Functions ##################################
+# All of the following functions mirror the existing RDD map functions, #
+# but allow for use with DataFrames by first converting to an RRDD before calling #
+# the requested map function. #
+###################################################################################
+
+#' @rdname lapply
+setMethod("lapply",
+ signature(X = "DataFrame", FUN = "function"),
+ function(X, FUN) {
+ rdd <- toRDD(X)
+ lapply(rdd, FUN)
+ })
+
+#' @rdname lapply
+setMethod("map",
+ signature(X = "DataFrame", FUN = "function"),
+ function(X, FUN) {
+ lapply(X, FUN)
+ })
+
+#' @rdname flatMap
+setMethod("flatMap",
+ signature(X = "DataFrame", FUN = "function"),
+ function(X, FUN) {
+ rdd <- toRDD(X)
+ flatMap(rdd, FUN)
+ })
+
+#' @rdname lapplyPartition
+setMethod("lapplyPartition",
+ signature(X = "DataFrame", FUN = "function"),
+ function(X, FUN) {
+ rdd <- toRDD(X)
+ lapplyPartition(rdd, FUN)
+ })
+
+#' @rdname lapplyPartition
+setMethod("mapPartitions",
+ signature(X = "DataFrame", FUN = "function"),
+ function(X, FUN) {
+ lapplyPartition(X, FUN)
+ })
+
+#' @rdname foreach
+setMethod("foreach",
+ signature(x = "DataFrame", func = "function"),
+ function(x, func) {
+ rdd <- toRDD(x)
+ foreach(rdd, func)
+ })
+
+#' @rdname foreach
+setMethod("foreachPartition",
+ signature(x = "DataFrame", func = "function"),
+ function(x, func) {
+ rdd <- toRDD(x)
+ foreachPartition(rdd, func)
+ })
+
+
+############################## SELECT ##################################
+
+getColumn <- function(x, c) {
+ column(callJMethod(x@sdf, "col", c))
+}
+
+#' @rdname select
+setMethod("$", signature(x = "DataFrame"),
+ function(x, name) {
+ getColumn(x, name)
+ })
+
+setMethod("$<-", signature(x = "DataFrame"),
+ function(x, name, value) {
+ stopifnot(class(value) == "Column")
+ cols <- columns(x)
+ if (name %in% cols) {
+ cols <- lapply(cols, function(c) {
+ if (c == name) {
+ alias(value, name)
+ } else {
+ col(c)
+ }
+ })
+ nx <- select(x, cols)
+ } else {
+ nx <- withColumn(x, name, value)
+ }
+ x@sdf <- nx@sdf
+ x
+ })
+
+#' @rdname select
+setMethod("[[", signature(x = "DataFrame"),
+ function(x, i) {
+ if (is.numeric(i)) {
+ cols <- columns(x)
+ i <- cols[[i]]
+ }
+ getColumn(x, i)
+ })
+
+#' @rdname select
+setMethod("[", signature(x = "DataFrame", i = "missing"),
+ function(x, i, j, ...) {
+ if (is.numeric(j)) {
+ cols <- columns(x)
+ j <- cols[j]
+ }
+ if (length(j) > 1) {
+ j <- as.list(j)
+ }
+ select(x, j)
+ })
+
+#' Select
+#'
+#' Selects a set of columns with names or Column expressions.
+#' @param x A DataFrame
+#' @param col A list of columns or single Column or name
+#' @return A new DataFrame with selected columns
+#' @export
+#' @rdname select
+#' @examples
+#' \dontrun{
+#' select(df, "*")
+#' select(df, "col1", "col2")
+#' select(df, df$name, df$age + 1)
+#' select(df, c("col1", "col2"))
+#' select(df, list(df$name, df$age + 1))
+#' # Columns can also be selected using `[[` and `[`
+#' df[[2]] == df[["age"]]
+#' df[,2] == df[,"age"]
+#' # Similar to R data frames columns can also be selected using `$`
+#' df$age
+#' }
+setMethod("select", signature(x = "DataFrame", col = "character"),
+ function(x, col, ...) {
+ sdf <- callJMethod(x@sdf, "select", col, toSeq(...))
+ dataFrame(sdf)
+ })
+
+#' @rdname select
+#' @export
+setMethod("select", signature(x = "DataFrame", col = "Column"),
+ function(x, col, ...) {
+ jcols <- lapply(list(col, ...), function(c) {
+ c@jc
+ })
+ sdf <- callJMethod(x@sdf, "select", listToSeq(jcols))
+ dataFrame(sdf)
+ })
+
+#' @rdname select
+#' @export
+setMethod("select",
+ signature(x = "DataFrame", col = "list"),
+ function(x, col) {
+ cols <- lapply(col, function(c) {
+ if (class(c)== "Column") {
+ c@jc
+ } else {
+ col(c)@jc
+ }
+ })
+ sdf <- callJMethod(x@sdf, "select", listToSeq(cols))
+ dataFrame(sdf)
+ })
+
+#' SelectExpr
+#'
+#' Select from a DataFrame using a set of SQL expressions.
+#'
+#' @param x A DataFrame to be selected from.
+#' @param expr A string containing a SQL expression
+#' @param ... Additional expressions
+#' @return A DataFrame
+#' @rdname selectExpr
+#' @export
+#' @examples
+#'\dontrun{
+#' sc <- sparkR.init()
+#' sqlCtx <- sparkRSQL.init(sc)
+#' path <- "path/to/file.json"
+#' df <- jsonFile(sqlCtx, path)
+#' selectExpr(df, "col1", "(col2 * 5) as newCol")
+#' }
+setMethod("selectExpr",
+ signature(x = "DataFrame", expr = "character"),
+ function(x, expr, ...) {
+ exprList <- list(expr, ...)
+ sdf <- callJMethod(x@sdf, "selectExpr", listToSeq(exprList))
+ dataFrame(sdf)
+ })
+
+#' WithColumn
+#'
+#' Return a new DataFrame with the specified column added.
+#'
+#' @param x A DataFrame
+#' @param colName A string containing the name of the new column.
+#' @param col A Column expression.
+#' @return A DataFrame with the new column added.
+#' @rdname withColumn
+#' @export
+#' @examples
+#'\dontrun{
+#' sc <- sparkR.init()
+#' sqlCtx <- sparkRSQL.init(sc)
+#' path <- "path/to/file.json"
+#' df <- jsonFile(sqlCtx, path)
+#' newDF <- withColumn(df, "newCol", df$col1 * 5)
+#' }
+setMethod("withColumn",
+ signature(x = "DataFrame", colName = "character", col = "Column"),
+ function(x, colName, col) {
+ select(x, x$"*", alias(col, colName))
+ })
+
+#' WithColumnRenamed
+#'
+#' Rename an existing column in a DataFrame.
+#'
+#' @param x A DataFrame
+#' @param existingCol The name of the column you want to change.
+#' @param newCol The new column name.
+#' @return A DataFrame with the column name changed.
+#' @rdname withColumnRenamed
+#' @export
+#' @examples
+#'\dontrun{
+#' sc <- sparkR.init()
+#' sqlCtx <- sparkRSQL.init(sc)
+#' path <- "path/to/file.json"
+#' df <- jsonFile(sqlCtx, path)
+#' newDF <- withColumnRenamed(df, "col1", "newCol1")
+#' }
+setMethod("withColumnRenamed",
+ signature(x = "DataFrame", existingCol = "character", newCol = "character"),
+ function(x, existingCol, newCol) {
+ cols <- lapply(columns(x), function(c) {
+ if (c == existingCol) {
+ alias(col(c), newCol)
+ } else {
+ col(c)
+ }
+ })
+ select(x, cols)
+ })
+
+setClassUnion("characterOrColumn", c("character", "Column"))
+
+#' SortDF
+#'
+#' Sort a DataFrame by the specified column(s).
+#'
+#' @param x A DataFrame to be sorted.
+#' @param col Either a Column object or character vector indicating the field to sort on
+#' @param ... Additional sorting fields
+#' @return A DataFrame where all elements are sorted.
+#' @rdname sortDF
+#' @export
+#' @examples
+#'\dontrun{
+#' sc <- sparkR.init()
+#' sqlCtx <- sparkRSQL.init(sc)
+#' path <- "path/to/file.json"
+#' df <- jsonFile(sqlCtx, path)
+#' sortDF(df, df$col1)
+#' sortDF(df, "col1")
+#' sortDF(df, asc(df$col1), desc(abs(df$col2)))
+#' }
+setMethod("sortDF",
+ signature(x = "DataFrame", col = "characterOrColumn"),
+ function(x, col, ...) {
+ if (class(col) == "character") {
+ sdf <- callJMethod(x@sdf, "sort", col, toSeq(...))
+ } else if (class(col) == "Column") {
+ jcols <- lapply(list(col, ...), function(c) {
+ c@jc
+ })
+ sdf <- callJMethod(x@sdf, "sort", listToSeq(jcols))
+ }
+ dataFrame(sdf)
+ })
+
+#' @rdname sortDF
+#' @export
+setMethod("orderBy",
+ signature(x = "DataFrame", col = "characterOrColumn"),
+ function(x, col) {
+ sortDF(x, col)
+ })
+
+#' Filter
+#'
+#' Filter the rows of a DataFrame according to a given condition.
+#'
+#' @param x A DataFrame to be sorted.
+#' @param condition The condition to sort on. This may either be a Column expression
+#' or a string containing a SQL statement
+#' @return A DataFrame containing only the rows that meet the condition.
+#' @rdname filter
+#' @export
+#' @examples
+#'\dontrun{
+#' sc <- sparkR.init()
+#' sqlCtx <- sparkRSQL.init(sc)
+#' path <- "path/to/file.json"
+#' df <- jsonFile(sqlCtx, path)
+#' filter(df, "col1 > 0")
+#' filter(df, df$col2 != "abcdefg")
+#' }
+setMethod("filter",
+ signature(x = "DataFrame", condition = "characterOrColumn"),
+ function(x, condition) {
+ if (class(condition) == "Column") {
+ condition <- condition@jc
+ }
+ sdf <- callJMethod(x@sdf, "filter", condition)
+ dataFrame(sdf)
+ })
+
+#' @rdname filter
+#' @export
+setMethod("where",
+ signature(x = "DataFrame", condition = "characterOrColumn"),
+ function(x, condition) {
+ filter(x, condition)
+ })
+
+#' Join
+#'
+#' Join two DataFrames based on the given join expression.
+#'
+#' @param x A Spark DataFrame
+#' @param y A Spark DataFrame
+#' @param joinExpr (Optional) The expression used to perform the join. joinExpr must be a
+#' Column expression. If joinExpr is omitted, join() wil perform a Cartesian join
+#' @param joinType The type of join to perform. The following join types are available:
+#' 'inner', 'outer', 'left_outer', 'right_outer', 'semijoin'. The default joinType is "inner".
+#' @return A DataFrame containing the result of the join operation.
+#' @rdname join
+#' @export
+#' @examples
+#'\dontrun{
+#' sc <- sparkR.init()
+#' sqlCtx <- sparkRSQL.init(sc)
+#' df1 <- jsonFile(sqlCtx, path)
+#' df2 <- jsonFile(sqlCtx, path2)
+#' join(df1, df2) # Performs a Cartesian
+#' join(df1, df2, df1$col1 == df2$col2) # Performs an inner join based on expression
+#' join(df1, df2, df1$col1 == df2$col2, "right_outer")
+#' }
+setMethod("join",
+ signature(x = "DataFrame", y = "DataFrame"),
+ function(x, y, joinExpr = NULL, joinType = NULL) {
+ if (is.null(joinExpr)) {
+ sdf <- callJMethod(x@sdf, "join", y@sdf)
+ } else {
+ if (class(joinExpr) != "Column") stop("joinExpr must be a Column")
+ if (is.null(joinType)) {
+ sdf <- callJMethod(x@sdf, "join", y@sdf, joinExpr@jc)
+ } else {
+ if (joinType %in% c("inner", "outer", "left_outer", "right_outer", "semijoin")) {
+ sdf <- callJMethod(x@sdf, "join", y@sdf, joinExpr@jc, joinType)
+ } else {
+ stop("joinType must be one of the following types: ",
+ "'inner', 'outer', 'left_outer', 'right_outer', 'semijoin'")
+ }
+ }
+ }
+ dataFrame(sdf)
+ })
+
+#' UnionAll
+#'
+#' Return a new DataFrame containing the union of rows in this DataFrame
+#' and another DataFrame. This is equivalent to `UNION ALL` in SQL.
+#'
+#' @param x A Spark DataFrame
+#' @param y A Spark DataFrame
+#' @return A DataFrame containing the result of the union.
+#' @rdname unionAll
+#' @export
+#' @examples
+#'\dontrun{
+#' sc <- sparkR.init()
+#' sqlCtx <- sparkRSQL.init(sc)
+#' df1 <- jsonFile(sqlCtx, path)
+#' df2 <- jsonFile(sqlCtx, path2)
+#' unioned <- unionAll(df, df2)
+#' }
+setMethod("unionAll",
+ signature(x = "DataFrame", y = "DataFrame"),
+ function(x, y) {
+ unioned <- callJMethod(x@sdf, "unionAll", y@sdf)
+ dataFrame(unioned)
+ })
+
+#' Intersect
+#'
+#' Return a new DataFrame containing rows only in both this DataFrame
+#' and another DataFrame. This is equivalent to `INTERSECT` in SQL.
+#'
+#' @param x A Spark DataFrame
+#' @param y A Spark DataFrame
+#' @return A DataFrame containing the result of the intersect.
+#' @rdname intersect
+#' @export
+#' @examples
+#'\dontrun{
+#' sc <- sparkR.init()
+#' sqlCtx <- sparkRSQL.init(sc)
+#' df1 <- jsonFile(sqlCtx, path)
+#' df2 <- jsonFile(sqlCtx, path2)
+#' intersectDF <- intersect(df, df2)
+#' }
+setMethod("intersect",
+ signature(x = "DataFrame", y = "DataFrame"),
+ function(x, y) {
+ intersected <- callJMethod(x@sdf, "intersect", y@sdf)
+ dataFrame(intersected)
+ })
+
+#' Subtract
+#'
+#' Return a new DataFrame containing rows in this DataFrame
+#' but not in another DataFrame. This is equivalent to `EXCEPT` in SQL.
+#'
+#' @param x A Spark DataFrame
+#' @param y A Spark DataFrame
+#' @return A DataFrame containing the result of the subtract operation.
+#' @rdname subtract
+#' @export
+#' @examples
+#'\dontrun{
+#' sc <- sparkR.init()
+#' sqlCtx <- sparkRSQL.init(sc)
+#' df1 <- jsonFile(sqlCtx, path)
+#' df2 <- jsonFile(sqlCtx, path2)
+#' subtractDF <- subtract(df, df2)
+#' }
+setMethod("subtract",
+ signature(x = "DataFrame", y = "DataFrame"),
+ function(x, y) {
+ subtracted <- callJMethod(x@sdf, "except", y@sdf)
+ dataFrame(subtracted)
+ })
+
+#' Save the contents of the DataFrame to a data source
+#'
+#' The data source is specified by the `source` and a set of options (...).
+#' If `source` is not specified, the default data source configured by
+#' spark.sql.sources.default will be used.
+#'
+#' Additionally, mode is used to specify the behavior of the save operation when
+#' data already exists in the data source. There are four modes:
+#' append: Contents of this DataFrame are expected to be appended to existing data.
+#' overwrite: Existing data is expected to be overwritten by the contents of
+# this DataFrame.
+#' error: An exception is expected to be thrown.
+#' ignore: The save operation is expected to not save the contents of the DataFrame
+# and to not change the existing data.
+#'
+#' @param df A SparkSQL DataFrame
+#' @param path A name for the table
+#' @param source A name for external data source
+#' @param mode One of 'append', 'overwrite', 'error', 'ignore'
+#'
+#' @rdname saveAsTable
+#' @export
+#' @examples
+#'\dontrun{
+#' sc <- sparkR.init()
+#' sqlCtx <- sparkRSQL.init(sc)
+#' path <- "path/to/file.json"
+#' df <- jsonFile(sqlCtx, path)
+#' saveAsTable(df, "myfile")
+#' }
+setMethod("saveDF",
+ signature(df = "DataFrame", path = 'character', source = 'character',
+ mode = 'character'),
+ function(df, path = NULL, source = NULL, mode = "append", ...){
+ if (is.null(source)) {
+ sqlCtx <- get(".sparkRSQLsc", envir = .sparkREnv)
+ source <- callJMethod(sqlCtx, "getConf", "spark.sql.sources.default",
+ "org.apache.spark.sql.parquet")
+ }
+ allModes <- c("append", "overwrite", "error", "ignore")
+ if (!(mode %in% allModes)) {
+ stop('mode should be one of "append", "overwrite", "error", "ignore"')
+ }
+ jmode <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "saveMode", mode)
+ options <- varargsToEnv(...)
+ if (!is.null(path)) {
+ options[['path']] = path
+ }
+ callJMethod(df@sdf, "save", source, jmode, options)
+ })
+
+
+#' saveAsTable
+#'
+#' Save the contents of the DataFrame to a data source as a table
+#'
+#' The data source is specified by the `source` and a set of options (...).
+#' If `source` is not specified, the default data source configured by
+#' spark.sql.sources.default will be used.
+#'
+#' Additionally, mode is used to specify the behavior of the save operation when
+#' data already exists in the data source. There are four modes:
+#' append: Contents of this DataFrame are expected to be appended to existing data.
+#' overwrite: Existing data is expected to be overwritten by the contents of
+# this DataFrame.
+#' error: An exception is expected to be thrown.
+#' ignore: The save operation is expected to not save the contents of the DataFrame
+# and to not change the existing data.
+#'
+#' @param df A SparkSQL DataFrame
+#' @param tableName A name for the table
+#' @param source A name for external data source
+#' @param mode One of 'append', 'overwrite', 'error', 'ignore'
+#'
+#' @rdname saveAsTable
+#' @export
+#' @examples
+#'\dontrun{
+#' sc <- sparkR.init()
+#' sqlCtx <- sparkRSQL.init(sc)
+#' path <- "path/to/file.json"
+#' df <- jsonFile(sqlCtx, path)
+#' saveAsTable(df, "myfile")
+#' }
+setMethod("saveAsTable",
+ signature(df = "DataFrame", tableName = 'character', source = 'character',
+ mode = 'character'),
+ function(df, tableName, source = NULL, mode="append", ...){
+ if (is.null(source)) {
+ sqlCtx <- get(".sparkRSQLsc", envir = .sparkREnv)
+ source <- callJMethod(sqlCtx, "getConf", "spark.sql.sources.default",
+ "org.apache.spark.sql.parquet")
+ }
+ allModes <- c("append", "overwrite", "error", "ignore")
+ if (!(mode %in% allModes)) {
+ stop('mode should be one of "append", "overwrite", "error", "ignore"')
+ }
+ jmode <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "saveMode", mode)
+ options <- varargsToEnv(...)
+ callJMethod(df@sdf, "saveAsTable", tableName, source, jmode, options)
+ })
+
diff --git a/R/pkg/R/RDD.R b/R/pkg/R/RDD.R
new file mode 100644
index 0000000000..604ad03c40
--- /dev/null
+++ b/R/pkg/R/RDD.R
@@ -0,0 +1,1539 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# RDD in R implemented in S4 OO system.
+
+setOldClass("jobj")
+
+#' @title S4 class that represents an RDD
+#' @description RDD can be created using functions like
+#' \code{parallelize}, \code{textFile} etc.
+#' @rdname RDD
+#' @seealso parallelize, textFile
+#'
+#' @slot env An R environment that stores bookkeeping states of the RDD
+#' @slot jrdd Java object reference to the backing JavaRDD
+#' to an RDD
+#' @export
+setClass("RDD",
+ slots = list(env = "environment",
+ jrdd = "jobj"))
+
+setClass("PipelinedRDD",
+ slots = list(prev = "RDD",
+ func = "function",
+ prev_jrdd = "jobj"),
+ contains = "RDD")
+
+setMethod("initialize", "RDD", function(.Object, jrdd, serializedMode,
+ isCached, isCheckpointed) {
+ # Check that RDD constructor is using the correct version of serializedMode
+ stopifnot(class(serializedMode) == "character")
+ stopifnot(serializedMode %in% c("byte", "string", "row"))
+ # RDD has three serialization types:
+ # byte: The RDD stores data serialized in R.
+ # string: The RDD stores data as strings.
+ # row: The RDD stores the serialized rows of a DataFrame.
+
+ # We use an environment to store mutable states inside an RDD object.
+ # Note that R's call-by-value semantics makes modifying slots inside an
+ # object (passed as an argument into a function, such as cache()) difficult:
+ # i.e. one needs to make a copy of the RDD object and sets the new slot value
+ # there.
+
+ # The slots are inheritable from superclass. Here, both `env' and `jrdd' are
+ # inherited from RDD, but only the former is used.
+ .Object@env <- new.env()
+ .Object@env$isCached <- isCached
+ .Object@env$isCheckpointed <- isCheckpointed
+ .Object@env$serializedMode <- serializedMode
+
+ .Object@jrdd <- jrdd
+ .Object
+})
+
+setMethod("initialize", "PipelinedRDD", function(.Object, prev, func, jrdd_val) {
+ .Object@env <- new.env()
+ .Object@env$isCached <- FALSE
+ .Object@env$isCheckpointed <- FALSE
+ .Object@env$jrdd_val <- jrdd_val
+ if (!is.null(jrdd_val)) {
+ # This tracks the serialization mode for jrdd_val
+ .Object@env$serializedMode <- prev@env$serializedMode
+ }
+
+ .Object@prev <- prev
+
+ isPipelinable <- function(rdd) {
+ e <- rdd@env
+ !(e$isCached || e$isCheckpointed)
+ }
+
+ if (!inherits(prev, "PipelinedRDD") || !isPipelinable(prev)) {
+ # This transformation is the first in its stage:
+ .Object@func <- func
+ .Object@prev_jrdd <- getJRDD(prev)
+ .Object@env$prev_serializedMode <- prev@env$serializedMode
+ # NOTE: We use prev_serializedMode to track the serialization mode of prev_JRDD
+ # prev_serializedMode is used during the delayed computation of JRDD in getJRDD
+ } else {
+ pipelinedFunc <- function(split, iterator) {
+ func(split, prev@func(split, iterator))
+ }
+ .Object@func <- pipelinedFunc
+ .Object@prev_jrdd <- prev@prev_jrdd # maintain the pipeline
+ # Get the serialization mode of the parent RDD
+ .Object@env$prev_serializedMode <- prev@env$prev_serializedMode
+ }
+
+ .Object
+})
+
+#' @rdname RDD
+#' @export
+#'
+#' @param jrdd Java object reference to the backing JavaRDD
+#' @param serializedMode Use "byte" if the RDD stores data serialized in R, "string" if the RDD
+#' stores strings, and "row" if the RDD stores the rows of a DataFrame
+#' @param isCached TRUE if the RDD is cached
+#' @param isCheckpointed TRUE if the RDD has been checkpointed
+RDD <- function(jrdd, serializedMode = "byte", isCached = FALSE,
+ isCheckpointed = FALSE) {
+ new("RDD", jrdd, serializedMode, isCached, isCheckpointed)
+}
+
+PipelinedRDD <- function(prev, func) {
+ new("PipelinedRDD", prev, func, NULL)
+}
+
+# Return the serialization mode for an RDD.
+setGeneric("getSerializedMode", function(rdd, ...) { standardGeneric("getSerializedMode") })
+# For normal RDDs we can directly read the serializedMode
+setMethod("getSerializedMode", signature(rdd = "RDD"), function(rdd) rdd@env$serializedMode )
+# For pipelined RDDs if jrdd_val is set then serializedMode should exist
+# if not we return the defaultSerialization mode of "byte" as we don't know the serialization
+# mode at this point in time.
+setMethod("getSerializedMode", signature(rdd = "PipelinedRDD"),
+ function(rdd) {
+ if (!is.null(rdd@env$jrdd_val)) {
+ return(rdd@env$serializedMode)
+ } else {
+ return("byte")
+ }
+ })
+
+# The jrdd accessor function.
+setMethod("getJRDD", signature(rdd = "RDD"), function(rdd) rdd@jrdd )
+setMethod("getJRDD", signature(rdd = "PipelinedRDD"),
+ function(rdd, serializedMode = "byte") {
+ if (!is.null(rdd@env$jrdd_val)) {
+ return(rdd@env$jrdd_val)
+ }
+
+ computeFunc <- function(split, part) {
+ rdd@func(split, part)
+ }
+
+ packageNamesArr <- serialize(.sparkREnv[[".packages"]],
+ connection = NULL)
+
+ broadcastArr <- lapply(ls(.broadcastNames),
+ function(name) { get(name, .broadcastNames) })
+
+ serializedFuncArr <- serialize(computeFunc, connection = NULL)
+
+ prev_jrdd <- rdd@prev_jrdd
+
+ if (serializedMode == "string") {
+ rddRef <- newJObject("org.apache.spark.api.r.StringRRDD",
+ callJMethod(prev_jrdd, "rdd"),
+ serializedFuncArr,
+ rdd@env$prev_serializedMode,
+ packageNamesArr,
+ as.character(.sparkREnv[["libname"]]),
+ broadcastArr,
+ callJMethod(prev_jrdd, "classTag"))
+ } else {
+ rddRef <- newJObject("org.apache.spark.api.r.RRDD",
+ callJMethod(prev_jrdd, "rdd"),
+ serializedFuncArr,
+ rdd@env$prev_serializedMode,
+ serializedMode,
+ packageNamesArr,
+ as.character(.sparkREnv[["libname"]]),
+ broadcastArr,
+ callJMethod(prev_jrdd, "classTag"))
+ }
+ # Save the serialization flag after we create a RRDD
+ rdd@env$serializedMode <- serializedMode
+ rdd@env$jrdd_val <- callJMethod(rddRef, "asJavaRDD") # rddRef$asJavaRDD()
+ rdd@env$jrdd_val
+ })
+
+setValidity("RDD",
+ function(object) {
+ jrdd <- getJRDD(object)
+ cls <- callJMethod(jrdd, "getClass")
+ className <- callJMethod(cls, "getName")
+ if (grep("spark.api.java.*RDD*", className) == 1) {
+ TRUE
+ } else {
+ paste("Invalid RDD class ", className)
+ }
+ })
+
+
+############ Actions and Transformations ############
+
+#' Persist an RDD
+#'
+#' Persist this RDD with the default storage level (MEMORY_ONLY).
+#'
+#' @param x The RDD to cache
+#' @examples
+#'\dontrun{
+#' sc <- sparkR.init()
+#' rdd <- parallelize(sc, 1:10, 2L)
+#' cache(rdd)
+#'}
+#' @rdname cache-methods
+#' @aliases cache,RDD-method
+setMethod("cache",
+ signature(x = "RDD"),
+ function(x) {
+ callJMethod(getJRDD(x), "cache")
+ x@env$isCached <- TRUE
+ x
+ })
+
+#' Persist an RDD
+#'
+#' Persist this RDD with the specified storage level. For details of the
+#' supported storage levels, refer to
+#' http://spark.apache.org/docs/latest/programming-guide.html#rdd-persistence.
+#'
+#' @param x The RDD to persist
+#' @param newLevel The new storage level to be assigned
+#' @examples
+#'\dontrun{
+#' sc <- sparkR.init()
+#' rdd <- parallelize(sc, 1:10, 2L)
+#' persist(rdd, "MEMORY_AND_DISK")
+#'}
+#' @rdname persist
+#' @aliases persist,RDD-method
+setMethod("persist",
+ signature(x = "RDD", newLevel = "character"),
+ function(x, newLevel) {
+ callJMethod(getJRDD(x), "persist", getStorageLevel(newLevel))
+ x@env$isCached <- TRUE
+ x
+ })
+
+#' Unpersist an RDD
+#'
+#' Mark the RDD as non-persistent, and remove all blocks for it from memory and
+#' disk.
+#'
+#' @param x The RDD to unpersist
+#' @examples
+#'\dontrun{
+#' sc <- sparkR.init()
+#' rdd <- parallelize(sc, 1:10, 2L)
+#' cache(rdd) # rdd@@env$isCached == TRUE
+#' unpersist(rdd) # rdd@@env$isCached == FALSE
+#'}
+#' @rdname unpersist-methods
+#' @aliases unpersist,RDD-method
+setMethod("unpersist",
+ signature(x = "RDD"),
+ function(x) {
+ callJMethod(getJRDD(x), "unpersist")
+ x@env$isCached <- FALSE
+ x
+ })
+
+#' Checkpoint an RDD
+#'
+#' Mark this RDD for checkpointing. It will be saved to a file inside the
+#' checkpoint directory set with setCheckpointDir() and all references to its
+#' parent RDDs will be removed. This function must be called before any job has
+#' been executed on this RDD. It is strongly recommended that this RDD is
+#' persisted in memory, otherwise saving it on a file will require recomputation.
+#'
+#' @param x The RDD to checkpoint
+#' @examples
+#'\dontrun{
+#' sc <- sparkR.init()
+#' setCheckpointDir(sc, "checkpoints")
+#' rdd <- parallelize(sc, 1:10, 2L)
+#' checkpoint(rdd)
+#'}
+#' @rdname checkpoint-methods
+#' @aliases checkpoint,RDD-method
+setMethod("checkpoint",
+ signature(x = "RDD"),
+ function(x) {
+ jrdd <- getJRDD(x)
+ callJMethod(jrdd, "checkpoint")
+ x@env$isCheckpointed <- TRUE
+ x
+ })
+
+#' Gets the number of partitions of an RDD
+#'
+#' @param x A RDD.
+#' @return the number of partitions of rdd as an integer.
+#' @examples
+#'\dontrun{
+#' sc <- sparkR.init()
+#' rdd <- parallelize(sc, 1:10, 2L)
+#' numPartitions(rdd) # 2L
+#'}
+#' @rdname numPartitions
+#' @aliases numPartitions,RDD-method
+setMethod("numPartitions",
+ signature(x = "RDD"),
+ function(x) {
+ jrdd <- getJRDD(x)
+ partitions <- callJMethod(jrdd, "splits")
+ callJMethod(partitions, "size")
+ })
+
+#' Collect elements of an RDD
+#'
+#' @description
+#' \code{collect} returns a list that contains all of the elements in this RDD.
+#'
+#' @param x The RDD to collect
+#' @param ... Other optional arguments to collect
+#' @param flatten FALSE if the list should not flattened
+#' @return a list containing elements in the RDD
+#' @examples
+#'\dontrun{
+#' sc <- sparkR.init()
+#' rdd <- parallelize(sc, 1:10, 2L)
+#' collect(rdd) # list from 1 to 10
+#' collectPartition(rdd, 0L) # list from 1 to 5
+#'}
+#' @rdname collect-methods
+#' @aliases collect,RDD-method
+setMethod("collect",
+ signature(x = "RDD"),
+ function(x, flatten = TRUE) {
+ # Assumes a pairwise RDD is backed by a JavaPairRDD.
+ collected <- callJMethod(getJRDD(x), "collect")
+ convertJListToRList(collected, flatten,
+ serializedMode = getSerializedMode(x))
+ })
+
+
+#' @description
+#' \code{collectPartition} returns a list that contains all of the elements
+#' in the specified partition of the RDD.
+#' @param partitionId the partition to collect (starts from 0)
+#' @rdname collect-methods
+#' @aliases collectPartition,integer,RDD-method
+setMethod("collectPartition",
+ signature(x = "RDD", partitionId = "integer"),
+ function(x, partitionId) {
+ jPartitionsList <- callJMethod(getJRDD(x),
+ "collectPartitions",
+ as.list(as.integer(partitionId)))
+
+ jList <- jPartitionsList[[1]]
+ convertJListToRList(jList, flatten = TRUE,
+ serializedMode = getSerializedMode(x))
+ })
+
+#' @description
+#' \code{collectAsMap} returns a named list as a map that contains all of the elements
+#' in a key-value pair RDD.
+#' @examples
+#'\dontrun{
+#' sc <- sparkR.init()
+#' rdd <- parallelize(sc, list(list(1, 2), list(3, 4)), 2L)
+#' collectAsMap(rdd) # list(`1` = 2, `3` = 4)
+#'}
+#' @rdname collect-methods
+#' @aliases collectAsMap,RDD-method
+setMethod("collectAsMap",
+ signature(x = "RDD"),
+ function(x) {
+ pairList <- collect(x)
+ map <- new.env()
+ lapply(pairList, function(i) { assign(as.character(i[[1]]), i[[2]], envir = map) })
+ as.list(map)
+ })
+
+#' Return the number of elements in the RDD.
+#'
+#' @param x The RDD to count
+#' @return number of elements in the RDD.
+#' @examples
+#'\dontrun{
+#' sc <- sparkR.init()
+#' rdd <- parallelize(sc, 1:10)
+#' count(rdd) # 10
+#' length(rdd) # Same as count
+#'}
+#' @rdname count
+#' @aliases count,RDD-method
+setMethod("count",
+ signature(x = "RDD"),
+ function(x) {
+ countPartition <- function(part) {
+ as.integer(length(part))
+ }
+ valsRDD <- lapplyPartition(x, countPartition)
+ vals <- collect(valsRDD)
+ sum(as.integer(vals))
+ })
+
+#' Return the number of elements in the RDD
+#' @export
+#' @rdname count
+setMethod("length",
+ signature(x = "RDD"),
+ function(x) {
+ count(x)
+ })
+
+#' Return the count of each unique value in this RDD as a list of
+#' (value, count) pairs.
+#'
+#' Same as countByValue in Spark.
+#'
+#' @param x The RDD to count
+#' @return list of (value, count) pairs, where count is number of each unique
+#' value in rdd.
+#' @examples
+#'\dontrun{
+#' sc <- sparkR.init()
+#' rdd <- parallelize(sc, c(1,2,3,2,1))
+#' countByValue(rdd) # (1,2L), (2,2L), (3,1L)
+#'}
+#' @rdname countByValue
+#' @aliases countByValue,RDD-method
+setMethod("countByValue",
+ signature(x = "RDD"),
+ function(x) {
+ ones <- lapply(x, function(item) { list(item, 1L) })
+ collect(reduceByKey(ones, `+`, numPartitions(x)))
+ })
+
+#' Apply a function to all elements
+#'
+#' This function creates a new RDD by applying the given transformation to all
+#' elements of the given RDD
+#'
+#' @param X The RDD to apply the transformation.
+#' @param FUN the transformation to apply on each element
+#' @return a new RDD created by the transformation.
+#' @rdname lapply
+#' @aliases lapply
+#' @examples
+#'\dontrun{
+#' sc <- sparkR.init()
+#' rdd <- parallelize(sc, 1:10)
+#' multiplyByTwo <- lapply(rdd, function(x) { x * 2 })
+#' collect(multiplyByTwo) # 2,4,6...
+#'}
+setMethod("lapply",
+ signature(X = "RDD", FUN = "function"),
+ function(X, FUN) {
+ func <- function(split, iterator) {
+ lapply(iterator, FUN)
+ }
+ lapplyPartitionsWithIndex(X, func)
+ })
+
+#' @rdname lapply
+#' @aliases map,RDD,function-method
+setMethod("map",
+ signature(X = "RDD", FUN = "function"),
+ function(X, FUN) {
+ lapply(X, FUN)
+ })
+
+#' Flatten results after apply a function to all elements
+#'
+#' This function return a new RDD by first applying a function to all
+#' elements of this RDD, and then flattening the results.
+#'
+#' @param X The RDD to apply the transformation.
+#' @param FUN the transformation to apply on each element
+#' @return a new RDD created by the transformation.
+#' @examples
+#'\dontrun{
+#' sc <- sparkR.init()
+#' rdd <- parallelize(sc, 1:10)
+#' multiplyByTwo <- flatMap(rdd, function(x) { list(x*2, x*10) })
+#' collect(multiplyByTwo) # 2,20,4,40,6,60...
+#'}
+#' @rdname flatMap
+#' @aliases flatMap,RDD,function-method
+setMethod("flatMap",
+ signature(X = "RDD", FUN = "function"),
+ function(X, FUN) {
+ partitionFunc <- function(part) {
+ unlist(
+ lapply(part, FUN),
+ recursive = F
+ )
+ }
+ lapplyPartition(X, partitionFunc)
+ })
+
+#' Apply a function to each partition of an RDD
+#'
+#' Return a new RDD by applying a function to each partition of this RDD.
+#'
+#' @param X The RDD to apply the transformation.
+#' @param FUN the transformation to apply on each partition.
+#' @return a new RDD created by the transformation.
+#' @examples
+#'\dontrun{
+#' sc <- sparkR.init()
+#' rdd <- parallelize(sc, 1:10)
+#' partitionSum <- lapplyPartition(rdd, function(part) { Reduce("+", part) })
+#' collect(partitionSum) # 15, 40
+#'}
+#' @rdname lapplyPartition
+#' @aliases lapplyPartition,RDD,function-method
+setMethod("lapplyPartition",
+ signature(X = "RDD", FUN = "function"),
+ function(X, FUN) {
+ lapplyPartitionsWithIndex(X, function(s, part) { FUN(part) })
+ })
+
+#' mapPartitions is the same as lapplyPartition.
+#'
+#' @rdname lapplyPartition
+#' @aliases mapPartitions,RDD,function-method
+setMethod("mapPartitions",
+ signature(X = "RDD", FUN = "function"),
+ function(X, FUN) {
+ lapplyPartition(X, FUN)
+ })
+
+#' Return a new RDD by applying a function to each partition of this RDD, while
+#' tracking the index of the original partition.
+#'
+#' @param X The RDD to apply the transformation.
+#' @param FUN the transformation to apply on each partition; takes the partition
+#' index and a list of elements in the particular partition.
+#' @return a new RDD created by the transformation.
+#' @examples
+#'\dontrun{
+#' sc <- sparkR.init()
+#' rdd <- parallelize(sc, 1:10, 5L)
+#' prod <- lapplyPartitionsWithIndex(rdd, function(split, part) {
+#' split * Reduce("+", part) })
+#' collect(prod, flatten = FALSE) # 0, 7, 22, 45, 76
+#'}
+#' @rdname lapplyPartitionsWithIndex
+#' @aliases lapplyPartitionsWithIndex,RDD,function-method
+setMethod("lapplyPartitionsWithIndex",
+ signature(X = "RDD", FUN = "function"),
+ function(X, FUN) {
+ FUN <- cleanClosure(FUN)
+ closureCapturingFunc <- function(split, part) {
+ FUN(split, part)
+ }
+ PipelinedRDD(X, closureCapturingFunc)
+ })
+
+#' @rdname lapplyPartitionsWithIndex
+#' @aliases mapPartitionsWithIndex,RDD,function-method
+setMethod("mapPartitionsWithIndex",
+ signature(X = "RDD", FUN = "function"),
+ function(X, FUN) {
+ lapplyPartitionsWithIndex(X, FUN)
+ })
+
+#' This function returns a new RDD containing only the elements that satisfy
+#' a predicate (i.e. returning TRUE in a given logical function).
+#' The same as `filter()' in Spark.
+#'
+#' @param x The RDD to be filtered.
+#' @param f A unary predicate function.
+#' @examples
+#'\dontrun{
+#' sc <- sparkR.init()
+#' rdd <- parallelize(sc, 1:10)
+#' unlist(collect(filterRDD(rdd, function (x) { x < 3 }))) # c(1, 2)
+#'}
+#' @rdname filterRDD
+#' @aliases filterRDD,RDD,function-method
+setMethod("filterRDD",
+ signature(x = "RDD", f = "function"),
+ function(x, f) {
+ filter.func <- function(part) {
+ Filter(f, part)
+ }
+ lapplyPartition(x, filter.func)
+ })
+
+#' @rdname filterRDD
+#' @aliases Filter
+setMethod("Filter",
+ signature(f = "function", x = "RDD"),
+ function(f, x) {
+ filterRDD(x, f)
+ })
+
+#' Reduce across elements of an RDD.
+#'
+#' This function reduces the elements of this RDD using the
+#' specified commutative and associative binary operator.
+#'
+#' @param x The RDD to reduce
+#' @param func Commutative and associative function to apply on elements
+#' of the RDD.
+#' @examples
+#'\dontrun{
+#' sc <- sparkR.init()
+#' rdd <- parallelize(sc, 1:10)
+#' reduce(rdd, "+") # 55
+#'}
+#' @rdname reduce
+#' @aliases reduce,RDD,ANY-method
+setMethod("reduce",
+ signature(x = "RDD", func = "ANY"),
+ function(x, func) {
+
+ reducePartition <- function(part) {
+ Reduce(func, part)
+ }
+
+ partitionList <- collect(lapplyPartition(x, reducePartition),
+ flatten = FALSE)
+ Reduce(func, partitionList)
+ })
+
+#' Get the maximum element of an RDD.
+#'
+#' @param x The RDD to get the maximum element from
+#' @examples
+#'\dontrun{
+#' sc <- sparkR.init()
+#' rdd <- parallelize(sc, 1:10)
+#' maximum(rdd) # 10
+#'}
+#' @rdname maximum
+#' @aliases maximum,RDD
+setMethod("maximum",
+ signature(x = "RDD"),
+ function(x) {
+ reduce(x, max)
+ })
+
+#' Get the minimum element of an RDD.
+#'
+#' @param x The RDD to get the minimum element from
+#' @examples
+#'\dontrun{
+#' sc <- sparkR.init()
+#' rdd <- parallelize(sc, 1:10)
+#' minimum(rdd) # 1
+#'}
+#' @rdname minimum
+#' @aliases minimum,RDD
+setMethod("minimum",
+ signature(x = "RDD"),
+ function(x) {
+ reduce(x, min)
+ })
+
+#' Add up the elements in an RDD.
+#'
+#' @param x The RDD to add up the elements in
+#' @examples
+#'\dontrun{
+#' sc <- sparkR.init()
+#' rdd <- parallelize(sc, 1:10)
+#' sumRDD(rdd) # 55
+#'}
+#' @rdname sumRDD
+#' @aliases sumRDD,RDD
+setMethod("sumRDD",
+ signature(x = "RDD"),
+ function(x) {
+ reduce(x, "+")
+ })
+
+#' Applies a function to all elements in an RDD, and force evaluation.
+#'
+#' @param x The RDD to apply the function
+#' @param func The function to be applied.
+#' @return invisible NULL.
+#' @examples
+#'\dontrun{
+#' sc <- sparkR.init()
+#' rdd <- parallelize(sc, 1:10)
+#' foreach(rdd, function(x) { save(x, file=...) })
+#'}
+#' @rdname foreach
+#' @aliases foreach,RDD,function-method
+setMethod("foreach",
+ signature(x = "RDD", func = "function"),
+ function(x, func) {
+ partition.func <- function(x) {
+ lapply(x, func)
+ NULL
+ }
+ invisible(collect(mapPartitions(x, partition.func)))
+ })
+
+#' Applies a function to each partition in an RDD, and force evaluation.
+#'
+#' @examples
+#'\dontrun{
+#' sc <- sparkR.init()
+#' rdd <- parallelize(sc, 1:10)
+#' foreachPartition(rdd, function(part) { save(part, file=...); NULL })
+#'}
+#' @rdname foreach
+#' @aliases foreachPartition,RDD,function-method
+setMethod("foreachPartition",
+ signature(x = "RDD", func = "function"),
+ function(x, func) {
+ invisible(collect(mapPartitions(x, func)))
+ })
+
+#' Take elements from an RDD.
+#'
+#' This function takes the first NUM elements in the RDD and
+#' returns them in a list.
+#'
+#' @param x The RDD to take elements from
+#' @param num Number of elements to take
+#' @examples
+#'\dontrun{
+#' sc <- sparkR.init()
+#' rdd <- parallelize(sc, 1:10)
+#' take(rdd, 2L) # list(1, 2)
+#'}
+#' @rdname take
+#' @aliases take,RDD,numeric-method
+setMethod("take",
+ signature(x = "RDD", num = "numeric"),
+ function(x, num) {
+ resList <- list()
+ index <- -1
+ jrdd <- getJRDD(x)
+ numPartitions <- numPartitions(x)
+
+ # TODO(shivaram): Collect more than one partition based on size
+ # estimates similar to the scala version of `take`.
+ while (TRUE) {
+ index <- index + 1
+
+ if (length(resList) >= num || index >= numPartitions)
+ break
+
+ # a JList of byte arrays
+ partitionArr <- callJMethod(jrdd, "collectPartitions", as.list(as.integer(index)))
+ partition <- partitionArr[[1]]
+
+ size <- num - length(resList)
+ # elems is capped to have at most `size` elements
+ elems <- convertJListToRList(partition,
+ flatten = TRUE,
+ logicalUpperBound = size,
+ serializedMode = getSerializedMode(x))
+ # TODO: Check if this append is O(n^2)?
+ resList <- append(resList, elems)
+ }
+ resList
+ })
+
+#' First
+#'
+#' Return the first element of an RDD
+#'
+#' @rdname first
+#' @export
+#' @examples
+#'\dontrun{
+#' sc <- sparkR.init()
+#' rdd <- parallelize(sc, 1:10)
+#' first(rdd)
+#' }
+setMethod("first",
+ signature(x = "RDD"),
+ function(x) {
+ take(x, 1)[[1]]
+ })
+
+#' Removes the duplicates from RDD.
+#'
+#' This function returns a new RDD containing the distinct elements in the
+#' given RDD. The same as `distinct()' in Spark.
+#'
+#' @param x The RDD to remove duplicates from.
+#' @param numPartitions Number of partitions to create.
+#' @examples
+#'\dontrun{
+#' sc <- sparkR.init()
+#' rdd <- parallelize(sc, c(1,2,2,3,3,3))
+#' sort(unlist(collect(distinct(rdd)))) # c(1, 2, 3)
+#'}
+#' @rdname distinct
+#' @aliases distinct,RDD-method
+setMethod("distinct",
+ signature(x = "RDD"),
+ function(x, numPartitions = SparkR::numPartitions(x)) {
+ identical.mapped <- lapply(x, function(x) { list(x, NULL) })
+ reduced <- reduceByKey(identical.mapped,
+ function(x, y) { x },
+ numPartitions)
+ resRDD <- lapply(reduced, function(x) { x[[1]] })
+ resRDD
+ })
+
+#' Return an RDD that is a sampled subset of the given RDD.
+#'
+#' The same as `sample()' in Spark. (We rename it due to signature
+#' inconsistencies with the `sample()' function in R's base package.)
+#'
+#' @param x The RDD to sample elements from
+#' @param withReplacement Sampling with replacement or not
+#' @param fraction The (rough) sample target fraction
+#' @param seed Randomness seed value
+#' @examples
+#'\dontrun{
+#' sc <- sparkR.init()
+#' rdd <- parallelize(sc, 1:10) # ensure each num is in its own split
+#' collect(sampleRDD(rdd, FALSE, 0.5, 1618L)) # ~5 distinct elements
+#' collect(sampleRDD(rdd, TRUE, 0.5, 9L)) # ~5 elements possibly with duplicates
+#'}
+#' @rdname sampleRDD
+#' @aliases sampleRDD,RDD
+setMethod("sampleRDD",
+ signature(x = "RDD", withReplacement = "logical",
+ fraction = "numeric", seed = "integer"),
+ function(x, withReplacement, fraction, seed) {
+
+ # The sampler: takes a partition and returns its sampled version.
+ samplingFunc <- function(split, part) {
+ set.seed(seed)
+ res <- vector("list", length(part))
+ len <- 0
+
+ # Discards some random values to ensure each partition has a
+ # different random seed.
+ runif(split)
+
+ for (elem in part) {
+ if (withReplacement) {
+ count <- rpois(1, fraction)
+ if (count > 0) {
+ res[(len + 1):(len + count)] <- rep(list(elem), count)
+ len <- len + count
+ }
+ } else {
+ if (runif(1) < fraction) {
+ len <- len + 1
+ res[[len]] <- elem
+ }
+ }
+ }
+
+ # TODO(zongheng): look into the performance of the current
+ # implementation. Look into some iterator package? Note that
+ # Scala avoids many calls to creating an empty list and PySpark
+ # similarly achieves this using `yield'.
+ if (len > 0)
+ res[1:len]
+ else
+ list()
+ }
+
+ lapplyPartitionsWithIndex(x, samplingFunc)
+ })
+
+#' Return a list of the elements that are a sampled subset of the given RDD.
+#'
+#' @param x The RDD to sample elements from
+#' @param withReplacement Sampling with replacement or not
+#' @param num Number of elements to return
+#' @param seed Randomness seed value
+#' @examples
+#'\dontrun{
+#' sc <- sparkR.init()
+#' rdd <- parallelize(sc, 1:100)
+#' # exactly 5 elements sampled, which may not be distinct
+#' takeSample(rdd, TRUE, 5L, 1618L)
+#' # exactly 5 distinct elements sampled
+#' takeSample(rdd, FALSE, 5L, 16181618L)
+#'}
+#' @rdname takeSample
+#' @aliases takeSample,RDD
+setMethod("takeSample", signature(x = "RDD", withReplacement = "logical",
+ num = "integer", seed = "integer"),
+ function(x, withReplacement, num, seed) {
+ # This function is ported from RDD.scala.
+ fraction <- 0.0
+ total <- 0
+ multiplier <- 3.0
+ initialCount <- count(x)
+ maxSelected <- 0
+ MAXINT <- .Machine$integer.max
+
+ if (num < 0)
+ stop(paste("Negative number of elements requested"))
+
+ if (initialCount > MAXINT - 1) {
+ maxSelected <- MAXINT - 1
+ } else {
+ maxSelected <- initialCount
+ }
+
+ if (num > initialCount && !withReplacement) {
+ total <- maxSelected
+ fraction <- multiplier * (maxSelected + 1) / initialCount
+ } else {
+ total <- num
+ fraction <- multiplier * (num + 1) / initialCount
+ }
+
+ set.seed(seed)
+ samples <- collect(sampleRDD(x, withReplacement, fraction,
+ as.integer(ceiling(runif(1,
+ -MAXINT,
+ MAXINT)))))
+ # If the first sample didn't turn out large enough, keep trying to
+ # take samples; this shouldn't happen often because we use a big
+ # multiplier for thei initial size
+ while (length(samples) < total)
+ samples <- collect(sampleRDD(x, withReplacement, fraction,
+ as.integer(ceiling(runif(1,
+ -MAXINT,
+ MAXINT)))))
+
+ # TODO(zongheng): investigate if this call is an in-place shuffle?
+ sample(samples)[1:total]
+ })
+
+#' Creates tuples of the elements in this RDD by applying a function.
+#'
+#' @param x The RDD.
+#' @param func The function to be applied.
+#' @examples
+#'\dontrun{
+#' sc <- sparkR.init()
+#' rdd <- parallelize(sc, list(1, 2, 3))
+#' collect(keyBy(rdd, function(x) { x*x })) # list(list(1, 1), list(4, 2), list(9, 3))
+#'}
+#' @rdname keyBy
+#' @aliases keyBy,RDD
+setMethod("keyBy",
+ signature(x = "RDD", func = "function"),
+ function(x, func) {
+ apply.func <- function(x) {
+ list(func(x), x)
+ }
+ lapply(x, apply.func)
+ })
+
+#' Return a new RDD that has exactly numPartitions partitions.
+#' Can increase or decrease the level of parallelism in this RDD. Internally,
+#' this uses a shuffle to redistribute data.
+#' If you are decreasing the number of partitions in this RDD, consider using
+#' coalesce, which can avoid performing a shuffle.
+#'
+#' @param x The RDD.
+#' @param numPartitions Number of partitions to create.
+#' @seealso coalesce
+#' @examples
+#'\dontrun{
+#' sc <- sparkR.init()
+#' rdd <- parallelize(sc, list(1, 2, 3, 4, 5, 6, 7), 4L)
+#' numPartitions(rdd) # 4
+#' numPartitions(repartition(rdd, 2L)) # 2
+#'}
+#' @rdname repartition
+#' @aliases repartition,RDD
+setMethod("repartition",
+ signature(x = "RDD", numPartitions = "numeric"),
+ function(x, numPartitions) {
+ coalesce(x, numToInt(numPartitions), TRUE)
+ })
+
+#' Return a new RDD that is reduced into numPartitions partitions.
+#'
+#' @param x The RDD.
+#' @param numPartitions Number of partitions to create.
+#' @seealso repartition
+#' @examples
+#'\dontrun{
+#' sc <- sparkR.init()
+#' rdd <- parallelize(sc, list(1, 2, 3, 4, 5), 3L)
+#' numPartitions(rdd) # 3
+#' numPartitions(coalesce(rdd, 1L)) # 1
+#'}
+#' @rdname coalesce
+#' @aliases coalesce,RDD
+setMethod("coalesce",
+ signature(x = "RDD", numPartitions = "numeric"),
+ function(x, numPartitions, shuffle = FALSE) {
+ numPartitions <- numToInt(numPartitions)
+ if (shuffle || numPartitions > SparkR::numPartitions(x)) {
+ func <- function(s, part) {
+ set.seed(s) # split as seed
+ start <- as.integer(sample(numPartitions, 1) - 1)
+ lapply(seq_along(part),
+ function(i) {
+ pos <- (start + i) %% numPartitions
+ list(pos, part[[i]])
+ })
+ }
+ shuffled <- lapplyPartitionsWithIndex(x, func)
+ repartitioned <- partitionBy(shuffled, numPartitions)
+ values(repartitioned)
+ } else {
+ jrdd <- callJMethod(getJRDD(x), "coalesce", numPartitions, shuffle)
+ RDD(jrdd)
+ }
+ })
+
+#' Save this RDD as a SequenceFile of serialized objects.
+#'
+#' @param x The RDD to save
+#' @param path The directory where the file is saved
+#' @seealso objectFile
+#' @examples
+#'\dontrun{
+#' sc <- sparkR.init()
+#' rdd <- parallelize(sc, 1:3)
+#' saveAsObjectFile(rdd, "/tmp/sparkR-tmp")
+#'}
+#' @rdname saveAsObjectFile
+#' @aliases saveAsObjectFile,RDD
+setMethod("saveAsObjectFile",
+ signature(x = "RDD", path = "character"),
+ function(x, path) {
+ # If serializedMode == "string" we need to serialize the data before saving it since
+ # objectFile() assumes serializedMode == "byte".
+ if (getSerializedMode(x) != "byte") {
+ x <- serializeToBytes(x)
+ }
+ # Return nothing
+ invisible(callJMethod(getJRDD(x), "saveAsObjectFile", path))
+ })
+
+#' Save this RDD as a text file, using string representations of elements.
+#'
+#' @param x The RDD to save
+#' @param path The directory where the splits of the text file are saved
+#' @examples
+#'\dontrun{
+#' sc <- sparkR.init()
+#' rdd <- parallelize(sc, 1:3)
+#' saveAsTextFile(rdd, "/tmp/sparkR-tmp")
+#'}
+#' @rdname saveAsTextFile
+#' @aliases saveAsTextFile,RDD
+setMethod("saveAsTextFile",
+ signature(x = "RDD", path = "character"),
+ function(x, path) {
+ func <- function(str) {
+ toString(str)
+ }
+ stringRdd <- lapply(x, func)
+ # Return nothing
+ invisible(
+ callJMethod(getJRDD(stringRdd, serializedMode = "string"), "saveAsTextFile", path))
+ })
+
+#' Sort an RDD by the given key function.
+#'
+#' @param x An RDD to be sorted.
+#' @param func A function used to compute the sort key for each element.
+#' @param ascending A flag to indicate whether the sorting is ascending or descending.
+#' @param numPartitions Number of partitions to create.
+#' @return An RDD where all elements are sorted.
+#' @examples
+#'\dontrun{
+#' sc <- sparkR.init()
+#' rdd <- parallelize(sc, list(3, 2, 1))
+#' collect(sortBy(rdd, function(x) { x })) # list (1, 2, 3)
+#'}
+#' @rdname sortBy
+#' @aliases sortBy,RDD,RDD-method
+setMethod("sortBy",
+ signature(x = "RDD", func = "function"),
+ function(x, func, ascending = TRUE, numPartitions = SparkR::numPartitions(x)) {
+ values(sortByKey(keyBy(x, func), ascending, numPartitions))
+ })
+
+# Helper function to get first N elements from an RDD in the specified order.
+# Param:
+# x An RDD.
+# num Number of elements to return.
+# ascending A flag to indicate whether the sorting is ascending or descending.
+# Return:
+# A list of the first N elements from the RDD in the specified order.
+#
+takeOrderedElem <- function(x, num, ascending = TRUE) {
+ if (num <= 0L) {
+ return(list())
+ }
+
+ partitionFunc <- function(part) {
+ if (num < length(part)) {
+ # R limitation: order works only on primitive types!
+ ord <- order(unlist(part, recursive = FALSE), decreasing = !ascending)
+ list(part[ord[1:num]])
+ } else {
+ list(part)
+ }
+ }
+
+ reduceFunc <- function(elems, part) {
+ newElems <- append(elems, part)
+ # R limitation: order works only on primitive types!
+ ord <- order(unlist(newElems, recursive = FALSE), decreasing = !ascending)
+ newElems[ord[1:num]]
+ }
+
+ newRdd <- mapPartitions(x, partitionFunc)
+ reduce(newRdd, reduceFunc)
+}
+
+#' Returns the first N elements from an RDD in ascending order.
+#'
+#' @param x An RDD.
+#' @param num Number of elements to return.
+#' @return The first N elements from the RDD in ascending order.
+#' @examples
+#'\dontrun{
+#' sc <- sparkR.init()
+#' rdd <- parallelize(sc, list(10, 1, 2, 9, 3, 4, 5, 6, 7))
+#' takeOrdered(rdd, 6L) # list(1, 2, 3, 4, 5, 6)
+#'}
+#' @rdname takeOrdered
+#' @aliases takeOrdered,RDD,RDD-method
+setMethod("takeOrdered",
+ signature(x = "RDD", num = "integer"),
+ function(x, num) {
+ takeOrderedElem(x, num)
+ })
+
+#' Returns the top N elements from an RDD.
+#'
+#' @param x An RDD.
+#' @param num Number of elements to return.
+#' @return The top N elements from the RDD.
+#' @rdname top
+#' @export
+#' @examples
+#'\dontrun{
+#' sc <- sparkR.init()
+#' rdd <- parallelize(sc, list(10, 1, 2, 9, 3, 4, 5, 6, 7))
+#' top(rdd, 6L) # list(10, 9, 7, 6, 5, 4)
+#'}
+#' @rdname top
+#' @aliases top,RDD,RDD-method
+setMethod("top",
+ signature(x = "RDD", num = "integer"),
+ function(x, num) {
+ takeOrderedElem(x, num, FALSE)
+ })
+
+#' Fold an RDD using a given associative function and a neutral "zero value".
+#'
+#' Aggregate the elements of each partition, and then the results for all the
+#' partitions, using a given associative function and a neutral "zero value".
+#'
+#' @param x An RDD.
+#' @param zeroValue A neutral "zero value".
+#' @param op An associative function for the folding operation.
+#' @return The folding result.
+#' @rdname fold
+#' @seealso reduce
+#' @export
+#' @examples
+#'\dontrun{
+#' sc <- sparkR.init()
+#' rdd <- parallelize(sc, list(1, 2, 3, 4, 5))
+#' fold(rdd, 0, "+") # 15
+#'}
+#' @rdname fold
+#' @aliases fold,RDD,RDD-method
+setMethod("fold",
+ signature(x = "RDD", zeroValue = "ANY", op = "ANY"),
+ function(x, zeroValue, op) {
+ aggregateRDD(x, zeroValue, op, op)
+ })
+
+#' Aggregate an RDD using the given combine functions and a neutral "zero value".
+#'
+#' Aggregate the elements of each partition, and then the results for all the
+#' partitions, using given combine functions and a neutral "zero value".
+#'
+#' @param x An RDD.
+#' @param zeroValue A neutral "zero value".
+#' @param seqOp A function to aggregate the RDD elements. It may return a different
+#' result type from the type of the RDD elements.
+#' @param combOp A function to aggregate results of seqOp.
+#' @return The aggregation result.
+#' @rdname aggregateRDD
+#' @seealso reduce
+#' @export
+#' @examples
+#'\dontrun{
+#' sc <- sparkR.init()
+#' rdd <- parallelize(sc, list(1, 2, 3, 4))
+#' zeroValue <- list(0, 0)
+#' seqOp <- function(x, y) { list(x[[1]] + y, x[[2]] + 1) }
+#' combOp <- function(x, y) { list(x[[1]] + y[[1]], x[[2]] + y[[2]]) }
+#' aggregateRDD(rdd, zeroValue, seqOp, combOp) # list(10, 4)
+#'}
+#' @rdname aggregateRDD
+#' @aliases aggregateRDD,RDD,RDD-method
+setMethod("aggregateRDD",
+ signature(x = "RDD", zeroValue = "ANY", seqOp = "ANY", combOp = "ANY"),
+ function(x, zeroValue, seqOp, combOp) {
+ partitionFunc <- function(part) {
+ Reduce(seqOp, part, zeroValue)
+ }
+
+ partitionList <- collect(lapplyPartition(x, partitionFunc),
+ flatten = FALSE)
+ Reduce(combOp, partitionList, zeroValue)
+ })
+
+#' Pipes elements to a forked external process.
+#'
+#' The same as 'pipe()' in Spark.
+#'
+#' @param x The RDD whose elements are piped to the forked external process.
+#' @param command The command to fork an external process.
+#' @param env A named list to set environment variables of the external process.
+#' @return A new RDD created by piping all elements to a forked external process.
+#' @rdname pipeRDD
+#' @export
+#' @examples
+#'\dontrun{
+#' sc <- sparkR.init()
+#' rdd <- parallelize(sc, 1:10)
+#' collect(pipeRDD(rdd, "more")
+#' Output: c("1", "2", ..., "10")
+#'}
+#' @rdname pipeRDD
+#' @aliases pipeRDD,RDD,character-method
+setMethod("pipeRDD",
+ signature(x = "RDD", command = "character"),
+ function(x, command, env = list()) {
+ func <- function(part) {
+ trim.trailing.func <- function(x) {
+ sub("[\r\n]*$", "", toString(x))
+ }
+ input <- unlist(lapply(part, trim.trailing.func))
+ res <- system2(command, stdout = TRUE, input = input, env = env)
+ lapply(res, trim.trailing.func)
+ }
+ lapplyPartition(x, func)
+ })
+
+# TODO: Consider caching the name in the RDD's environment
+#' Return an RDD's name.
+#'
+#' @param x The RDD whose name is returned.
+#' @rdname name
+#' @export
+#' @examples
+#'\dontrun{
+#' sc <- sparkR.init()
+#' rdd <- parallelize(sc, list(1,2,3))
+#' name(rdd) # NULL (if not set before)
+#'}
+#' @rdname name
+#' @aliases name,RDD
+setMethod("name",
+ signature(x = "RDD"),
+ function(x) {
+ callJMethod(getJRDD(x), "name")
+ })
+
+#' Set an RDD's name.
+#'
+#' @param x The RDD whose name is to be set.
+#' @param name The RDD name to be set.
+#' @return a new RDD renamed.
+#' @rdname setName
+#' @export
+#' @examples
+#'\dontrun{
+#' sc <- sparkR.init()
+#' rdd <- parallelize(sc, list(1,2,3))
+#' setName(rdd, "myRDD")
+#' name(rdd) # "myRDD"
+#'}
+#' @rdname setName
+#' @aliases setName,RDD
+setMethod("setName",
+ signature(x = "RDD", name = "character"),
+ function(x, name) {
+ callJMethod(getJRDD(x), "setName", name)
+ x
+ })
+
+#' Zip an RDD with generated unique Long IDs.
+#'
+#' Items in the kth partition will get ids k, n+k, 2*n+k, ..., where
+#' n is the number of partitions. So there may exist gaps, but this
+#' method won't trigger a spark job, which is different from
+#' zipWithIndex.
+#'
+#' @param x An RDD to be zipped.
+#' @return An RDD with zipped items.
+#' @seealso zipWithIndex
+#' @examples
+#'\dontrun{
+#' sc <- sparkR.init()
+#' rdd <- parallelize(sc, list("a", "b", "c", "d", "e"), 3L)
+#' collect(zipWithUniqueId(rdd))
+#' # list(list("a", 0), list("b", 3), list("c", 1), list("d", 4), list("e", 2))
+#'}
+#' @rdname zipWithUniqueId
+#' @aliases zipWithUniqueId,RDD
+setMethod("zipWithUniqueId",
+ signature(x = "RDD"),
+ function(x) {
+ n <- numPartitions(x)
+
+ partitionFunc <- function(split, part) {
+ mapply(
+ function(item, index) {
+ list(item, (index - 1) * n + split)
+ },
+ part,
+ seq_along(part),
+ SIMPLIFY = FALSE)
+ }
+
+ lapplyPartitionsWithIndex(x, partitionFunc)
+ })
+
+#' Zip an RDD with its element indices.
+#'
+#' The ordering is first based on the partition index and then the
+#' ordering of items within each partition. So the first item in
+#' the first partition gets index 0, and the last item in the last
+#' partition receives the largest index.
+#'
+#' This method needs to trigger a Spark job when this RDD contains
+#' more than one partition.
+#'
+#' @param x An RDD to be zipped.
+#' @return An RDD with zipped items.
+#' @seealso zipWithUniqueId
+#' @examples
+#'\dontrun{
+#' sc <- sparkR.init()
+#' rdd <- parallelize(sc, list("a", "b", "c", "d", "e"), 3L)
+#' collect(zipWithIndex(rdd))
+#' # list(list("a", 0), list("b", 1), list("c", 2), list("d", 3), list("e", 4))
+#'}
+#' @rdname zipWithIndex
+#' @aliases zipWithIndex,RDD
+setMethod("zipWithIndex",
+ signature(x = "RDD"),
+ function(x) {
+ n <- numPartitions(x)
+ if (n > 1) {
+ nums <- collect(lapplyPartition(x,
+ function(part) {
+ list(length(part))
+ }))
+ startIndices <- Reduce("+", nums, accumulate = TRUE)
+ }
+
+ partitionFunc <- function(split, part) {
+ if (split == 0) {
+ startIndex <- 0
+ } else {
+ startIndex <- startIndices[[split]]
+ }
+
+ mapply(
+ function(item, index) {
+ list(item, index - 1 + startIndex)
+ },
+ part,
+ seq_along(part),
+ SIMPLIFY = FALSE)
+ }
+
+ lapplyPartitionsWithIndex(x, partitionFunc)
+ })
+
+#' Coalesce all elements within each partition of an RDD into a list.
+#'
+#' @param x An RDD.
+#' @return An RDD created by coalescing all elements within
+#' each partition into a list.
+#' @examples
+#'\dontrun{
+#' sc <- sparkR.init()
+#' rdd <- parallelize(sc, as.list(1:4), 2L)
+#' collect(glom(rdd))
+#' # list(list(1, 2), list(3, 4))
+#'}
+#' @rdname glom
+#' @aliases glom,RDD
+setMethod("glom",
+ signature(x = "RDD"),
+ function(x) {
+ partitionFunc <- function(part) {
+ list(part)
+ }
+
+ lapplyPartition(x, partitionFunc)
+ })
+
+############ Binary Functions #############
+
+#' Return the union RDD of two RDDs.
+#' The same as union() in Spark.
+#'
+#' @param x An RDD.
+#' @param y An RDD.
+#' @return a new RDD created by performing the simple union (witout removing
+#' duplicates) of two input RDDs.
+#' @examples
+#'\dontrun{
+#' sc <- sparkR.init()
+#' rdd <- parallelize(sc, 1:3)
+#' unionRDD(rdd, rdd) # 1, 2, 3, 1, 2, 3
+#'}
+#' @rdname unionRDD
+#' @aliases unionRDD,RDD,RDD-method
+setMethod("unionRDD",
+ signature(x = "RDD", y = "RDD"),
+ function(x, y) {
+ if (getSerializedMode(x) == getSerializedMode(y)) {
+ jrdd <- callJMethod(getJRDD(x), "union", getJRDD(y))
+ union.rdd <- RDD(jrdd, getSerializedMode(x))
+ } else {
+ # One of the RDDs is not serialized, we need to serialize it first.
+ if (getSerializedMode(x) != "byte") x <- serializeToBytes(x)
+ if (getSerializedMode(y) != "byte") y <- serializeToBytes(y)
+ jrdd <- callJMethod(getJRDD(x), "union", getJRDD(y))
+ union.rdd <- RDD(jrdd, "byte")
+ }
+ union.rdd
+ })
+
+#' Zip an RDD with another RDD.
+#'
+#' Zips this RDD with another one, returning key-value pairs with the
+#' first element in each RDD second element in each RDD, etc. Assumes
+#' that the two RDDs have the same number of partitions and the same
+#' number of elements in each partition (e.g. one was made through
+#' a map on the other).
+#'
+#' @param x An RDD to be zipped.
+#' @param other Another RDD to be zipped.
+#' @return An RDD zipped from the two RDDs.
+#' @examples
+#'\dontrun{
+#' sc <- sparkR.init()
+#' rdd1 <- parallelize(sc, 0:4)
+#' rdd2 <- parallelize(sc, 1000:1004)
+#' collect(zipRDD(rdd1, rdd2))
+#' # list(list(0, 1000), list(1, 1001), list(2, 1002), list(3, 1003), list(4, 1004))
+#'}
+#' @rdname zipRDD
+#' @aliases zipRDD,RDD
+setMethod("zipRDD",
+ signature(x = "RDD", other = "RDD"),
+ function(x, other) {
+ n1 <- numPartitions(x)
+ n2 <- numPartitions(other)
+ if (n1 != n2) {
+ stop("Can only zip RDDs which have the same number of partitions.")
+ }
+
+ if (getSerializedMode(x) != getSerializedMode(other) ||
+ getSerializedMode(x) == "byte") {
+ # Append the number of elements in each partition to that partition so that we can later
+ # check if corresponding partitions of both RDDs have the same number of elements.
+ #
+ # Note that this appending also serves the purpose of reserialization, because even if
+ # any RDD is serialized, we need to reserialize it to make sure its partitions are encoded
+ # as a single byte array. For example, partitions of an RDD generated from partitionBy()
+ # may be encoded as multiple byte arrays.
+ appendLength <- function(part) {
+ part[[length(part) + 1]] <- length(part) + 1
+ part
+ }
+ x <- lapplyPartition(x, appendLength)
+ other <- lapplyPartition(other, appendLength)
+ }
+
+ zippedJRDD <- callJMethod(getJRDD(x), "zip", getJRDD(other))
+ # The zippedRDD's elements are of scala Tuple2 type. The serialized
+ # flag Here is used for the elements inside the tuples.
+ serializerMode <- getSerializedMode(x)
+ zippedRDD <- RDD(zippedJRDD, serializerMode)
+
+ partitionFunc <- function(split, part) {
+ len <- length(part)
+ if (len > 0) {
+ if (serializerMode == "byte") {
+ lengthOfValues <- part[[len]]
+ lengthOfKeys <- part[[len - lengthOfValues]]
+ stopifnot(len == lengthOfKeys + lengthOfValues)
+
+ # check if corresponding partitions of both RDDs have the same number of elements.
+ if (lengthOfKeys != lengthOfValues) {
+ stop("Can only zip RDDs with same number of elements in each pair of corresponding partitions.")
+ }
+
+ if (lengthOfKeys > 1) {
+ keys <- part[1 : (lengthOfKeys - 1)]
+ values <- part[(lengthOfKeys + 1) : (len - 1)]
+ } else {
+ keys <- list()
+ values <- list()
+ }
+ } else {
+ # Keys, values must have same length here, because this has
+ # been validated inside the JavaRDD.zip() function.
+ keys <- part[c(TRUE, FALSE)]
+ values <- part[c(FALSE, TRUE)]
+ }
+ mapply(
+ function(k, v) {
+ list(k, v)
+ },
+ keys,
+ values,
+ SIMPLIFY = FALSE,
+ USE.NAMES = FALSE)
+ } else {
+ part
+ }
+ }
+
+ PipelinedRDD(zippedRDD, partitionFunc)
+ })
diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R
new file mode 100644
index 0000000000..930ada22f4
--- /dev/null
+++ b/R/pkg/R/SQLContext.R
@@ -0,0 +1,520 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# SQLcontext.R: SQLContext-driven functions
+
+#' infer the SQL type
+infer_type <- function(x) {
+ if (is.null(x)) {
+ stop("can not infer type from NULL")
+ }
+
+ # class of POSIXlt is c("POSIXlt" "POSIXt")
+ type <- switch(class(x)[[1]],
+ integer = "integer",
+ character = "string",
+ logical = "boolean",
+ double = "double",
+ numeric = "double",
+ raw = "binary",
+ list = "array",
+ environment = "map",
+ Date = "date",
+ POSIXlt = "timestamp",
+ POSIXct = "timestamp",
+ stop(paste("Unsupported type for DataFrame:", class(x))))
+
+ if (type == "map") {
+ stopifnot(length(x) > 0)
+ key <- ls(x)[[1]]
+ list(type = "map",
+ keyType = "string",
+ valueType = infer_type(get(key, x)),
+ valueContainsNull = TRUE)
+ } else if (type == "array") {
+ stopifnot(length(x) > 0)
+ names <- names(x)
+ if (is.null(names)) {
+ list(type = "array", elementType = infer_type(x[[1]]), containsNull = TRUE)
+ } else {
+ # StructType
+ types <- lapply(x, infer_type)
+ fields <- lapply(1:length(x), function(i) {
+ list(name = names[[i]], type = types[[i]], nullable = TRUE)
+ })
+ list(type = "struct", fields = fields)
+ }
+ } else if (length(x) > 1) {
+ list(type = "array", elementType = type, containsNull = TRUE)
+ } else {
+ type
+ }
+}
+
+#' dump the schema into JSON string
+tojson <- function(x) {
+ if (is.list(x)) {
+ names <- names(x)
+ if (!is.null(names)) {
+ items <- lapply(names, function(n) {
+ safe_n <- gsub('"', '\\"', n)
+ paste(tojson(safe_n), ':', tojson(x[[n]]), sep = '')
+ })
+ d <- paste(items, collapse = ', ')
+ paste('{', d, '}', sep = '')
+ } else {
+ l <- paste(lapply(x, tojson), collapse = ', ')
+ paste('[', l, ']', sep = '')
+ }
+ } else if (is.character(x)) {
+ paste('"', x, '"', sep = '')
+ } else if (is.logical(x)) {
+ if (x) "true" else "false"
+ } else {
+ stop(paste("unexpected type:", class(x)))
+ }
+}
+
+#' Create a DataFrame from an RDD
+#'
+#' Converts an RDD to a DataFrame by infer the types.
+#'
+#' @param sqlCtx A SQLContext
+#' @param data An RDD or list or data.frame
+#' @param schema a list of column names or named list (StructType), optional
+#' @return an DataFrame
+#' @export
+#' @examples
+#'\dontrun{
+#' sc <- sparkR.init()
+#' sqlCtx <- sparkRSQL.init(sc)
+#' rdd <- lapply(parallelize(sc, 1:10), function(x) list(a=x, b=as.character(x)))
+#' df <- createDataFrame(sqlCtx, rdd)
+#' }
+
+# TODO(davies): support sampling and infer type from NA
+createDataFrame <- function(sqlCtx, data, schema = NULL, samplingRatio = 1.0) {
+ if (is.data.frame(data)) {
+ # get the names of columns, they will be put into RDD
+ schema <- names(data)
+ n <- nrow(data)
+ m <- ncol(data)
+ # get rid of factor type
+ dropFactor <- function(x) {
+ if (is.factor(x)) {
+ as.character(x)
+ } else {
+ x
+ }
+ }
+ data <- lapply(1:n, function(i) {
+ lapply(1:m, function(j) { dropFactor(data[i,j]) })
+ })
+ }
+ if (is.list(data)) {
+ sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sqlCtx)
+ rdd <- parallelize(sc, data)
+ } else if (inherits(data, "RDD")) {
+ rdd <- data
+ } else {
+ stop(paste("unexpected type:", class(data)))
+ }
+
+ if (is.null(schema) || is.null(names(schema))) {
+ row <- first(rdd)
+ names <- if (is.null(schema)) {
+ names(row)
+ } else {
+ as.list(schema)
+ }
+ if (is.null(names)) {
+ names <- lapply(1:length(row), function(x) {
+ paste("_", as.character(x), sep = "")
+ })
+ }
+
+ # SPAKR-SQL does not support '.' in column name, so replace it with '_'
+ # TODO(davies): remove this once SPARK-2775 is fixed
+ names <- lapply(names, function(n) {
+ nn <- gsub("[.]", "_", n)
+ if (nn != n) {
+ warning(paste("Use", nn, "instead of", n, " as column name"))
+ }
+ nn
+ })
+
+ types <- lapply(row, infer_type)
+ fields <- lapply(1:length(row), function(i) {
+ list(name = names[[i]], type = types[[i]], nullable = TRUE)
+ })
+ schema <- list(type = "struct", fields = fields)
+ }
+
+ stopifnot(class(schema) == "list")
+ stopifnot(schema$type == "struct")
+ stopifnot(class(schema$fields) == "list")
+ schemaString <- tojson(schema)
+
+ jrdd <- getJRDD(lapply(rdd, function(x) x), "row")
+ srdd <- callJMethod(jrdd, "rdd")
+ sdf <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "createDF",
+ srdd, schemaString, sqlCtx)
+ dataFrame(sdf)
+}
+
+#' toDF
+#'
+#' Converts an RDD to a DataFrame by infer the types.
+#'
+#' @param x An RDD
+#'
+#' @rdname DataFrame
+#' @export
+#' @examples
+#'\dontrun{
+#' sc <- sparkR.init()
+#' sqlCtx <- sparkRSQL.init(sc)
+#' rdd <- lapply(parallelize(sc, 1:10), function(x) list(a=x, b=as.character(x)))
+#' df <- toDF(rdd)
+#' }
+
+setGeneric("toDF", function(x, ...) { standardGeneric("toDF") })
+
+setMethod("toDF", signature(x = "RDD"),
+ function(x, ...) {
+ sqlCtx <- if (exists(".sparkRHivesc", envir = .sparkREnv)) {
+ get(".sparkRHivesc", envir = .sparkREnv)
+ } else if (exists(".sparkRSQLsc", envir = .sparkREnv)) {
+ get(".sparkRSQLsc", envir = .sparkREnv)
+ } else {
+ stop("no SQL context available")
+ }
+ createDataFrame(sqlCtx, x, ...)
+ })
+
+#' Create a DataFrame from a JSON file.
+#'
+#' Loads a JSON file (one object per line), returning the result as a DataFrame
+#' It goes through the entire dataset once to determine the schema.
+#'
+#' @param sqlCtx SQLContext to use
+#' @param path Path of file to read. A vector of multiple paths is allowed.
+#' @return DataFrame
+#' @export
+#' @examples
+#'\dontrun{
+#' sc <- sparkR.init()
+#' sqlCtx <- sparkRSQL.init(sc)
+#' path <- "path/to/file.json"
+#' df <- jsonFile(sqlCtx, path)
+#' }
+
+jsonFile <- function(sqlCtx, path) {
+ # Allow the user to have a more flexible definiton of the text file path
+ path <- normalizePath(path)
+ # Convert a string vector of paths to a string containing comma separated paths
+ path <- paste(path, collapse = ",")
+ sdf <- callJMethod(sqlCtx, "jsonFile", path)
+ dataFrame(sdf)
+}
+
+
+#' JSON RDD
+#'
+#' Loads an RDD storing one JSON object per string as a DataFrame.
+#'
+#' @param sqlCtx SQLContext to use
+#' @param rdd An RDD of JSON string
+#' @param schema A StructType object to use as schema
+#' @param samplingRatio The ratio of simpling used to infer the schema
+#' @return A DataFrame
+#' @export
+#' @examples
+#'\dontrun{
+#' sc <- sparkR.init()
+#' sqlCtx <- sparkRSQL.init(sc)
+#' rdd <- texFile(sc, "path/to/json")
+#' df <- jsonRDD(sqlCtx, rdd)
+#' }
+
+# TODO: support schema
+jsonRDD <- function(sqlCtx, rdd, schema = NULL, samplingRatio = 1.0) {
+ rdd <- serializeToString(rdd)
+ if (is.null(schema)) {
+ sdf <- callJMethod(sqlCtx, "jsonRDD", callJMethod(getJRDD(rdd), "rdd"), samplingRatio)
+ dataFrame(sdf)
+ } else {
+ stop("not implemented")
+ }
+}
+
+
+#' Create a DataFrame from a Parquet file.
+#'
+#' Loads a Parquet file, returning the result as a DataFrame.
+#'
+#' @param sqlCtx SQLContext to use
+#' @param ... Path(s) of parquet file(s) to read.
+#' @return DataFrame
+#' @export
+
+# TODO: Implement saveasParquetFile and write examples for both
+parquetFile <- function(sqlCtx, ...) {
+ # Allow the user to have a more flexible definiton of the text file path
+ paths <- lapply(list(...), normalizePath)
+ sdf <- callJMethod(sqlCtx, "parquetFile", paths)
+ dataFrame(sdf)
+}
+
+#' SQL Query
+#'
+#' Executes a SQL query using Spark, returning the result as a DataFrame.
+#'
+#' @param sqlCtx SQLContext to use
+#' @param sqlQuery A character vector containing the SQL query
+#' @return DataFrame
+#' @export
+#' @examples
+#'\dontrun{
+#' sc <- sparkR.init()
+#' sqlCtx <- sparkRSQL.init(sc)
+#' path <- "path/to/file.json"
+#' df <- jsonFile(sqlCtx, path)
+#' registerTempTable(df, "table")
+#' new_df <- sql(sqlCtx, "SELECT * FROM table")
+#' }
+
+sql <- function(sqlCtx, sqlQuery) {
+ sdf <- callJMethod(sqlCtx, "sql", sqlQuery)
+ dataFrame(sdf)
+}
+
+
+#' Create a DataFrame from a SparkSQL Table
+#'
+#' Returns the specified Table as a DataFrame. The Table must have already been registered
+#' in the SQLContext.
+#'
+#' @param sqlCtx SQLContext to use
+#' @param tableName The SparkSQL Table to convert to a DataFrame.
+#' @return DataFrame
+#' @export
+#' @examples
+#'\dontrun{
+#' sc <- sparkR.init()
+#' sqlCtx <- sparkRSQL.init(sc)
+#' path <- "path/to/file.json"
+#' df <- jsonFile(sqlCtx, path)
+#' registerTempTable(df, "table")
+#' new_df <- table(sqlCtx, "table")
+#' }
+
+table <- function(sqlCtx, tableName) {
+ sdf <- callJMethod(sqlCtx, "table", tableName)
+ dataFrame(sdf)
+}
+
+
+#' Tables
+#'
+#' Returns a DataFrame containing names of tables in the given database.
+#'
+#' @param sqlCtx SQLContext to use
+#' @param databaseName name of the database
+#' @return a DataFrame
+#' @export
+#' @examples
+#'\dontrun{
+#' sc <- sparkR.init()
+#' sqlCtx <- sparkRSQL.init(sc)
+#' tables(sqlCtx, "hive")
+#' }
+
+tables <- function(sqlCtx, databaseName = NULL) {
+ jdf <- if (is.null(databaseName)) {
+ callJMethod(sqlCtx, "tables")
+ } else {
+ callJMethod(sqlCtx, "tables", databaseName)
+ }
+ dataFrame(jdf)
+}
+
+
+#' Table Names
+#'
+#' Returns the names of tables in the given database as an array.
+#'
+#' @param sqlCtx SQLContext to use
+#' @param databaseName name of the database
+#' @return a list of table names
+#' @export
+#' @examples
+#'\dontrun{
+#' sc <- sparkR.init()
+#' sqlCtx <- sparkRSQL.init(sc)
+#' tableNames(sqlCtx, "hive")
+#' }
+
+tableNames <- function(sqlCtx, databaseName = NULL) {
+ if (is.null(databaseName)) {
+ callJMethod(sqlCtx, "tableNames")
+ } else {
+ callJMethod(sqlCtx, "tableNames", databaseName)
+ }
+}
+
+
+#' Cache Table
+#'
+#' Caches the specified table in-memory.
+#'
+#' @param sqlCtx SQLContext to use
+#' @param tableName The name of the table being cached
+#' @return DataFrame
+#' @export
+#' @examples
+#'\dontrun{
+#' sc <- sparkR.init()
+#' sqlCtx <- sparkRSQL.init(sc)
+#' path <- "path/to/file.json"
+#' df <- jsonFile(sqlCtx, path)
+#' registerTempTable(df, "table")
+#' cacheTable(sqlCtx, "table")
+#' }
+
+cacheTable <- function(sqlCtx, tableName) {
+ callJMethod(sqlCtx, "cacheTable", tableName)
+}
+
+#' Uncache Table
+#'
+#' Removes the specified table from the in-memory cache.
+#'
+#' @param sqlCtx SQLContext to use
+#' @param tableName The name of the table being uncached
+#' @return DataFrame
+#' @export
+#' @examples
+#'\dontrun{
+#' sc <- sparkR.init()
+#' sqlCtx <- sparkRSQL.init(sc)
+#' path <- "path/to/file.json"
+#' df <- jsonFile(sqlCtx, path)
+#' registerTempTable(df, "table")
+#' uncacheTable(sqlCtx, "table")
+#' }
+
+uncacheTable <- function(sqlCtx, tableName) {
+ callJMethod(sqlCtx, "uncacheTable", tableName)
+}
+
+#' Clear Cache
+#'
+#' Removes all cached tables from the in-memory cache.
+#'
+#' @param sqlCtx SQLContext to use
+#' @examples
+#' \dontrun{
+#' clearCache(sqlCtx)
+#' }
+
+clearCache <- function(sqlCtx) {
+ callJMethod(sqlCtx, "clearCache")
+}
+
+#' Drop Temporary Table
+#'
+#' Drops the temporary table with the given table name in the catalog.
+#' If the table has been cached/persisted before, it's also unpersisted.
+#'
+#' @param sqlCtx SQLContext to use
+#' @param tableName The name of the SparkSQL table to be dropped.
+#' @examples
+#' \dontrun{
+#' sc <- sparkR.init()
+#' sqlCtx <- sparkRSQL.init(sc)
+#' df <- loadDF(sqlCtx, path, "parquet")
+#' registerTempTable(df, "table")
+#' dropTempTable(sqlCtx, "table")
+#' }
+
+dropTempTable <- function(sqlCtx, tableName) {
+ if (class(tableName) != "character") {
+ stop("tableName must be a string.")
+ }
+ callJMethod(sqlCtx, "dropTempTable", tableName)
+}
+
+#' Load an DataFrame
+#'
+#' Returns the dataset in a data source as a DataFrame
+#'
+#' The data source is specified by the `source` and a set of options(...).
+#' If `source` is not specified, the default data source configured by
+#' "spark.sql.sources.default" will be used.
+#'
+#' @param sqlCtx SQLContext to use
+#' @param path The path of files to load
+#' @param source the name of external data source
+#' @return DataFrame
+#' @export
+#' @examples
+#'\dontrun{
+#' sc <- sparkR.init()
+#' sqlCtx <- sparkRSQL.init(sc)
+#' df <- load(sqlCtx, "path/to/file.json", source = "json")
+#' }
+
+loadDF <- function(sqlCtx, path = NULL, source = NULL, ...) {
+ options <- varargsToEnv(...)
+ if (!is.null(path)) {
+ options[['path']] <- path
+ }
+ sdf <- callJMethod(sqlCtx, "load", source, options)
+ dataFrame(sdf)
+}
+
+#' Create an external table
+#'
+#' Creates an external table based on the dataset in a data source,
+#' Returns the DataFrame associated with the external table.
+#'
+#' The data source is specified by the `source` and a set of options(...).
+#' If `source` is not specified, the default data source configured by
+#' "spark.sql.sources.default" will be used.
+#'
+#' @param sqlCtx SQLContext to use
+#' @param tableName A name of the table
+#' @param path The path of files to load
+#' @param source the name of external data source
+#' @return DataFrame
+#' @export
+#' @examples
+#'\dontrun{
+#' sc <- sparkR.init()
+#' sqlCtx <- sparkRSQL.init(sc)
+#' df <- sparkRSQL.createExternalTable(sqlCtx, "myjson", path="path/to/json", source="json")
+#' }
+
+createExternalTable <- function(sqlCtx, tableName, path = NULL, source = NULL, ...) {
+ options <- varargsToEnv(...)
+ if (!is.null(path)) {
+ options[['path']] <- path
+ }
+ sdf <- callJMethod(sqlCtx, "createExternalTable", tableName, source, options)
+ dataFrame(sdf)
+}
diff --git a/R/pkg/R/SQLTypes.R b/R/pkg/R/SQLTypes.R
new file mode 100644
index 0000000000..962fba5b3c
--- /dev/null
+++ b/R/pkg/R/SQLTypes.R
@@ -0,0 +1,64 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# Utility functions for handling SparkSQL DataTypes.
+
+# Handler for StructType
+structType <- function(st) {
+ obj <- structure(new.env(parent = emptyenv()), class = "structType")
+ obj$jobj <- st
+ obj$fields <- function() { lapply(callJMethod(st, "fields"), structField) }
+ obj
+}
+
+#' Print a Spark StructType.
+#'
+#' This function prints the contents of a StructType returned from the
+#' SparkR JVM backend.
+#'
+#' @param x A StructType object
+#' @param ... further arguments passed to or from other methods
+print.structType <- function(x, ...) {
+ fieldsList <- lapply(x$fields(), function(i) { i$print() })
+ print(fieldsList)
+}
+
+# Handler for StructField
+structField <- function(sf) {
+ obj <- structure(new.env(parent = emptyenv()), class = "structField")
+ obj$jobj <- sf
+ obj$name <- function() { callJMethod(sf, "name") }
+ obj$dataType <- function() { callJMethod(sf, "dataType") }
+ obj$dataType.toString <- function() { callJMethod(obj$dataType(), "toString") }
+ obj$dataType.simpleString <- function() { callJMethod(obj$dataType(), "simpleString") }
+ obj$nullable <- function() { callJMethod(sf, "nullable") }
+ obj$print <- function() { paste("StructField(",
+ paste(obj$name(), obj$dataType.toString(), obj$nullable(), sep = ", "),
+ ")", sep = "") }
+ obj
+}
+
+#' Print a Spark StructField.
+#'
+#' This function prints the contents of a StructField returned from the
+#' SparkR JVM backend.
+#'
+#' @param x A StructField object
+#' @param ... further arguments passed to or from other methods
+print.structField <- function(x, ...) {
+ cat(x$print())
+}
diff --git a/R/pkg/R/backend.R b/R/pkg/R/backend.R
new file mode 100644
index 0000000000..2fb6fae55f
--- /dev/null
+++ b/R/pkg/R/backend.R
@@ -0,0 +1,115 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# Methods to call into SparkRBackend.
+
+
+# Returns TRUE if object is an instance of given class
+isInstanceOf <- function(jobj, className) {
+ stopifnot(class(jobj) == "jobj")
+ cls <- callJStatic("java.lang.Class", "forName", className)
+ callJMethod(cls, "isInstance", jobj)
+}
+
+# Call a Java method named methodName on the object
+# specified by objId. objId should be a "jobj" returned
+# from the SparkRBackend.
+callJMethod <- function(objId, methodName, ...) {
+ stopifnot(class(objId) == "jobj")
+ if (!isValidJobj(objId)) {
+ stop("Invalid jobj ", objId$id,
+ ". If SparkR was restarted, Spark operations need to be re-executed.")
+ }
+ invokeJava(isStatic = FALSE, objId$id, methodName, ...)
+}
+
+# Call a static method on a specified className
+callJStatic <- function(className, methodName, ...) {
+ invokeJava(isStatic = TRUE, className, methodName, ...)
+}
+
+# Create a new object of the specified class name
+newJObject <- function(className, ...) {
+ invokeJava(isStatic = TRUE, className, methodName = "<init>", ...)
+}
+
+# Remove an object from the SparkR backend. This is done
+# automatically when a jobj is garbage collected.
+removeJObject <- function(objId) {
+ invokeJava(isStatic = TRUE, "SparkRHandler", "rm", objId)
+}
+
+isRemoveMethod <- function(isStatic, objId, methodName) {
+ isStatic == TRUE && objId == "SparkRHandler" && methodName == "rm"
+}
+
+# Invoke a Java method on the SparkR backend. Users
+# should typically use one of the higher level methods like
+# callJMethod, callJStatic etc. instead of using this.
+#
+# isStatic - TRUE if the method to be called is static
+# objId - String that refers to the object on which method is invoked
+# Should be a jobj id for non-static methods and the classname
+# for static methods
+# methodName - name of method to be invoked
+invokeJava <- function(isStatic, objId, methodName, ...) {
+ if (!exists(".sparkRCon", .sparkREnv)) {
+ stop("No connection to backend found. Please re-run sparkR.init")
+ }
+
+ # If this isn't a removeJObject call
+ if (!isRemoveMethod(isStatic, objId, methodName)) {
+ objsToRemove <- ls(.toRemoveJobjs)
+ if (length(objsToRemove) > 0) {
+ sapply(objsToRemove,
+ function(e) {
+ removeJObject(e)
+ })
+ rm(list = objsToRemove, envir = .toRemoveJobjs)
+ }
+ }
+
+
+ rc <- rawConnection(raw(0), "r+")
+
+ writeBoolean(rc, isStatic)
+ writeString(rc, objId)
+ writeString(rc, methodName)
+
+ args <- list(...)
+ writeInt(rc, length(args))
+ writeArgs(rc, args)
+
+ # Construct the whole request message to send it once,
+ # avoiding write-write-read pattern in case of Nagle's algorithm.
+ # Refer to http://en.wikipedia.org/wiki/Nagle%27s_algorithm for the details.
+ bytesToSend <- rawConnectionValue(rc)
+ close(rc)
+ rc <- rawConnection(raw(0), "r+")
+ writeInt(rc, length(bytesToSend))
+ writeBin(bytesToSend, rc)
+ requestMessage <- rawConnectionValue(rc)
+ close(rc)
+
+ conn <- get(".sparkRCon", .sparkREnv)
+ writeBin(requestMessage, conn)
+
+ # TODO: check the status code to output error information
+ returnStatus <- readInt(conn)
+ stopifnot(returnStatus == 0)
+ readObject(conn)
+}
diff --git a/R/pkg/R/broadcast.R b/R/pkg/R/broadcast.R
new file mode 100644
index 0000000000..583fa2e7fd
--- /dev/null
+++ b/R/pkg/R/broadcast.R
@@ -0,0 +1,86 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# S4 class representing Broadcast variables
+
+# Hidden environment that holds values for broadcast variables
+# This will not be serialized / shipped by default
+.broadcastNames <- new.env()
+.broadcastValues <- new.env()
+.broadcastIdToName <- new.env()
+
+#' @title S4 class that represents a Broadcast variable
+#' @description Broadcast variables can be created using the broadcast
+#' function from a \code{SparkContext}.
+#' @rdname broadcast-class
+#' @seealso broadcast
+#'
+#' @param id Id of the backing Spark broadcast variable
+#' @export
+setClass("Broadcast", slots = list(id = "character"))
+
+#' @rdname broadcast-class
+#' @param value Value of the broadcast variable
+#' @param jBroadcastRef reference to the backing Java broadcast object
+#' @param objName name of broadcasted object
+#' @export
+Broadcast <- function(id, value, jBroadcastRef, objName) {
+ .broadcastValues[[id]] <- value
+ .broadcastNames[[as.character(objName)]] <- jBroadcastRef
+ .broadcastIdToName[[id]] <- as.character(objName)
+ new("Broadcast", id = id)
+}
+
+#' @description
+#' \code{value} can be used to get the value of a broadcast variable inside
+#' a distributed function.
+#'
+#' @param bcast The broadcast variable to get
+#' @rdname broadcast
+#' @aliases value,Broadcast-method
+setMethod("value",
+ signature(bcast = "Broadcast"),
+ function(bcast) {
+ if (exists(bcast@id, envir = .broadcastValues)) {
+ get(bcast@id, envir = .broadcastValues)
+ } else {
+ NULL
+ }
+ })
+
+#' Internal function to set values of a broadcast variable.
+#'
+#' This function is used internally by Spark to set the value of a broadcast
+#' variable on workers. Not intended for use outside the package.
+#'
+#' @rdname broadcast-internal
+#' @seealso broadcast, value
+
+#' @param bcastId The id of broadcast variable to set
+#' @param value The value to be set
+#' @export
+setBroadcastValue <- function(bcastId, value) {
+ bcastIdStr <- as.character(bcastId)
+ .broadcastValues[[bcastIdStr]] <- value
+}
+
+#' Helper function to clear the list of broadcast variables we know about
+#' Should be called when the SparkR JVM backend is shutdown
+clearBroadcastVariables <- function() {
+ bcasts <- ls(.broadcastNames)
+ rm(list = bcasts, envir = .broadcastNames)
+}
diff --git a/R/pkg/R/client.R b/R/pkg/R/client.R
new file mode 100644
index 0000000000..1281c41213
--- /dev/null
+++ b/R/pkg/R/client.R
@@ -0,0 +1,57 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# Client code to connect to SparkRBackend
+
+# Creates a SparkR client connection object
+# if one doesn't already exist
+connectBackend <- function(hostname, port, timeout = 6000) {
+ if (exists(".sparkRcon", envir = .sparkREnv)) {
+ if (isOpen(.sparkREnv[[".sparkRCon"]])) {
+ cat("SparkRBackend client connection already exists\n")
+ return(get(".sparkRcon", envir = .sparkREnv))
+ }
+ }
+
+ con <- socketConnection(host = hostname, port = port, server = FALSE,
+ blocking = TRUE, open = "wb", timeout = timeout)
+
+ assign(".sparkRCon", con, envir = .sparkREnv)
+ con
+}
+
+launchBackend <- function(args, sparkHome, jars, sparkSubmitOpts) {
+ if (.Platform$OS.type == "unix") {
+ sparkSubmitBinName = "spark-submit"
+ } else {
+ sparkSubmitBinName = "spark-submit.cmd"
+ }
+
+ if (sparkHome != "") {
+ sparkSubmitBin <- file.path(sparkHome, "bin", sparkSubmitBinName)
+ } else {
+ sparkSubmitBin <- sparkSubmitBinName
+ }
+
+ if (jars != "") {
+ jars <- paste("--jars", jars)
+ }
+
+ combinedArgs <- paste(jars, sparkSubmitOpts, args, sep = " ")
+ cat("Launching java with spark-submit command", sparkSubmitBin, combinedArgs, "\n")
+ invisible(system2(sparkSubmitBin, combinedArgs, wait = F))
+}
diff --git a/R/pkg/R/column.R b/R/pkg/R/column.R
new file mode 100644
index 0000000000..e196305186
--- /dev/null
+++ b/R/pkg/R/column.R
@@ -0,0 +1,199 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# Column Class
+
+#' @include generics.R jobj.R
+NULL
+
+setOldClass("jobj")
+
+#' @title S4 class that represents a DataFrame column
+#' @description The column class supports unary, binary operations on DataFrame columns
+
+#' @rdname column
+#'
+#' @param jc reference to JVM DataFrame column
+#' @export
+setClass("Column",
+ slots = list(jc = "jobj"))
+
+setMethod("initialize", "Column", function(.Object, jc) {
+ .Object@jc <- jc
+ .Object
+})
+
+column <- function(jc) {
+ new("Column", jc)
+}
+
+col <- function(x) {
+ column(callJStatic("org.apache.spark.sql.functions", "col", x))
+}
+
+#' @rdname show
+setMethod("show", "Column",
+ function(object) {
+ cat("Column", callJMethod(object@jc, "toString"), "\n")
+ })
+
+operators <- list(
+ "+" = "plus", "-" = "minus", "*" = "multiply", "/" = "divide", "%%" = "mod",
+ "==" = "equalTo", ">" = "gt", "<" = "lt", "!=" = "notEqual", "<=" = "leq", ">=" = "geq",
+ # we can not override `&&` and `||`, so use `&` and `|` instead
+ "&" = "and", "|" = "or" #, "!" = "unary_$bang"
+)
+column_functions1 <- c("asc", "desc", "isNull", "isNotNull")
+column_functions2 <- c("like", "rlike", "startsWith", "endsWith", "getField", "getItem", "contains")
+functions <- c("min", "max", "sum", "avg", "mean", "count", "abs", "sqrt",
+ "first", "last", "lower", "upper", "sumDistinct")
+
+createOperator <- function(op) {
+ setMethod(op,
+ signature(e1 = "Column"),
+ function(e1, e2) {
+ jc <- if (missing(e2)) {
+ if (op == "-") {
+ callJMethod(e1@jc, "unary_$minus")
+ } else {
+ callJMethod(e1@jc, operators[[op]])
+ }
+ } else {
+ if (class(e2) == "Column") {
+ e2 <- e2@jc
+ }
+ callJMethod(e1@jc, operators[[op]], e2)
+ }
+ column(jc)
+ })
+}
+
+createColumnFunction1 <- function(name) {
+ setMethod(name,
+ signature(x = "Column"),
+ function(x) {
+ column(callJMethod(x@jc, name))
+ })
+}
+
+createColumnFunction2 <- function(name) {
+ setMethod(name,
+ signature(x = "Column"),
+ function(x, data) {
+ if (class(data) == "Column") {
+ data <- data@jc
+ }
+ jc <- callJMethod(x@jc, name, data)
+ column(jc)
+ })
+}
+
+createStaticFunction <- function(name) {
+ setMethod(name,
+ signature(x = "Column"),
+ function(x) {
+ jc <- callJStatic("org.apache.spark.sql.functions", name, x@jc)
+ column(jc)
+ })
+}
+
+createMethods <- function() {
+ for (op in names(operators)) {
+ createOperator(op)
+ }
+ for (name in column_functions1) {
+ createColumnFunction1(name)
+ }
+ for (name in column_functions2) {
+ createColumnFunction2(name)
+ }
+ for (x in functions) {
+ createStaticFunction(x)
+ }
+}
+
+createMethods()
+
+#' alias
+#'
+#' Set a new name for a column
+setMethod("alias",
+ signature(object = "Column"),
+ function(object, data) {
+ if (is.character(data)) {
+ column(callJMethod(object@jc, "as", data))
+ } else {
+ stop("data should be character")
+ }
+ })
+
+#' An expression that returns a substring.
+#'
+#' @param start starting position
+#' @param stop ending position
+setMethod("substr", signature(x = "Column"),
+ function(x, start, stop) {
+ jc <- callJMethod(x@jc, "substr", as.integer(start - 1), as.integer(stop - start + 1))
+ column(jc)
+ })
+
+#' Casts the column to a different data type.
+#' @examples
+#' \dontrun{
+#' cast(df$age, "string")
+#' cast(df$name, list(type="array", elementType="byte", containsNull = TRUE))
+#' }
+setMethod("cast",
+ signature(x = "Column"),
+ function(x, dataType) {
+ if (is.character(dataType)) {
+ column(callJMethod(x@jc, "cast", dataType))
+ } else if (is.list(dataType)) {
+ json <- tojson(dataType)
+ jdataType <- callJStatic("org.apache.spark.sql.types.DataType", "fromJson", json)
+ column(callJMethod(x@jc, "cast", jdataType))
+ } else {
+ stop("dataType should be character or list")
+ }
+ })
+
+#' Approx Count Distinct
+#'
+#' Returns the approximate number of distinct items in a group.
+#'
+setMethod("approxCountDistinct",
+ signature(x = "Column"),
+ function(x, rsd = 0.95) {
+ jc <- callJStatic("org.apache.spark.sql.functions", "approxCountDistinct", x@jc, rsd)
+ column(jc)
+ })
+
+#' Count Distinct
+#'
+#' returns the number of distinct items in a group.
+#'
+setMethod("countDistinct",
+ signature(x = "Column"),
+ function(x, ...) {
+ jcol <- lapply(list(...), function (x) {
+ x@jc
+ })
+ jc <- callJStatic("org.apache.spark.sql.functions", "countDistinct", x@jc,
+ listToSeq(jcol))
+ column(jc)
+ })
+
diff --git a/R/pkg/R/context.R b/R/pkg/R/context.R
new file mode 100644
index 0000000000..2fc0bb294b
--- /dev/null
+++ b/R/pkg/R/context.R
@@ -0,0 +1,225 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# context.R: SparkContext driven functions
+
+getMinSplits <- function(sc, minSplits) {
+ if (is.null(minSplits)) {
+ defaultParallelism <- callJMethod(sc, "defaultParallelism")
+ minSplits <- min(defaultParallelism, 2)
+ }
+ as.integer(minSplits)
+}
+
+#' Create an RDD from a text file.
+#'
+#' This function reads a text file from HDFS, a local file system (available on all
+#' nodes), or any Hadoop-supported file system URI, and creates an
+#' RDD of strings from it.
+#'
+#' @param sc SparkContext to use
+#' @param path Path of file to read. A vector of multiple paths is allowed.
+#' @param minSplits Minimum number of splits to be created. If NULL, the default
+#' value is chosen based on available parallelism.
+#' @return RDD where each item is of type \code{character}
+#' @export
+#' @examples
+#'\dontrun{
+#' sc <- sparkR.init()
+#' lines <- textFile(sc, "myfile.txt")
+#'}
+textFile <- function(sc, path, minSplits = NULL) {
+ # Allow the user to have a more flexible definiton of the text file path
+ path <- suppressWarnings(normalizePath(path))
+ #' Convert a string vector of paths to a string containing comma separated paths
+ path <- paste(path, collapse = ",")
+
+ jrdd <- callJMethod(sc, "textFile", path, getMinSplits(sc, minSplits))
+ # jrdd is of type JavaRDD[String]
+ RDD(jrdd, "string")
+}
+
+#' Load an RDD saved as a SequenceFile containing serialized objects.
+#'
+#' The file to be loaded should be one that was previously generated by calling
+#' saveAsObjectFile() of the RDD class.
+#'
+#' @param sc SparkContext to use
+#' @param path Path of file to read. A vector of multiple paths is allowed.
+#' @param minSplits Minimum number of splits to be created. If NULL, the default
+#' value is chosen based on available parallelism.
+#' @return RDD containing serialized R objects.
+#' @seealso saveAsObjectFile
+#' @export
+#' @examples
+#'\dontrun{
+#' sc <- sparkR.init()
+#' rdd <- objectFile(sc, "myfile")
+#'}
+objectFile <- function(sc, path, minSplits = NULL) {
+ # Allow the user to have a more flexible definiton of the text file path
+ path <- suppressWarnings(normalizePath(path))
+ #' Convert a string vector of paths to a string containing comma separated paths
+ path <- paste(path, collapse = ",")
+
+ jrdd <- callJMethod(sc, "objectFile", path, getMinSplits(sc, minSplits))
+ # Assume the RDD contains serialized R objects.
+ RDD(jrdd, "byte")
+}
+
+#' Create an RDD from a homogeneous list or vector.
+#'
+#' This function creates an RDD from a local homogeneous list in R. The elements
+#' in the list are split into \code{numSlices} slices and distributed to nodes
+#' in the cluster.
+#'
+#' @param sc SparkContext to use
+#' @param coll collection to parallelize
+#' @param numSlices number of partitions to create in the RDD
+#' @return an RDD created from this collection
+#' @export
+#' @examples
+#'\dontrun{
+#' sc <- sparkR.init()
+#' rdd <- parallelize(sc, 1:10, 2)
+#' # The RDD should contain 10 elements
+#' length(rdd)
+#'}
+parallelize <- function(sc, coll, numSlices = 1) {
+ # TODO: bound/safeguard numSlices
+ # TODO: unit tests for if the split works for all primitives
+ # TODO: support matrix, data frame, etc
+ if ((!is.list(coll) && !is.vector(coll)) || is.data.frame(coll)) {
+ if (is.data.frame(coll)) {
+ message(paste("context.R: A data frame is parallelized by columns."))
+ } else {
+ if (is.matrix(coll)) {
+ message(paste("context.R: A matrix is parallelized by elements."))
+ } else {
+ message(paste("context.R: parallelize() currently only supports lists and vectors.",
+ "Calling as.list() to coerce coll into a list."))
+ }
+ }
+ coll <- as.list(coll)
+ }
+
+ if (numSlices > length(coll))
+ numSlices <- length(coll)
+
+ sliceLen <- ceiling(length(coll) / numSlices)
+ slices <- split(coll, rep(1:(numSlices + 1), each = sliceLen)[1:length(coll)])
+
+ # Serialize each slice: obtain a list of raws, or a list of lists (slices) of
+ # 2-tuples of raws
+ serializedSlices <- lapply(slices, serialize, connection = NULL)
+
+ jrdd <- callJStatic("org.apache.spark.api.r.RRDD",
+ "createRDDFromArray", sc, serializedSlices)
+
+ RDD(jrdd, "byte")
+}
+
+#' Include this specified package on all workers
+#'
+#' This function can be used to include a package on all workers before the
+#' user's code is executed. This is useful in scenarios where other R package
+#' functions are used in a function passed to functions like \code{lapply}.
+#' NOTE: The package is assumed to be installed on every node in the Spark
+#' cluster.
+#'
+#' @param sc SparkContext to use
+#' @param pkg Package name
+#'
+#' @export
+#' @examples
+#'\dontrun{
+#' library(Matrix)
+#'
+#' sc <- sparkR.init()
+#' # Include the matrix library we will be using
+#' includePackage(sc, Matrix)
+#'
+#' generateSparse <- function(x) {
+#' sparseMatrix(i=c(1, 2, 3), j=c(1, 2, 3), x=c(1, 2, 3))
+#' }
+#'
+#' rdd <- lapplyPartition(parallelize(sc, 1:2, 2L), generateSparse)
+#' collect(rdd)
+#'}
+includePackage <- function(sc, pkg) {
+ pkg <- as.character(substitute(pkg))
+ if (exists(".packages", .sparkREnv)) {
+ packages <- .sparkREnv$.packages
+ } else {
+ packages <- list()
+ }
+ packages <- c(packages, pkg)
+ .sparkREnv$.packages <- packages
+}
+
+#' @title Broadcast a variable to all workers
+#'
+#' @description
+#' Broadcast a read-only variable to the cluster, returning a \code{Broadcast}
+#' object for reading it in distributed functions.
+#'
+#' @param sc Spark Context to use
+#' @param object Object to be broadcast
+#' @export
+#' @examples
+#'\dontrun{
+#' sc <- sparkR.init()
+#' rdd <- parallelize(sc, 1:2, 2L)
+#'
+#' # Large Matrix object that we want to broadcast
+#' randomMat <- matrix(nrow=100, ncol=10, data=rnorm(1000))
+#' randomMatBr <- broadcast(sc, randomMat)
+#'
+#' # Use the broadcast variable inside the function
+#' useBroadcast <- function(x) {
+#' sum(value(randomMatBr) * x)
+#' }
+#' sumRDD <- lapply(rdd, useBroadcast)
+#'}
+broadcast <- function(sc, object) {
+ objName <- as.character(substitute(object))
+ serializedObj <- serialize(object, connection = NULL)
+
+ jBroadcast <- callJMethod(sc, "broadcast", serializedObj)
+ id <- as.character(callJMethod(jBroadcast, "id"))
+
+ Broadcast(id, object, jBroadcast, objName)
+}
+
+#' @title Set the checkpoint directory
+#'
+#' Set the directory under which RDDs are going to be checkpointed. The
+#' directory must be a HDFS path if running on a cluster.
+#'
+#' @param sc Spark Context to use
+#' @param dirName Directory path
+#' @export
+#' @examples
+#'\dontrun{
+#' sc <- sparkR.init()
+#' setCheckpointDir(sc, "~/checkpoints")
+#' rdd <- parallelize(sc, 1:2, 2L)
+#' checkpoint(rdd)
+#'}
+setCheckpointDir <- function(sc, dirName) {
+ invisible(callJMethod(sc, "setCheckpointDir", suppressWarnings(normalizePath(dirName))))
+}
diff --git a/R/pkg/R/deserialize.R b/R/pkg/R/deserialize.R
new file mode 100644
index 0000000000..257b435607
--- /dev/null
+++ b/R/pkg/R/deserialize.R
@@ -0,0 +1,184 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# Utility functions to deserialize objects from Java.
+
+# Type mapping from Java to R
+#
+# void -> NULL
+# Int -> integer
+# String -> character
+# Boolean -> logical
+# Double -> double
+# Long -> double
+# Array[Byte] -> raw
+# Date -> Date
+# Time -> POSIXct
+#
+# Array[T] -> list()
+# Object -> jobj
+
+readObject <- function(con) {
+ # Read type first
+ type <- readType(con)
+ readTypedObject(con, type)
+}
+
+readTypedObject <- function(con, type) {
+ switch (type,
+ "i" = readInt(con),
+ "c" = readString(con),
+ "b" = readBoolean(con),
+ "d" = readDouble(con),
+ "r" = readRaw(con),
+ "D" = readDate(con),
+ "t" = readTime(con),
+ "l" = readList(con),
+ "n" = NULL,
+ "j" = getJobj(readString(con)),
+ stop(paste("Unsupported type for deserialization", type)))
+}
+
+readString <- function(con) {
+ stringLen <- readInt(con)
+ string <- readBin(con, raw(), stringLen, endian = "big")
+ rawToChar(string)
+}
+
+readInt <- function(con) {
+ readBin(con, integer(), n = 1, endian = "big")
+}
+
+readDouble <- function(con) {
+ readBin(con, double(), n = 1, endian = "big")
+}
+
+readBoolean <- function(con) {
+ as.logical(readInt(con))
+}
+
+readType <- function(con) {
+ rawToChar(readBin(con, "raw", n = 1L))
+}
+
+readDate <- function(con) {
+ as.Date(readString(con))
+}
+
+readTime <- function(con) {
+ t <- readDouble(con)
+ as.POSIXct(t, origin = "1970-01-01")
+}
+
+# We only support lists where all elements are of same type
+readList <- function(con) {
+ type <- readType(con)
+ len <- readInt(con)
+ if (len > 0) {
+ l <- vector("list", len)
+ for (i in 1:len) {
+ l[[i]] <- readTypedObject(con, type)
+ }
+ l
+ } else {
+ list()
+ }
+}
+
+readRaw <- function(con) {
+ dataLen <- readInt(con)
+ data <- readBin(con, raw(), as.integer(dataLen), endian = "big")
+}
+
+readRawLen <- function(con, dataLen) {
+ data <- readBin(con, raw(), as.integer(dataLen), endian = "big")
+}
+
+readDeserialize <- function(con) {
+ # We have two cases that are possible - In one, the entire partition is
+ # encoded as a byte array, so we have only one value to read. If so just
+ # return firstData
+ dataLen <- readInt(con)
+ firstData <- unserialize(
+ readBin(con, raw(), as.integer(dataLen), endian = "big"))
+
+ # Else, read things into a list
+ dataLen <- readInt(con)
+ if (length(dataLen) > 0 && dataLen > 0) {
+ data <- list(firstData)
+ while (length(dataLen) > 0 && dataLen > 0) {
+ data[[length(data) + 1L]] <- unserialize(
+ readBin(con, raw(), as.integer(dataLen), endian = "big"))
+ dataLen <- readInt(con)
+ }
+ unlist(data, recursive = FALSE)
+ } else {
+ firstData
+ }
+}
+
+readDeserializeRows <- function(inputCon) {
+ # readDeserializeRows will deserialize a DataOutputStream composed of
+ # a list of lists. Since the DOS is one continuous stream and
+ # the number of rows varies, we put the readRow function in a while loop
+ # that termintates when the next row is empty.
+ data <- list()
+ while(TRUE) {
+ row <- readRow(inputCon)
+ if (length(row) == 0) {
+ break
+ }
+ data[[length(data) + 1L]] <- row
+ }
+ data # this is a list of named lists now
+}
+
+readRowList <- function(obj) {
+ # readRowList is meant for use inside an lapply. As a result, it is
+ # necessary to open a standalone connection for the row and consume
+ # the numCols bytes inside the read function in order to correctly
+ # deserialize the row.
+ rawObj <- rawConnection(obj, "r+")
+ on.exit(close(rawObj))
+ readRow(rawObj)
+}
+
+readRow <- function(inputCon) {
+ numCols <- readInt(inputCon)
+ if (length(numCols) > 0 && numCols > 0) {
+ lapply(1:numCols, function(x) {
+ obj <- readObject(inputCon)
+ if (is.null(obj)) {
+ NA
+ } else {
+ obj
+ }
+ }) # each row is a list now
+ } else {
+ list()
+ }
+}
+
+# Take a single column as Array[Byte] and deserialize it into an atomic vector
+readCol <- function(inputCon, numRows) {
+ # sapply can not work with POSIXlt
+ do.call(c, lapply(1:numRows, function(x) {
+ value <- readObject(inputCon)
+ # Replace NULL with NA so we can coerce to vectors
+ if (is.null(value)) NA else value
+ }))
+}
diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R
new file mode 100644
index 0000000000..5fb1ccaa84
--- /dev/null
+++ b/R/pkg/R/generics.R
@@ -0,0 +1,543 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+############ RDD Actions and Transformations ############
+
+#' @rdname aggregateRDD
+#' @seealso reduce
+#' @export
+setGeneric("aggregateRDD", function(x, zeroValue, seqOp, combOp) { standardGeneric("aggregateRDD") })
+
+#' @rdname cache-methods
+#' @export
+setGeneric("cache", function(x) { standardGeneric("cache") })
+
+#' @rdname coalesce
+#' @seealso repartition
+#' @export
+setGeneric("coalesce", function(x, numPartitions, ...) { standardGeneric("coalesce") })
+
+#' @rdname checkpoint-methods
+#' @export
+setGeneric("checkpoint", function(x) { standardGeneric("checkpoint") })
+
+#' @rdname collect-methods
+#' @export
+setGeneric("collect", function(x, ...) { standardGeneric("collect") })
+
+#' @rdname collect-methods
+#' @export
+setGeneric("collectAsMap", function(x) { standardGeneric("collectAsMap") })
+
+#' @rdname collect-methods
+#' @export
+setGeneric("collectPartition",
+ function(x, partitionId) {
+ standardGeneric("collectPartition")
+ })
+
+#' @rdname count
+#' @export
+setGeneric("count", function(x) { standardGeneric("count") })
+
+#' @rdname countByValue
+#' @export
+setGeneric("countByValue", function(x) { standardGeneric("countByValue") })
+
+#' @rdname distinct
+#' @export
+setGeneric("distinct", function(x, numPartitions = 1L) { standardGeneric("distinct") })
+
+#' @rdname filterRDD
+#' @export
+setGeneric("filterRDD", function(x, f) { standardGeneric("filterRDD") })
+
+#' @rdname first
+#' @export
+setGeneric("first", function(x) { standardGeneric("first") })
+
+#' @rdname flatMap
+#' @export
+setGeneric("flatMap", function(X, FUN) { standardGeneric("flatMap") })
+
+#' @rdname fold
+#' @seealso reduce
+#' @export
+setGeneric("fold", function(x, zeroValue, op) { standardGeneric("fold") })
+
+#' @rdname foreach
+#' @export
+setGeneric("foreach", function(x, func) { standardGeneric("foreach") })
+
+#' @rdname foreach
+#' @export
+setGeneric("foreachPartition", function(x, func) { standardGeneric("foreachPartition") })
+
+# The jrdd accessor function.
+setGeneric("getJRDD", function(rdd, ...) { standardGeneric("getJRDD") })
+
+#' @rdname glom
+#' @export
+setGeneric("glom", function(x) { standardGeneric("glom") })
+
+#' @rdname keyBy
+#' @export
+setGeneric("keyBy", function(x, func) { standardGeneric("keyBy") })
+
+#' @rdname lapplyPartition
+#' @export
+setGeneric("lapplyPartition", function(X, FUN) { standardGeneric("lapplyPartition") })
+
+#' @rdname lapplyPartitionsWithIndex
+#' @export
+setGeneric("lapplyPartitionsWithIndex",
+ function(X, FUN) {
+ standardGeneric("lapplyPartitionsWithIndex")
+ })
+
+#' @rdname lapply
+#' @export
+setGeneric("map", function(X, FUN) { standardGeneric("map") })
+
+#' @rdname lapplyPartition
+#' @export
+setGeneric("mapPartitions", function(X, FUN) { standardGeneric("mapPartitions") })
+
+#' @rdname lapplyPartitionsWithIndex
+#' @export
+setGeneric("mapPartitionsWithIndex",
+ function(X, FUN) { standardGeneric("mapPartitionsWithIndex") })
+
+#' @rdname maximum
+#' @export
+setGeneric("maximum", function(x) { standardGeneric("maximum") })
+
+#' @rdname minimum
+#' @export
+setGeneric("minimum", function(x) { standardGeneric("minimum") })
+
+#' @rdname sumRDD
+#' @export
+setGeneric("sumRDD", function(x) { standardGeneric("sumRDD") })
+
+#' @rdname name
+#' @export
+setGeneric("name", function(x) { standardGeneric("name") })
+
+#' @rdname numPartitions
+#' @export
+setGeneric("numPartitions", function(x) { standardGeneric("numPartitions") })
+
+#' @rdname persist
+#' @export
+setGeneric("persist", function(x, newLevel) { standardGeneric("persist") })
+
+#' @rdname pipeRDD
+#' @export
+setGeneric("pipeRDD", function(x, command, env = list()) { standardGeneric("pipeRDD")})
+
+#' @rdname reduce
+#' @export
+setGeneric("reduce", function(x, func) { standardGeneric("reduce") })
+
+#' @rdname repartition
+#' @seealso coalesce
+#' @export
+setGeneric("repartition", function(x, numPartitions) { standardGeneric("repartition") })
+
+#' @rdname sampleRDD
+#' @export
+setGeneric("sampleRDD",
+ function(x, withReplacement, fraction, seed) {
+ standardGeneric("sampleRDD")
+ })
+
+#' @rdname saveAsObjectFile
+#' @seealso objectFile
+#' @export
+setGeneric("saveAsObjectFile", function(x, path) { standardGeneric("saveAsObjectFile") })
+
+#' @rdname saveAsTextFile
+#' @export
+setGeneric("saveAsTextFile", function(x, path) { standardGeneric("saveAsTextFile") })
+
+#' @rdname setName
+#' @export
+setGeneric("setName", function(x, name) { standardGeneric("setName") })
+
+#' @rdname sortBy
+#' @export
+setGeneric("sortBy",
+ function(x, func, ascending = TRUE, numPartitions = 1L) {
+ standardGeneric("sortBy")
+ })
+
+#' @rdname take
+#' @export
+setGeneric("take", function(x, num) { standardGeneric("take") })
+
+#' @rdname takeOrdered
+#' @export
+setGeneric("takeOrdered", function(x, num) { standardGeneric("takeOrdered") })
+
+#' @rdname takeSample
+#' @export
+setGeneric("takeSample",
+ function(x, withReplacement, num, seed) {
+ standardGeneric("takeSample")
+ })
+
+#' @rdname top
+#' @export
+setGeneric("top", function(x, num) { standardGeneric("top") })
+
+#' @rdname unionRDD
+#' @export
+setGeneric("unionRDD", function(x, y) { standardGeneric("unionRDD") })
+
+#' @rdname unpersist-methods
+#' @export
+setGeneric("unpersist", function(x, ...) { standardGeneric("unpersist") })
+
+#' @rdname zipRDD
+#' @export
+setGeneric("zipRDD", function(x, other) { standardGeneric("zipRDD") })
+
+#' @rdname zipWithIndex
+#' @seealso zipWithUniqueId
+#' @export
+setGeneric("zipWithIndex", function(x) { standardGeneric("zipWithIndex") })
+
+#' @rdname zipWithUniqueId
+#' @seealso zipWithIndex
+#' @export
+setGeneric("zipWithUniqueId", function(x) { standardGeneric("zipWithUniqueId") })
+
+
+############ Binary Functions #############
+
+#' @rdname countByKey
+#' @export
+setGeneric("countByKey", function(x) { standardGeneric("countByKey") })
+
+#' @rdname flatMapValues
+#' @export
+setGeneric("flatMapValues", function(X, FUN) { standardGeneric("flatMapValues") })
+
+#' @rdname keys
+#' @export
+setGeneric("keys", function(x) { standardGeneric("keys") })
+
+#' @rdname lookup
+#' @export
+setGeneric("lookup", function(x, key) { standardGeneric("lookup") })
+
+#' @rdname mapValues
+#' @export
+setGeneric("mapValues", function(X, FUN) { standardGeneric("mapValues") })
+
+#' @rdname values
+#' @export
+setGeneric("values", function(x) { standardGeneric("values") })
+
+
+
+############ Shuffle Functions ############
+
+#' @rdname aggregateByKey
+#' @seealso foldByKey, combineByKey
+#' @export
+setGeneric("aggregateByKey",
+ function(x, zeroValue, seqOp, combOp, numPartitions) {
+ standardGeneric("aggregateByKey")
+ })
+
+#' @rdname cogroup
+#' @export
+setGeneric("cogroup",
+ function(..., numPartitions) {
+ standardGeneric("cogroup")
+ },
+ signature = "...")
+
+#' @rdname combineByKey
+#' @seealso groupByKey, reduceByKey
+#' @export
+setGeneric("combineByKey",
+ function(x, createCombiner, mergeValue, mergeCombiners, numPartitions) {
+ standardGeneric("combineByKey")
+ })
+
+#' @rdname foldByKey
+#' @seealso aggregateByKey, combineByKey
+#' @export
+setGeneric("foldByKey",
+ function(x, zeroValue, func, numPartitions) {
+ standardGeneric("foldByKey")
+ })
+
+#' @rdname join-methods
+#' @export
+setGeneric("fullOuterJoin", function(x, y, numPartitions) { standardGeneric("fullOuterJoin") })
+
+#' @rdname groupByKey
+#' @seealso reduceByKey
+#' @export
+setGeneric("groupByKey", function(x, numPartitions) { standardGeneric("groupByKey") })
+
+#' @rdname join-methods
+#' @export
+setGeneric("join", function(x, y, ...) { standardGeneric("join") })
+
+#' @rdname join-methods
+#' @export
+setGeneric("leftOuterJoin", function(x, y, numPartitions) { standardGeneric("leftOuterJoin") })
+
+#' @rdname partitionBy
+#' @export
+setGeneric("partitionBy", function(x, numPartitions, ...) { standardGeneric("partitionBy") })
+
+#' @rdname reduceByKey
+#' @seealso groupByKey
+#' @export
+setGeneric("reduceByKey", function(x, combineFunc, numPartitions) { standardGeneric("reduceByKey")})
+
+#' @rdname reduceByKeyLocally
+#' @seealso reduceByKey
+#' @export
+setGeneric("reduceByKeyLocally",
+ function(x, combineFunc) {
+ standardGeneric("reduceByKeyLocally")
+ })
+
+#' @rdname join-methods
+#' @export
+setGeneric("rightOuterJoin", function(x, y, numPartitions) { standardGeneric("rightOuterJoin") })
+
+#' @rdname sortByKey
+#' @export
+setGeneric("sortByKey", function(x, ascending = TRUE, numPartitions = 1L) {
+ standardGeneric("sortByKey")
+})
+
+
+################### Broadcast Variable Methods #################
+
+#' @rdname broadcast
+#' @export
+setGeneric("value", function(bcast) { standardGeneric("value") })
+
+
+
+#################### DataFrame Methods ########################
+
+#' @rdname schema
+#' @export
+setGeneric("columns", function(x) {standardGeneric("columns") })
+
+#' @rdname schema
+#' @export
+setGeneric("dtypes", function(x) { standardGeneric("dtypes") })
+
+#' @rdname explain
+#' @export
+setGeneric("explain", function(x, ...) { standardGeneric("explain") })
+
+#' @rdname filter
+#' @export
+setGeneric("filter", function(x, condition) { standardGeneric("filter") })
+
+#' @rdname DataFrame
+#' @export
+setGeneric("groupBy", function(x, ...) { standardGeneric("groupBy") })
+
+#' @rdname insertInto
+#' @export
+setGeneric("insertInto", function(x, tableName, ...) { standardGeneric("insertInto") })
+
+#' @rdname intersect
+#' @export
+setGeneric("intersect", function(x, y) { standardGeneric("intersect") })
+
+#' @rdname isLocal
+#' @export
+setGeneric("isLocal", function(x) { standardGeneric("isLocal") })
+
+#' @rdname limit
+#' @export
+setGeneric("limit", function(x, num) {standardGeneric("limit") })
+
+#' @rdname sortDF
+#' @export
+setGeneric("orderBy", function(x, col) { standardGeneric("orderBy") })
+
+#' @rdname schema
+#' @export
+setGeneric("printSchema", function(x) { standardGeneric("printSchema") })
+
+#' @rdname registerTempTable
+#' @export
+setGeneric("registerTempTable", function(x, tableName) { standardGeneric("registerTempTable") })
+
+#' @rdname sampleDF
+#' @export
+setGeneric("sampleDF",
+ function(x, withReplacement, fraction, seed) {
+ standardGeneric("sampleDF")
+ })
+
+#' @rdname saveAsParquetFile
+#' @export
+setGeneric("saveAsParquetFile", function(x, path) { standardGeneric("saveAsParquetFile") })
+
+#' @rdname saveAsTable
+#' @export
+setGeneric("saveAsTable", function(df, tableName, source, mode, ...) {
+ standardGeneric("saveAsTable")
+})
+
+#' @rdname saveAsTable
+#' @export
+setGeneric("saveDF", function(df, path, source, mode, ...) { standardGeneric("saveDF") })
+
+#' @rdname schema
+#' @export
+setGeneric("schema", function(x) { standardGeneric("schema") })
+
+#' @rdname select
+#' @export
+setGeneric("select", function(x, col, ...) { standardGeneric("select") } )
+
+#' @rdname select
+#' @export
+setGeneric("selectExpr", function(x, expr, ...) { standardGeneric("selectExpr") })
+
+#' @rdname showDF
+#' @export
+setGeneric("showDF", function(x,...) { standardGeneric("showDF") })
+
+#' @rdname sortDF
+#' @export
+setGeneric("sortDF", function(x, col, ...) { standardGeneric("sortDF") })
+
+#' @rdname subtract
+#' @export
+setGeneric("subtract", function(x, y) { standardGeneric("subtract") })
+
+#' @rdname tojson
+#' @export
+setGeneric("toJSON", function(x) { standardGeneric("toJSON") })
+
+#' @rdname DataFrame
+#' @export
+setGeneric("toRDD", function(x) { standardGeneric("toRDD") })
+
+#' @rdname unionAll
+#' @export
+setGeneric("unionAll", function(x, y) { standardGeneric("unionAll") })
+
+#' @rdname filter
+#' @export
+setGeneric("where", function(x, condition) { standardGeneric("where") })
+
+#' @rdname withColumn
+#' @export
+setGeneric("withColumn", function(x, colName, col) { standardGeneric("withColumn") })
+
+#' @rdname withColumnRenamed
+#' @export
+setGeneric("withColumnRenamed", function(x, existingCol, newCol) {
+ standardGeneric("withColumnRenamed") })
+
+
+###################### Column Methods ##########################
+
+#' @rdname column
+#' @export
+setGeneric("approxCountDistinct", function(x, ...) { standardGeneric("approxCountDistinct") })
+
+#' @rdname column
+#' @export
+setGeneric("asc", function(x) { standardGeneric("asc") })
+
+#' @rdname column
+#' @export
+setGeneric("avg", function(x, ...) { standardGeneric("avg") })
+
+#' @rdname column
+#' @export
+setGeneric("cast", function(x, dataType) { standardGeneric("cast") })
+
+#' @rdname column
+#' @export
+setGeneric("contains", function(x, ...) { standardGeneric("contains") })
+#' @rdname column
+#' @export
+setGeneric("countDistinct", function(x, ...) { standardGeneric("countDistinct") })
+
+#' @rdname column
+#' @export
+setGeneric("desc", function(x) { standardGeneric("desc") })
+
+#' @rdname column
+#' @export
+setGeneric("endsWith", function(x, ...) { standardGeneric("endsWith") })
+
+#' @rdname column
+#' @export
+setGeneric("getField", function(x, ...) { standardGeneric("getField") })
+
+#' @rdname column
+#' @export
+setGeneric("getItem", function(x, ...) { standardGeneric("getItem") })
+
+#' @rdname column
+#' @export
+setGeneric("isNull", function(x) { standardGeneric("isNull") })
+
+#' @rdname column
+#' @export
+setGeneric("isNotNull", function(x) { standardGeneric("isNotNull") })
+
+#' @rdname column
+#' @export
+setGeneric("last", function(x) { standardGeneric("last") })
+
+#' @rdname column
+#' @export
+setGeneric("like", function(x, ...) { standardGeneric("like") })
+
+#' @rdname column
+#' @export
+setGeneric("lower", function(x) { standardGeneric("lower") })
+
+#' @rdname column
+#' @export
+setGeneric("rlike", function(x, ...) { standardGeneric("rlike") })
+
+#' @rdname column
+#' @export
+setGeneric("startsWith", function(x, ...) { standardGeneric("startsWith") })
+
+#' @rdname column
+#' @export
+setGeneric("sumDistinct", function(x) { standardGeneric("sumDistinct") })
+
+#' @rdname column
+#' @export
+setGeneric("upper", function(x) { standardGeneric("upper") })
+
diff --git a/R/pkg/R/group.R b/R/pkg/R/group.R
new file mode 100644
index 0000000000..09fc0a7abe
--- /dev/null
+++ b/R/pkg/R/group.R
@@ -0,0 +1,132 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# group.R - GroupedData class and methods implemented in S4 OO classes
+
+setOldClass("jobj")
+
+#' @title S4 class that represents a GroupedData
+#' @description GroupedDatas can be created using groupBy() on a DataFrame
+#' @rdname GroupedData
+#' @seealso groupBy
+#'
+#' @param sgd A Java object reference to the backing Scala GroupedData
+#' @export
+setClass("GroupedData",
+ slots = list(sgd = "jobj"))
+
+setMethod("initialize", "GroupedData", function(.Object, sgd) {
+ .Object@sgd <- sgd
+ .Object
+})
+
+#' @rdname DataFrame
+groupedData <- function(sgd) {
+ new("GroupedData", sgd)
+}
+
+
+#' @rdname show
+setMethod("show", "GroupedData",
+ function(object) {
+ cat("GroupedData\n")
+ })
+
+#' Count
+#'
+#' Count the number of rows for each group.
+#' The resulting DataFrame will also contain the grouping columns.
+#'
+#' @param x a GroupedData
+#' @return a DataFrame
+#' @export
+#' @examples
+#' \dontrun{
+#' count(groupBy(df, "name"))
+#' }
+setMethod("count",
+ signature(x = "GroupedData"),
+ function(x) {
+ dataFrame(callJMethod(x@sgd, "count"))
+ })
+
+#' Agg
+#'
+#' Aggregates on the entire DataFrame without groups.
+#' The resulting DataFrame will also contain the grouping columns.
+#'
+#' df2 <- agg(df, <column> = <aggFunction>)
+#' df2 <- agg(df, newColName = aggFunction(column))
+#'
+#' @param x a GroupedData
+#' @return a DataFrame
+#' @rdname agg
+#' @examples
+#' \dontrun{
+#' df2 <- agg(df, age = "sum") # new column name will be created as 'SUM(age#0)'
+#' df2 <- agg(df, ageSum = sum(df$age)) # Creates a new column named ageSum
+#' }
+setGeneric("agg", function (x, ...) { standardGeneric("agg") })
+
+setMethod("agg",
+ signature(x = "GroupedData"),
+ function(x, ...) {
+ cols = list(...)
+ stopifnot(length(cols) > 0)
+ if (is.character(cols[[1]])) {
+ cols <- varargsToEnv(...)
+ sdf <- callJMethod(x@sgd, "agg", cols)
+ } else if (class(cols[[1]]) == "Column") {
+ ns <- names(cols)
+ if (!is.null(ns)) {
+ for (n in ns) {
+ if (n != "") {
+ cols[[n]] = alias(cols[[n]], n)
+ }
+ }
+ }
+ jcols <- lapply(cols, function(c) { c@jc })
+ # the GroupedData.agg(col, cols*) API does not contain grouping Column
+ sdf <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "aggWithGrouping",
+ x@sgd, listToSeq(jcols))
+ } else {
+ stop("agg can only support Column or character")
+ }
+ dataFrame(sdf)
+ })
+
+
+# sum/mean/avg/min/max
+methods <- c("sum", "mean", "avg", "min", "max")
+
+createMethod <- function(name) {
+ setMethod(name,
+ signature(x = "GroupedData"),
+ function(x, ...) {
+ sdf <- callJMethod(x@sgd, name, toSeq(...))
+ dataFrame(sdf)
+ })
+}
+
+createMethods <- function() {
+ for (name in methods) {
+ createMethod(name)
+ }
+}
+
+createMethods()
+
diff --git a/R/pkg/R/jobj.R b/R/pkg/R/jobj.R
new file mode 100644
index 0000000000..4180f146b7
--- /dev/null
+++ b/R/pkg/R/jobj.R
@@ -0,0 +1,101 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# References to objects that exist on the JVM backend
+# are maintained using the jobj.
+
+# Maintain a reference count of Java object references
+# This allows us to GC the java object when it is safe
+.validJobjs <- new.env(parent = emptyenv())
+
+# List of object ids to be removed
+.toRemoveJobjs <- new.env(parent = emptyenv())
+
+# Check if jobj was created with the current SparkContext
+isValidJobj <- function(jobj) {
+ if (exists(".scStartTime", envir = .sparkREnv)) {
+ jobj$appId == get(".scStartTime", envir = .sparkREnv)
+ } else {
+ FALSE
+ }
+}
+
+getJobj <- function(objId) {
+ newObj <- jobj(objId)
+ if (exists(objId, .validJobjs)) {
+ .validJobjs[[objId]] <- .validJobjs[[objId]] + 1
+ } else {
+ .validJobjs[[objId]] <- 1
+ }
+ newObj
+}
+
+# Handler for a java object that exists on the backend.
+jobj <- function(objId) {
+ if (!is.character(objId)) {
+ stop("object id must be a character")
+ }
+ # NOTE: We need a new env for a jobj as we can only register
+ # finalizers for environments or external references pointers.
+ obj <- structure(new.env(parent = emptyenv()), class = "jobj")
+ obj$id <- objId
+ obj$appId <- get(".scStartTime", envir = .sparkREnv)
+
+ # Register a finalizer to remove the Java object when this reference
+ # is garbage collected in R
+ reg.finalizer(obj, cleanup.jobj)
+ obj
+}
+
+#' Print a JVM object reference.
+#'
+#' This function prints the type and id for an object stored
+#' in the SparkR JVM backend.
+#'
+#' @param x The JVM object reference
+#' @param ... further arguments passed to or from other methods
+print.jobj <- function(x, ...) {
+ cls <- callJMethod(x, "getClass")
+ name <- callJMethod(cls, "getName")
+ cat("Java ref type", name, "id", x$id, "\n", sep = " ")
+}
+
+cleanup.jobj <- function(jobj) {
+ if (isValidJobj(jobj)) {
+ objId <- jobj$id
+ # If we don't know anything about this jobj, ignore it
+ if (exists(objId, envir = .validJobjs)) {
+ .validJobjs[[objId]] <- .validJobjs[[objId]] - 1
+
+ if (.validJobjs[[objId]] == 0) {
+ rm(list = objId, envir = .validJobjs)
+ # NOTE: We cannot call removeJObject here as the finalizer may be run
+ # in the middle of another RPC. Thus we queue up this object Id to be removed
+ # and then run all the removeJObject when the next RPC is called.
+ .toRemoveJobjs[[objId]] <- 1
+ }
+ }
+ }
+}
+
+clearJobjs <- function() {
+ valid <- ls(.validJobjs)
+ rm(list = valid, envir = .validJobjs)
+
+ removeList <- ls(.toRemoveJobjs)
+ rm(list = removeList, envir = .toRemoveJobjs)
+}
diff --git a/R/pkg/R/pairRDD.R b/R/pkg/R/pairRDD.R
new file mode 100644
index 0000000000..c2396c32a7
--- /dev/null
+++ b/R/pkg/R/pairRDD.R
@@ -0,0 +1,789 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# Operations supported on RDDs contains pairs (i.e key, value)
+
+############ Actions and Transformations ############
+
+#' Look up elements of a key in an RDD
+#'
+#' @description
+#' \code{lookup} returns a list of values in this RDD for key key.
+#'
+#' @param x The RDD to collect
+#' @param key The key to look up for
+#' @return a list of values in this RDD for key key
+#' @examples
+#'\dontrun{
+#' sc <- sparkR.init()
+#' pairs <- list(c(1, 1), c(2, 2), c(1, 3))
+#' rdd <- parallelize(sc, pairs)
+#' lookup(rdd, 1) # list(1, 3)
+#'}
+#' @rdname lookup
+#' @aliases lookup,RDD-method
+setMethod("lookup",
+ signature(x = "RDD", key = "ANY"),
+ function(x, key) {
+ partitionFunc <- function(part) {
+ filtered <- part[unlist(lapply(part, function(i) { identical(key, i[[1]]) }))]
+ lapply(filtered, function(i) { i[[2]] })
+ }
+ valsRDD <- lapplyPartition(x, partitionFunc)
+ collect(valsRDD)
+ })
+
+#' Count the number of elements for each key, and return the result to the
+#' master as lists of (key, count) pairs.
+#'
+#' Same as countByKey in Spark.
+#'
+#' @param x The RDD to count keys.
+#' @return list of (key, count) pairs, where count is number of each key in rdd.
+#' @examples
+#'\dontrun{
+#' sc <- sparkR.init()
+#' rdd <- parallelize(sc, list(c("a", 1), c("b", 1), c("a", 1)))
+#' countByKey(rdd) # ("a", 2L), ("b", 1L)
+#'}
+#' @rdname countByKey
+#' @aliases countByKey,RDD-method
+setMethod("countByKey",
+ signature(x = "RDD"),
+ function(x) {
+ keys <- lapply(x, function(item) { item[[1]] })
+ countByValue(keys)
+ })
+
+#' Return an RDD with the keys of each tuple.
+#'
+#' @param x The RDD from which the keys of each tuple is returned.
+#' @examples
+#'\dontrun{
+#' sc <- sparkR.init()
+#' rdd <- parallelize(sc, list(list(1, 2), list(3, 4)))
+#' collect(keys(rdd)) # list(1, 3)
+#'}
+#' @rdname keys
+#' @aliases keys,RDD
+setMethod("keys",
+ signature(x = "RDD"),
+ function(x) {
+ func <- function(k) {
+ k[[1]]
+ }
+ lapply(x, func)
+ })
+
+#' Return an RDD with the values of each tuple.
+#'
+#' @param x The RDD from which the values of each tuple is returned.
+#' @examples
+#'\dontrun{
+#' sc <- sparkR.init()
+#' rdd <- parallelize(sc, list(list(1, 2), list(3, 4)))
+#' collect(values(rdd)) # list(2, 4)
+#'}
+#' @rdname values
+#' @aliases values,RDD
+setMethod("values",
+ signature(x = "RDD"),
+ function(x) {
+ func <- function(v) {
+ v[[2]]
+ }
+ lapply(x, func)
+ })
+
+#' Applies a function to all values of the elements, without modifying the keys.
+#'
+#' The same as `mapValues()' in Spark.
+#'
+#' @param X The RDD to apply the transformation.
+#' @param FUN the transformation to apply on the value of each element.
+#' @return a new RDD created by the transformation.
+#' @examples
+#'\dontrun{
+#' sc <- sparkR.init()
+#' rdd <- parallelize(sc, 1:10)
+#' makePairs <- lapply(rdd, function(x) { list(x, x) })
+#' collect(mapValues(makePairs, function(x) { x * 2) })
+#' Output: list(list(1,2), list(2,4), list(3,6), ...)
+#'}
+#' @rdname mapValues
+#' @aliases mapValues,RDD,function-method
+setMethod("mapValues",
+ signature(X = "RDD", FUN = "function"),
+ function(X, FUN) {
+ func <- function(x) {
+ list(x[[1]], FUN(x[[2]]))
+ }
+ lapply(X, func)
+ })
+
+#' Pass each value in the key-value pair RDD through a flatMap function without
+#' changing the keys; this also retains the original RDD's partitioning.
+#'
+#' The same as 'flatMapValues()' in Spark.
+#'
+#' @param X The RDD to apply the transformation.
+#' @param FUN the transformation to apply on the value of each element.
+#' @return a new RDD created by the transformation.
+#' @examples
+#'\dontrun{
+#' sc <- sparkR.init()
+#' rdd <- parallelize(sc, list(list(1, c(1,2)), list(2, c(3,4))))
+#' collect(flatMapValues(rdd, function(x) { x }))
+#' Output: list(list(1,1), list(1,2), list(2,3), list(2,4))
+#'}
+#' @rdname flatMapValues
+#' @aliases flatMapValues,RDD,function-method
+setMethod("flatMapValues",
+ signature(X = "RDD", FUN = "function"),
+ function(X, FUN) {
+ flatMapFunc <- function(x) {
+ lapply(FUN(x[[2]]), function(v) { list(x[[1]], v) })
+ }
+ flatMap(X, flatMapFunc)
+ })
+
+############ Shuffle Functions ############
+
+#' Partition an RDD by key
+#'
+#' This function operates on RDDs where every element is of the form list(K, V) or c(K, V).
+#' For each element of this RDD, the partitioner is used to compute a hash
+#' function and the RDD is partitioned using this hash value.
+#'
+#' @param x The RDD to partition. Should be an RDD where each element is
+#' list(K, V) or c(K, V).
+#' @param numPartitions Number of partitions to create.
+#' @param ... Other optional arguments to partitionBy.
+#'
+#' @param partitionFunc The partition function to use. Uses a default hashCode
+#' function if not provided
+#' @return An RDD partitioned using the specified partitioner.
+#' @examples
+#'\dontrun{
+#' sc <- sparkR.init()
+#' pairs <- list(list(1, 2), list(1.1, 3), list(1, 4))
+#' rdd <- parallelize(sc, pairs)
+#' parts <- partitionBy(rdd, 2L)
+#' collectPartition(parts, 0L) # First partition should contain list(1, 2) and list(1, 4)
+#'}
+#' @rdname partitionBy
+#' @aliases partitionBy,RDD,integer-method
+setMethod("partitionBy",
+ signature(x = "RDD", numPartitions = "integer"),
+ function(x, numPartitions, partitionFunc = hashCode) {
+
+ #if (missing(partitionFunc)) {
+ # partitionFunc <- hashCode
+ #}
+
+ partitionFunc <- cleanClosure(partitionFunc)
+ serializedHashFuncBytes <- serialize(partitionFunc, connection = NULL)
+
+ packageNamesArr <- serialize(.sparkREnv$.packages,
+ connection = NULL)
+ broadcastArr <- lapply(ls(.broadcastNames), function(name) {
+ get(name, .broadcastNames) })
+ jrdd <- getJRDD(x)
+
+ # We create a PairwiseRRDD that extends RDD[(Array[Byte],
+ # Array[Byte])], where the key is the hashed split, the value is
+ # the content (key-val pairs).
+ pairwiseRRDD <- newJObject("org.apache.spark.api.r.PairwiseRRDD",
+ callJMethod(jrdd, "rdd"),
+ as.integer(numPartitions),
+ serializedHashFuncBytes,
+ getSerializedMode(x),
+ packageNamesArr,
+ as.character(.sparkREnv$libname),
+ broadcastArr,
+ callJMethod(jrdd, "classTag"))
+
+ # Create a corresponding partitioner.
+ rPartitioner <- newJObject("org.apache.spark.HashPartitioner",
+ as.integer(numPartitions))
+
+ # Call partitionBy on the obtained PairwiseRDD.
+ javaPairRDD <- callJMethod(pairwiseRRDD, "asJavaPairRDD")
+ javaPairRDD <- callJMethod(javaPairRDD, "partitionBy", rPartitioner)
+
+ # Call .values() on the result to get back the final result, the
+ # shuffled acutal content key-val pairs.
+ r <- callJMethod(javaPairRDD, "values")
+
+ RDD(r, serializedMode = "byte")
+ })
+
+#' Group values by key
+#'
+#' This function operates on RDDs where every element is of the form list(K, V) or c(K, V).
+#' and group values for each key in the RDD into a single sequence.
+#'
+#' @param x The RDD to group. Should be an RDD where each element is
+#' list(K, V) or c(K, V).
+#' @param numPartitions Number of partitions to create.
+#' @return An RDD where each element is list(K, list(V))
+#' @seealso reduceByKey
+#' @examples
+#'\dontrun{
+#' sc <- sparkR.init()
+#' pairs <- list(list(1, 2), list(1.1, 3), list(1, 4))
+#' rdd <- parallelize(sc, pairs)
+#' parts <- groupByKey(rdd, 2L)
+#' grouped <- collect(parts)
+#' grouped[[1]] # Should be a list(1, list(2, 4))
+#'}
+#' @rdname groupByKey
+#' @aliases groupByKey,RDD,integer-method
+setMethod("groupByKey",
+ signature(x = "RDD", numPartitions = "integer"),
+ function(x, numPartitions) {
+ shuffled <- partitionBy(x, numPartitions)
+ groupVals <- function(part) {
+ vals <- new.env()
+ keys <- new.env()
+ pred <- function(item) exists(item$hash, keys)
+ appendList <- function(acc, i) {
+ addItemToAccumulator(acc, i)
+ acc
+ }
+ makeList <- function(i) {
+ acc <- initAccumulator()
+ addItemToAccumulator(acc, i)
+ acc
+ }
+ # Each item in the partition is list of (K, V)
+ lapply(part,
+ function(item) {
+ item$hash <- as.character(hashCode(item[[1]]))
+ updateOrCreatePair(item, keys, vals, pred,
+ appendList, makeList)
+ })
+ # extract out data field
+ vals <- eapply(vals,
+ function(i) {
+ length(i$data) <- i$counter
+ i$data
+ })
+ # Every key in the environment contains a list
+ # Convert that to list(K, Seq[V])
+ convertEnvsToList(keys, vals)
+ }
+ lapplyPartition(shuffled, groupVals)
+ })
+
+#' Merge values by key
+#'
+#' This function operates on RDDs where every element is of the form list(K, V) or c(K, V).
+#' and merges the values for each key using an associative reduce function.
+#'
+#' @param x The RDD to reduce by key. Should be an RDD where each element is
+#' list(K, V) or c(K, V).
+#' @param combineFunc The associative reduce function to use.
+#' @param numPartitions Number of partitions to create.
+#' @return An RDD where each element is list(K, V') where V' is the merged
+#' value
+#' @examples
+#'\dontrun{
+#' sc <- sparkR.init()
+#' pairs <- list(list(1, 2), list(1.1, 3), list(1, 4))
+#' rdd <- parallelize(sc, pairs)
+#' parts <- reduceByKey(rdd, "+", 2L)
+#' reduced <- collect(parts)
+#' reduced[[1]] # Should be a list(1, 6)
+#'}
+#' @rdname reduceByKey
+#' @aliases reduceByKey,RDD,integer-method
+setMethod("reduceByKey",
+ signature(x = "RDD", combineFunc = "ANY", numPartitions = "integer"),
+ function(x, combineFunc, numPartitions) {
+ reduceVals <- function(part) {
+ vals <- new.env()
+ keys <- new.env()
+ pred <- function(item) exists(item$hash, keys)
+ lapply(part,
+ function(item) {
+ item$hash <- as.character(hashCode(item[[1]]))
+ updateOrCreatePair(item, keys, vals, pred, combineFunc, identity)
+ })
+ convertEnvsToList(keys, vals)
+ }
+ locallyReduced <- lapplyPartition(x, reduceVals)
+ shuffled <- partitionBy(locallyReduced, numPartitions)
+ lapplyPartition(shuffled, reduceVals)
+ })
+
+#' Merge values by key locally
+#'
+#' This function operates on RDDs where every element is of the form list(K, V) or c(K, V).
+#' and merges the values for each key using an associative reduce function, but return the
+#' results immediately to the driver as an R list.
+#'
+#' @param x The RDD to reduce by key. Should be an RDD where each element is
+#' list(K, V) or c(K, V).
+#' @param combineFunc The associative reduce function to use.
+#' @return A list of elements of type list(K, V') where V' is the merged value for each key
+#' @seealso reduceByKey
+#' @examples
+#'\dontrun{
+#' sc <- sparkR.init()
+#' pairs <- list(list(1, 2), list(1.1, 3), list(1, 4))
+#' rdd <- parallelize(sc, pairs)
+#' reduced <- reduceByKeyLocally(rdd, "+")
+#' reduced # list(list(1, 6), list(1.1, 3))
+#'}
+#' @rdname reduceByKeyLocally
+#' @aliases reduceByKeyLocally,RDD,integer-method
+setMethod("reduceByKeyLocally",
+ signature(x = "RDD", combineFunc = "ANY"),
+ function(x, combineFunc) {
+ reducePart <- function(part) {
+ vals <- new.env()
+ keys <- new.env()
+ pred <- function(item) exists(item$hash, keys)
+ lapply(part,
+ function(item) {
+ item$hash <- as.character(hashCode(item[[1]]))
+ updateOrCreatePair(item, keys, vals, pred, combineFunc, identity)
+ })
+ list(list(keys, vals)) # return hash to avoid re-compute in merge
+ }
+ mergeParts <- function(accum, x) {
+ pred <- function(item) {
+ exists(item$hash, accum[[1]])
+ }
+ lapply(ls(x[[1]]),
+ function(name) {
+ item <- list(x[[1]][[name]], x[[2]][[name]])
+ item$hash <- name
+ updateOrCreatePair(item, accum[[1]], accum[[2]], pred, combineFunc, identity)
+ })
+ accum
+ }
+ reduced <- mapPartitions(x, reducePart)
+ merged <- reduce(reduced, mergeParts)
+ convertEnvsToList(merged[[1]], merged[[2]])
+ })
+
+#' Combine values by key
+#'
+#' Generic function to combine the elements for each key using a custom set of
+#' aggregation functions. Turns an RDD[(K, V)] into a result of type RDD[(K, C)],
+#' for a "combined type" C. Note that V and C can be different -- for example, one
+#' might group an RDD of type (Int, Int) into an RDD of type (Int, Seq[Int]).
+
+#' Users provide three functions:
+#' \itemize{
+#' \item createCombiner, which turns a V into a C (e.g., creates a one-element list)
+#' \item mergeValue, to merge a V into a C (e.g., adds it to the end of a list) -
+#' \item mergeCombiners, to combine two C's into a single one (e.g., concatentates
+#' two lists).
+#' }
+#'
+#' @param x The RDD to combine. Should be an RDD where each element is
+#' list(K, V) or c(K, V).
+#' @param createCombiner Create a combiner (C) given a value (V)
+#' @param mergeValue Merge the given value (V) with an existing combiner (C)
+#' @param mergeCombiners Merge two combiners and return a new combiner
+#' @param numPartitions Number of partitions to create.
+#' @return An RDD where each element is list(K, C) where C is the combined type
+#'
+#' @seealso groupByKey, reduceByKey
+#' @examples
+#'\dontrun{
+#' sc <- sparkR.init()
+#' pairs <- list(list(1, 2), list(1.1, 3), list(1, 4))
+#' rdd <- parallelize(sc, pairs)
+#' parts <- combineByKey(rdd, function(x) { x }, "+", "+", 2L)
+#' combined <- collect(parts)
+#' combined[[1]] # Should be a list(1, 6)
+#'}
+#' @rdname combineByKey
+#' @aliases combineByKey,RDD,ANY,ANY,ANY,integer-method
+setMethod("combineByKey",
+ signature(x = "RDD", createCombiner = "ANY", mergeValue = "ANY",
+ mergeCombiners = "ANY", numPartitions = "integer"),
+ function(x, createCombiner, mergeValue, mergeCombiners, numPartitions) {
+ combineLocally <- function(part) {
+ combiners <- new.env()
+ keys <- new.env()
+ pred <- function(item) exists(item$hash, keys)
+ lapply(part,
+ function(item) {
+ item$hash <- as.character(item[[1]])
+ updateOrCreatePair(item, keys, combiners, pred, mergeValue, createCombiner)
+ })
+ convertEnvsToList(keys, combiners)
+ }
+ locallyCombined <- lapplyPartition(x, combineLocally)
+ shuffled <- partitionBy(locallyCombined, numPartitions)
+ mergeAfterShuffle <- function(part) {
+ combiners <- new.env()
+ keys <- new.env()
+ pred <- function(item) exists(item$hash, keys)
+ lapply(part,
+ function(item) {
+ item$hash <- as.character(item[[1]])
+ updateOrCreatePair(item, keys, combiners, pred, mergeCombiners, identity)
+ })
+ convertEnvsToList(keys, combiners)
+ }
+ lapplyPartition(shuffled, mergeAfterShuffle)
+ })
+
+#' Aggregate a pair RDD by each key.
+#'
+#' Aggregate the values of each key in an RDD, using given combine functions
+#' and a neutral "zero value". This function can return a different result type,
+#' U, than the type of the values in this RDD, V. Thus, we need one operation
+#' for merging a V into a U and one operation for merging two U's, The former
+#' operation is used for merging values within a partition, and the latter is
+#' used for merging values between partitions. To avoid memory allocation, both
+#' of these functions are allowed to modify and return their first argument
+#' instead of creating a new U.
+#'
+#' @param x An RDD.
+#' @param zeroValue A neutral "zero value".
+#' @param seqOp A function to aggregate the values of each key. It may return
+#' a different result type from the type of the values.
+#' @param combOp A function to aggregate results of seqOp.
+#' @return An RDD containing the aggregation result.
+#' @seealso foldByKey, combineByKey
+#' @examples
+#'\dontrun{
+#' sc <- sparkR.init()
+#' rdd <- parallelize(sc, list(list(1, 1), list(1, 2), list(2, 3), list(2, 4)))
+#' zeroValue <- list(0, 0)
+#' seqOp <- function(x, y) { list(x[[1]] + y, x[[2]] + 1) }
+#' combOp <- function(x, y) { list(x[[1]] + y[[1]], x[[2]] + y[[2]]) }
+#' aggregateByKey(rdd, zeroValue, seqOp, combOp, 2L)
+#' # list(list(1, list(3, 2)), list(2, list(7, 2)))
+#'}
+#' @rdname aggregateByKey
+#' @aliases aggregateByKey,RDD,ANY,ANY,ANY,integer-method
+setMethod("aggregateByKey",
+ signature(x = "RDD", zeroValue = "ANY", seqOp = "ANY",
+ combOp = "ANY", numPartitions = "integer"),
+ function(x, zeroValue, seqOp, combOp, numPartitions) {
+ createCombiner <- function(v) {
+ do.call(seqOp, list(zeroValue, v))
+ }
+
+ combineByKey(x, createCombiner, seqOp, combOp, numPartitions)
+ })
+
+#' Fold a pair RDD by each key.
+#'
+#' Aggregate the values of each key in an RDD, using an associative function "func"
+#' and a neutral "zero value" which may be added to the result an arbitrary
+#' number of times, and must not change the result (e.g., 0 for addition, or
+#' 1 for multiplication.).
+#'
+#' @param x An RDD.
+#' @param zeroValue A neutral "zero value".
+#' @param func An associative function for folding values of each key.
+#' @return An RDD containing the aggregation result.
+#' @seealso aggregateByKey, combineByKey
+#' @examples
+#'\dontrun{
+#' sc <- sparkR.init()
+#' rdd <- parallelize(sc, list(list(1, 1), list(1, 2), list(2, 3), list(2, 4)))
+#' foldByKey(rdd, 0, "+", 2L) # list(list(1, 3), list(2, 7))
+#'}
+#' @rdname foldByKey
+#' @aliases foldByKey,RDD,ANY,ANY,integer-method
+setMethod("foldByKey",
+ signature(x = "RDD", zeroValue = "ANY",
+ func = "ANY", numPartitions = "integer"),
+ function(x, zeroValue, func, numPartitions) {
+ aggregateByKey(x, zeroValue, func, func, numPartitions)
+ })
+
+############ Binary Functions #############
+
+#' Join two RDDs
+#'
+#' @description
+#' \code{join} This function joins two RDDs where every element is of the form list(K, V).
+#' The key types of the two RDDs should be the same.
+#'
+#' @param x An RDD to be joined. Should be an RDD where each element is
+#' list(K, V).
+#' @param y An RDD to be joined. Should be an RDD where each element is
+#' list(K, V).
+#' @param numPartitions Number of partitions to create.
+#' @return a new RDD containing all pairs of elements with matching keys in
+#' two input RDDs.
+#' @examples
+#'\dontrun{
+#' sc <- sparkR.init()
+#' rdd1 <- parallelize(sc, list(list(1, 1), list(2, 4)))
+#' rdd2 <- parallelize(sc, list(list(1, 2), list(1, 3)))
+#' join(rdd1, rdd2, 2L) # list(list(1, list(1, 2)), list(1, list(1, 3))
+#'}
+#' @rdname join-methods
+#' @aliases join,RDD,RDD-method
+setMethod("join",
+ signature(x = "RDD", y = "RDD"),
+ function(x, y, numPartitions) {
+ xTagged <- lapply(x, function(i) { list(i[[1]], list(1L, i[[2]])) })
+ yTagged <- lapply(y, function(i) { list(i[[1]], list(2L, i[[2]])) })
+
+ doJoin <- function(v) {
+ joinTaggedList(v, list(FALSE, FALSE))
+ }
+
+ joined <- flatMapValues(groupByKey(unionRDD(xTagged, yTagged), numToInt(numPartitions)),
+ doJoin)
+ })
+
+#' Left outer join two RDDs
+#'
+#' @description
+#' \code{leftouterjoin} This function left-outer-joins two RDDs where every element is of the form list(K, V).
+#' The key types of the two RDDs should be the same.
+#'
+#' @param x An RDD to be joined. Should be an RDD where each element is
+#' list(K, V).
+#' @param y An RDD to be joined. Should be an RDD where each element is
+#' list(K, V).
+#' @param numPartitions Number of partitions to create.
+#' @return For each element (k, v) in x, the resulting RDD will either contain
+#' all pairs (k, (v, w)) for (k, w) in rdd2, or the pair (k, (v, NULL))
+#' if no elements in rdd2 have key k.
+#' @examples
+#'\dontrun{
+#' sc <- sparkR.init()
+#' rdd1 <- parallelize(sc, list(list(1, 1), list(2, 4)))
+#' rdd2 <- parallelize(sc, list(list(1, 2), list(1, 3)))
+#' leftOuterJoin(rdd1, rdd2, 2L)
+#' # list(list(1, list(1, 2)), list(1, list(1, 3)), list(2, list(4, NULL)))
+#'}
+#' @rdname join-methods
+#' @aliases leftOuterJoin,RDD,RDD-method
+setMethod("leftOuterJoin",
+ signature(x = "RDD", y = "RDD", numPartitions = "integer"),
+ function(x, y, numPartitions) {
+ xTagged <- lapply(x, function(i) { list(i[[1]], list(1L, i[[2]])) })
+ yTagged <- lapply(y, function(i) { list(i[[1]], list(2L, i[[2]])) })
+
+ doJoin <- function(v) {
+ joinTaggedList(v, list(FALSE, TRUE))
+ }
+
+ joined <- flatMapValues(groupByKey(unionRDD(xTagged, yTagged), numPartitions), doJoin)
+ })
+
+#' Right outer join two RDDs
+#'
+#' @description
+#' \code{rightouterjoin} This function right-outer-joins two RDDs where every element is of the form list(K, V).
+#' The key types of the two RDDs should be the same.
+#'
+#' @param x An RDD to be joined. Should be an RDD where each element is
+#' list(K, V).
+#' @param y An RDD to be joined. Should be an RDD where each element is
+#' list(K, V).
+#' @param numPartitions Number of partitions to create.
+#' @return For each element (k, w) in y, the resulting RDD will either contain
+#' all pairs (k, (v, w)) for (k, v) in x, or the pair (k, (NULL, w))
+#' if no elements in x have key k.
+#' @examples
+#'\dontrun{
+#' sc <- sparkR.init()
+#' rdd1 <- parallelize(sc, list(list(1, 2), list(1, 3)))
+#' rdd2 <- parallelize(sc, list(list(1, 1), list(2, 4)))
+#' rightOuterJoin(rdd1, rdd2, 2L)
+#' # list(list(1, list(2, 1)), list(1, list(3, 1)), list(2, list(NULL, 4)))
+#'}
+#' @rdname join-methods
+#' @aliases rightOuterJoin,RDD,RDD-method
+setMethod("rightOuterJoin",
+ signature(x = "RDD", y = "RDD", numPartitions = "integer"),
+ function(x, y, numPartitions) {
+ xTagged <- lapply(x, function(i) { list(i[[1]], list(1L, i[[2]])) })
+ yTagged <- lapply(y, function(i) { list(i[[1]], list(2L, i[[2]])) })
+
+ doJoin <- function(v) {
+ joinTaggedList(v, list(TRUE, FALSE))
+ }
+
+ joined <- flatMapValues(groupByKey(unionRDD(xTagged, yTagged), numPartitions), doJoin)
+ })
+
+#' Full outer join two RDDs
+#'
+#' @description
+#' \code{fullouterjoin} This function full-outer-joins two RDDs where every element is of the form list(K, V).
+#' The key types of the two RDDs should be the same.
+#'
+#' @param x An RDD to be joined. Should be an RDD where each element is
+#' list(K, V).
+#' @param y An RDD to be joined. Should be an RDD where each element is
+#' list(K, V).
+#' @param numPartitions Number of partitions to create.
+#' @return For each element (k, v) in x and (k, w) in y, the resulting RDD
+#' will contain all pairs (k, (v, w)) for both (k, v) in x and
+#' (k, w) in y, or the pair (k, (NULL, w))/(k, (v, NULL)) if no elements
+#' in x/y have key k.
+#' @examples
+#'\dontrun{
+#' sc <- sparkR.init()
+#' rdd1 <- parallelize(sc, list(list(1, 2), list(1, 3), list(3, 3)))
+#' rdd2 <- parallelize(sc, list(list(1, 1), list(2, 4)))
+#' fullOuterJoin(rdd1, rdd2, 2L) # list(list(1, list(2, 1)),
+#' # list(1, list(3, 1)),
+#' # list(2, list(NULL, 4)))
+#' # list(3, list(3, NULL)),
+#'}
+#' @rdname join-methods
+#' @aliases fullOuterJoin,RDD,RDD-method
+setMethod("fullOuterJoin",
+ signature(x = "RDD", y = "RDD", numPartitions = "integer"),
+ function(x, y, numPartitions) {
+ xTagged <- lapply(x, function(i) { list(i[[1]], list(1L, i[[2]])) })
+ yTagged <- lapply(y, function(i) { list(i[[1]], list(2L, i[[2]])) })
+
+ doJoin <- function(v) {
+ joinTaggedList(v, list(TRUE, TRUE))
+ }
+
+ joined <- flatMapValues(groupByKey(unionRDD(xTagged, yTagged), numPartitions), doJoin)
+ })
+
+#' For each key k in several RDDs, return a resulting RDD that
+#' whose values are a list of values for the key in all RDDs.
+#'
+#' @param ... Several RDDs.
+#' @param numPartitions Number of partitions to create.
+#' @return a new RDD containing all pairs of elements with values in a list
+#' in all RDDs.
+#' @examples
+#'\dontrun{
+#' sc <- sparkR.init()
+#' rdd1 <- parallelize(sc, list(list(1, 1), list(2, 4)))
+#' rdd2 <- parallelize(sc, list(list(1, 2), list(1, 3)))
+#' cogroup(rdd1, rdd2, numPartitions = 2L)
+#' # list(list(1, list(1, list(2, 3))), list(2, list(list(4), list()))
+#'}
+#' @rdname cogroup
+#' @aliases cogroup,RDD-method
+setMethod("cogroup",
+ "RDD",
+ function(..., numPartitions) {
+ rdds <- list(...)
+ rddsLen <- length(rdds)
+ for (i in 1:rddsLen) {
+ rdds[[i]] <- lapply(rdds[[i]],
+ function(x) { list(x[[1]], list(i, x[[2]])) })
+ # TODO(hao): As issue [SparkR-142] mentions, the right value of i
+ # will not be captured into UDF if getJRDD is not invoked.
+ # It should be resolved together with that issue.
+ getJRDD(rdds[[i]]) # Capture the closure.
+ }
+ union.rdd <- Reduce(unionRDD, rdds)
+ group.func <- function(vlist) {
+ res <- list()
+ length(res) <- rddsLen
+ for (x in vlist) {
+ i <- x[[1]]
+ acc <- res[[i]]
+ # Create an accumulator.
+ if (is.null(acc)) {
+ acc <- initAccumulator()
+ }
+ addItemToAccumulator(acc, x[[2]])
+ res[[i]] <- acc
+ }
+ lapply(res, function(acc) {
+ if (is.null(acc)) {
+ list()
+ } else {
+ acc$data
+ }
+ })
+ }
+ cogroup.rdd <- mapValues(groupByKey(union.rdd, numPartitions),
+ group.func)
+ })
+
+#' Sort a (k, v) pair RDD by k.
+#'
+#' @param x A (k, v) pair RDD to be sorted.
+#' @param ascending A flag to indicate whether the sorting is ascending or descending.
+#' @param numPartitions Number of partitions to create.
+#' @return An RDD where all (k, v) pair elements are sorted.
+#' @examples
+#'\dontrun{
+#' sc <- sparkR.init()
+#' rdd <- parallelize(sc, list(list(3, 1), list(2, 2), list(1, 3)))
+#' collect(sortByKey(rdd)) # list (list(1, 3), list(2, 2), list(3, 1))
+#'}
+#' @rdname sortByKey
+#' @aliases sortByKey,RDD,RDD-method
+setMethod("sortByKey",
+ signature(x = "RDD"),
+ function(x, ascending = TRUE, numPartitions = SparkR::numPartitions(x)) {
+ rangeBounds <- list()
+
+ if (numPartitions > 1) {
+ rddSize <- count(x)
+ # constant from Spark's RangePartitioner
+ maxSampleSize <- numPartitions * 20
+ fraction <- min(maxSampleSize / max(rddSize, 1), 1.0)
+
+ samples <- collect(keys(sampleRDD(x, FALSE, fraction, 1L)))
+
+ # Note: the built-in R sort() function only works on atomic vectors
+ samples <- sort(unlist(samples, recursive = FALSE), decreasing = !ascending)
+
+ if (length(samples) > 0) {
+ rangeBounds <- lapply(seq_len(numPartitions - 1),
+ function(i) {
+ j <- ceiling(length(samples) * i / numPartitions)
+ samples[j]
+ })
+ }
+ }
+
+ rangePartitionFunc <- function(key) {
+ partition <- 0
+
+ # TODO: Use binary search instead of linear search, similar with Spark
+ while (partition < length(rangeBounds) && key > rangeBounds[[partition + 1]]) {
+ partition <- partition + 1
+ }
+
+ if (ascending) {
+ partition
+ } else {
+ numPartitions - partition - 1
+ }
+ }
+
+ partitionFunc <- function(part) {
+ sortKeyValueList(part, decreasing = !ascending)
+ }
+
+ newRDD <- partitionBy(x, numPartitions, rangePartitionFunc)
+ lapplyPartition(newRDD, partitionFunc)
+ })
+
diff --git a/R/pkg/R/serialize.R b/R/pkg/R/serialize.R
new file mode 100644
index 0000000000..8a9c0c652c
--- /dev/null
+++ b/R/pkg/R/serialize.R
@@ -0,0 +1,195 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# Utility functions to serialize R objects so they can be read in Java.
+
+# Type mapping from R to Java
+#
+# NULL -> Void
+# integer -> Int
+# character -> String
+# logical -> Boolean
+# double, numeric -> Double
+# raw -> Array[Byte]
+# Date -> Date
+# POSIXct,POSIXlt -> Time
+#
+# list[T] -> Array[T], where T is one of above mentioned types
+# environment -> Map[String, T], where T is a native type
+# jobj -> Object, where jobj is an object created in the backend
+
+writeObject <- function(con, object, writeType = TRUE) {
+ # NOTE: In R vectors have same type as objects. So we don't support
+ # passing in vectors as arrays and instead require arrays to be passed
+ # as lists.
+ type <- class(object)[[1]] # class of POSIXlt is c("POSIXlt", "POSIXt")
+ if (writeType) {
+ writeType(con, type)
+ }
+ switch(type,
+ NULL = writeVoid(con),
+ integer = writeInt(con, object),
+ character = writeString(con, object),
+ logical = writeBoolean(con, object),
+ double = writeDouble(con, object),
+ numeric = writeDouble(con, object),
+ raw = writeRaw(con, object),
+ list = writeList(con, object),
+ jobj = writeJobj(con, object),
+ environment = writeEnv(con, object),
+ Date = writeDate(con, object),
+ POSIXlt = writeTime(con, object),
+ POSIXct = writeTime(con, object),
+ stop(paste("Unsupported type for serialization", type)))
+}
+
+writeVoid <- function(con) {
+ # no value for NULL
+}
+
+writeJobj <- function(con, value) {
+ if (!isValidJobj(value)) {
+ stop("invalid jobj ", value$id)
+ }
+ writeString(con, value$id)
+}
+
+writeString <- function(con, value) {
+ writeInt(con, as.integer(nchar(value) + 1))
+ writeBin(value, con, endian = "big")
+}
+
+writeInt <- function(con, value) {
+ writeBin(as.integer(value), con, endian = "big")
+}
+
+writeDouble <- function(con, value) {
+ writeBin(value, con, endian = "big")
+}
+
+writeBoolean <- function(con, value) {
+ # TRUE becomes 1, FALSE becomes 0
+ writeInt(con, as.integer(value))
+}
+
+writeRawSerialize <- function(outputCon, batch) {
+ outputSer <- serialize(batch, ascii = FALSE, connection = NULL)
+ writeRaw(outputCon, outputSer)
+}
+
+writeRowSerialize <- function(outputCon, rows) {
+ invisible(lapply(rows, function(r) {
+ bytes <- serializeRow(r)
+ writeRaw(outputCon, bytes)
+ }))
+}
+
+serializeRow <- function(row) {
+ rawObj <- rawConnection(raw(0), "wb")
+ on.exit(close(rawObj))
+ writeRow(rawObj, row)
+ rawConnectionValue(rawObj)
+}
+
+writeRow <- function(con, row) {
+ numCols <- length(row)
+ writeInt(con, numCols)
+ for (i in 1:numCols) {
+ writeObject(con, row[[i]])
+ }
+}
+
+writeRaw <- function(con, batch) {
+ writeInt(con, length(batch))
+ writeBin(batch, con, endian = "big")
+}
+
+writeType <- function(con, class) {
+ type <- switch(class,
+ NULL = "n",
+ integer = "i",
+ character = "c",
+ logical = "b",
+ double = "d",
+ numeric = "d",
+ raw = "r",
+ list = "l",
+ jobj = "j",
+ environment = "e",
+ Date = "D",
+ POSIXlt = 't',
+ POSIXct = 't',
+ stop(paste("Unsupported type for serialization", class)))
+ writeBin(charToRaw(type), con)
+}
+
+# Used to pass arrays where all the elements are of the same type
+writeList <- function(con, arr) {
+ # All elements should be of same type
+ elemType <- unique(sapply(arr, function(elem) { class(elem) }))
+ stopifnot(length(elemType) <= 1)
+
+ # TODO: Empty lists are given type "character" right now.
+ # This may not work if the Java side expects array of any other type.
+ if (length(elemType) == 0) {
+ elemType <- class("somestring")
+ }
+
+ writeType(con, elemType)
+ writeInt(con, length(arr))
+
+ if (length(arr) > 0) {
+ for (a in arr) {
+ writeObject(con, a, FALSE)
+ }
+ }
+}
+
+# Used to pass in hash maps required on Java side.
+writeEnv <- function(con, env) {
+ len <- length(env)
+
+ writeInt(con, len)
+ if (len > 0) {
+ writeList(con, as.list(ls(env)))
+ vals <- lapply(ls(env), function(x) { env[[x]] })
+ writeList(con, as.list(vals))
+ }
+}
+
+writeDate <- function(con, date) {
+ writeString(con, as.character(date))
+}
+
+writeTime <- function(con, time) {
+ writeDouble(con, as.double(time))
+}
+
+# Used to serialize in a list of objects where each
+# object can be of a different type. Serialization format is
+# <object type> <object> for each object
+writeArgs <- function(con, args) {
+ if (length(args) > 0) {
+ for (a in args) {
+ writeObject(con, a)
+ }
+ }
+}
+
+writeStrings <- function(con, stringList) {
+ writeLines(unlist(stringList), con)
+}
diff --git a/R/pkg/R/sparkR.R b/R/pkg/R/sparkR.R
new file mode 100644
index 0000000000..bc82df01f0
--- /dev/null
+++ b/R/pkg/R/sparkR.R
@@ -0,0 +1,266 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+.sparkREnv <- new.env()
+
+sparkR.onLoad <- function(libname, pkgname) {
+ .sparkREnv$libname <- libname
+}
+
+# Utility function that returns TRUE if we have an active connection to the
+# backend and FALSE otherwise
+connExists <- function(env) {
+ tryCatch({
+ exists(".sparkRCon", envir = env) && isOpen(env[[".sparkRCon"]])
+ }, error = function(err) {
+ return(FALSE)
+ })
+}
+
+#' Stop the Spark context.
+#'
+#' Also terminates the backend this R session is connected to
+sparkR.stop <- function() {
+ env <- .sparkREnv
+ if (exists(".sparkRCon", envir = env)) {
+ # cat("Stopping SparkR\n")
+ if (exists(".sparkRjsc", envir = env)) {
+ sc <- get(".sparkRjsc", envir = env)
+ callJMethod(sc, "stop")
+ rm(".sparkRjsc", envir = env)
+ }
+
+ if (exists(".backendLaunched", envir = env)) {
+ callJStatic("SparkRHandler", "stopBackend")
+ }
+
+ # Also close the connection and remove it from our env
+ conn <- get(".sparkRCon", envir = env)
+ close(conn)
+
+ rm(".sparkRCon", envir = env)
+ rm(".scStartTime", envir = env)
+ }
+
+ if (exists(".monitorConn", envir = env)) {
+ conn <- get(".monitorConn", envir = env)
+ close(conn)
+ rm(".monitorConn", envir = env)
+ }
+
+ # Clear all broadcast variables we have
+ # as the jobj will not be valid if we restart the JVM
+ clearBroadcastVariables()
+
+ # Clear jobj maps
+ clearJobjs()
+}
+
+#' Initialize a new Spark Context.
+#'
+#' This function initializes a new SparkContext.
+#'
+#' @param master The Spark master URL.
+#' @param appName Application name to register with cluster manager
+#' @param sparkHome Spark Home directory
+#' @param sparkEnvir Named list of environment variables to set on worker nodes.
+#' @param sparkExecutorEnv Named list of environment variables to be used when launching executors.
+#' @param sparkJars Character string vector of jar files to pass to the worker nodes.
+#' @param sparkRLibDir The path where R is installed on the worker nodes.
+#' @export
+#' @examples
+#'\dontrun{
+#' sc <- sparkR.init("local[2]", "SparkR", "/home/spark")
+#' sc <- sparkR.init("local[2]", "SparkR", "/home/spark",
+#' list(spark.executor.memory="1g"))
+#' sc <- sparkR.init("yarn-client", "SparkR", "/home/spark",
+#' list(spark.executor.memory="1g"),
+#' list(LD_LIBRARY_PATH="/directory of JVM libraries (libjvm.so) on workers/"),
+#' c("jarfile1.jar","jarfile2.jar"))
+#'}
+
+sparkR.init <- function(
+ master = "",
+ appName = "SparkR",
+ sparkHome = Sys.getenv("SPARK_HOME"),
+ sparkEnvir = list(),
+ sparkExecutorEnv = list(),
+ sparkJars = "",
+ sparkRLibDir = "") {
+
+ if (exists(".sparkRjsc", envir = .sparkREnv)) {
+ cat("Re-using existing Spark Context. Please stop SparkR with sparkR.stop() or restart R to create a new Spark Context\n")
+ return(get(".sparkRjsc", envir = .sparkREnv))
+ }
+
+ sparkMem <- Sys.getenv("SPARK_MEM", "512m")
+ jars <- suppressWarnings(normalizePath(as.character(sparkJars)))
+
+ # Classpath separator is ";" on Windows
+ # URI needs four /// as from http://stackoverflow.com/a/18522792
+ if (.Platform$OS.type == "unix") {
+ collapseChar <- ":"
+ uriSep <- "//"
+ } else {
+ collapseChar <- ";"
+ uriSep <- "////"
+ }
+
+ existingPort <- Sys.getenv("EXISTING_SPARKR_BACKEND_PORT", "")
+ if (existingPort != "") {
+ backendPort <- existingPort
+ } else {
+ path <- tempfile(pattern = "backend_port")
+ launchBackend(
+ args = path,
+ sparkHome = sparkHome,
+ jars = jars,
+ sparkSubmitOpts = Sys.getenv("SPARKR_SUBMIT_ARGS", "sparkr-shell"))
+ # wait atmost 100 seconds for JVM to launch
+ wait <- 0.1
+ for (i in 1:25) {
+ Sys.sleep(wait)
+ if (file.exists(path)) {
+ break
+ }
+ wait <- wait * 1.25
+ }
+ if (!file.exists(path)) {
+ stop("JVM is not ready after 10 seconds")
+ }
+ f <- file(path, open='rb')
+ backendPort <- readInt(f)
+ monitorPort <- readInt(f)
+ close(f)
+ file.remove(path)
+ if (length(backendPort) == 0 || backendPort == 0 ||
+ length(monitorPort) == 0 || monitorPort == 0) {
+ stop("JVM failed to launch")
+ }
+ assign(".monitorConn", socketConnection(port = monitorPort), envir = .sparkREnv)
+ assign(".backendLaunched", 1, envir = .sparkREnv)
+ }
+
+ .sparkREnv$backendPort <- backendPort
+ tryCatch({
+ connectBackend("localhost", backendPort)
+ }, error = function(err) {
+ stop("Failed to connect JVM\n")
+ })
+
+ if (nchar(sparkHome) != 0) {
+ sparkHome <- normalizePath(sparkHome)
+ }
+
+ if (nchar(sparkRLibDir) != 0) {
+ .sparkREnv$libname <- sparkRLibDir
+ }
+
+ sparkEnvirMap <- new.env()
+ for (varname in names(sparkEnvir)) {
+ sparkEnvirMap[[varname]] <- sparkEnvir[[varname]]
+ }
+
+ sparkExecutorEnvMap <- new.env()
+ if (!any(names(sparkExecutorEnv) == "LD_LIBRARY_PATH")) {
+ sparkExecutorEnvMap[["LD_LIBRARY_PATH"]] <- paste0("$LD_LIBRARY_PATH:",Sys.getenv("LD_LIBRARY_PATH"))
+ }
+ for (varname in names(sparkExecutorEnv)) {
+ sparkExecutorEnvMap[[varname]] <- sparkExecutorEnv[[varname]]
+ }
+
+ nonEmptyJars <- Filter(function(x) { x != "" }, jars)
+ localJarPaths <- sapply(nonEmptyJars, function(j) { utils::URLencode(paste("file:", uriSep, j, sep = "")) })
+
+ # Set the start time to identify jobjs
+ # Seconds resolution is good enough for this purpose, so use ints
+ assign(".scStartTime", as.integer(Sys.time()), envir = .sparkREnv)
+
+ assign(
+ ".sparkRjsc",
+ callJStatic(
+ "org.apache.spark.api.r.RRDD",
+ "createSparkContext",
+ master,
+ appName,
+ as.character(sparkHome),
+ as.list(localJarPaths),
+ sparkEnvirMap,
+ sparkExecutorEnvMap),
+ envir = .sparkREnv
+ )
+
+ sc <- get(".sparkRjsc", envir = .sparkREnv)
+
+ # Register a finalizer to sleep 1 seconds on R exit to make RStudio happy
+ reg.finalizer(.sparkREnv, function(x) { Sys.sleep(1) }, onexit = TRUE)
+
+ sc
+}
+
+#' Initialize a new SQLContext.
+#'
+#' This function creates a SparkContext from an existing JavaSparkContext and
+#' then uses it to initialize a new SQLContext
+#'
+#' @param jsc The existing JavaSparkContext created with SparkR.init()
+#' @export
+#' @examples
+#'\dontrun{
+#' sc <- sparkR.init()
+#' sqlCtx <- sparkRSQL.init(sc)
+#'}
+
+sparkRSQL.init <- function(jsc) {
+ if (exists(".sparkRSQLsc", envir = .sparkREnv)) {
+ return(get(".sparkRSQLsc", envir = .sparkREnv))
+ }
+
+ sqlCtx <- callJStatic("org.apache.spark.sql.api.r.SQLUtils",
+ "createSQLContext",
+ jsc)
+ assign(".sparkRSQLsc", sqlCtx, envir = .sparkREnv)
+ sqlCtx
+}
+
+#' Initialize a new HiveContext.
+#'
+#' This function creates a HiveContext from an existing JavaSparkContext
+#'
+#' @param jsc The existing JavaSparkContext created with SparkR.init()
+#' @export
+#' @examples
+#'\dontrun{
+#' sc <- sparkR.init()
+#' sqlCtx <- sparkRHive.init(sc)
+#'}
+
+sparkRHive.init <- function(jsc) {
+ if (exists(".sparkRHivesc", envir = .sparkREnv)) {
+ return(get(".sparkRHivesc", envir = .sparkREnv))
+ }
+
+ ssc <- callJMethod(jsc, "sc")
+ hiveCtx <- tryCatch({
+ newJObject("org.apache.spark.sql.hive.HiveContext", ssc)
+ }, error = function(err) {
+ stop("Spark SQL is not built with Hive support")
+ })
+
+ assign(".sparkRHivesc", hiveCtx, envir = .sparkREnv)
+ hiveCtx
+}
diff --git a/R/pkg/R/utils.R b/R/pkg/R/utils.R
new file mode 100644
index 0000000000..c337fb0751
--- /dev/null
+++ b/R/pkg/R/utils.R
@@ -0,0 +1,467 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# Utilities and Helpers
+
+# Given a JList<T>, returns an R list containing the same elements, the number
+# of which is optionally upper bounded by `logicalUpperBound` (by default,
+# return all elements). Takes care of deserializations and type conversions.
+convertJListToRList <- function(jList, flatten, logicalUpperBound = NULL,
+ serializedMode = "byte") {
+ arrSize <- callJMethod(jList, "size")
+
+ # Datasets with serializedMode == "string" (such as an RDD directly generated by textFile()):
+ # each partition is not dense-packed into one Array[Byte], and `arrSize`
+ # here corresponds to number of logical elements. Thus we can prune here.
+ if (serializedMode == "string" && !is.null(logicalUpperBound)) {
+ arrSize <- min(arrSize, logicalUpperBound)
+ }
+
+ results <- if (arrSize > 0) {
+ lapply(0:(arrSize - 1),
+ function(index) {
+ obj <- callJMethod(jList, "get", as.integer(index))
+
+ # Assume it is either an R object or a Java obj ref.
+ if (inherits(obj, "jobj")) {
+ if (isInstanceOf(obj, "scala.Tuple2")) {
+ # JavaPairRDD[Array[Byte], Array[Byte]].
+
+ keyBytes = callJMethod(obj, "_1")
+ valBytes = callJMethod(obj, "_2")
+ res <- list(unserialize(keyBytes),
+ unserialize(valBytes))
+ } else {
+ stop(paste("utils.R: convertJListToRList only supports",
+ "RDD[Array[Byte]] and",
+ "JavaPairRDD[Array[Byte], Array[Byte]] for now"))
+ }
+ } else {
+ if (inherits(obj, "raw")) {
+ if (serializedMode == "byte") {
+ # RDD[Array[Byte]]. `obj` is a whole partition.
+ res <- unserialize(obj)
+ # For serialized datasets, `obj` (and `rRaw`) here corresponds to
+ # one whole partition dense-packed together. We deserialize the
+ # whole partition first, then cap the number of elements to be returned.
+ } else if (serializedMode == "row") {
+ res <- readRowList(obj)
+ # For DataFrames that have been converted to RRDDs, we call readRowList
+ # which will read in each row of the RRDD as a list and deserialize
+ # each element.
+ flatten <<- FALSE
+ # Use global assignment to change the flatten flag. This means
+ # we don't have to worry about the default argument in other functions
+ # e.g. collect
+ }
+ # TODO: is it possible to distinguish element boundary so that we can
+ # unserialize only what we need?
+ if (!is.null(logicalUpperBound)) {
+ res <- head(res, n = logicalUpperBound)
+ }
+ } else {
+ # obj is of a primitive Java type, is simplified to R's
+ # corresponding type.
+ res <- list(obj)
+ }
+ }
+ res
+ })
+ } else {
+ list()
+ }
+
+ if (flatten) {
+ as.list(unlist(results, recursive = FALSE))
+ } else {
+ as.list(results)
+ }
+}
+
+# Returns TRUE if `name` refers to an RDD in the given environment `env`
+isRDD <- function(name, env) {
+ obj <- get(name, envir = env)
+ inherits(obj, "RDD")
+}
+
+#' Compute the hashCode of an object
+#'
+#' Java-style function to compute the hashCode for the given object. Returns
+#' an integer value.
+#'
+#' @details
+#' This only works for integer, numeric and character types right now.
+#'
+#' @param key the object to be hashed
+#' @return the hash code as an integer
+#' @export
+#' @examples
+#' hashCode(1L) # 1
+#' hashCode(1.0) # 1072693248
+#' hashCode("1") # 49
+hashCode <- function(key) {
+ if (class(key) == "integer") {
+ as.integer(key[[1]])
+ } else if (class(key) == "numeric") {
+ # Convert the double to long and then calculate the hash code
+ rawVec <- writeBin(key[[1]], con = raw())
+ intBits <- packBits(rawToBits(rawVec), "integer")
+ as.integer(bitwXor(intBits[2], intBits[1]))
+ } else if (class(key) == "character") {
+ .Call("stringHashCode", key)
+ } else {
+ warning(paste("Could not hash object, returning 0", sep = ""))
+ as.integer(0)
+ }
+}
+
+# Create a new RDD with serializedMode == "byte".
+# Return itself if already in "byte" format.
+serializeToBytes <- function(rdd) {
+ if (!inherits(rdd, "RDD")) {
+ stop("Argument 'rdd' is not an RDD type.")
+ }
+ if (getSerializedMode(rdd) != "byte") {
+ ser.rdd <- lapply(rdd, function(x) { x })
+ return(ser.rdd)
+ } else {
+ return(rdd)
+ }
+}
+
+# Create a new RDD with serializedMode == "string".
+# Return itself if already in "string" format.
+serializeToString <- function(rdd) {
+ if (!inherits(rdd, "RDD")) {
+ stop("Argument 'rdd' is not an RDD type.")
+ }
+ if (getSerializedMode(rdd) != "string") {
+ ser.rdd <- lapply(rdd, function(x) { toString(x) })
+ # force it to create jrdd using "string"
+ getJRDD(ser.rdd, serializedMode = "string")
+ return(ser.rdd)
+ } else {
+ return(rdd)
+ }
+}
+
+# Fast append to list by using an accumulator.
+# http://stackoverflow.com/questions/17046336/here-we-go-again-append-an-element-to-a-list-in-r
+#
+# The accumulator should has three fields size, counter and data.
+# This function amortizes the allocation cost by doubling
+# the size of the list every time it fills up.
+addItemToAccumulator <- function(acc, item) {
+ if(acc$counter == acc$size) {
+ acc$size <- acc$size * 2
+ length(acc$data) <- acc$size
+ }
+ acc$counter <- acc$counter + 1
+ acc$data[[acc$counter]] <- item
+}
+
+initAccumulator <- function() {
+ acc <- new.env()
+ acc$counter <- 0
+ acc$data <- list(NULL)
+ acc$size <- 1
+ acc
+}
+
+# Utility function to sort a list of key value pairs
+# Used in unit tests
+sortKeyValueList <- function(kv_list, decreasing = FALSE) {
+ keys <- sapply(kv_list, function(x) x[[1]])
+ kv_list[order(keys, decreasing = decreasing)]
+}
+
+# Utility function to generate compact R lists from grouped rdd
+# Used in Join-family functions
+# param:
+# tagged_list R list generated via groupByKey with tags(1L, 2L, ...)
+# cnull Boolean list where each element determines whether the corresponding list should
+# be converted to list(NULL)
+genCompactLists <- function(tagged_list, cnull) {
+ len <- length(tagged_list)
+ lists <- list(vector("list", len), vector("list", len))
+ index <- list(1, 1)
+
+ for (x in tagged_list) {
+ tag <- x[[1]]
+ idx <- index[[tag]]
+ lists[[tag]][[idx]] <- x[[2]]
+ index[[tag]] <- idx + 1
+ }
+
+ len <- lapply(index, function(x) x - 1)
+ for (i in (1:2)) {
+ if (cnull[[i]] && len[[i]] == 0) {
+ lists[[i]] <- list(NULL)
+ } else {
+ length(lists[[i]]) <- len[[i]]
+ }
+ }
+
+ lists
+}
+
+# Utility function to merge compact R lists
+# Used in Join-family functions
+# param:
+# left/right Two compact lists ready for Cartesian product
+mergeCompactLists <- function(left, right) {
+ result <- list()
+ length(result) <- length(left) * length(right)
+ index <- 1
+ for (i in left) {
+ for (j in right) {
+ result[[index]] <- list(i, j)
+ index <- index + 1
+ }
+ }
+ result
+}
+
+# Utility function to wrapper above two operations
+# Used in Join-family functions
+# param (same as genCompactLists):
+# tagged_list R list generated via groupByKey with tags(1L, 2L, ...)
+# cnull Boolean list where each element determines whether the corresponding list should
+# be converted to list(NULL)
+joinTaggedList <- function(tagged_list, cnull) {
+ lists <- genCompactLists(tagged_list, cnull)
+ mergeCompactLists(lists[[1]], lists[[2]])
+}
+
+# Utility function to reduce a key-value list with predicate
+# Used in *ByKey functions
+# param
+# pair key-value pair
+# keys/vals env of key/value with hashes
+# updateOrCreatePred predicate function
+# updateFn update or merge function for existing pair, similar with `mergeVal` @combineByKey
+# createFn create function for new pair, similar with `createCombiner` @combinebykey
+updateOrCreatePair <- function(pair, keys, vals, updateOrCreatePred, updateFn, createFn) {
+ # assume hashVal bind to `$hash`, key/val with index 1/2
+ hashVal <- pair$hash
+ key <- pair[[1]]
+ val <- pair[[2]]
+ if (updateOrCreatePred(pair)) {
+ assign(hashVal, do.call(updateFn, list(get(hashVal, envir = vals), val)), envir = vals)
+ } else {
+ assign(hashVal, do.call(createFn, list(val)), envir = vals)
+ assign(hashVal, key, envir = keys)
+ }
+}
+
+# Utility function to convert key&values envs into key-val list
+convertEnvsToList <- function(keys, vals) {
+ lapply(ls(keys),
+ function(name) {
+ list(keys[[name]], vals[[name]])
+ })
+}
+
+# Utility function to capture the varargs into environment object
+varargsToEnv <- function(...) {
+ pairs <- as.list(substitute(list(...)))[-1L]
+ env <- new.env()
+ for (name in names(pairs)) {
+ env[[name]] <- pairs[[name]]
+ }
+ env
+}
+
+getStorageLevel <- function(newLevel = c("DISK_ONLY",
+ "DISK_ONLY_2",
+ "MEMORY_AND_DISK",
+ "MEMORY_AND_DISK_2",
+ "MEMORY_AND_DISK_SER",
+ "MEMORY_AND_DISK_SER_2",
+ "MEMORY_ONLY",
+ "MEMORY_ONLY_2",
+ "MEMORY_ONLY_SER",
+ "MEMORY_ONLY_SER_2",
+ "OFF_HEAP")) {
+ match.arg(newLevel)
+ storageLevel <- switch(newLevel,
+ "DISK_ONLY" = callJStatic("org.apache.spark.storage.StorageLevel", "DISK_ONLY"),
+ "DISK_ONLY_2" = callJStatic("org.apache.spark.storage.StorageLevel", "DISK_ONLY_2"),
+ "MEMORY_AND_DISK" = callJStatic("org.apache.spark.storage.StorageLevel", "MEMORY_AND_DISK"),
+ "MEMORY_AND_DISK_2" = callJStatic("org.apache.spark.storage.StorageLevel", "MEMORY_AND_DISK_2"),
+ "MEMORY_AND_DISK_SER" = callJStatic("org.apache.spark.storage.StorageLevel", "MEMORY_AND_DISK_SER"),
+ "MEMORY_AND_DISK_SER_2" = callJStatic("org.apache.spark.storage.StorageLevel", "MEMORY_AND_DISK_SER_2"),
+ "MEMORY_ONLY" = callJStatic("org.apache.spark.storage.StorageLevel", "MEMORY_ONLY"),
+ "MEMORY_ONLY_2" = callJStatic("org.apache.spark.storage.StorageLevel", "MEMORY_ONLY_2"),
+ "MEMORY_ONLY_SER" = callJStatic("org.apache.spark.storage.StorageLevel", "MEMORY_ONLY_SER"),
+ "MEMORY_ONLY_SER_2" = callJStatic("org.apache.spark.storage.StorageLevel", "MEMORY_ONLY_SER_2"),
+ "OFF_HEAP" = callJStatic("org.apache.spark.storage.StorageLevel", "OFF_HEAP"))
+}
+
+# Utility function for functions where an argument needs to be integer but we want to allow
+# the user to type (for example) `5` instead of `5L` to avoid a confusing error message.
+numToInt <- function(num) {
+ if (as.integer(num) != num) {
+ warning(paste("Coercing", as.list(sys.call())[[2]], "to integer."))
+ }
+ as.integer(num)
+}
+
+# create a Seq in JVM
+toSeq <- function(...) {
+ callJStatic("org.apache.spark.sql.api.r.SQLUtils", "toSeq", list(...))
+}
+
+# create a Seq in JVM from a list
+listToSeq <- function(l) {
+ callJStatic("org.apache.spark.sql.api.r.SQLUtils", "toSeq", l)
+}
+
+# Utility function to recursively traverse the Abstract Syntax Tree (AST) of a
+# user defined function (UDF), and to examine variables in the UDF to decide
+# if their values should be included in the new function environment.
+# param
+# node The current AST node in the traversal.
+# oldEnv The original function environment.
+# defVars An Accumulator of variables names defined in the function's calling environment,
+# including function argument and local variable names.
+# checkedFunc An environment of function objects examined during cleanClosure. It can
+# be considered as a "name"-to-"list of functions" mapping.
+# newEnv A new function environment to store necessary function dependencies, an output argument.
+processClosure <- function(node, oldEnv, defVars, checkedFuncs, newEnv) {
+ nodeLen <- length(node)
+
+ if (nodeLen > 1 && typeof(node) == "language") {
+ # Recursive case: current AST node is an internal node, check for its children.
+ if (length(node[[1]]) > 1) {
+ for (i in 1:nodeLen) {
+ processClosure(node[[i]], oldEnv, defVars, checkedFuncs, newEnv)
+ }
+ } else { # if node[[1]] is length of 1, check for some R special functions.
+ nodeChar <- as.character(node[[1]])
+ if (nodeChar == "{" || nodeChar == "(") { # Skip start symbol.
+ for (i in 2:nodeLen) {
+ processClosure(node[[i]], oldEnv, defVars, checkedFuncs, newEnv)
+ }
+ } else if (nodeChar == "<-" || nodeChar == "=" ||
+ nodeChar == "<<-") { # Assignment Ops.
+ defVar <- node[[2]]
+ if (length(defVar) == 1 && typeof(defVar) == "symbol") {
+ # Add the defined variable name into defVars.
+ addItemToAccumulator(defVars, as.character(defVar))
+ } else {
+ processClosure(node[[2]], oldEnv, defVars, checkedFuncs, newEnv)
+ }
+ for (i in 3:nodeLen) {
+ processClosure(node[[i]], oldEnv, defVars, checkedFuncs, newEnv)
+ }
+ } else if (nodeChar == "function") { # Function definition.
+ # Add parameter names.
+ newArgs <- names(node[[2]])
+ lapply(newArgs, function(arg) { addItemToAccumulator(defVars, arg) })
+ for (i in 3:nodeLen) {
+ processClosure(node[[i]], oldEnv, defVars, checkedFuncs, newEnv)
+ }
+ } else if (nodeChar == "$") { # Skip the field.
+ processClosure(node[[2]], oldEnv, defVars, checkedFuncs, newEnv)
+ } else if (nodeChar == "::" || nodeChar == ":::") {
+ processClosure(node[[3]], oldEnv, defVars, checkedFuncs, newEnv)
+ } else {
+ for (i in 1:nodeLen) {
+ processClosure(node[[i]], oldEnv, defVars, checkedFuncs, newEnv)
+ }
+ }
+ }
+ } else if (nodeLen == 1 &&
+ (typeof(node) == "symbol" || typeof(node) == "language")) {
+ # Base case: current AST node is a leaf node and a symbol or a function call.
+ nodeChar <- as.character(node)
+ if (!nodeChar %in% defVars$data) { # Not a function parameter or local variable.
+ func.env <- oldEnv
+ topEnv <- parent.env(.GlobalEnv)
+ # Search in function environment, and function's enclosing environments
+ # up to global environment. There is no need to look into package environments
+ # above the global or namespace environment that is not SparkR below the global,
+ # as they are assumed to be loaded on workers.
+ while (!identical(func.env, topEnv)) {
+ # Namespaces other than "SparkR" will not be searched.
+ if (!isNamespace(func.env) ||
+ (getNamespaceName(func.env) == "SparkR" &&
+ !(nodeChar %in% getNamespaceExports("SparkR")))) { # Only include SparkR internals.
+ # Set parameter 'inherits' to FALSE since we do not need to search in
+ # attached package environments.
+ if (tryCatch(exists(nodeChar, envir = func.env, inherits = FALSE),
+ error = function(e) { FALSE })) {
+ obj <- get(nodeChar, envir = func.env, inherits = FALSE)
+ if (is.function(obj)) { # If the node is a function call.
+ funcList <- mget(nodeChar, envir = checkedFuncs, inherits = F,
+ ifnotfound = list(list(NULL)))[[1]]
+ found <- sapply(funcList, function(func) {
+ ifelse(identical(func, obj), TRUE, FALSE)
+ })
+ if (sum(found) > 0) { # If function has been examined, ignore.
+ break
+ }
+ # Function has not been examined, record it and recursively clean its closure.
+ assign(nodeChar,
+ if (is.null(funcList[[1]])) {
+ list(obj)
+ } else {
+ append(funcList, obj)
+ },
+ envir = checkedFuncs)
+ obj <- cleanClosure(obj, checkedFuncs)
+ }
+ assign(nodeChar, obj, envir = newEnv)
+ break
+ }
+ }
+
+ # Continue to search in enclosure.
+ func.env <- parent.env(func.env)
+ }
+ }
+ }
+}
+
+# Utility function to get user defined function (UDF) dependencies (closure).
+# More specifically, this function captures the values of free variables defined
+# outside a UDF, and stores them in the function's environment.
+# param
+# func A function whose closure needs to be captured.
+# checkedFunc An environment of function objects examined during cleanClosure. It can be
+# considered as a "name"-to-"list of functions" mapping.
+# return value
+# a new version of func that has an correct environment (closure).
+cleanClosure <- function(func, checkedFuncs = new.env()) {
+ if (is.function(func)) {
+ newEnv <- new.env(parent = .GlobalEnv)
+ func.body <- body(func)
+ oldEnv <- environment(func)
+ # defVars is an Accumulator of variables names defined in the function's calling
+ # environment. First, function's arguments are added to defVars.
+ defVars <- initAccumulator()
+ argNames <- names(as.list(args(func)))
+ for (i in 1:(length(argNames) - 1)) { # Remove the ending NULL in pairlist.
+ addItemToAccumulator(defVars, argNames[i])
+ }
+ # Recursively examine variables in the function body.
+ processClosure(func.body, oldEnv, defVars, checkedFuncs, newEnv)
+ environment(func) <- newEnv
+ }
+ func
+}
diff --git a/R/pkg/R/zzz.R b/R/pkg/R/zzz.R
new file mode 100644
index 0000000000..80d796d467
--- /dev/null
+++ b/R/pkg/R/zzz.R
@@ -0,0 +1,21 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+.onLoad <- function(libname, pkgname) {
+ sparkR.onLoad(libname, pkgname)
+}
+
diff --git a/R/pkg/inst/profile/general.R b/R/pkg/inst/profile/general.R
new file mode 100644
index 0000000000..8fe711b622
--- /dev/null
+++ b/R/pkg/inst/profile/general.R
@@ -0,0 +1,22 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+.First <- function() {
+ home <- Sys.getenv("SPARK_HOME")
+ .libPaths(c(file.path(home, "R", "lib"), .libPaths()))
+ Sys.setenv(NOAWT=1)
+}
diff --git a/R/pkg/inst/profile/shell.R b/R/pkg/inst/profile/shell.R
new file mode 100644
index 0000000000..7a7f203115
--- /dev/null
+++ b/R/pkg/inst/profile/shell.R
@@ -0,0 +1,31 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+.First <- function() {
+ home <- Sys.getenv("SPARK_HOME")
+ .libPaths(c(file.path(home, "R", "lib"), .libPaths()))
+ Sys.setenv(NOAWT=1)
+
+ library(utils)
+ library(SparkR)
+ sc <- sparkR.init(Sys.getenv("MASTER", unset = ""))
+ assign("sc", sc, envir=.GlobalEnv)
+ sqlCtx <- sparkRSQL.init(sc)
+ assign("sqlCtx", sqlCtx, envir=.GlobalEnv)
+ cat("\n Welcome to SparkR!")
+ cat("\n Spark context is available as sc, SQL context is available as sqlCtx\n")
+}
diff --git a/R/pkg/inst/tests/test_binaryFile.R b/R/pkg/inst/tests/test_binaryFile.R
new file mode 100644
index 0000000000..4bb5f58d83
--- /dev/null
+++ b/R/pkg/inst/tests/test_binaryFile.R
@@ -0,0 +1,90 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+context("functions on binary files")
+
+# JavaSparkContext handle
+sc <- sparkR.init()
+
+mockFile = c("Spark is pretty.", "Spark is awesome.")
+
+test_that("saveAsObjectFile()/objectFile() following textFile() works", {
+ fileName1 <- tempfile(pattern="spark-test", fileext=".tmp")
+ fileName2 <- tempfile(pattern="spark-test", fileext=".tmp")
+ writeLines(mockFile, fileName1)
+
+ rdd <- textFile(sc, fileName1)
+ saveAsObjectFile(rdd, fileName2)
+ rdd <- objectFile(sc, fileName2)
+ expect_equal(collect(rdd), as.list(mockFile))
+
+ unlink(fileName1)
+ unlink(fileName2, recursive = TRUE)
+})
+
+test_that("saveAsObjectFile()/objectFile() works on a parallelized list", {
+ fileName <- tempfile(pattern="spark-test", fileext=".tmp")
+
+ l <- list(1, 2, 3)
+ rdd <- parallelize(sc, l)
+ saveAsObjectFile(rdd, fileName)
+ rdd <- objectFile(sc, fileName)
+ expect_equal(collect(rdd), l)
+
+ unlink(fileName, recursive = TRUE)
+})
+
+test_that("saveAsObjectFile()/objectFile() following RDD transformations works", {
+ fileName1 <- tempfile(pattern="spark-test", fileext=".tmp")
+ fileName2 <- tempfile(pattern="spark-test", fileext=".tmp")
+ writeLines(mockFile, fileName1)
+
+ rdd <- textFile(sc, fileName1)
+
+ words <- flatMap(rdd, function(line) { strsplit(line, " ")[[1]] })
+ wordCount <- lapply(words, function(word) { list(word, 1L) })
+
+ counts <- reduceByKey(wordCount, "+", 2L)
+
+ saveAsObjectFile(counts, fileName2)
+ counts <- objectFile(sc, fileName2)
+
+ output <- collect(counts)
+ expected <- list(list("awesome.", 1), list("Spark", 2), list("pretty.", 1),
+ list("is", 2))
+ expect_equal(sortKeyValueList(output), sortKeyValueList(expected))
+
+ unlink(fileName1)
+ unlink(fileName2, recursive = TRUE)
+})
+
+test_that("saveAsObjectFile()/objectFile() works with multiple paths", {
+ fileName1 <- tempfile(pattern="spark-test", fileext=".tmp")
+ fileName2 <- tempfile(pattern="spark-test", fileext=".tmp")
+
+ rdd1 <- parallelize(sc, "Spark is pretty.")
+ saveAsObjectFile(rdd1, fileName1)
+ rdd2 <- parallelize(sc, "Spark is awesome.")
+ saveAsObjectFile(rdd2, fileName2)
+
+ rdd <- objectFile(sc, c(fileName1, fileName2))
+ expect_true(count(rdd) == 2)
+
+ unlink(fileName1, recursive = TRUE)
+ unlink(fileName2, recursive = TRUE)
+})
+
diff --git a/R/pkg/inst/tests/test_binary_function.R b/R/pkg/inst/tests/test_binary_function.R
new file mode 100644
index 0000000000..c15553ba28
--- /dev/null
+++ b/R/pkg/inst/tests/test_binary_function.R
@@ -0,0 +1,68 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+context("binary functions")
+
+# JavaSparkContext handle
+sc <- sparkR.init()
+
+# Data
+nums <- 1:10
+rdd <- parallelize(sc, nums, 2L)
+
+# File content
+mockFile <- c("Spark is pretty.", "Spark is awesome.")
+
+test_that("union on two RDDs", {
+ actual <- collect(unionRDD(rdd, rdd))
+ expect_equal(actual, as.list(rep(nums, 2)))
+
+ fileName <- tempfile(pattern="spark-test", fileext=".tmp")
+ writeLines(mockFile, fileName)
+
+ text.rdd <- textFile(sc, fileName)
+ union.rdd <- unionRDD(rdd, text.rdd)
+ actual <- collect(union.rdd)
+ expect_equal(actual, c(as.list(nums), mockFile))
+ expect_true(getSerializedMode(union.rdd) == "byte")
+
+ rdd<- map(text.rdd, function(x) {x})
+ union.rdd <- unionRDD(rdd, text.rdd)
+ actual <- collect(union.rdd)
+ expect_equal(actual, as.list(c(mockFile, mockFile)))
+ expect_true(getSerializedMode(union.rdd) == "byte")
+
+ unlink(fileName)
+})
+
+test_that("cogroup on two RDDs", {
+ rdd1 <- parallelize(sc, list(list(1, 1), list(2, 4)))
+ rdd2 <- parallelize(sc, list(list(1, 2), list(1, 3)))
+ cogroup.rdd <- cogroup(rdd1, rdd2, numPartitions = 2L)
+ actual <- collect(cogroup.rdd)
+ expect_equal(actual,
+ list(list(1, list(list(1), list(2, 3))), list(2, list(list(4), list()))))
+
+ rdd1 <- parallelize(sc, list(list("a", 1), list("a", 4)))
+ rdd2 <- parallelize(sc, list(list("b", 2), list("a", 3)))
+ cogroup.rdd <- cogroup(rdd1, rdd2, numPartitions = 2L)
+ actual <- collect(cogroup.rdd)
+
+ expected <- list(list("b", list(list(), list(2))), list("a", list(list(1, 4), list(3))))
+ expect_equal(sortKeyValueList(actual),
+ sortKeyValueList(expected))
+})
diff --git a/R/pkg/inst/tests/test_broadcast.R b/R/pkg/inst/tests/test_broadcast.R
new file mode 100644
index 0000000000..fee91a427d
--- /dev/null
+++ b/R/pkg/inst/tests/test_broadcast.R
@@ -0,0 +1,48 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+context("broadcast variables")
+
+# JavaSparkContext handle
+sc <- sparkR.init()
+
+# Partitioned data
+nums <- 1:2
+rrdd <- parallelize(sc, nums, 2L)
+
+test_that("using broadcast variable", {
+ randomMat <- matrix(nrow=10, ncol=10, data=rnorm(100))
+ randomMatBr <- broadcast(sc, randomMat)
+
+ useBroadcast <- function(x) {
+ sum(value(randomMatBr) * x)
+ }
+ actual <- collect(lapply(rrdd, useBroadcast))
+ expected <- list(sum(randomMat) * 1, sum(randomMat) * 2)
+ expect_equal(actual, expected)
+})
+
+test_that("without using broadcast variable", {
+ randomMat <- matrix(nrow=10, ncol=10, data=rnorm(100))
+
+ useBroadcast <- function(x) {
+ sum(randomMat * x)
+ }
+ actual <- collect(lapply(rrdd, useBroadcast))
+ expected <- list(sum(randomMat) * 1, sum(randomMat) * 2)
+ expect_equal(actual, expected)
+})
diff --git a/R/pkg/inst/tests/test_context.R b/R/pkg/inst/tests/test_context.R
new file mode 100644
index 0000000000..e4aab37436
--- /dev/null
+++ b/R/pkg/inst/tests/test_context.R
@@ -0,0 +1,50 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+context("test functions in sparkR.R")
+
+test_that("repeatedly starting and stopping SparkR", {
+ for (i in 1:4) {
+ sc <- sparkR.init()
+ rdd <- parallelize(sc, 1:20, 2L)
+ expect_equal(count(rdd), 20)
+ sparkR.stop()
+ }
+})
+
+test_that("rdd GC across sparkR.stop", {
+ sparkR.stop()
+ sc <- sparkR.init() # sc should get id 0
+ rdd1 <- parallelize(sc, 1:20, 2L) # rdd1 should get id 1
+ rdd2 <- parallelize(sc, 1:10, 2L) # rdd2 should get id 2
+ sparkR.stop()
+
+ sc <- sparkR.init() # sc should get id 0 again
+
+ # GC rdd1 before creating rdd3 and rdd2 after
+ rm(rdd1)
+ gc()
+
+ rdd3 <- parallelize(sc, 1:20, 2L) # rdd3 should get id 1 now
+ rdd4 <- parallelize(sc, 1:10, 2L) # rdd4 should get id 2 now
+
+ rm(rdd2)
+ gc()
+
+ count(rdd3)
+ count(rdd4)
+})
diff --git a/R/pkg/inst/tests/test_includePackage.R b/R/pkg/inst/tests/test_includePackage.R
new file mode 100644
index 0000000000..8152b448d0
--- /dev/null
+++ b/R/pkg/inst/tests/test_includePackage.R
@@ -0,0 +1,57 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+context("include R packages")
+
+# JavaSparkContext handle
+sc <- sparkR.init()
+
+# Partitioned data
+nums <- 1:2
+rdd <- parallelize(sc, nums, 2L)
+
+test_that("include inside function", {
+ # Only run the test if plyr is installed.
+ if ("plyr" %in% rownames(installed.packages())) {
+ suppressPackageStartupMessages(library(plyr))
+ generateData <- function(x) {
+ suppressPackageStartupMessages(library(plyr))
+ attach(airquality)
+ result <- transform(Ozone, logOzone = log(Ozone))
+ result
+ }
+
+ data <- lapplyPartition(rdd, generateData)
+ actual <- collect(data)
+ }
+})
+
+test_that("use include package", {
+ # Only run the test if plyr is installed.
+ if ("plyr" %in% rownames(installed.packages())) {
+ suppressPackageStartupMessages(library(plyr))
+ generateData <- function(x) {
+ attach(airquality)
+ result <- transform(Ozone, logOzone = log(Ozone))
+ result
+ }
+
+ includePackage(sc, plyr)
+ data <- lapplyPartition(rdd, generateData)
+ actual <- collect(data)
+ }
+})
diff --git a/R/pkg/inst/tests/test_parallelize_collect.R b/R/pkg/inst/tests/test_parallelize_collect.R
new file mode 100644
index 0000000000..fff028657d
--- /dev/null
+++ b/R/pkg/inst/tests/test_parallelize_collect.R
@@ -0,0 +1,109 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+context("parallelize() and collect()")
+
+# Mock data
+numVector <- c(-10:97)
+numList <- list(sqrt(1), sqrt(2), sqrt(3), 4 ** 10)
+strVector <- c("Dexter Morgan: I suppose I should be upset, even feel",
+ "violated, but I'm not. No, in fact, I think this is a friendly",
+ "message, like \"Hey, wanna play?\" and yes, I want to play. ",
+ "I really, really do.")
+strList <- list("Dexter Morgan: Blood. Sometimes it sets my teeth on edge, ",
+ "other times it helps me control the chaos.",
+ "Dexter Morgan: Harry and Dorris Morgan did a wonderful job ",
+ "raising me. But they're both dead now. I didn't kill them. Honest.")
+
+numPairs <- list(list(1, 1), list(1, 2), list(2, 2), list(2, 3))
+strPairs <- list(list(strList, strList), list(strList, strList))
+
+# JavaSparkContext handle
+jsc <- sparkR.init()
+
+# Tests
+
+test_that("parallelize() on simple vectors and lists returns an RDD", {
+ numVectorRDD <- parallelize(jsc, numVector, 1)
+ numVectorRDD2 <- parallelize(jsc, numVector, 10)
+ numListRDD <- parallelize(jsc, numList, 1)
+ numListRDD2 <- parallelize(jsc, numList, 4)
+ strVectorRDD <- parallelize(jsc, strVector, 2)
+ strVectorRDD2 <- parallelize(jsc, strVector, 3)
+ strListRDD <- parallelize(jsc, strList, 4)
+ strListRDD2 <- parallelize(jsc, strList, 1)
+
+ rdds <- c(numVectorRDD,
+ numVectorRDD2,
+ numListRDD,
+ numListRDD2,
+ strVectorRDD,
+ strVectorRDD2,
+ strListRDD,
+ strListRDD2)
+
+ for (rdd in rdds) {
+ expect_true(inherits(rdd, "RDD"))
+ expect_true(.hasSlot(rdd, "jrdd")
+ && inherits(rdd@jrdd, "jobj")
+ && isInstanceOf(rdd@jrdd, "org.apache.spark.api.java.JavaRDD"))
+ }
+})
+
+test_that("collect(), following a parallelize(), gives back the original collections", {
+ numVectorRDD <- parallelize(jsc, numVector, 10)
+ expect_equal(collect(numVectorRDD), as.list(numVector))
+
+ numListRDD <- parallelize(jsc, numList, 1)
+ numListRDD2 <- parallelize(jsc, numList, 4)
+ expect_equal(collect(numListRDD), as.list(numList))
+ expect_equal(collect(numListRDD2), as.list(numList))
+
+ strVectorRDD <- parallelize(jsc, strVector, 2)
+ strVectorRDD2 <- parallelize(jsc, strVector, 3)
+ expect_equal(collect(strVectorRDD), as.list(strVector))
+ expect_equal(collect(strVectorRDD2), as.list(strVector))
+
+ strListRDD <- parallelize(jsc, strList, 4)
+ strListRDD2 <- parallelize(jsc, strList, 1)
+ expect_equal(collect(strListRDD), as.list(strList))
+ expect_equal(collect(strListRDD2), as.list(strList))
+})
+
+test_that("regression: collect() following a parallelize() does not drop elements", {
+ # 10 %/% 6 = 1, ceiling(10 / 6) = 2
+ collLen <- 10
+ numPart <- 6
+ expected <- runif(collLen)
+ actual <- collect(parallelize(jsc, expected, numPart))
+ expect_equal(actual, as.list(expected))
+})
+
+test_that("parallelize() and collect() work for lists of pairs (pairwise data)", {
+ # use the pairwise logical to indicate pairwise data
+ numPairsRDDD1 <- parallelize(jsc, numPairs, 1)
+ numPairsRDDD2 <- parallelize(jsc, numPairs, 2)
+ numPairsRDDD3 <- parallelize(jsc, numPairs, 3)
+ expect_equal(collect(numPairsRDDD1), numPairs)
+ expect_equal(collect(numPairsRDDD2), numPairs)
+ expect_equal(collect(numPairsRDDD3), numPairs)
+ # can also leave out the parameter name, if the params are supplied in order
+ strPairsRDDD1 <- parallelize(jsc, strPairs, 1)
+ strPairsRDDD2 <- parallelize(jsc, strPairs, 2)
+ expect_equal(collect(strPairsRDDD1), strPairs)
+ expect_equal(collect(strPairsRDDD2), strPairs)
+})
diff --git a/R/pkg/inst/tests/test_rdd.R b/R/pkg/inst/tests/test_rdd.R
new file mode 100644
index 0000000000..f75e0817b9
--- /dev/null
+++ b/R/pkg/inst/tests/test_rdd.R
@@ -0,0 +1,644 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+context("basic RDD functions")
+
+# JavaSparkContext handle
+sc <- sparkR.init()
+
+# Data
+nums <- 1:10
+rdd <- parallelize(sc, nums, 2L)
+
+intPairs <- list(list(1L, -1), list(2L, 100), list(2L, 1), list(1L, 200))
+intRdd <- parallelize(sc, intPairs, 2L)
+
+test_that("get number of partitions in RDD", {
+ expect_equal(numPartitions(rdd), 2)
+ expect_equal(numPartitions(intRdd), 2)
+})
+
+test_that("first on RDD", {
+ expect_true(first(rdd) == 1)
+ newrdd <- lapply(rdd, function(x) x + 1)
+ expect_true(first(newrdd) == 2)
+})
+
+test_that("count and length on RDD", {
+ expect_equal(count(rdd), 10)
+ expect_equal(length(rdd), 10)
+})
+
+test_that("count by values and keys", {
+ mods <- lapply(rdd, function(x) { x %% 3 })
+ actual <- countByValue(mods)
+ expected <- list(list(0, 3L), list(1, 4L), list(2, 3L))
+ expect_equal(sortKeyValueList(actual), sortKeyValueList(expected))
+
+ actual <- countByKey(intRdd)
+ expected <- list(list(2L, 2L), list(1L, 2L))
+ expect_equal(sortKeyValueList(actual), sortKeyValueList(expected))
+})
+
+test_that("lapply on RDD", {
+ multiples <- lapply(rdd, function(x) { 2 * x })
+ actual <- collect(multiples)
+ expect_equal(actual, as.list(nums * 2))
+})
+
+test_that("lapplyPartition on RDD", {
+ sums <- lapplyPartition(rdd, function(part) { sum(unlist(part)) })
+ actual <- collect(sums)
+ expect_equal(actual, list(15, 40))
+})
+
+test_that("mapPartitions on RDD", {
+ sums <- mapPartitions(rdd, function(part) { sum(unlist(part)) })
+ actual <- collect(sums)
+ expect_equal(actual, list(15, 40))
+})
+
+test_that("flatMap() on RDDs", {
+ flat <- flatMap(intRdd, function(x) { list(x, x) })
+ actual <- collect(flat)
+ expect_equal(actual, rep(intPairs, each=2))
+})
+
+test_that("filterRDD on RDD", {
+ filtered.rdd <- filterRDD(rdd, function(x) { x %% 2 == 0 })
+ actual <- collect(filtered.rdd)
+ expect_equal(actual, list(2, 4, 6, 8, 10))
+
+ filtered.rdd <- Filter(function(x) { x[[2]] < 0 }, intRdd)
+ actual <- collect(filtered.rdd)
+ expect_equal(actual, list(list(1L, -1)))
+
+ # Filter out all elements.
+ filtered.rdd <- filterRDD(rdd, function(x) { x > 10 })
+ actual <- collect(filtered.rdd)
+ expect_equal(actual, list())
+})
+
+test_that("lookup on RDD", {
+ vals <- lookup(intRdd, 1L)
+ expect_equal(vals, list(-1, 200))
+
+ vals <- lookup(intRdd, 3L)
+ expect_equal(vals, list())
+})
+
+test_that("several transformations on RDD (a benchmark on PipelinedRDD)", {
+ rdd2 <- rdd
+ for (i in 1:12)
+ rdd2 <- lapplyPartitionsWithIndex(
+ rdd2, function(split, part) {
+ part <- as.list(unlist(part) * split + i)
+ })
+ rdd2 <- lapply(rdd2, function(x) x + x)
+ actual <- collect(rdd2)
+ expected <- list(24, 24, 24, 24, 24,
+ 168, 170, 172, 174, 176)
+ expect_equal(actual, expected)
+})
+
+test_that("PipelinedRDD support actions: cache(), persist(), unpersist(), checkpoint()", {
+ # RDD
+ rdd2 <- rdd
+ # PipelinedRDD
+ rdd2 <- lapplyPartitionsWithIndex(
+ rdd2,
+ function(split, part) {
+ part <- as.list(unlist(part) * split)
+ })
+
+ cache(rdd2)
+ expect_true(rdd2@env$isCached)
+ rdd2 <- lapply(rdd2, function(x) x)
+ expect_false(rdd2@env$isCached)
+
+ unpersist(rdd2)
+ expect_false(rdd2@env$isCached)
+
+ persist(rdd2, "MEMORY_AND_DISK")
+ expect_true(rdd2@env$isCached)
+ rdd2 <- lapply(rdd2, function(x) x)
+ expect_false(rdd2@env$isCached)
+
+ unpersist(rdd2)
+ expect_false(rdd2@env$isCached)
+
+ setCheckpointDir(sc, "checkpoints")
+ checkpoint(rdd2)
+ expect_true(rdd2@env$isCheckpointed)
+
+ rdd2 <- lapply(rdd2, function(x) x)
+ expect_false(rdd2@env$isCached)
+ expect_false(rdd2@env$isCheckpointed)
+
+ # make sure the data is collectable
+ collect(rdd2)
+
+ unlink("checkpoints")
+})
+
+test_that("reduce on RDD", {
+ sum <- reduce(rdd, "+")
+ expect_equal(sum, 55)
+
+ # Also test with an inline function
+ sumInline <- reduce(rdd, function(x, y) { x + y })
+ expect_equal(sumInline, 55)
+})
+
+test_that("lapply with dependency", {
+ fa <- 5
+ multiples <- lapply(rdd, function(x) { fa * x })
+ actual <- collect(multiples)
+
+ expect_equal(actual, as.list(nums * 5))
+})
+
+test_that("lapplyPartitionsWithIndex on RDDs", {
+ func <- function(splitIndex, part) { list(splitIndex, Reduce("+", part)) }
+ actual <- collect(lapplyPartitionsWithIndex(rdd, func), flatten = FALSE)
+ expect_equal(actual, list(list(0, 15), list(1, 40)))
+
+ pairsRDD <- parallelize(sc, list(list(1, 2), list(3, 4), list(4, 8)), 1L)
+ partitionByParity <- function(key) { if (key %% 2 == 1) 0 else 1 }
+ mkTup <- function(splitIndex, part) { list(splitIndex, part) }
+ actual <- collect(lapplyPartitionsWithIndex(
+ partitionBy(pairsRDD, 2L, partitionByParity),
+ mkTup),
+ FALSE)
+ expect_equal(actual, list(list(0, list(list(1, 2), list(3, 4))),
+ list(1, list(list(4, 8)))))
+})
+
+test_that("sampleRDD() on RDDs", {
+ expect_equal(unlist(collect(sampleRDD(rdd, FALSE, 1.0, 2014L))), nums)
+})
+
+test_that("takeSample() on RDDs", {
+ # ported from RDDSuite.scala, modified seeds
+ data <- parallelize(sc, 1:100, 2L)
+ for (seed in 4:5) {
+ s <- takeSample(data, FALSE, 20L, seed)
+ expect_equal(length(s), 20L)
+ expect_equal(length(unique(s)), 20L)
+ for (elem in s) {
+ expect_true(elem >= 1 && elem <= 100)
+ }
+ }
+ for (seed in 4:5) {
+ s <- takeSample(data, FALSE, 200L, seed)
+ expect_equal(length(s), 100L)
+ expect_equal(length(unique(s)), 100L)
+ for (elem in s) {
+ expect_true(elem >= 1 && elem <= 100)
+ }
+ }
+ for (seed in 4:5) {
+ s <- takeSample(data, TRUE, 20L, seed)
+ expect_equal(length(s), 20L)
+ for (elem in s) {
+ expect_true(elem >= 1 && elem <= 100)
+ }
+ }
+ for (seed in 4:5) {
+ s <- takeSample(data, TRUE, 100L, seed)
+ expect_equal(length(s), 100L)
+ # Chance of getting all distinct elements is astronomically low, so test we
+ # got < 100
+ expect_true(length(unique(s)) < 100L)
+ }
+ for (seed in 4:5) {
+ s <- takeSample(data, TRUE, 200L, seed)
+ expect_equal(length(s), 200L)
+ # Chance of getting all distinct elements is still quite low, so test we
+ # got < 100
+ expect_true(length(unique(s)) < 100L)
+ }
+})
+
+test_that("mapValues() on pairwise RDDs", {
+ multiples <- mapValues(intRdd, function(x) { x * 2 })
+ actual <- collect(multiples)
+ expected <- lapply(intPairs, function(x) {
+ list(x[[1]], x[[2]] * 2)
+ })
+ expect_equal(sortKeyValueList(actual), sortKeyValueList(expected))
+})
+
+test_that("flatMapValues() on pairwise RDDs", {
+ l <- parallelize(sc, list(list(1, c(1,2)), list(2, c(3,4))))
+ actual <- collect(flatMapValues(l, function(x) { x }))
+ expect_equal(actual, list(list(1,1), list(1,2), list(2,3), list(2,4)))
+
+ # Generate x to x+1 for every value
+ actual <- collect(flatMapValues(intRdd, function(x) { x:(x + 1) }))
+ expect_equal(actual,
+ list(list(1L, -1), list(1L, 0), list(2L, 100), list(2L, 101),
+ list(2L, 1), list(2L, 2), list(1L, 200), list(1L, 201)))
+})
+
+test_that("reduceByKeyLocally() on PairwiseRDDs", {
+ pairs <- parallelize(sc, list(list(1, 2), list(1.1, 3), list(1, 4)), 2L)
+ actual <- reduceByKeyLocally(pairs, "+")
+ expect_equal(sortKeyValueList(actual),
+ sortKeyValueList(list(list(1, 6), list(1.1, 3))))
+
+ pairs <- parallelize(sc, list(list("abc", 1.2), list(1.1, 0), list("abc", 1.3),
+ list("bb", 5)), 4L)
+ actual <- reduceByKeyLocally(pairs, "+")
+ expect_equal(sortKeyValueList(actual),
+ sortKeyValueList(list(list("abc", 2.5), list(1.1, 0), list("bb", 5))))
+})
+
+test_that("distinct() on RDDs", {
+ nums.rep2 <- rep(1:10, 2)
+ rdd.rep2 <- parallelize(sc, nums.rep2, 2L)
+ uniques <- distinct(rdd.rep2)
+ actual <- sort(unlist(collect(uniques)))
+ expect_equal(actual, nums)
+})
+
+test_that("maximum() on RDDs", {
+ max <- maximum(rdd)
+ expect_equal(max, 10)
+})
+
+test_that("minimum() on RDDs", {
+ min <- minimum(rdd)
+ expect_equal(min, 1)
+})
+
+test_that("sumRDD() on RDDs", {
+ sum <- sumRDD(rdd)
+ expect_equal(sum, 55)
+})
+
+test_that("keyBy on RDDs", {
+ func <- function(x) { x*x }
+ keys <- keyBy(rdd, func)
+ actual <- collect(keys)
+ expect_equal(actual, lapply(nums, function(x) { list(func(x), x) }))
+})
+
+test_that("repartition/coalesce on RDDs", {
+ rdd <- parallelize(sc, 1:20, 4L) # each partition contains 5 elements
+
+ # repartition
+ r1 <- repartition(rdd, 2)
+ expect_equal(numPartitions(r1), 2L)
+ count <- length(collectPartition(r1, 0L))
+ expect_true(count >= 8 && count <= 12)
+
+ r2 <- repartition(rdd, 6)
+ expect_equal(numPartitions(r2), 6L)
+ count <- length(collectPartition(r2, 0L))
+ expect_true(count >=0 && count <= 4)
+
+ # coalesce
+ r3 <- coalesce(rdd, 1)
+ expect_equal(numPartitions(r3), 1L)
+ count <- length(collectPartition(r3, 0L))
+ expect_equal(count, 20)
+})
+
+test_that("sortBy() on RDDs", {
+ sortedRdd <- sortBy(rdd, function(x) { x * x }, ascending = FALSE)
+ actual <- collect(sortedRdd)
+ expect_equal(actual, as.list(sort(nums, decreasing = TRUE)))
+
+ rdd2 <- parallelize(sc, sort(nums, decreasing = TRUE), 2L)
+ sortedRdd2 <- sortBy(rdd2, function(x) { x * x })
+ actual <- collect(sortedRdd2)
+ expect_equal(actual, as.list(nums))
+})
+
+test_that("takeOrdered() on RDDs", {
+ l <- list(10, 1, 2, 9, 3, 4, 5, 6, 7)
+ rdd <- parallelize(sc, l)
+ actual <- takeOrdered(rdd, 6L)
+ expect_equal(actual, as.list(sort(unlist(l)))[1:6])
+
+ l <- list("e", "d", "c", "d", "a")
+ rdd <- parallelize(sc, l)
+ actual <- takeOrdered(rdd, 3L)
+ expect_equal(actual, as.list(sort(unlist(l)))[1:3])
+})
+
+test_that("top() on RDDs", {
+ l <- list(10, 1, 2, 9, 3, 4, 5, 6, 7)
+ rdd <- parallelize(sc, l)
+ actual <- top(rdd, 6L)
+ expect_equal(actual, as.list(sort(unlist(l), decreasing = TRUE))[1:6])
+
+ l <- list("e", "d", "c", "d", "a")
+ rdd <- parallelize(sc, l)
+ actual <- top(rdd, 3L)
+ expect_equal(actual, as.list(sort(unlist(l), decreasing = TRUE))[1:3])
+})
+
+test_that("fold() on RDDs", {
+ actual <- fold(rdd, 0, "+")
+ expect_equal(actual, Reduce("+", nums, 0))
+
+ rdd <- parallelize(sc, list())
+ actual <- fold(rdd, 0, "+")
+ expect_equal(actual, 0)
+})
+
+test_that("aggregateRDD() on RDDs", {
+ rdd <- parallelize(sc, list(1, 2, 3, 4))
+ zeroValue <- list(0, 0)
+ seqOp <- function(x, y) { list(x[[1]] + y, x[[2]] + 1) }
+ combOp <- function(x, y) { list(x[[1]] + y[[1]], x[[2]] + y[[2]]) }
+ actual <- aggregateRDD(rdd, zeroValue, seqOp, combOp)
+ expect_equal(actual, list(10, 4))
+
+ rdd <- parallelize(sc, list())
+ actual <- aggregateRDD(rdd, zeroValue, seqOp, combOp)
+ expect_equal(actual, list(0, 0))
+})
+
+test_that("zipWithUniqueId() on RDDs", {
+ rdd <- parallelize(sc, list("a", "b", "c", "d", "e"), 3L)
+ actual <- collect(zipWithUniqueId(rdd))
+ expected <- list(list("a", 0), list("b", 3), list("c", 1),
+ list("d", 4), list("e", 2))
+ expect_equal(actual, expected)
+
+ rdd <- parallelize(sc, list("a", "b", "c", "d", "e"), 1L)
+ actual <- collect(zipWithUniqueId(rdd))
+ expected <- list(list("a", 0), list("b", 1), list("c", 2),
+ list("d", 3), list("e", 4))
+ expect_equal(actual, expected)
+})
+
+test_that("zipWithIndex() on RDDs", {
+ rdd <- parallelize(sc, list("a", "b", "c", "d", "e"), 3L)
+ actual <- collect(zipWithIndex(rdd))
+ expected <- list(list("a", 0), list("b", 1), list("c", 2),
+ list("d", 3), list("e", 4))
+ expect_equal(actual, expected)
+
+ rdd <- parallelize(sc, list("a", "b", "c", "d", "e"), 1L)
+ actual <- collect(zipWithIndex(rdd))
+ expected <- list(list("a", 0), list("b", 1), list("c", 2),
+ list("d", 3), list("e", 4))
+ expect_equal(actual, expected)
+})
+
+test_that("glom() on RDD", {
+ rdd <- parallelize(sc, as.list(1:4), 2L)
+ actual <- collect(glom(rdd))
+ expect_equal(actual, list(list(1, 2), list(3, 4)))
+})
+
+test_that("keys() on RDDs", {
+ keys <- keys(intRdd)
+ actual <- collect(keys)
+ expect_equal(actual, lapply(intPairs, function(x) { x[[1]] }))
+})
+
+test_that("values() on RDDs", {
+ values <- values(intRdd)
+ actual <- collect(values)
+ expect_equal(actual, lapply(intPairs, function(x) { x[[2]] }))
+})
+
+test_that("pipeRDD() on RDDs", {
+ actual <- collect(pipeRDD(rdd, "more"))
+ expected <- as.list(as.character(1:10))
+ expect_equal(actual, expected)
+
+ trailed.rdd <- parallelize(sc, c("1", "", "2\n", "3\n\r\n"))
+ actual <- collect(pipeRDD(trailed.rdd, "sort"))
+ expected <- list("", "1", "2", "3")
+ expect_equal(actual, expected)
+
+ rev.nums <- 9:0
+ rev.rdd <- parallelize(sc, rev.nums, 2L)
+ actual <- collect(pipeRDD(rev.rdd, "sort"))
+ expected <- as.list(as.character(c(5:9, 0:4)))
+ expect_equal(actual, expected)
+})
+
+test_that("zipRDD() on RDDs", {
+ rdd1 <- parallelize(sc, 0:4, 2)
+ rdd2 <- parallelize(sc, 1000:1004, 2)
+ actual <- collect(zipRDD(rdd1, rdd2))
+ expect_equal(actual,
+ list(list(0, 1000), list(1, 1001), list(2, 1002), list(3, 1003), list(4, 1004)))
+
+ mockFile = c("Spark is pretty.", "Spark is awesome.")
+ fileName <- tempfile(pattern="spark-test", fileext=".tmp")
+ writeLines(mockFile, fileName)
+
+ rdd <- textFile(sc, fileName, 1)
+ actual <- collect(zipRDD(rdd, rdd))
+ expected <- lapply(mockFile, function(x) { list(x ,x) })
+ expect_equal(actual, expected)
+
+ rdd1 <- parallelize(sc, 0:1, 1)
+ actual <- collect(zipRDD(rdd1, rdd))
+ expected <- lapply(0:1, function(x) { list(x, mockFile[x + 1]) })
+ expect_equal(actual, expected)
+
+ rdd1 <- map(rdd, function(x) { x })
+ actual <- collect(zipRDD(rdd, rdd1))
+ expected <- lapply(mockFile, function(x) { list(x, x) })
+ expect_equal(actual, expected)
+
+ unlink(fileName)
+})
+
+test_that("join() on pairwise RDDs", {
+ rdd1 <- parallelize(sc, list(list(1,1), list(2,4)))
+ rdd2 <- parallelize(sc, list(list(1,2), list(1,3)))
+ actual <- collect(join(rdd1, rdd2, 2L))
+ expect_equal(sortKeyValueList(actual),
+ sortKeyValueList(list(list(1, list(1, 2)), list(1, list(1, 3)))))
+
+ rdd1 <- parallelize(sc, list(list("a",1), list("b",4)))
+ rdd2 <- parallelize(sc, list(list("a",2), list("a",3)))
+ actual <- collect(join(rdd1, rdd2, 2L))
+ expect_equal(sortKeyValueList(actual),
+ sortKeyValueList(list(list("a", list(1, 2)), list("a", list(1, 3)))))
+
+ rdd1 <- parallelize(sc, list(list(1,1), list(2,2)))
+ rdd2 <- parallelize(sc, list(list(3,3), list(4,4)))
+ actual <- collect(join(rdd1, rdd2, 2L))
+ expect_equal(actual, list())
+
+ rdd1 <- parallelize(sc, list(list("a",1), list("b",2)))
+ rdd2 <- parallelize(sc, list(list("c",3), list("d",4)))
+ actual <- collect(join(rdd1, rdd2, 2L))
+ expect_equal(actual, list())
+})
+
+test_that("leftOuterJoin() on pairwise RDDs", {
+ rdd1 <- parallelize(sc, list(list(1,1), list(2,4)))
+ rdd2 <- parallelize(sc, list(list(1,2), list(1,3)))
+ actual <- collect(leftOuterJoin(rdd1, rdd2, 2L))
+ expected <- list(list(1, list(1, 2)), list(1, list(1, 3)), list(2, list(4, NULL)))
+ expect_equal(sortKeyValueList(actual),
+ sortKeyValueList(expected))
+
+ rdd1 <- parallelize(sc, list(list("a",1), list("b",4)))
+ rdd2 <- parallelize(sc, list(list("a",2), list("a",3)))
+ actual <- collect(leftOuterJoin(rdd1, rdd2, 2L))
+ expected <- list(list("b", list(4, NULL)), list("a", list(1, 2)), list("a", list(1, 3)))
+ expect_equal(sortKeyValueList(actual),
+ sortKeyValueList(expected))
+
+ rdd1 <- parallelize(sc, list(list(1,1), list(2,2)))
+ rdd2 <- parallelize(sc, list(list(3,3), list(4,4)))
+ actual <- collect(leftOuterJoin(rdd1, rdd2, 2L))
+ expected <- list(list(1, list(1, NULL)), list(2, list(2, NULL)))
+ expect_equal(sortKeyValueList(actual),
+ sortKeyValueList(expected))
+
+ rdd1 <- parallelize(sc, list(list("a",1), list("b",2)))
+ rdd2 <- parallelize(sc, list(list("c",3), list("d",4)))
+ actual <- collect(leftOuterJoin(rdd1, rdd2, 2L))
+ expected <- list(list("b", list(2, NULL)), list("a", list(1, NULL)))
+ expect_equal(sortKeyValueList(actual),
+ sortKeyValueList(expected))
+})
+
+test_that("rightOuterJoin() on pairwise RDDs", {
+ rdd1 <- parallelize(sc, list(list(1,2), list(1,3)))
+ rdd2 <- parallelize(sc, list(list(1,1), list(2,4)))
+ actual <- collect(rightOuterJoin(rdd1, rdd2, 2L))
+ expected <- list(list(1, list(2, 1)), list(1, list(3, 1)), list(2, list(NULL, 4)))
+ expect_equal(sortKeyValueList(actual), sortKeyValueList(expected))
+
+ rdd1 <- parallelize(sc, list(list("a",2), list("a",3)))
+ rdd2 <- parallelize(sc, list(list("a",1), list("b",4)))
+ actual <- collect(rightOuterJoin(rdd1, rdd2, 2L))
+ expected <- list(list("b", list(NULL, 4)), list("a", list(2, 1)), list("a", list(3, 1)))
+ expect_equal(sortKeyValueList(actual),
+ sortKeyValueList(expected))
+
+ rdd1 <- parallelize(sc, list(list(1,1), list(2,2)))
+ rdd2 <- parallelize(sc, list(list(3,3), list(4,4)))
+ actual <- collect(rightOuterJoin(rdd1, rdd2, 2L))
+ expect_equal(sortKeyValueList(actual),
+ sortKeyValueList(list(list(3, list(NULL, 3)), list(4, list(NULL, 4)))))
+
+ rdd1 <- parallelize(sc, list(list("a",1), list("b",2)))
+ rdd2 <- parallelize(sc, list(list("c",3), list("d",4)))
+ actual <- collect(rightOuterJoin(rdd1, rdd2, 2L))
+ expect_equal(sortKeyValueList(actual),
+ sortKeyValueList(list(list("d", list(NULL, 4)), list("c", list(NULL, 3)))))
+})
+
+test_that("fullOuterJoin() on pairwise RDDs", {
+ rdd1 <- parallelize(sc, list(list(1,2), list(1,3), list(3,3)))
+ rdd2 <- parallelize(sc, list(list(1,1), list(2,4)))
+ actual <- collect(fullOuterJoin(rdd1, rdd2, 2L))
+ expected <- list(list(1, list(2, 1)), list(1, list(3, 1)), list(2, list(NULL, 4)), list(3, list(3, NULL)))
+ expect_equal(sortKeyValueList(actual), sortKeyValueList(expected))
+
+ rdd1 <- parallelize(sc, list(list("a",2), list("a",3), list("c", 1)))
+ rdd2 <- parallelize(sc, list(list("a",1), list("b",4)))
+ actual <- collect(fullOuterJoin(rdd1, rdd2, 2L))
+ expected <- list(list("b", list(NULL, 4)), list("a", list(2, 1)), list("a", list(3, 1)), list("c", list(1, NULL)))
+ expect_equal(sortKeyValueList(actual),
+ sortKeyValueList(expected))
+
+ rdd1 <- parallelize(sc, list(list(1,1), list(2,2)))
+ rdd2 <- parallelize(sc, list(list(3,3), list(4,4)))
+ actual <- collect(fullOuterJoin(rdd1, rdd2, 2L))
+ expect_equal(sortKeyValueList(actual),
+ sortKeyValueList(list(list(1, list(1, NULL)), list(2, list(2, NULL)), list(3, list(NULL, 3)), list(4, list(NULL, 4)))))
+
+ rdd1 <- parallelize(sc, list(list("a",1), list("b",2)))
+ rdd2 <- parallelize(sc, list(list("c",3), list("d",4)))
+ actual <- collect(fullOuterJoin(rdd1, rdd2, 2L))
+ expect_equal(sortKeyValueList(actual),
+ sortKeyValueList(list(list("a", list(1, NULL)), list("b", list(2, NULL)), list("d", list(NULL, 4)), list("c", list(NULL, 3)))))
+})
+
+test_that("sortByKey() on pairwise RDDs", {
+ numPairsRdd <- map(rdd, function(x) { list (x, x) })
+ sortedRdd <- sortByKey(numPairsRdd, ascending = FALSE)
+ actual <- collect(sortedRdd)
+ numPairs <- lapply(nums, function(x) { list (x, x) })
+ expect_equal(actual, sortKeyValueList(numPairs, decreasing = TRUE))
+
+ rdd2 <- parallelize(sc, sort(nums, decreasing = TRUE), 2L)
+ numPairsRdd2 <- map(rdd2, function(x) { list (x, x) })
+ sortedRdd2 <- sortByKey(numPairsRdd2)
+ actual <- collect(sortedRdd2)
+ expect_equal(actual, numPairs)
+
+ # sort by string keys
+ l <- list(list("a", 1), list("b", 2), list("1", 3), list("d", 4), list("2", 5))
+ rdd3 <- parallelize(sc, l, 2L)
+ sortedRdd3 <- sortByKey(rdd3)
+ actual <- collect(sortedRdd3)
+ expect_equal(actual, list(list("1", 3), list("2", 5), list("a", 1), list("b", 2), list("d", 4)))
+
+ # test on the boundary cases
+
+ # boundary case 1: the RDD to be sorted has only 1 partition
+ rdd4 <- parallelize(sc, l, 1L)
+ sortedRdd4 <- sortByKey(rdd4)
+ actual <- collect(sortedRdd4)
+ expect_equal(actual, list(list("1", 3), list("2", 5), list("a", 1), list("b", 2), list("d", 4)))
+
+ # boundary case 2: the sorted RDD has only 1 partition
+ rdd5 <- parallelize(sc, l, 2L)
+ sortedRdd5 <- sortByKey(rdd5, numPartitions = 1L)
+ actual <- collect(sortedRdd5)
+ expect_equal(actual, list(list("1", 3), list("2", 5), list("a", 1), list("b", 2), list("d", 4)))
+
+ # boundary case 3: the RDD to be sorted has only 1 element
+ l2 <- list(list("a", 1))
+ rdd6 <- parallelize(sc, l2, 2L)
+ sortedRdd6 <- sortByKey(rdd6)
+ actual <- collect(sortedRdd6)
+ expect_equal(actual, l2)
+
+ # boundary case 4: the RDD to be sorted has 0 element
+ l3 <- list()
+ rdd7 <- parallelize(sc, l3, 2L)
+ sortedRdd7 <- sortByKey(rdd7)
+ actual <- collect(sortedRdd7)
+ expect_equal(actual, l3)
+})
+
+test_that("collectAsMap() on a pairwise RDD", {
+ rdd <- parallelize(sc, list(list(1, 2), list(3, 4)))
+ vals <- collectAsMap(rdd)
+ expect_equal(vals, list(`1` = 2, `3` = 4))
+
+ rdd <- parallelize(sc, list(list("a", 1), list("b", 2)))
+ vals <- collectAsMap(rdd)
+ expect_equal(vals, list(a = 1, b = 2))
+
+ rdd <- parallelize(sc, list(list(1.1, 2.2), list(1.2, 2.4)))
+ vals <- collectAsMap(rdd)
+ expect_equal(vals, list(`1.1` = 2.2, `1.2` = 2.4))
+
+ rdd <- parallelize(sc, list(list(1, "a"), list(2, "b")))
+ vals <- collectAsMap(rdd)
+ expect_equal(vals, list(`1` = "a", `2` = "b"))
+})
diff --git a/R/pkg/inst/tests/test_shuffle.R b/R/pkg/inst/tests/test_shuffle.R
new file mode 100644
index 0000000000..d1da8232ae
--- /dev/null
+++ b/R/pkg/inst/tests/test_shuffle.R
@@ -0,0 +1,209 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+context("partitionBy, groupByKey, reduceByKey etc.")
+
+# JavaSparkContext handle
+sc <- sparkR.init()
+
+# Data
+intPairs <- list(list(1L, -1), list(2L, 100), list(2L, 1), list(1L, 200))
+intRdd <- parallelize(sc, intPairs, 2L)
+
+doublePairs <- list(list(1.5, -1), list(2.5, 100), list(2.5, 1), list(1.5, 200))
+doubleRdd <- parallelize(sc, doublePairs, 2L)
+
+numPairs <- list(list(1L, 100), list(2L, 200), list(4L, -1), list(3L, 1),
+ list(3L, 0))
+numPairsRdd <- parallelize(sc, numPairs, length(numPairs))
+
+strList <- list("Dexter Morgan: Blood. Sometimes it sets my teeth on edge and ",
+ "Dexter Morgan: Harry and Dorris Morgan did a wonderful job ")
+strListRDD <- parallelize(sc, strList, 4)
+
+test_that("groupByKey for integers", {
+ grouped <- groupByKey(intRdd, 2L)
+
+ actual <- collect(grouped)
+
+ expected <- list(list(2L, list(100, 1)), list(1L, list(-1, 200)))
+ expect_equal(sortKeyValueList(actual), sortKeyValueList(expected))
+})
+
+test_that("groupByKey for doubles", {
+ grouped <- groupByKey(doubleRdd, 2L)
+
+ actual <- collect(grouped)
+
+ expected <- list(list(1.5, list(-1, 200)), list(2.5, list(100, 1)))
+ expect_equal(sortKeyValueList(actual), sortKeyValueList(expected))
+})
+
+test_that("reduceByKey for ints", {
+ reduced <- reduceByKey(intRdd, "+", 2L)
+
+ actual <- collect(reduced)
+
+ expected <- list(list(2L, 101), list(1L, 199))
+ expect_equal(sortKeyValueList(actual), sortKeyValueList(expected))
+})
+
+test_that("reduceByKey for doubles", {
+ reduced <- reduceByKey(doubleRdd, "+", 2L)
+ actual <- collect(reduced)
+
+ expected <- list(list(1.5, 199), list(2.5, 101))
+ expect_equal(sortKeyValueList(actual), sortKeyValueList(expected))
+})
+
+test_that("combineByKey for ints", {
+ reduced <- combineByKey(intRdd, function(x) { x }, "+", "+", 2L)
+
+ actual <- collect(reduced)
+
+ expected <- list(list(2L, 101), list(1L, 199))
+ expect_equal(sortKeyValueList(actual), sortKeyValueList(expected))
+})
+
+test_that("combineByKey for doubles", {
+ reduced <- combineByKey(doubleRdd, function(x) { x }, "+", "+", 2L)
+ actual <- collect(reduced)
+
+ expected <- list(list(1.5, 199), list(2.5, 101))
+ expect_equal(sortKeyValueList(actual), sortKeyValueList(expected))
+})
+
+test_that("aggregateByKey", {
+ # test aggregateByKey for int keys
+ rdd <- parallelize(sc, list(list(1, 1), list(1, 2), list(2, 3), list(2, 4)))
+
+ zeroValue <- list(0, 0)
+ seqOp <- function(x, y) { list(x[[1]] + y, x[[2]] + 1) }
+ combOp <- function(x, y) { list(x[[1]] + y[[1]], x[[2]] + y[[2]]) }
+ aggregatedRDD <- aggregateByKey(rdd, zeroValue, seqOp, combOp, 2L)
+
+ actual <- collect(aggregatedRDD)
+
+ expected <- list(list(1, list(3, 2)), list(2, list(7, 2)))
+ expect_equal(sortKeyValueList(actual), sortKeyValueList(expected))
+
+ # test aggregateByKey for string keys
+ rdd <- parallelize(sc, list(list("a", 1), list("a", 2), list("b", 3), list("b", 4)))
+
+ zeroValue <- list(0, 0)
+ seqOp <- function(x, y) { list(x[[1]] + y, x[[2]] + 1) }
+ combOp <- function(x, y) { list(x[[1]] + y[[1]], x[[2]] + y[[2]]) }
+ aggregatedRDD <- aggregateByKey(rdd, zeroValue, seqOp, combOp, 2L)
+
+ actual <- collect(aggregatedRDD)
+
+ expected <- list(list("a", list(3, 2)), list("b", list(7, 2)))
+ expect_equal(sortKeyValueList(actual), sortKeyValueList(expected))
+})
+
+test_that("foldByKey", {
+ # test foldByKey for int keys
+ folded <- foldByKey(intRdd, 0, "+", 2L)
+
+ actual <- collect(folded)
+
+ expected <- list(list(2L, 101), list(1L, 199))
+ expect_equal(sortKeyValueList(actual), sortKeyValueList(expected))
+
+ # test foldByKey for double keys
+ folded <- foldByKey(doubleRdd, 0, "+", 2L)
+
+ actual <- collect(folded)
+
+ expected <- list(list(1.5, 199), list(2.5, 101))
+ expect_equal(sortKeyValueList(actual), sortKeyValueList(expected))
+
+ # test foldByKey for string keys
+ stringKeyPairs <- list(list("a", -1), list("b", 100), list("b", 1), list("a", 200))
+
+ stringKeyRDD <- parallelize(sc, stringKeyPairs)
+ folded <- foldByKey(stringKeyRDD, 0, "+", 2L)
+
+ actual <- collect(folded)
+
+ expected <- list(list("b", 101), list("a", 199))
+ expect_equal(sortKeyValueList(actual), sortKeyValueList(expected))
+
+ # test foldByKey for empty pair RDD
+ rdd <- parallelize(sc, list())
+ folded <- foldByKey(rdd, 0, "+", 2L)
+ actual <- collect(folded)
+ expected <- list()
+ expect_equal(actual, expected)
+
+ # test foldByKey for RDD with only 1 pair
+ rdd <- parallelize(sc, list(list(1, 1)))
+ folded <- foldByKey(rdd, 0, "+", 2L)
+ actual <- collect(folded)
+ expected <- list(list(1, 1))
+ expect_equal(actual, expected)
+})
+
+test_that("partitionBy() partitions data correctly", {
+ # Partition by magnitude
+ partitionByMagnitude <- function(key) { if (key >= 3) 1 else 0 }
+
+ resultRDD <- partitionBy(numPairsRdd, 2L, partitionByMagnitude)
+
+ expected_first <- list(list(1, 100), list(2, 200)) # key < 3
+ expected_second <- list(list(4, -1), list(3, 1), list(3, 0)) # key >= 3
+ actual_first <- collectPartition(resultRDD, 0L)
+ actual_second <- collectPartition(resultRDD, 1L)
+
+ expect_equal(sortKeyValueList(actual_first), sortKeyValueList(expected_first))
+ expect_equal(sortKeyValueList(actual_second), sortKeyValueList(expected_second))
+})
+
+test_that("partitionBy works with dependencies", {
+ kOne <- 1
+ partitionByParity <- function(key) { if (key %% 2 == kOne) 7 else 4 }
+
+ # Partition by parity
+ resultRDD <- partitionBy(numPairsRdd, numPartitions = 2L, partitionByParity)
+
+ # keys even; 100 %% 2 == 0
+ expected_first <- list(list(2, 200), list(4, -1))
+ # keys odd; 3 %% 2 == 1
+ expected_second <- list(list(1, 100), list(3, 1), list(3, 0))
+ actual_first <- collectPartition(resultRDD, 0L)
+ actual_second <- collectPartition(resultRDD, 1L)
+
+ expect_equal(sortKeyValueList(actual_first), sortKeyValueList(expected_first))
+ expect_equal(sortKeyValueList(actual_second), sortKeyValueList(expected_second))
+})
+
+test_that("test partitionBy with string keys", {
+ words <- flatMap(strListRDD, function(line) { strsplit(line, " ")[[1]] })
+ wordCount <- lapply(words, function(word) { list(word, 1L) })
+
+ resultRDD <- partitionBy(wordCount, 2L)
+ expected_first <- list(list("Dexter", 1), list("Dexter", 1))
+ expected_second <- list(list("and", 1), list("and", 1))
+
+ actual_first <- Filter(function(item) { item[[1]] == "Dexter" },
+ collectPartition(resultRDD, 0L))
+ actual_second <- Filter(function(item) { item[[1]] == "and" },
+ collectPartition(resultRDD, 1L))
+
+ expect_equal(sortKeyValueList(actual_first), sortKeyValueList(expected_first))
+ expect_equal(sortKeyValueList(actual_second), sortKeyValueList(expected_second))
+})
diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R
new file mode 100644
index 0000000000..cf5cf6d169
--- /dev/null
+++ b/R/pkg/inst/tests/test_sparkSQL.R
@@ -0,0 +1,695 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+library(testthat)
+
+context("SparkSQL functions")
+
+# Tests for SparkSQL functions in SparkR
+
+sc <- sparkR.init()
+
+sqlCtx <- sparkRSQL.init(sc)
+
+mockLines <- c("{\"name\":\"Michael\"}",
+ "{\"name\":\"Andy\", \"age\":30}",
+ "{\"name\":\"Justin\", \"age\":19}")
+jsonPath <- tempfile(pattern="sparkr-test", fileext=".tmp")
+parquetPath <- tempfile(pattern="sparkr-test", fileext=".parquet")
+writeLines(mockLines, jsonPath)
+
+test_that("infer types", {
+ expect_equal(infer_type(1L), "integer")
+ expect_equal(infer_type(1.0), "double")
+ expect_equal(infer_type("abc"), "string")
+ expect_equal(infer_type(TRUE), "boolean")
+ expect_equal(infer_type(as.Date("2015-03-11")), "date")
+ expect_equal(infer_type(as.POSIXlt("2015-03-11 12:13:04.043")), "timestamp")
+ expect_equal(infer_type(c(1L, 2L)),
+ list(type = 'array', elementType = "integer", containsNull = TRUE))
+ expect_equal(infer_type(list(1L, 2L)),
+ list(type = 'array', elementType = "integer", containsNull = TRUE))
+ expect_equal(infer_type(list(a = 1L, b = "2")),
+ list(type = "struct",
+ fields = list(list(name = "a", type = "integer", nullable = TRUE),
+ list(name = "b", type = "string", nullable = TRUE))))
+ e <- new.env()
+ assign("a", 1L, envir = e)
+ expect_equal(infer_type(e),
+ list(type = "map", keyType = "string", valueType = "integer",
+ valueContainsNull = TRUE))
+})
+
+test_that("create DataFrame from RDD", {
+ rdd <- lapply(parallelize(sc, 1:10), function(x) { list(x, as.character(x)) })
+ df <- createDataFrame(sqlCtx, rdd, list("a", "b"))
+ expect_true(inherits(df, "DataFrame"))
+ expect_true(count(df) == 10)
+ expect_equal(columns(df), c("a", "b"))
+ expect_equal(dtypes(df), list(c("a", "int"), c("b", "string")))
+
+ df <- createDataFrame(sqlCtx, rdd)
+ expect_true(inherits(df, "DataFrame"))
+ expect_equal(columns(df), c("_1", "_2"))
+
+ fields <- list(list(name = "a", type = "integer", nullable = TRUE),
+ list(name = "b", type = "string", nullable = TRUE))
+ schema <- list(type = "struct", fields = fields)
+ df <- createDataFrame(sqlCtx, rdd, schema)
+ expect_true(inherits(df, "DataFrame"))
+ expect_equal(columns(df), c("a", "b"))
+ expect_equal(dtypes(df), list(c("a", "int"), c("b", "string")))
+
+ rdd <- lapply(parallelize(sc, 1:10), function(x) { list(a = x, b = as.character(x)) })
+ df <- createDataFrame(sqlCtx, rdd)
+ expect_true(inherits(df, "DataFrame"))
+ expect_true(count(df) == 10)
+ expect_equal(columns(df), c("a", "b"))
+ expect_equal(dtypes(df), list(c("a", "int"), c("b", "string")))
+})
+
+test_that("toDF", {
+ rdd <- lapply(parallelize(sc, 1:10), function(x) { list(x, as.character(x)) })
+ df <- toDF(rdd, list("a", "b"))
+ expect_true(inherits(df, "DataFrame"))
+ expect_true(count(df) == 10)
+ expect_equal(columns(df), c("a", "b"))
+ expect_equal(dtypes(df), list(c("a", "int"), c("b", "string")))
+
+ df <- toDF(rdd)
+ expect_true(inherits(df, "DataFrame"))
+ expect_equal(columns(df), c("_1", "_2"))
+
+ fields <- list(list(name = "a", type = "integer", nullable = TRUE),
+ list(name = "b", type = "string", nullable = TRUE))
+ schema <- list(type = "struct", fields = fields)
+ df <- toDF(rdd, schema)
+ expect_true(inherits(df, "DataFrame"))
+ expect_equal(columns(df), c("a", "b"))
+ expect_equal(dtypes(df), list(c("a", "int"), c("b", "string")))
+
+ rdd <- lapply(parallelize(sc, 1:10), function(x) { list(a = x, b = as.character(x)) })
+ df <- toDF(rdd)
+ expect_true(inherits(df, "DataFrame"))
+ expect_true(count(df) == 10)
+ expect_equal(columns(df), c("a", "b"))
+ expect_equal(dtypes(df), list(c("a", "int"), c("b", "string")))
+})
+
+test_that("create DataFrame from list or data.frame", {
+ l <- list(list(1, 2), list(3, 4))
+ df <- createDataFrame(sqlCtx, l, c("a", "b"))
+ expect_equal(columns(df), c("a", "b"))
+
+ l <- list(list(a=1, b=2), list(a=3, b=4))
+ df <- createDataFrame(sqlCtx, l)
+ expect_equal(columns(df), c("a", "b"))
+
+ a <- 1:3
+ b <- c("a", "b", "c")
+ ldf <- data.frame(a, b)
+ df <- createDataFrame(sqlCtx, ldf)
+ expect_equal(columns(df), c("a", "b"))
+ expect_equal(dtypes(df), list(c("a", "int"), c("b", "string")))
+ expect_equal(count(df), 3)
+ ldf2 <- collect(df)
+ expect_equal(ldf$a, ldf2$a)
+})
+
+test_that("create DataFrame with different data types", {
+ l <- list(a = 1L, b = 2, c = TRUE, d = "ss", e = as.Date("2012-12-13"),
+ f = as.POSIXct("2015-03-15 12:13:14.056"))
+ df <- createDataFrame(sqlCtx, list(l))
+ expect_equal(dtypes(df), list(c("a", "int"), c("b", "double"), c("c", "boolean"),
+ c("d", "string"), c("e", "date"), c("f", "timestamp")))
+ expect_equal(count(df), 1)
+ expect_equal(collect(df), data.frame(l, stringsAsFactors = FALSE))
+})
+
+# TODO: enable this test after fix serialization for nested object
+#test_that("create DataFrame with nested array and struct", {
+# e <- new.env()
+# assign("n", 3L, envir = e)
+# l <- list(1:10, list("a", "b"), e, list(a="aa", b=3L))
+# df <- createDataFrame(sqlCtx, list(l), c("a", "b", "c", "d"))
+# expect_equal(dtypes(df), list(c("a", "array<int>"), c("b", "array<string>"),
+# c("c", "map<string,int>"), c("d", "struct<a:string,b:int>")))
+# expect_equal(count(df), 1)
+# ldf <- collect(df)
+# expect_equal(ldf[1,], l[[1]])
+#})
+
+test_that("jsonFile() on a local file returns a DataFrame", {
+ df <- jsonFile(sqlCtx, jsonPath)
+ expect_true(inherits(df, "DataFrame"))
+ expect_true(count(df) == 3)
+})
+
+test_that("jsonRDD() on a RDD with json string", {
+ rdd <- parallelize(sc, mockLines)
+ expect_true(count(rdd) == 3)
+ df <- jsonRDD(sqlCtx, rdd)
+ expect_true(inherits(df, "DataFrame"))
+ expect_true(count(df) == 3)
+
+ rdd2 <- flatMap(rdd, function(x) c(x, x))
+ df <- jsonRDD(sqlCtx, rdd2)
+ expect_true(inherits(df, "DataFrame"))
+ expect_true(count(df) == 6)
+})
+
+test_that("test cache, uncache and clearCache", {
+ df <- jsonFile(sqlCtx, jsonPath)
+ registerTempTable(df, "table1")
+ cacheTable(sqlCtx, "table1")
+ uncacheTable(sqlCtx, "table1")
+ clearCache(sqlCtx)
+ dropTempTable(sqlCtx, "table1")
+})
+
+test_that("test tableNames and tables", {
+ df <- jsonFile(sqlCtx, jsonPath)
+ registerTempTable(df, "table1")
+ expect_true(length(tableNames(sqlCtx)) == 1)
+ df <- tables(sqlCtx)
+ expect_true(count(df) == 1)
+ dropTempTable(sqlCtx, "table1")
+})
+
+test_that("registerTempTable() results in a queryable table and sql() results in a new DataFrame", {
+ df <- jsonFile(sqlCtx, jsonPath)
+ registerTempTable(df, "table1")
+ newdf <- sql(sqlCtx, "SELECT * FROM table1 where name = 'Michael'")
+ expect_true(inherits(newdf, "DataFrame"))
+ expect_true(count(newdf) == 1)
+ dropTempTable(sqlCtx, "table1")
+})
+
+test_that("insertInto() on a registered table", {
+ df <- loadDF(sqlCtx, jsonPath, "json")
+ saveDF(df, parquetPath, "parquet", "overwrite")
+ dfParquet <- loadDF(sqlCtx, parquetPath, "parquet")
+
+ lines <- c("{\"name\":\"Bob\", \"age\":24}",
+ "{\"name\":\"James\", \"age\":35}")
+ jsonPath2 <- tempfile(pattern="jsonPath2", fileext=".tmp")
+ parquetPath2 <- tempfile(pattern = "parquetPath2", fileext = ".parquet")
+ writeLines(lines, jsonPath2)
+ df2 <- loadDF(sqlCtx, jsonPath2, "json")
+ saveDF(df2, parquetPath2, "parquet", "overwrite")
+ dfParquet2 <- loadDF(sqlCtx, parquetPath2, "parquet")
+
+ registerTempTable(dfParquet, "table1")
+ insertInto(dfParquet2, "table1")
+ expect_true(count(sql(sqlCtx, "select * from table1")) == 5)
+ expect_true(first(sql(sqlCtx, "select * from table1 order by age"))$name == "Michael")
+ dropTempTable(sqlCtx, "table1")
+
+ registerTempTable(dfParquet, "table1")
+ insertInto(dfParquet2, "table1", overwrite = TRUE)
+ expect_true(count(sql(sqlCtx, "select * from table1")) == 2)
+ expect_true(first(sql(sqlCtx, "select * from table1 order by age"))$name == "Bob")
+ dropTempTable(sqlCtx, "table1")
+})
+
+test_that("table() returns a new DataFrame", {
+ df <- jsonFile(sqlCtx, jsonPath)
+ registerTempTable(df, "table1")
+ tabledf <- table(sqlCtx, "table1")
+ expect_true(inherits(tabledf, "DataFrame"))
+ expect_true(count(tabledf) == 3)
+ dropTempTable(sqlCtx, "table1")
+})
+
+test_that("toRDD() returns an RRDD", {
+ df <- jsonFile(sqlCtx, jsonPath)
+ testRDD <- toRDD(df)
+ expect_true(inherits(testRDD, "RDD"))
+ expect_true(count(testRDD) == 3)
+})
+
+test_that("union on two RDDs created from DataFrames returns an RRDD", {
+ df <- jsonFile(sqlCtx, jsonPath)
+ RDD1 <- toRDD(df)
+ RDD2 <- toRDD(df)
+ unioned <- unionRDD(RDD1, RDD2)
+ expect_true(inherits(unioned, "RDD"))
+ expect_true(SparkR:::getSerializedMode(unioned) == "byte")
+ expect_true(collect(unioned)[[2]]$name == "Andy")
+})
+
+test_that("union on mixed serialization types correctly returns a byte RRDD", {
+ # Byte RDD
+ nums <- 1:10
+ rdd <- parallelize(sc, nums, 2L)
+
+ # String RDD
+ textLines <- c("Michael",
+ "Andy, 30",
+ "Justin, 19")
+ textPath <- tempfile(pattern="sparkr-textLines", fileext=".tmp")
+ writeLines(textLines, textPath)
+ textRDD <- textFile(sc, textPath)
+
+ df <- jsonFile(sqlCtx, jsonPath)
+ dfRDD <- toRDD(df)
+
+ unionByte <- unionRDD(rdd, dfRDD)
+ expect_true(inherits(unionByte, "RDD"))
+ expect_true(SparkR:::getSerializedMode(unionByte) == "byte")
+ expect_true(collect(unionByte)[[1]] == 1)
+ expect_true(collect(unionByte)[[12]]$name == "Andy")
+
+ unionString <- unionRDD(textRDD, dfRDD)
+ expect_true(inherits(unionString, "RDD"))
+ expect_true(SparkR:::getSerializedMode(unionString) == "byte")
+ expect_true(collect(unionString)[[1]] == "Michael")
+ expect_true(collect(unionString)[[5]]$name == "Andy")
+})
+
+test_that("objectFile() works with row serialization", {
+ objectPath <- tempfile(pattern="spark-test", fileext=".tmp")
+ df <- jsonFile(sqlCtx, jsonPath)
+ dfRDD <- toRDD(df)
+ saveAsObjectFile(coalesce(dfRDD, 1L), objectPath)
+ objectIn <- objectFile(sc, objectPath)
+
+ expect_true(inherits(objectIn, "RDD"))
+ expect_equal(SparkR:::getSerializedMode(objectIn), "byte")
+ expect_equal(collect(objectIn)[[2]]$age, 30)
+})
+
+test_that("lapply() on a DataFrame returns an RDD with the correct columns", {
+ df <- jsonFile(sqlCtx, jsonPath)
+ testRDD <- lapply(df, function(row) {
+ row$newCol <- row$age + 5
+ row
+ })
+ expect_true(inherits(testRDD, "RDD"))
+ collected <- collect(testRDD)
+ expect_true(collected[[1]]$name == "Michael")
+ expect_true(collected[[2]]$newCol == "35")
+})
+
+test_that("collect() returns a data.frame", {
+ df <- jsonFile(sqlCtx, jsonPath)
+ rdf <- collect(df)
+ expect_true(is.data.frame(rdf))
+ expect_true(names(rdf)[1] == "age")
+ expect_true(nrow(rdf) == 3)
+ expect_true(ncol(rdf) == 2)
+})
+
+test_that("limit() returns DataFrame with the correct number of rows", {
+ df <- jsonFile(sqlCtx, jsonPath)
+ dfLimited <- limit(df, 2)
+ expect_true(inherits(dfLimited, "DataFrame"))
+ expect_true(count(dfLimited) == 2)
+})
+
+test_that("collect() and take() on a DataFrame return the same number of rows and columns", {
+ df <- jsonFile(sqlCtx, jsonPath)
+ expect_true(nrow(collect(df)) == nrow(take(df, 10)))
+ expect_true(ncol(collect(df)) == ncol(take(df, 10)))
+})
+
+test_that("multiple pipeline transformations starting with a DataFrame result in an RDD with the correct values", {
+ df <- jsonFile(sqlCtx, jsonPath)
+ first <- lapply(df, function(row) {
+ row$age <- row$age + 5
+ row
+ })
+ second <- lapply(first, function(row) {
+ row$testCol <- if (row$age == 35 && !is.na(row$age)) TRUE else FALSE
+ row
+ })
+ expect_true(inherits(second, "RDD"))
+ expect_true(count(second) == 3)
+ expect_true(collect(second)[[2]]$age == 35)
+ expect_true(collect(second)[[2]]$testCol)
+ expect_false(collect(second)[[3]]$testCol)
+})
+
+test_that("cache(), persist(), and unpersist() on a DataFrame", {
+ df <- jsonFile(sqlCtx, jsonPath)
+ expect_false(df@env$isCached)
+ cache(df)
+ expect_true(df@env$isCached)
+
+ unpersist(df)
+ expect_false(df@env$isCached)
+
+ persist(df, "MEMORY_AND_DISK")
+ expect_true(df@env$isCached)
+
+ unpersist(df)
+ expect_false(df@env$isCached)
+
+ # make sure the data is collectable
+ expect_true(is.data.frame(collect(df)))
+})
+
+test_that("schema(), dtypes(), columns(), names() return the correct values/format", {
+ df <- jsonFile(sqlCtx, jsonPath)
+ testSchema <- schema(df)
+ expect_true(length(testSchema$fields()) == 2)
+ expect_true(testSchema$fields()[[1]]$dataType.toString() == "LongType")
+ expect_true(testSchema$fields()[[2]]$dataType.simpleString() == "string")
+ expect_true(testSchema$fields()[[1]]$name() == "age")
+
+ testTypes <- dtypes(df)
+ expect_true(length(testTypes[[1]]) == 2)
+ expect_true(testTypes[[1]][1] == "age")
+
+ testCols <- columns(df)
+ expect_true(length(testCols) == 2)
+ expect_true(testCols[2] == "name")
+
+ testNames <- names(df)
+ expect_true(length(testNames) == 2)
+ expect_true(testNames[2] == "name")
+})
+
+test_that("head() and first() return the correct data", {
+ df <- jsonFile(sqlCtx, jsonPath)
+ testHead <- head(df)
+ expect_true(nrow(testHead) == 3)
+ expect_true(ncol(testHead) == 2)
+
+ testHead2 <- head(df, 2)
+ expect_true(nrow(testHead2) == 2)
+ expect_true(ncol(testHead2) == 2)
+
+ testFirst <- first(df)
+ expect_true(nrow(testFirst) == 1)
+})
+
+test_that("distinct() on DataFrames", {
+ lines <- c("{\"name\":\"Michael\"}",
+ "{\"name\":\"Andy\", \"age\":30}",
+ "{\"name\":\"Justin\", \"age\":19}",
+ "{\"name\":\"Justin\", \"age\":19}")
+ jsonPathWithDup <- tempfile(pattern="sparkr-test", fileext=".tmp")
+ writeLines(lines, jsonPathWithDup)
+
+ df <- jsonFile(sqlCtx, jsonPathWithDup)
+ uniques <- distinct(df)
+ expect_true(inherits(uniques, "DataFrame"))
+ expect_true(count(uniques) == 3)
+})
+
+test_that("sampleDF on a DataFrame", {
+ df <- jsonFile(sqlCtx, jsonPath)
+ sampled <- sampleDF(df, FALSE, 1.0)
+ expect_equal(nrow(collect(sampled)), count(df))
+ expect_true(inherits(sampled, "DataFrame"))
+ sampled2 <- sampleDF(df, FALSE, 0.1)
+ expect_true(count(sampled2) < 3)
+})
+
+test_that("select operators", {
+ df <- select(jsonFile(sqlCtx, jsonPath), "name", "age")
+ expect_true(inherits(df$name, "Column"))
+ expect_true(inherits(df[[2]], "Column"))
+ expect_true(inherits(df[["age"]], "Column"))
+
+ expect_true(inherits(df[,1], "DataFrame"))
+ expect_equal(columns(df[,1]), c("name"))
+ expect_equal(columns(df[,"age"]), c("age"))
+ df2 <- df[,c("age", "name")]
+ expect_true(inherits(df2, "DataFrame"))
+ expect_equal(columns(df2), c("age", "name"))
+
+ df$age2 <- df$age
+ expect_equal(columns(df), c("name", "age", "age2"))
+ expect_equal(count(where(df, df$age2 == df$age)), 2)
+ df$age2 <- df$age * 2
+ expect_equal(columns(df), c("name", "age", "age2"))
+ expect_equal(count(where(df, df$age2 == df$age * 2)), 2)
+})
+
+test_that("select with column", {
+ df <- jsonFile(sqlCtx, jsonPath)
+ df1 <- select(df, "name")
+ expect_true(columns(df1) == c("name"))
+ expect_true(count(df1) == 3)
+
+ df2 <- select(df, df$age)
+ expect_true(columns(df2) == c("age"))
+ expect_true(count(df2) == 3)
+})
+
+test_that("selectExpr() on a DataFrame", {
+ df <- jsonFile(sqlCtx, jsonPath)
+ selected <- selectExpr(df, "age * 2")
+ expect_true(names(selected) == "(age * 2)")
+ expect_equal(collect(selected), collect(select(df, df$age * 2L)))
+
+ selected2 <- selectExpr(df, "name as newName", "abs(age) as age")
+ expect_equal(names(selected2), c("newName", "age"))
+ expect_true(count(selected2) == 3)
+})
+
+test_that("column calculation", {
+ df <- jsonFile(sqlCtx, jsonPath)
+ d <- collect(select(df, alias(df$age + 1, "age2")))
+ expect_true(names(d) == c("age2"))
+ df2 <- select(df, lower(df$name), abs(df$age))
+ expect_true(inherits(df2, "DataFrame"))
+ expect_true(count(df2) == 3)
+})
+
+test_that("load() from json file", {
+ df <- loadDF(sqlCtx, jsonPath, "json")
+ expect_true(inherits(df, "DataFrame"))
+ expect_true(count(df) == 3)
+})
+
+test_that("save() as parquet file", {
+ df <- loadDF(sqlCtx, jsonPath, "json")
+ saveDF(df, parquetPath, "parquet", mode="overwrite")
+ df2 <- loadDF(sqlCtx, parquetPath, "parquet")
+ expect_true(inherits(df2, "DataFrame"))
+ expect_true(count(df2) == 3)
+})
+
+test_that("test HiveContext", {
+ hiveCtx <- tryCatch({
+ newJObject("org.apache.spark.sql.hive.test.TestHiveContext", ssc)
+ }, error = function(err) {
+ skip("Hive is not build with SparkSQL, skipped")
+ })
+ df <- createExternalTable(hiveCtx, "json", jsonPath, "json")
+ expect_true(inherits(df, "DataFrame"))
+ expect_true(count(df) == 3)
+ df2 <- sql(hiveCtx, "select * from json")
+ expect_true(inherits(df2, "DataFrame"))
+ expect_true(count(df2) == 3)
+
+ jsonPath2 <- tempfile(pattern="sparkr-test", fileext=".tmp")
+ saveAsTable(df, "json", "json", "append", path = jsonPath2)
+ df3 <- sql(hiveCtx, "select * from json")
+ expect_true(inherits(df3, "DataFrame"))
+ expect_true(count(df3) == 6)
+})
+
+test_that("column operators", {
+ c <- SparkR:::col("a")
+ c2 <- (- c + 1 - 2) * 3 / 4.0
+ c3 <- (c + c2 - c2) * c2 %% c2
+ c4 <- (c > c2) & (c2 <= c3) | (c == c2) & (c2 != c3)
+})
+
+test_that("column functions", {
+ c <- SparkR:::col("a")
+ c2 <- min(c) + max(c) + sum(c) + avg(c) + count(c) + abs(c) + sqrt(c)
+ c3 <- lower(c) + upper(c) + first(c) + last(c)
+ c4 <- approxCountDistinct(c) + countDistinct(c) + cast(c, "string")
+})
+
+test_that("string operators", {
+ df <- jsonFile(sqlCtx, jsonPath)
+ expect_equal(count(where(df, like(df$name, "A%"))), 1)
+ expect_equal(count(where(df, startsWith(df$name, "A"))), 1)
+ expect_equal(first(select(df, substr(df$name, 1, 2)))[[1]], "Mi")
+ expect_equal(collect(select(df, cast(df$age, "string")))[[2, 1]], "30")
+})
+
+test_that("group by", {
+ df <- jsonFile(sqlCtx, jsonPath)
+ df1 <- agg(df, name = "max", age = "sum")
+ expect_true(1 == count(df1))
+ df1 <- agg(df, age2 = max(df$age))
+ expect_true(1 == count(df1))
+ expect_equal(columns(df1), c("age2"))
+
+ gd <- groupBy(df, "name")
+ expect_true(inherits(gd, "GroupedData"))
+ df2 <- count(gd)
+ expect_true(inherits(df2, "DataFrame"))
+ expect_true(3 == count(df2))
+
+ df3 <- agg(gd, age = "sum")
+ expect_true(inherits(df3, "DataFrame"))
+ expect_true(3 == count(df3))
+
+ df3 <- agg(gd, age = sum(df$age))
+ expect_true(inherits(df3, "DataFrame"))
+ expect_true(3 == count(df3))
+ expect_equal(columns(df3), c("name", "age"))
+
+ df4 <- sum(gd, "age")
+ expect_true(inherits(df4, "DataFrame"))
+ expect_true(3 == count(df4))
+ expect_true(3 == count(mean(gd, "age")))
+ expect_true(3 == count(max(gd, "age")))
+})
+
+test_that("sortDF() and orderBy() on a DataFrame", {
+ df <- jsonFile(sqlCtx, jsonPath)
+ sorted <- sortDF(df, df$age)
+ expect_true(collect(sorted)[1,2] == "Michael")
+
+ sorted2 <- sortDF(df, "name")
+ expect_true(collect(sorted2)[2,"age"] == 19)
+
+ sorted3 <- orderBy(df, asc(df$age))
+ expect_true(is.na(first(sorted3)$age))
+ expect_true(collect(sorted3)[2, "age"] == 19)
+
+ sorted4 <- orderBy(df, desc(df$name))
+ expect_true(first(sorted4)$name == "Michael")
+ expect_true(collect(sorted4)[3,"name"] == "Andy")
+})
+
+test_that("filter() on a DataFrame", {
+ df <- jsonFile(sqlCtx, jsonPath)
+ filtered <- filter(df, "age > 20")
+ expect_true(count(filtered) == 1)
+ expect_true(collect(filtered)$name == "Andy")
+ filtered2 <- where(df, df$name != "Michael")
+ expect_true(count(filtered2) == 2)
+ expect_true(collect(filtered2)$age[2] == 19)
+})
+
+test_that("join() on a DataFrame", {
+ df <- jsonFile(sqlCtx, jsonPath)
+
+ mockLines2 <- c("{\"name\":\"Michael\", \"test\": \"yes\"}",
+ "{\"name\":\"Andy\", \"test\": \"no\"}",
+ "{\"name\":\"Justin\", \"test\": \"yes\"}",
+ "{\"name\":\"Bob\", \"test\": \"yes\"}")
+ jsonPath2 <- tempfile(pattern="sparkr-test", fileext=".tmp")
+ writeLines(mockLines2, jsonPath2)
+ df2 <- jsonFile(sqlCtx, jsonPath2)
+
+ joined <- join(df, df2)
+ expect_equal(names(joined), c("age", "name", "name", "test"))
+ expect_true(count(joined) == 12)
+
+ joined2 <- join(df, df2, df$name == df2$name)
+ expect_equal(names(joined2), c("age", "name", "name", "test"))
+ expect_true(count(joined2) == 3)
+
+ joined3 <- join(df, df2, df$name == df2$name, "right_outer")
+ expect_equal(names(joined3), c("age", "name", "name", "test"))
+ expect_true(count(joined3) == 4)
+ expect_true(is.na(collect(orderBy(joined3, joined3$age))$age[2]))
+
+ joined4 <- select(join(df, df2, df$name == df2$name, "outer"),
+ alias(df$age + 5, "newAge"), df$name, df2$test)
+ expect_equal(names(joined4), c("newAge", "name", "test"))
+ expect_true(count(joined4) == 4)
+ expect_equal(collect(orderBy(joined4, joined4$name))$newAge[3], 24)
+})
+
+test_that("toJSON() returns an RDD of the correct values", {
+ df <- jsonFile(sqlCtx, jsonPath)
+ testRDD <- toJSON(df)
+ expect_true(inherits(testRDD, "RDD"))
+ expect_true(SparkR:::getSerializedMode(testRDD) == "string")
+ expect_equal(collect(testRDD)[[1]], mockLines[1])
+})
+
+test_that("showDF()", {
+ df <- jsonFile(sqlCtx, jsonPath)
+ expect_output(showDF(df), "age name \nnull Michael\n30 Andy \n19 Justin ")
+})
+
+test_that("isLocal()", {
+ df <- jsonFile(sqlCtx, jsonPath)
+ expect_false(isLocal(df))
+})
+
+test_that("unionAll(), subtract(), and intersect() on a DataFrame", {
+ df <- jsonFile(sqlCtx, jsonPath)
+
+ lines <- c("{\"name\":\"Bob\", \"age\":24}",
+ "{\"name\":\"Andy\", \"age\":30}",
+ "{\"name\":\"James\", \"age\":35}")
+ jsonPath2 <- tempfile(pattern="sparkr-test", fileext=".tmp")
+ writeLines(lines, jsonPath2)
+ df2 <- loadDF(sqlCtx, jsonPath2, "json")
+
+ unioned <- sortDF(unionAll(df, df2), df$age)
+ expect_true(inherits(unioned, "DataFrame"))
+ expect_true(count(unioned) == 6)
+ expect_true(first(unioned)$name == "Michael")
+
+ subtracted <- sortDF(subtract(df, df2), desc(df$age))
+ expect_true(inherits(unioned, "DataFrame"))
+ expect_true(count(subtracted) == 2)
+ expect_true(first(subtracted)$name == "Justin")
+
+ intersected <- sortDF(intersect(df, df2), df$age)
+ expect_true(inherits(unioned, "DataFrame"))
+ expect_true(count(intersected) == 1)
+ expect_true(first(intersected)$name == "Andy")
+})
+
+test_that("withColumn() and withColumnRenamed()", {
+ df <- jsonFile(sqlCtx, jsonPath)
+ newDF <- withColumn(df, "newAge", df$age + 2)
+ expect_true(length(columns(newDF)) == 3)
+ expect_true(columns(newDF)[3] == "newAge")
+ expect_true(first(filter(newDF, df$name != "Michael"))$newAge == 32)
+
+ newDF2 <- withColumnRenamed(df, "age", "newerAge")
+ expect_true(length(columns(newDF2)) == 2)
+ expect_true(columns(newDF2)[1] == "newerAge")
+})
+
+test_that("saveDF() on DataFrame and works with parquetFile", {
+ df <- jsonFile(sqlCtx, jsonPath)
+ saveDF(df, parquetPath, "parquet", mode="overwrite")
+ parquetDF <- parquetFile(sqlCtx, parquetPath)
+ expect_true(inherits(parquetDF, "DataFrame"))
+ expect_equal(count(df), count(parquetDF))
+})
+
+test_that("parquetFile works with multiple input paths", {
+ df <- jsonFile(sqlCtx, jsonPath)
+ saveDF(df, parquetPath, "parquet", mode="overwrite")
+ parquetPath2 <- tempfile(pattern = "parquetPath2", fileext = ".parquet")
+ saveDF(df, parquetPath2, "parquet", mode="overwrite")
+ parquetDF <- parquetFile(sqlCtx, parquetPath, parquetPath2)
+ expect_true(inherits(parquetDF, "DataFrame"))
+ expect_true(count(parquetDF) == count(df)*2)
+})
+
+unlink(parquetPath)
+unlink(jsonPath)
diff --git a/R/pkg/inst/tests/test_take.R b/R/pkg/inst/tests/test_take.R
new file mode 100644
index 0000000000..7f4c7c315d
--- /dev/null
+++ b/R/pkg/inst/tests/test_take.R
@@ -0,0 +1,67 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+context("tests RDD function take()")
+
+# Mock data
+numVector <- c(-10:97)
+numList <- list(sqrt(1), sqrt(2), sqrt(3), 4 ** 10)
+strVector <- c("Dexter Morgan: I suppose I should be upset, even feel",
+ "violated, but I'm not. No, in fact, I think this is a friendly",
+ "message, like \"Hey, wanna play?\" and yes, I want to play. ",
+ "I really, really do.")
+strList <- list("Dexter Morgan: Blood. Sometimes it sets my teeth on edge, ",
+ "other times it helps me control the chaos.",
+ "Dexter Morgan: Harry and Dorris Morgan did a wonderful job ",
+ "raising me. But they're both dead now. I didn't kill them. Honest.")
+
+# JavaSparkContext handle
+jsc <- sparkR.init()
+
+test_that("take() gives back the original elements in correct count and order", {
+ numVectorRDD <- parallelize(jsc, numVector, 10)
+ # case: number of elements to take is less than the size of the first partition
+ expect_equal(take(numVectorRDD, 1), as.list(head(numVector, n = 1)))
+ # case: number of elements to take is the same as the size of the first partition
+ expect_equal(take(numVectorRDD, 11), as.list(head(numVector, n = 11)))
+ # case: number of elements to take is greater than all elements
+ expect_equal(take(numVectorRDD, length(numVector)), as.list(numVector))
+ expect_equal(take(numVectorRDD, length(numVector) + 1), as.list(numVector))
+
+ numListRDD <- parallelize(jsc, numList, 1)
+ numListRDD2 <- parallelize(jsc, numList, 4)
+ expect_equal(take(numListRDD, 3), take(numListRDD2, 3))
+ expect_equal(take(numListRDD, 5), take(numListRDD2, 5))
+ expect_equal(take(numListRDD, 1), as.list(head(numList, n = 1)))
+ expect_equal(take(numListRDD2, 999), numList)
+
+ strVectorRDD <- parallelize(jsc, strVector, 2)
+ strVectorRDD2 <- parallelize(jsc, strVector, 3)
+ expect_equal(take(strVectorRDD, 4), as.list(strVector))
+ expect_equal(take(strVectorRDD2, 2), as.list(head(strVector, n = 2)))
+
+ strListRDD <- parallelize(jsc, strList, 4)
+ strListRDD2 <- parallelize(jsc, strList, 1)
+ expect_equal(take(strListRDD, 3), as.list(head(strList, n = 3)))
+ expect_equal(take(strListRDD2, 1), as.list(head(strList, n = 1)))
+
+ expect_true(length(take(strListRDD, 0)) == 0)
+ expect_true(length(take(strVectorRDD, 0)) == 0)
+ expect_true(length(take(numListRDD, 0)) == 0)
+ expect_true(length(take(numVectorRDD, 0)) == 0)
+})
+
diff --git a/R/pkg/inst/tests/test_textFile.R b/R/pkg/inst/tests/test_textFile.R
new file mode 100644
index 0000000000..7bb3e80031
--- /dev/null
+++ b/R/pkg/inst/tests/test_textFile.R
@@ -0,0 +1,162 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+context("the textFile() function")
+
+# JavaSparkContext handle
+sc <- sparkR.init()
+
+mockFile = c("Spark is pretty.", "Spark is awesome.")
+
+test_that("textFile() on a local file returns an RDD", {
+ fileName <- tempfile(pattern="spark-test", fileext=".tmp")
+ writeLines(mockFile, fileName)
+
+ rdd <- textFile(sc, fileName)
+ expect_true(inherits(rdd, "RDD"))
+ expect_true(count(rdd) > 0)
+ expect_true(count(rdd) == 2)
+
+ unlink(fileName)
+})
+
+test_that("textFile() followed by a collect() returns the same content", {
+ fileName <- tempfile(pattern="spark-test", fileext=".tmp")
+ writeLines(mockFile, fileName)
+
+ rdd <- textFile(sc, fileName)
+ expect_equal(collect(rdd), as.list(mockFile))
+
+ unlink(fileName)
+})
+
+test_that("textFile() word count works as expected", {
+ fileName <- tempfile(pattern="spark-test", fileext=".tmp")
+ writeLines(mockFile, fileName)
+
+ rdd <- textFile(sc, fileName)
+
+ words <- flatMap(rdd, function(line) { strsplit(line, " ")[[1]] })
+ wordCount <- lapply(words, function(word) { list(word, 1L) })
+
+ counts <- reduceByKey(wordCount, "+", 2L)
+ output <- collect(counts)
+ expected <- list(list("pretty.", 1), list("is", 2), list("awesome.", 1),
+ list("Spark", 2))
+ expect_equal(sortKeyValueList(output), sortKeyValueList(expected))
+
+ unlink(fileName)
+})
+
+test_that("several transformations on RDD created by textFile()", {
+ fileName <- tempfile(pattern="spark-test", fileext=".tmp")
+ writeLines(mockFile, fileName)
+
+ rdd <- textFile(sc, fileName) # RDD
+ for (i in 1:10) {
+ # PipelinedRDD initially created from RDD
+ rdd <- lapply(rdd, function(x) paste(x, x))
+ }
+ collect(rdd)
+
+ unlink(fileName)
+})
+
+test_that("textFile() followed by a saveAsTextFile() returns the same content", {
+ fileName1 <- tempfile(pattern="spark-test", fileext=".tmp")
+ fileName2 <- tempfile(pattern="spark-test", fileext=".tmp")
+ writeLines(mockFile, fileName1)
+
+ rdd <- textFile(sc, fileName1)
+ saveAsTextFile(rdd, fileName2)
+ rdd <- textFile(sc, fileName2)
+ expect_equal(collect(rdd), as.list(mockFile))
+
+ unlink(fileName1)
+ unlink(fileName2)
+})
+
+test_that("saveAsTextFile() on a parallelized list works as expected", {
+ fileName <- tempfile(pattern="spark-test", fileext=".tmp")
+ l <- list(1, 2, 3)
+ rdd <- parallelize(sc, l)
+ saveAsTextFile(rdd, fileName)
+ rdd <- textFile(sc, fileName)
+ expect_equal(collect(rdd), lapply(l, function(x) {toString(x)}))
+
+ unlink(fileName)
+})
+
+test_that("textFile() and saveAsTextFile() word count works as expected", {
+ fileName1 <- tempfile(pattern="spark-test", fileext=".tmp")
+ fileName2 <- tempfile(pattern="spark-test", fileext=".tmp")
+ writeLines(mockFile, fileName1)
+
+ rdd <- textFile(sc, fileName1)
+
+ words <- flatMap(rdd, function(line) { strsplit(line, " ")[[1]] })
+ wordCount <- lapply(words, function(word) { list(word, 1L) })
+
+ counts <- reduceByKey(wordCount, "+", 2L)
+
+ saveAsTextFile(counts, fileName2)
+ rdd <- textFile(sc, fileName2)
+
+ output <- collect(rdd)
+ expected <- list(list("awesome.", 1), list("Spark", 2),
+ list("pretty.", 1), list("is", 2))
+ expectedStr <- lapply(expected, function(x) { toString(x) })
+ expect_equal(sortKeyValueList(output), sortKeyValueList(expectedStr))
+
+ unlink(fileName1)
+ unlink(fileName2)
+})
+
+test_that("textFile() on multiple paths", {
+ fileName1 <- tempfile(pattern="spark-test", fileext=".tmp")
+ fileName2 <- tempfile(pattern="spark-test", fileext=".tmp")
+ writeLines("Spark is pretty.", fileName1)
+ writeLines("Spark is awesome.", fileName2)
+
+ rdd <- textFile(sc, c(fileName1, fileName2))
+ expect_true(count(rdd) == 2)
+
+ unlink(fileName1)
+ unlink(fileName2)
+})
+
+test_that("Pipelined operations on RDDs created using textFile", {
+ fileName <- tempfile(pattern="spark-test", fileext=".tmp")
+ writeLines(mockFile, fileName)
+
+ rdd <- textFile(sc, fileName)
+
+ lengths <- lapply(rdd, function(x) { length(x) })
+ expect_equal(collect(lengths), list(1, 1))
+
+ lengthsPipelined <- lapply(lengths, function(x) { x + 10 })
+ expect_equal(collect(lengthsPipelined), list(11, 11))
+
+ lengths30 <- lapply(lengthsPipelined, function(x) { x + 20 })
+ expect_equal(collect(lengths30), list(31, 31))
+
+ lengths20 <- lapply(lengths, function(x) { x + 20 })
+ expect_equal(collect(lengths20), list(21, 21))
+
+ unlink(fileName)
+})
+
diff --git a/R/pkg/inst/tests/test_utils.R b/R/pkg/inst/tests/test_utils.R
new file mode 100644
index 0000000000..9c5bb42793
--- /dev/null
+++ b/R/pkg/inst/tests/test_utils.R
@@ -0,0 +1,137 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+context("functions in utils.R")
+
+# JavaSparkContext handle
+sc <- sparkR.init()
+
+test_that("convertJListToRList() gives back (deserializes) the original JLists
+ of strings and integers", {
+ # It's hard to manually create a Java List using rJava, since it does not
+ # support generics well. Instead, we rely on collect() returning a
+ # JList.
+ nums <- as.list(1:10)
+ rdd <- parallelize(sc, nums, 1L)
+ jList <- callJMethod(rdd@jrdd, "collect")
+ rList <- convertJListToRList(jList, flatten = TRUE)
+ expect_equal(rList, nums)
+
+ strs <- as.list("hello", "spark")
+ rdd <- parallelize(sc, strs, 2L)
+ jList <- callJMethod(rdd@jrdd, "collect")
+ rList <- convertJListToRList(jList, flatten = TRUE)
+ expect_equal(rList, strs)
+})
+
+test_that("serializeToBytes on RDD", {
+ # File content
+ mockFile <- c("Spark is pretty.", "Spark is awesome.")
+ fileName <- tempfile(pattern="spark-test", fileext=".tmp")
+ writeLines(mockFile, fileName)
+
+ text.rdd <- textFile(sc, fileName)
+ expect_true(getSerializedMode(text.rdd) == "string")
+ ser.rdd <- serializeToBytes(text.rdd)
+ expect_equal(collect(ser.rdd), as.list(mockFile))
+ expect_true(getSerializedMode(ser.rdd) == "byte")
+
+ unlink(fileName)
+})
+
+test_that("cleanClosure on R functions", {
+ y <- c(1, 2, 3)
+ g <- function(x) { x + 1 }
+ f <- function(x) { g(x) + y }
+ newF <- cleanClosure(f)
+ env <- environment(newF)
+ expect_equal(length(ls(env)), 2) # y, g
+ actual <- get("y", envir = env, inherits = FALSE)
+ expect_equal(actual, y)
+ actual <- get("g", envir = env, inherits = FALSE)
+ expect_equal(actual, g)
+
+ # Test for nested enclosures and package variables.
+ env2 <- new.env()
+ funcEnv <- new.env(parent = env2)
+ f <- function(x) { log(g(x) + y) }
+ environment(f) <- funcEnv # enclosing relationship: f -> funcEnv -> env2 -> .GlobalEnv
+ newF <- cleanClosure(f)
+ env <- environment(newF)
+ expect_equal(length(ls(env)), 2) # "min" should not be included
+ actual <- get("y", envir = env, inherits = FALSE)
+ expect_equal(actual, y)
+ actual <- get("g", envir = env, inherits = FALSE)
+ expect_equal(actual, g)
+
+ base <- c(1, 2, 3)
+ l <- list(field = matrix(1))
+ field <- matrix(2)
+ defUse <- 3
+ g <- function(x) { x + y }
+ f <- function(x) {
+ defUse <- base::as.integer(x) + 1 # Test for access operators `::`.
+ lapply(x, g) + 1 # Test for capturing function call "g"'s closure as a argument of lapply.
+ l$field[1,1] <- 3 # Test for access operators `$`.
+ res <- defUse + l$field[1,] # Test for def-use chain of "defUse", and "" symbol.
+ f(res) # Test for recursive calls.
+ }
+ newF <- cleanClosure(f)
+ env <- environment(newF)
+ expect_equal(length(ls(env)), 3) # Only "g", "l" and "f". No "base", "field" or "defUse".
+ expect_true("g" %in% ls(env))
+ expect_true("l" %in% ls(env))
+ expect_true("f" %in% ls(env))
+ expect_equal(get("l", envir = env, inherits = FALSE), l)
+ # "y" should be in the environemnt of g.
+ newG <- get("g", envir = env, inherits = FALSE)
+ env <- environment(newG)
+ expect_equal(length(ls(env)), 1)
+ actual <- get("y", envir = env, inherits = FALSE)
+ expect_equal(actual, y)
+
+ # Test for function (and variable) definitions.
+ f <- function(x) {
+ g <- function(y) { y * 2 }
+ g(x)
+ }
+ newF <- cleanClosure(f)
+ env <- environment(newF)
+ expect_equal(length(ls(env)), 0) # "y" and "g" should not be included.
+
+ # Test for overriding variables in base namespace (Issue: SparkR-196).
+ nums <- as.list(1:10)
+ rdd <- parallelize(sc, nums, 2L)
+ t = 4 # Override base::t in .GlobalEnv.
+ f <- function(x) { x > t }
+ newF <- cleanClosure(f)
+ env <- environment(newF)
+ expect_equal(ls(env), "t")
+ expect_equal(get("t", envir = env, inherits = FALSE), t)
+ actual <- collect(lapply(rdd, f))
+ expected <- as.list(c(rep(FALSE, 4), rep(TRUE, 6)))
+ expect_equal(actual, expected)
+
+ # Test for broadcast variables.
+ a <- matrix(nrow=10, ncol=10, data=rnorm(100))
+ aBroadcast <- broadcast(sc, a)
+ normMultiply <- function(x) { norm(aBroadcast$value) * x }
+ newnormMultiply <- SparkR:::cleanClosure(normMultiply)
+ env <- environment(newnormMultiply)
+ expect_equal(ls(env), "aBroadcast")
+ expect_equal(get("aBroadcast", envir = env, inherits = FALSE), aBroadcast)
+})
diff --git a/R/pkg/inst/worker/daemon.R b/R/pkg/inst/worker/daemon.R
new file mode 100644
index 0000000000..3584b418a7
--- /dev/null
+++ b/R/pkg/inst/worker/daemon.R
@@ -0,0 +1,52 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# Worker daemon
+
+rLibDir <- Sys.getenv("SPARKR_RLIBDIR")
+script <- paste(rLibDir, "SparkR/worker/worker.R", sep = "/")
+
+# preload SparkR package, speedup worker
+.libPaths(c(rLibDir, .libPaths()))
+suppressPackageStartupMessages(library(SparkR))
+
+port <- as.integer(Sys.getenv("SPARKR_WORKER_PORT"))
+inputCon <- socketConnection(port = port, open = "rb", blocking = TRUE, timeout = 3600)
+
+while (TRUE) {
+ ready <- socketSelect(list(inputCon))
+ if (ready) {
+ port <- SparkR:::readInt(inputCon)
+ # There is a small chance that it could be interrupted by signal, retry one time
+ if (length(port) == 0) {
+ port <- SparkR:::readInt(inputCon)
+ if (length(port) == 0) {
+ cat("quitting daemon\n")
+ quit(save = "no")
+ }
+ }
+ p <- parallel:::mcfork()
+ if (inherits(p, "masterProcess")) {
+ close(inputCon)
+ Sys.setenv(SPARKR_WORKER_PORT = port)
+ source(script)
+ # Set SIGUSR1 so that child can exit
+ tools::pskill(Sys.getpid(), tools::SIGUSR1)
+ parallel:::mcexit(0L)
+ }
+ }
+}
diff --git a/R/pkg/inst/worker/worker.R b/R/pkg/inst/worker/worker.R
new file mode 100644
index 0000000000..c6542928e8
--- /dev/null
+++ b/R/pkg/inst/worker/worker.R
@@ -0,0 +1,128 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# Worker class
+
+rLibDir <- Sys.getenv("SPARKR_RLIBDIR")
+# Set libPaths to include SparkR package as loadNamespace needs this
+# TODO: Figure out if we can avoid this by not loading any objects that require
+# SparkR namespace
+.libPaths(c(rLibDir, .libPaths()))
+suppressPackageStartupMessages(library(SparkR))
+
+port <- as.integer(Sys.getenv("SPARKR_WORKER_PORT"))
+inputCon <- socketConnection(port = port, blocking = TRUE, open = "rb")
+outputCon <- socketConnection(port = port, blocking = TRUE, open = "wb")
+
+# read the index of the current partition inside the RDD
+partition <- SparkR:::readInt(inputCon)
+
+deserializer <- SparkR:::readString(inputCon)
+serializer <- SparkR:::readString(inputCon)
+
+# Include packages as required
+packageNames <- unserialize(SparkR:::readRaw(inputCon))
+for (pkg in packageNames) {
+ suppressPackageStartupMessages(require(as.character(pkg), character.only=TRUE))
+}
+
+# read function dependencies
+funcLen <- SparkR:::readInt(inputCon)
+computeFunc <- unserialize(SparkR:::readRawLen(inputCon, funcLen))
+env <- environment(computeFunc)
+parent.env(env) <- .GlobalEnv # Attach under global environment.
+
+# Read and set broadcast variables
+numBroadcastVars <- SparkR:::readInt(inputCon)
+if (numBroadcastVars > 0) {
+ for (bcast in seq(1:numBroadcastVars)) {
+ bcastId <- SparkR:::readInt(inputCon)
+ value <- unserialize(SparkR:::readRaw(inputCon))
+ setBroadcastValue(bcastId, value)
+ }
+}
+
+# If -1: read as normal RDD; if >= 0, treat as pairwise RDD and treat the int
+# as number of partitions to create.
+numPartitions <- SparkR:::readInt(inputCon)
+
+isEmpty <- SparkR:::readInt(inputCon)
+
+if (isEmpty != 0) {
+
+ if (numPartitions == -1) {
+ if (deserializer == "byte") {
+ # Now read as many characters as described in funcLen
+ data <- SparkR:::readDeserialize(inputCon)
+ } else if (deserializer == "string") {
+ data <- as.list(readLines(inputCon))
+ } else if (deserializer == "row") {
+ data <- SparkR:::readDeserializeRows(inputCon)
+ }
+ output <- computeFunc(partition, data)
+ if (serializer == "byte") {
+ SparkR:::writeRawSerialize(outputCon, output)
+ } else if (serializer == "row") {
+ SparkR:::writeRowSerialize(outputCon, output)
+ } else {
+ SparkR:::writeStrings(outputCon, output)
+ }
+ } else {
+ if (deserializer == "byte") {
+ # Now read as many characters as described in funcLen
+ data <- SparkR:::readDeserialize(inputCon)
+ } else if (deserializer == "string") {
+ data <- readLines(inputCon)
+ } else if (deserializer == "row") {
+ data <- SparkR:::readDeserializeRows(inputCon)
+ }
+
+ res <- new.env()
+
+ # Step 1: hash the data to an environment
+ hashTupleToEnvir <- function(tuple) {
+ # NOTE: execFunction is the hash function here
+ hashVal <- computeFunc(tuple[[1]])
+ bucket <- as.character(hashVal %% numPartitions)
+ acc <- res[[bucket]]
+ # Create a new accumulator
+ if (is.null(acc)) {
+ acc <- SparkR:::initAccumulator()
+ }
+ SparkR:::addItemToAccumulator(acc, tuple)
+ res[[bucket]] <- acc
+ }
+ invisible(lapply(data, hashTupleToEnvir))
+
+ # Step 2: write out all of the environment as key-value pairs.
+ for (name in ls(res)) {
+ SparkR:::writeInt(outputCon, 2L)
+ SparkR:::writeInt(outputCon, as.integer(name))
+ # Truncate the accumulator list to the number of elements we have
+ length(res[[name]]$data) <- res[[name]]$counter
+ SparkR:::writeRawSerialize(outputCon, res[[name]]$data)
+ }
+ }
+}
+
+# End of output
+if (serializer %in% c("byte", "row")) {
+ SparkR:::writeInt(outputCon, 0L)
+}
+
+close(outputCon)
+close(inputCon)
diff --git a/R/pkg/src/Makefile b/R/pkg/src/Makefile
new file mode 100644
index 0000000000..a55a56fe80
--- /dev/null
+++ b/R/pkg/src/Makefile
@@ -0,0 +1,27 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+all: sharelib
+
+sharelib: string_hash_code.c
+ R CMD SHLIB -o SparkR.so string_hash_code.c
+
+clean:
+ rm -f *.o
+ rm -f *.so
+
+.PHONY: all clean
diff --git a/R/pkg/src/Makefile.win b/R/pkg/src/Makefile.win
new file mode 100644
index 0000000000..aa486d8228
--- /dev/null
+++ b/R/pkg/src/Makefile.win
@@ -0,0 +1,27 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+all: sharelib
+
+sharelib: string_hash_code.c
+ R CMD SHLIB -o SparkR.dll string_hash_code.c
+
+clean:
+ rm -f *.o
+ rm -f *.dll
+
+.PHONY: all clean
diff --git a/R/pkg/src/string_hash_code.c b/R/pkg/src/string_hash_code.c
new file mode 100644
index 0000000000..e3274b9a0c
--- /dev/null
+++ b/R/pkg/src/string_hash_code.c
@@ -0,0 +1,49 @@
+/*
+ Licensed to the Apache Software Foundation (ASF) under one or more
+ contributor license agreements. See the NOTICE file distributed with
+ this work for additional information regarding copyright ownership.
+ The ASF licenses this file to You under the Apache License, Version 2.0
+ (the "License"); you may not use this file except in compliance with
+ the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+*/
+
+/*
+ * A C function for R extension which implements the Java String hash algorithm.
+ * Refer to http://en.wikipedia.org/wiki/Java_hashCode%28%29#The_java.lang.String_hash_function
+ *
+ */
+
+#include <R.h>
+#include <Rinternals.h>
+
+/* for compatibility with R before 3.1 */
+#ifndef IS_SCALAR
+#define IS_SCALAR(x, type) (TYPEOF(x) == (type) && XLENGTH(x) == 1)
+#endif
+
+SEXP stringHashCode(SEXP string) {
+ const char* str;
+ R_xlen_t len, i;
+ int hashCode = 0;
+
+ if (!IS_SCALAR(string, STRSXP)) {
+ error("invalid input");
+ }
+
+ str = CHAR(asChar(string));
+ len = XLENGTH(asChar(string));
+
+ for (i = 0; i < len; i++) {
+ hashCode = (hashCode << 5) - hashCode + *str++;
+ }
+
+ return ScalarInteger(hashCode);
+}
diff --git a/R/pkg/tests/run-all.R b/R/pkg/tests/run-all.R
new file mode 100644
index 0000000000..4f8a1ed2d8
--- /dev/null
+++ b/R/pkg/tests/run-all.R
@@ -0,0 +1,21 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+library(testthat)
+library(SparkR)
+
+test_package("SparkR")
diff --git a/R/run-tests.sh b/R/run-tests.sh
new file mode 100755
index 0000000000..e82ad0ba2c
--- /dev/null
+++ b/R/run-tests.sh
@@ -0,0 +1,39 @@
+#!/bin/bash
+
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+FWDIR="$(cd `dirname $0`; pwd)"
+
+FAILED=0
+LOGFILE=$FWDIR/unit-tests.out
+rm -f $LOGFILE
+
+SPARK_TESTING=1 $FWDIR/../bin/sparkR --driver-java-options "-Dlog4j.configuration=file:$FWDIR/log4j.properties" $FWDIR/pkg/tests/run-all.R 2>&1 | tee -a $LOGFILE
+FAILED=$((PIPESTATUS[0]||$FAILED))
+
+if [[ $FAILED != 0 ]]; then
+ cat $LOGFILE
+ echo -en "\033[31m" # Red
+ echo "Had test failures; see logs."
+ echo -en "\033[0m" # No color
+ exit -1
+else
+ echo -en "\033[32m" # Green
+ echo "Tests passed."
+ echo -en "\033[0m" # No color
+fi
diff --git a/bin/sparkR b/bin/sparkR
new file mode 100755
index 0000000000..8c918e2b09
--- /dev/null
+++ b/bin/sparkR
@@ -0,0 +1,39 @@
+#!/bin/bash
+
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# Figure out where Spark is installed
+export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)"
+
+source "$SPARK_HOME"/bin/load-spark-env.sh
+
+function usage() {
+ if [ -n "$1" ]; then
+ echo $1
+ fi
+ echo "Usage: ./bin/sparkR [options]" 1>&2
+ "$SPARK_HOME"/bin/spark-submit --help 2>&1 | grep -v Usage 1>&2
+ exit $2
+}
+export -f usage
+
+if [[ "$@" = *--help ]] || [[ "$@" = *-h ]]; then
+ usage
+fi
+
+exec "$SPARK_HOME"/bin/spark-submit sparkr-shell-main "$@"
diff --git a/bin/sparkR.cmd b/bin/sparkR.cmd
new file mode 100644
index 0000000000..d7b60183ca
--- /dev/null
+++ b/bin/sparkR.cmd
@@ -0,0 +1,23 @@
+@echo off
+
+rem
+rem Licensed to the Apache Software Foundation (ASF) under one or more
+rem contributor license agreements. See the NOTICE file distributed with
+rem this work for additional information regarding copyright ownership.
+rem The ASF licenses this file to You under the Apache License, Version 2.0
+rem (the "License"); you may not use this file except in compliance with
+rem the License. You may obtain a copy of the License at
+rem
+rem http://www.apache.org/licenses/LICENSE-2.0
+rem
+rem Unless required by applicable law or agreed to in writing, software
+rem distributed under the License is distributed on an "AS IS" BASIS,
+rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+rem See the License for the specific language governing permissions and
+rem limitations under the License.
+rem
+
+rem This is the entry point for running SparkR. To avoid polluting the
+rem environment, it just launches a new cmd to do the real work.
+
+cmd /V /E /C %~dp0sparkR2.cmd %*
diff --git a/bin/sparkR2.cmd b/bin/sparkR2.cmd
new file mode 100644
index 0000000000..e47f22c730
--- /dev/null
+++ b/bin/sparkR2.cmd
@@ -0,0 +1,26 @@
+@echo off
+
+rem
+rem Licensed to the Apache Software Foundation (ASF) under one or more
+rem contributor license agreements. See the NOTICE file distributed with
+rem this work for additional information regarding copyright ownership.
+rem The ASF licenses this file to You under the Apache License, Version 2.0
+rem (the "License"); you may not use this file except in compliance with
+rem the License. You may obtain a copy of the License at
+rem
+rem http://www.apache.org/licenses/LICENSE-2.0
+rem
+rem Unless required by applicable law or agreed to in writing, software
+rem distributed under the License is distributed on an "AS IS" BASIS,
+rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+rem See the License for the specific language governing permissions and
+rem limitations under the License.
+rem
+
+rem Figure out where the Spark framework is installed
+set SPARK_HOME=%~dp0..
+
+call %SPARK_HOME%\bin\load-spark-env.cmd
+
+
+call %SPARK_HOME%\bin\spark-submit2.cmd sparkr-shell-main %*
diff --git a/core/pom.xml b/core/pom.xml
index 6cd1965ec3..e80829b7a7 100644
--- a/core/pom.xml
+++ b/core/pom.xml
@@ -442,4 +442,55 @@
</resources>
</build>
+ <profiles>
+ <profile>
+ <id>Windows</id>
+ <activation>
+ <os>
+ <family>Windows</family>
+ </os>
+ </activation>
+ <properties>
+ <path.separator>\</path.separator>
+ <script.extension>.bat</script.extension>
+ </properties>
+ </profile>
+ <profile>
+ <id>unix</id>
+ <activation>
+ <os>
+ <family>unix</family>
+ </os>
+ </activation>
+ <properties>
+ <path.separator>/</path.separator>
+ <script.extension>.sh</script.extension>
+ </properties>
+ </profile>
+ <profile>
+ <id>sparkr</id>
+ <build>
+ <plugins>
+ <plugin>
+ <groupId>org.codehaus.mojo</groupId>
+ <artifactId>exec-maven-plugin</artifactId>
+ <version>1.3.2</version>
+ <executions>
+ <execution>
+ <id>sparkr-pkg</id>
+ <phase>compile</phase>
+ <goals>
+ <goal>exec</goal>
+ </goals>
+ </execution>
+ </executions>
+ <configuration>
+ <executable>..${path.separator}R${path.separator}install-dev${script.extension}</executable>
+ </configuration>
+ </plugin>
+ </plugins>
+ </build>
+ </profile>
+ </profiles>
+
</project>
diff --git a/core/src/main/scala/org/apache/spark/api/r/RBackend.scala b/core/src/main/scala/org/apache/spark/api/r/RBackend.scala
new file mode 100644
index 0000000000..3a2c94bd9d
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/api/r/RBackend.scala
@@ -0,0 +1,145 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.api.r
+
+import java.io.{DataOutputStream, File, FileOutputStream, IOException}
+import java.net.{InetSocketAddress, ServerSocket}
+import java.util.concurrent.TimeUnit
+
+import io.netty.bootstrap.ServerBootstrap
+import io.netty.channel.{ChannelFuture, ChannelInitializer, EventLoopGroup}
+import io.netty.channel.nio.NioEventLoopGroup
+import io.netty.channel.socket.SocketChannel
+import io.netty.channel.socket.nio.NioServerSocketChannel
+import io.netty.handler.codec.LengthFieldBasedFrameDecoder
+import io.netty.handler.codec.bytes.{ByteArrayDecoder, ByteArrayEncoder}
+
+import org.apache.spark.Logging
+
+/**
+ * Netty-based backend server that is used to communicate between R and Java.
+ */
+private[spark] class RBackend {
+
+ private[this] var channelFuture: ChannelFuture = null
+ private[this] var bootstrap: ServerBootstrap = null
+ private[this] var bossGroup: EventLoopGroup = null
+
+ def init(): Int = {
+ bossGroup = new NioEventLoopGroup(2)
+ val workerGroup = bossGroup
+ val handler = new RBackendHandler(this)
+
+ bootstrap = new ServerBootstrap()
+ .group(bossGroup, workerGroup)
+ .channel(classOf[NioServerSocketChannel])
+
+ bootstrap.childHandler(new ChannelInitializer[SocketChannel]() {
+ def initChannel(ch: SocketChannel): Unit = {
+ ch.pipeline()
+ .addLast("encoder", new ByteArrayEncoder())
+ .addLast("frameDecoder",
+ // maxFrameLength = 2G
+ // lengthFieldOffset = 0
+ // lengthFieldLength = 4
+ // lengthAdjustment = 0
+ // initialBytesToStrip = 4, i.e. strip out the length field itself
+ new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE, 0, 4, 0, 4))
+ .addLast("decoder", new ByteArrayDecoder())
+ .addLast("handler", handler)
+ }
+ })
+
+ channelFuture = bootstrap.bind(new InetSocketAddress(0))
+ channelFuture.syncUninterruptibly()
+ channelFuture.channel().localAddress().asInstanceOf[InetSocketAddress].getPort()
+ }
+
+ def run(): Unit = {
+ channelFuture.channel.closeFuture().syncUninterruptibly()
+ }
+
+ def close(): Unit = {
+ if (channelFuture != null) {
+ // close is a local operation and should finish within milliseconds; timeout just to be safe
+ channelFuture.channel().close().awaitUninterruptibly(10, TimeUnit.SECONDS)
+ channelFuture = null
+ }
+ if (bootstrap != null && bootstrap.group() != null) {
+ bootstrap.group().shutdownGracefully()
+ }
+ if (bootstrap != null && bootstrap.childGroup() != null) {
+ bootstrap.childGroup().shutdownGracefully()
+ }
+ bootstrap = null
+ }
+
+}
+
+private[spark] object RBackend extends Logging {
+ def main(args: Array[String]): Unit = {
+ if (args.length < 1) {
+ System.err.println("Usage: RBackend <tempFilePath>")
+ System.exit(-1)
+ }
+ val sparkRBackend = new RBackend()
+ try {
+ // bind to random port
+ val boundPort = sparkRBackend.init()
+ val serverSocket = new ServerSocket(0, 1)
+ val listenPort = serverSocket.getLocalPort()
+
+ // tell the R process via temporary file
+ val path = args(0)
+ val f = new File(path + ".tmp")
+ val dos = new DataOutputStream(new FileOutputStream(f))
+ dos.writeInt(boundPort)
+ dos.writeInt(listenPort)
+ dos.close()
+ f.renameTo(new File(path))
+
+ // wait for the end of stdin, then exit
+ new Thread("wait for socket to close") {
+ setDaemon(true)
+ override def run(): Unit = {
+ // any un-catched exception will also shutdown JVM
+ val buf = new Array[Byte](1024)
+ // shutdown JVM if R does not connect back in 10 seconds
+ serverSocket.setSoTimeout(10000)
+ try {
+ val inSocket = serverSocket.accept()
+ serverSocket.close()
+ // wait for the end of socket, closed if R process die
+ inSocket.getInputStream().read(buf)
+ } finally {
+ sparkRBackend.close()
+ System.exit(0)
+ }
+ }
+ }.start()
+
+ sparkRBackend.run()
+ } catch {
+ case e: IOException =>
+ logError("Server shutting down: failed with exception ", e)
+ sparkRBackend.close()
+ System.exit(1)
+ }
+ System.exit(0)
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala b/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala
new file mode 100644
index 0000000000..0075d96371
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala
@@ -0,0 +1,223 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.api.r
+
+import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream}
+
+import scala.collection.mutable.HashMap
+
+import io.netty.channel.ChannelHandler.Sharable
+import io.netty.channel.{ChannelHandlerContext, SimpleChannelInboundHandler}
+
+import org.apache.spark.Logging
+import org.apache.spark.api.r.SerDe._
+
+/**
+ * Handler for RBackend
+ * TODO: This is marked as sharable to get a handle to RBackend. Is it safe to re-use
+ * this across connections ?
+ */
+@Sharable
+private[r] class RBackendHandler(server: RBackend)
+ extends SimpleChannelInboundHandler[Array[Byte]] with Logging {
+
+ override def channelRead0(ctx: ChannelHandlerContext, msg: Array[Byte]): Unit = {
+ val bis = new ByteArrayInputStream(msg)
+ val dis = new DataInputStream(bis)
+
+ val bos = new ByteArrayOutputStream()
+ val dos = new DataOutputStream(bos)
+
+ // First bit is isStatic
+ val isStatic = readBoolean(dis)
+ val objId = readString(dis)
+ val methodName = readString(dis)
+ val numArgs = readInt(dis)
+
+ if (objId == "SparkRHandler") {
+ methodName match {
+ case "stopBackend" =>
+ writeInt(dos, 0)
+ writeType(dos, "void")
+ server.close()
+ case "rm" =>
+ try {
+ val t = readObjectType(dis)
+ assert(t == 'c')
+ val objToRemove = readString(dis)
+ JVMObjectTracker.remove(objToRemove)
+ writeInt(dos, 0)
+ writeObject(dos, null)
+ } catch {
+ case e: Exception =>
+ logError(s"Removing $objId failed", e)
+ writeInt(dos, -1)
+ }
+ case _ => dos.writeInt(-1)
+ }
+ } else {
+ handleMethodCall(isStatic, objId, methodName, numArgs, dis, dos)
+ }
+
+ val reply = bos.toByteArray
+ ctx.write(reply)
+ }
+
+ override def channelReadComplete(ctx: ChannelHandlerContext): Unit = {
+ ctx.flush()
+ }
+
+ override def exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable): Unit = {
+ // Close the connection when an exception is raised.
+ cause.printStackTrace()
+ ctx.close()
+ }
+
+ def handleMethodCall(
+ isStatic: Boolean,
+ objId: String,
+ methodName: String,
+ numArgs: Int,
+ dis: DataInputStream,
+ dos: DataOutputStream): Unit = {
+ var obj: Object = null
+ try {
+ val cls = if (isStatic) {
+ Class.forName(objId)
+ } else {
+ JVMObjectTracker.get(objId) match {
+ case None => throw new IllegalArgumentException("Object not found " + objId)
+ case Some(o) =>
+ obj = o
+ o.getClass
+ }
+ }
+
+ val args = readArgs(numArgs, dis)
+
+ val methods = cls.getMethods
+ val selectedMethods = methods.filter(m => m.getName == methodName)
+ if (selectedMethods.length > 0) {
+ val methods = selectedMethods.filter { x =>
+ matchMethod(numArgs, args, x.getParameterTypes)
+ }
+ if (methods.isEmpty) {
+ logWarning(s"cannot find matching method ${cls}.$methodName. "
+ + s"Candidates are:")
+ selectedMethods.foreach { method =>
+ logWarning(s"$methodName(${method.getParameterTypes.mkString(",")})")
+ }
+ throw new Exception(s"No matched method found for $cls.$methodName")
+ }
+ val ret = methods.head.invoke(obj, args:_*)
+
+ // Write status bit
+ writeInt(dos, 0)
+ writeObject(dos, ret.asInstanceOf[AnyRef])
+ } else if (methodName == "<init>") {
+ // methodName should be "<init>" for constructor
+ val ctor = cls.getConstructors.filter { x =>
+ matchMethod(numArgs, args, x.getParameterTypes)
+ }.head
+
+ val obj = ctor.newInstance(args:_*)
+
+ writeInt(dos, 0)
+ writeObject(dos, obj.asInstanceOf[AnyRef])
+ } else {
+ throw new IllegalArgumentException("invalid method " + methodName + " for object " + objId)
+ }
+ } catch {
+ case e: Exception =>
+ logError(s"$methodName on $objId failed", e)
+ writeInt(dos, -1)
+ }
+ }
+
+ // Read a number of arguments from the data input stream
+ def readArgs(numArgs: Int, dis: DataInputStream): Array[java.lang.Object] = {
+ (0 until numArgs).map { arg =>
+ readObject(dis)
+ }.toArray
+ }
+
+ // Checks if the arguments passed in args matches the parameter types.
+ // NOTE: Currently we do exact match. We may add type conversions later.
+ def matchMethod(
+ numArgs: Int,
+ args: Array[java.lang.Object],
+ parameterTypes: Array[Class[_]]): Boolean = {
+ if (parameterTypes.length != numArgs) {
+ return false
+ }
+
+ for (i <- 0 to numArgs - 1) {
+ val parameterType = parameterTypes(i)
+ var parameterWrapperType = parameterType
+
+ // Convert native parameters to Object types as args is Array[Object] here
+ if (parameterType.isPrimitive) {
+ parameterWrapperType = parameterType match {
+ case java.lang.Integer.TYPE => classOf[java.lang.Integer]
+ case java.lang.Double.TYPE => classOf[java.lang.Double]
+ case java.lang.Boolean.TYPE => classOf[java.lang.Boolean]
+ case _ => parameterType
+ }
+ }
+ if (!parameterWrapperType.isInstance(args(i))) {
+ return false
+ }
+ }
+ true
+ }
+}
+
+/**
+ * Helper singleton that tracks JVM objects returned to R.
+ * This is useful for referencing these objects in RPC calls.
+ */
+private[r] object JVMObjectTracker {
+
+ // TODO: This map should be thread-safe if we want to support multiple
+ // connections at the same time
+ private[this] val objMap = new HashMap[String, Object]
+
+ // TODO: We support only one connection now, so an integer is fine.
+ // Investigate using use atomic integer in the future.
+ private[this] var objCounter: Int = 0
+
+ def getObject(id: String): Object = {
+ objMap(id)
+ }
+
+ def get(id: String): Option[Object] = {
+ objMap.get(id)
+ }
+
+ def put(obj: Object): String = {
+ val objId = objCounter.toString
+ objCounter = objCounter + 1
+ objMap.put(objId, obj)
+ objId
+ }
+
+ def remove(id: String): Option[Object] = {
+ objMap.remove(id)
+ }
+
+}
diff --git a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala
new file mode 100644
index 0000000000..5fa4d483b8
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala
@@ -0,0 +1,450 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.api.r
+
+import java.io._
+import java.net.ServerSocket
+import java.util.{Map => JMap}
+
+import scala.collection.JavaConversions._
+import scala.io.Source
+import scala.reflect.ClassTag
+import scala.util.Try
+
+import org.apache.spark._
+import org.apache.spark.api.java.{JavaPairRDD, JavaRDD, JavaSparkContext}
+import org.apache.spark.broadcast.Broadcast
+import org.apache.spark.rdd.RDD
+import org.apache.spark.util.Utils
+
+private abstract class BaseRRDD[T: ClassTag, U: ClassTag](
+ parent: RDD[T],
+ numPartitions: Int,
+ func: Array[Byte],
+ deserializer: String,
+ serializer: String,
+ packageNames: Array[Byte],
+ rLibDir: String,
+ broadcastVars: Array[Broadcast[Object]])
+ extends RDD[U](parent) with Logging {
+ override def getPartitions: Array[Partition] = parent.partitions
+
+ override def compute(partition: Partition, context: TaskContext): Iterator[U] = {
+
+ // The parent may be also an RRDD, so we should launch it first.
+ val parentIterator = firstParent[T].iterator(partition, context)
+
+ // we expect two connections
+ val serverSocket = new ServerSocket(0, 2)
+ val listenPort = serverSocket.getLocalPort()
+
+ // The stdout/stderr is shared by multiple tasks, because we use one daemon
+ // to launch child process as worker.
+ val errThread = RRDD.createRWorker(rLibDir, listenPort)
+
+ // We use two sockets to separate input and output, then it's easy to manage
+ // the lifecycle of them to avoid deadlock.
+ // TODO: optimize it to use one socket
+
+ // the socket used to send out the input of task
+ serverSocket.setSoTimeout(10000)
+ val inSocket = serverSocket.accept()
+ startStdinThread(inSocket.getOutputStream(), parentIterator, partition.index)
+
+ // the socket used to receive the output of task
+ val outSocket = serverSocket.accept()
+ val inputStream = new BufferedInputStream(outSocket.getInputStream)
+ val dataStream = openDataStream(inputStream)
+ serverSocket.close()
+
+ try {
+
+ return new Iterator[U] {
+ def next(): U = {
+ val obj = _nextObj
+ if (hasNext) {
+ _nextObj = read()
+ }
+ obj
+ }
+
+ var _nextObj = read()
+
+ def hasNext(): Boolean = {
+ val hasMore = (_nextObj != null)
+ if (!hasMore) {
+ dataStream.close()
+ }
+ hasMore
+ }
+ }
+ } catch {
+ case e: Exception =>
+ throw new SparkException("R computation failed with\n " + errThread.getLines())
+ }
+ }
+
+ /**
+ * Start a thread to write RDD data to the R process.
+ */
+ private def startStdinThread[T](
+ output: OutputStream,
+ iter: Iterator[T],
+ partition: Int): Unit = {
+
+ val env = SparkEnv.get
+ val bufferSize = System.getProperty("spark.buffer.size", "65536").toInt
+ val stream = new BufferedOutputStream(output, bufferSize)
+
+ new Thread("writer for R") {
+ override def run(): Unit = {
+ try {
+ SparkEnv.set(env)
+ val dataOut = new DataOutputStream(stream)
+ dataOut.writeInt(partition)
+
+ SerDe.writeString(dataOut, deserializer)
+ SerDe.writeString(dataOut, serializer)
+
+ dataOut.writeInt(packageNames.length)
+ dataOut.write(packageNames)
+
+ dataOut.writeInt(func.length)
+ dataOut.write(func)
+
+ dataOut.writeInt(broadcastVars.length)
+ broadcastVars.foreach { broadcast =>
+ // TODO(shivaram): Read a Long in R to avoid this cast
+ dataOut.writeInt(broadcast.id.toInt)
+ // TODO: Pass a byte array from R to avoid this cast ?
+ val broadcastByteArr = broadcast.value.asInstanceOf[Array[Byte]]
+ dataOut.writeInt(broadcastByteArr.length)
+ dataOut.write(broadcastByteArr)
+ }
+
+ dataOut.writeInt(numPartitions)
+
+ if (!iter.hasNext) {
+ dataOut.writeInt(0)
+ } else {
+ dataOut.writeInt(1)
+ }
+
+ val printOut = new PrintStream(stream)
+
+ def writeElem(elem: Any): Unit = {
+ if (deserializer == SerializationFormats.BYTE) {
+ val elemArr = elem.asInstanceOf[Array[Byte]]
+ dataOut.writeInt(elemArr.length)
+ dataOut.write(elemArr)
+ } else if (deserializer == SerializationFormats.ROW) {
+ dataOut.write(elem.asInstanceOf[Array[Byte]])
+ } else if (deserializer == SerializationFormats.STRING) {
+ printOut.println(elem)
+ }
+ }
+
+ for (elem <- iter) {
+ elem match {
+ case (key, value) =>
+ writeElem(key)
+ writeElem(value)
+ case _ =>
+ writeElem(elem)
+ }
+ }
+ stream.flush()
+ } catch {
+ // TODO: We should propogate this error to the task thread
+ case e: Exception =>
+ logError("R Writer thread got an exception", e)
+ } finally {
+ Try(output.close())
+ }
+ }
+ }.start()
+ }
+
+ protected def openDataStream(input: InputStream): Closeable
+
+ protected def read(): U
+}
+
+/**
+ * Form an RDD[(Int, Array[Byte])] from key-value pairs returned from R.
+ * This is used by SparkR's shuffle operations.
+ */
+private class PairwiseRRDD[T: ClassTag](
+ parent: RDD[T],
+ numPartitions: Int,
+ hashFunc: Array[Byte],
+ deserializer: String,
+ packageNames: Array[Byte],
+ rLibDir: String,
+ broadcastVars: Array[Object])
+ extends BaseRRDD[T, (Int, Array[Byte])](
+ parent, numPartitions, hashFunc, deserializer,
+ SerializationFormats.BYTE, packageNames, rLibDir,
+ broadcastVars.map(x => x.asInstanceOf[Broadcast[Object]])) {
+
+ private var dataStream: DataInputStream = _
+
+ override protected def openDataStream(input: InputStream): Closeable = {
+ dataStream = new DataInputStream(input)
+ dataStream
+ }
+
+ override protected def read(): (Int, Array[Byte]) = {
+ try {
+ val length = dataStream.readInt()
+
+ length match {
+ case length if length == 2 =>
+ val hashedKey = dataStream.readInt()
+ val contentPairsLength = dataStream.readInt()
+ val contentPairs = new Array[Byte](contentPairsLength)
+ dataStream.readFully(contentPairs)
+ (hashedKey, contentPairs)
+ case _ => null // End of input
+ }
+ } catch {
+ case eof: EOFException => {
+ throw new SparkException("R worker exited unexpectedly (crashed)", eof)
+ }
+ }
+ }
+
+ lazy val asJavaPairRDD : JavaPairRDD[Int, Array[Byte]] = JavaPairRDD.fromRDD(this)
+}
+
+/**
+ * An RDD that stores serialized R objects as Array[Byte].
+ */
+private class RRDD[T: ClassTag](
+ parent: RDD[T],
+ func: Array[Byte],
+ deserializer: String,
+ serializer: String,
+ packageNames: Array[Byte],
+ rLibDir: String,
+ broadcastVars: Array[Object])
+ extends BaseRRDD[T, Array[Byte]](
+ parent, -1, func, deserializer, serializer, packageNames, rLibDir,
+ broadcastVars.map(x => x.asInstanceOf[Broadcast[Object]])) {
+
+ private var dataStream: DataInputStream = _
+
+ override protected def openDataStream(input: InputStream): Closeable = {
+ dataStream = new DataInputStream(input)
+ dataStream
+ }
+
+ override protected def read(): Array[Byte] = {
+ try {
+ val length = dataStream.readInt()
+
+ length match {
+ case length if length > 0 =>
+ val obj = new Array[Byte](length)
+ dataStream.readFully(obj, 0, length)
+ obj
+ case _ => null
+ }
+ } catch {
+ case eof: EOFException => {
+ throw new SparkException("R worker exited unexpectedly (crashed)", eof)
+ }
+ }
+ }
+
+ lazy val asJavaRDD : JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this)
+}
+
+/**
+ * An RDD that stores R objects as Array[String].
+ */
+private class StringRRDD[T: ClassTag](
+ parent: RDD[T],
+ func: Array[Byte],
+ deserializer: String,
+ packageNames: Array[Byte],
+ rLibDir: String,
+ broadcastVars: Array[Object])
+ extends BaseRRDD[T, String](
+ parent, -1, func, deserializer, SerializationFormats.STRING, packageNames, rLibDir,
+ broadcastVars.map(x => x.asInstanceOf[Broadcast[Object]])) {
+
+ private var dataStream: BufferedReader = _
+
+ override protected def openDataStream(input: InputStream): Closeable = {
+ dataStream = new BufferedReader(new InputStreamReader(input))
+ dataStream
+ }
+
+ override protected def read(): String = {
+ try {
+ dataStream.readLine()
+ } catch {
+ case e: IOException => {
+ throw new SparkException("R worker exited unexpectedly (crashed)", e)
+ }
+ }
+ }
+
+ lazy val asJavaRDD : JavaRDD[String] = JavaRDD.fromRDD(this)
+}
+
+private[r] class BufferedStreamThread(
+ in: InputStream,
+ name: String,
+ errBufferSize: Int) extends Thread(name) with Logging {
+ val lines = new Array[String](errBufferSize)
+ var lineIdx = 0
+ override def run() {
+ for (line <- Source.fromInputStream(in).getLines) {
+ synchronized {
+ lines(lineIdx) = line
+ lineIdx = (lineIdx + 1) % errBufferSize
+ }
+ logInfo(line)
+ }
+ }
+
+ def getLines(): String = synchronized {
+ (0 until errBufferSize).filter { x =>
+ lines((x + lineIdx) % errBufferSize) != null
+ }.map { x =>
+ lines((x + lineIdx) % errBufferSize)
+ }.mkString("\n")
+ }
+}
+
+private[r] object RRDD {
+ // Because forking processes from Java is expensive, we prefer to launch
+ // a single R daemon (daemon.R) and tell it to fork new workers for our tasks.
+ // This daemon currently only works on UNIX-based systems now, so we should
+ // also fall back to launching workers (worker.R) directly.
+ private[this] var errThread: BufferedStreamThread = _
+ private[this] var daemonChannel: DataOutputStream = _
+
+ def createSparkContext(
+ master: String,
+ appName: String,
+ sparkHome: String,
+ jars: Array[String],
+ sparkEnvirMap: JMap[Object, Object],
+ sparkExecutorEnvMap: JMap[Object, Object]): JavaSparkContext = {
+
+ val sparkConf = new SparkConf().setAppName(appName)
+ .setSparkHome(sparkHome)
+ .setJars(jars)
+
+ // Override `master` if we have a user-specified value
+ if (master != "") {
+ sparkConf.setMaster(master)
+ } else {
+ // If conf has no master set it to "local" to maintain
+ // backwards compatibility
+ sparkConf.setIfMissing("spark.master", "local")
+ }
+
+ for ((name, value) <- sparkEnvirMap) {
+ sparkConf.set(name.asInstanceOf[String], value.asInstanceOf[String])
+ }
+ for ((name, value) <- sparkExecutorEnvMap) {
+ sparkConf.setExecutorEnv(name.asInstanceOf[String], value.asInstanceOf[String])
+ }
+
+ new JavaSparkContext(sparkConf)
+ }
+
+ /**
+ * Start a thread to print the process's stderr to ours
+ */
+ private def startStdoutThread(proc: Process): BufferedStreamThread = {
+ val BUFFER_SIZE = 100
+ val thread = new BufferedStreamThread(proc.getInputStream, "stdout reader for R", BUFFER_SIZE)
+ thread.setDaemon(true)
+ thread.start()
+ thread
+ }
+
+ private def createRProcess(rLibDir: String, port: Int, script: String): BufferedStreamThread = {
+ val rCommand = "Rscript"
+ val rOptions = "--vanilla"
+ val rExecScript = rLibDir + "/SparkR/worker/" + script
+ val pb = new ProcessBuilder(List(rCommand, rOptions, rExecScript))
+ // Unset the R_TESTS environment variable for workers.
+ // This is set by R CMD check as startup.Rs
+ // (http://svn.r-project.org/R/trunk/src/library/tools/R/testing.R)
+ // and confuses worker script which tries to load a non-existent file
+ pb.environment().put("R_TESTS", "")
+ pb.environment().put("SPARKR_RLIBDIR", rLibDir)
+ pb.environment().put("SPARKR_WORKER_PORT", port.toString)
+ pb.redirectErrorStream(true) // redirect stderr into stdout
+ val proc = pb.start()
+ val errThread = startStdoutThread(proc)
+ errThread
+ }
+
+ /**
+ * ProcessBuilder used to launch worker R processes.
+ */
+ def createRWorker(rLibDir: String, port: Int): BufferedStreamThread = {
+ val useDaemon = SparkEnv.get.conf.getBoolean("spark.sparkr.use.daemon", true)
+ if (!Utils.isWindows && useDaemon) {
+ synchronized {
+ if (daemonChannel == null) {
+ // we expect one connections
+ val serverSocket = new ServerSocket(0, 1)
+ val daemonPort = serverSocket.getLocalPort
+ errThread = createRProcess(rLibDir, daemonPort, "daemon.R")
+ // the socket used to send out the input of task
+ serverSocket.setSoTimeout(10000)
+ val sock = serverSocket.accept()
+ daemonChannel = new DataOutputStream(new BufferedOutputStream(sock.getOutputStream))
+ serverSocket.close()
+ }
+ try {
+ daemonChannel.writeInt(port)
+ daemonChannel.flush()
+ } catch {
+ case e: IOException =>
+ // daemon process died
+ daemonChannel.close()
+ daemonChannel = null
+ errThread = null
+ // fail the current task, retry by scheduler
+ throw e
+ }
+ errThread
+ }
+ } else {
+ createRProcess(rLibDir, port, "worker.R")
+ }
+ }
+
+ /**
+ * Create an RRDD given a sequence of byte arrays. Used to create RRDD when `parallelize` is
+ * called from R.
+ */
+ def createRDDFromArray(jsc: JavaSparkContext, arr: Array[Array[Byte]]): JavaRDD[Array[Byte]] = {
+ JavaRDD.fromRDD(jsc.sc.parallelize(arr, arr.length))
+ }
+
+}
diff --git a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala
new file mode 100644
index 0000000000..ccb2a371f4
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala
@@ -0,0 +1,340 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.api.r
+
+import java.io.{DataInputStream, DataOutputStream}
+import java.sql.{Date, Time}
+
+import scala.collection.JavaConversions._
+
+/**
+ * Utility functions to serialize, deserialize objects to / from R
+ */
+private[spark] object SerDe {
+
+ // Type mapping from R to Java
+ //
+ // NULL -> void
+ // integer -> Int
+ // character -> String
+ // logical -> Boolean
+ // double, numeric -> Double
+ // raw -> Array[Byte]
+ // Date -> Date
+ // POSIXlt/POSIXct -> Time
+ //
+ // list[T] -> Array[T], where T is one of above mentioned types
+ // environment -> Map[String, T], where T is a native type
+ // jobj -> Object, where jobj is an object created in the backend
+
+ def readObjectType(dis: DataInputStream): Char = {
+ dis.readByte().toChar
+ }
+
+ def readObject(dis: DataInputStream): Object = {
+ val dataType = readObjectType(dis)
+ readTypedObject(dis, dataType)
+ }
+
+ def readTypedObject(
+ dis: DataInputStream,
+ dataType: Char): Object = {
+ dataType match {
+ case 'n' => null
+ case 'i' => new java.lang.Integer(readInt(dis))
+ case 'd' => new java.lang.Double(readDouble(dis))
+ case 'b' => new java.lang.Boolean(readBoolean(dis))
+ case 'c' => readString(dis)
+ case 'e' => readMap(dis)
+ case 'r' => readBytes(dis)
+ case 'l' => readList(dis)
+ case 'D' => readDate(dis)
+ case 't' => readTime(dis)
+ case 'j' => JVMObjectTracker.getObject(readString(dis))
+ case _ => throw new IllegalArgumentException(s"Invalid type $dataType")
+ }
+ }
+
+ def readBytes(in: DataInputStream): Array[Byte] = {
+ val len = readInt(in)
+ val out = new Array[Byte](len)
+ val bytesRead = in.readFully(out)
+ out
+ }
+
+ def readInt(in: DataInputStream): Int = {
+ in.readInt()
+ }
+
+ def readDouble(in: DataInputStream): Double = {
+ in.readDouble()
+ }
+
+ def readString(in: DataInputStream): String = {
+ val len = in.readInt()
+ val asciiBytes = new Array[Byte](len)
+ in.readFully(asciiBytes)
+ assert(asciiBytes(len - 1) == 0)
+ val str = new String(asciiBytes.dropRight(1).map(_.toChar))
+ str
+ }
+
+ def readBoolean(in: DataInputStream): Boolean = {
+ val intVal = in.readInt()
+ if (intVal == 0) false else true
+ }
+
+ def readDate(in: DataInputStream): Date = {
+ Date.valueOf(readString(in))
+ }
+
+ def readTime(in: DataInputStream): Time = {
+ val t = in.readDouble()
+ new Time((t * 1000L).toLong)
+ }
+
+ def readBytesArr(in: DataInputStream): Array[Array[Byte]] = {
+ val len = readInt(in)
+ (0 until len).map(_ => readBytes(in)).toArray
+ }
+
+ def readIntArr(in: DataInputStream): Array[Int] = {
+ val len = readInt(in)
+ (0 until len).map(_ => readInt(in)).toArray
+ }
+
+ def readDoubleArr(in: DataInputStream): Array[Double] = {
+ val len = readInt(in)
+ (0 until len).map(_ => readDouble(in)).toArray
+ }
+
+ def readBooleanArr(in: DataInputStream): Array[Boolean] = {
+ val len = readInt(in)
+ (0 until len).map(_ => readBoolean(in)).toArray
+ }
+
+ def readStringArr(in: DataInputStream): Array[String] = {
+ val len = readInt(in)
+ (0 until len).map(_ => readString(in)).toArray
+ }
+
+ def readList(dis: DataInputStream): Array[_] = {
+ val arrType = readObjectType(dis)
+ arrType match {
+ case 'i' => readIntArr(dis)
+ case 'c' => readStringArr(dis)
+ case 'd' => readDoubleArr(dis)
+ case 'b' => readBooleanArr(dis)
+ case 'j' => readStringArr(dis).map(x => JVMObjectTracker.getObject(x))
+ case 'r' => readBytesArr(dis)
+ case _ => throw new IllegalArgumentException(s"Invalid array type $arrType")
+ }
+ }
+
+ def readMap(in: DataInputStream): java.util.Map[Object, Object] = {
+ val len = readInt(in)
+ if (len > 0) {
+ val keysType = readObjectType(in)
+ val keysLen = readInt(in)
+ val keys = (0 until keysLen).map(_ => readTypedObject(in, keysType))
+
+ val valuesType = readObjectType(in)
+ val valuesLen = readInt(in)
+ val values = (0 until valuesLen).map(_ => readTypedObject(in, valuesType))
+ mapAsJavaMap(keys.zip(values).toMap)
+ } else {
+ new java.util.HashMap[Object, Object]()
+ }
+ }
+
+ // Methods to write out data from Java to R
+ //
+ // Type mapping from Java to R
+ //
+ // void -> NULL
+ // Int -> integer
+ // String -> character
+ // Boolean -> logical
+ // Double -> double
+ // Long -> double
+ // Array[Byte] -> raw
+ // Date -> Date
+ // Time -> POSIXct
+ //
+ // Array[T] -> list()
+ // Object -> jobj
+
+ def writeType(dos: DataOutputStream, typeStr: String): Unit = {
+ typeStr match {
+ case "void" => dos.writeByte('n')
+ case "character" => dos.writeByte('c')
+ case "double" => dos.writeByte('d')
+ case "integer" => dos.writeByte('i')
+ case "logical" => dos.writeByte('b')
+ case "date" => dos.writeByte('D')
+ case "time" => dos.writeByte('t')
+ case "raw" => dos.writeByte('r')
+ case "list" => dos.writeByte('l')
+ case "jobj" => dos.writeByte('j')
+ case _ => throw new IllegalArgumentException(s"Invalid type $typeStr")
+ }
+ }
+
+ def writeObject(dos: DataOutputStream, value: Object): Unit = {
+ if (value == null) {
+ writeType(dos, "void")
+ } else {
+ value.getClass.getName match {
+ case "java.lang.String" =>
+ writeType(dos, "character")
+ writeString(dos, value.asInstanceOf[String])
+ case "long" | "java.lang.Long" =>
+ writeType(dos, "double")
+ writeDouble(dos, value.asInstanceOf[Long].toDouble)
+ case "double" | "java.lang.Double" =>
+ writeType(dos, "double")
+ writeDouble(dos, value.asInstanceOf[Double])
+ case "int" | "java.lang.Integer" =>
+ writeType(dos, "integer")
+ writeInt(dos, value.asInstanceOf[Int])
+ case "boolean" | "java.lang.Boolean" =>
+ writeType(dos, "logical")
+ writeBoolean(dos, value.asInstanceOf[Boolean])
+ case "java.sql.Date" =>
+ writeType(dos, "date")
+ writeDate(dos, value.asInstanceOf[Date])
+ case "java.sql.Time" =>
+ writeType(dos, "time")
+ writeTime(dos, value.asInstanceOf[Time])
+ case "[B" =>
+ writeType(dos, "raw")
+ writeBytes(dos, value.asInstanceOf[Array[Byte]])
+ // TODO: Types not handled right now include
+ // byte, char, short, float
+
+ // Handle arrays
+ case "[Ljava.lang.String;" =>
+ writeType(dos, "list")
+ writeStringArr(dos, value.asInstanceOf[Array[String]])
+ case "[I" =>
+ writeType(dos, "list")
+ writeIntArr(dos, value.asInstanceOf[Array[Int]])
+ case "[J" =>
+ writeType(dos, "list")
+ writeDoubleArr(dos, value.asInstanceOf[Array[Long]].map(_.toDouble))
+ case "[D" =>
+ writeType(dos, "list")
+ writeDoubleArr(dos, value.asInstanceOf[Array[Double]])
+ case "[Z" =>
+ writeType(dos, "list")
+ writeBooleanArr(dos, value.asInstanceOf[Array[Boolean]])
+ case "[[B" =>
+ writeType(dos, "list")
+ writeBytesArr(dos, value.asInstanceOf[Array[Array[Byte]]])
+ case otherName =>
+ // Handle array of objects
+ if (otherName.startsWith("[L")) {
+ val objArr = value.asInstanceOf[Array[Object]]
+ writeType(dos, "list")
+ writeType(dos, "jobj")
+ dos.writeInt(objArr.length)
+ objArr.foreach(o => writeJObj(dos, o))
+ } else {
+ writeType(dos, "jobj")
+ writeJObj(dos, value)
+ }
+ }
+ }
+ }
+
+ def writeInt(out: DataOutputStream, value: Int): Unit = {
+ out.writeInt(value)
+ }
+
+ def writeDouble(out: DataOutputStream, value: Double): Unit = {
+ out.writeDouble(value)
+ }
+
+ def writeBoolean(out: DataOutputStream, value: Boolean): Unit = {
+ val intValue = if (value) 1 else 0
+ out.writeInt(intValue)
+ }
+
+ def writeDate(out: DataOutputStream, value: Date): Unit = {
+ writeString(out, value.toString)
+ }
+
+ def writeTime(out: DataOutputStream, value: Time): Unit = {
+ out.writeDouble(value.getTime.toDouble / 1000.0)
+ }
+
+
+ // NOTE: Only works for ASCII right now
+ def writeString(out: DataOutputStream, value: String): Unit = {
+ val len = value.length
+ out.writeInt(len + 1) // For the \0
+ out.writeBytes(value)
+ out.writeByte(0)
+ }
+
+ def writeBytes(out: DataOutputStream, value: Array[Byte]): Unit = {
+ out.writeInt(value.length)
+ out.write(value)
+ }
+
+ def writeJObj(out: DataOutputStream, value: Object): Unit = {
+ val objId = JVMObjectTracker.put(value)
+ writeString(out, objId)
+ }
+
+ def writeIntArr(out: DataOutputStream, value: Array[Int]): Unit = {
+ writeType(out, "integer")
+ out.writeInt(value.length)
+ value.foreach(v => out.writeInt(v))
+ }
+
+ def writeDoubleArr(out: DataOutputStream, value: Array[Double]): Unit = {
+ writeType(out, "double")
+ out.writeInt(value.length)
+ value.foreach(v => out.writeDouble(v))
+ }
+
+ def writeBooleanArr(out: DataOutputStream, value: Array[Boolean]): Unit = {
+ writeType(out, "logical")
+ out.writeInt(value.length)
+ value.foreach(v => writeBoolean(out, v))
+ }
+
+ def writeStringArr(out: DataOutputStream, value: Array[String]): Unit = {
+ writeType(out, "character")
+ out.writeInt(value.length)
+ value.foreach(v => writeString(out, v))
+ }
+
+ def writeBytesArr(out: DataOutputStream, value: Array[Array[Byte]]): Unit = {
+ writeType(out, "raw")
+ out.writeInt(value.length)
+ value.foreach(v => writeBytes(out, v))
+ }
+}
+
+private[r] object SerializationFormats {
+ val BYTE = "byte"
+ val STRING = "string"
+ val ROW = "row"
+}
diff --git a/core/src/main/scala/org/apache/spark/deploy/RRunner.scala b/core/src/main/scala/org/apache/spark/deploy/RRunner.scala
new file mode 100644
index 0000000000..e99779f299
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/deploy/RRunner.scala
@@ -0,0 +1,92 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy
+
+import java.io._
+import java.util.concurrent.{Semaphore, TimeUnit}
+
+import scala.collection.JavaConversions._
+
+import org.apache.hadoop.fs.Path
+
+import org.apache.spark.api.r.RBackend
+import org.apache.spark.util.RedirectThread
+
+/**
+ * Main class used to launch SparkR applications using spark-submit. It executes R as a
+ * subprocess and then has it connect back to the JVM to access system properties etc.
+ */
+object RRunner {
+ def main(args: Array[String]): Unit = {
+ val rFile = PythonRunner.formatPath(args(0))
+
+ val otherArgs = args.slice(1, args.length)
+
+ // Time to wait for SparkR backend to initialize in seconds
+ val backendTimeout = sys.env.getOrElse("SPARKR_BACKEND_TIMEOUT", "120").toInt
+ val rCommand = "Rscript"
+
+ // Check if the file path exists.
+ // If not, change directory to current working directory for YARN cluster mode
+ val rF = new File(rFile)
+ val rFileNormalized = if (!rF.exists()) {
+ new Path(rFile).getName
+ } else {
+ rFile
+ }
+
+ // Launch a SparkR backend server for the R process to connect to; this will let it see our
+ // Java system properties etc.
+ val sparkRBackend = new RBackend()
+ @volatile var sparkRBackendPort = 0
+ val initialized = new Semaphore(0)
+ val sparkRBackendThread = new Thread("SparkR backend") {
+ override def run() {
+ sparkRBackendPort = sparkRBackend.init()
+ initialized.release()
+ sparkRBackend.run()
+ }
+ }
+
+ sparkRBackendThread.start()
+ // Wait for RBackend initialization to finish
+ if (initialized.tryAcquire(backendTimeout, TimeUnit.SECONDS)) {
+ // Launch R
+ val returnCode = try {
+ val builder = new ProcessBuilder(Seq(rCommand, rFileNormalized) ++ otherArgs)
+ val env = builder.environment()
+ env.put("EXISTING_SPARKR_BACKEND_PORT", sparkRBackendPort.toString)
+ val sparkHome = System.getenv("SPARK_HOME")
+ env.put("R_PROFILE_USER",
+ Seq(sparkHome, "R", "lib", "SparkR", "profile", "general.R").mkString(File.separator))
+ builder.redirectErrorStream(true) // Ugly but needed for stdout and stderr to synchronize
+ val process = builder.start()
+
+ new RedirectThread(process.getInputStream, System.out, "redirect R output").start()
+
+ process.waitFor()
+ } finally {
+ sparkRBackend.close()
+ }
+ System.exit(returnCode)
+ } else {
+ System.err.println("SparkR backend did not initialize in " + backendTimeout + " seconds")
+ System.exit(-1)
+ }
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala
index 660307d19e..60bc243ebf 100644
--- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala
@@ -77,6 +77,7 @@ object SparkSubmit {
// Special primary resource names that represent shells rather than application jars.
private val SPARK_SHELL = "spark-shell"
private val PYSPARK_SHELL = "pyspark-shell"
+ private val SPARKR_SHELL = "sparkr-shell"
private val CLASS_NOT_FOUND_EXIT_STATUS = 101
@@ -284,6 +285,13 @@ object SparkSubmit {
}
}
+ // Require all R files to be local
+ if (args.isR && !isYarnCluster) {
+ if (Utils.nonLocalPaths(args.primaryResource).nonEmpty) {
+ printErrorAndExit(s"Only local R files are supported: $args.primaryResource")
+ }
+ }
+
// The following modes are not supported or applicable
(clusterManager, deployMode) match {
case (MESOS, CLUSTER) =>
@@ -291,6 +299,9 @@ object SparkSubmit {
case (STANDALONE, CLUSTER) if args.isPython =>
printErrorAndExit("Cluster deploy mode is currently not supported for python " +
"applications on standalone clusters.")
+ case (STANDALONE, CLUSTER) if args.isR =>
+ printErrorAndExit("Cluster deploy mode is currently not supported for R " +
+ "applications on standalone clusters.")
case (_, CLUSTER) if isShell(args.primaryResource) =>
printErrorAndExit("Cluster deploy mode is not applicable to Spark shells.")
case (_, CLUSTER) if isSqlShell(args.mainClass) =>
@@ -317,11 +328,32 @@ object SparkSubmit {
}
}
- // In yarn-cluster mode for a python app, add primary resource and pyFiles to files
- // that can be distributed with the job
- if (args.isPython && isYarnCluster) {
- args.files = mergeFileLists(args.files, args.primaryResource)
- args.files = mergeFileLists(args.files, args.pyFiles)
+ // If we're running a R app, set the main class to our specific R runner
+ if (args.isR && deployMode == CLIENT) {
+ if (args.primaryResource == SPARKR_SHELL) {
+ args.mainClass = "org.apache.spark.api.r.RBackend"
+ } else {
+ // If a R file is provided, add it to the child arguments and list of files to deploy.
+ // Usage: RRunner <main R file> [app arguments]
+ args.mainClass = "org.apache.spark.deploy.RRunner"
+ args.childArgs = ArrayBuffer(args.primaryResource) ++ args.childArgs
+ args.files = mergeFileLists(args.files, args.primaryResource)
+ }
+ }
+
+ if (isYarnCluster) {
+ // In yarn-cluster mode for a python app, add primary resource and pyFiles to files
+ // that can be distributed with the job
+ if (args.isPython) {
+ args.files = mergeFileLists(args.files, args.primaryResource)
+ args.files = mergeFileLists(args.files, args.pyFiles)
+ }
+
+ // In yarn-cluster mode for a R app, add primary resource to files
+ // that can be distributed with the job
+ if (args.isR) {
+ args.files = mergeFileLists(args.files, args.primaryResource)
+ }
}
// Special flag to avoid deprecation warnings at the client
@@ -405,8 +437,8 @@ object SparkSubmit {
// Add the application jar automatically so the user doesn't have to call sc.addJar
// For YARN cluster mode, the jar is already distributed on each node as "app.jar"
- // For python files, the primary resource is already distributed as a regular file
- if (!isYarnCluster && !args.isPython) {
+ // For python and R files, the primary resource is already distributed as a regular file
+ if (!isYarnCluster && !args.isPython && !args.isR) {
var jars = sysProps.get("spark.jars").map(x => x.split(",").toSeq).getOrElse(Seq.empty)
if (isUserJar(args.primaryResource)) {
jars = jars ++ Seq(args.primaryResource)
@@ -447,6 +479,10 @@ object SparkSubmit {
childArgs += ("--py-files", pyFilesNames)
}
childArgs += ("--class", "org.apache.spark.deploy.PythonRunner")
+ } else if (args.isR) {
+ val mainFile = new Path(args.primaryResource).getName
+ childArgs += ("--primary-r-file", mainFile)
+ childArgs += ("--class", "org.apache.spark.deploy.RRunner")
} else {
if (args.primaryResource != SPARK_INTERNAL) {
childArgs += ("--jar", args.primaryResource)
@@ -591,15 +627,15 @@ object SparkSubmit {
/**
* Return whether the given primary resource represents a user jar.
*/
- private def isUserJar(primaryResource: String): Boolean = {
- !isShell(primaryResource) && !isPython(primaryResource) && !isInternal(primaryResource)
+ private[deploy] def isUserJar(res: String): Boolean = {
+ !isShell(res) && !isPython(res) && !isInternal(res) && !isR(res)
}
/**
* Return whether the given primary resource represents a shell.
*/
- private[deploy] def isShell(primaryResource: String): Boolean = {
- primaryResource == SPARK_SHELL || primaryResource == PYSPARK_SHELL
+ private[deploy] def isShell(res: String): Boolean = {
+ (res == SPARK_SHELL || res == PYSPARK_SHELL || res == SPARKR_SHELL)
}
/**
@@ -619,12 +655,19 @@ object SparkSubmit {
/**
* Return whether the given primary resource requires running python.
*/
- private[deploy] def isPython(primaryResource: String): Boolean = {
- primaryResource.endsWith(".py") || primaryResource == PYSPARK_SHELL
+ private[deploy] def isPython(res: String): Boolean = {
+ res != null && res.endsWith(".py") || res == PYSPARK_SHELL
+ }
+
+ /**
+ * Return whether the given primary resource requires running R.
+ */
+ private[deploy] def isR(res: String): Boolean = {
+ res != null && res.endsWith(".R") || res == SPARKR_SHELL
}
- private[deploy] def isInternal(primaryResource: String): Boolean = {
- primaryResource == SPARK_INTERNAL
+ private[deploy] def isInternal(res: String): Boolean = {
+ res == SPARK_INTERNAL
}
/**
diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala
index 6eb73c4347..03ecf3fd99 100644
--- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala
@@ -59,6 +59,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S
var verbose: Boolean = false
var isPython: Boolean = false
var pyFiles: String = null
+ var isR: Boolean = false
var action: SparkSubmitAction = null
val sparkProperties: HashMap[String, String] = new HashMap[String, String]()
var proxyUser: String = null
@@ -158,7 +159,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S
.getOrElse(sparkProperties.get("spark.executor.instances").orNull)
// Try to set main class from JAR if no --class argument is given
- if (mainClass == null && !isPython && primaryResource != null) {
+ if (mainClass == null && !isPython && !isR && primaryResource != null) {
val uri = new URI(primaryResource)
val uriScheme = uri.getScheme()
@@ -211,9 +212,9 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S
printUsageAndExit(-1)
}
if (primaryResource == null) {
- SparkSubmit.printErrorAndExit("Must specify a primary resource (JAR or Python file)")
+ SparkSubmit.printErrorAndExit("Must specify a primary resource (JAR or Python or R file)")
}
- if (mainClass == null && !isPython) {
+ if (mainClass == null && SparkSubmit.isUserJar(primaryResource)) {
SparkSubmit.printErrorAndExit("No main class set in JAR; please specify one with --class")
}
if (pyFiles != null && !isPython) {
@@ -414,6 +415,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S
opt
}
isPython = SparkSubmit.isPython(opt)
+ isR = SparkSubmit.isR(opt)
false
}
diff --git a/dev/run-tests b/dev/run-tests
index 561d7fc9e7..1b6cf78b5d 100755
--- a/dev/run-tests
+++ b/dev/run-tests
@@ -236,3 +236,18 @@ echo "========================================================================="
CURRENT_BLOCK=$BLOCK_PYSPARK_UNIT_TESTS
./python/run-tests
+
+echo ""
+echo "========================================================================="
+echo "Running SparkR tests"
+echo "========================================================================="
+
+CURRENT_BLOCK=$BLOCK_SPARKR_UNIT_TESTS
+
+if [ $(command -v R) ]; then
+ ./R/install-dev.sh
+ ./R/run-tests.sh
+else
+ echo "Ignoring SparkR tests as R was not found in PATH"
+fi
+
diff --git a/dev/run-tests-codes.sh b/dev/run-tests-codes.sh
index 8ab6db6925..154e01255b 100644
--- a/dev/run-tests-codes.sh
+++ b/dev/run-tests-codes.sh
@@ -25,3 +25,4 @@ readonly BLOCK_BUILD=14
readonly BLOCK_MIMA=15
readonly BLOCK_SPARK_UNIT_TESTS=16
readonly BLOCK_PYSPARK_UNIT_TESTS=17
+readonly BLOCK_SPARKR_UNIT_TESTS=18
diff --git a/dev/run-tests-jenkins b/dev/run-tests-jenkins
index f10aa6b59e..f6372835a6 100755
--- a/dev/run-tests-jenkins
+++ b/dev/run-tests-jenkins
@@ -210,6 +210,8 @@ done
failing_test="Spark unit tests"
elif [ "$test_result" -eq "$BLOCK_PYSPARK_UNIT_TESTS" ]; then
failing_test="PySpark unit tests"
+ elif [ "$test_result" -eq "$BLOCK_SPARKR_UNIT_TESTS" ]; then
+ failing_test="SparkR unit tests"
else
failing_test="some tests"
fi
diff --git a/docs/README.md b/docs/README.md
index 3773ea25c8..5852f972a0 100644
--- a/docs/README.md
+++ b/docs/README.md
@@ -58,13 +58,19 @@ phase, use the following sytax:
We use Sphinx to generate Python API docs, so you will need to install it by running
`sudo pip install sphinx`.
-## API Docs (Scaladoc and Sphinx)
+## knitr, devtools
+
+SparkR documentation is written using `roxygen2` and we use `knitr`, `devtools` to generate
+documentation. To install these packages you can run `install.packages(c("knitr", "devtools"))` from a
+R console.
+
+## API Docs (Scaladoc, Sphinx, roxygen2)
You can build just the Spark scaladoc by running `build/sbt unidoc` from the SPARK_PROJECT_ROOT directory.
Similarly, you can build just the PySpark docs by running `make html` from the
SPARK_PROJECT_ROOT/python/docs directory. Documentation is only generated for classes that are listed as
-public in `__init__.py`.
+public in `__init__.py`. The SparkR docs can be built by running SPARK_PROJECT_ROOT/R/create-docs.sh.
When you run `jekyll` in the `docs` directory, it will also copy over the scaladoc for the various
Spark subprojects into the `docs` directory (and then also into the `_site` directory). We use a
@@ -72,5 +78,5 @@ jekyll plugin to run `build/sbt unidoc` before building the site so if you haven
may take some time as it generates all of the scaladoc. The jekyll plugin also generates the
PySpark docs [Sphinx](http://sphinx-doc.org/).
-NOTE: To skip the step of building and copying over the Scala and Python API docs, run `SKIP_API=1
+NOTE: To skip the step of building and copying over the Scala, Python, R API docs, run `SKIP_API=1
jekyll`.
diff --git a/docs/_layouts/global.html b/docs/_layouts/global.html
index 2e88b30936..b92c75f90b 100755
--- a/docs/_layouts/global.html
+++ b/docs/_layouts/global.html
@@ -84,6 +84,7 @@
<li><a href="api/scala/index.html#org.apache.spark.package">Scala</a></li>
<li><a href="api/java/index.html">Java</a></li>
<li><a href="api/python/index.html">Python</a></li>
+ <li><a href="api/R/index.html">R</a></li>
</ul>
</li>
diff --git a/docs/_plugins/copy_api_dirs.rb b/docs/_plugins/copy_api_dirs.rb
index 3c626a0b7f..0ea3f8eab4 100644
--- a/docs/_plugins/copy_api_dirs.rb
+++ b/docs/_plugins/copy_api_dirs.rb
@@ -78,5 +78,18 @@ if not (ENV['SKIP_API'] == '1' or ENV['SKIP_SCALADOC'] == '1')
puts "cp -r python/docs/_build/html/. docs/api/python"
cp_r("python/docs/_build/html/.", "docs/api/python")
- cd("..")
+ # Build SparkR API docs
+ puts "Moving to R directory and building roxygen docs."
+ cd("R")
+ puts `./create-docs.sh`
+
+ puts "Moving back into home dir."
+ cd("../")
+
+ puts "Making directory api/R"
+ mkdir_p "docs/api/R"
+
+ puts "cp -r R/pkg/html/. docs/api/R"
+ cp_r("R/pkg/html/.", "docs/api/R")
+
end
diff --git a/examples/src/main/r/kmeans.R b/examples/src/main/r/kmeans.R
new file mode 100644
index 0000000000..6e6b5cb937
--- /dev/null
+++ b/examples/src/main/r/kmeans.R
@@ -0,0 +1,93 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+library(SparkR)
+
+# Logistic regression in Spark.
+# Note: unlike the example in Scala, a point here is represented as a vector of
+# doubles.
+
+parseVectors <- function(lines) {
+ lines <- strsplit(as.character(lines) , " ", fixed = TRUE)
+ list(matrix(as.numeric(unlist(lines)), ncol = length(lines[[1]])))
+}
+
+dist.fun <- function(P, C) {
+ apply(
+ C,
+ 1,
+ function(x) {
+ colSums((t(P) - x)^2)
+ }
+ )
+}
+
+closestPoint <- function(P, C) {
+ max.col(-dist.fun(P, C))
+}
+# Main program
+
+args <- commandArgs(trailing = TRUE)
+
+if (length(args) != 3) {
+ print("Usage: kmeans <file> <K> <convergeDist>")
+ q("no")
+}
+
+sc <- sparkR.init(appName = "RKMeans")
+K <- as.integer(args[[2]])
+convergeDist <- as.double(args[[3]])
+
+lines <- textFile(sc, args[[1]])
+points <- cache(lapplyPartition(lines, parseVectors))
+# kPoints <- take(points, K)
+kPoints <- do.call(rbind, takeSample(points, FALSE, K, 16189L))
+tempDist <- 1.0
+
+while (tempDist > convergeDist) {
+ closest <- lapplyPartition(
+ lapply(points,
+ function(p) {
+ cp <- closestPoint(p, kPoints);
+ mapply(list, unique(cp), split.data.frame(cbind(1, p), cp), SIMPLIFY=FALSE)
+ }),
+ function(x) {do.call(c, x)
+ })
+
+ pointStats <- reduceByKey(closest,
+ function(p1, p2) {
+ t(colSums(rbind(p1, p2)))
+ },
+ 2L)
+
+ newPoints <- do.call(
+ rbind,
+ collect(lapply(pointStats,
+ function(tup) {
+ point.sum <- tup[[2]][, -1]
+ point.count <- tup[[2]][, 1]
+ point.sum/point.count
+ })))
+
+ D <- dist.fun(kPoints, newPoints)
+ tempDist <- sum(D[cbind(1:3, max.col(-D))])
+ kPoints <- newPoints
+ cat("Finished iteration (delta = ", tempDist, ")\n")
+}
+
+cat("Final centers:\n")
+writeLines(unlist(lapply(kPoints, paste, collapse = " ")))
diff --git a/examples/src/main/r/linear_solver_mnist.R b/examples/src/main/r/linear_solver_mnist.R
new file mode 100644
index 0000000000..c864a4232d
--- /dev/null
+++ b/examples/src/main/r/linear_solver_mnist.R
@@ -0,0 +1,107 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# Instructions: https://github.com/amplab-extras/SparkR-pkg/wiki/SparkR-Example:-Digit-Recognition-on-EC2
+
+library(SparkR)
+library(Matrix)
+
+args <- commandArgs(trailing = TRUE)
+
+# number of random features; default to 1100
+D <- ifelse(length(args) > 0, as.integer(args[[1]]), 1100)
+# number of partitions for training dataset
+trainParts <- 12
+# dimension of digits
+d <- 784
+# number of test examples
+NTrain <- 60000
+# number of training examples
+NTest <- 10000
+# scale of features
+gamma <- 4e-4
+
+sc <- sparkR.init(appName = "SparkR-LinearSolver")
+
+# You can also use HDFS path to speed things up:
+# hdfs://<master>/train-mnist-dense-with-labels.data
+file <- textFile(sc, "/data/train-mnist-dense-with-labels.data", trainParts)
+
+W <- gamma * matrix(nrow=D, ncol=d, data=rnorm(D*d))
+b <- 2 * pi * matrix(nrow=D, ncol=1, data=runif(D))
+broadcastW <- broadcast(sc, W)
+broadcastB <- broadcast(sc, b)
+
+includePackage(sc, Matrix)
+numericLines <- lapplyPartitionsWithIndex(file,
+ function(split, part) {
+ matList <- sapply(part, function(line) {
+ as.numeric(strsplit(line, ",", fixed=TRUE)[[1]])
+ }, simplify=FALSE)
+ mat <- Matrix(ncol=d+1, data=unlist(matList, F, F),
+ sparse=T, byrow=T)
+ mat
+ })
+
+featureLabels <- cache(lapplyPartition(
+ numericLines,
+ function(part) {
+ label <- part[,1]
+ mat <- part[,-1]
+ ones <- rep(1, nrow(mat))
+ features <- cos(
+ mat %*% t(value(broadcastW)) + (matrix(ncol=1, data=ones) %*% t(value(broadcastB))))
+ onesMat <- Matrix(ones)
+ featuresPlus <- cBind(features, onesMat)
+ labels <- matrix(nrow=nrow(mat), ncol=10, data=-1)
+ for (i in 1:nrow(mat)) {
+ labels[i, label[i]] <- 1
+ }
+ list(label=labels, features=featuresPlus)
+ }))
+
+FTF <- Reduce("+", collect(lapplyPartition(featureLabels,
+ function(part) {
+ t(part$features) %*% part$features
+ }), flatten=F))
+
+FTY <- Reduce("+", collect(lapplyPartition(featureLabels,
+ function(part) {
+ t(part$features) %*% part$label
+ }), flatten=F))
+
+# solve for the coefficient matrix
+C <- solve(FTF, FTY)
+
+test <- Matrix(as.matrix(read.csv("/data/test-mnist-dense-with-labels.data",
+ header=F), sparse=T))
+testData <- test[,-1]
+testLabels <- matrix(ncol=1, test[,1])
+
+err <- 0
+
+# contstruct the feature maps for all examples from this digit
+featuresTest <- cos(testData %*% t(value(broadcastW)) +
+ (matrix(ncol=1, data=rep(1, NTest)) %*% t(value(broadcastB))))
+featuresTest <- cBind(featuresTest, Matrix(rep(1, NTest)))
+
+# extract the one vs. all assignment
+results <- featuresTest %*% C
+labelsGot <- apply(results, 1, which.max)
+err <- sum(testLabels != labelsGot) / nrow(testLabels)
+
+cat("\nFinished running. The error rate is: ", err, ".\n")
diff --git a/examples/src/main/r/logistic_regression.R b/examples/src/main/r/logistic_regression.R
new file mode 100644
index 0000000000..2a86aa9816
--- /dev/null
+++ b/examples/src/main/r/logistic_regression.R
@@ -0,0 +1,62 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+library(SparkR)
+
+args <- commandArgs(trailing = TRUE)
+
+if (length(args) != 3) {
+ print("Usage: logistic_regression <file> <iters> <dimension>")
+ q("no")
+}
+
+# Initialize Spark context
+sc <- sparkR.init(appName = "LogisticRegressionR")
+iterations <- as.integer(args[[2]])
+D <- as.integer(args[[3]])
+
+readPartition <- function(part){
+ part = strsplit(part, " ", fixed = T)
+ list(matrix(as.numeric(unlist(part)), ncol = length(part[[1]])))
+}
+
+# Read data points and convert each partition to a matrix
+points <- cache(lapplyPartition(textFile(sc, args[[1]]), readPartition))
+
+# Initialize w to a random value
+w <- runif(n=D, min = -1, max = 1)
+cat("Initial w: ", w, "\n")
+
+# Compute logistic regression gradient for a matrix of data points
+gradient <- function(partition) {
+ partition = partition[[1]]
+ Y <- partition[, 1] # point labels (first column of input file)
+ X <- partition[, -1] # point coordinates
+
+ # For each point (x, y), compute gradient function
+ dot <- X %*% w
+ logit <- 1 / (1 + exp(-Y * dot))
+ grad <- t(X) %*% ((logit - 1) * Y)
+ list(grad)
+}
+
+for (i in 1:iterations) {
+ cat("On iteration ", i, "\n")
+ w <- w - reduce(lapplyPartition(points, gradient), "+")
+}
+
+cat("Final w: ", w, "\n")
diff --git a/examples/src/main/r/pi.R b/examples/src/main/r/pi.R
new file mode 100644
index 0000000000..aa7a833e14
--- /dev/null
+++ b/examples/src/main/r/pi.R
@@ -0,0 +1,46 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+library(SparkR)
+
+args <- commandArgs(trailing = TRUE)
+
+sc <- sparkR.init(appName = "PiR")
+
+slices <- ifelse(length(args) > 1, as.integer(args[[2]]), 2)
+
+n <- 100000 * slices
+
+piFunc <- function(elem) {
+ rands <- runif(n = 2, min = -1, max = 1)
+ val <- ifelse((rands[1]^2 + rands[2]^2) < 1, 1.0, 0.0)
+ val
+}
+
+
+piFuncVec <- function(elems) {
+ message(length(elems))
+ rands1 <- runif(n = length(elems), min = -1, max = 1)
+ rands2 <- runif(n = length(elems), min = -1, max = 1)
+ val <- ifelse((rands1^2 + rands2^2) < 1, 1.0, 0.0)
+ sum(val)
+}
+
+rdd <- parallelize(sc, 1:n, slices)
+count <- reduce(lapplyPartition(rdd, piFuncVec), sum)
+cat("Pi is roughly", 4.0 * count / n, "\n")
+cat("Num elements in RDD ", count(rdd), "\n")
diff --git a/examples/src/main/r/wordcount.R b/examples/src/main/r/wordcount.R
new file mode 100644
index 0000000000..b734cb0ecf
--- /dev/null
+++ b/examples/src/main/r/wordcount.R
@@ -0,0 +1,42 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+library(SparkR)
+
+args <- commandArgs(trailing = TRUE)
+
+if (length(args) != 1) {
+ print("Usage: wordcount <file>")
+ q("no")
+}
+
+# Initialize Spark context
+sc <- sparkR.init(appName = "RwordCount")
+lines <- textFile(sc, args[[1]])
+
+words <- flatMap(lines,
+ function(line) {
+ strsplit(line, " ")[[1]]
+ })
+wordCount <- lapply(words, function(word) { list(word, 1L) })
+
+counts <- reduceByKey(wordCount, "+", 2L)
+output <- collect(counts)
+
+for (wordcount in output) {
+ cat(wordcount[[1]], ": ", wordcount[[2]], "\n")
+}
diff --git a/launcher/src/main/java/org/apache/spark/launcher/CommandBuilderUtils.java b/launcher/src/main/java/org/apache/spark/launcher/CommandBuilderUtils.java
index 9b04732afe..f4ebc25bdd 100644
--- a/launcher/src/main/java/org/apache/spark/launcher/CommandBuilderUtils.java
+++ b/launcher/src/main/java/org/apache/spark/launcher/CommandBuilderUtils.java
@@ -274,14 +274,14 @@ class CommandBuilderUtils {
}
/**
- * Quotes a string so that it can be used in a command string and be parsed back into a single
- * argument by python's "shlex.split()" function.
- *
+ * Quotes a string so that it can be used in a command string.
* Basically, just add simple escapes. E.g.:
* original single argument : ab "cd" ef
* after: "ab \"cd\" ef"
+ *
+ * This can be parsed back into a single argument by python's "shlex.split()" function.
*/
- static String quoteForPython(String s) {
+ static String quoteForCommandString(String s) {
StringBuilder quoted = new StringBuilder().append('"');
for (int i = 0; i < s.length(); i++) {
int cp = s.codePointAt(i);
diff --git a/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java
index 91dcf70f10..a73c9c87e3 100644
--- a/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java
+++ b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java
@@ -17,14 +17,9 @@
package org.apache.spark.launcher;
+import java.io.File;
import java.io.IOException;
-import java.util.ArrayList;
-import java.util.Arrays;
-import java.util.Collections;
-import java.util.HashMap;
-import java.util.List;
-import java.util.Map;
-import java.util.Properties;
+import java.util.*;
import static org.apache.spark.launcher.CommandBuilderUtils.*;
@@ -54,6 +49,20 @@ class SparkSubmitCommandBuilder extends AbstractCommandBuilder {
static final String PYSPARK_SHELL_RESOURCE = "pyspark-shell";
/**
+ * Name of the app resource used to identify the SparkR shell. The command line parser expects
+ * the resource name to be the very first argument to spark-submit in this case.
+ *
+ * NOTE: this cannot be "sparkr-shell" since that identifies the SparkR shell to SparkSubmit
+ * (see sparkR.R), and can cause this code to enter into an infinite loop.
+ */
+ static final String SPARKR_SHELL = "sparkr-shell-main";
+
+ /**
+ * This is the actual resource name that identifies the SparkR shell to SparkSubmit.
+ */
+ static final String SPARKR_SHELL_RESOURCE = "sparkr-shell";
+
+ /**
* This map must match the class names for available special classes, since this modifies the way
* command line parsing works. This maps the class name to the resource to use when calling
* spark-submit.
@@ -87,6 +96,10 @@ class SparkSubmitCommandBuilder extends AbstractCommandBuilder {
this.allowsMixedArguments = true;
appResource = PYSPARK_SHELL_RESOURCE;
submitArgs = args.subList(1, args.size());
+ } else if (args.size() > 0 && args.get(0).equals(SPARKR_SHELL)) {
+ this.allowsMixedArguments = true;
+ appResource = SPARKR_SHELL_RESOURCE;
+ submitArgs = args.subList(1, args.size());
} else {
this.allowsMixedArguments = false;
}
@@ -98,6 +111,8 @@ class SparkSubmitCommandBuilder extends AbstractCommandBuilder {
public List<String> buildCommand(Map<String, String> env) throws IOException {
if (PYSPARK_SHELL_RESOURCE.equals(appResource)) {
return buildPySparkShellCommand(env);
+ } else if (SPARKR_SHELL_RESOURCE.equals(appResource)) {
+ return buildSparkRCommand(env);
} else {
return buildSparkSubmitCommand(env);
}
@@ -213,36 +228,62 @@ class SparkSubmitCommandBuilder extends AbstractCommandBuilder {
return buildCommand(env);
}
- // When launching the pyspark shell, the spark-submit arguments should be stored in the
- // PYSPARK_SUBMIT_ARGS env variable. The executable is the PYSPARK_DRIVER_PYTHON env variable
- // set by the pyspark script, followed by PYSPARK_DRIVER_PYTHON_OPTS.
checkArgument(appArgs.isEmpty(), "pyspark does not support any application options.");
+ // When launching the pyspark shell, the spark-submit arguments should be stored in the
+ // PYSPARK_SUBMIT_ARGS env variable.
+ constructEnvVarArgs(env, "PYSPARK_SUBMIT_ARGS");
+
+ // The executable is the PYSPARK_DRIVER_PYTHON env variable set by the pyspark script,
+ // followed by PYSPARK_DRIVER_PYTHON_OPTS.
+ List<String> pyargs = new ArrayList<String>();
+ pyargs.add(firstNonEmpty(System.getenv("PYSPARK_DRIVER_PYTHON"), "python"));
+ String pyOpts = System.getenv("PYSPARK_DRIVER_PYTHON_OPTS");
+ if (!isEmpty(pyOpts)) {
+ pyargs.addAll(parseOptionString(pyOpts));
+ }
+
+ return pyargs;
+ }
+
+ private List<String> buildSparkRCommand(Map<String, String> env) throws IOException {
+ if (!appArgs.isEmpty() && appArgs.get(0).endsWith(".R")) {
+ appResource = appArgs.get(0);
+ appArgs.remove(0);
+ return buildCommand(env);
+ }
+ // When launching the SparkR shell, store the spark-submit arguments in the SPARKR_SUBMIT_ARGS
+ // env variable.
+ constructEnvVarArgs(env, "SPARKR_SUBMIT_ARGS");
+
+ // Set shell.R as R_PROFILE_USER to load the SparkR package when the shell comes up.
+ String sparkHome = System.getenv("SPARK_HOME");
+ env.put("R_PROFILE_USER",
+ join(File.separator, sparkHome, "R", "lib", "SparkR", "profile", "shell.R"));
+
+ List<String> args = new ArrayList<String>();
+ args.add(firstNonEmpty(System.getenv("SPARKR_DRIVER_R"), "R"));
+ return args;
+ }
+
+ private void constructEnvVarArgs(
+ Map<String, String> env,
+ String submitArgsEnvVariable) throws IOException {
Properties props = loadPropertiesFile();
mergeEnvPathList(env, getLibPathEnvName(),
firstNonEmptyValue(SparkLauncher.DRIVER_EXTRA_LIBRARY_PATH, conf, props));
- // Store spark-submit arguments in an environment variable, since there's no way to pass
- // them to shell.py on the comand line.
StringBuilder submitArgs = new StringBuilder();
for (String arg : buildSparkSubmitArgs()) {
if (submitArgs.length() > 0) {
submitArgs.append(" ");
}
- submitArgs.append(quoteForPython(arg));
+ submitArgs.append(quoteForCommandString(arg));
}
- env.put("PYSPARK_SUBMIT_ARGS", submitArgs.toString());
-
- List<String> pyargs = new ArrayList<String>();
- pyargs.add(firstNonEmpty(System.getenv("PYSPARK_DRIVER_PYTHON"), "python"));
- String pyOpts = System.getenv("PYSPARK_DRIVER_PYTHON_OPTS");
- if (!isEmpty(pyOpts)) {
- pyargs.addAll(parseOptionString(pyOpts));
- }
-
- return pyargs;
+ env.put(submitArgsEnvVariable, submitArgs.toString());
}
+
private boolean isClientMode(Properties userProps) {
String userMaster = firstNonEmpty(master, (String) userProps.get(SparkLauncher.SPARK_MASTER));
// Default master is "local[*]", so assume client mode in that case.
diff --git a/launcher/src/test/java/org/apache/spark/launcher/CommandBuilderUtilsSuite.java b/launcher/src/test/java/org/apache/spark/launcher/CommandBuilderUtilsSuite.java
index dba0203867..1ae42eed8a 100644
--- a/launcher/src/test/java/org/apache/spark/launcher/CommandBuilderUtilsSuite.java
+++ b/launcher/src/test/java/org/apache/spark/launcher/CommandBuilderUtilsSuite.java
@@ -79,9 +79,9 @@ public class CommandBuilderUtilsSuite {
@Test
public void testPythonArgQuoting() {
- assertEquals("\"abc\"", quoteForPython("abc"));
- assertEquals("\"a b c\"", quoteForPython("a b c"));
- assertEquals("\"a \\\"b\\\" c\"", quoteForPython("a \"b\" c"));
+ assertEquals("\"abc\"", quoteForCommandString("abc"));
+ assertEquals("\"a b c\"", quoteForCommandString("a b c"));
+ assertEquals("\"a \\\"b\\\" c\"", quoteForCommandString("a \"b\" c"));
}
private void testOpt(String opts, List<String> expected) {
diff --git a/pom.xml b/pom.xml
index 42bd926a2f..70e297c4f0 100644
--- a/pom.xml
+++ b/pom.xml
@@ -1749,5 +1749,8 @@
<profile>
<id>parquet-provided</id>
</profile>
+ <profile>
+ <id>sparkr</id>
+ </profile>
</profiles>
</project>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala
index a5e6b638d2..53ad67372e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala
@@ -34,7 +34,7 @@ import org.apache.spark.sql.types.NumericType
@Experimental
class GroupedData protected[sql](df: DataFrame, groupingExprs: Seq[Expression]) {
- private[this] implicit def toDF(aggExprs: Seq[NamedExpression]): DataFrame = {
+ private[sql] implicit def toDF(aggExprs: Seq[NamedExpression]): DataFrame = {
val namedGroupingExprs = groupingExprs.map {
case expr: NamedExpression => expr
case expr: Expression => Alias(expr, expr.prettyString)()
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala
new file mode 100644
index 0000000000..d1ea7cc3e9
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala
@@ -0,0 +1,127 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.api.r
+
+import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream}
+
+import org.apache.spark.api.java.{JavaRDD, JavaSparkContext}
+import org.apache.spark.api.r.SerDe
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.expressions.{Alias, Expression, NamedExpression}
+import org.apache.spark.sql.types.{DataType, StructType}
+import org.apache.spark.sql.{Column, DataFrame, GroupedData, Row, SQLContext, SaveMode}
+
+private[r] object SQLUtils {
+ def createSQLContext(jsc: JavaSparkContext): SQLContext = {
+ new SQLContext(jsc)
+ }
+
+ def getJavaSparkContext(sqlCtx: SQLContext): JavaSparkContext = {
+ new JavaSparkContext(sqlCtx.sparkContext)
+ }
+
+ def toSeq[T](arr: Array[T]): Seq[T] = {
+ arr.toSeq
+ }
+
+ def createDF(rdd: RDD[Array[Byte]], schemaString: String, sqlContext: SQLContext): DataFrame = {
+ val schema = DataType.fromJson(schemaString).asInstanceOf[StructType]
+ val num = schema.fields.size
+ val rowRDD = rdd.map(bytesToRow)
+ sqlContext.createDataFrame(rowRDD, schema)
+ }
+
+ // A helper to include grouping columns in Agg()
+ def aggWithGrouping(gd: GroupedData, exprs: Column*): DataFrame = {
+ val aggExprs = exprs.map { col =>
+ col.expr match {
+ case expr: NamedExpression => expr
+ case expr: Expression => Alias(expr, expr.simpleString)()
+ }
+ }
+ gd.toDF(aggExprs)
+ }
+
+ def dfToRowRDD(df: DataFrame): JavaRDD[Array[Byte]] = {
+ df.map(r => rowToRBytes(r))
+ }
+
+ private[this] def bytesToRow(bytes: Array[Byte]): Row = {
+ val bis = new ByteArrayInputStream(bytes)
+ val dis = new DataInputStream(bis)
+ val num = SerDe.readInt(dis)
+ Row.fromSeq((0 until num).map { i =>
+ SerDe.readObject(dis)
+ }.toSeq)
+ }
+
+ private[this] def rowToRBytes(row: Row): Array[Byte] = {
+ val bos = new ByteArrayOutputStream()
+ val dos = new DataOutputStream(bos)
+
+ SerDe.writeInt(dos, row.length)
+ (0 until row.length).map { idx =>
+ val obj: Object = row(idx).asInstanceOf[Object]
+ SerDe.writeObject(dos, obj)
+ }
+ bos.toByteArray()
+ }
+
+ def dfToCols(df: DataFrame): Array[Array[Byte]] = {
+ // localDF is Array[Row]
+ val localDF = df.collect()
+ val numCols = df.columns.length
+ // dfCols is Array[Array[Any]]
+ val dfCols = convertRowsToColumns(localDF, numCols)
+
+ dfCols.map { col =>
+ colToRBytes(col)
+ }
+ }
+
+ def convertRowsToColumns(localDF: Array[Row], numCols: Int): Array[Array[Any]] = {
+ (0 until numCols).map { colIdx =>
+ localDF.map { row =>
+ row(colIdx)
+ }
+ }.toArray
+ }
+
+ def colToRBytes(col: Array[Any]): Array[Byte] = {
+ val numRows = col.length
+ val bos = new ByteArrayOutputStream()
+ val dos = new DataOutputStream(bos)
+
+ SerDe.writeInt(dos, numRows)
+
+ col.map { item =>
+ val obj: Object = item.asInstanceOf[Object]
+ SerDe.writeObject(dos, obj)
+ }
+ bos.toByteArray()
+ }
+
+ def saveMode(mode: String): SaveMode = {
+ mode match {
+ case "append" => SaveMode.Append
+ case "overwrite" => SaveMode.Overwrite
+ case "error" => SaveMode.ErrorIfExists
+ case "ignore" => SaveMode.Ignore
+ }
+ }
+}
diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
index 24a1e02795..32bc4e5663 100644
--- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
+++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
@@ -469,6 +469,9 @@ private[spark] class ApplicationMaster(
System.setProperty("spark.submit.pyFiles",
PythonRunner.formatPaths(args.pyFiles).mkString(","))
}
+ if (args.primaryRFile != null && args.primaryRFile.endsWith(".R")) {
+ // TODO(davies): add R dependencies here
+ }
val mainMethod = userClassLoader.loadClass(args.userClass)
.getMethod("main", classOf[Array[String]])
diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala
index e1a992af3a..ae6dc1094d 100644
--- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala
+++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala
@@ -25,6 +25,7 @@ class ApplicationMasterArguments(val args: Array[String]) {
var userJar: String = null
var userClass: String = null
var primaryPyFile: String = null
+ var primaryRFile: String = null
var pyFiles: String = null
var userArgs: Seq[String] = Seq[String]()
var executorMemory = 1024
@@ -54,6 +55,10 @@ class ApplicationMasterArguments(val args: Array[String]) {
primaryPyFile = value
args = tail
+ case ("--primary-r-file") :: value :: tail =>
+ primaryRFile = value
+ args = tail
+
case ("--py-files") :: value :: tail =>
pyFiles = value
args = tail
@@ -79,6 +84,11 @@ class ApplicationMasterArguments(val args: Array[String]) {
}
}
+ if (primaryPyFile != null && primaryRFile != null) {
+ System.err.println("Cannot have primary-py-file and primary-r-file at the same time")
+ System.exit(-1)
+ }
+
userArgs = userArgsBuffer.readOnly
}
@@ -92,6 +102,7 @@ class ApplicationMasterArguments(val args: Array[String]) {
| --jar JAR_PATH Path to your application's JAR file
| --class CLASS_NAME Name of your application's main class
| --primary-py-file A main Python file
+ | --primary-r-file A main R file
| --py-files PY_FILES Comma-separated list of .zip, .egg, or .py files to
| place on the PYTHONPATH for Python apps.
| --args ARGS Arguments to be passed to your application's main class.
diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
index 7219852c0a..c1effd3c8a 100644
--- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
+++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
@@ -491,6 +491,12 @@ private[spark] class Client(
} else {
Nil
}
+ val primaryRFile =
+ if (args.primaryRFile != null) {
+ Seq("--primary-r-file", args.primaryRFile)
+ } else {
+ Nil
+ }
val amClass =
if (isClusterMode) {
Class.forName("org.apache.spark.deploy.yarn.ApplicationMaster").getName
@@ -500,12 +506,15 @@ private[spark] class Client(
if (args.primaryPyFile != null && args.primaryPyFile.endsWith(".py")) {
args.userArgs = ArrayBuffer(args.primaryPyFile, args.pyFiles) ++ args.userArgs
}
+ if (args.primaryRFile != null && args.primaryRFile.endsWith(".R")) {
+ args.userArgs = ArrayBuffer(args.primaryRFile) ++ args.userArgs
+ }
val userArgs = args.userArgs.flatMap { arg =>
Seq("--arg", YarnSparkHadoopUtil.escapeForShell(arg))
}
val amArgs =
- Seq(amClass) ++ userClass ++ userJar ++ primaryPyFile ++ pyFiles ++ userArgs ++
- Seq(
+ Seq(amClass) ++ userClass ++ userJar ++ primaryPyFile ++ pyFiles ++ primaryRFile ++
+ userArgs ++ Seq(
"--executor-memory", args.executorMemory.toString + "m",
"--executor-cores", args.executorCores.toString,
"--num-executors ", args.numExecutors.toString)
diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala
index 3bc7eb1abf..da6798cb1b 100644
--- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala
+++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala
@@ -32,6 +32,7 @@ private[spark] class ClientArguments(args: Array[String], sparkConf: SparkConf)
var userClass: String = null
var pyFiles: String = null
var primaryPyFile: String = null
+ var primaryRFile: String = null
var userArgs: ArrayBuffer[String] = new ArrayBuffer[String]()
var executorMemory = 1024 // MB
var executorCores = 1
@@ -150,6 +151,10 @@ private[spark] class ClientArguments(args: Array[String], sparkConf: SparkConf)
primaryPyFile = value
args = tail
+ case ("--primary-r-file") :: value :: tail =>
+ primaryRFile = value
+ args = tail
+
case ("--args" | "--arg") :: value :: tail =>
if (args(0) == "--args") {
println("--args is deprecated. Use --arg instead.")
@@ -228,6 +233,11 @@ private[spark] class ClientArguments(args: Array[String], sparkConf: SparkConf)
throw new IllegalArgumentException(getUsageMessage(args))
}
}
+
+ if (primaryPyFile != null && primaryRFile != null) {
+ throw new IllegalArgumentException("Cannot have primary-py-file and primary-r-file" +
+ " at the same time")
+ }
}
private def getUsageMessage(unknownParam: List[String] = null): String = {
@@ -240,6 +250,7 @@ private[spark] class ClientArguments(args: Array[String], sparkConf: SparkConf)
| mode)
| --class CLASS_NAME Name of your application's main class (required)
| --primary-py-file A main Python file
+ | --primary-r-file A main R file
| --arg ARG Argument to be passed to your application's main class.
| Multiple invocations are possible, each will be passed in order.
| --num-executors NUM Number of executors to start (Default: 2)