aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/rdd.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/rdd.py')
-rw-r--r--python/pyspark/rdd.py66
1 files changed, 55 insertions, 11 deletions
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index f87923e6fa..6fb4a7b3be 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -23,6 +23,7 @@ import operator
import os
import sys
import shlex
+import traceback
from subprocess import Popen, PIPE
from tempfile import NamedTemporaryFile
from threading import Thread
@@ -39,6 +40,46 @@ from py4j.java_collections import ListConverter, MapConverter
__all__ = ["RDD"]
+def _extract_concise_traceback():
+ tb = traceback.extract_stack()
+ if len(tb) == 0:
+ return "I'm lost!"
+ # HACK: This function is in a file called 'rdd.py' in the top level of
+ # everything PySpark. Just trim off the directory name and assume
+ # everything in that tree is PySpark guts.
+ file, line, module, what = tb[len(tb) - 1]
+ sparkpath = os.path.dirname(file)
+ first_spark_frame = len(tb) - 1
+ for i in range(0, len(tb)):
+ file, line, fun, what = tb[i]
+ if file.startswith(sparkpath):
+ first_spark_frame = i
+ break
+ if first_spark_frame == 0:
+ file, line, fun, what = tb[0]
+ return "%s at %s:%d" % (fun, file, line)
+ sfile, sline, sfun, swhat = tb[first_spark_frame]
+ ufile, uline, ufun, uwhat = tb[first_spark_frame-1]
+ return "%s at %s:%d" % (sfun, ufile, uline)
+
+_spark_stack_depth = 0
+
+class _JavaStackTrace(object):
+ def __init__(self, sc):
+ self._traceback = _extract_concise_traceback()
+ self._context = sc
+
+ def __enter__(self):
+ global _spark_stack_depth
+ if _spark_stack_depth == 0:
+ self._context._jsc.setCallSite(self._traceback)
+ _spark_stack_depth += 1
+
+ def __exit__(self, type, value, tb):
+ global _spark_stack_depth
+ _spark_stack_depth -= 1
+ if _spark_stack_depth == 0:
+ self._context._jsc.setCallSite(None)
class RDD(object):
"""
@@ -401,7 +442,8 @@ class RDD(object):
"""
Return a list that contains all of the elements in this RDD.
"""
- bytesInJava = self._jrdd.collect().iterator()
+ with _JavaStackTrace(self.context) as st:
+ bytesInJava = self._jrdd.collect().iterator()
return list(self._collect_iterator_through_file(bytesInJava))
def _collect_iterator_through_file(self, iterator):
@@ -582,13 +624,14 @@ class RDD(object):
# TODO(shivaram): Similar to the scala implementation, update the take
# method to scan multiple splits based on an estimate of how many elements
# we have per-split.
- for partition in range(mapped._jrdd.splits().size()):
- partitionsToTake = self.ctx._gateway.new_array(self.ctx._jvm.int, 1)
- partitionsToTake[0] = partition
- iterator = mapped._jrdd.collectPartitions(partitionsToTake)[0].iterator()
- items.extend(mapped._collect_iterator_through_file(iterator))
- if len(items) >= num:
- break
+ with _JavaStackTrace(self.context) as st:
+ for partition in range(mapped._jrdd.splits().size()):
+ partitionsToTake = self.ctx._gateway.new_array(self.ctx._jvm.int, 1)
+ partitionsToTake[0] = partition
+ iterator = mapped._jrdd.collectPartitions(partitionsToTake)[0].iterator()
+ items.extend(mapped._collect_iterator_through_file(iterator))
+ if len(items) >= num:
+ break
return items[:num]
def first(self):
@@ -765,9 +808,10 @@ class RDD(object):
yield outputSerializer.dumps(items)
keyed = PipelinedRDD(self, add_shuffle_key)
keyed._bypass_serializer = True
- pairRDD = self.ctx._jvm.PairwiseRDD(keyed._jrdd.rdd()).asJavaPairRDD()
- partitioner = self.ctx._jvm.PythonPartitioner(numPartitions,
- id(partitionFunc))
+ with _JavaStackTrace(self.context) as st:
+ pairRDD = self.ctx._jvm.PairwiseRDD(keyed._jrdd.rdd()).asJavaPairRDD()
+ partitioner = self.ctx._jvm.PythonPartitioner(numPartitions,
+ id(partitionFunc))
jrdd = pairRDD.partitionBy(partitioner).values()
rdd = RDD(jrdd, self.ctx, BatchedSerializer(outputSerializer))
# This is required so that id(partitionFunc) remains unique, even if