aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDaoyuan Wang <daoyuan.wang@intel.com>2015-07-17 16:45:46 -0700
committerMichael Armbrust <michael@databricks.com>2015-07-17 16:45:46 -0700
commit1707238601690fd0e8e173e2c47f1b4286644a29 (patch)
tree5a17ff1d0dd82694aa8e52f32fa21d237a738032
parentb13ef7723f254c10c685b93eb8dc08a52527ec73 (diff)
downloadspark-1707238601690fd0e8e173e2c47f1b4286644a29.tar.gz
spark-1707238601690fd0e8e173e2c47f1b4286644a29.tar.bz2
spark-1707238601690fd0e8e173e2c47f1b4286644a29.zip
[SPARK-7026] [SQL] fix left semi join with equi key and non-equi condition
When the `condition` extracted by `ExtractEquiJoinKeys` contain join Predicate for left semi join, we can not plan it as semiJoin. Such as SELECT * FROM testData2 x LEFT SEMI JOIN testData2 y ON x.b = y.b AND x.a >= y.a + 2 Condition `x.a >= y.a + 2` can not evaluate on table `x`, so it throw errors Author: Daoyuan Wang <daoyuan.wang@intel.com> Closes #5643 from adrian-wang/spark7026 and squashes the following commits: cc09809 [Daoyuan Wang] refactor semijoin and add plan test 575a7c8 [Daoyuan Wang] fix notserializable 27841de [Daoyuan Wang] fix rebase 10bf124 [Daoyuan Wang] fix style 72baa02 [Daoyuan Wang] fix style 8e0afca [Daoyuan Wang] merge commits for rebase
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala10
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala42
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala3
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala91
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala35
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala12
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala74
7 files changed, 208 insertions, 59 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 73b463471e..240332a80a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -38,14 +38,12 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
case ExtractEquiJoinKeys(LeftSemi, leftKeys, rightKeys, condition, left, right)
if sqlContext.conf.autoBroadcastJoinThreshold > 0 &&
right.statistics.sizeInBytes <= sqlContext.conf.autoBroadcastJoinThreshold =>
- val semiJoin = joins.BroadcastLeftSemiJoinHash(
- leftKeys, rightKeys, planLater(left), planLater(right))
- condition.map(Filter(_, semiJoin)).getOrElse(semiJoin) :: Nil
+ joins.BroadcastLeftSemiJoinHash(
+ leftKeys, rightKeys, planLater(left), planLater(right), condition) :: Nil
// Find left semi joins where at least some predicates can be evaluated by matching join keys
case ExtractEquiJoinKeys(LeftSemi, leftKeys, rightKeys, condition, left, right) =>
- val semiJoin = joins.LeftSemiJoinHash(
- leftKeys, rightKeys, planLater(left), planLater(right))
- condition.map(Filter(_, semiJoin)).getOrElse(semiJoin) :: Nil
+ joins.LeftSemiJoinHash(
+ leftKeys, rightKeys, planLater(left), planLater(right), condition) :: Nil
// no predicate can be evaluated by matching hash keys
case logical.Join(left, right, LeftSemi, condition) =>
joins.LeftSemiJoinBNL(planLater(left), planLater(right), condition) :: Nil
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala
index f7b46d6888..2750f58b00 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala
@@ -33,37 +33,27 @@ case class BroadcastLeftSemiJoinHash(
leftKeys: Seq[Expression],
rightKeys: Seq[Expression],
left: SparkPlan,
- right: SparkPlan) extends BinaryNode with HashJoin {
-
- override val buildSide: BuildSide = BuildRight
-
- override def output: Seq[Attribute] = left.output
+ right: SparkPlan,
+ condition: Option[Expression]) extends BinaryNode with HashSemiJoin {
protected override def doExecute(): RDD[InternalRow] = {
- val buildIter = buildPlan.execute().map(_.copy()).collect().toIterator
- val hashSet = new java.util.HashSet[InternalRow]()
- var currentRow: InternalRow = null
+ val buildIter = right.execute().map(_.copy()).collect().toIterator
- // Create a Hash set of buildKeys
- while (buildIter.hasNext) {
- currentRow = buildIter.next()
- val rowKey = buildSideKeyGenerator(currentRow)
- if (!rowKey.anyNull) {
- val keyExists = hashSet.contains(rowKey)
- if (!keyExists) {
- // rowKey may be not serializable (from codegen)
- hashSet.add(rowKey.copy())
- }
- }
- }
+ if (condition.isEmpty) {
+ // rowKey may be not serializable (from codegen)
+ val hashSet = buildKeyHashSet(buildIter, copy = true)
+ val broadcastedRelation = sparkContext.broadcast(hashSet)
- val broadcastedRelation = sparkContext.broadcast(hashSet)
+ left.execute().mapPartitions { streamIter =>
+ hashSemiJoin(streamIter, broadcastedRelation.value)
+ }
+ } else {
+ val hashRelation = HashedRelation(buildIter, rightKeyGenerator)
+ val broadcastedRelation = sparkContext.broadcast(hashRelation)
- streamedPlan.execute().mapPartitions { streamIter =>
- val joinKeys = streamSideKeyGenerator()
- streamIter.filter(current => {
- !joinKeys(current).anyNull && broadcastedRelation.value.contains(joinKeys.currentValue)
- })
+ left.execute().mapPartitions { streamIter =>
+ hashSemiJoin(streamIter, broadcastedRelation.value)
+ }
}
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala
index 0522ee85ee..74a7db7761 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala
@@ -65,8 +65,7 @@ override def outputPartitioning: Partitioning = joinType match {
@transient private[this] lazy val leftNullRow = new GenericInternalRow(left.output.length)
@transient private[this] lazy val rightNullRow = new GenericInternalRow(right.output.length)
@transient private[this] lazy val boundCondition =
- condition.map(
- newPredicate(_, left.output ++ right.output)).getOrElse((row: InternalRow) => true)
+ newPredicate(condition.getOrElse(Literal(true)), left.output ++ right.output)
// TODO we need to rewrite all of the iterators with our own implementation instead of the Scala
// iterator for performance purpose.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala
new file mode 100644
index 0000000000..1b983bc3a9
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala
@@ -0,0 +1,91 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.joins
+
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.execution.SparkPlan
+
+
+trait HashSemiJoin {
+ self: SparkPlan =>
+ val leftKeys: Seq[Expression]
+ val rightKeys: Seq[Expression]
+ val left: SparkPlan
+ val right: SparkPlan
+ val condition: Option[Expression]
+
+ override def output: Seq[Attribute] = left.output
+
+ @transient protected lazy val rightKeyGenerator: Projection =
+ newProjection(rightKeys, right.output)
+
+ @transient protected lazy val leftKeyGenerator: () => MutableProjection =
+ newMutableProjection(leftKeys, left.output)
+
+ @transient private lazy val boundCondition =
+ newPredicate(condition.getOrElse(Literal(true)), left.output ++ right.output)
+
+ protected def buildKeyHashSet(
+ buildIter: Iterator[InternalRow],
+ copy: Boolean): java.util.Set[InternalRow] = {
+ val hashSet = new java.util.HashSet[InternalRow]()
+ var currentRow: InternalRow = null
+
+ // Create a Hash set of buildKeys
+ while (buildIter.hasNext) {
+ currentRow = buildIter.next()
+ val rowKey = rightKeyGenerator(currentRow)
+ if (!rowKey.anyNull) {
+ val keyExists = hashSet.contains(rowKey)
+ if (!keyExists) {
+ if (copy) {
+ hashSet.add(rowKey.copy())
+ } else {
+ // rowKey may be not serializable (from codegen)
+ hashSet.add(rowKey)
+ }
+ }
+ }
+ }
+ hashSet
+ }
+
+ protected def hashSemiJoin(
+ streamIter: Iterator[InternalRow],
+ hashedRelation: HashedRelation): Iterator[InternalRow] = {
+ val joinKeys = leftKeyGenerator()
+ val joinedRow = new JoinedRow
+ streamIter.filter(current => {
+ lazy val rowBuffer = hashedRelation.get(joinKeys.currentValue)
+ !joinKeys(current).anyNull && rowBuffer != null && rowBuffer.exists {
+ (build: InternalRow) => boundCondition(joinedRow(current, build))
+ }
+ })
+ }
+
+ protected def hashSemiJoin(
+ streamIter: Iterator[InternalRow],
+ hashSet: java.util.Set[InternalRow]): Iterator[InternalRow] = {
+ val joinKeys = leftKeyGenerator()
+ val joinedRow = new JoinedRow
+ streamIter.filter(current => {
+ !joinKeys(current.copy()).anyNull && hashSet.contains(joinKeys.currentValue)
+ })
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala
index 611ba928a1..9eaac817d9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.joins
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression}
+import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical.ClusteredDistribution
import org.apache.spark.sql.execution.{BinaryNode, SparkPlan}
@@ -34,36 +34,21 @@ case class LeftSemiJoinHash(
leftKeys: Seq[Expression],
rightKeys: Seq[Expression],
left: SparkPlan,
- right: SparkPlan) extends BinaryNode with HashJoin {
-
- override val buildSide: BuildSide = BuildRight
+ right: SparkPlan,
+ condition: Option[Expression]) extends BinaryNode with HashSemiJoin {
override def requiredChildDistribution: Seq[ClusteredDistribution] =
ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil
- override def output: Seq[Attribute] = left.output
-
protected override def doExecute(): RDD[InternalRow] = {
- buildPlan.execute().zipPartitions(streamedPlan.execute()) { (buildIter, streamIter) =>
- val hashSet = new java.util.HashSet[InternalRow]()
- var currentRow: InternalRow = null
-
- // Create a Hash set of buildKeys
- while (buildIter.hasNext) {
- currentRow = buildIter.next()
- val rowKey = buildSideKeyGenerator(currentRow)
- if (!rowKey.anyNull) {
- val keyExists = hashSet.contains(rowKey)
- if (!keyExists) {
- hashSet.add(rowKey)
- }
- }
+ right.execute().zipPartitions(left.execute()) { (buildIter, streamIter) =>
+ if (condition.isEmpty) {
+ val hashSet = buildKeyHashSet(buildIter, copy = false)
+ hashSemiJoin(streamIter, hashSet)
+ } else {
+ val hashRelation = HashedRelation(buildIter, rightKeyGenerator)
+ hashSemiJoin(streamIter, hashRelation)
}
-
- val joinKeys = streamSideKeyGenerator()
- streamIter.filter(current => {
- !joinKeys(current).anyNull && hashSet.contains(joinKeys.currentValue)
- })
}
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index 5b8b70ed5a..61d5f2061a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -395,6 +395,18 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils {
)
}
+ test("left semi greater than predicate and equal operator") {
+ checkAnswer(
+ sql("SELECT * FROM testData2 x LEFT SEMI JOIN testData2 y ON x.b = y.b and x.a >= y.a + 2"),
+ Seq(Row(3, 1), Row(3, 2))
+ )
+
+ checkAnswer(
+ sql("SELECT * FROM testData2 x LEFT SEMI JOIN testData2 y ON x.b = y.a and x.a >= y.b + 1"),
+ Seq(Row(2, 1), Row(2, 2), Row(3, 1), Row(3, 2))
+ )
+ }
+
test("index into array of arrays") {
checkAnswer(
sql(
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala
new file mode 100644
index 0000000000..927e85a7db
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala
@@ -0,0 +1,74 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.joins
+
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.expressions.{LessThan, Expression}
+import org.apache.spark.sql.execution.{SparkPlan, SparkPlanTest}
+
+
+class SemiJoinSuite extends SparkPlanTest{
+ val left = Seq(
+ (1, 2.0),
+ (1, 2.0),
+ (2, 1.0),
+ (2, 1.0),
+ (3, 3.0)
+ ).toDF("a", "b")
+
+ val right = Seq(
+ (2, 3.0),
+ (2, 3.0),
+ (3, 2.0),
+ (4, 1.0)
+ ).toDF("c", "d")
+
+ val leftKeys: List[Expression] = 'a :: Nil
+ val rightKeys: List[Expression] = 'c :: Nil
+ val condition = Some(LessThan('b, 'd))
+
+ test("left semi join hash") {
+ checkAnswer2(left, right, (left: SparkPlan, right: SparkPlan) =>
+ LeftSemiJoinHash(leftKeys, rightKeys, left, right, condition),
+ Seq(
+ (2, 1.0),
+ (2, 1.0)
+ ).map(Row.fromTuple))
+ }
+
+ test("left semi join BNL") {
+ checkAnswer2(left, right, (left: SparkPlan, right: SparkPlan) =>
+ LeftSemiJoinBNL(left, right, condition),
+ Seq(
+ (1, 2.0),
+ (1, 2.0),
+ (2, 1.0),
+ (2, 1.0)
+ ).map(Row.fromTuple))
+ }
+
+ test("broadcast left semi join hash") {
+ checkAnswer2(left, right, (left: SparkPlan, right: SparkPlan) =>
+ BroadcastLeftSemiJoinHash(leftKeys, rightKeys, left, right, condition),
+ Seq(
+ (2, 1.0),
+ (2, 1.0)
+ ).map(Row.fromTuple))
+ }
+}