aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2016-02-03 10:38:53 -0800
committerDavies Liu <davies.liu@gmail.com>2016-02-03 10:38:53 -0800
commitc4feec26eb677bfd3bfac38e5e28eae05279956e (patch)
treefd0da335914c9cbdd6bff0787f902ff42b8dac72 /sql
parente9eb248edfa81d75f99c9afc2063e6b3d9ee7392 (diff)
downloadspark-c4feec26eb677bfd3bfac38e5e28eae05279956e.tar.gz
spark-c4feec26eb677bfd3bfac38e5e28eae05279956e.tar.bz2
spark-c4feec26eb677bfd3bfac38e5e28eae05279956e.zip
[SPARK-12798] [SQL] generated BroadcastHashJoin
A row from stream side could match multiple rows on build side, the loop for these matched rows should not be interrupted when emitting a row, so we buffer the output rows in a linked list, check the termination condition on producer loop (for example, Range or Aggregate). Author: Davies Liu <davies@databricks.com> Closes #10989 from davies/gen_join.
Diffstat (limited to 'sql')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/BufferedRowIterator.java30
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala18
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala92
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala28
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala15
7 files changed, 169 insertions, 20 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/BufferedRowIterator.java b/sql/core/src/main/scala/org/apache/spark/sql/execution/BufferedRowIterator.java
index 6acf70dbba..ea20115770 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/BufferedRowIterator.java
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/BufferedRowIterator.java
@@ -18,9 +18,11 @@
package org.apache.spark.sql.execution;
import java.io.IOException;
+import java.util.LinkedList;
import scala.collection.Iterator;
+import org.apache.spark.TaskContext;
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.catalyst.expressions.UnsafeRow;
@@ -31,22 +33,20 @@ import org.apache.spark.sql.catalyst.expressions.UnsafeRow;
* TODO: replaced it by batched columnar format.
*/
public class BufferedRowIterator {
- protected InternalRow currentRow;
+ protected LinkedList<InternalRow> currentRows = new LinkedList<>();
protected Iterator<InternalRow> input;
// used when there is no column in output
protected UnsafeRow unsafeRow = new UnsafeRow(0);
public boolean hasNext() throws IOException {
- if (currentRow == null) {
+ if (currentRows.isEmpty()) {
processNext();
}
- return currentRow != null;
+ return !currentRows.isEmpty();
}
public InternalRow next() {
- InternalRow r = currentRow;
- currentRow = null;
- return r;
+ return currentRows.remove();
}
public void setInput(Iterator<InternalRow> iter) {
@@ -54,13 +54,29 @@ public class BufferedRowIterator {
}
/**
+ * Returns whether `processNext()` should stop processing next row from `input` or not.
+ *
+ * If it returns true, the caller should exit the loop (return from processNext()).
+ */
+ protected boolean shouldStop() {
+ return !currentRows.isEmpty();
+ }
+
+ /**
+ * Increase the peak execution memory for current task.
+ */
+ protected void incPeakExecutionMemory(long size) {
+ TaskContext.get().taskMetrics().incPeakExecutionMemory(size);
+ }
+
+ /**
* Processes the input until have a row as output (currentRow).
*
* After it's called, if currentRow is still null, it means no more rows left.
*/
protected void processNext() throws IOException {
if (input.hasNext()) {
- currentRow = input.next();
+ currentRows.add(input.next());
}
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
index 1475496907..131efea20f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
@@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.aggregate.TungstenAggregate
+import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, BuildLeft, BuildRight}
import org.apache.spark.util.Utils
/**
@@ -172,6 +173,9 @@ case class InputAdapter(child: SparkPlan) extends LeafNode with CodegenSupport {
| InternalRow $row = (InternalRow) input.next();
| ${columns.map(_.code).mkString("\n").trim}
| ${consume(ctx, columns).trim}
+ | if (shouldStop()) {
+ | return;
+ | }
| }
""".stripMargin
}
@@ -283,8 +287,7 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan])
if (row != null) {
// There is an UnsafeRow already
s"""
- | currentRow = $row;
- | return;
+ | currentRows.add($row.copy());
""".stripMargin
} else {
assert(input != null)
@@ -297,14 +300,12 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan])
val code = GenerateUnsafeProjection.createCode(ctx, colExprs, false)
s"""
| ${code.code.trim}
- | currentRow = ${code.value};
- | return;
+ | currentRows.add(${code.value}.copy());
""".stripMargin
} else {
// There is no columns
s"""
- | currentRow = unsafeRow;
- | return;
+ | currentRows.add(unsafeRow);
""".stripMargin
}
}
@@ -371,6 +372,11 @@ private[sql] case class CollapseCodegenStages(sqlContext: SQLContext) extends Ru
var inputs = ArrayBuffer[SparkPlan]()
val combined = plan.transform {
+ // The build side can't be compiled together
+ case b @ BroadcastHashJoin(_, _, BuildLeft, _, left, right) =>
+ b.copy(left = apply(left))
+ case b @ BroadcastHashJoin(_, _, BuildRight, _, left, right) =>
+ b.copy(right = apply(right))
case p if !supportCodegen(p) =>
val input = apply(p) // collapse them recursively
inputs += input
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
index d024477061..9d9f14f2dd 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
@@ -471,6 +471,8 @@ case class TungstenAggregate(
UnsafeRow $keyTerm = (UnsafeRow) $iterTerm.getKey();
UnsafeRow $bufferTerm = (UnsafeRow) $iterTerm.getValue();
$outputCode
+
+ if (shouldStop()) return;
}
$iterTerm.close();
@@ -480,7 +482,7 @@ case class TungstenAggregate(
"""
}
- private def doConsumeWithKeys( ctx: CodegenContext, input: Seq[ExprCode]): String = {
+ private def doConsumeWithKeys(ctx: CodegenContext, input: Seq[ExprCode]): String = {
// create grouping key
ctx.currentVars = input
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
index ae4422195c..6e51c4d848 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
@@ -237,6 +237,8 @@ case class Range(
| $overflow = true;
| }
| ${consume(ctx, Seq(ev))}
+ |
+ | if (shouldStop()) return;
| }
""".stripMargin
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
index 04640711d9..8b275e886c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
@@ -20,14 +20,17 @@ package org.apache.spark.sql.execution.joins
import scala.concurrent._
import scala.concurrent.duration._
-import org.apache.spark.{InternalAccumulator, TaskContext}
+import org.apache.spark.TaskContext
+import org.apache.spark.broadcast.Broadcast
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.Expression
+import org.apache.spark.sql.catalyst.expressions.{BindReferences, BoundReference, Expression, UnsafeRow}
+import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, GenerateUnsafeProjection}
import org.apache.spark.sql.catalyst.plans.physical.{Distribution, Partitioning, UnspecifiedDistribution}
-import org.apache.spark.sql.execution.{BinaryNode, SparkPlan, SQLExecution}
+import org.apache.spark.sql.execution.{BinaryNode, CodegenSupport, SparkPlan, SQLExecution}
import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.util.ThreadUtils
+import org.apache.spark.util.collection.CompactBuffer
/**
* Performs an inner hash join of two child relations. When the output RDD of this operator is
@@ -42,7 +45,7 @@ case class BroadcastHashJoin(
condition: Option[Expression],
left: SparkPlan,
right: SparkPlan)
- extends BinaryNode with HashJoin {
+ extends BinaryNode with HashJoin with CodegenSupport {
override private[sql] lazy val metrics = Map(
"numLeftRows" -> SQLMetrics.createLongMetric(sparkContext, "number of left rows"),
@@ -117,6 +120,87 @@ case class BroadcastHashJoin(
hashJoin(streamedIter, numStreamedRows, hashedRelation, numOutputRows)
}
}
+
+ // the term for hash relation
+ private var relationTerm: String = _
+
+ override def upstream(): RDD[InternalRow] = {
+ streamedPlan.asInstanceOf[CodegenSupport].upstream()
+ }
+
+ override def doProduce(ctx: CodegenContext): String = {
+ // create a name for HashRelation
+ val broadcastRelation = Await.result(broadcastFuture, timeout)
+ val broadcast = ctx.addReferenceObj("broadcast", broadcastRelation)
+ relationTerm = ctx.freshName("relation")
+ // TODO: create specialized HashRelation for single join key
+ val clsName = classOf[UnsafeHashedRelation].getName
+ ctx.addMutableState(clsName, relationTerm,
+ s"""
+ | $relationTerm = ($clsName) $broadcast.value();
+ | incPeakExecutionMemory($relationTerm.getUnsafeSize());
+ """.stripMargin)
+
+ s"""
+ | ${streamedPlan.asInstanceOf[CodegenSupport].produce(ctx, this)}
+ """.stripMargin
+ }
+
+ override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = {
+ // generate the key as UnsafeRow
+ ctx.currentVars = input
+ val keyExpr = streamedKeys.map(BindReferences.bindReference(_, streamedPlan.output))
+ val keyVal = GenerateUnsafeProjection.createCode(ctx, keyExpr)
+ val keyTerm = keyVal.value
+ val anyNull = if (keyExpr.exists(_.nullable)) s"$keyTerm.anyNull()" else "false"
+
+ // find the matches from HashedRelation
+ val matches = ctx.freshName("matches")
+ val bufferType = classOf[CompactBuffer[UnsafeRow]].getName
+ val i = ctx.freshName("i")
+ val size = ctx.freshName("size")
+ val row = ctx.freshName("row")
+
+ // create variables for output
+ ctx.currentVars = null
+ ctx.INPUT_ROW = row
+ val buildColumns = buildPlan.output.zipWithIndex.map { case (a, i) =>
+ BoundReference(i, a.dataType, a.nullable).gen(ctx)
+ }
+ val resultVars = buildSide match {
+ case BuildLeft => buildColumns ++ input
+ case BuildRight => input ++ buildColumns
+ }
+
+ val ouputCode = if (condition.isDefined) {
+ // filter the output via condition
+ ctx.currentVars = resultVars
+ val ev = BindReferences.bindReference(condition.get, this.output).gen(ctx)
+ s"""
+ | ${ev.code}
+ | if (!${ev.isNull} && ${ev.value}) {
+ | ${consume(ctx, resultVars)}
+ | }
+ """.stripMargin
+ } else {
+ consume(ctx, resultVars)
+ }
+
+ s"""
+ | // generate join key
+ | ${keyVal.code}
+ | // find matches from HashRelation
+ | $bufferType $matches = $anyNull ? null : ($bufferType) $relationTerm.get($keyTerm);
+ | if ($matches != null) {
+ | int $size = $matches.size();
+ | for (int $i = 0; $i < $size; $i++) {
+ | UnsafeRow $row = (UnsafeRow) $matches.apply($i);
+ | ${buildColumns.map(_.code).mkString("\n")}
+ | $ouputCode
+ | }
+ | }
+ """.stripMargin
+ }
}
object BroadcastHashJoin {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala
index ec2b9ab2cb..15ba773531 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala
@@ -21,6 +21,7 @@ import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite}
import org.apache.spark.memory.{StaticMemoryManager, TaskMemoryManager}
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
+import org.apache.spark.sql.functions._
import org.apache.spark.unsafe.Platform
import org.apache.spark.unsafe.hash.Murmur3_x86_32
import org.apache.spark.unsafe.map.BytesToBytesMap
@@ -130,6 +131,30 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
benchmark.run()
}
+ def testBroadcastHashJoin(values: Int): Unit = {
+ val benchmark = new Benchmark("BroadcastHashJoin", values)
+
+ val dim = broadcast(sqlContext.range(1 << 16).selectExpr("id as k", "cast(id as string) as v"))
+
+ benchmark.addCase("BroadcastHashJoin w/o codegen") { iter =>
+ sqlContext.setConf("spark.sql.codegen.wholeStage", "false")
+ sqlContext.range(values).join(dim, (col("id") % 60000) === col("k")).count()
+ }
+ benchmark.addCase(s"BroadcastHashJoin w codegen") { iter =>
+ sqlContext.setConf("spark.sql.codegen.wholeStage", "true")
+ sqlContext.range(values).join(dim, (col("id") % 60000) === col("k")).count()
+ }
+
+ /*
+ Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
+ BroadcastHashJoin: Avg Time(ms) Avg Rate(M/s) Relative Rate
+ -------------------------------------------------------------------------------
+ BroadcastHashJoin w/o codegen 3053.41 3.43 1.00 X
+ BroadcastHashJoin w codegen 1028.40 10.20 2.97 X
+ */
+ benchmark.run()
+ }
+
def testBytesToBytesMap(values: Int): Unit = {
val benchmark = new Benchmark("BytesToBytesMap", values)
@@ -201,6 +226,7 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
// testWholeStage(200 << 20)
// testStatFunctions(20 << 20)
// testAggregateWithKey(20 << 20)
- // testBytesToBytesMap(1024 * 1024 * 50)
+ // testBytesToBytesMap(50 << 20)
+ // testBroadcastHashJoin(10 << 20)
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
index c2516509df..9350205d79 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
@@ -20,8 +20,10 @@ package org.apache.spark.sql.execution
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.execution.aggregate.TungstenAggregate
-import org.apache.spark.sql.functions.{avg, col, max}
+import org.apache.spark.sql.execution.joins.BroadcastHashJoin
+import org.apache.spark.sql.functions.{avg, broadcast, col, max}
import org.apache.spark.sql.test.SharedSQLContext
+import org.apache.spark.sql.types.{IntegerType, StringType, StructType}
class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext {
@@ -56,4 +58,15 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext {
p.asInstanceOf[WholeStageCodegen].plan.isInstanceOf[TungstenAggregate]).isDefined)
assert(df.collect() === Array(Row(0, 1), Row(1, 1), Row(2, 1)))
}
+
+ test("BroadcastHashJoin should be included in WholeStageCodegen") {
+ val rdd = sqlContext.sparkContext.makeRDD(Seq(Row(1, "1"), Row(1, "1"), Row(2, "2")))
+ val schema = new StructType().add("k", IntegerType).add("v", StringType)
+ val smallDF = sqlContext.createDataFrame(rdd, schema)
+ val df = sqlContext.range(10).join(broadcast(smallDF), col("k") === col("id"))
+ assert(df.queryExecution.executedPlan.find(p =>
+ p.isInstanceOf[WholeStageCodegen] &&
+ p.asInstanceOf[WholeStageCodegen].plan.isInstanceOf[BroadcastHashJoin]).isDefined)
+ assert(df.collect() === Array(Row(1, 1, "1"), Row(1, 1, "1"), Row(2, 2, "2")))
+ }
}