diff options
4 files changed, 34 insertions, 15 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index b64352a9e0..64d89f238c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -63,11 +63,6 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ false } - /** - * Whether the "prepare" method is called. - */ - private val prepareCalled = new AtomicBoolean(false) - /** Overridden make copy also propagates sqlContext to copied plan. */ override def makeCopy(newArgs: Array[AnyRef]): SparkPlan = { SQLContext.setActive(sqlContext) @@ -131,7 +126,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ * Execute a query after preparing the query and adding query plan information to created RDDs * for visualization. */ - private final def executeQuery[T](query: => T): T = { + protected final def executeQuery[T](query: => T): T = { RDDOperationScope.withScope(sparkContext, nodeName, false, true) { prepare() waitForSubqueries() @@ -165,7 +160,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ /** * Blocks the thread until all subqueries finish evaluation and update the results. */ - protected def waitForSubqueries(): Unit = { + protected def waitForSubqueries(): Unit = synchronized { // fill in the result of subqueries subqueryResults.foreach { case (e, futureResult) => val rows = ThreadUtils.awaitResult(futureResult, Duration.Inf) @@ -185,13 +180,22 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ } /** + * Whether the "prepare" method is called. + */ + private var prepared = false + + /** * Prepare a SparkPlan for execution. It's idempotent. */ final def prepare(): Unit = { - if (prepareCalled.compareAndSet(false, true)) { - doPrepare() - prepareSubqueries() - children.foreach(_.prepare()) + // doPrepare() may depend on it's children, we should call prepare() on all the children first. + children.foreach(_.prepare()) + synchronized { + if (!prepared) { + prepareSubqueries() + doPrepare() + prepared = true + } } } @@ -202,6 +206,8 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ * * Note: the prepare method has already walked down the tree, so the implementation doesn't need * to call children's prepare methods. + * + * This will only be called once, protected by `this`. */ protected def doPrepare(): Unit = {} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala index 23b2eabd0c..944962b1c8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala @@ -79,10 +79,9 @@ trait CodegenSupport extends SparkPlan { /** * Returns Java source code to process the rows from input RDD. */ - final def produce(ctx: CodegenContext, parent: CodegenSupport): String = { + final def produce(ctx: CodegenContext, parent: CodegenSupport): String = executeQuery { this.parent = parent ctx.freshNamePrefix = variablePrefix - waitForSubqueries() s""" |/*** PRODUCE: ${toCommentSafeString(this.simpleString)} */ |${doProduce(ctx)} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala index 71b6a97852..c023cc573c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala @@ -48,15 +48,21 @@ case class ScalarSubquery( override def toString: String = s"subquery#${exprId.id}" // the first column in first row from `query`. - private var result: Any = null + @volatile private var result: Any = null + @volatile private var updated: Boolean = false def updateResult(v: Any): Unit = { result = v + updated = true } - override def eval(input: InternalRow): Any = result + override def eval(input: InternalRow): Any = { + require(updated, s"$this has not finished") + result + } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + require(updated, s"$this has not finished") Literal.create(result, dataType).doGenCode(ctx, ev) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala index d182495757..f9bada156b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala @@ -123,6 +123,14 @@ class SubquerySuite extends QueryTest with SharedSQLContext { ) } + test("SPARK-14791: scalar subquery inside broadcast join") { + val df = sql("select a, sum(b) as s from l group by a having a > (select avg(a) from l)") + val expected = Row(3, 2.0, 3, 3.0) :: Row(6, null, 6, null) :: Nil + (1 to 10).foreach { _ => + checkAnswer(r.join(df, $"c" === $"a"), expected) + } + } + test("EXISTS predicate subquery") { checkAnswer( sql("select * from l where exists (select * from r where l.a = r.c)"), |