aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala7
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala41
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala96
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala50
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala32
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala84
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala63
9 files changed, 224 insertions, 155 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
index 4727ff1885..72fe065459 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
@@ -62,9 +62,10 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean)
val javaType = ctx.javaType(dataType)
val value = ctx.getValue(ctx.INPUT_ROW, dataType, ordinal.toString)
if (ctx.currentVars != null && ctx.currentVars(ordinal) != null) {
- ev.isNull = ctx.currentVars(ordinal).isNull
- ev.value = ctx.currentVars(ordinal).value
- ""
+ val oev = ctx.currentVars(ordinal)
+ ev.isNull = oev.isNull
+ ev.value = oev.value
+ oev.code
} else if (nullable) {
s"""
boolean ${ev.isNull} = ${ctx.INPUT_ROW}.isNullAt($ordinal);
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index 63e19564dd..c4265a7539 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -37,6 +37,8 @@ import org.apache.spark.util.Utils
* Java source for evaluating an [[Expression]] given a [[InternalRow]] of input.
*
* @param code The sequence of statements required to evaluate the expression.
+ * It should be empty string, if `isNull` and `value` are already existed, or no code
+ * needed to evaluate them (literals).
* @param isNull A term that holds a boolean value representing whether the expression evaluated
* to null.
* @param value A term for a (possibly primitive) value of the result of the evaluation. Not
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
index 4ad07508ca..3662ed74d2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
@@ -151,9 +151,6 @@ private[sql] case class PhysicalRDD(
val exprs = output.zipWithIndex.map(x => new BoundReference(x._2, x._1.dataType, true))
val row = ctx.freshName("row")
val numOutputRows = metricTerm(ctx, "numOutputRows")
- ctx.INPUT_ROW = row
- ctx.currentVars = null
- val columns = exprs.map(_.gen(ctx))
// The input RDD can either return (all) ColumnarBatches or InternalRows. We determine this
// by looking at the first value of the RDD and then calling the function which will process
@@ -161,7 +158,9 @@ private[sql] case class PhysicalRDD(
// TODO: The abstractions between this class and SqlNewHadoopRDD makes it difficult to know
// here which path to use. Fix this.
-
+ ctx.INPUT_ROW = row
+ ctx.currentVars = null
+ val columns1 = exprs.map(_.gen(ctx))
val scanBatches = ctx.freshName("processBatches")
ctx.addNewFunction(scanBatches,
s"""
@@ -170,12 +169,11 @@ private[sql] case class PhysicalRDD(
| int numRows = $batch.numRows();
| if ($idx == 0) $numOutputRows.add(numRows);
|
- | while ($idx < numRows) {
+ | while (!shouldStop() && $idx < numRows) {
| InternalRow $row = $batch.getRow($idx++);
- | ${columns.map(_.code).mkString("\n").trim}
- | ${consume(ctx, columns).trim}
- | if (shouldStop()) return;
+ | ${consume(ctx, columns1).trim}
| }
+ | if (shouldStop()) return;
|
| if (!$input.hasNext()) {
| $batch = null;
@@ -186,30 +184,37 @@ private[sql] case class PhysicalRDD(
| }
| }""".stripMargin)
+ ctx.INPUT_ROW = row
+ ctx.currentVars = null
+ val columns2 = exprs.map(_.gen(ctx))
+ val inputRow = if (isUnsafeRow) row else null
val scanRows = ctx.freshName("processRows")
ctx.addNewFunction(scanRows,
s"""
| private void $scanRows(InternalRow $row) throws java.io.IOException {
- | while (true) {
+ | boolean firstRow = true;
+ | while (!shouldStop() && (firstRow || $input.hasNext())) {
+ | if (firstRow) {
+ | firstRow = false;
+ | } else {
+ | $row = (InternalRow) $input.next();
+ | }
| $numOutputRows.add(1);
- | ${columns.map(_.code).mkString("\n").trim}
- | ${consume(ctx, columns).trim}
- | if (shouldStop()) return;
- | if (!$input.hasNext()) break;
- | $row = (InternalRow)$input.next();
+ | ${consume(ctx, columns2, inputRow).trim}
| }
| }""".stripMargin)
+ val value = ctx.freshName("value")
s"""
| if ($batch != null) {
| $scanBatches();
| } else if ($input.hasNext()) {
- | Object value = $input.next();
- | if (value instanceof $columnarBatchClz) {
- | $batch = ($columnarBatchClz)value;
+ | Object $value = $input.next();
+ | if ($value instanceof $columnarBatchClz) {
+ | $batch = ($columnarBatchClz)$value;
| $scanBatches();
| } else {
- | $scanRows((InternalRow)value);
+ | $scanRows((InternalRow) $value);
| }
| }
""".stripMargin
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala
index 12998a38f5..524285bc87 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala
@@ -185,8 +185,10 @@ case class Expand(
val numOutput = metricTerm(ctx, "numOutputRows")
val i = ctx.freshName("i")
+ // these column have to declared before the loop.
+ val evaluate = evaluateVariables(outputColumns)
s"""
- |${outputColumns.map(_.code).mkString("\n").trim}
+ |$evaluate
|for (int $i = 0; $i < ${projections.length}; $i ++) {
| switch ($i) {
| ${cases.mkString("\n").trim}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
index 6d231bf74a..45578d50bf 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
@@ -81,11 +81,14 @@ trait CodegenSupport extends SparkPlan {
this.parent = parent
ctx.freshNamePrefix = variablePrefix
waitForSubqueries()
- doProduce(ctx)
+ s"""
+ |/*** PRODUCE: ${toCommentSafeString(this.simpleString)} */
+ |${doProduce(ctx)}
+ """.stripMargin
}
/**
- * Generate the Java source code to process, should be overrided by subclass to support codegen.
+ * Generate the Java source code to process, should be overridden by subclass to support codegen.
*
* doProduce() usually generate the framework, for example, aggregation could generate this:
*
@@ -94,11 +97,11 @@ trait CodegenSupport extends SparkPlan {
* # call child.produce()
* initialized = true;
* }
- * while (hashmap.hasNext()) {
+ * while (!shouldStop() && hashmap.hasNext()) {
* row = hashmap.next();
* # build the aggregation results
- * # create varialbles for results
- * # call consume(), wich will call parent.doConsume()
+ * # create variables for results
+ * # call consume(), which will call parent.doConsume()
* }
*/
protected def doProduce(ctx: CodegenContext): String
@@ -114,27 +117,71 @@ trait CodegenSupport extends SparkPlan {
}
/**
- * Consume the columns generated from it's child, call doConsume() or emit the rows.
+ * Returns source code to evaluate all the variables, and clear the code of them, to prevent
+ * them to be evaluated twice.
+ */
+ protected def evaluateVariables(variables: Seq[ExprCode]): String = {
+ val evaluate = variables.filter(_.code != "").map(_.code.trim).mkString("\n")
+ variables.foreach(_.code = "")
+ evaluate
+ }
+
+ /**
+ * Returns source code to evaluate the variables for required attributes, and clear the code
+ * of evaluated variables, to prevent them to be evaluated twice..
*/
+ protected def evaluateRequiredVariables(
+ attributes: Seq[Attribute],
+ variables: Seq[ExprCode],
+ required: AttributeSet): String = {
+ var evaluateVars = ""
+ variables.zipWithIndex.foreach { case (ev, i) =>
+ if (ev.code != "" && required.contains(attributes(i))) {
+ evaluateVars += ev.code.trim + "\n"
+ ev.code = ""
+ }
+ }
+ evaluateVars
+ }
+
+ /**
+ * The subset of inputSet those should be evaluated before this plan.
+ *
+ * We will use this to insert some code to access those columns that are actually used by current
+ * plan before calling doConsume().
+ */
+ def usedInputs: AttributeSet = references
+
+ /**
+ * Consume the columns generated from its child, call doConsume() or emit the rows.
+ *
+ * An operator could generate variables for the output, or a row, either one could be null.
+ *
+ * If the row is not null, we create variables to access the columns that are actually used by
+ * current plan before calling doConsume().
+ */
def consumeChild(
ctx: CodegenContext,
child: SparkPlan,
input: Seq[ExprCode],
row: String = null): String = {
ctx.freshNamePrefix = variablePrefix
- if (row != null) {
- ctx.currentVars = null
- ctx.INPUT_ROW = row
- val evals = child.output.zipWithIndex.map { case (attr, i) =>
- BoundReference(i, attr.dataType, attr.nullable).gen(ctx)
+ val inputVars =
+ if (row != null) {
+ ctx.currentVars = null
+ ctx.INPUT_ROW = row
+ child.output.zipWithIndex.map { case (attr, i) =>
+ BoundReference(i, attr.dataType, attr.nullable).gen(ctx)
+ }
+ } else {
+ input
}
- s"""
- | ${evals.map(_.code).mkString("\n")}
- | ${doConsume(ctx, evals)}
- """.stripMargin
- } else {
- doConsume(ctx, input)
- }
+ s"""
+ |
+ |/*** CONSUME: ${toCommentSafeString(this.simpleString)} */
+ |${evaluateRequiredVariables(child.output, inputVars, usedInputs)}
+ |${doConsume(ctx, inputVars)}
+ """.stripMargin
}
/**
@@ -145,9 +192,8 @@ trait CodegenSupport extends SparkPlan {
* For example, Filter will generate the code like this:
*
* # code to evaluate the predicate expression, result is isNull1 and value2
- * if (isNull1 || value2) {
- * # call consume(), which will call parent.doConsume()
- * }
+ * if (isNull1 || !value2) continue;
+ * # call consume(), which will call parent.doConsume()
*/
protected def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = {
throw new UnsupportedOperationException
@@ -190,13 +236,9 @@ case class InputAdapter(child: SparkPlan) extends UnaryNode with CodegenSupport
ctx.currentVars = null
val columns = exprs.map(_.gen(ctx))
s"""
- | while ($input.hasNext()) {
+ | while (!shouldStop() && $input.hasNext()) {
| InternalRow $row = (InternalRow) $input.next();
- | ${columns.map(_.code).mkString("\n").trim}
| ${consume(ctx, columns).trim}
- | if (shouldStop()) {
- | return;
- | }
| }
""".stripMargin
}
@@ -332,10 +374,12 @@ case class WholeStageCodegen(child: SparkPlan) extends UnaryNode with CodegenSup
val colExprs = output.zipWithIndex.map { case (attr, i) =>
BoundReference(i, attr.dataType, attr.nullable)
}
+ val evaluateInputs = evaluateVariables(input)
// generate the code to create a UnsafeRow
ctx.currentVars = input
val code = GenerateUnsafeProjection.createCode(ctx, colExprs, false)
s"""
+ |$evaluateInputs
|${code.code.trim}
|append(${code.value}.copy());
""".stripMargin.trim
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
index a46722963a..f07add83d5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
@@ -116,6 +116,8 @@ case class TungstenAggregate(
// all the mode of aggregate expressions
private val modes = aggregateExpressions.map(_.mode).distinct
+ override def usedInputs: AttributeSet = inputSet
+
override def supportCodegen: Boolean = {
// ImperativeAggregate is not supported right now
!aggregateExpressions.exists(_.aggregateFunction.isInstanceOf[ImperativeAggregate])
@@ -164,23 +166,24 @@ case class TungstenAggregate(
""".stripMargin
ExprCode(ev.code + initVars, isNull, value)
}
+ val initBufVar = evaluateVariables(bufVars)
// generate variables for output
- val bufferAttrs = functions.flatMap(_.aggBufferAttributes)
val (resultVars, genResult) = if (modes.contains(Final) || modes.contains(Complete)) {
// evaluate aggregate results
ctx.currentVars = bufVars
val aggResults = functions.map(_.evaluateExpression).map { e =>
- BindReferences.bindReference(e, bufferAttrs).gen(ctx)
+ BindReferences.bindReference(e, aggregateBufferAttributes).gen(ctx)
}
+ val evaluateAggResults = evaluateVariables(aggResults)
// evaluate result expressions
ctx.currentVars = aggResults
val resultVars = resultExpressions.map { e =>
BindReferences.bindReference(e, aggregateAttributes).gen(ctx)
}
(resultVars, s"""
- | ${aggResults.map(_.code).mkString("\n")}
- | ${resultVars.map(_.code).mkString("\n")}
+ |$evaluateAggResults
+ |${evaluateVariables(resultVars)}
""".stripMargin)
} else if (modes.contains(Partial) || modes.contains(PartialMerge)) {
// output the aggregate buffer directly
@@ -188,7 +191,7 @@ case class TungstenAggregate(
} else {
// no aggregate function, the result should be literals
val resultVars = resultExpressions.map(_.gen(ctx))
- (resultVars, resultVars.map(_.code).mkString("\n"))
+ (resultVars, evaluateVariables(resultVars))
}
val doAgg = ctx.freshName("doAggregateWithoutKey")
@@ -196,7 +199,7 @@ case class TungstenAggregate(
s"""
| private void $doAgg() throws java.io.IOException {
| // initialize aggregation buffer
- | ${bufVars.map(_.code).mkString("\n")}
+ | $initBufVar
|
| ${child.asInstanceOf[CodegenSupport].produce(ctx, this)}
| }
@@ -204,7 +207,7 @@ case class TungstenAggregate(
val numOutput = metricTerm(ctx, "numOutputRows")
s"""
- | if (!$initAgg) {
+ | while (!$initAgg) {
| $initAgg = true;
| $doAgg();
|
@@ -241,7 +244,7 @@ case class TungstenAggregate(
}
s"""
| // do aggregate
- | ${aggVals.map(_.code).mkString("\n").trim}
+ | ${evaluateVariables(aggVals)}
| // update aggregation buffer
| ${updates.mkString("\n").trim}
""".stripMargin
@@ -252,8 +255,7 @@ case class TungstenAggregate(
private val declFunctions = aggregateExpressions.map(_.aggregateFunction)
.filter(_.isInstanceOf[DeclarativeAggregate])
.map(_.asInstanceOf[DeclarativeAggregate])
- private val bufferAttributes = declFunctions.flatMap(_.aggBufferAttributes)
- private val bufferSchema = StructType.fromAttributes(bufferAttributes)
+ private val bufferSchema = StructType.fromAttributes(aggregateBufferAttributes)
// The name for HashMap
private var hashMapTerm: String = _
@@ -318,7 +320,7 @@ case class TungstenAggregate(
val mergeExpr = declFunctions.flatMap(_.mergeExpressions)
val mergeProjection = newMutableProjection(
mergeExpr,
- bufferAttributes ++ declFunctions.flatMap(_.inputAggBufferAttributes),
+ aggregateBufferAttributes ++ declFunctions.flatMap(_.inputAggBufferAttributes),
subexpressionEliminationEnabled)()
val joinedRow = new JoinedRow()
@@ -380,15 +382,18 @@ case class TungstenAggregate(
val keyVars = groupingExpressions.zipWithIndex.map { case (e, i) =>
BoundReference(i, e.dataType, e.nullable).gen(ctx)
}
+ val evaluateKeyVars = evaluateVariables(keyVars)
ctx.INPUT_ROW = bufferTerm
- val bufferVars = bufferAttributes.zipWithIndex.map { case (e, i) =>
+ val bufferVars = aggregateBufferAttributes.zipWithIndex.map { case (e, i) =>
BoundReference(i, e.dataType, e.nullable).gen(ctx)
}
+ val evaluateBufferVars = evaluateVariables(bufferVars)
// evaluate the aggregation result
ctx.currentVars = bufferVars
val aggResults = declFunctions.map(_.evaluateExpression).map { e =>
- BindReferences.bindReference(e, bufferAttributes).gen(ctx)
+ BindReferences.bindReference(e, aggregateBufferAttributes).gen(ctx)
}
+ val evaluateAggResults = evaluateVariables(aggResults)
// generate the final result
ctx.currentVars = keyVars ++ aggResults
val inputAttrs = groupingAttributes ++ aggregateAttributes
@@ -396,11 +401,9 @@ case class TungstenAggregate(
BindReferences.bindReference(e, inputAttrs).gen(ctx)
}
s"""
- ${keyVars.map(_.code).mkString("\n")}
- ${bufferVars.map(_.code).mkString("\n")}
- ${aggResults.map(_.code).mkString("\n")}
- ${resultVars.map(_.code).mkString("\n")}
-
+ $evaluateKeyVars
+ $evaluateBufferVars
+ $evaluateAggResults
${consume(ctx, resultVars)}
"""
@@ -422,10 +425,7 @@ case class TungstenAggregate(
val eval = resultExpressions.map{ e =>
BindReferences.bindReference(e, groupingAttributes).gen(ctx)
}
- s"""
- ${eval.map(_.code).mkString("\n")}
- ${consume(ctx, eval)}
- """
+ consume(ctx, eval)
}
}
@@ -508,8 +508,8 @@ case class TungstenAggregate(
ctx.currentVars = input
val hashEval = BindReferences.bindReference(hashExpr, child.output).gen(ctx)
- val inputAttr = bufferAttributes ++ child.output
- ctx.currentVars = new Array[ExprCode](bufferAttributes.length) ++ input
+ val inputAttr = aggregateBufferAttributes ++ child.output
+ ctx.currentVars = new Array[ExprCode](aggregateBufferAttributes.length) ++ input
ctx.INPUT_ROW = buffer
// TODO: support subexpression elimination
val evals = updateExpr.map(BindReferences.bindReference(_, inputAttr).gen(ctx))
@@ -557,7 +557,7 @@ case class TungstenAggregate(
$incCounter
// evaluate aggregate function
- ${evals.map(_.code).mkString("\n").trim}
+ ${evaluateVariables(evals)}
// update aggregate buffer
${updates.mkString("\n").trim}
"""
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
index b2f443c0e9..4a9e736f7a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
@@ -39,15 +39,26 @@ case class Project(projectList: Seq[NamedExpression], child: SparkPlan)
child.asInstanceOf[CodegenSupport].produce(ctx, this)
}
+ override def usedInputs: AttributeSet = {
+ // only the attributes those are used at least twice should be evaluated before this plan,
+ // otherwise we could defer the evaluation until output attribute is actually used.
+ val usedExprIds = projectList.flatMap(_.collect {
+ case a: Attribute => a.exprId
+ })
+ val usedMoreThanOnce = usedExprIds.groupBy(id => id).filter(_._2.size > 1).keySet
+ references.filter(a => usedMoreThanOnce.contains(a.exprId))
+ }
+
override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = {
val exprs = projectList.map(x =>
ExpressionCanonicalizer.execute(BindReferences.bindReference(x, child.output)))
ctx.currentVars = input
- val output = exprs.map(_.gen(ctx))
+ val resultVars = exprs.map(_.gen(ctx))
+ // Evaluation of non-deterministic expressions can't be deferred.
+ val nonDeterministicAttrs = projectList.filterNot(_.deterministic).map(_.toAttribute)
s"""
- | ${output.map(_.code).mkString("\n")}
- |
- | ${consume(ctx, output)}
+ |${evaluateRequiredVariables(output, resultVars, AttributeSet(nonDeterministicAttrs))}
+ |${consume(ctx, resultVars)}
""".stripMargin
}
@@ -89,11 +100,10 @@ case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode wit
s""
}
s"""
- | ${eval.code}
- | if ($nullCheck ${eval.value}) {
- | $numOutput.add(1);
- | ${consume(ctx, ctx.currentVars)}
- | }
+ |${eval.code}
+ |if (!($nullCheck ${eval.value})) continue;
+ |$numOutput.add(1);
+ |${consume(ctx, ctx.currentVars)}
""".stripMargin
}
@@ -228,15 +238,13 @@ case class Range(
| }
| }
|
- | while (!$overflow && $checkEnd) {
+ | while (!$overflow && $checkEnd && !shouldStop()) {
| long $value = $number;
| $number += ${step}L;
| if ($number < $value ^ ${step}L < 0) {
| $overflow = true;
| }
| ${consume(ctx, Seq(ev))}
- |
- | if (shouldStop()) return;
| }
""".stripMargin
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
index 6699dbafe7..c52662a61e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
@@ -190,40 +190,38 @@ case class BroadcastHashJoin(
val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input)
val matched = ctx.freshName("matched")
val buildVars = genBuildSideVars(ctx, matched)
- val resultVars = buildSide match {
- case BuildLeft => buildVars ++ input
- case BuildRight => input ++ buildVars
- }
val numOutput = metricTerm(ctx, "numOutputRows")
- val outputCode = if (condition.isDefined) {
+ val checkCondition = if (condition.isDefined) {
+ val expr = condition.get
+ // evaluate the variables from build side that used by condition
+ val eval = evaluateRequiredVariables(buildPlan.output, buildVars, expr.references)
// filter the output via condition
- ctx.currentVars = resultVars
- val ev = BindReferences.bindReference(condition.get, this.output).gen(ctx)
+ ctx.currentVars = input ++ buildVars
+ val ev = BindReferences.bindReference(expr, streamedPlan.output ++ buildPlan.output).gen(ctx)
s"""
+ |$eval
|${ev.code}
- |if (!${ev.isNull} && ${ev.value}) {
- | $numOutput.add(1);
- | ${consume(ctx, resultVars)}
- |}
+ |if (${ev.isNull} || !${ev.value}) continue;
""".stripMargin
} else {
- s"""
- |$numOutput.add(1);
- |${consume(ctx, resultVars)}
- """.stripMargin
+ ""
}
+ val resultVars = buildSide match {
+ case BuildLeft => buildVars ++ input
+ case BuildRight => input ++ buildVars
+ }
if (broadcastRelation.value.isInstanceOf[UniqueHashedRelation]) {
s"""
|// generate join key for stream side
|${keyEv.code}
|// find matches from HashedRelation
|UnsafeRow $matched = $anyNull ? null: (UnsafeRow)$relationTerm.getValue(${keyEv.value});
- |if ($matched != null) {
- | ${buildVars.map(_.code).mkString("\n")}
- | $outputCode
- |}
+ |if ($matched == null) continue;
+ |$checkCondition
+ |$numOutput.add(1);
+ |${consume(ctx, resultVars)}
""".stripMargin
} else {
@@ -236,13 +234,13 @@ case class BroadcastHashJoin(
|${keyEv.code}
|// find matches from HashRelation
|$bufferType $matches = $anyNull ? null : ($bufferType)$relationTerm.get(${keyEv.value});
- |if ($matches != null) {
- | int $size = $matches.size();
- | for (int $i = 0; $i < $size; $i++) {
- | UnsafeRow $matched = (UnsafeRow) $matches.apply($i);
- | ${buildVars.map(_.code).mkString("\n")}
- | $outputCode
- | }
+ |if ($matches == null) continue;
+ |int $size = $matches.size();
+ |for (int $i = 0; $i < $size; $i++) {
+ | UnsafeRow $matched = (UnsafeRow) $matches.apply($i);
+ | $checkCondition
+ | $numOutput.add(1);
+ | ${consume(ctx, resultVars)}
|}
""".stripMargin
}
@@ -257,21 +255,21 @@ case class BroadcastHashJoin(
val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input)
val matched = ctx.freshName("matched")
val buildVars = genBuildSideVars(ctx, matched)
- val resultVars = buildSide match {
- case BuildLeft => buildVars ++ input
- case BuildRight => input ++ buildVars
- }
val numOutput = metricTerm(ctx, "numOutputRows")
// filter the output via condition
val conditionPassed = ctx.freshName("conditionPassed")
val checkCondition = if (condition.isDefined) {
- ctx.currentVars = resultVars
- val ev = BindReferences.bindReference(condition.get, this.output).gen(ctx)
+ val expr = condition.get
+ // evaluate the variables from build side that used by condition
+ val eval = evaluateRequiredVariables(buildPlan.output, buildVars, expr.references)
+ ctx.currentVars = input ++ buildVars
+ val ev = BindReferences.bindReference(expr, streamedPlan.output ++ buildPlan.output).gen(ctx)
s"""
|boolean $conditionPassed = true;
+ |${eval.trim}
+ |${ev.code}
|if ($matched != null) {
- | ${ev.code}
| $conditionPassed = !${ev.isNull} && ${ev.value};
|}
""".stripMargin
@@ -279,17 +277,21 @@ case class BroadcastHashJoin(
s"final boolean $conditionPassed = true;"
}
+ val resultVars = buildSide match {
+ case BuildLeft => buildVars ++ input
+ case BuildRight => input ++ buildVars
+ }
if (broadcastRelation.value.isInstanceOf[UniqueHashedRelation]) {
s"""
|// generate join key for stream side
|${keyEv.code}
|// find matches from HashedRelation
|UnsafeRow $matched = $anyNull ? null: (UnsafeRow)$relationTerm.getValue(${keyEv.value});
- |${buildVars.map(_.code).mkString("\n")}
|${checkCondition.trim}
|if (!$conditionPassed) {
- | // reset to null
- | ${buildVars.map(v => s"${v.isNull} = true;").mkString("\n")}
+ | $matched = null;
+ | // reset the variables those are already evaluated.
+ | ${buildVars.filter(_.code == "").map(v => s"${v.isNull} = true;").mkString("\n")}
|}
|$numOutput.add(1);
|${consume(ctx, resultVars)}
@@ -311,13 +313,11 @@ case class BroadcastHashJoin(
|// the last iteration of this loop is to emit an empty row if there is no matched rows.
|for (int $i = 0; $i <= $size; $i++) {
| UnsafeRow $matched = $i < $size ? (UnsafeRow) $matches.apply($i) : null;
- | ${buildVars.map(_.code).mkString("\n")}
| ${checkCondition.trim}
- | if ($conditionPassed && ($i < $size || !$found)) {
- | $found = true;
- | $numOutput.add(1);
- | ${consume(ctx, resultVars)}
- | }
+ | if (!$conditionPassed || ($i == $size && $found)) continue;
+ | $found = true;
+ | $numOutput.add(1);
+ | ${consume(ctx, resultVars)}
|}
""".stripMargin
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala
index 7ec4027188..cffd6f6032 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala
@@ -306,11 +306,11 @@ case class SortMergeJoin(
val (used, notUsed) = attributes.zip(variables).partition{ case (a, ev) =>
condRefs.contains(a)
}
- val beforeCond = used.map(_._2.code).mkString("\n")
- val afterCond = notUsed.map(_._2.code).mkString("\n")
+ val beforeCond = evaluateVariables(used.map(_._2))
+ val afterCond = evaluateVariables(notUsed.map(_._2))
(beforeCond, afterCond)
} else {
- (variables.map(_.code).mkString("\n"), "")
+ (evaluateVariables(variables), "")
}
}
@@ -326,41 +326,48 @@ case class SortMergeJoin(
val leftVars = createLeftVars(ctx, leftRow)
val rightRow = ctx.freshName("rightRow")
val rightVars = createRightVar(ctx, rightRow)
- val resultVars = leftVars ++ rightVars
-
- // Check condition
- ctx.currentVars = resultVars
- val cond = if (condition.isDefined) {
- BindReferences.bindReference(condition.get, output).gen(ctx)
- } else {
- ExprCode("", "false", "true")
- }
- // Split the code of creating variables based on whether it's used by condition or not.
- val loaded = ctx.freshName("loaded")
- val (leftBefore, leftAfter) = splitVarsByCondition(left.output, leftVars)
- val (rightBefore, rightAfter) = splitVarsByCondition(right.output, rightVars)
-
val size = ctx.freshName("size")
val i = ctx.freshName("i")
val numOutput = metricTerm(ctx, "numOutputRows")
+ val (beforeLoop, condCheck) = if (condition.isDefined) {
+ // Split the code of creating variables based on whether it's used by condition or not.
+ val loaded = ctx.freshName("loaded")
+ val (leftBefore, leftAfter) = splitVarsByCondition(left.output, leftVars)
+ val (rightBefore, rightAfter) = splitVarsByCondition(right.output, rightVars)
+ // Generate code for condition
+ ctx.currentVars = leftVars ++ rightVars
+ val cond = BindReferences.bindReference(condition.get, output).gen(ctx)
+ // evaluate the columns those used by condition before loop
+ val before = s"""
+ |boolean $loaded = false;
+ |$leftBefore
+ """.stripMargin
+
+ val checking = s"""
+ |$rightBefore
+ |${cond.code}
+ |if (${cond.isNull} || !${cond.value}) continue;
+ |if (!$loaded) {
+ | $loaded = true;
+ | $leftAfter
+ |}
+ |$rightAfter
+ """.stripMargin
+ (before, checking)
+ } else {
+ (evaluateVariables(leftVars), "")
+ }
+
s"""
|while (findNextInnerJoinRows($leftInput, $rightInput)) {
| int $size = $matches.size();
- | boolean $loaded = false;
- | $leftBefore
+ | ${beforeLoop.trim}
| for (int $i = 0; $i < $size; $i ++) {
| InternalRow $rightRow = (InternalRow) $matches.get($i);
- | $rightBefore
- | ${cond.code}
- | if (${cond.isNull} || !${cond.value}) continue;
- | if (!$loaded) {
- | $loaded = true;
- | $leftAfter
- | }
- | $rightAfter
+ | ${condCheck.trim}
| $numOutput.add(1);
- | ${consume(ctx, resultVars)}
+ | ${consume(ctx, leftVars ++ rightVars)}
| }
| if (shouldStop()) return;
|}