diff options
author | Aaron Davidson <aaron@databricks.com> | 2014-05-31 13:04:57 -0700 |
---|---|---|
committer | Reynold Xin <rxin@apache.org> | 2014-05-31 13:04:57 -0700 |
commit | 9909efc10aaa62c47fd7c4c9da73ac8c56a454d5 (patch) | |
tree | 86ab8e6477ab4a631b3a91f0e89007ca69c78d37 /core | |
parent | 7d52777effd0ff41aed545f53d2ab8de2364a188 (diff) | |
download | spark-9909efc10aaa62c47fd7c4c9da73ac8c56a454d5.tar.gz spark-9909efc10aaa62c47fd7c4c9da73ac8c56a454d5.tar.bz2 spark-9909efc10aaa62c47fd7c4c9da73ac8c56a454d5.zip |
SPARK-1839: PySpark RDD#take() shouldn't always read from driver
This patch simply ports over the Scala implementation of RDD#take(), which reads the first partition at the driver, then decides how many more partitions it needs to read and will possibly start a real job if it's more than 1. (Note that SparkContext#runJob(allowLocal=true) only runs the job locally if there's 1 partition selected and no parent stages.)
Author: Aaron Davidson <aaron@databricks.com>
Closes #922 from aarondav/take and squashes the following commits:
fa06df9 [Aaron Davidson] SPARK-1839: PySpark RDD#take() shouldn't always read from driver
Diffstat (limited to 'core')
-rw-r--r-- | core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala | 20 |
1 files changed, 20 insertions, 0 deletions
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 57b28b9972..d1df99300c 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -269,6 +269,26 @@ private object SpecialLengths { private[spark] object PythonRDD { val UTF8 = Charset.forName("UTF-8") + /** + * Adapter for calling SparkContext#runJob from Python. + * + * This method will return an iterator of an array that contains all elements in the RDD + * (effectively a collect()), but allows you to run on a certain subset of partitions, + * or to enable local execution. + */ + def runJob( + sc: SparkContext, + rdd: JavaRDD[Array[Byte]], + partitions: JArrayList[Int], + allowLocal: Boolean): Iterator[Array[Byte]] = { + type ByteArray = Array[Byte] + type UnrolledPartition = Array[ByteArray] + val allPartitions: Array[UnrolledPartition] = + sc.runJob(rdd, (x: Iterator[ByteArray]) => x.toArray, partitions, allowLocal) + val flattenedPartition: UnrolledPartition = Array.concat(allPartitions: _*) + flattenedPartition.iterator + } + def readRDDFromFile(sc: JavaSparkContext, filename: String, parallelism: Int): JavaRDD[Array[Byte]] = { val file = new DataInputStream(new FileInputStream(filename)) |