From fec01664a717c8ecf84eaf7a2523a62cd5d3b4b8 Mon Sep 17 00:00:00 2001 From: Tor Myklebust Date: Sat, 28 Dec 2013 23:34:16 -0500 Subject: Make Python function/line appear in the UI. --- python/pyspark/rdd.py | 66 ++++++++++++++++++++++++++++++++++++++++++--------- 1 file changed, 55 insertions(+), 11 deletions(-) (limited to 'python/pyspark/rdd.py') diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index f87923e6fa..6fb4a7b3be 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -23,6 +23,7 @@ import operator import os import sys import shlex +import traceback from subprocess import Popen, PIPE from tempfile import NamedTemporaryFile from threading import Thread @@ -39,6 +40,46 @@ from py4j.java_collections import ListConverter, MapConverter __all__ = ["RDD"] +def _extract_concise_traceback(): + tb = traceback.extract_stack() + if len(tb) == 0: + return "I'm lost!" + # HACK: This function is in a file called 'rdd.py' in the top level of + # everything PySpark. Just trim off the directory name and assume + # everything in that tree is PySpark guts. + file, line, module, what = tb[len(tb) - 1] + sparkpath = os.path.dirname(file) + first_spark_frame = len(tb) - 1 + for i in range(0, len(tb)): + file, line, fun, what = tb[i] + if file.startswith(sparkpath): + first_spark_frame = i + break + if first_spark_frame == 0: + file, line, fun, what = tb[0] + return "%s at %s:%d" % (fun, file, line) + sfile, sline, sfun, swhat = tb[first_spark_frame] + ufile, uline, ufun, uwhat = tb[first_spark_frame-1] + return "%s at %s:%d" % (sfun, ufile, uline) + +_spark_stack_depth = 0 + +class _JavaStackTrace(object): + def __init__(self, sc): + self._traceback = _extract_concise_traceback() + self._context = sc + + def __enter__(self): + global _spark_stack_depth + if _spark_stack_depth == 0: + self._context._jsc.setCallSite(self._traceback) + _spark_stack_depth += 1 + + def __exit__(self, type, value, tb): + global _spark_stack_depth + _spark_stack_depth -= 1 + if _spark_stack_depth == 0: + self._context._jsc.setCallSite(None) class RDD(object): """ @@ -401,7 +442,8 @@ class RDD(object): """ Return a list that contains all of the elements in this RDD. """ - bytesInJava = self._jrdd.collect().iterator() + with _JavaStackTrace(self.context) as st: + bytesInJava = self._jrdd.collect().iterator() return list(self._collect_iterator_through_file(bytesInJava)) def _collect_iterator_through_file(self, iterator): @@ -582,13 +624,14 @@ class RDD(object): # TODO(shivaram): Similar to the scala implementation, update the take # method to scan multiple splits based on an estimate of how many elements # we have per-split. - for partition in range(mapped._jrdd.splits().size()): - partitionsToTake = self.ctx._gateway.new_array(self.ctx._jvm.int, 1) - partitionsToTake[0] = partition - iterator = mapped._jrdd.collectPartitions(partitionsToTake)[0].iterator() - items.extend(mapped._collect_iterator_through_file(iterator)) - if len(items) >= num: - break + with _JavaStackTrace(self.context) as st: + for partition in range(mapped._jrdd.splits().size()): + partitionsToTake = self.ctx._gateway.new_array(self.ctx._jvm.int, 1) + partitionsToTake[0] = partition + iterator = mapped._jrdd.collectPartitions(partitionsToTake)[0].iterator() + items.extend(mapped._collect_iterator_through_file(iterator)) + if len(items) >= num: + break return items[:num] def first(self): @@ -765,9 +808,10 @@ class RDD(object): yield outputSerializer.dumps(items) keyed = PipelinedRDD(self, add_shuffle_key) keyed._bypass_serializer = True - pairRDD = self.ctx._jvm.PairwiseRDD(keyed._jrdd.rdd()).asJavaPairRDD() - partitioner = self.ctx._jvm.PythonPartitioner(numPartitions, - id(partitionFunc)) + with _JavaStackTrace(self.context) as st: + pairRDD = self.ctx._jvm.PairwiseRDD(keyed._jrdd.rdd()).asJavaPairRDD() + partitioner = self.ctx._jvm.PythonPartitioner(numPartitions, + id(partitionFunc)) jrdd = pairRDD.partitionBy(partitioner).values() rdd = RDD(jrdd, self.ctx, BatchedSerializer(outputSerializer)) # This is required so that id(partitionFunc) remains unique, even if -- cgit v1.2.3