diff options
author | Matei Zaharia <matei@databricks.com> | 2014-01-02 15:54:54 -0500 |
---|---|---|
committer | Matei Zaharia <matei@databricks.com> | 2014-01-02 15:54:54 -0500 |
commit | ca67909cd4b4accea64498b3bc1421c6d75e33a6 (patch) | |
tree | b6e13c3c0be95ea99c857429c97d3687c61050a3 /python | |
parent | 3713f8129a618a633a7aca8c944960c3e7ac9d3b (diff) | |
parent | fec01664a717c8ecf84eaf7a2523a62cd5d3b4b8 (diff) | |
download | spark-ca67909cd4b4accea64498b3bc1421c6d75e33a6.tar.gz spark-ca67909cd4b4accea64498b3bc1421c6d75e33a6.tar.bz2 spark-ca67909cd4b4accea64498b3bc1421c6d75e33a6.zip |
Merge pull request #311 from tmyklebu/master
SPARK-991: Report information gleaned from a Python stacktrace in the UI
Scala:
- Added setCallSite/clearCallSite to SparkContext and JavaSparkContext.
These functions mutate a LocalProperty called "externalCallSite."
- Add a wrapper, getCallSite, that checks for an externalCallSite and, if
none is found, calls the usual Utils.formatSparkCallSite.
- Change everything that calls Utils.formatSparkCallSite to call
getCallSite instead. Except getCallSite.
- Add wrappers to setCallSite/clearCallSite wrappers to JavaSparkContext.
Python:
- Add a gruesome hack to rdd.py that inspects the traceback and guesses
what you want to see in the UI.
- Add a RAII wrapper around said gruesome hack that calls
setCallSite/clearCallSite as appropriate.
- Wire said RAII wrapper up around three calls into the Scala code.
I'm not sure that I hit all the spots with the RAII wrapper. I'm also
not sure that my gruesome hack does exactly what we want.
One could also approach this change by refactoring
runJob/submitJob/runApproximateJob to take a call site, then threading
that parameter through everything that needs to know it.
One might object to the pointless-looking wrappers in JavaSparkContext.
Unfortunately, I can't directly access the SparkContext from
Python---or, if I can, I don't know how---so I need to wrap everything
that matters in JavaSparkContext.
Conflicts:
core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala
Diffstat (limited to 'python')
-rw-r--r-- | python/pyspark/rdd.py | 66 |
1 files changed, 55 insertions, 11 deletions
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index f87923e6fa..6fb4a7b3be 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -23,6 +23,7 @@ import operator import os import sys import shlex +import traceback from subprocess import Popen, PIPE from tempfile import NamedTemporaryFile from threading import Thread @@ -39,6 +40,46 @@ from py4j.java_collections import ListConverter, MapConverter __all__ = ["RDD"] +def _extract_concise_traceback(): + tb = traceback.extract_stack() + if len(tb) == 0: + return "I'm lost!" + # HACK: This function is in a file called 'rdd.py' in the top level of + # everything PySpark. Just trim off the directory name and assume + # everything in that tree is PySpark guts. + file, line, module, what = tb[len(tb) - 1] + sparkpath = os.path.dirname(file) + first_spark_frame = len(tb) - 1 + for i in range(0, len(tb)): + file, line, fun, what = tb[i] + if file.startswith(sparkpath): + first_spark_frame = i + break + if first_spark_frame == 0: + file, line, fun, what = tb[0] + return "%s at %s:%d" % (fun, file, line) + sfile, sline, sfun, swhat = tb[first_spark_frame] + ufile, uline, ufun, uwhat = tb[first_spark_frame-1] + return "%s at %s:%d" % (sfun, ufile, uline) + +_spark_stack_depth = 0 + +class _JavaStackTrace(object): + def __init__(self, sc): + self._traceback = _extract_concise_traceback() + self._context = sc + + def __enter__(self): + global _spark_stack_depth + if _spark_stack_depth == 0: + self._context._jsc.setCallSite(self._traceback) + _spark_stack_depth += 1 + + def __exit__(self, type, value, tb): + global _spark_stack_depth + _spark_stack_depth -= 1 + if _spark_stack_depth == 0: + self._context._jsc.setCallSite(None) class RDD(object): """ @@ -401,7 +442,8 @@ class RDD(object): """ Return a list that contains all of the elements in this RDD. """ - bytesInJava = self._jrdd.collect().iterator() + with _JavaStackTrace(self.context) as st: + bytesInJava = self._jrdd.collect().iterator() return list(self._collect_iterator_through_file(bytesInJava)) def _collect_iterator_through_file(self, iterator): @@ -582,13 +624,14 @@ class RDD(object): # TODO(shivaram): Similar to the scala implementation, update the take # method to scan multiple splits based on an estimate of how many elements # we have per-split. - for partition in range(mapped._jrdd.splits().size()): - partitionsToTake = self.ctx._gateway.new_array(self.ctx._jvm.int, 1) - partitionsToTake[0] = partition - iterator = mapped._jrdd.collectPartitions(partitionsToTake)[0].iterator() - items.extend(mapped._collect_iterator_through_file(iterator)) - if len(items) >= num: - break + with _JavaStackTrace(self.context) as st: + for partition in range(mapped._jrdd.splits().size()): + partitionsToTake = self.ctx._gateway.new_array(self.ctx._jvm.int, 1) + partitionsToTake[0] = partition + iterator = mapped._jrdd.collectPartitions(partitionsToTake)[0].iterator() + items.extend(mapped._collect_iterator_through_file(iterator)) + if len(items) >= num: + break return items[:num] def first(self): @@ -765,9 +808,10 @@ class RDD(object): yield outputSerializer.dumps(items) keyed = PipelinedRDD(self, add_shuffle_key) keyed._bypass_serializer = True - pairRDD = self.ctx._jvm.PairwiseRDD(keyed._jrdd.rdd()).asJavaPairRDD() - partitioner = self.ctx._jvm.PythonPartitioner(numPartitions, - id(partitionFunc)) + with _JavaStackTrace(self.context) as st: + pairRDD = self.ctx._jvm.PairwiseRDD(keyed._jrdd.rdd()).asJavaPairRDD() + partitioner = self.ctx._jvm.PythonPartitioner(numPartitions, + id(partitionFunc)) jrdd = pairRDD.partitionBy(partitioner).values() rdd = RDD(jrdd, self.ctx, BatchedSerializer(outputSerializer)) # This is required so that id(partitionFunc) remains unique, even if |