aboutsummaryrefslogtreecommitdiff
path: root/R/pkg/inst/worker/worker.R
diff options
context:
space:
mode:
Diffstat (limited to 'R/pkg/inst/worker/worker.R')
-rw-r--r--R/pkg/inst/worker/worker.R128
1 files changed, 128 insertions, 0 deletions
diff --git a/R/pkg/inst/worker/worker.R b/R/pkg/inst/worker/worker.R
new file mode 100644
index 0000000000..c6542928e8
--- /dev/null
+++ b/R/pkg/inst/worker/worker.R
@@ -0,0 +1,128 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# Worker class
+
+rLibDir <- Sys.getenv("SPARKR_RLIBDIR")
+# Set libPaths to include SparkR package as loadNamespace needs this
+# TODO: Figure out if we can avoid this by not loading any objects that require
+# SparkR namespace
+.libPaths(c(rLibDir, .libPaths()))
+suppressPackageStartupMessages(library(SparkR))
+
+port <- as.integer(Sys.getenv("SPARKR_WORKER_PORT"))
+inputCon <- socketConnection(port = port, blocking = TRUE, open = "rb")
+outputCon <- socketConnection(port = port, blocking = TRUE, open = "wb")
+
+# read the index of the current partition inside the RDD
+partition <- SparkR:::readInt(inputCon)
+
+deserializer <- SparkR:::readString(inputCon)
+serializer <- SparkR:::readString(inputCon)
+
+# Include packages as required
+packageNames <- unserialize(SparkR:::readRaw(inputCon))
+for (pkg in packageNames) {
+ suppressPackageStartupMessages(require(as.character(pkg), character.only=TRUE))
+}
+
+# read function dependencies
+funcLen <- SparkR:::readInt(inputCon)
+computeFunc <- unserialize(SparkR:::readRawLen(inputCon, funcLen))
+env <- environment(computeFunc)
+parent.env(env) <- .GlobalEnv # Attach under global environment.
+
+# Read and set broadcast variables
+numBroadcastVars <- SparkR:::readInt(inputCon)
+if (numBroadcastVars > 0) {
+ for (bcast in seq(1:numBroadcastVars)) {
+ bcastId <- SparkR:::readInt(inputCon)
+ value <- unserialize(SparkR:::readRaw(inputCon))
+ setBroadcastValue(bcastId, value)
+ }
+}
+
+# If -1: read as normal RDD; if >= 0, treat as pairwise RDD and treat the int
+# as number of partitions to create.
+numPartitions <- SparkR:::readInt(inputCon)
+
+isEmpty <- SparkR:::readInt(inputCon)
+
+if (isEmpty != 0) {
+
+ if (numPartitions == -1) {
+ if (deserializer == "byte") {
+ # Now read as many characters as described in funcLen
+ data <- SparkR:::readDeserialize(inputCon)
+ } else if (deserializer == "string") {
+ data <- as.list(readLines(inputCon))
+ } else if (deserializer == "row") {
+ data <- SparkR:::readDeserializeRows(inputCon)
+ }
+ output <- computeFunc(partition, data)
+ if (serializer == "byte") {
+ SparkR:::writeRawSerialize(outputCon, output)
+ } else if (serializer == "row") {
+ SparkR:::writeRowSerialize(outputCon, output)
+ } else {
+ SparkR:::writeStrings(outputCon, output)
+ }
+ } else {
+ if (deserializer == "byte") {
+ # Now read as many characters as described in funcLen
+ data <- SparkR:::readDeserialize(inputCon)
+ } else if (deserializer == "string") {
+ data <- readLines(inputCon)
+ } else if (deserializer == "row") {
+ data <- SparkR:::readDeserializeRows(inputCon)
+ }
+
+ res <- new.env()
+
+ # Step 1: hash the data to an environment
+ hashTupleToEnvir <- function(tuple) {
+ # NOTE: execFunction is the hash function here
+ hashVal <- computeFunc(tuple[[1]])
+ bucket <- as.character(hashVal %% numPartitions)
+ acc <- res[[bucket]]
+ # Create a new accumulator
+ if (is.null(acc)) {
+ acc <- SparkR:::initAccumulator()
+ }
+ SparkR:::addItemToAccumulator(acc, tuple)
+ res[[bucket]] <- acc
+ }
+ invisible(lapply(data, hashTupleToEnvir))
+
+ # Step 2: write out all of the environment as key-value pairs.
+ for (name in ls(res)) {
+ SparkR:::writeInt(outputCon, 2L)
+ SparkR:::writeInt(outputCon, as.integer(name))
+ # Truncate the accumulator list to the number of elements we have
+ length(res[[name]]$data) <- res[[name]]$counter
+ SparkR:::writeRawSerialize(outputCon, res[[name]]$data)
+ }
+ }
+}
+
+# End of output
+if (serializer %in% c("byte", "row")) {
+ SparkR:::writeInt(outputCon, 0L)
+}
+
+close(outputCon)
+close(inputCon)