diff options
author | Aaron Davidson <aaron@databricks.com> | 2014-05-31 13:04:57 -0700 |
---|---|---|
committer | Reynold Xin <rxin@apache.org> | 2014-05-31 13:04:57 -0700 |
commit | 9909efc10aaa62c47fd7c4c9da73ac8c56a454d5 (patch) | |
tree | 86ab8e6477ab4a631b3a91f0e89007ca69c78d37 /python/pyspark/context.py | |
parent | 7d52777effd0ff41aed545f53d2ab8de2364a188 (diff) | |
download | spark-9909efc10aaa62c47fd7c4c9da73ac8c56a454d5.tar.gz spark-9909efc10aaa62c47fd7c4c9da73ac8c56a454d5.tar.bz2 spark-9909efc10aaa62c47fd7c4c9da73ac8c56a454d5.zip |
SPARK-1839: PySpark RDD#take() shouldn't always read from driver
This patch simply ports over the Scala implementation of RDD#take(), which reads the first partition at the driver, then decides how many more partitions it needs to read and will possibly start a real job if it's more than 1. (Note that SparkContext#runJob(allowLocal=true) only runs the job locally if there's 1 partition selected and no parent stages.)
Author: Aaron Davidson <aaron@databricks.com>
Closes #922 from aarondav/take and squashes the following commits:
fa06df9 [Aaron Davidson] SPARK-1839: PySpark RDD#take() shouldn't always read from driver
Diffstat (limited to 'python/pyspark/context.py')
-rw-r--r-- | python/pyspark/context.py | 26 |
1 files changed, 26 insertions, 0 deletions
diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 56746cb7aa..9ae9305d4f 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -537,6 +537,32 @@ class SparkContext(object): """ self._jsc.sc().cancelAllJobs() + def runJob(self, rdd, partitionFunc, partitions = None, allowLocal = False): + """ + Executes the given partitionFunc on the specified set of partitions, + returning the result as an array of elements. + + If 'partitions' is not specified, this will run over all partitions. + + >>> myRDD = sc.parallelize(range(6), 3) + >>> sc.runJob(myRDD, lambda part: [x * x for x in part]) + [0, 1, 4, 9, 16, 25] + + >>> myRDD = sc.parallelize(range(6), 3) + >>> sc.runJob(myRDD, lambda part: [x * x for x in part], [0, 2], True) + [0, 1, 16, 25] + """ + if partitions == None: + partitions = range(rdd._jrdd.splits().size()) + javaPartitions = ListConverter().convert(partitions, self._gateway._gateway_client) + + # Implementation note: This is implemented as a mapPartitions followed + # by runJob() in order to avoid having to pass a Python lambda into + # SparkContext#runJob. + mappedRDD = rdd.mapPartitions(partitionFunc) + it = self._jvm.PythonRDD.runJob(self._jsc.sc(), mappedRDD._jrdd, javaPartitions, allowLocal) + return list(mappedRDD._collect_iterator_through_file(it)) + def _test(): import atexit import doctest |