aboutsummaryrefslogtreecommitdiff
path: root/core/src
diff options
context:
space:
mode:
authorroot <root@ip-10-165-19-218.ec2.internal>2013-07-01 02:45:00 +0000
committerroot <root@ip-10-165-19-218.ec2.internal>2013-07-01 06:25:43 +0000
commit3296d132b6ce042843de6e7384800e089b49e5fa (patch)
tree31ca339fb26ce8a59092f1c1ab4e53ce0842a12a /core/src
parent39ae073b5cd0dcfe4a00d9f205c88bad9df37870 (diff)
downloadspark-3296d132b6ce042843de6e7384800e089b49e5fa.tar.gz
spark-3296d132b6ce042843de6e7384800e089b49e5fa.tar.bz2
spark-3296d132b6ce042843de6e7384800e089b49e5fa.zip
Fix performance bug with new Python code not using buffered streams
Diffstat (limited to 'core/src')
-rw-r--r--core/src/main/scala/spark/SparkEnv.scala3
-rw-r--r--core/src/main/scala/spark/api/python/PythonRDD.scala33
2 files changed, 19 insertions, 17 deletions
diff --git a/core/src/main/scala/spark/SparkEnv.scala b/core/src/main/scala/spark/SparkEnv.scala
index 7ccde2e818..ec59b4f48f 100644
--- a/core/src/main/scala/spark/SparkEnv.scala
+++ b/core/src/main/scala/spark/SparkEnv.scala
@@ -59,7 +59,8 @@ class SparkEnv (
def createPythonWorker(pythonExec: String, envVars: Map[String, String]): java.net.Socket = {
synchronized {
- pythonWorkers.getOrElseUpdate((pythonExec, envVars), new PythonWorkerFactory(pythonExec, envVars)).create()
+ val key = (pythonExec, envVars)
+ pythonWorkers.getOrElseUpdate(key, new PythonWorkerFactory(pythonExec, envVars)).create()
}
}
diff --git a/core/src/main/scala/spark/api/python/PythonRDD.scala b/core/src/main/scala/spark/api/python/PythonRDD.scala
index 63140cf37f..3f283afa62 100644
--- a/core/src/main/scala/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/spark/api/python/PythonRDD.scala
@@ -45,37 +45,38 @@ private[spark] class PythonRDD[T: ClassManifest](
new Thread("stdin writer for " + pythonExec) {
override def run() {
SparkEnv.set(env)
- val out = new PrintWriter(worker.getOutputStream)
- val dOut = new DataOutputStream(worker.getOutputStream)
+ val stream = new BufferedOutputStream(worker.getOutputStream)
+ val dataOut = new DataOutputStream(stream)
+ val printOut = new PrintWriter(stream)
// Partition index
- dOut.writeInt(split.index)
+ dataOut.writeInt(split.index)
// sparkFilesDir
- PythonRDD.writeAsPickle(SparkFiles.getRootDirectory, dOut)
+ PythonRDD.writeAsPickle(SparkFiles.getRootDirectory, dataOut)
// Broadcast variables
- dOut.writeInt(broadcastVars.length)
+ dataOut.writeInt(broadcastVars.length)
for (broadcast <- broadcastVars) {
- dOut.writeLong(broadcast.id)
- dOut.writeInt(broadcast.value.length)
- dOut.write(broadcast.value)
- dOut.flush()
+ dataOut.writeLong(broadcast.id)
+ dataOut.writeInt(broadcast.value.length)
+ dataOut.write(broadcast.value)
}
+ dataOut.flush()
// Serialized user code
for (elem <- command) {
- out.println(elem)
+ printOut.println(elem)
}
- out.flush()
+ printOut.flush()
// Data values
for (elem <- parent.iterator(split, context)) {
- PythonRDD.writeAsPickle(elem, dOut)
+ PythonRDD.writeAsPickle(elem, dataOut)
}
- dOut.flush()
- out.flush()
+ dataOut.flush()
+ printOut.flush()
worker.shutdownOutput()
}
}.start()
// Return an iterator that read lines from the process's stdout
- val stream = new DataInputStream(worker.getInputStream)
+ val stream = new DataInputStream(new BufferedInputStream(worker.getInputStream))
return new Iterator[Array[Byte]] {
def next(): Array[Byte] = {
val obj = _nextObj
@@ -288,7 +289,7 @@ class PythonAccumulatorParam(@transient serverHost: String, serverPort: Int)
// This happens on the master, where we pass the updates to Python through a socket
val socket = new Socket(serverHost, serverPort)
val in = socket.getInputStream
- val out = new DataOutputStream(socket.getOutputStream)
+ val out = new DataOutputStream(new BufferedOutputStream(socket.getOutputStream))
out.writeInt(val2.size)
for (array <- val2) {
out.writeInt(array.length)