diff options
author | Daoyuan <daoyuan.wang@intel.com> | 2014-06-09 11:31:36 -0700 |
---|---|---|
committer | Michael Armbrust <michael@databricks.com> | 2014-06-09 11:31:36 -0700 |
commit | 0cf600280167a94faec75736223256e8f2e48085 (patch) | |
tree | 94d0bd6d1c02263ac0d3e9dbd9bd793e452944af /sql/core | |
parent | 35630c86ff0e27862c9d902887eb0a24d25867ae (diff) | |
download | spark-0cf600280167a94faec75736223256e8f2e48085.tar.gz spark-0cf600280167a94faec75736223256e8f2e48085.tar.bz2 spark-0cf600280167a94faec75736223256e8f2e48085.zip |
[SPARK-1495][SQL]add support for left semi join
Just submit another solution for #395
Author: Daoyuan <daoyuan.wang@intel.com>
Author: Michael Armbrust <michael@databricks.com>
Author: Daoyuan Wang <daoyuan.wang@intel.com>
Closes #837 from adrian-wang/left-semi-join-support and squashes the following commits:
d39cd12 [Daoyuan Wang] Merge pull request #1 from marmbrus/pr/837
6713c09 [Michael Armbrust] Better debugging for failed query tests.
035b73e [Michael Armbrust] Add test for left semi that can't be done with a hash join.
5ec6fa4 [Michael Armbrust] Add left semi to SQL Parser.
4c726e5 [Daoyuan] improvement according to Michael
8d4a121 [Daoyuan] add golden files for leftsemijoin
83a3c8a [Daoyuan] scala style fix
14cff80 [Daoyuan] add support for left semi join
Diffstat (limited to 'sql/core')
5 files changed, 156 insertions, 1 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 043be58edc..e371c82d81 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -193,6 +193,7 @@ class SQLContext(@transient val sparkContext: SparkContext) val strategies: Seq[Strategy] = TakeOrdered :: PartialAggregation :: + LeftSemiJoin :: HashJoin :: ParquetOperations :: BasicOperators :: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index cfa8bdae58..6463f47510 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -28,6 +28,22 @@ import org.apache.spark.sql.parquet._ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { self: SQLContext#SparkPlanner => + object LeftSemiJoin extends Strategy with PredicateHelper { + def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + // Find left semi joins where at least some predicates can be evaluated by matching hash + // keys using the HashFilteredJoin pattern. + case HashFilteredJoin(LeftSemi, leftKeys, rightKeys, condition, left, right) => + val semiJoin = execution.LeftSemiJoinHash( + leftKeys, rightKeys, planLater(left), planLater(right)) + condition.map(Filter(_, semiJoin)).getOrElse(semiJoin) :: Nil + // no predicate can be evaluated by matching hash keys + case logical.Join(left, right, LeftSemi, condition) => + execution.LeftSemiJoinBNL( + planLater(left), planLater(right), condition)(sparkContext) :: Nil + case _ => Nil + } + } + object HashJoin extends Strategy with PredicateHelper { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { // Find inner joins where at least some predicates can be evaluated by matching hash keys diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala index 31cc26962a..88ff3d49a7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala @@ -142,6 +142,137 @@ case class HashJoin( /** * :: DeveloperApi :: + * Build the right table's join keys into a HashSet, and iteratively go through the left + * table, to find the if join keys are in the Hash set. + */ +@DeveloperApi +case class LeftSemiJoinHash( + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + left: SparkPlan, + right: SparkPlan) extends BinaryNode { + + override def outputPartitioning: Partitioning = left.outputPartitioning + + override def requiredChildDistribution = + ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil + + val (buildPlan, streamedPlan) = (right, left) + val (buildKeys, streamedKeys) = (rightKeys, leftKeys) + + def output = left.output + + @transient lazy val buildSideKeyGenerator = new Projection(buildKeys, buildPlan.output) + @transient lazy val streamSideKeyGenerator = + () => new MutableProjection(streamedKeys, streamedPlan.output) + + def execute() = { + + buildPlan.execute().zipPartitions(streamedPlan.execute()) { (buildIter, streamIter) => + val hashTable = new java.util.HashSet[Row]() + var currentRow: Row = null + + // Create a Hash set of buildKeys + while (buildIter.hasNext) { + currentRow = buildIter.next() + val rowKey = buildSideKeyGenerator(currentRow) + if(!rowKey.anyNull) { + val keyExists = hashTable.contains(rowKey) + if (!keyExists) { + hashTable.add(rowKey) + } + } + } + + new Iterator[Row] { + private[this] var currentStreamedRow: Row = _ + private[this] var currentHashMatched: Boolean = false + + private[this] val joinKeys = streamSideKeyGenerator() + + override final def hasNext: Boolean = + streamIter.hasNext && fetchNext() + + override final def next() = { + currentStreamedRow + } + + /** + * Searches the streamed iterator for the next row that has at least one match in hashtable. + * + * @return true if the search is successful, and false the streamed iterator runs out of + * tuples. + */ + private final def fetchNext(): Boolean = { + currentHashMatched = false + while (!currentHashMatched && streamIter.hasNext) { + currentStreamedRow = streamIter.next() + if (!joinKeys(currentStreamedRow).anyNull) { + currentHashMatched = hashTable.contains(joinKeys.currentValue) + } + } + currentHashMatched + } + } + } + } +} + +/** + * :: DeveloperApi :: + * Using BroadcastNestedLoopJoin to calculate left semi join result when there's no join keys + * for hash join. + */ +@DeveloperApi +case class LeftSemiJoinBNL( + streamed: SparkPlan, broadcast: SparkPlan, condition: Option[Expression]) + (@transient sc: SparkContext) + extends BinaryNode { + // TODO: Override requiredChildDistribution. + + override def outputPartitioning: Partitioning = streamed.outputPartitioning + + override def otherCopyArgs = sc :: Nil + + def output = left.output + + /** The Streamed Relation */ + def left = streamed + /** The Broadcast relation */ + def right = broadcast + + @transient lazy val boundCondition = + InterpretedPredicate( + condition + .map(c => BindReferences.bindReference(c, left.output ++ right.output)) + .getOrElse(Literal(true))) + + + def execute() = { + val broadcastedRelation = sc.broadcast(broadcast.execute().map(_.copy()).collect().toIndexedSeq) + + streamed.execute().mapPartitions { streamedIter => + val joinedRow = new JoinedRow + + streamedIter.filter(streamedRow => { + var i = 0 + var matched = false + + while (i < broadcastedRelation.value.size && !matched) { + val broadcastedRow = broadcastedRelation.value(i) + if (boundCondition(joinedRow(streamedRow, broadcastedRow))) { + matched = true + } + i += 1 + } + matched + }) + } + } +} + +/** + * :: DeveloperApi :: */ @DeveloperApi case class CartesianProduct(left: SparkPlan, right: SparkPlan) extends BinaryNode { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index d6072b402a..d7f6abaf5d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -44,7 +44,7 @@ class QueryTest extends FunSuite { fail( s""" |Exception thrown while executing query: - |${rdd.logicalPlan} + |${rdd.queryExecution} |== Exception == |$e """.stripMargin) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index aa0c426f6f..d651b967a6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -40,6 +40,13 @@ class SQLQuerySuite extends QueryTest { arrayData.map(d => (d.data, d.data(0), d.data(0) + d.data(1), d.data(1))).collect().toSeq) } + test("left semi greater than predicate") { + checkAnswer( + sql("SELECT * FROM testData2 x LEFT SEMI JOIN testData2 y ON x.a >= y.a + 2"), + Seq((3,1), (3,2)) + ) + } + test("index into array of arrays") { checkAnswer( sql( |