diff options
Diffstat (limited to 'sql/core/src/main')
3 files changed, 26 insertions, 15 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index b64352a9e0..64d89f238c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -63,11 +63,6 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ false } - /** - * Whether the "prepare" method is called. - */ - private val prepareCalled = new AtomicBoolean(false) - /** Overridden make copy also propagates sqlContext to copied plan. */ override def makeCopy(newArgs: Array[AnyRef]): SparkPlan = { SQLContext.setActive(sqlContext) @@ -131,7 +126,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ * Execute a query after preparing the query and adding query plan information to created RDDs * for visualization. */ - private final def executeQuery[T](query: => T): T = { + protected final def executeQuery[T](query: => T): T = { RDDOperationScope.withScope(sparkContext, nodeName, false, true) { prepare() waitForSubqueries() @@ -165,7 +160,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ /** * Blocks the thread until all subqueries finish evaluation and update the results. */ - protected def waitForSubqueries(): Unit = { + protected def waitForSubqueries(): Unit = synchronized { // fill in the result of subqueries subqueryResults.foreach { case (e, futureResult) => val rows = ThreadUtils.awaitResult(futureResult, Duration.Inf) @@ -185,13 +180,22 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ } /** + * Whether the "prepare" method is called. + */ + private var prepared = false + + /** * Prepare a SparkPlan for execution. It's idempotent. */ final def prepare(): Unit = { - if (prepareCalled.compareAndSet(false, true)) { - doPrepare() - prepareSubqueries() - children.foreach(_.prepare()) + // doPrepare() may depend on it's children, we should call prepare() on all the children first. + children.foreach(_.prepare()) + synchronized { + if (!prepared) { + prepareSubqueries() + doPrepare() + prepared = true + } } } @@ -202,6 +206,8 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ * * Note: the prepare method has already walked down the tree, so the implementation doesn't need * to call children's prepare methods. + * + * This will only be called once, protected by `this`. */ protected def doPrepare(): Unit = {} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala index 23b2eabd0c..944962b1c8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala @@ -79,10 +79,9 @@ trait CodegenSupport extends SparkPlan { /** * Returns Java source code to process the rows from input RDD. */ - final def produce(ctx: CodegenContext, parent: CodegenSupport): String = { + final def produce(ctx: CodegenContext, parent: CodegenSupport): String = executeQuery { this.parent = parent ctx.freshNamePrefix = variablePrefix - waitForSubqueries() s""" |/*** PRODUCE: ${toCommentSafeString(this.simpleString)} */ |${doProduce(ctx)} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala index 71b6a97852..c023cc573c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala @@ -48,15 +48,21 @@ case class ScalarSubquery( override def toString: String = s"subquery#${exprId.id}" // the first column in first row from `query`. - private var result: Any = null + @volatile private var result: Any = null + @volatile private var updated: Boolean = false def updateResult(v: Any): Unit = { result = v + updated = true } - override def eval(input: InternalRow): Any = result + override def eval(input: InternalRow): Any = { + require(updated, s"$this has not finished") + result + } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + require(updated, s"$this has not finished") Literal.create(result, dataType).doGenCode(ctx, ev) } } |