From b9d6783f36d527f5082bf13a4ee6fd108e97795c Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Sun, 28 Jul 2013 23:28:42 -0400 Subject: Optimize Python take() to not compute entire first partition --- .../main/scala/spark/api/python/PythonRDD.scala | 64 ++++++++++++---------- python/pyspark/rdd.py | 15 +++-- 2 files changed, 45 insertions(+), 34 deletions(-) diff --git a/core/src/main/scala/spark/api/python/PythonRDD.scala b/core/src/main/scala/spark/api/python/PythonRDD.scala index af10822dbd..2dd79f7100 100644 --- a/core/src/main/scala/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/spark/api/python/PythonRDD.scala @@ -63,34 +63,42 @@ private[spark] class PythonRDD[T: ClassManifest]( // Start a thread to feed the process input from our parent's iterator new Thread("stdin writer for " + pythonExec) { override def run() { - SparkEnv.set(env) - val stream = new BufferedOutputStream(worker.getOutputStream, bufferSize) - val dataOut = new DataOutputStream(stream) - val printOut = new PrintWriter(stream) - // Partition index - dataOut.writeInt(split.index) - // sparkFilesDir - PythonRDD.writeAsPickle(SparkFiles.getRootDirectory, dataOut) - // Broadcast variables - dataOut.writeInt(broadcastVars.length) - for (broadcast <- broadcastVars) { - dataOut.writeLong(broadcast.id) - dataOut.writeInt(broadcast.value.length) - dataOut.write(broadcast.value) - } - dataOut.flush() - // Serialized user code - for (elem <- command) { - printOut.println(elem) - } - printOut.flush() - // Data values - for (elem <- parent.iterator(split, context)) { - PythonRDD.writeAsPickle(elem, dataOut) + try { + SparkEnv.set(env) + val stream = new BufferedOutputStream(worker.getOutputStream, bufferSize) + val dataOut = new DataOutputStream(stream) + val printOut = new PrintWriter(stream) + // Partition index + dataOut.writeInt(split.index) + // sparkFilesDir + PythonRDD.writeAsPickle(SparkFiles.getRootDirectory, dataOut) + // Broadcast variables + dataOut.writeInt(broadcastVars.length) + for (broadcast <- broadcastVars) { + dataOut.writeLong(broadcast.id) + dataOut.writeInt(broadcast.value.length) + dataOut.write(broadcast.value) + } + dataOut.flush() + // Serialized user code + for (elem <- command) { + printOut.println(elem) + } + printOut.flush() + // Data values + for (elem <- parent.iterator(split, context)) { + PythonRDD.writeAsPickle(elem, dataOut) + } + dataOut.flush() + printOut.flush() + worker.shutdownOutput() + } catch { + case e: IOException => + // This can happen for legitimate reasons if the Python code stops returning data before we are done + // passing elements through, e.g., for take(). Just log a message to say it happened. + logInfo("stdin writer to Python finished early") + logDebug("stdin writer to Python finished early", e) } - dataOut.flush() - printOut.flush() - worker.shutdownOutput() } }.start() @@ -297,7 +305,7 @@ class PythonAccumulatorParam(@transient serverHost: String, serverPort: Int) Utils.checkHost(serverHost, "Expected hostname") val bufferSize = System.getProperty("spark.buffer.size", "65536").toInt - + override def zero(value: JList[Array[Byte]]): JList[Array[Byte]] = new JArrayList override def addInPlace(val1: JList[Array[Byte]], val2: JList[Array[Byte]]) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index c6a6b24c5a..6efa61aa66 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -386,13 +386,16 @@ class RDD(object): >>> sc.parallelize([2, 3, 4, 5, 6]).take(10) [2, 3, 4, 5, 6] """ + def takeUpToNum(iterator): + taken = 0 + while taken < num: + yield next(iterator) + taken += 1 + # Take only up to num elements from each partition we try + mapped = self.mapPartitions(takeUpToNum) items = [] - for partition in range(self._jrdd.splits().size()): - iterator = self.ctx._takePartition(self._jrdd.rdd(), partition) - # Each item in the iterator is a string, Python object, batch of - # Python objects. Regardless, it is sufficient to take `num` - # of these objects in order to collect `num` Python objects: - iterator = iterator.take(num) + for partition in range(mapped._jrdd.splits().size()): + iterator = self.ctx._takePartition(mapped._jrdd.rdd(), partition) items.extend(self._collect_iterator_through_file(iterator)) if len(items) >= num: break -- cgit v1.2.3