diff options
author | Ala Luszczak <ala@databricks.com> | 2017-02-15 17:06:04 +0100 |
---|---|---|
committer | Reynold Xin <rxin@databricks.com> | 2017-02-15 17:06:04 +0100 |
commit | b55563c17ec67f56017fa6bda5a18310c38dbefb (patch) | |
tree | fce28a24d93367d24b8ac8facded0e3534bbbe29 /sql | |
parent | 3973403d5d90a48e3a995159680239ba5240e30c (diff) | |
download | spark-b55563c17ec67f56017fa6bda5a18310c38dbefb.tar.gz spark-b55563c17ec67f56017fa6bda5a18310c38dbefb.tar.bz2 spark-b55563c17ec67f56017fa6bda5a18310c38dbefb.zip |
[SPARK-19607] Finding QueryExecution that matches provided executionId
## What changes were proposed in this pull request?
Implementing a mapping between executionId and corresponding QueryExecution in SQLExecution.
## How was this patch tested?
Adds a unit test.
Author: Ala Luszczak <ala@databricks.com>
Closes #16940 from ala/execution-id.
Diffstat (limited to 'sql')
-rw-r--r-- | sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala | 9 | ||||
-rw-r--r-- | sql/core/src/test/scala/org/apache/spark/sql/execution/SQLExecutionSuite.scala | 32 |
2 files changed, 41 insertions, 0 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala index ec07aab359..be35916e34 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution +import java.util.concurrent.ConcurrentHashMap import java.util.concurrent.atomic.AtomicLong import org.apache.spark.SparkContext @@ -32,6 +33,12 @@ object SQLExecution { private def nextExecutionId: Long = _nextExecutionId.getAndIncrement + private val executionIdToQueryExecution = new ConcurrentHashMap[Long, QueryExecution]() + + def getQueryExecution(executionId: Long): QueryExecution = { + executionIdToQueryExecution.get(executionId) + } + /** * Wrap an action that will execute "queryExecution" to track all Spark jobs in the body so that * we can connect them with an execution. @@ -44,6 +51,7 @@ object SQLExecution { if (oldExecutionId == null) { val executionId = SQLExecution.nextExecutionId sc.setLocalProperty(EXECUTION_ID_KEY, executionId.toString) + executionIdToQueryExecution.put(executionId, queryExecution) val r = try { // sparkContext.getCallSite() would first try to pick up any call site that was previously // set, then fall back to Utils.getCallSite(); call Utils.getCallSite() directly on @@ -60,6 +68,7 @@ object SQLExecution { executionId, System.currentTimeMillis())) } } finally { + executionIdToQueryExecution.remove(executionId) sc.setLocalProperty(EXECUTION_ID_KEY, null) } r diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLExecutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLExecutionSuite.scala index ad41111bec..b059706783 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLExecutionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLExecutionSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution import java.util.Properties import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} +import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart} import org.apache.spark.sql.SparkSession class SQLExecutionSuite extends SparkFunSuite { @@ -102,6 +103,33 @@ class SQLExecutionSuite extends SparkFunSuite { } } + + test("Finding QueryExecution for given executionId") { + val spark = SparkSession.builder.master("local[*]").appName("test").getOrCreate() + import spark.implicits._ + + var queryExecution: QueryExecution = null + + spark.sparkContext.addSparkListener(new SparkListener { + override def onJobStart(jobStart: SparkListenerJobStart): Unit = { + val executionIdStr = jobStart.properties.getProperty(SQLExecution.EXECUTION_ID_KEY) + if (executionIdStr != null) { + queryExecution = SQLExecution.getQueryExecution(executionIdStr.toLong) + } + SQLExecutionSuite.canProgress = true + } + }) + + val df = spark.range(1).map { x => + while (!SQLExecutionSuite.canProgress) { + Thread.sleep(1) + } + x + } + df.collect() + + assert(df.queryExecution === queryExecution) + } } /** @@ -114,3 +142,7 @@ private class BadSparkContext(conf: SparkConf) extends SparkContext(conf) { override protected def initialValue(): Properties = new Properties() } } + +object SQLExecutionSuite { + @volatile var canProgress = false +} |