aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/context.py
diff options
context:
space:
mode:
authorAaron Davidson <aaron@databricks.com>2014-05-31 13:04:57 -0700
committerReynold Xin <rxin@apache.org>2014-05-31 13:04:57 -0700
commit9909efc10aaa62c47fd7c4c9da73ac8c56a454d5 (patch)
tree86ab8e6477ab4a631b3a91f0e89007ca69c78d37 /python/pyspark/context.py
parent7d52777effd0ff41aed545f53d2ab8de2364a188 (diff)
downloadspark-9909efc10aaa62c47fd7c4c9da73ac8c56a454d5.tar.gz
spark-9909efc10aaa62c47fd7c4c9da73ac8c56a454d5.tar.bz2
spark-9909efc10aaa62c47fd7c4c9da73ac8c56a454d5.zip
SPARK-1839: PySpark RDD#take() shouldn't always read from driver
This patch simply ports over the Scala implementation of RDD#take(), which reads the first partition at the driver, then decides how many more partitions it needs to read and will possibly start a real job if it's more than 1. (Note that SparkContext#runJob(allowLocal=true) only runs the job locally if there's 1 partition selected and no parent stages.) Author: Aaron Davidson <aaron@databricks.com> Closes #922 from aarondav/take and squashes the following commits: fa06df9 [Aaron Davidson] SPARK-1839: PySpark RDD#take() shouldn't always read from driver
Diffstat (limited to 'python/pyspark/context.py')
-rw-r--r--python/pyspark/context.py26
1 files changed, 26 insertions, 0 deletions
diff --git a/python/pyspark/context.py b/python/pyspark/context.py
index 56746cb7aa..9ae9305d4f 100644
--- a/python/pyspark/context.py
+++ b/python/pyspark/context.py
@@ -537,6 +537,32 @@ class SparkContext(object):
"""
self._jsc.sc().cancelAllJobs()
+ def runJob(self, rdd, partitionFunc, partitions = None, allowLocal = False):
+ """
+ Executes the given partitionFunc on the specified set of partitions,
+ returning the result as an array of elements.
+
+ If 'partitions' is not specified, this will run over all partitions.
+
+ >>> myRDD = sc.parallelize(range(6), 3)
+ >>> sc.runJob(myRDD, lambda part: [x * x for x in part])
+ [0, 1, 4, 9, 16, 25]
+
+ >>> myRDD = sc.parallelize(range(6), 3)
+ >>> sc.runJob(myRDD, lambda part: [x * x for x in part], [0, 2], True)
+ [0, 1, 16, 25]
+ """
+ if partitions == None:
+ partitions = range(rdd._jrdd.splits().size())
+ javaPartitions = ListConverter().convert(partitions, self._gateway._gateway_client)
+
+ # Implementation note: This is implemented as a mapPartitions followed
+ # by runJob() in order to avoid having to pass a Python lambda into
+ # SparkContext#runJob.
+ mappedRDD = rdd.mapPartitions(partitionFunc)
+ it = self._jvm.PythonRDD.runJob(self._jsc.sc(), mappedRDD._jrdd, javaPartitions, allowLocal)
+ return list(mappedRDD._collect_iterator_through_file(it))
+
def _test():
import atexit
import doctest