diff options
author | root <root@ip-10-165-19-218.ec2.internal> | 2013-07-01 06:20:14 +0000 |
---|---|---|
committer | root <root@ip-10-165-19-218.ec2.internal> | 2013-07-01 06:26:31 +0000 |
commit | ec31e68d5df259e6df001529235d8c906ff02a6f (patch) | |
tree | f71c7fce1c75b8d931676440b0b139a88fd5a7e6 /core/src | |
parent | 3296d132b6ce042843de6e7384800e089b49e5fa (diff) | |
download | spark-ec31e68d5df259e6df001529235d8c906ff02a6f.tar.gz spark-ec31e68d5df259e6df001529235d8c906ff02a6f.tar.bz2 spark-ec31e68d5df259e6df001529235d8c906ff02a6f.zip |
Fixed PySpark perf regression by not using socket.makefile(), and improved
debuggability by letting "print" statements show up in the executor's stderr
Conflicts:
core/src/main/scala/spark/api/python/PythonRDD.scala
Diffstat (limited to 'core/src')
-rw-r--r-- | core/src/main/scala/spark/api/python/PythonRDD.scala | 10 | ||||
-rw-r--r-- | core/src/main/scala/spark/api/python/PythonWorkerFactory.scala | 20 |
2 files changed, 26 insertions, 4 deletions
diff --git a/core/src/main/scala/spark/api/python/PythonRDD.scala b/core/src/main/scala/spark/api/python/PythonRDD.scala index 3f283afa62..31d8ea89d4 100644 --- a/core/src/main/scala/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/spark/api/python/PythonRDD.scala @@ -22,6 +22,8 @@ private[spark] class PythonRDD[T: ClassManifest]( accumulator: Accumulator[JList[Array[Byte]]]) extends RDD[Array[Byte]](parent) { + val bufferSize = System.getProperty("spark.buffer.size", "65536").toInt + // Similar to Runtime.exec(), if we are given a single string, split it into words // using a standard StringTokenizer (i.e. by spaces) def this(parent: RDD[T], command: String, envVars: JMap[String, String], @@ -45,7 +47,7 @@ private[spark] class PythonRDD[T: ClassManifest]( new Thread("stdin writer for " + pythonExec) { override def run() { SparkEnv.set(env) - val stream = new BufferedOutputStream(worker.getOutputStream) + val stream = new BufferedOutputStream(worker.getOutputStream, bufferSize) val dataOut = new DataOutputStream(stream) val printOut = new PrintWriter(stream) // Partition index @@ -76,7 +78,7 @@ private[spark] class PythonRDD[T: ClassManifest]( }.start() // Return an iterator that read lines from the process's stdout - val stream = new DataInputStream(new BufferedInputStream(worker.getInputStream)) + val stream = new DataInputStream(new BufferedInputStream(worker.getInputStream, bufferSize)) return new Iterator[Array[Byte]] { def next(): Array[Byte] = { val obj = _nextObj @@ -276,6 +278,8 @@ class PythonAccumulatorParam(@transient serverHost: String, serverPort: Int) extends AccumulatorParam[JList[Array[Byte]]] { Utils.checkHost(serverHost, "Expected hostname") + + val bufferSize = System.getProperty("spark.buffer.size", "65536").toInt override def zero(value: JList[Array[Byte]]): JList[Array[Byte]] = new JArrayList @@ -289,7 +293,7 @@ class PythonAccumulatorParam(@transient serverHost: String, serverPort: Int) // This happens on the master, where we pass the updates to Python through a socket val socket = new Socket(serverHost, serverPort) val in = socket.getInputStream - val out = new DataOutputStream(new BufferedOutputStream(socket.getOutputStream)) + val out = new DataOutputStream(new BufferedOutputStream(socket.getOutputStream, bufferSize)) out.writeInt(val2.size) for (array <- val2) { out.writeInt(array.length) diff --git a/core/src/main/scala/spark/api/python/PythonWorkerFactory.scala b/core/src/main/scala/spark/api/python/PythonWorkerFactory.scala index 8844411d73..85d1dfeac8 100644 --- a/core/src/main/scala/spark/api/python/PythonWorkerFactory.scala +++ b/core/src/main/scala/spark/api/python/PythonWorkerFactory.scala @@ -51,7 +51,6 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String val workerEnv = pb.environment() workerEnv.putAll(envVars) daemon = pb.start() - daemonPort = new DataInputStream(daemon.getInputStream).readInt() // Redirect the stderr to ours new Thread("stderr reader for " + pythonExec) { @@ -69,6 +68,25 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String } } }.start() + + val in = new DataInputStream(daemon.getInputStream) + daemonPort = in.readInt() + + // Redirect further stdout output to our stderr + new Thread("stdout reader for " + pythonExec) { + override def run() { + scala.util.control.Exception.ignoring(classOf[IOException]) { + // FIXME HACK: We copy the stream on the level of bytes to + // attempt to dodge encoding problems. + var buf = new Array[Byte](1024) + var len = in.read(buf) + while (len != -1) { + System.err.write(buf, 0, len) + len = in.read(buf) + } + } + } + }.start() } catch { case e => { stopDaemon() |