aboutsummaryrefslogtreecommitdiff
path: root/core/src/main/scala/org/apache/spark/util/Utils.scala
diff options
context:
space:
mode:
Diffstat (limited to 'core/src/main/scala/org/apache/spark/util/Utils.scala')
-rw-r--r--core/src/main/scala/org/apache/spark/util/Utils.scala27
1 files changed, 20 insertions, 7 deletions
diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala
index ed06384432..2755887fee 100644
--- a/core/src/main/scala/org/apache/spark/util/Utils.scala
+++ b/core/src/main/scala/org/apache/spark/util/Utils.scala
@@ -49,6 +49,11 @@ import org.apache.spark.serializer.{DeserializationStream, SerializationStream,
/** CallSite represents a place in user code. It can have a short and a long form. */
private[spark] case class CallSite(shortForm: String, longForm: String)
+private[spark] object CallSite {
+ val SHORT_FORM = "callSite.short"
+ val LONG_FORM = "callSite.long"
+}
+
/**
* Various utility methods used by Spark.
*/
@@ -859,18 +864,26 @@ private[spark] object Utils extends Logging {
}
}
- /**
- * A regular expression to match classes of the "core" Spark API that we want to skip when
- * finding the call site of a method.
- */
- private val SPARK_CLASS_REGEX = """^org\.apache\.spark(\.api\.java)?(\.util)?(\.rdd)?\.[A-Z]""".r
+ /** Default filtering function for finding call sites using `getCallSite`. */
+ private def coreExclusionFunction(className: String): Boolean = {
+ // A regular expression to match classes of the "core" Spark API that we want to skip when
+ // finding the call site of a method.
+ val SPARK_CORE_CLASS_REGEX = """^org\.apache\.spark(\.api\.java)?(\.util)?(\.rdd)?\.[A-Z]""".r
+ val SCALA_CLASS_REGEX = """^scala""".r
+ val isSparkCoreClass = SPARK_CORE_CLASS_REGEX.findFirstIn(className).isDefined
+ val isScalaClass = SCALA_CLASS_REGEX.findFirstIn(className).isDefined
+ // If the class is a Spark internal class or a Scala class, then exclude.
+ isSparkCoreClass || isScalaClass
+ }
/**
* When called inside a class in the spark package, returns the name of the user code class
* (outside the spark package) that called into Spark, as well as which Spark method they called.
* This is used, for example, to tell users where in their code each RDD got created.
+ *
+ * @param skipClass Function that is used to exclude non-user-code classes.
*/
- def getCallSite: CallSite = {
+ def getCallSite(skipClass: String => Boolean = coreExclusionFunction): CallSite = {
val trace = Thread.currentThread.getStackTrace()
.filterNot { ste:StackTraceElement =>
// When running under some profilers, the current stack trace might contain some bogus
@@ -891,7 +904,7 @@ private[spark] object Utils extends Logging {
for (el <- trace) {
if (insideSpark) {
- if (SPARK_CLASS_REGEX.findFirstIn(el.getClassName).isDefined) {
+ if (skipClass(el.getClassName)) {
lastSparkMethod = if (el.getMethodName == "<init>") {
// Spark method is a constructor; get its class name
el.getClassName.substring(el.getClassName.lastIndexOf('.') + 1)