aboutsummaryrefslogtreecommitdiff
path: root/core
diff options
context:
space:
mode:
authorMatei Zaharia <matei@eecs.berkeley.edu>2013-01-20 01:57:44 -0800
committerMatei Zaharia <matei@eecs.berkeley.edu>2013-01-20 01:57:44 -0800
commit8e7f098a2c9e5e85cb9435f28d53a3a5847c14aa (patch)
treea4b5d1891501e99f44b4d797bee3a10504e0b2fd /core
parent54c0f9f185576e9b844fa8f81ca410f188daa51c (diff)
downloadspark-8e7f098a2c9e5e85cb9435f28d53a3a5847c14aa.tar.gz
spark-8e7f098a2c9e5e85cb9435f28d53a3a5847c14aa.tar.bz2
spark-8e7f098a2c9e5e85cb9435f28d53a3a5847c14aa.zip
Added accumulators to PySpark
Diffstat (limited to 'core')
-rw-r--r--core/src/main/scala/spark/api/python/PythonRDD.scala83
1 files changed, 67 insertions, 16 deletions
diff --git a/core/src/main/scala/spark/api/python/PythonRDD.scala b/core/src/main/scala/spark/api/python/PythonRDD.scala
index f431ef28d3..fb13e84658 100644
--- a/core/src/main/scala/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/spark/api/python/PythonRDD.scala
@@ -1,7 +1,8 @@
package spark.api.python
import java.io._
-import java.util.{List => JList}
+import java.net._
+import java.util.{List => JList, ArrayList => JArrayList, Collections}
import scala.collection.JavaConversions._
import scala.io.Source
@@ -10,25 +11,26 @@ import spark.api.java.{JavaSparkContext, JavaPairRDD, JavaRDD}
import spark.broadcast.Broadcast
import spark._
import spark.rdd.PipedRDD
-import java.util
private[spark] class PythonRDD[T: ClassManifest](
- parent: RDD[T],
- command: Seq[String],
- envVars: java.util.Map[String, String],
- preservePartitoning: Boolean,
- pythonExec: String,
- broadcastVars: java.util.List[Broadcast[Array[Byte]]])
+ parent: RDD[T],
+ command: Seq[String],
+ envVars: java.util.Map[String, String],
+ preservePartitoning: Boolean,
+ pythonExec: String,
+ broadcastVars: JList[Broadcast[Array[Byte]]],
+ accumulator: Accumulator[JList[Array[Byte]]])
extends RDD[Array[Byte]](parent.context) {
// Similar to Runtime.exec(), if we are given a single string, split it into words
// using a standard StringTokenizer (i.e. by spaces)
def this(parent: RDD[T], command: String, envVars: java.util.Map[String, String],
- preservePartitoning: Boolean, pythonExec: String,
- broadcastVars: java.util.List[Broadcast[Array[Byte]]]) =
+ preservePartitoning: Boolean, pythonExec: String,
+ broadcastVars: JList[Broadcast[Array[Byte]]],
+ accumulator: Accumulator[JList[Array[Byte]]]) =
this(parent, PipedRDD.tokenize(command), envVars, preservePartitoning, pythonExec,
- broadcastVars)
+ broadcastVars, accumulator)
override def splits = parent.splits
@@ -93,18 +95,30 @@ private[spark] class PythonRDD[T: ClassManifest](
// Return an iterator that read lines from the process's stdout
val stream = new DataInputStream(proc.getInputStream)
return new Iterator[Array[Byte]] {
- def next() = {
+ def next(): Array[Byte] = {
val obj = _nextObj
_nextObj = read()
obj
}
- private def read() = {
+ private def read(): Array[Byte] = {
try {
val length = stream.readInt()
- val obj = new Array[Byte](length)
- stream.readFully(obj)
- obj
+ if (length != -1) {
+ val obj = new Array[Byte](length)
+ stream.readFully(obj)
+ obj
+ } else {
+ // We've finished the data section of the output, but we can still read some
+ // accumulator updates; let's do that, breaking when we get EOFException
+ while (true) {
+ val len2 = stream.readInt()
+ val update = new Array[Byte](len2)
+ stream.readFully(update)
+ accumulator += Collections.singletonList(update)
+ }
+ new Array[Byte](0)
+ }
} catch {
case eof: EOFException => {
val exitStatus = proc.waitFor()
@@ -246,3 +260,40 @@ private class ExtractValue extends spark.api.java.function.Function[(Array[Byte]
private class BytesToString extends spark.api.java.function.Function[Array[Byte], String] {
override def call(arr: Array[Byte]) : String = new String(arr, "UTF-8")
}
+
+/**
+ * Internal class that acts as an `AccumulatorParam` for Python accumulators. Inside, it
+ * collects a list of pickled strings that we pass to Python through a socket.
+ */
+class PythonAccumulatorParam(@transient serverHost: String, serverPort: Int)
+ extends AccumulatorParam[JList[Array[Byte]]] {
+
+ override def zero(value: JList[Array[Byte]]): JList[Array[Byte]] = new JArrayList
+
+ override def addInPlace(val1: JList[Array[Byte]], val2: JList[Array[Byte]])
+ : JList[Array[Byte]] = {
+ if (serverHost == null) {
+ // This happens on the worker node, where we just want to remember all the updates
+ val1.addAll(val2)
+ val1
+ } else {
+ // This happens on the master, where we pass the updates to Python through a socket
+ val socket = new Socket(serverHost, serverPort)
+ val in = socket.getInputStream
+ val out = new DataOutputStream(socket.getOutputStream)
+ out.writeInt(val2.size)
+ for (array <- val2) {
+ out.writeInt(array.length)
+ out.write(array)
+ }
+ out.flush()
+ // Wait for a byte from the Python side as an acknowledgement
+ val byteRead = in.read()
+ if (byteRead == -1) {
+ throw new SparkException("EOF reached before Python server acknowledged")
+ }
+ socket.close()
+ null
+ }
+ }
+}