aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMatei Zaharia <matei@databricks.com>2014-01-02 15:54:54 -0500
committerMatei Zaharia <matei@databricks.com>2014-01-02 15:54:54 -0500
commitca67909cd4b4accea64498b3bc1421c6d75e33a6 (patch)
treeb6e13c3c0be95ea99c857429c97d3687c61050a3
parent3713f8129a618a633a7aca8c944960c3e7ac9d3b (diff)
parentfec01664a717c8ecf84eaf7a2523a62cd5d3b4b8 (diff)
downloadspark-ca67909cd4b4accea64498b3bc1421c6d75e33a6.tar.gz
spark-ca67909cd4b4accea64498b3bc1421c6d75e33a6.tar.bz2
spark-ca67909cd4b4accea64498b3bc1421c6d75e33a6.zip
Merge pull request #311 from tmyklebu/master
SPARK-991: Report information gleaned from a Python stacktrace in the UI Scala: - Added setCallSite/clearCallSite to SparkContext and JavaSparkContext. These functions mutate a LocalProperty called "externalCallSite." - Add a wrapper, getCallSite, that checks for an externalCallSite and, if none is found, calls the usual Utils.formatSparkCallSite. - Change everything that calls Utils.formatSparkCallSite to call getCallSite instead. Except getCallSite. - Add wrappers to setCallSite/clearCallSite wrappers to JavaSparkContext. Python: - Add a gruesome hack to rdd.py that inspects the traceback and guesses what you want to see in the UI. - Add a RAII wrapper around said gruesome hack that calls setCallSite/clearCallSite as appropriate. - Wire said RAII wrapper up around three calls into the Scala code. I'm not sure that I hit all the spots with the RAII wrapper. I'm also not sure that my gruesome hack does exactly what we want. One could also approach this change by refactoring runJob/submitJob/runApproximateJob to take a call site, then threading that parameter through everything that needs to know it. One might object to the pointless-looking wrappers in JavaSparkContext. Unfortunately, I can't directly access the SparkContext from Python---or, if I can, I don't know how---so I need to wrap everything that matters in JavaSparkContext. Conflicts: core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala
-rw-r--r--core/src/main/scala/org/apache/spark/SparkContext.scala26
-rw-r--r--core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala14
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/RDD.scala2
-rw-r--r--python/pyspark/rdd.py66
4 files changed, 93 insertions, 15 deletions
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala
index 84bd0f7ffd..4d6a97e255 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -729,6 +729,26 @@ class SparkContext(
}
/**
+ * Support function for API backtraces.
+ */
+ def setCallSite(site: String) {
+ setLocalProperty("externalCallSite", site)
+ }
+
+ /**
+ * Support function for API backtraces.
+ */
+ def clearCallSite() {
+ setLocalProperty("externalCallSite", null)
+ }
+
+ private[spark] def getCallSite(): String = {
+ val callSite = getLocalProperty("externalCallSite")
+ if (callSite == null) return Utils.formatSparkCallSite
+ callSite
+ }
+
+ /**
* Run a function on a given set of partitions in an RDD and pass the results to the given
* handler function. This is the main entry point for all actions in Spark. The allowLocal
* flag specifies whether the scheduler can run the computation on the driver rather than
@@ -740,7 +760,7 @@ class SparkContext(
partitions: Seq[Int],
allowLocal: Boolean,
resultHandler: (Int, U) => Unit) {
- val callSite = Utils.formatSparkCallSite
+ val callSite = getCallSite
val cleanedFunc = clean(func)
logInfo("Starting job: " + callSite)
val start = System.nanoTime
@@ -824,7 +844,7 @@ class SparkContext(
func: (TaskContext, Iterator[T]) => U,
evaluator: ApproximateEvaluator[U, R],
timeout: Long): PartialResult[R] = {
- val callSite = Utils.formatSparkCallSite
+ val callSite = getCallSite
logInfo("Starting job: " + callSite)
val start = System.nanoTime
val result = dagScheduler.runApproximateJob(rdd, func, evaluator, callSite, timeout,
@@ -844,7 +864,7 @@ class SparkContext(
resultFunc: => R): SimpleFutureAction[R] =
{
val cleanF = clean(processPartition)
- val callSite = Utils.formatSparkCallSite
+ val callSite = getCallSite
val waiter = dagScheduler.submitJob(
rdd,
(context: TaskContext, iter: Iterator[T]) => cleanF(iter),
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala
index 0680a065a3..5be5317f40 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala
@@ -411,6 +411,20 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork
* changed at runtime.
*/
def getConf: SparkConf = sc.getConf
+
+ /**
+ * Pass-through to SparkContext.setCallSite. For API support only.
+ */
+ def setCallSite(site: String) {
+ sc.setCallSite(site)
+ }
+
+ /**
+ * Pass-through to SparkContext.setCallSite. For API support only.
+ */
+ def clearCallSite() {
+ sc.clearCallSite()
+ }
}
object JavaSparkContext {
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
index 6a7b0f8a86..3f41b66279 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
@@ -953,7 +953,7 @@ abstract class RDD[T: ClassTag](
private var storageLevel: StorageLevel = StorageLevel.NONE
/** Record user function generating this RDD. */
- @transient private[spark] val origin = Utils.formatSparkCallSite
+ @transient private[spark] val origin = sc.getCallSite
private[spark] def elementClassTag: ClassTag[T] = classTag[T]
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index f87923e6fa..6fb4a7b3be 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -23,6 +23,7 @@ import operator
import os
import sys
import shlex
+import traceback
from subprocess import Popen, PIPE
from tempfile import NamedTemporaryFile
from threading import Thread
@@ -39,6 +40,46 @@ from py4j.java_collections import ListConverter, MapConverter
__all__ = ["RDD"]
+def _extract_concise_traceback():
+ tb = traceback.extract_stack()
+ if len(tb) == 0:
+ return "I'm lost!"
+ # HACK: This function is in a file called 'rdd.py' in the top level of
+ # everything PySpark. Just trim off the directory name and assume
+ # everything in that tree is PySpark guts.
+ file, line, module, what = tb[len(tb) - 1]
+ sparkpath = os.path.dirname(file)
+ first_spark_frame = len(tb) - 1
+ for i in range(0, len(tb)):
+ file, line, fun, what = tb[i]
+ if file.startswith(sparkpath):
+ first_spark_frame = i
+ break
+ if first_spark_frame == 0:
+ file, line, fun, what = tb[0]
+ return "%s at %s:%d" % (fun, file, line)
+ sfile, sline, sfun, swhat = tb[first_spark_frame]
+ ufile, uline, ufun, uwhat = tb[first_spark_frame-1]
+ return "%s at %s:%d" % (sfun, ufile, uline)
+
+_spark_stack_depth = 0
+
+class _JavaStackTrace(object):
+ def __init__(self, sc):
+ self._traceback = _extract_concise_traceback()
+ self._context = sc
+
+ def __enter__(self):
+ global _spark_stack_depth
+ if _spark_stack_depth == 0:
+ self._context._jsc.setCallSite(self._traceback)
+ _spark_stack_depth += 1
+
+ def __exit__(self, type, value, tb):
+ global _spark_stack_depth
+ _spark_stack_depth -= 1
+ if _spark_stack_depth == 0:
+ self._context._jsc.setCallSite(None)
class RDD(object):
"""
@@ -401,7 +442,8 @@ class RDD(object):
"""
Return a list that contains all of the elements in this RDD.
"""
- bytesInJava = self._jrdd.collect().iterator()
+ with _JavaStackTrace(self.context) as st:
+ bytesInJava = self._jrdd.collect().iterator()
return list(self._collect_iterator_through_file(bytesInJava))
def _collect_iterator_through_file(self, iterator):
@@ -582,13 +624,14 @@ class RDD(object):
# TODO(shivaram): Similar to the scala implementation, update the take
# method to scan multiple splits based on an estimate of how many elements
# we have per-split.
- for partition in range(mapped._jrdd.splits().size()):
- partitionsToTake = self.ctx._gateway.new_array(self.ctx._jvm.int, 1)
- partitionsToTake[0] = partition
- iterator = mapped._jrdd.collectPartitions(partitionsToTake)[0].iterator()
- items.extend(mapped._collect_iterator_through_file(iterator))
- if len(items) >= num:
- break
+ with _JavaStackTrace(self.context) as st:
+ for partition in range(mapped._jrdd.splits().size()):
+ partitionsToTake = self.ctx._gateway.new_array(self.ctx._jvm.int, 1)
+ partitionsToTake[0] = partition
+ iterator = mapped._jrdd.collectPartitions(partitionsToTake)[0].iterator()
+ items.extend(mapped._collect_iterator_through_file(iterator))
+ if len(items) >= num:
+ break
return items[:num]
def first(self):
@@ -765,9 +808,10 @@ class RDD(object):
yield outputSerializer.dumps(items)
keyed = PipelinedRDD(self, add_shuffle_key)
keyed._bypass_serializer = True
- pairRDD = self.ctx._jvm.PairwiseRDD(keyed._jrdd.rdd()).asJavaPairRDD()
- partitioner = self.ctx._jvm.PythonPartitioner(numPartitions,
- id(partitionFunc))
+ with _JavaStackTrace(self.context) as st:
+ pairRDD = self.ctx._jvm.PairwiseRDD(keyed._jrdd.rdd()).asJavaPairRDD()
+ partitioner = self.ctx._jvm.PythonPartitioner(numPartitions,
+ id(partitionFunc))
jrdd = pairRDD.partitionBy(partitioner).values()
rdd = RDD(jrdd, self.ctx, BatchedSerializer(outputSerializer))
# This is required so that id(partitionFunc) remains unique, even if