aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorMatei Zaharia <matei@databricks.com>2014-01-02 15:54:54 -0500
committerMatei Zaharia <matei@databricks.com>2014-01-02 15:54:54 -0500
commitca67909cd4b4accea64498b3bc1421c6d75e33a6 (patch)
treeb6e13c3c0be95ea99c857429c97d3687c61050a3 /python
parent3713f8129a618a633a7aca8c944960c3e7ac9d3b (diff)
parentfec01664a717c8ecf84eaf7a2523a62cd5d3b4b8 (diff)
downloadspark-ca67909cd4b4accea64498b3bc1421c6d75e33a6.tar.gz
spark-ca67909cd4b4accea64498b3bc1421c6d75e33a6.tar.bz2
spark-ca67909cd4b4accea64498b3bc1421c6d75e33a6.zip
Merge pull request #311 from tmyklebu/master
SPARK-991: Report information gleaned from a Python stacktrace in the UI Scala: - Added setCallSite/clearCallSite to SparkContext and JavaSparkContext. These functions mutate a LocalProperty called "externalCallSite." - Add a wrapper, getCallSite, that checks for an externalCallSite and, if none is found, calls the usual Utils.formatSparkCallSite. - Change everything that calls Utils.formatSparkCallSite to call getCallSite instead. Except getCallSite. - Add wrappers to setCallSite/clearCallSite wrappers to JavaSparkContext. Python: - Add a gruesome hack to rdd.py that inspects the traceback and guesses what you want to see in the UI. - Add a RAII wrapper around said gruesome hack that calls setCallSite/clearCallSite as appropriate. - Wire said RAII wrapper up around three calls into the Scala code. I'm not sure that I hit all the spots with the RAII wrapper. I'm also not sure that my gruesome hack does exactly what we want. One could also approach this change by refactoring runJob/submitJob/runApproximateJob to take a call site, then threading that parameter through everything that needs to know it. One might object to the pointless-looking wrappers in JavaSparkContext. Unfortunately, I can't directly access the SparkContext from Python---or, if I can, I don't know how---so I need to wrap everything that matters in JavaSparkContext. Conflicts: core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/rdd.py66
1 files changed, 55 insertions, 11 deletions
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index f87923e6fa..6fb4a7b3be 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -23,6 +23,7 @@ import operator
import os
import sys
import shlex
+import traceback
from subprocess import Popen, PIPE
from tempfile import NamedTemporaryFile
from threading import Thread
@@ -39,6 +40,46 @@ from py4j.java_collections import ListConverter, MapConverter
__all__ = ["RDD"]
+def _extract_concise_traceback():
+ tb = traceback.extract_stack()
+ if len(tb) == 0:
+ return "I'm lost!"
+ # HACK: This function is in a file called 'rdd.py' in the top level of
+ # everything PySpark. Just trim off the directory name and assume
+ # everything in that tree is PySpark guts.
+ file, line, module, what = tb[len(tb) - 1]
+ sparkpath = os.path.dirname(file)
+ first_spark_frame = len(tb) - 1
+ for i in range(0, len(tb)):
+ file, line, fun, what = tb[i]
+ if file.startswith(sparkpath):
+ first_spark_frame = i
+ break
+ if first_spark_frame == 0:
+ file, line, fun, what = tb[0]
+ return "%s at %s:%d" % (fun, file, line)
+ sfile, sline, sfun, swhat = tb[first_spark_frame]
+ ufile, uline, ufun, uwhat = tb[first_spark_frame-1]
+ return "%s at %s:%d" % (sfun, ufile, uline)
+
+_spark_stack_depth = 0
+
+class _JavaStackTrace(object):
+ def __init__(self, sc):
+ self._traceback = _extract_concise_traceback()
+ self._context = sc
+
+ def __enter__(self):
+ global _spark_stack_depth
+ if _spark_stack_depth == 0:
+ self._context._jsc.setCallSite(self._traceback)
+ _spark_stack_depth += 1
+
+ def __exit__(self, type, value, tb):
+ global _spark_stack_depth
+ _spark_stack_depth -= 1
+ if _spark_stack_depth == 0:
+ self._context._jsc.setCallSite(None)
class RDD(object):
"""
@@ -401,7 +442,8 @@ class RDD(object):
"""
Return a list that contains all of the elements in this RDD.
"""
- bytesInJava = self._jrdd.collect().iterator()
+ with _JavaStackTrace(self.context) as st:
+ bytesInJava = self._jrdd.collect().iterator()
return list(self._collect_iterator_through_file(bytesInJava))
def _collect_iterator_through_file(self, iterator):
@@ -582,13 +624,14 @@ class RDD(object):
# TODO(shivaram): Similar to the scala implementation, update the take
# method to scan multiple splits based on an estimate of how many elements
# we have per-split.
- for partition in range(mapped._jrdd.splits().size()):
- partitionsToTake = self.ctx._gateway.new_array(self.ctx._jvm.int, 1)
- partitionsToTake[0] = partition
- iterator = mapped._jrdd.collectPartitions(partitionsToTake)[0].iterator()
- items.extend(mapped._collect_iterator_through_file(iterator))
- if len(items) >= num:
- break
+ with _JavaStackTrace(self.context) as st:
+ for partition in range(mapped._jrdd.splits().size()):
+ partitionsToTake = self.ctx._gateway.new_array(self.ctx._jvm.int, 1)
+ partitionsToTake[0] = partition
+ iterator = mapped._jrdd.collectPartitions(partitionsToTake)[0].iterator()
+ items.extend(mapped._collect_iterator_through_file(iterator))
+ if len(items) >= num:
+ break
return items[:num]
def first(self):
@@ -765,9 +808,10 @@ class RDD(object):
yield outputSerializer.dumps(items)
keyed = PipelinedRDD(self, add_shuffle_key)
keyed._bypass_serializer = True
- pairRDD = self.ctx._jvm.PairwiseRDD(keyed._jrdd.rdd()).asJavaPairRDD()
- partitioner = self.ctx._jvm.PythonPartitioner(numPartitions,
- id(partitionFunc))
+ with _JavaStackTrace(self.context) as st:
+ pairRDD = self.ctx._jvm.PairwiseRDD(keyed._jrdd.rdd()).asJavaPairRDD()
+ partitioner = self.ctx._jvm.PythonPartitioner(numPartitions,
+ id(partitionFunc))
jrdd = pairRDD.partitionBy(partitioner).values()
rdd = RDD(jrdd, self.ctx, BatchedSerializer(outputSerializer))
# This is required so that id(partitionFunc) remains unique, even if