aboutsummaryrefslogtreecommitdiff
path: root/core
diff options
context:
space:
mode:
authorMatei Zaharia <matei@eecs.berkeley.edu>2013-01-20 12:47:55 -0800
committerMatei Zaharia <matei@eecs.berkeley.edu>2013-01-20 12:47:55 -0800
commit86057ec7c868262763d1e31b3f3c94bd43eeafb3 (patch)
treecdde5e4264549bd10b67e4da73322391b922e14e /core
parent76ff962edcb7f41601c6c2d4fc6714bbc885faa7 (diff)
parent9f54d7e1f5a5e6f80b3d710de67f800bef943d33 (diff)
downloadspark-86057ec7c868262763d1e31b3f3c94bd43eeafb3.tar.gz
spark-86057ec7c868262763d1e31b3f3c94bd43eeafb3.tar.bz2
spark-86057ec7c868262763d1e31b3f3c94bd43eeafb3.zip
Merge branch 'master' into streaming
Conflicts: core/src/main/scala/spark/api/python/PythonRDD.scala
Diffstat (limited to 'core')
-rw-r--r--core/src/main/scala/spark/Utils.scala11
-rw-r--r--core/src/main/scala/spark/api/python/PythonRDD.scala83
-rw-r--r--core/src/main/scala/spark/broadcast/HttpBroadcast.scala2
-rw-r--r--core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala3
4 files changed, 78 insertions, 21 deletions
diff --git a/core/src/main/scala/spark/Utils.scala b/core/src/main/scala/spark/Utils.scala
index d08921b25f..b3421df27c 100644
--- a/core/src/main/scala/spark/Utils.scala
+++ b/core/src/main/scala/spark/Utils.scala
@@ -134,7 +134,7 @@ private object Utils extends Logging {
*/
def fetchFile(url: String, targetDir: File) {
val filename = url.split("/").last
- val tempDir = System.getProperty("spark.local.dir", System.getProperty("java.io.tmpdir"))
+ val tempDir = getLocalDir
val tempFile = File.createTempFile("fetchFileTemp", null, new File(tempDir))
val targetFile = new File(targetDir, filename)
val uri = new URI(url)
@@ -205,6 +205,15 @@ private object Utils extends Logging {
}
/**
+ * Get a temporary directory using Spark's spark.local.dir property, if set. This will always
+ * return a single directory, even though the spark.local.dir property might be a list of
+ * multiple paths.
+ */
+ def getLocalDir: String = {
+ System.getProperty("spark.local.dir", System.getProperty("java.io.tmpdir")).split(',')(0)
+ }
+
+ /**
* Shuffle the elements of a collection into a random order, returning the
* result in a new collection. Unlike scala.util.Random.shuffle, this method
* uses a local random number generator, avoiding inter-thread contention.
diff --git a/core/src/main/scala/spark/api/python/PythonRDD.scala b/core/src/main/scala/spark/api/python/PythonRDD.scala
index 0138b22d38..89f7c316dc 100644
--- a/core/src/main/scala/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/spark/api/python/PythonRDD.scala
@@ -1,7 +1,8 @@
package spark.api.python
import java.io._
-import java.util.{List => JList}
+import java.net._
+import java.util.{List => JList, ArrayList => JArrayList, Collections}
import scala.collection.JavaConversions._
import scala.io.Source
@@ -10,25 +11,26 @@ import spark.api.java.{JavaSparkContext, JavaPairRDD, JavaRDD}
import spark.broadcast.Broadcast
import spark._
import spark.rdd.PipedRDD
-import java.util
private[spark] class PythonRDD[T: ClassManifest](
- parent: RDD[T],
- command: Seq[String],
- envVars: java.util.Map[String, String],
- preservePartitoning: Boolean,
- pythonExec: String,
- broadcastVars: java.util.List[Broadcast[Array[Byte]]])
+ parent: RDD[T],
+ command: Seq[String],
+ envVars: java.util.Map[String, String],
+ preservePartitoning: Boolean,
+ pythonExec: String,
+ broadcastVars: JList[Broadcast[Array[Byte]]],
+ accumulator: Accumulator[JList[Array[Byte]]])
extends RDD[Array[Byte]](parent) {
// Similar to Runtime.exec(), if we are given a single string, split it into words
// using a standard StringTokenizer (i.e. by spaces)
def this(parent: RDD[T], command: String, envVars: java.util.Map[String, String],
- preservePartitoning: Boolean, pythonExec: String,
- broadcastVars: java.util.List[Broadcast[Array[Byte]]]) =
+ preservePartitoning: Boolean, pythonExec: String,
+ broadcastVars: JList[Broadcast[Array[Byte]]],
+ accumulator: Accumulator[JList[Array[Byte]]]) =
this(parent, PipedRDD.tokenize(command), envVars, preservePartitoning, pythonExec,
- broadcastVars)
+ broadcastVars, accumulator)
override def getSplits = parent.splits
@@ -91,18 +93,30 @@ private[spark] class PythonRDD[T: ClassManifest](
// Return an iterator that read lines from the process's stdout
val stream = new DataInputStream(proc.getInputStream)
return new Iterator[Array[Byte]] {
- def next() = {
+ def next(): Array[Byte] = {
val obj = _nextObj
_nextObj = read()
obj
}
- private def read() = {
+ private def read(): Array[Byte] = {
try {
val length = stream.readInt()
- val obj = new Array[Byte](length)
- stream.readFully(obj)
- obj
+ if (length != -1) {
+ val obj = new Array[Byte](length)
+ stream.readFully(obj)
+ obj
+ } else {
+ // We've finished the data section of the output, but we can still read some
+ // accumulator updates; let's do that, breaking when we get EOFException
+ while (true) {
+ val len2 = stream.readInt()
+ val update = new Array[Byte](len2)
+ stream.readFully(update)
+ accumulator += Collections.singletonList(update)
+ }
+ new Array[Byte](0)
+ }
} catch {
case eof: EOFException => {
val exitStatus = proc.waitFor()
@@ -246,3 +260,40 @@ private class ExtractValue extends spark.api.java.function.Function[(Array[Byte]
private class BytesToString extends spark.api.java.function.Function[Array[Byte], String] {
override def call(arr: Array[Byte]) : String = new String(arr, "UTF-8")
}
+
+/**
+ * Internal class that acts as an `AccumulatorParam` for Python accumulators. Inside, it
+ * collects a list of pickled strings that we pass to Python through a socket.
+ */
+class PythonAccumulatorParam(@transient serverHost: String, serverPort: Int)
+ extends AccumulatorParam[JList[Array[Byte]]] {
+
+ override def zero(value: JList[Array[Byte]]): JList[Array[Byte]] = new JArrayList
+
+ override def addInPlace(val1: JList[Array[Byte]], val2: JList[Array[Byte]])
+ : JList[Array[Byte]] = {
+ if (serverHost == null) {
+ // This happens on the worker node, where we just want to remember all the updates
+ val1.addAll(val2)
+ val1
+ } else {
+ // This happens on the master, where we pass the updates to Python through a socket
+ val socket = new Socket(serverHost, serverPort)
+ val in = socket.getInputStream
+ val out = new DataOutputStream(socket.getOutputStream)
+ out.writeInt(val2.size)
+ for (array <- val2) {
+ out.writeInt(array.length)
+ out.write(array)
+ }
+ out.flush()
+ // Wait for a byte from the Python side as an acknowledgement
+ val byteRead = in.read()
+ if (byteRead == -1) {
+ throw new SparkException("EOF reached before Python server acknowledged")
+ }
+ socket.close()
+ null
+ }
+ }
+}
diff --git a/core/src/main/scala/spark/broadcast/HttpBroadcast.scala b/core/src/main/scala/spark/broadcast/HttpBroadcast.scala
index fef264aab1..8e490e6bad 100644
--- a/core/src/main/scala/spark/broadcast/HttpBroadcast.scala
+++ b/core/src/main/scala/spark/broadcast/HttpBroadcast.scala
@@ -95,7 +95,7 @@ private object HttpBroadcast extends Logging {
}
private def createServer() {
- broadcastDir = Utils.createTempDir()
+ broadcastDir = Utils.createTempDir(Utils.getLocalDir)
server = new HttpServer(broadcastDir)
server.start()
serverUri = server.uri
diff --git a/core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala b/core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala
index 915f71ba9f..a29bf974d2 100644
--- a/core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala
+++ b/core/src/main/scala/spark/executor/StandaloneExecutorBackend.scala
@@ -24,9 +24,6 @@ private[spark] class StandaloneExecutorBackend(
with ExecutorBackend
with Logging {
- val threadPool = new ThreadPoolExecutor(
- 1, 128, 600, TimeUnit.SECONDS, new SynchronousQueue[Runnable])
-
var master: ActorRef = null
override def preStart() {