aboutsummaryrefslogtreecommitdiff
path: root/core
diff options
context:
space:
mode:
authorAaron Davidson <aaron@databricks.com>2014-05-31 13:04:57 -0700
committerReynold Xin <rxin@apache.org>2014-05-31 13:04:57 -0700
commit9909efc10aaa62c47fd7c4c9da73ac8c56a454d5 (patch)
tree86ab8e6477ab4a631b3a91f0e89007ca69c78d37 /core
parent7d52777effd0ff41aed545f53d2ab8de2364a188 (diff)
downloadspark-9909efc10aaa62c47fd7c4c9da73ac8c56a454d5.tar.gz
spark-9909efc10aaa62c47fd7c4c9da73ac8c56a454d5.tar.bz2
spark-9909efc10aaa62c47fd7c4c9da73ac8c56a454d5.zip
SPARK-1839: PySpark RDD#take() shouldn't always read from driver
This patch simply ports over the Scala implementation of RDD#take(), which reads the first partition at the driver, then decides how many more partitions it needs to read and will possibly start a real job if it's more than 1. (Note that SparkContext#runJob(allowLocal=true) only runs the job locally if there's 1 partition selected and no parent stages.) Author: Aaron Davidson <aaron@databricks.com> Closes #922 from aarondav/take and squashes the following commits: fa06df9 [Aaron Davidson] SPARK-1839: PySpark RDD#take() shouldn't always read from driver
Diffstat (limited to 'core')
-rw-r--r--core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala20
1 files changed, 20 insertions, 0 deletions
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
index 57b28b9972..d1df99300c 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
@@ -269,6 +269,26 @@ private object SpecialLengths {
private[spark] object PythonRDD {
val UTF8 = Charset.forName("UTF-8")
+ /**
+ * Adapter for calling SparkContext#runJob from Python.
+ *
+ * This method will return an iterator of an array that contains all elements in the RDD
+ * (effectively a collect()), but allows you to run on a certain subset of partitions,
+ * or to enable local execution.
+ */
+ def runJob(
+ sc: SparkContext,
+ rdd: JavaRDD[Array[Byte]],
+ partitions: JArrayList[Int],
+ allowLocal: Boolean): Iterator[Array[Byte]] = {
+ type ByteArray = Array[Byte]
+ type UnrolledPartition = Array[ByteArray]
+ val allPartitions: Array[UnrolledPartition] =
+ sc.runJob(rdd, (x: Iterator[ByteArray]) => x.toArray, partitions, allowLocal)
+ val flattenedPartition: UnrolledPartition = Array.concat(allPartitions: _*)
+ flattenedPartition.iterator
+ }
+
def readRDDFromFile(sc: JavaSparkContext, filename: String, parallelism: Int):
JavaRDD[Array[Byte]] = {
val file = new DataInputStream(new FileInputStream(filename))