aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/rdd.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/rdd.py')
-rw-r--r--python/pyspark/rdd.py11
1 files changed, 7 insertions, 4 deletions
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index 7019fb8bee..d3c4d13a1e 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -54,6 +54,7 @@ class RDD(object):
self.is_checkpointed = False
self.ctx = ctx
self._partitionFunc = None
+ self._stage_input_is_pairs = False
@property
def context(self):
@@ -344,6 +345,7 @@ class RDD(object):
yield pair
else:
yield pair
+ java_cartesian._stage_input_is_pairs = True
return java_cartesian.flatMap(unpack_batches)
def groupBy(self, f, numPartitions=None):
@@ -391,8 +393,8 @@ class RDD(object):
"""
Return a list that contains all of the elements in this RDD.
"""
- picklesInJava = self._jrdd.collect().iterator()
- return list(self._collect_iterator_through_file(picklesInJava))
+ bytesInJava = self._jrdd.collect().iterator()
+ return list(self._collect_iterator_through_file(bytesInJava))
def _collect_iterator_through_file(self, iterator):
# Transferring lots of data through Py4J can be slow because
@@ -400,7 +402,7 @@ class RDD(object):
# file and read it back.
tempFile = NamedTemporaryFile(delete=False, dir=self.ctx._temp_dir)
tempFile.close()
- self.ctx._writeIteratorToPickleFile(iterator, tempFile.name)
+ self.ctx._writeToFile(iterator, tempFile.name)
# Read the data into Python and deserialize it:
with open(tempFile.name, 'rb') as tempFile:
for item in read_from_pickle_file(tempFile):
@@ -941,6 +943,7 @@ class PipelinedRDD(RDD):
self.func = func
self.preservesPartitioning = preservesPartitioning
self._prev_jrdd = prev._jrdd
+ self._stage_input_is_pairs = prev._stage_input_is_pairs
self.is_cached = False
self.is_checkpointed = False
self.ctx = prev.ctx
@@ -959,7 +962,7 @@ class PipelinedRDD(RDD):
def batched_func(split, iterator):
return batched(oldfunc(split, iterator), batchSize)
func = batched_func
- cmds = [func, self._bypass_serializer]
+ cmds = [func, self._bypass_serializer, self._stage_input_is_pairs]
pipe_command = ' '.join(b64enc(cloudpickle.dumps(f)) for f in cmds)
broadcast_vars = ListConverter().convert(
[x._jbroadcast for x in self.ctx._pickled_broadcast_vars],