aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/rdd.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/rdd.py')
-rw-r--r--python/pyspark/rdd.py32
1 files changed, 20 insertions, 12 deletions
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index c6a6b24c5a..99f5967a8e 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -160,7 +160,7 @@ class RDD(object):
>>> sorted(sc.parallelize([1, 1, 2, 3]).distinct().collect())
[1, 2, 3]
"""
- return self.map(lambda x: (x, "")) \
+ return self.map(lambda x: (x, None)) \
.reduceByKey(lambda x, _: x) \
.map(lambda (x, _): x)
@@ -267,7 +267,11 @@ class RDD(object):
>>> def f(x): print x
>>> sc.parallelize([1, 2, 3, 4, 5]).foreach(f)
"""
- self.map(f).collect() # Force evaluation
+ def processPartition(iterator):
+ for x in iterator:
+ f(x)
+ yield None
+ self.mapPartitions(processPartition).collect() # Force evaluation
def collect(self):
"""
@@ -386,13 +390,16 @@ class RDD(object):
>>> sc.parallelize([2, 3, 4, 5, 6]).take(10)
[2, 3, 4, 5, 6]
"""
+ def takeUpToNum(iterator):
+ taken = 0
+ while taken < num:
+ yield next(iterator)
+ taken += 1
+ # Take only up to num elements from each partition we try
+ mapped = self.mapPartitions(takeUpToNum)
items = []
- for partition in range(self._jrdd.splits().size()):
- iterator = self.ctx._takePartition(self._jrdd.rdd(), partition)
- # Each item in the iterator is a string, Python object, batch of
- # Python objects. Regardless, it is sufficient to take `num`
- # of these objects in order to collect `num` Python objects:
- iterator = iterator.take(num)
+ for partition in range(mapped._jrdd.splits().size()):
+ iterator = self.ctx._takePartition(mapped._jrdd.rdd(), partition)
items.extend(self._collect_iterator_through_file(iterator))
if len(items) >= num:
break
@@ -749,11 +756,12 @@ class PipelinedRDD(RDD):
self.ctx._gateway._gateway_client)
self.ctx._pickled_broadcast_vars.clear()
class_manifest = self._prev_jrdd.classManifest()
- env = copy.copy(self.ctx.environment)
- env['PYTHONPATH'] = os.environ.get("PYTHONPATH", "")
- env = MapConverter().convert(env, self.ctx._gateway._gateway_client)
+ env = MapConverter().convert(self.ctx.environment,
+ self.ctx._gateway._gateway_client)
+ includes = ListConverter().convert(self.ctx._python_includes,
+ self.ctx._gateway._gateway_client)
python_rdd = self.ctx._jvm.PythonRDD(self._prev_jrdd.rdd(),
- pipe_command, env, self.preservesPartitioning, self.ctx.pythonExec,
+ pipe_command, env, includes, self.preservesPartitioning, self.ctx.pythonExec,
broadcast_vars, self.ctx._javaAccumulator, class_manifest)
self._jrdd_val = python_rdd.asJavaRDD()
return self._jrdd_val