aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2016-02-23 15:00:10 -0800
committerDavies Liu <davies.liu@gmail.com>2016-02-23 15:00:10 -0800
commit9cdd867da978629ea2f61f94e3c346fa0bfecf0e (patch)
treeb8aa4cceb32a21911a7e8ed40a9c50859df25266 /sql
parentc481bdf512f09060c9b9f341a5ce9fce00427d08 (diff)
downloadspark-9cdd867da978629ea2f61f94e3c346fa0bfecf0e.tar.gz
spark-9cdd867da978629ea2f61f94e3c346fa0bfecf0e.tar.bz2
spark-9cdd867da978629ea2f61f94e3c346fa0bfecf0e.zip
[SPARK-13373] [SQL] generate sort merge join
## What changes were proposed in this pull request? Generates code for SortMergeJoin. ## How was the this patch tested? Unit tests and manually tested with TPCDS Q72, which showed 70% performance improvements (from 42s to 25s), but micro benchmark only show minor improvements, it may depends the distribution of data and number of columns. Author: Davies Liu <davies@databricks.com> Closes #11248 from davies/gen_smj.
Diffstat (limited to 'sql')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/BufferedRowIterator.java23
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala71
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala20
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala1
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala247
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala34
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala3
10 files changed, 359 insertions, 52 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/BufferedRowIterator.java b/sql/core/src/main/scala/org/apache/spark/sql/execution/BufferedRowIterator.java
index ea20115770..1d1d7edb24 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/BufferedRowIterator.java
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/BufferedRowIterator.java
@@ -29,12 +29,9 @@ import org.apache.spark.sql.catalyst.expressions.UnsafeRow;
/**
* An iterator interface used to pull the output from generated function for multiple operators
* (whole stage codegen).
- *
- * TODO: replaced it by batched columnar format.
*/
-public class BufferedRowIterator {
+public abstract class BufferedRowIterator {
protected LinkedList<InternalRow> currentRows = new LinkedList<>();
- protected Iterator<InternalRow> input;
// used when there is no column in output
protected UnsafeRow unsafeRow = new UnsafeRow(0);
@@ -49,8 +46,16 @@ public class BufferedRowIterator {
return currentRows.remove();
}
- public void setInput(Iterator<InternalRow> iter) {
- input = iter;
+ /**
+ * Initializes from array of iterators of InternalRow.
+ */
+ public abstract void init(Iterator<InternalRow> iters[]);
+
+ /**
+ * Append a row to currentRows.
+ */
+ protected void append(InternalRow row) {
+ currentRows.add(row);
}
/**
@@ -74,9 +79,5 @@ public class BufferedRowIterator {
*
* After it's called, if currentRow is still null, it means no more rows left.
*/
- protected void processNext() throws IOException {
- if (input.hasNext()) {
- currentRows.add(input.next());
- }
- }
+ protected abstract void processNext() throws IOException;
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala
index a9e77abbda..12998a38f5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala
@@ -85,8 +85,8 @@ case class Expand(
}
}
- override def upstream(): RDD[InternalRow] = {
- child.asInstanceOf[CodegenSupport].upstream()
+ override def upstreams(): Seq[RDD[InternalRow]] = {
+ child.asInstanceOf[CodegenSupport].upstreams()
}
protected override def doProduce(ctx: CodegenContext): String = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
index d79b547137..08c52e5f43 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
@@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.plans.physical.Partitioning
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.util.toCommentSafeString
import org.apache.spark.sql.execution.aggregate.TungstenAggregate
-import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, BuildLeft, BuildRight}
+import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, BuildLeft, BuildRight, SortMergeJoin}
import org.apache.spark.sql.execution.metric.LongSQLMetricValue
/**
@@ -40,7 +40,8 @@ trait CodegenSupport extends SparkPlan {
/** Prefix used in the current operator's variable names. */
private def variablePrefix: String = this match {
case _: TungstenAggregate => "agg"
- case _: BroadcastHashJoin => "join"
+ case _: BroadcastHashJoin => "bhj"
+ case _: SortMergeJoin => "smj"
case _ => nodeName.toLowerCase
}
@@ -68,9 +69,11 @@ trait CodegenSupport extends SparkPlan {
private var parent: CodegenSupport = null
/**
- * Returns the RDD of InternalRow which generates the input rows.
+ * Returns all the RDDs of InternalRow which generates the input rows.
+ *
+ * Note: right now we support up to two RDDs.
*/
- def upstream(): RDD[InternalRow]
+ def upstreams(): Seq[RDD[InternalRow]]
/**
* Returns Java source code to process the rows from upstream.
@@ -179,19 +182,23 @@ case class InputAdapter(child: SparkPlan) extends LeafNode with CodegenSupport {
override def supportCodegen: Boolean = false
- override def upstream(): RDD[InternalRow] = {
- child.execute()
+ override def upstreams(): Seq[RDD[InternalRow]] = {
+ child.execute() :: Nil
}
override def doProduce(ctx: CodegenContext): String = {
+ val input = ctx.freshName("input")
+ // Right now, InputAdapter is only used when there is one upstream.
+ ctx.addMutableState("scala.collection.Iterator", input, s"$input = inputs[0];")
+
val exprs = output.zipWithIndex.map(x => new BoundReference(x._2, x._1.dataType, true))
val row = ctx.freshName("row")
ctx.INPUT_ROW = row
ctx.currentVars = null
val columns = exprs.map(_.gen(ctx))
s"""
- | while (input.hasNext()) {
- | InternalRow $row = (InternalRow) input.next();
+ | while ($input.hasNext()) {
+ | InternalRow $row = (InternalRow) $input.next();
| ${columns.map(_.code).mkString("\n").trim}
| ${consume(ctx, columns).trim}
| if (shouldStop()) {
@@ -215,7 +222,7 @@ case class InputAdapter(child: SparkPlan) extends LeafNode with CodegenSupport {
*
* -> execute()
* |
- * doExecute() ---------> upstream() -------> upstream() ------> execute()
+ * doExecute() ---------> upstreams() -------> upstreams() ------> execute()
* |
* -----------------> produce()
* |
@@ -267,6 +274,9 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan])
public GeneratedIterator(Object[] references) {
this.references = references;
+ }
+
+ public void init(scala.collection.Iterator inputs[]) {
${ctx.initMutableStates()}
}
@@ -283,19 +293,33 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan])
// println(s"${CodeFormatter.format(cleanedSource)}")
CodeGenerator.compile(cleanedSource)
- plan.upstream().mapPartitions { iter =>
-
- val clazz = CodeGenerator.compile(source)
- val buffer = clazz.generate(references).asInstanceOf[BufferedRowIterator]
- buffer.setInput(iter)
- new Iterator[InternalRow] {
- override def hasNext: Boolean = buffer.hasNext
- override def next: InternalRow = buffer.next()
+ val rdds = plan.upstreams()
+ assert(rdds.size <= 2, "Up to two upstream RDDs can be supported")
+ if (rdds.length == 1) {
+ rdds.head.mapPartitions { iter =>
+ val clazz = CodeGenerator.compile(cleanedSource)
+ val buffer = clazz.generate(references).asInstanceOf[BufferedRowIterator]
+ buffer.init(Array(iter))
+ new Iterator[InternalRow] {
+ override def hasNext: Boolean = buffer.hasNext
+ override def next: InternalRow = buffer.next()
+ }
+ }
+ } else {
+ // Right now, we support up to two upstreams.
+ rdds.head.zipPartitions(rdds(1)) { (leftIter, rightIter) =>
+ val clazz = CodeGenerator.compile(cleanedSource)
+ val buffer = clazz.generate(references).asInstanceOf[BufferedRowIterator]
+ buffer.init(Array(leftIter, rightIter))
+ new Iterator[InternalRow] {
+ override def hasNext: Boolean = buffer.hasNext
+ override def next: InternalRow = buffer.next()
+ }
}
}
}
- override def upstream(): RDD[InternalRow] = {
+ override def upstreams(): Seq[RDD[InternalRow]] = {
throw new UnsupportedOperationException
}
@@ -312,7 +336,7 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan])
if (row != null) {
// There is an UnsafeRow already
s"""
- | currentRows.add($row.copy());
+ |append($row.copy());
""".stripMargin
} else {
assert(input != null)
@@ -324,13 +348,13 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan])
ctx.currentVars = input
val code = GenerateUnsafeProjection.createCode(ctx, colExprs, false)
s"""
- | ${code.code.trim}
- | currentRows.add(${code.value}.copy());
+ |${code.code.trim}
+ |append(${code.value}.copy());
""".stripMargin
} else {
// There is no columns
s"""
- | currentRows.add(unsafeRow);
+ |append(unsafeRow);
""".stripMargin
}
}
@@ -402,6 +426,9 @@ private[sql] case class CollapseCodegenStages(sqlContext: SQLContext) extends Ru
b.copy(left = apply(left))
case b @ BroadcastHashJoin(_, _, _, BuildRight, _, left, right) =>
b.copy(right = apply(right))
+ case j @ SortMergeJoin(_, _, _, left, right) =>
+ // The children of SortMergeJoin should do codegen separately.
+ j.copy(left = apply(left), right = apply(right))
case p if !supportCodegen(p) =>
val input = apply(p) // collapse them recursively
inputs += input
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
index 852203f374..a46722963a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
@@ -121,8 +121,8 @@ case class TungstenAggregate(
!aggregateExpressions.exists(_.aggregateFunction.isInstanceOf[ImperativeAggregate])
}
- override def upstream(): RDD[InternalRow] = {
- child.asInstanceOf[CodegenSupport].upstream()
+ override def upstreams(): Seq[RDD[InternalRow]] = {
+ child.asInstanceOf[CodegenSupport].upstreams()
}
protected override def doProduce(ctx: CodegenContext): String = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
index 55bddd196e..b2f443c0e9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
@@ -31,8 +31,8 @@ case class Project(projectList: Seq[NamedExpression], child: SparkPlan)
override def output: Seq[Attribute] = projectList.map(_.toAttribute)
- override def upstream(): RDD[InternalRow] = {
- child.asInstanceOf[CodegenSupport].upstream()
+ override def upstreams(): Seq[RDD[InternalRow]] = {
+ child.asInstanceOf[CodegenSupport].upstreams()
}
protected override def doProduce(ctx: CodegenContext): String = {
@@ -69,8 +69,8 @@ case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode wit
private[sql] override lazy val metrics = Map(
"numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
- override def upstream(): RDD[InternalRow] = {
- child.asInstanceOf[CodegenSupport].upstream()
+ override def upstreams(): Seq[RDD[InternalRow]] = {
+ child.asInstanceOf[CodegenSupport].upstreams()
}
protected override def doProduce(ctx: CodegenContext): String = {
@@ -156,8 +156,9 @@ case class Range(
private[sql] override lazy val metrics = Map(
"numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
- override def upstream(): RDD[InternalRow] = {
- sqlContext.sparkContext.parallelize(0 until numSlices, numSlices).map(i => InternalRow(i))
+ override def upstreams(): Seq[RDD[InternalRow]] = {
+ sqlContext.sparkContext.parallelize(0 until numSlices, numSlices)
+ .map(i => InternalRow(i)) :: Nil
}
protected override def doProduce(ctx: CodegenContext): String = {
@@ -213,12 +214,15 @@ case class Range(
| }
""".stripMargin)
+ val input = ctx.freshName("input")
+ // Right now, Range is only used when there is one upstream.
+ ctx.addMutableState("scala.collection.Iterator", input, s"$input = inputs[0];")
s"""
| // initialize Range
| if (!$initTerm) {
| $initTerm = true;
- | if (input.hasNext()) {
- | initRange(((InternalRow) input.next()).getInt(0));
+ | if ($input.hasNext()) {
+ | initRange(((InternalRow) $input.next()).getInt(0));
| } else {
| return;
| }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
index ddc08822f3..6699dbafe7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
@@ -99,8 +99,8 @@ case class BroadcastHashJoin(
}
}
- override def upstream(): RDD[InternalRow] = {
- streamedPlan.asInstanceOf[CodegenSupport].upstream()
+ override def upstreams(): Seq[RDD[InternalRow]] = {
+ streamedPlan.asInstanceOf[CodegenSupport].upstreams()
}
override def doProduce(ctx: CodegenContext): String = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala
index e417079b61..fabd2fbe1e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala
@@ -27,7 +27,6 @@ import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.util.CompletionIterator
import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter
-
/**
* An optimized CartesianRDD for UnsafeRow, which will cache the rows from second child RDD,
* will be much faster than building the right partition for every row in left RDD, it also
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala
index cd8a5670e2..7ec4027188 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala
@@ -22,9 +22,10 @@ import scala.collection.mutable.ArrayBuffer
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.plans.physical._
-import org.apache.spark.sql.execution.{BinaryNode, RowIterator, SparkPlan}
-import org.apache.spark.sql.execution.metric.{LongSQLMetric, SQLMetrics}
+import org.apache.spark.sql.execution.{BinaryNode, CodegenSupport, RowIterator, SparkPlan}
+import org.apache.spark.sql.execution.metric.SQLMetrics
/**
* Performs an sort merge join of two child relations.
@@ -34,7 +35,7 @@ case class SortMergeJoin(
rightKeys: Seq[Expression],
condition: Option[Expression],
left: SparkPlan,
- right: SparkPlan) extends BinaryNode {
+ right: SparkPlan) extends BinaryNode with CodegenSupport {
override private[sql] lazy val metrics = Map(
"numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
@@ -125,6 +126,246 @@ case class SortMergeJoin(
}.toScala
}
}
+
+ override def upstreams(): Seq[RDD[InternalRow]] = {
+ left.execute() :: right.execute() :: Nil
+ }
+
+ private def createJoinKey(
+ ctx: CodegenContext,
+ row: String,
+ keys: Seq[Expression],
+ input: Seq[Attribute]): Seq[ExprCode] = {
+ ctx.INPUT_ROW = row
+ keys.map(BindReferences.bindReference(_, input).gen(ctx))
+ }
+
+ private def copyKeys(ctx: CodegenContext, vars: Seq[ExprCode]): Seq[ExprCode] = {
+ vars.zipWithIndex.map { case (ev, i) =>
+ val value = ctx.freshName("value")
+ ctx.addMutableState(ctx.javaType(leftKeys(i).dataType), value, "")
+ val code =
+ s"""
+ |$value = ${ev.value};
+ """.stripMargin
+ ExprCode(code, "false", value)
+ }
+ }
+
+ private def genComparision(ctx: CodegenContext, a: Seq[ExprCode], b: Seq[ExprCode]): String = {
+ val comparisons = a.zip(b).zipWithIndex.map { case ((l, r), i) =>
+ s"""
+ |if (comp == 0) {
+ | comp = ${ctx.genComp(leftKeys(i).dataType, l.value, r.value)};
+ |}
+ """.stripMargin.trim
+ }
+ s"""
+ |comp = 0;
+ |${comparisons.mkString("\n")}
+ """.stripMargin
+ }
+
+ /**
+ * Generate a function to scan both left and right to find a match, returns the term for
+ * matched one row from left side and buffered rows from right side.
+ */
+ private def genScanner(ctx: CodegenContext): (String, String) = {
+ // Create class member for next row from both sides.
+ val leftRow = ctx.freshName("leftRow")
+ ctx.addMutableState("InternalRow", leftRow, "")
+ val rightRow = ctx.freshName("rightRow")
+ ctx.addMutableState("InternalRow", rightRow, s"$rightRow = null;")
+
+ // Create variables for join keys from both sides.
+ val leftKeyVars = createJoinKey(ctx, leftRow, leftKeys, left.output)
+ val leftAnyNull = leftKeyVars.map(_.isNull).mkString(" || ")
+ val rightKeyTmpVars = createJoinKey(ctx, rightRow, rightKeys, right.output)
+ val rightAnyNull = rightKeyTmpVars.map(_.isNull).mkString(" || ")
+ // Copy the right key as class members so they could be used in next function call.
+ val rightKeyVars = copyKeys(ctx, rightKeyTmpVars)
+
+ // A list to hold all matched rows from right side.
+ val matches = ctx.freshName("matches")
+ val clsName = classOf[java.util.ArrayList[InternalRow]].getName
+ ctx.addMutableState(clsName, matches, s"$matches = new $clsName();")
+ // Copy the left keys as class members so they could be used in next function call.
+ val matchedKeyVars = copyKeys(ctx, leftKeyVars)
+
+ ctx.addNewFunction("findNextInnerJoinRows",
+ s"""
+ |private boolean findNextInnerJoinRows(
+ | scala.collection.Iterator leftIter,
+ | scala.collection.Iterator rightIter) {
+ | $leftRow = null;
+ | int comp = 0;
+ | while ($leftRow == null) {
+ | if (!leftIter.hasNext()) return false;
+ | $leftRow = (InternalRow) leftIter.next();
+ | ${leftKeyVars.map(_.code).mkString("\n")}
+ | if ($leftAnyNull) {
+ | $leftRow = null;
+ | continue;
+ | }
+ | if (!$matches.isEmpty()) {
+ | ${genComparision(ctx, leftKeyVars, matchedKeyVars)}
+ | if (comp == 0) {
+ | return true;
+ | }
+ | $matches.clear();
+ | }
+ |
+ | do {
+ | if ($rightRow == null) {
+ | if (!rightIter.hasNext()) {
+ | ${matchedKeyVars.map(_.code).mkString("\n")}
+ | return !$matches.isEmpty();
+ | }
+ | $rightRow = (InternalRow) rightIter.next();
+ | ${rightKeyTmpVars.map(_.code).mkString("\n")}
+ | if ($rightAnyNull) {
+ | $rightRow = null;
+ | continue;
+ | }
+ | ${rightKeyVars.map(_.code).mkString("\n")}
+ | }
+ | ${genComparision(ctx, leftKeyVars, rightKeyVars)}
+ | if (comp > 0) {
+ | $rightRow = null;
+ | } else if (comp < 0) {
+ | if (!$matches.isEmpty()) {
+ | ${matchedKeyVars.map(_.code).mkString("\n")}
+ | return true;
+ | }
+ | $leftRow = null;
+ | } else {
+ | $matches.add($rightRow.copy());
+ | $rightRow = null;;
+ | }
+ | } while ($leftRow != null);
+ | }
+ | return false; // unreachable
+ |}
+ """.stripMargin)
+
+ (leftRow, matches)
+ }
+
+ /**
+ * Creates variables for left part of result row.
+ *
+ * In order to defer the access after condition and also only access once in the loop,
+ * the variables should be declared separately from accessing the columns, we can't use the
+ * codegen of BoundReference here.
+ */
+ private def createLeftVars(ctx: CodegenContext, leftRow: String): Seq[ExprCode] = {
+ ctx.INPUT_ROW = leftRow
+ left.output.zipWithIndex.map { case (a, i) =>
+ val value = ctx.freshName("value")
+ val valueCode = ctx.getValue(leftRow, a.dataType, i.toString)
+ // declare it as class member, so we can access the column before or in the loop.
+ ctx.addMutableState(ctx.javaType(a.dataType), value, "")
+ if (a.nullable) {
+ val isNull = ctx.freshName("isNull")
+ ctx.addMutableState("boolean", isNull, "")
+ val code =
+ s"""
+ |$isNull = $leftRow.isNullAt($i);
+ |$value = $isNull ? ${ctx.defaultValue(a.dataType)} : ($valueCode);
+ """.stripMargin
+ ExprCode(code, isNull, value)
+ } else {
+ ExprCode(s"$value = $valueCode;", "false", value)
+ }
+ }
+ }
+
+ /**
+ * Creates the variables for right part of result row, using BoundReference, since the right
+ * part are accessed inside the loop.
+ */
+ private def createRightVar(ctx: CodegenContext, rightRow: String): Seq[ExprCode] = {
+ ctx.INPUT_ROW = rightRow
+ right.output.zipWithIndex.map { case (a, i) =>
+ BoundReference(i, a.dataType, a.nullable).gen(ctx)
+ }
+ }
+
+ /**
+ * Splits variables based on whether it's used by condition or not, returns the code to create
+ * these variables before the condition and after the condition.
+ *
+ * Only a few columns are used by condition, then we can skip the accessing of those columns
+ * that are not used by condition also filtered out by condition.
+ */
+ private def splitVarsByCondition(
+ attributes: Seq[Attribute],
+ variables: Seq[ExprCode]): (String, String) = {
+ if (condition.isDefined) {
+ val condRefs = condition.get.references
+ val (used, notUsed) = attributes.zip(variables).partition{ case (a, ev) =>
+ condRefs.contains(a)
+ }
+ val beforeCond = used.map(_._2.code).mkString("\n")
+ val afterCond = notUsed.map(_._2.code).mkString("\n")
+ (beforeCond, afterCond)
+ } else {
+ (variables.map(_.code).mkString("\n"), "")
+ }
+ }
+
+ override def doProduce(ctx: CodegenContext): String = {
+ val leftInput = ctx.freshName("leftInput")
+ ctx.addMutableState("scala.collection.Iterator", leftInput, s"$leftInput = inputs[0];")
+ val rightInput = ctx.freshName("rightInput")
+ ctx.addMutableState("scala.collection.Iterator", rightInput, s"$rightInput = inputs[1];")
+
+ val (leftRow, matches) = genScanner(ctx)
+
+ // Create variables for row from both sides.
+ val leftVars = createLeftVars(ctx, leftRow)
+ val rightRow = ctx.freshName("rightRow")
+ val rightVars = createRightVar(ctx, rightRow)
+ val resultVars = leftVars ++ rightVars
+
+ // Check condition
+ ctx.currentVars = resultVars
+ val cond = if (condition.isDefined) {
+ BindReferences.bindReference(condition.get, output).gen(ctx)
+ } else {
+ ExprCode("", "false", "true")
+ }
+ // Split the code of creating variables based on whether it's used by condition or not.
+ val loaded = ctx.freshName("loaded")
+ val (leftBefore, leftAfter) = splitVarsByCondition(left.output, leftVars)
+ val (rightBefore, rightAfter) = splitVarsByCondition(right.output, rightVars)
+
+
+ val size = ctx.freshName("size")
+ val i = ctx.freshName("i")
+ val numOutput = metricTerm(ctx, "numOutputRows")
+ s"""
+ |while (findNextInnerJoinRows($leftInput, $rightInput)) {
+ | int $size = $matches.size();
+ | boolean $loaded = false;
+ | $leftBefore
+ | for (int $i = 0; $i < $size; $i ++) {
+ | InternalRow $rightRow = (InternalRow) $matches.get($i);
+ | $rightBefore
+ | ${cond.code}
+ | if (${cond.isNull} || !${cond.value}) continue;
+ | if (!$loaded) {
+ | $loaded = true;
+ | $leftAfter
+ | }
+ | $rightAfter
+ | $numOutput.add(1);
+ | ${consume(ctx, resultVars)}
+ | }
+ | if (shouldStop()) return;
+ |}
+ """.stripMargin
+ }
}
/**
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala
index bcac660a35..6d6cc0186a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala
@@ -38,6 +38,7 @@ import org.apache.spark.util.Benchmark
class BenchmarkWholeStageCodegen extends SparkFunSuite {
lazy val conf = new SparkConf().setMaster("local[1]").setAppName("benchmark")
.set("spark.sql.shuffle.partitions", "1")
+ .set("spark.sql.autoBroadcastJoinThreshold", "0")
lazy val sc = SparkContext.getOrCreate(conf)
lazy val sqlContext = SQLContext.getOrCreate(sc)
@@ -187,6 +188,39 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
*/
}
+ ignore("sort merge join") {
+ val N = 2 << 20
+ runBenchmark("merge join", N) {
+ val df1 = sqlContext.range(N).selectExpr(s"id * 2 as k1")
+ val df2 = sqlContext.range(N).selectExpr(s"id * 3 as k2")
+ df1.join(df2, col("k1") === col("k2")).count()
+ }
+
+ /**
+ Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
+ merge join: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
+ -------------------------------------------------------------------------------------------
+ merge join codegen=false 1588 / 1880 1.3 757.1 1.0X
+ merge join codegen=true 1477 / 1531 1.4 704.2 1.1X
+ */
+
+ runBenchmark("sort merge join", N) {
+ val df1 = sqlContext.range(N)
+ .selectExpr(s"(id * 15485863) % ${N*10} as k1")
+ val df2 = sqlContext.range(N)
+ .selectExpr(s"(id * 15485867) % ${N*10} as k2")
+ df1.join(df2, col("k1") === col("k2")).count()
+ }
+
+ /**
+ Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
+ sort merge join: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
+ -------------------------------------------------------------------------------------------
+ sort merge join codegen=false 3626 / 3667 0.6 1728.9 1.0X
+ sort merge join codegen=true 3405 / 3438 0.6 1623.8 1.1X
+ */
+ }
+
ignore("rube") {
val N = 5 << 20
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala
index a05a57c0f5..0ef42f45e3 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala
@@ -240,7 +240,8 @@ class BucketedReadSuite extends QueryTest with SQLTestUtils with TestHiveSinglet
withBucket(df1.write.format("parquet"), bucketSpecLeft).saveAsTable("bucketed_table1")
withBucket(df2.write.format("parquet"), bucketSpecRight).saveAsTable("bucketed_table2")
- withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "0") {
+ withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "0",
+ SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false") {
val t1 = hiveContext.table("bucketed_table1")
val t2 = hiveContext.table("bucketed_table2")
val joined = t1.join(t2, joinCondition(t1, t2, joinColumns))