aboutsummaryrefslogtreecommitdiff
path: root/sql/core
diff options
context:
space:
mode:
authorDaoyuan <daoyuan.wang@intel.com>2014-06-09 11:31:36 -0700
committerMichael Armbrust <michael@databricks.com>2014-06-09 11:31:36 -0700
commit0cf600280167a94faec75736223256e8f2e48085 (patch)
tree94d0bd6d1c02263ac0d3e9dbd9bd793e452944af /sql/core
parent35630c86ff0e27862c9d902887eb0a24d25867ae (diff)
downloadspark-0cf600280167a94faec75736223256e8f2e48085.tar.gz
spark-0cf600280167a94faec75736223256e8f2e48085.tar.bz2
spark-0cf600280167a94faec75736223256e8f2e48085.zip
[SPARK-1495][SQL]add support for left semi join
Just submit another solution for #395 Author: Daoyuan <daoyuan.wang@intel.com> Author: Michael Armbrust <michael@databricks.com> Author: Daoyuan Wang <daoyuan.wang@intel.com> Closes #837 from adrian-wang/left-semi-join-support and squashes the following commits: d39cd12 [Daoyuan Wang] Merge pull request #1 from marmbrus/pr/837 6713c09 [Michael Armbrust] Better debugging for failed query tests. 035b73e [Michael Armbrust] Add test for left semi that can't be done with a hash join. 5ec6fa4 [Michael Armbrust] Add left semi to SQL Parser. 4c726e5 [Daoyuan] improvement according to Michael 8d4a121 [Daoyuan] add golden files for leftsemijoin 83a3c8a [Daoyuan] scala style fix 14cff80 [Daoyuan] add support for left semi join
Diffstat (limited to 'sql/core')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala1
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala16
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala131
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala7
5 files changed, 156 insertions, 1 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index 043be58edc..e371c82d81 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -193,6 +193,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
val strategies: Seq[Strategy] =
TakeOrdered ::
PartialAggregation ::
+ LeftSemiJoin ::
HashJoin ::
ParquetOperations ::
BasicOperators ::
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index cfa8bdae58..6463f47510 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -28,6 +28,22 @@ import org.apache.spark.sql.parquet._
private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
self: SQLContext#SparkPlanner =>
+ object LeftSemiJoin extends Strategy with PredicateHelper {
+ def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
+ // Find left semi joins where at least some predicates can be evaluated by matching hash
+ // keys using the HashFilteredJoin pattern.
+ case HashFilteredJoin(LeftSemi, leftKeys, rightKeys, condition, left, right) =>
+ val semiJoin = execution.LeftSemiJoinHash(
+ leftKeys, rightKeys, planLater(left), planLater(right))
+ condition.map(Filter(_, semiJoin)).getOrElse(semiJoin) :: Nil
+ // no predicate can be evaluated by matching hash keys
+ case logical.Join(left, right, LeftSemi, condition) =>
+ execution.LeftSemiJoinBNL(
+ planLater(left), planLater(right), condition)(sparkContext) :: Nil
+ case _ => Nil
+ }
+ }
+
object HashJoin extends Strategy with PredicateHelper {
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
// Find inner joins where at least some predicates can be evaluated by matching hash keys
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala
index 31cc26962a..88ff3d49a7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala
@@ -142,6 +142,137 @@ case class HashJoin(
/**
* :: DeveloperApi ::
+ * Build the right table's join keys into a HashSet, and iteratively go through the left
+ * table, to find the if join keys are in the Hash set.
+ */
+@DeveloperApi
+case class LeftSemiJoinHash(
+ leftKeys: Seq[Expression],
+ rightKeys: Seq[Expression],
+ left: SparkPlan,
+ right: SparkPlan) extends BinaryNode {
+
+ override def outputPartitioning: Partitioning = left.outputPartitioning
+
+ override def requiredChildDistribution =
+ ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil
+
+ val (buildPlan, streamedPlan) = (right, left)
+ val (buildKeys, streamedKeys) = (rightKeys, leftKeys)
+
+ def output = left.output
+
+ @transient lazy val buildSideKeyGenerator = new Projection(buildKeys, buildPlan.output)
+ @transient lazy val streamSideKeyGenerator =
+ () => new MutableProjection(streamedKeys, streamedPlan.output)
+
+ def execute() = {
+
+ buildPlan.execute().zipPartitions(streamedPlan.execute()) { (buildIter, streamIter) =>
+ val hashTable = new java.util.HashSet[Row]()
+ var currentRow: Row = null
+
+ // Create a Hash set of buildKeys
+ while (buildIter.hasNext) {
+ currentRow = buildIter.next()
+ val rowKey = buildSideKeyGenerator(currentRow)
+ if(!rowKey.anyNull) {
+ val keyExists = hashTable.contains(rowKey)
+ if (!keyExists) {
+ hashTable.add(rowKey)
+ }
+ }
+ }
+
+ new Iterator[Row] {
+ private[this] var currentStreamedRow: Row = _
+ private[this] var currentHashMatched: Boolean = false
+
+ private[this] val joinKeys = streamSideKeyGenerator()
+
+ override final def hasNext: Boolean =
+ streamIter.hasNext && fetchNext()
+
+ override final def next() = {
+ currentStreamedRow
+ }
+
+ /**
+ * Searches the streamed iterator for the next row that has at least one match in hashtable.
+ *
+ * @return true if the search is successful, and false the streamed iterator runs out of
+ * tuples.
+ */
+ private final def fetchNext(): Boolean = {
+ currentHashMatched = false
+ while (!currentHashMatched && streamIter.hasNext) {
+ currentStreamedRow = streamIter.next()
+ if (!joinKeys(currentStreamedRow).anyNull) {
+ currentHashMatched = hashTable.contains(joinKeys.currentValue)
+ }
+ }
+ currentHashMatched
+ }
+ }
+ }
+ }
+}
+
+/**
+ * :: DeveloperApi ::
+ * Using BroadcastNestedLoopJoin to calculate left semi join result when there's no join keys
+ * for hash join.
+ */
+@DeveloperApi
+case class LeftSemiJoinBNL(
+ streamed: SparkPlan, broadcast: SparkPlan, condition: Option[Expression])
+ (@transient sc: SparkContext)
+ extends BinaryNode {
+ // TODO: Override requiredChildDistribution.
+
+ override def outputPartitioning: Partitioning = streamed.outputPartitioning
+
+ override def otherCopyArgs = sc :: Nil
+
+ def output = left.output
+
+ /** The Streamed Relation */
+ def left = streamed
+ /** The Broadcast relation */
+ def right = broadcast
+
+ @transient lazy val boundCondition =
+ InterpretedPredicate(
+ condition
+ .map(c => BindReferences.bindReference(c, left.output ++ right.output))
+ .getOrElse(Literal(true)))
+
+
+ def execute() = {
+ val broadcastedRelation = sc.broadcast(broadcast.execute().map(_.copy()).collect().toIndexedSeq)
+
+ streamed.execute().mapPartitions { streamedIter =>
+ val joinedRow = new JoinedRow
+
+ streamedIter.filter(streamedRow => {
+ var i = 0
+ var matched = false
+
+ while (i < broadcastedRelation.value.size && !matched) {
+ val broadcastedRow = broadcastedRelation.value(i)
+ if (boundCondition(joinedRow(streamedRow, broadcastedRow))) {
+ matched = true
+ }
+ i += 1
+ }
+ matched
+ })
+ }
+ }
+}
+
+/**
+ * :: DeveloperApi ::
*/
@DeveloperApi
case class CartesianProduct(left: SparkPlan, right: SparkPlan) extends BinaryNode {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
index d6072b402a..d7f6abaf5d 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
@@ -44,7 +44,7 @@ class QueryTest extends FunSuite {
fail(
s"""
|Exception thrown while executing query:
- |${rdd.logicalPlan}
+ |${rdd.queryExecution}
|== Exception ==
|$e
""".stripMargin)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index aa0c426f6f..d651b967a6 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -40,6 +40,13 @@ class SQLQuerySuite extends QueryTest {
arrayData.map(d => (d.data, d.data(0), d.data(0) + d.data(1), d.data(1))).collect().toSeq)
}
+ test("left semi greater than predicate") {
+ checkAnswer(
+ sql("SELECT * FROM testData2 x LEFT SEMI JOIN testData2 y ON x.a >= y.a + 2"),
+ Seq((3,1), (3,2))
+ )
+ }
+
test("index into array of arrays") {
checkAnswer(
sql(