aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2016-03-07 20:09:08 -0800
committerDavies Liu <davies.liu@gmail.com>2016-03-07 20:09:08 -0800
commit25bba58d160d0d24e40db1ca595200a52db922ed (patch)
tree9822c6c2f20af2a2faa68fd6d7e1d65921b2b877 /sql
parentda7bfac488b2a25c591986fe5f906b5c98dc34ea (diff)
downloadspark-25bba58d160d0d24e40db1ca595200a52db922ed.tar.gz
spark-25bba58d160d0d24e40db1ca595200a52db922ed.tar.bz2
spark-25bba58d160d0d24e40db1ca595200a52db922ed.zip
[SPARK-13404] [SQL] Create variables for input row when it's actually used
## What changes were proposed in this pull request? This PR change the way how we generate the code for the output variables passing from a plan to it's parent. Right now, they are generated before call consume() of it's parent. It's not efficient, if the parent is a Filter or Join, which could filter out most the rows, the time to access some of the columns that are not used by the Filter or Join are wasted. This PR try to improve this by defering the access of columns until they are actually used by a plan. After this PR, a plan does not need to generate code to evaluate the variables for output, just passing the ExprCode to its parent by `consume()`. In `parent.consumeChild()`, it will check the output from child and `usedInputs`, generate the code for those columns that is part of `usedInputs` before calling `doConsume()`. This PR also change the `if` from ``` if (cond) { xxx } ``` to ``` if (!cond) continue; xxx ``` The new one could help to reduce the nested indents for multiple levels of Filter and BroadcastHashJoin. It also added some comments for operators. ## How was the this patch tested? Unit tests. Manually ran TPCDS Q55, this PR improve the performance about 30% (scale=10, from 2.56s to 1.96s) Author: Davies Liu <davies@databricks.com> Closes #11274 from davies/gen_defer.
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala7
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala41
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala96
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala50
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala32
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala84
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala63
9 files changed, 224 insertions, 155 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
index 4727ff1885..72fe065459 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
@@ -62,9 +62,10 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean)
val javaType = ctx.javaType(dataType)
val value = ctx.getValue(ctx.INPUT_ROW, dataType, ordinal.toString)
if (ctx.currentVars != null && ctx.currentVars(ordinal) != null) {
- ev.isNull = ctx.currentVars(ordinal).isNull
- ev.value = ctx.currentVars(ordinal).value
- ""
+ val oev = ctx.currentVars(ordinal)
+ ev.isNull = oev.isNull
+ ev.value = oev.value
+ oev.code
} else if (nullable) {
s"""
boolean ${ev.isNull} = ${ctx.INPUT_ROW}.isNullAt($ordinal);
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index 63e19564dd..c4265a7539 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -37,6 +37,8 @@ import org.apache.spark.util.Utils
* Java source for evaluating an [[Expression]] given a [[InternalRow]] of input.
*
* @param code The sequence of statements required to evaluate the expression.
+ * It should be empty string, if `isNull` and `value` are already existed, or no code
+ * needed to evaluate them (literals).
* @param isNull A term that holds a boolean value representing whether the expression evaluated
* to null.
* @param value A term for a (possibly primitive) value of the result of the evaluation. Not
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
index 4ad07508ca..3662ed74d2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
@@ -151,9 +151,6 @@ private[sql] case class PhysicalRDD(
val exprs = output.zipWithIndex.map(x => new BoundReference(x._2, x._1.dataType, true))
val row = ctx.freshName("row")
val numOutputRows = metricTerm(ctx, "numOutputRows")
- ctx.INPUT_ROW = row
- ctx.currentVars = null
- val columns = exprs.map(_.gen(ctx))
// The input RDD can either return (all) ColumnarBatches or InternalRows. We determine this
// by looking at the first value of the RDD and then calling the function which will process
@@ -161,7 +158,9 @@ private[sql] case class PhysicalRDD(
// TODO: The abstractions between this class and SqlNewHadoopRDD makes it difficult to know
// here which path to use. Fix this.
-
+ ctx.INPUT_ROW = row
+ ctx.currentVars = null
+ val columns1 = exprs.map(_.gen(ctx))
val scanBatches = ctx.freshName("processBatches")
ctx.addNewFunction(scanBatches,
s"""
@@ -170,12 +169,11 @@ private[sql] case class PhysicalRDD(
| int numRows = $batch.numRows();
| if ($idx == 0) $numOutputRows.add(numRows);
|
- | while ($idx < numRows) {
+ | while (!shouldStop() && $idx < numRows) {
| InternalRow $row = $batch.getRow($idx++);
- | ${columns.map(_.code).mkString("\n").trim}
- | ${consume(ctx, columns).trim}
- | if (shouldStop()) return;
+ | ${consume(ctx, columns1).trim}
| }
+ | if (shouldStop()) return;
|
| if (!$input.hasNext()) {
| $batch = null;
@@ -186,30 +184,37 @@ private[sql] case class PhysicalRDD(
| }
| }""".stripMargin)
+ ctx.INPUT_ROW = row
+ ctx.currentVars = null
+ val columns2 = exprs.map(_.gen(ctx))
+ val inputRow = if (isUnsafeRow) row else null
val scanRows = ctx.freshName("processRows")
ctx.addNewFunction(scanRows,
s"""
| private void $scanRows(InternalRow $row) throws java.io.IOException {
- | while (true) {
+ | boolean firstRow = true;
+ | while (!shouldStop() && (firstRow || $input.hasNext())) {
+ | if (firstRow) {
+ | firstRow = false;
+ | } else {
+ | $row = (InternalRow) $input.next();
+ | }
| $numOutputRows.add(1);
- | ${columns.map(_.code).mkString("\n").trim}
- | ${consume(ctx, columns).trim}
- | if (shouldStop()) return;
- | if (!$input.hasNext()) break;
- | $row = (InternalRow)$input.next();
+ | ${consume(ctx, columns2, inputRow).trim}
| }
| }""".stripMargin)
+ val value = ctx.freshName("value")
s"""
| if ($batch != null) {
| $scanBatches();
| } else if ($input.hasNext()) {
- | Object value = $input.next();
- | if (value instanceof $columnarBatchClz) {
- | $batch = ($columnarBatchClz)value;
+ | Object $value = $input.next();
+ | if ($value instanceof $columnarBatchClz) {
+ | $batch = ($columnarBatchClz)$value;
| $scanBatches();
| } else {
- | $scanRows((InternalRow)value);
+ | $scanRows((InternalRow) $value);
| }
| }
""".stripMargin
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala
index 12998a38f5..524285bc87 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala
@@ -185,8 +185,10 @@ case class Expand(
val numOutput = metricTerm(ctx, "numOutputRows")
val i = ctx.freshName("i")
+ // these column have to declared before the loop.
+ val evaluate = evaluateVariables(outputColumns)
s"""
- |${outputColumns.map(_.code).mkString("\n").trim}
+ |$evaluate
|for (int $i = 0; $i < ${projections.length}; $i ++) {
| switch ($i) {
| ${cases.mkString("\n").trim}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
index 6d231bf74a..45578d50bf 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
@@ -81,11 +81,14 @@ trait CodegenSupport extends SparkPlan {
this.parent = parent
ctx.freshNamePrefix = variablePrefix
waitForSubqueries()
- doProduce(ctx)
+ s"""
+ |/*** PRODUCE: ${toCommentSafeString(this.simpleString)} */
+ |${doProduce(ctx)}
+ """.stripMargin
}
/**
- * Generate the Java source code to process, should be overrided by subclass to support codegen.
+ * Generate the Java source code to process, should be overridden by subclass to support codegen.
*
* doProduce() usually generate the framework, for example, aggregation could generate this:
*
@@ -94,11 +97,11 @@ trait CodegenSupport extends SparkPlan {
* # call child.produce()
* initialized = true;
* }
- * while (hashmap.hasNext()) {
+ * while (!shouldStop() && hashmap.hasNext()) {
* row = hashmap.next();
* # build the aggregation results
- * # create varialbles for results
- * # call consume(), wich will call parent.doConsume()
+ * # create variables for results
+ * # call consume(), which will call parent.doConsume()
* }
*/
protected def doProduce(ctx: CodegenContext): String
@@ -114,27 +117,71 @@ trait CodegenSupport extends SparkPlan {
}
/**
- * Consume the columns generated from it's child, call doConsume() or emit the rows.
+ * Returns source code to evaluate all the variables, and clear the code of them, to prevent
+ * them to be evaluated twice.
+ */
+ protected def evaluateVariables(variables: Seq[ExprCode]): String = {
+ val evaluate = variables.filter(_.code != "").map(_.code.trim).mkString("\n")
+ variables.foreach(_.code = "")
+ evaluate
+ }
+
+ /**
+ * Returns source code to evaluate the variables for required attributes, and clear the code
+ * of evaluated variables, to prevent them to be evaluated twice..
*/
+ protected def evaluateRequiredVariables(
+ attributes: Seq[Attribute],
+ variables: Seq[ExprCode],
+ required: AttributeSet): String = {
+ var evaluateVars = ""
+ variables.zipWithIndex.foreach { case (ev, i) =>
+ if (ev.code != "" && required.contains(attributes(i))) {
+ evaluateVars += ev.code.trim + "\n"
+ ev.code = ""
+ }
+ }
+ evaluateVars
+ }
+
+ /**
+ * The subset of inputSet those should be evaluated before this plan.
+ *
+ * We will use this to insert some code to access those columns that are actually used by current
+ * plan before calling doConsume().
+ */
+ def usedInputs: AttributeSet = references
+
+ /**
+ * Consume the columns generated from its child, call doConsume() or emit the rows.
+ *
+ * An operator could generate variables for the output, or a row, either one could be null.
+ *
+ * If the row is not null, we create variables to access the columns that are actually used by
+ * current plan before calling doConsume().
+ */
def consumeChild(
ctx: CodegenContext,
child: SparkPlan,
input: Seq[ExprCode],
row: String = null): String = {
ctx.freshNamePrefix = variablePrefix
- if (row != null) {
- ctx.currentVars = null
- ctx.INPUT_ROW = row
- val evals = child.output.zipWithIndex.map { case (attr, i) =>
- BoundReference(i, attr.dataType, attr.nullable).gen(ctx)
+ val inputVars =
+ if (row != null) {
+ ctx.currentVars = null
+ ctx.INPUT_ROW = row
+ child.output.zipWithIndex.map { case (attr, i) =>
+ BoundReference(i, attr.dataType, attr.nullable).gen(ctx)
+ }
+ } else {
+ input
}
- s"""
- | ${evals.map(_.code).mkString("\n")}
- | ${doConsume(ctx, evals)}
- """.stripMargin
- } else {
- doConsume(ctx, input)
- }
+ s"""
+ |
+ |/*** CONSUME: ${toCommentSafeString(this.simpleString)} */
+ |${evaluateRequiredVariables(child.output, inputVars, usedInputs)}
+ |${doConsume(ctx, inputVars)}
+ """.stripMargin
}
/**
@@ -145,9 +192,8 @@ trait CodegenSupport extends SparkPlan {
* For example, Filter will generate the code like this:
*
* # code to evaluate the predicate expression, result is isNull1 and value2
- * if (isNull1 || value2) {
- * # call consume(), which will call parent.doConsume()
- * }
+ * if (isNull1 || !value2) continue;
+ * # call consume(), which will call parent.doConsume()
*/
protected def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = {
throw new UnsupportedOperationException
@@ -190,13 +236,9 @@ case class InputAdapter(child: SparkPlan) extends UnaryNode with CodegenSupport
ctx.currentVars = null
val columns = exprs.map(_.gen(ctx))
s"""
- | while ($input.hasNext()) {
+ | while (!shouldStop() && $input.hasNext()) {
| InternalRow $row = (InternalRow) $input.next();
- | ${columns.map(_.code).mkString("\n").trim}
| ${consume(ctx, columns).trim}
- | if (shouldStop()) {
- | return;
- | }
| }
""".stripMargin
}
@@ -332,10 +374,12 @@ case class WholeStageCodegen(child: SparkPlan) extends UnaryNode with CodegenSup
val colExprs = output.zipWithIndex.map { case (attr, i) =>
BoundReference(i, attr.dataType, attr.nullable)
}
+ val evaluateInputs = evaluateVariables(input)
// generate the code to create a UnsafeRow
ctx.currentVars = input
val code = GenerateUnsafeProjection.createCode(ctx, colExprs, false)
s"""
+ |$evaluateInputs
|${code.code.trim}
|append(${code.value}.copy());
""".stripMargin.trim
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
index a46722963a..f07add83d5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
@@ -116,6 +116,8 @@ case class TungstenAggregate(
// all the mode of aggregate expressions
private val modes = aggregateExpressions.map(_.mode).distinct
+ override def usedInputs: AttributeSet = inputSet
+
override def supportCodegen: Boolean = {
// ImperativeAggregate is not supported right now
!aggregateExpressions.exists(_.aggregateFunction.isInstanceOf[ImperativeAggregate])
@@ -164,23 +166,24 @@ case class TungstenAggregate(
""".stripMargin
ExprCode(ev.code + initVars, isNull, value)
}
+ val initBufVar = evaluateVariables(bufVars)
// generate variables for output
- val bufferAttrs = functions.flatMap(_.aggBufferAttributes)
val (resultVars, genResult) = if (modes.contains(Final) || modes.contains(Complete)) {
// evaluate aggregate results
ctx.currentVars = bufVars
val aggResults = functions.map(_.evaluateExpression).map { e =>
- BindReferences.bindReference(e, bufferAttrs).gen(ctx)
+ BindReferences.bindReference(e, aggregateBufferAttributes).gen(ctx)
}
+ val evaluateAggResults = evaluateVariables(aggResults)
// evaluate result expressions
ctx.currentVars = aggResults
val resultVars = resultExpressions.map { e =>
BindReferences.bindReference(e, aggregateAttributes).gen(ctx)
}
(resultVars, s"""
- | ${aggResults.map(_.code).mkString("\n")}
- | ${resultVars.map(_.code).mkString("\n")}
+ |$evaluateAggResults
+ |${evaluateVariables(resultVars)}
""".stripMargin)
} else if (modes.contains(Partial) || modes.contains(PartialMerge)) {
// output the aggregate buffer directly
@@ -188,7 +191,7 @@ case class TungstenAggregate(
} else {
// no aggregate function, the result should be literals
val resultVars = resultExpressions.map(_.gen(ctx))
- (resultVars, resultVars.map(_.code).mkString("\n"))
+ (resultVars, evaluateVariables(resultVars))
}
val doAgg = ctx.freshName("doAggregateWithoutKey")
@@ -196,7 +199,7 @@ case class TungstenAggregate(
s"""
| private void $doAgg() throws java.io.IOException {
| // initialize aggregation buffer
- | ${bufVars.map(_.code).mkString("\n")}
+ | $initBufVar
|
| ${child.asInstanceOf[CodegenSupport].produce(ctx, this)}
| }
@@ -204,7 +207,7 @@ case class TungstenAggregate(
val numOutput = metricTerm(ctx, "numOutputRows")
s"""
- | if (!$initAgg) {
+ | while (!$initAgg) {
| $initAgg = true;
| $doAgg();
|
@@ -241,7 +244,7 @@ case class TungstenAggregate(
}
s"""
| // do aggregate
- | ${aggVals.map(_.code).mkString("\n").trim}
+ | ${evaluateVariables(aggVals)}
| // update aggregation buffer
| ${updates.mkString("\n").trim}
""".stripMargin
@@ -252,8 +255,7 @@ case class TungstenAggregate(
private val declFunctions = aggregateExpressions.map(_.aggregateFunction)
.filter(_.isInstanceOf[DeclarativeAggregate])
.map(_.asInstanceOf[DeclarativeAggregate])
- private val bufferAttributes = declFunctions.flatMap(_.aggBufferAttributes)
- private val bufferSchema = StructType.fromAttributes(bufferAttributes)
+ private val bufferSchema = StructType.fromAttributes(aggregateBufferAttributes)
// The name for HashMap
private var hashMapTerm: String = _
@@ -318,7 +320,7 @@ case class TungstenAggregate(
val mergeExpr = declFunctions.flatMap(_.mergeExpressions)
val mergeProjection = newMutableProjection(
mergeExpr,
- bufferAttributes ++ declFunctions.flatMap(_.inputAggBufferAttributes),
+ aggregateBufferAttributes ++ declFunctions.flatMap(_.inputAggBufferAttributes),
subexpressionEliminationEnabled)()
val joinedRow = new JoinedRow()
@@ -380,15 +382,18 @@ case class TungstenAggregate(
val keyVars = groupingExpressions.zipWithIndex.map { case (e, i) =>
BoundReference(i, e.dataType, e.nullable).gen(ctx)
}
+ val evaluateKeyVars = evaluateVariables(keyVars)
ctx.INPUT_ROW = bufferTerm
- val bufferVars = bufferAttributes.zipWithIndex.map { case (e, i) =>
+ val bufferVars = aggregateBufferAttributes.zipWithIndex.map { case (e, i) =>
BoundReference(i, e.dataType, e.nullable).gen(ctx)
}
+ val evaluateBufferVars = evaluateVariables(bufferVars)
// evaluate the aggregation result
ctx.currentVars = bufferVars
val aggResults = declFunctions.map(_.evaluateExpression).map { e =>
- BindReferences.bindReference(e, bufferAttributes).gen(ctx)
+ BindReferences.bindReference(e, aggregateBufferAttributes).gen(ctx)
}
+ val evaluateAggResults = evaluateVariables(aggResults)
// generate the final result
ctx.currentVars = keyVars ++ aggResults
val inputAttrs = groupingAttributes ++ aggregateAttributes
@@ -396,11 +401,9 @@ case class TungstenAggregate(
BindReferences.bindReference(e, inputAttrs).gen(ctx)
}
s"""
- ${keyVars.map(_.code).mkString("\n")}
- ${bufferVars.map(_.code).mkString("\n")}
- ${aggResults.map(_.code).mkString("\n")}
- ${resultVars.map(_.code).mkString("\n")}
-
+ $evaluateKeyVars
+ $evaluateBufferVars
+ $evaluateAggResults
${consume(ctx, resultVars)}
"""
@@ -422,10 +425,7 @@ case class TungstenAggregate(
val eval = resultExpressions.map{ e =>
BindReferences.bindReference(e, groupingAttributes).gen(ctx)
}
- s"""
- ${eval.map(_.code).mkString("\n")}
- ${consume(ctx, eval)}
- """
+ consume(ctx, eval)
}
}
@@ -508,8 +508,8 @@ case class TungstenAggregate(
ctx.currentVars = input
val hashEval = BindReferences.bindReference(hashExpr, child.output).gen(ctx)
- val inputAttr = bufferAttributes ++ child.output
- ctx.currentVars = new Array[ExprCode](bufferAttributes.length) ++ input
+ val inputAttr = aggregateBufferAttributes ++ child.output
+ ctx.currentVars = new Array[ExprCode](aggregateBufferAttributes.length) ++ input
ctx.INPUT_ROW = buffer
// TODO: support subexpression elimination
val evals = updateExpr.map(BindReferences.bindReference(_, inputAttr).gen(ctx))
@@ -557,7 +557,7 @@ case class TungstenAggregate(
$incCounter
// evaluate aggregate function
- ${evals.map(_.code).mkString("\n").trim}
+ ${evaluateVariables(evals)}
// update aggregate buffer
${updates.mkString("\n").trim}
"""
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
index b2f443c0e9..4a9e736f7a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
@@ -39,15 +39,26 @@ case class Project(projectList: Seq[NamedExpression], child: SparkPlan)
child.asInstanceOf[CodegenSupport].produce(ctx, this)
}
+ override def usedInputs: AttributeSet = {
+ // only the attributes those are used at least twice should be evaluated before this plan,
+ // otherwise we could defer the evaluation until output attribute is actually used.
+ val usedExprIds = projectList.flatMap(_.collect {
+ case a: Attribute => a.exprId
+ })
+ val usedMoreThanOnce = usedExprIds.groupBy(id => id).filter(_._2.size > 1).keySet
+ references.filter(a => usedMoreThanOnce.contains(a.exprId))
+ }
+
override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = {
val exprs = projectList.map(x =>
ExpressionCanonicalizer.execute(BindReferences.bindReference(x, child.output)))
ctx.currentVars = input
- val output = exprs.map(_.gen(ctx))
+ val resultVars = exprs.map(_.gen(ctx))
+ // Evaluation of non-deterministic expressions can't be deferred.
+ val nonDeterministicAttrs = projectList.filterNot(_.deterministic).map(_.toAttribute)
s"""
- | ${output.map(_.code).mkString("\n")}
- |
- | ${consume(ctx, output)}
+ |${evaluateRequiredVariables(output, resultVars, AttributeSet(nonDeterministicAttrs))}
+ |${consume(ctx, resultVars)}
""".stripMargin
}
@@ -89,11 +100,10 @@ case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode wit
s""
}
s"""
- | ${eval.code}
- | if ($nullCheck ${eval.value}) {
- | $numOutput.add(1);
- | ${consume(ctx, ctx.currentVars)}
- | }
+ |${eval.code}
+ |if (!($nullCheck ${eval.value})) continue;
+ |$numOutput.add(1);
+ |${consume(ctx, ctx.currentVars)}
""".stripMargin
}
@@ -228,15 +238,13 @@ case class Range(
| }
| }
|
- | while (!$overflow && $checkEnd) {
+ | while (!$overflow && $checkEnd && !shouldStop()) {
| long $value = $number;
| $number += ${step}L;
| if ($number < $value ^ ${step}L < 0) {
| $overflow = true;
| }
| ${consume(ctx, Seq(ev))}
- |
- | if (shouldStop()) return;
| }
""".stripMargin
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
index 6699dbafe7..c52662a61e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
@@ -190,40 +190,38 @@ case class BroadcastHashJoin(
val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input)
val matched = ctx.freshName("matched")
val buildVars = genBuildSideVars(ctx, matched)
- val resultVars = buildSide match {
- case BuildLeft => buildVars ++ input
- case BuildRight => input ++ buildVars
- }
val numOutput = metricTerm(ctx, "numOutputRows")
- val outputCode = if (condition.isDefined) {
+ val checkCondition = if (condition.isDefined) {
+ val expr = condition.get
+ // evaluate the variables from build side that used by condition
+ val eval = evaluateRequiredVariables(buildPlan.output, buildVars, expr.references)
// filter the output via condition
- ctx.currentVars = resultVars
- val ev = BindReferences.bindReference(condition.get, this.output).gen(ctx)
+ ctx.currentVars = input ++ buildVars
+ val ev = BindReferences.bindReference(expr, streamedPlan.output ++ buildPlan.output).gen(ctx)
s"""
+ |$eval
|${ev.code}
- |if (!${ev.isNull} && ${ev.value}) {
- | $numOutput.add(1);
- | ${consume(ctx, resultVars)}
- |}
+ |if (${ev.isNull} || !${ev.value}) continue;
""".stripMargin
} else {
- s"""
- |$numOutput.add(1);
- |${consume(ctx, resultVars)}
- """.stripMargin
+ ""
}
+ val resultVars = buildSide match {
+ case BuildLeft => buildVars ++ input
+ case BuildRight => input ++ buildVars
+ }
if (broadcastRelation.value.isInstanceOf[UniqueHashedRelation]) {
s"""
|// generate join key for stream side
|${keyEv.code}
|// find matches from HashedRelation
|UnsafeRow $matched = $anyNull ? null: (UnsafeRow)$relationTerm.getValue(${keyEv.value});
- |if ($matched != null) {
- | ${buildVars.map(_.code).mkString("\n")}
- | $outputCode
- |}
+ |if ($matched == null) continue;
+ |$checkCondition
+ |$numOutput.add(1);
+ |${consume(ctx, resultVars)}
""".stripMargin
} else {
@@ -236,13 +234,13 @@ case class BroadcastHashJoin(
|${keyEv.code}
|// find matches from HashRelation
|$bufferType $matches = $anyNull ? null : ($bufferType)$relationTerm.get(${keyEv.value});
- |if ($matches != null) {
- | int $size = $matches.size();
- | for (int $i = 0; $i < $size; $i++) {
- | UnsafeRow $matched = (UnsafeRow) $matches.apply($i);
- | ${buildVars.map(_.code).mkString("\n")}
- | $outputCode
- | }
+ |if ($matches == null) continue;
+ |int $size = $matches.size();
+ |for (int $i = 0; $i < $size; $i++) {
+ | UnsafeRow $matched = (UnsafeRow) $matches.apply($i);
+ | $checkCondition
+ | $numOutput.add(1);
+ | ${consume(ctx, resultVars)}
|}
""".stripMargin
}
@@ -257,21 +255,21 @@ case class BroadcastHashJoin(
val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input)
val matched = ctx.freshName("matched")
val buildVars = genBuildSideVars(ctx, matched)
- val resultVars = buildSide match {
- case BuildLeft => buildVars ++ input
- case BuildRight => input ++ buildVars
- }
val numOutput = metricTerm(ctx, "numOutputRows")
// filter the output via condition
val conditionPassed = ctx.freshName("conditionPassed")
val checkCondition = if (condition.isDefined) {
- ctx.currentVars = resultVars
- val ev = BindReferences.bindReference(condition.get, this.output).gen(ctx)
+ val expr = condition.get
+ // evaluate the variables from build side that used by condition
+ val eval = evaluateRequiredVariables(buildPlan.output, buildVars, expr.references)
+ ctx.currentVars = input ++ buildVars
+ val ev = BindReferences.bindReference(expr, streamedPlan.output ++ buildPlan.output).gen(ctx)
s"""
|boolean $conditionPassed = true;
+ |${eval.trim}
+ |${ev.code}
|if ($matched != null) {
- | ${ev.code}
| $conditionPassed = !${ev.isNull} && ${ev.value};
|}
""".stripMargin
@@ -279,17 +277,21 @@ case class BroadcastHashJoin(
s"final boolean $conditionPassed = true;"
}
+ val resultVars = buildSide match {
+ case BuildLeft => buildVars ++ input
+ case BuildRight => input ++ buildVars
+ }
if (broadcastRelation.value.isInstanceOf[UniqueHashedRelation]) {
s"""
|// generate join key for stream side
|${keyEv.code}
|// find matches from HashedRelation
|UnsafeRow $matched = $anyNull ? null: (UnsafeRow)$relationTerm.getValue(${keyEv.value});
- |${buildVars.map(_.code).mkString("\n")}
|${checkCondition.trim}
|if (!$conditionPassed) {
- | // reset to null
- | ${buildVars.map(v => s"${v.isNull} = true;").mkString("\n")}
+ | $matched = null;
+ | // reset the variables those are already evaluated.
+ | ${buildVars.filter(_.code == "").map(v => s"${v.isNull} = true;").mkString("\n")}
|}
|$numOutput.add(1);
|${consume(ctx, resultVars)}
@@ -311,13 +313,11 @@ case class BroadcastHashJoin(
|// the last iteration of this loop is to emit an empty row if there is no matched rows.
|for (int $i = 0; $i <= $size; $i++) {
| UnsafeRow $matched = $i < $size ? (UnsafeRow) $matches.apply($i) : null;
- | ${buildVars.map(_.code).mkString("\n")}
| ${checkCondition.trim}
- | if ($conditionPassed && ($i < $size || !$found)) {
- | $found = true;
- | $numOutput.add(1);
- | ${consume(ctx, resultVars)}
- | }
+ | if (!$conditionPassed || ($i == $size && $found)) continue;
+ | $found = true;
+ | $numOutput.add(1);
+ | ${consume(ctx, resultVars)}
|}
""".stripMargin
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala
index 7ec4027188..cffd6f6032 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala
@@ -306,11 +306,11 @@ case class SortMergeJoin(
val (used, notUsed) = attributes.zip(variables).partition{ case (a, ev) =>
condRefs.contains(a)
}
- val beforeCond = used.map(_._2.code).mkString("\n")
- val afterCond = notUsed.map(_._2.code).mkString("\n")
+ val beforeCond = evaluateVariables(used.map(_._2))
+ val afterCond = evaluateVariables(notUsed.map(_._2))
(beforeCond, afterCond)
} else {
- (variables.map(_.code).mkString("\n"), "")
+ (evaluateVariables(variables), "")
}
}
@@ -326,41 +326,48 @@ case class SortMergeJoin(
val leftVars = createLeftVars(ctx, leftRow)
val rightRow = ctx.freshName("rightRow")
val rightVars = createRightVar(ctx, rightRow)
- val resultVars = leftVars ++ rightVars
-
- // Check condition
- ctx.currentVars = resultVars
- val cond = if (condition.isDefined) {
- BindReferences.bindReference(condition.get, output).gen(ctx)
- } else {
- ExprCode("", "false", "true")
- }
- // Split the code of creating variables based on whether it's used by condition or not.
- val loaded = ctx.freshName("loaded")
- val (leftBefore, leftAfter) = splitVarsByCondition(left.output, leftVars)
- val (rightBefore, rightAfter) = splitVarsByCondition(right.output, rightVars)
-
val size = ctx.freshName("size")
val i = ctx.freshName("i")
val numOutput = metricTerm(ctx, "numOutputRows")
+ val (beforeLoop, condCheck) = if (condition.isDefined) {
+ // Split the code of creating variables based on whether it's used by condition or not.
+ val loaded = ctx.freshName("loaded")
+ val (leftBefore, leftAfter) = splitVarsByCondition(left.output, leftVars)
+ val (rightBefore, rightAfter) = splitVarsByCondition(right.output, rightVars)
+ // Generate code for condition
+ ctx.currentVars = leftVars ++ rightVars
+ val cond = BindReferences.bindReference(condition.get, output).gen(ctx)
+ // evaluate the columns those used by condition before loop
+ val before = s"""
+ |boolean $loaded = false;
+ |$leftBefore
+ """.stripMargin
+
+ val checking = s"""
+ |$rightBefore
+ |${cond.code}
+ |if (${cond.isNull} || !${cond.value}) continue;
+ |if (!$loaded) {
+ | $loaded = true;
+ | $leftAfter
+ |}
+ |$rightAfter
+ """.stripMargin
+ (before, checking)
+ } else {
+ (evaluateVariables(leftVars), "")
+ }
+
s"""
|while (findNextInnerJoinRows($leftInput, $rightInput)) {
| int $size = $matches.size();
- | boolean $loaded = false;
- | $leftBefore
+ | ${beforeLoop.trim}
| for (int $i = 0; $i < $size; $i ++) {
| InternalRow $rightRow = (InternalRow) $matches.get($i);
- | $rightBefore
- | ${cond.code}
- | if (${cond.isNull} || !${cond.value}) continue;
- | if (!$loaded) {
- | $loaded = true;
- | $leftAfter
- | }
- | $rightAfter
+ | ${condCheck.trim}
| $numOutput.add(1);
- | ${consume(ctx, resultVars)}
+ | ${consume(ctx, leftVars ++ rightVars)}
| }
| if (shouldStop()) return;
|}