aboutsummaryrefslogtreecommitdiff
path: root/core/src
diff options
context:
space:
mode:
authorroot <root@ip-10-165-19-218.ec2.internal>2013-07-01 06:20:14 +0000
committerroot <root@ip-10-165-19-218.ec2.internal>2013-07-01 06:26:31 +0000
commitec31e68d5df259e6df001529235d8c906ff02a6f (patch)
treef71c7fce1c75b8d931676440b0b139a88fd5a7e6 /core/src
parent3296d132b6ce042843de6e7384800e089b49e5fa (diff)
downloadspark-ec31e68d5df259e6df001529235d8c906ff02a6f.tar.gz
spark-ec31e68d5df259e6df001529235d8c906ff02a6f.tar.bz2
spark-ec31e68d5df259e6df001529235d8c906ff02a6f.zip
Fixed PySpark perf regression by not using socket.makefile(), and improved
debuggability by letting "print" statements show up in the executor's stderr Conflicts: core/src/main/scala/spark/api/python/PythonRDD.scala
Diffstat (limited to 'core/src')
-rw-r--r--core/src/main/scala/spark/api/python/PythonRDD.scala10
-rw-r--r--core/src/main/scala/spark/api/python/PythonWorkerFactory.scala20
2 files changed, 26 insertions, 4 deletions
diff --git a/core/src/main/scala/spark/api/python/PythonRDD.scala b/core/src/main/scala/spark/api/python/PythonRDD.scala
index 3f283afa62..31d8ea89d4 100644
--- a/core/src/main/scala/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/spark/api/python/PythonRDD.scala
@@ -22,6 +22,8 @@ private[spark] class PythonRDD[T: ClassManifest](
accumulator: Accumulator[JList[Array[Byte]]])
extends RDD[Array[Byte]](parent) {
+ val bufferSize = System.getProperty("spark.buffer.size", "65536").toInt
+
// Similar to Runtime.exec(), if we are given a single string, split it into words
// using a standard StringTokenizer (i.e. by spaces)
def this(parent: RDD[T], command: String, envVars: JMap[String, String],
@@ -45,7 +47,7 @@ private[spark] class PythonRDD[T: ClassManifest](
new Thread("stdin writer for " + pythonExec) {
override def run() {
SparkEnv.set(env)
- val stream = new BufferedOutputStream(worker.getOutputStream)
+ val stream = new BufferedOutputStream(worker.getOutputStream, bufferSize)
val dataOut = new DataOutputStream(stream)
val printOut = new PrintWriter(stream)
// Partition index
@@ -76,7 +78,7 @@ private[spark] class PythonRDD[T: ClassManifest](
}.start()
// Return an iterator that read lines from the process's stdout
- val stream = new DataInputStream(new BufferedInputStream(worker.getInputStream))
+ val stream = new DataInputStream(new BufferedInputStream(worker.getInputStream, bufferSize))
return new Iterator[Array[Byte]] {
def next(): Array[Byte] = {
val obj = _nextObj
@@ -276,6 +278,8 @@ class PythonAccumulatorParam(@transient serverHost: String, serverPort: Int)
extends AccumulatorParam[JList[Array[Byte]]] {
Utils.checkHost(serverHost, "Expected hostname")
+
+ val bufferSize = System.getProperty("spark.buffer.size", "65536").toInt
override def zero(value: JList[Array[Byte]]): JList[Array[Byte]] = new JArrayList
@@ -289,7 +293,7 @@ class PythonAccumulatorParam(@transient serverHost: String, serverPort: Int)
// This happens on the master, where we pass the updates to Python through a socket
val socket = new Socket(serverHost, serverPort)
val in = socket.getInputStream
- val out = new DataOutputStream(new BufferedOutputStream(socket.getOutputStream))
+ val out = new DataOutputStream(new BufferedOutputStream(socket.getOutputStream, bufferSize))
out.writeInt(val2.size)
for (array <- val2) {
out.writeInt(array.length)
diff --git a/core/src/main/scala/spark/api/python/PythonWorkerFactory.scala b/core/src/main/scala/spark/api/python/PythonWorkerFactory.scala
index 8844411d73..85d1dfeac8 100644
--- a/core/src/main/scala/spark/api/python/PythonWorkerFactory.scala
+++ b/core/src/main/scala/spark/api/python/PythonWorkerFactory.scala
@@ -51,7 +51,6 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
val workerEnv = pb.environment()
workerEnv.putAll(envVars)
daemon = pb.start()
- daemonPort = new DataInputStream(daemon.getInputStream).readInt()
// Redirect the stderr to ours
new Thread("stderr reader for " + pythonExec) {
@@ -69,6 +68,25 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
}
}
}.start()
+
+ val in = new DataInputStream(daemon.getInputStream)
+ daemonPort = in.readInt()
+
+ // Redirect further stdout output to our stderr
+ new Thread("stdout reader for " + pythonExec) {
+ override def run() {
+ scala.util.control.Exception.ignoring(classOf[IOException]) {
+ // FIXME HACK: We copy the stream on the level of bytes to
+ // attempt to dodge encoding problems.
+ var buf = new Array[Byte](1024)
+ var len = in.read(buf)
+ while (len != -1) {
+ System.err.write(buf, 0, len)
+ len = in.read(buf)
+ }
+ }
+ }
+ }.start()
} catch {
case e => {
stopDaemon()