aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala28
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala3
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala10
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala8
4 files changed, 34 insertions, 15 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
index b64352a9e0..64d89f238c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
@@ -63,11 +63,6 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
false
}
- /**
- * Whether the "prepare" method is called.
- */
- private val prepareCalled = new AtomicBoolean(false)
-
/** Overridden make copy also propagates sqlContext to copied plan. */
override def makeCopy(newArgs: Array[AnyRef]): SparkPlan = {
SQLContext.setActive(sqlContext)
@@ -131,7 +126,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
* Execute a query after preparing the query and adding query plan information to created RDDs
* for visualization.
*/
- private final def executeQuery[T](query: => T): T = {
+ protected final def executeQuery[T](query: => T): T = {
RDDOperationScope.withScope(sparkContext, nodeName, false, true) {
prepare()
waitForSubqueries()
@@ -165,7 +160,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
/**
* Blocks the thread until all subqueries finish evaluation and update the results.
*/
- protected def waitForSubqueries(): Unit = {
+ protected def waitForSubqueries(): Unit = synchronized {
// fill in the result of subqueries
subqueryResults.foreach { case (e, futureResult) =>
val rows = ThreadUtils.awaitResult(futureResult, Duration.Inf)
@@ -185,13 +180,22 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
}
/**
+ * Whether the "prepare" method is called.
+ */
+ private var prepared = false
+
+ /**
* Prepare a SparkPlan for execution. It's idempotent.
*/
final def prepare(): Unit = {
- if (prepareCalled.compareAndSet(false, true)) {
- doPrepare()
- prepareSubqueries()
- children.foreach(_.prepare())
+ // doPrepare() may depend on it's children, we should call prepare() on all the children first.
+ children.foreach(_.prepare())
+ synchronized {
+ if (!prepared) {
+ prepareSubqueries()
+ doPrepare()
+ prepared = true
+ }
}
}
@@ -202,6 +206,8 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
*
* Note: the prepare method has already walked down the tree, so the implementation doesn't need
* to call children's prepare methods.
+ *
+ * This will only be called once, protected by `this`.
*/
protected def doPrepare(): Unit = {}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
index 23b2eabd0c..944962b1c8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
@@ -79,10 +79,9 @@ trait CodegenSupport extends SparkPlan {
/**
* Returns Java source code to process the rows from input RDD.
*/
- final def produce(ctx: CodegenContext, parent: CodegenSupport): String = {
+ final def produce(ctx: CodegenContext, parent: CodegenSupport): String = executeQuery {
this.parent = parent
ctx.freshNamePrefix = variablePrefix
- waitForSubqueries()
s"""
|/*** PRODUCE: ${toCommentSafeString(this.simpleString)} */
|${doProduce(ctx)}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala
index 71b6a97852..c023cc573c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala
@@ -48,15 +48,21 @@ case class ScalarSubquery(
override def toString: String = s"subquery#${exprId.id}"
// the first column in first row from `query`.
- private var result: Any = null
+ @volatile private var result: Any = null
+ @volatile private var updated: Boolean = false
def updateResult(v: Any): Unit = {
result = v
+ updated = true
}
- override def eval(input: InternalRow): Any = result
+ override def eval(input: InternalRow): Any = {
+ require(updated, s"$this has not finished")
+ result
+ }
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+ require(updated, s"$this has not finished")
Literal.create(result, dataType).doGenCode(ctx, ev)
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
index d182495757..f9bada156b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
@@ -123,6 +123,14 @@ class SubquerySuite extends QueryTest with SharedSQLContext {
)
}
+ test("SPARK-14791: scalar subquery inside broadcast join") {
+ val df = sql("select a, sum(b) as s from l group by a having a > (select avg(a) from l)")
+ val expected = Row(3, 2.0, 3, 3.0) :: Row(6, null, 6, null) :: Nil
+ (1 to 10).foreach { _ =>
+ checkAnswer(r.join(df, $"c" === $"a"), expected)
+ }
+ }
+
test("EXISTS predicate subquery") {
checkAnswer(
sql("select * from l where exists (select * from r where l.a = r.c)"),