diff options
Diffstat (limited to 'python/pyspark/rdd.py')
-rw-r--r-- | python/pyspark/rdd.py | 11 |
1 files changed, 7 insertions, 4 deletions
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 7019fb8bee..d3c4d13a1e 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -54,6 +54,7 @@ class RDD(object): self.is_checkpointed = False self.ctx = ctx self._partitionFunc = None + self._stage_input_is_pairs = False @property def context(self): @@ -344,6 +345,7 @@ class RDD(object): yield pair else: yield pair + java_cartesian._stage_input_is_pairs = True return java_cartesian.flatMap(unpack_batches) def groupBy(self, f, numPartitions=None): @@ -391,8 +393,8 @@ class RDD(object): """ Return a list that contains all of the elements in this RDD. """ - picklesInJava = self._jrdd.collect().iterator() - return list(self._collect_iterator_through_file(picklesInJava)) + bytesInJava = self._jrdd.collect().iterator() + return list(self._collect_iterator_through_file(bytesInJava)) def _collect_iterator_through_file(self, iterator): # Transferring lots of data through Py4J can be slow because @@ -400,7 +402,7 @@ class RDD(object): # file and read it back. tempFile = NamedTemporaryFile(delete=False, dir=self.ctx._temp_dir) tempFile.close() - self.ctx._writeIteratorToPickleFile(iterator, tempFile.name) + self.ctx._writeToFile(iterator, tempFile.name) # Read the data into Python and deserialize it: with open(tempFile.name, 'rb') as tempFile: for item in read_from_pickle_file(tempFile): @@ -941,6 +943,7 @@ class PipelinedRDD(RDD): self.func = func self.preservesPartitioning = preservesPartitioning self._prev_jrdd = prev._jrdd + self._stage_input_is_pairs = prev._stage_input_is_pairs self.is_cached = False self.is_checkpointed = False self.ctx = prev.ctx @@ -959,7 +962,7 @@ class PipelinedRDD(RDD): def batched_func(split, iterator): return batched(oldfunc(split, iterator), batchSize) func = batched_func - cmds = [func, self._bypass_serializer] + cmds = [func, self._bypass_serializer, self._stage_input_is_pairs] pipe_command = ' '.join(b64enc(cloudpickle.dumps(f)) for f in cmds) broadcast_vars = ListConverter().convert( [x._jbroadcast for x in self.ctx._pickled_broadcast_vars], |