diff options
author | root <root@ip-10-165-19-218.ec2.internal> | 2013-07-01 02:45:00 +0000 |
---|---|---|
committer | root <root@ip-10-165-19-218.ec2.internal> | 2013-07-01 06:25:43 +0000 |
commit | 3296d132b6ce042843de6e7384800e089b49e5fa (patch) | |
tree | 31ca339fb26ce8a59092f1c1ab4e53ce0842a12a /core | |
parent | 39ae073b5cd0dcfe4a00d9f205c88bad9df37870 (diff) | |
download | spark-3296d132b6ce042843de6e7384800e089b49e5fa.tar.gz spark-3296d132b6ce042843de6e7384800e089b49e5fa.tar.bz2 spark-3296d132b6ce042843de6e7384800e089b49e5fa.zip |
Fix performance bug with new Python code not using buffered streams
Diffstat (limited to 'core')
-rw-r--r-- | core/src/main/scala/spark/SparkEnv.scala | 3 | ||||
-rw-r--r-- | core/src/main/scala/spark/api/python/PythonRDD.scala | 33 |
2 files changed, 19 insertions, 17 deletions
diff --git a/core/src/main/scala/spark/SparkEnv.scala b/core/src/main/scala/spark/SparkEnv.scala index 7ccde2e818..ec59b4f48f 100644 --- a/core/src/main/scala/spark/SparkEnv.scala +++ b/core/src/main/scala/spark/SparkEnv.scala @@ -59,7 +59,8 @@ class SparkEnv ( def createPythonWorker(pythonExec: String, envVars: Map[String, String]): java.net.Socket = { synchronized { - pythonWorkers.getOrElseUpdate((pythonExec, envVars), new PythonWorkerFactory(pythonExec, envVars)).create() + val key = (pythonExec, envVars) + pythonWorkers.getOrElseUpdate(key, new PythonWorkerFactory(pythonExec, envVars)).create() } } diff --git a/core/src/main/scala/spark/api/python/PythonRDD.scala b/core/src/main/scala/spark/api/python/PythonRDD.scala index 63140cf37f..3f283afa62 100644 --- a/core/src/main/scala/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/spark/api/python/PythonRDD.scala @@ -45,37 +45,38 @@ private[spark] class PythonRDD[T: ClassManifest]( new Thread("stdin writer for " + pythonExec) { override def run() { SparkEnv.set(env) - val out = new PrintWriter(worker.getOutputStream) - val dOut = new DataOutputStream(worker.getOutputStream) + val stream = new BufferedOutputStream(worker.getOutputStream) + val dataOut = new DataOutputStream(stream) + val printOut = new PrintWriter(stream) // Partition index - dOut.writeInt(split.index) + dataOut.writeInt(split.index) // sparkFilesDir - PythonRDD.writeAsPickle(SparkFiles.getRootDirectory, dOut) + PythonRDD.writeAsPickle(SparkFiles.getRootDirectory, dataOut) // Broadcast variables - dOut.writeInt(broadcastVars.length) + dataOut.writeInt(broadcastVars.length) for (broadcast <- broadcastVars) { - dOut.writeLong(broadcast.id) - dOut.writeInt(broadcast.value.length) - dOut.write(broadcast.value) - dOut.flush() + dataOut.writeLong(broadcast.id) + dataOut.writeInt(broadcast.value.length) + dataOut.write(broadcast.value) } + dataOut.flush() // Serialized user code for (elem <- command) { - out.println(elem) + printOut.println(elem) } - out.flush() + printOut.flush() // Data values for (elem <- parent.iterator(split, context)) { - PythonRDD.writeAsPickle(elem, dOut) + PythonRDD.writeAsPickle(elem, dataOut) } - dOut.flush() - out.flush() + dataOut.flush() + printOut.flush() worker.shutdownOutput() } }.start() // Return an iterator that read lines from the process's stdout - val stream = new DataInputStream(worker.getInputStream) + val stream = new DataInputStream(new BufferedInputStream(worker.getInputStream)) return new Iterator[Array[Byte]] { def next(): Array[Byte] = { val obj = _nextObj @@ -288,7 +289,7 @@ class PythonAccumulatorParam(@transient serverHost: String, serverPort: Int) // This happens on the master, where we pass the updates to Python through a socket val socket = new Socket(serverHost, serverPort) val in = socket.getInputStream - val out = new DataOutputStream(socket.getOutputStream) + val out = new DataOutputStream(new BufferedOutputStream(socket.getOutputStream)) out.writeInt(val2.size) for (array <- val2) { out.writeInt(array.length) |