aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorwitgo <witgo@qq.com>2014-03-25 13:28:13 -0700
committerPatrick Wendell <pwendell@gmail.com>2014-03-25 13:28:13 -0700
commit8237df8060039af59eb387f5ea5d6611e8f3e526 (patch)
tree9d7da71bf16b1d4a37ffdf9092e4f54a2376d835
parentf8111eaeb0e35f6aa9b1e3ec1173fff207174155 (diff)
downloadspark-8237df8060039af59eb387f5ea5d6611e8f3e526.tar.gz
spark-8237df8060039af59eb387f5ea5d6611e8f3e526.tar.bz2
spark-8237df8060039af59eb387f5ea5d6611e8f3e526.zip
Avoid Option while generating call site
This is an update on https://github.com/apache/spark/pull/180, which changes the solution from blacklisting "Option.scala" to avoiding the Option code path while generating the call path. Also includes a unit test to prevent this issue in the future, and some minor refactoring. Thanks @witgo for reporting this issue and working on the initial solution! Author: witgo <witgo@qq.com> Author: Aaron Davidson <aaron@databricks.com> Closes #222 from aarondav/180 and squashes the following commits: f74aad1 [Aaron Davidson] Avoid Option while generating call site & add unit tests d2b4980 [witgo] Modify the position of the filter 1bc22d7 [witgo] Fix Stage.name return "apply at Option.scala:120"
-rw-r--r--core/src/main/scala/org/apache/spark/SparkContext.scala3
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/RDD.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/util/Utils.scala18
-rw-r--r--core/src/test/scala/org/apache/spark/SparkContextInfoSuite.scala36
4 files changed, 47 insertions, 12 deletions
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala
index a1003b7925..4dd298177f 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -877,7 +877,8 @@ class SparkContext(
* has overridden the call site, this will return the user's version.
*/
private[spark] def getCallSite(): String = {
- Option(getLocalProperty("externalCallSite")).getOrElse(Utils.formatCallSiteInfo())
+ val defaultCallSite = Utils.getCallSiteInfo
+ Option(getLocalProperty("externalCallSite")).getOrElse(defaultCallSite.toString)
}
/**
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
index 1b43040c6d..4f9d39f865 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
@@ -1041,7 +1041,7 @@ abstract class RDD[T: ClassTag](
/** User code that created this RDD (e.g. `textFile`, `parallelize`). */
@transient private[spark] val creationSiteInfo = Utils.getCallSiteInfo
- private[spark] def getCreationSite = Utils.formatCallSiteInfo(creationSiteInfo)
+ private[spark] def getCreationSite: String = creationSiteInfo.toString
private[spark] def elementClassTag: ClassTag[T] = classTag[T]
diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala
index ad87fda140..62ee704d58 100644
--- a/core/src/main/scala/org/apache/spark/util/Utils.scala
+++ b/core/src/main/scala/org/apache/spark/util/Utils.scala
@@ -679,7 +679,13 @@ private[spark] object Utils extends Logging {
private val SPARK_CLASS_REGEX = """^org\.apache\.spark(\.api\.java)?(\.util)?(\.rdd)?\.[A-Z]""".r
private[spark] class CallSiteInfo(val lastSparkMethod: String, val firstUserFile: String,
- val firstUserLine: Int, val firstUserClass: String)
+ val firstUserLine: Int, val firstUserClass: String) {
+
+ /** Returns a printable version of the call site info suitable for logs. */
+ override def toString = {
+ "%s at %s:%s".format(lastSparkMethod, firstUserFile, firstUserLine)
+ }
+ }
/**
* When called inside a class in the spark package, returns the name of the user code class
@@ -687,8 +693,8 @@ private[spark] object Utils extends Logging {
* This is used, for example, to tell users where in their code each RDD got created.
*/
def getCallSiteInfo: CallSiteInfo = {
- val trace = Thread.currentThread.getStackTrace().filter( el =>
- (!el.getMethodName.contains("getStackTrace")))
+ val trace = Thread.currentThread.getStackTrace()
+ .filterNot(_.getMethodName.contains("getStackTrace"))
// Keep crawling up the stack trace until we find the first function not inside of the spark
// package. We track the last (shallowest) contiguous Spark method. This might be an RDD
@@ -721,12 +727,6 @@ private[spark] object Utils extends Logging {
new CallSiteInfo(lastSparkMethod, firstUserFile, firstUserLine, firstUserClass)
}
- /** Returns a printable version of the call site info suitable for logs. */
- def formatCallSiteInfo(callSiteInfo: CallSiteInfo = Utils.getCallSiteInfo) = {
- "%s at %s:%s".format(callSiteInfo.lastSparkMethod, callSiteInfo.firstUserFile,
- callSiteInfo.firstUserLine)
- }
-
/** Return a string containing part of a file from byte 'start' to 'end'. */
def offsetBytes(path: String, start: Long, end: Long): String = {
val file = new File(path)
diff --git a/core/src/test/scala/org/apache/spark/SparkContextInfoSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextInfoSuite.scala
index 5cb49d9a7f..cd3887dcc7 100644
--- a/core/src/test/scala/org/apache/spark/SparkContextInfoSuite.scala
+++ b/core/src/test/scala/org/apache/spark/SparkContextInfoSuite.scala
@@ -17,7 +17,7 @@
package org.apache.spark
-import org.scalatest.FunSuite
+import org.scalatest.{Assertions, FunSuite}
class SparkContextInfoSuite extends FunSuite with LocalSparkContext {
test("getPersistentRDDs only returns RDDs that are marked as cached") {
@@ -56,4 +56,38 @@ class SparkContextInfoSuite extends FunSuite with LocalSparkContext {
rdd.collect()
assert(sc.getRDDStorageInfo.size === 1)
}
+
+ test("call sites report correct locations") {
+ sc = new SparkContext("local", "test")
+ testPackage.runCallSiteTest(sc)
+ }
+}
+
+/** Call site must be outside of usual org.apache.spark packages (see Utils#SPARK_CLASS_REGEX). */
+package object testPackage extends Assertions {
+ private val CALL_SITE_REGEX = "(.+) at (.+):([0-9]+)".r
+
+ def runCallSiteTest(sc: SparkContext) {
+ val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2)
+ val rddCreationSite = rdd.getCreationSite
+ val curCallSite = sc.getCallSite() // note: 2 lines after definition of "rdd"
+
+ val rddCreationLine = rddCreationSite match {
+ case CALL_SITE_REGEX(func, file, line) => {
+ assert(func === "makeRDD")
+ assert(file === "SparkContextInfoSuite.scala")
+ line.toInt
+ }
+ case _ => fail("Did not match expected call site format")
+ }
+
+ curCallSite match {
+ case CALL_SITE_REGEX(func, file, line) => {
+ assert(func === "getCallSite") // this is correct because we called it from outside of Spark
+ assert(file === "SparkContextInfoSuite.scala")
+ assert(line.toInt === rddCreationLine.toInt + 2)
+ }
+ case _ => fail("Did not match expected call site format")
+ }
+ }
}