aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMatei Zaharia <matei@eecs.berkeley.edu>2013-07-28 23:28:42 -0400
committerMatei Zaharia <matei@eecs.berkeley.edu>2013-07-29 02:51:43 -0400
commitb9d6783f36d527f5082bf13a4ee6fd108e97795c (patch)
tree922a260da5591d675b58246ed1cbd23c38abc4e7
parent72ff62a37c7310bab02f0231e91d3ba4d423217a (diff)
downloadspark-b9d6783f36d527f5082bf13a4ee6fd108e97795c.tar.gz
spark-b9d6783f36d527f5082bf13a4ee6fd108e97795c.tar.bz2
spark-b9d6783f36d527f5082bf13a4ee6fd108e97795c.zip
Optimize Python take() to not compute entire first partition
-rw-r--r--core/src/main/scala/spark/api/python/PythonRDD.scala64
-rw-r--r--python/pyspark/rdd.py15
2 files changed, 45 insertions, 34 deletions
diff --git a/core/src/main/scala/spark/api/python/PythonRDD.scala b/core/src/main/scala/spark/api/python/PythonRDD.scala
index af10822dbd..2dd79f7100 100644
--- a/core/src/main/scala/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/spark/api/python/PythonRDD.scala
@@ -63,34 +63,42 @@ private[spark] class PythonRDD[T: ClassManifest](
// Start a thread to feed the process input from our parent's iterator
new Thread("stdin writer for " + pythonExec) {
override def run() {
- SparkEnv.set(env)
- val stream = new BufferedOutputStream(worker.getOutputStream, bufferSize)
- val dataOut = new DataOutputStream(stream)
- val printOut = new PrintWriter(stream)
- // Partition index
- dataOut.writeInt(split.index)
- // sparkFilesDir
- PythonRDD.writeAsPickle(SparkFiles.getRootDirectory, dataOut)
- // Broadcast variables
- dataOut.writeInt(broadcastVars.length)
- for (broadcast <- broadcastVars) {
- dataOut.writeLong(broadcast.id)
- dataOut.writeInt(broadcast.value.length)
- dataOut.write(broadcast.value)
- }
- dataOut.flush()
- // Serialized user code
- for (elem <- command) {
- printOut.println(elem)
- }
- printOut.flush()
- // Data values
- for (elem <- parent.iterator(split, context)) {
- PythonRDD.writeAsPickle(elem, dataOut)
+ try {
+ SparkEnv.set(env)
+ val stream = new BufferedOutputStream(worker.getOutputStream, bufferSize)
+ val dataOut = new DataOutputStream(stream)
+ val printOut = new PrintWriter(stream)
+ // Partition index
+ dataOut.writeInt(split.index)
+ // sparkFilesDir
+ PythonRDD.writeAsPickle(SparkFiles.getRootDirectory, dataOut)
+ // Broadcast variables
+ dataOut.writeInt(broadcastVars.length)
+ for (broadcast <- broadcastVars) {
+ dataOut.writeLong(broadcast.id)
+ dataOut.writeInt(broadcast.value.length)
+ dataOut.write(broadcast.value)
+ }
+ dataOut.flush()
+ // Serialized user code
+ for (elem <- command) {
+ printOut.println(elem)
+ }
+ printOut.flush()
+ // Data values
+ for (elem <- parent.iterator(split, context)) {
+ PythonRDD.writeAsPickle(elem, dataOut)
+ }
+ dataOut.flush()
+ printOut.flush()
+ worker.shutdownOutput()
+ } catch {
+ case e: IOException =>
+ // This can happen for legitimate reasons if the Python code stops returning data before we are done
+ // passing elements through, e.g., for take(). Just log a message to say it happened.
+ logInfo("stdin writer to Python finished early")
+ logDebug("stdin writer to Python finished early", e)
}
- dataOut.flush()
- printOut.flush()
- worker.shutdownOutput()
}
}.start()
@@ -297,7 +305,7 @@ class PythonAccumulatorParam(@transient serverHost: String, serverPort: Int)
Utils.checkHost(serverHost, "Expected hostname")
val bufferSize = System.getProperty("spark.buffer.size", "65536").toInt
-
+
override def zero(value: JList[Array[Byte]]): JList[Array[Byte]] = new JArrayList
override def addInPlace(val1: JList[Array[Byte]], val2: JList[Array[Byte]])
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index c6a6b24c5a..6efa61aa66 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -386,13 +386,16 @@ class RDD(object):
>>> sc.parallelize([2, 3, 4, 5, 6]).take(10)
[2, 3, 4, 5, 6]
"""
+ def takeUpToNum(iterator):
+ taken = 0
+ while taken < num:
+ yield next(iterator)
+ taken += 1
+ # Take only up to num elements from each partition we try
+ mapped = self.mapPartitions(takeUpToNum)
items = []
- for partition in range(self._jrdd.splits().size()):
- iterator = self.ctx._takePartition(self._jrdd.rdd(), partition)
- # Each item in the iterator is a string, Python object, batch of
- # Python objects. Regardless, it is sufficient to take `num`
- # of these objects in order to collect `num` Python objects:
- iterator = iterator.take(num)
+ for partition in range(mapped._jrdd.splits().size()):
+ iterator = self.ctx._takePartition(mapped._jrdd.rdd(), partition)
items.extend(self._collect_iterator_through_file(iterator))
if len(items) >= num:
break