aboutsummaryrefslogtreecommitdiff
path: root/R/pkg/inst/worker/worker.R
blob: c6542928e8dddfdd58c78eafe48c9d54d71a8228 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements.  See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License.  You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

# Worker class

rLibDir <- Sys.getenv("SPARKR_RLIBDIR")
# Set libPaths to include SparkR package as loadNamespace needs this
# TODO: Figure out if we can avoid this by not loading any objects that require
# SparkR namespace
.libPaths(c(rLibDir, .libPaths()))
suppressPackageStartupMessages(library(SparkR))

port <- as.integer(Sys.getenv("SPARKR_WORKER_PORT"))
inputCon <- socketConnection(port = port, blocking = TRUE, open = "rb")
outputCon <- socketConnection(port = port, blocking = TRUE, open = "wb")

# read the index of the current partition inside the RDD
partition <- SparkR:::readInt(inputCon)

deserializer <- SparkR:::readString(inputCon)
serializer <- SparkR:::readString(inputCon)

# Include packages as required
packageNames <- unserialize(SparkR:::readRaw(inputCon))
for (pkg in packageNames) {
  suppressPackageStartupMessages(require(as.character(pkg), character.only=TRUE))
}

# read function dependencies
funcLen <- SparkR:::readInt(inputCon)
computeFunc <- unserialize(SparkR:::readRawLen(inputCon, funcLen))
env <- environment(computeFunc)
parent.env(env) <- .GlobalEnv  # Attach under global environment.

# Read and set broadcast variables
numBroadcastVars <- SparkR:::readInt(inputCon)
if (numBroadcastVars > 0) {
  for (bcast in seq(1:numBroadcastVars)) {
    bcastId <- SparkR:::readInt(inputCon)
    value <- unserialize(SparkR:::readRaw(inputCon))
    setBroadcastValue(bcastId, value)
  }
}

# If -1: read as normal RDD; if >= 0, treat as pairwise RDD and treat the int
# as number of partitions to create.
numPartitions <- SparkR:::readInt(inputCon)

isEmpty <- SparkR:::readInt(inputCon)

if (isEmpty != 0) {

  if (numPartitions == -1) {
    if (deserializer == "byte") {
      # Now read as many characters as described in funcLen
      data <- SparkR:::readDeserialize(inputCon)
    } else if (deserializer == "string") {
      data <- as.list(readLines(inputCon))
    } else if (deserializer == "row") {
      data <- SparkR:::readDeserializeRows(inputCon)
    }
    output <- computeFunc(partition, data)
    if (serializer == "byte") {
      SparkR:::writeRawSerialize(outputCon, output)
    } else if (serializer == "row") {
      SparkR:::writeRowSerialize(outputCon, output)
    } else {
      SparkR:::writeStrings(outputCon, output)
    }
  } else {
    if (deserializer == "byte") {
      # Now read as many characters as described in funcLen
      data <- SparkR:::readDeserialize(inputCon)
    } else if (deserializer == "string") {
      data <- readLines(inputCon)
    } else if (deserializer == "row") {
      data <- SparkR:::readDeserializeRows(inputCon)
    }

    res <- new.env()

    # Step 1: hash the data to an environment
    hashTupleToEnvir <- function(tuple) {
      # NOTE: execFunction is the hash function here
      hashVal <- computeFunc(tuple[[1]])
      bucket <- as.character(hashVal %% numPartitions)
      acc <- res[[bucket]]
      # Create a new accumulator
      if (is.null(acc)) {
        acc <- SparkR:::initAccumulator()
      }
      SparkR:::addItemToAccumulator(acc, tuple)
      res[[bucket]] <- acc
    }
    invisible(lapply(data, hashTupleToEnvir))

    # Step 2: write out all of the environment as key-value pairs.
    for (name in ls(res)) {
      SparkR:::writeInt(outputCon, 2L)
      SparkR:::writeInt(outputCon, as.integer(name))
      # Truncate the accumulator list to the number of elements we have
      length(res[[name]]$data) <- res[[name]]$counter
      SparkR:::writeRawSerialize(outputCon, res[[name]]$data)
    }
  }
}

# End of output
if (serializer %in% c("byte", "row")) {
  SparkR:::writeInt(outputCon, 0L)
}

close(outputCon)
close(inputCon)