aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorgatorsmile <gatorsmile@gmail.com>2016-01-29 11:22:12 -0800
committerReynold Xin <rxin@databricks.com>2016-01-29 11:22:12 -0800
commit5f686cc8b74ea9e36f56c31f14df90d134fd9343 (patch)
tree282fbb236a8a20e5f2ba879c7adf44a2c182d129
parentc5f745ede01831b59c57effa7de88c648b82c13d (diff)
downloadspark-5f686cc8b74ea9e36f56c31f14df90d134fd9343.tar.gz
spark-5f686cc8b74ea9e36f56c31f14df90d134fd9343.tar.bz2
spark-5f686cc8b74ea9e36f56c31f14df90d134fd9343.zip
[SPARK-12656] [SQL] Implement Intersect with Left-semi Join
Our current Intersect physical operator simply delegates to RDD.intersect. We should remove the Intersect physical operator and simply transform a logical intersect into a semi-join with distinct. This way, we can take advantage of all the benefits of join implementations (e.g. managed memory, code generation, broadcast joins). After a search, I found one of the mainstream RDBMS did the same. In their query explain, Intersect is replaced by Left-semi Join. Left-semi Join could help outer-join elimination in Optimizer, as shown in the PR: https://github.com/apache/spark/pull/10566 Author: gatorsmile <gatorsmile@gmail.com> Author: xiaoli <lixiao1983@gmail.com> Author: Xiao Li <xiaoli@Xiaos-MacBook-Pro.local> Closes #10630 from gatorsmile/IntersectBySemiJoin.
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala113
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala14
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala45
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala32
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala5
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateOptimizeSuite.scala12
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala59
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala15
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala5
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala12
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala21
11 files changed, 211 insertions, 122 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 33d76eeb21..5fe700ee00 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -344,6 +344,63 @@ class Analyzer(
}
}
+ /**
+ * Generate a new logical plan for the right child with different expression IDs
+ * for all conflicting attributes.
+ */
+ private def dedupRight (left: LogicalPlan, right: LogicalPlan): LogicalPlan = {
+ val conflictingAttributes = left.outputSet.intersect(right.outputSet)
+ logDebug(s"Conflicting attributes ${conflictingAttributes.mkString(",")} " +
+ s"between $left and $right")
+
+ right.collect {
+ // Handle base relations that might appear more than once.
+ case oldVersion: MultiInstanceRelation
+ if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty =>
+ val newVersion = oldVersion.newInstance()
+ (oldVersion, newVersion)
+
+ // Handle projects that create conflicting aliases.
+ case oldVersion @ Project(projectList, _)
+ if findAliases(projectList).intersect(conflictingAttributes).nonEmpty =>
+ (oldVersion, oldVersion.copy(projectList = newAliases(projectList)))
+
+ case oldVersion @ Aggregate(_, aggregateExpressions, _)
+ if findAliases(aggregateExpressions).intersect(conflictingAttributes).nonEmpty =>
+ (oldVersion, oldVersion.copy(aggregateExpressions = newAliases(aggregateExpressions)))
+
+ case oldVersion: Generate
+ if oldVersion.generatedSet.intersect(conflictingAttributes).nonEmpty =>
+ val newOutput = oldVersion.generatorOutput.map(_.newInstance())
+ (oldVersion, oldVersion.copy(generatorOutput = newOutput))
+
+ case oldVersion @ Window(_, windowExpressions, _, _, child)
+ if AttributeSet(windowExpressions.map(_.toAttribute)).intersect(conflictingAttributes)
+ .nonEmpty =>
+ (oldVersion, oldVersion.copy(windowExpressions = newAliases(windowExpressions)))
+ }
+ // Only handle first case, others will be fixed on the next pass.
+ .headOption match {
+ case None =>
+ /*
+ * No result implies that there is a logical plan node that produces new references
+ * that this rule cannot handle. When that is the case, there must be another rule
+ * that resolves these conflicts. Otherwise, the analysis will fail.
+ */
+ right
+ case Some((oldRelation, newRelation)) =>
+ val attributeRewrites = AttributeMap(oldRelation.output.zip(newRelation.output))
+ val newRight = right transformUp {
+ case r if r == oldRelation => newRelation
+ } transformUp {
+ case other => other transformExpressions {
+ case a: Attribute => attributeRewrites.get(a).getOrElse(a)
+ }
+ }
+ newRight
+ }
+ }
+
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
case p: LogicalPlan if !p.childrenResolved => p
@@ -388,57 +445,11 @@ class Analyzer(
.map(_.asInstanceOf[NamedExpression])
a.copy(aggregateExpressions = expanded)
- // Special handling for cases when self-join introduce duplicate expression ids.
- case j @ Join(left, right, _, _) if !j.selfJoinResolved =>
- val conflictingAttributes = left.outputSet.intersect(right.outputSet)
- logDebug(s"Conflicting attributes ${conflictingAttributes.mkString(",")} in $j")
-
- right.collect {
- // Handle base relations that might appear more than once.
- case oldVersion: MultiInstanceRelation
- if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty =>
- val newVersion = oldVersion.newInstance()
- (oldVersion, newVersion)
-
- // Handle projects that create conflicting aliases.
- case oldVersion @ Project(projectList, _)
- if findAliases(projectList).intersect(conflictingAttributes).nonEmpty =>
- (oldVersion, oldVersion.copy(projectList = newAliases(projectList)))
-
- case oldVersion @ Aggregate(_, aggregateExpressions, _)
- if findAliases(aggregateExpressions).intersect(conflictingAttributes).nonEmpty =>
- (oldVersion, oldVersion.copy(aggregateExpressions = newAliases(aggregateExpressions)))
-
- case oldVersion: Generate
- if oldVersion.generatedSet.intersect(conflictingAttributes).nonEmpty =>
- val newOutput = oldVersion.generatorOutput.map(_.newInstance())
- (oldVersion, oldVersion.copy(generatorOutput = newOutput))
-
- case oldVersion @ Window(_, windowExpressions, _, _, child)
- if AttributeSet(windowExpressions.map(_.toAttribute)).intersect(conflictingAttributes)
- .nonEmpty =>
- (oldVersion, oldVersion.copy(windowExpressions = newAliases(windowExpressions)))
- }
- // Only handle first case, others will be fixed on the next pass.
- .headOption match {
- case None =>
- /*
- * No result implies that there is a logical plan node that produces new references
- * that this rule cannot handle. When that is the case, there must be another rule
- * that resolves these conflicts. Otherwise, the analysis will fail.
- */
- j
- case Some((oldRelation, newRelation)) =>
- val attributeRewrites = AttributeMap(oldRelation.output.zip(newRelation.output))
- val newRight = right transformUp {
- case r if r == oldRelation => newRelation
- } transformUp {
- case other => other transformExpressions {
- case a: Attribute => attributeRewrites.get(a).getOrElse(a)
- }
- }
- j.copy(right = newRight)
- }
+ // To resolve duplicate expression IDs for Join and Intersect
+ case j @ Join(left, right, _, _) if !j.duplicateResolved =>
+ j.copy(right = dedupRight(left, right))
+ case i @ Intersect(left, right) if !i.duplicateResolved =>
+ i.copy(right = dedupRight(left, right))
// When resolve `SortOrder`s in Sort based on child, don't report errors as
// we still have chance to resolve it based on grandchild
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
index f2e78d9744..4a2f2b8bc6 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
@@ -214,9 +214,8 @@ trait CheckAnalysis {
s"""Only a single table generating function is allowed in a SELECT clause, found:
| ${exprs.map(_.prettyString).mkString(",")}""".stripMargin)
- // Special handling for cases when self-join introduce duplicate expression ids.
- case j @ Join(left, right, _, _) if left.outputSet.intersect(right.outputSet).nonEmpty =>
- val conflictingAttributes = left.outputSet.intersect(right.outputSet)
+ case j: Join if !j.duplicateResolved =>
+ val conflictingAttributes = j.left.outputSet.intersect(j.right.outputSet)
failAnalysis(
s"""
|Failure when resolving conflicting references in Join:
@@ -224,6 +223,15 @@ trait CheckAnalysis {
|Conflicting attributes: ${conflictingAttributes.mkString(",")}
|""".stripMargin)
+ case i: Intersect if !i.duplicateResolved =>
+ val conflictingAttributes = i.left.outputSet.intersect(i.right.outputSet)
+ failAnalysis(
+ s"""
+ |Failure when resolving conflicting references in Intersect:
+ |$plan
+ |Conflicting attributes: ${conflictingAttributes.mkString(",")}
+ |""".stripMargin)
+
case o if !o.resolved =>
failAnalysis(
s"unresolved operator ${operator.simpleString}")
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 6addc20806..f156b5d10a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -52,8 +52,10 @@ abstract class Optimizer extends RuleExecutor[LogicalPlan] {
// since the other rules might make two separate Unions operators adjacent.
Batch("Union", Once,
CombineUnions) ::
+ Batch("Replace Operators", FixedPoint(100),
+ ReplaceIntersectWithSemiJoin,
+ ReplaceDistinctWithAggregate) ::
Batch("Aggregate", FixedPoint(100),
- ReplaceDistinctWithAggregate,
RemoveLiteralFromGroupExpressions) ::
Batch("Operator Optimizations", FixedPoint(100),
// Operator push down
@@ -124,18 +126,13 @@ object EliminateSerialization extends Rule[LogicalPlan] {
}
/**
- * Pushes certain operations to both sides of a Union, Intersect or Except operator.
+ * Pushes certain operations to both sides of a Union or Except operator.
* Operations that are safe to pushdown are listed as follows.
* Union:
* Right now, Union means UNION ALL, which does not de-duplicate rows. So, it is
* safe to pushdown Filters and Projections through it. Once we add UNION DISTINCT,
* we will not be able to pushdown Projections.
*
- * Intersect:
- * It is not safe to pushdown Projections through it because we need to get the
- * intersect of rows by comparing the entire rows. It is fine to pushdown Filters
- * with deterministic condition.
- *
* Except:
* It is not safe to pushdown Projections through it because we need to get the
* intersect of rows by comparing the entire rows. It is fine to pushdown Filters
@@ -153,7 +150,7 @@ object SetOperationPushDown extends Rule[LogicalPlan] with PredicateHelper {
/**
* Rewrites an expression so that it can be pushed to the right side of a
- * Union, Intersect or Except operator. This method relies on the fact that the output attributes
+ * Union or Except operator. This method relies on the fact that the output attributes
* of a union/intersect/except are always equal to the left child's output.
*/
private def pushToRight[A <: Expression](e: A, rewrites: AttributeMap[Attribute]) = {
@@ -210,17 +207,6 @@ object SetOperationPushDown extends Rule[LogicalPlan] with PredicateHelper {
}
Filter(nondeterministic, Union(newFirstChild +: newOtherChildren))
- // Push down filter through INTERSECT
- case Filter(condition, Intersect(left, right)) =>
- val (deterministic, nondeterministic) = partitionByDeterministic(condition)
- val rewrites = buildRewrites(left, right)
- Filter(nondeterministic,
- Intersect(
- Filter(deterministic, left),
- Filter(pushToRight(deterministic, rewrites), right)
- )
- )
-
// Push down filter through EXCEPT
case Filter(condition, Except(left, right)) =>
val (deterministic, nondeterministic) = partitionByDeterministic(condition)
@@ -1055,6 +1041,27 @@ object ReplaceDistinctWithAggregate extends Rule[LogicalPlan] {
}
/**
+ * Replaces logical [[Intersect]] operator with a left-semi [[Join]] operator.
+ * {{{
+ * SELECT a1, a2 FROM Tab1 INTERSECT SELECT b1, b2 FROM Tab2
+ * ==> SELECT DISTINCT a1, a2 FROM Tab1 LEFT SEMI JOIN Tab2 ON a1<=>b1 AND a2<=>b2
+ * }}}
+ *
+ * Note:
+ * 1. This rule is only applicable to INTERSECT DISTINCT. Do not use it for INTERSECT ALL.
+ * 2. This rule has to be done after de-duplicating the attributes; otherwise, the generated
+ * join conditions will be incorrect.
+ */
+object ReplaceIntersectWithSemiJoin extends Rule[LogicalPlan] {
+ def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+ case Intersect(left, right) =>
+ assert(left.output.size == right.output.size)
+ val joinCond = left.output.zip(right.output).map { case (l, r) => EqualNullSafe(l, r) }
+ Distinct(Join(left, right, LeftSemi, joinCond.reduceLeftOption(And)))
+ }
+}
+
+/**
* Removes literals from group expressions in [[Aggregate]], as they have no effect to the result
* but only makes the grouping key bigger.
*/
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
index e9c970cd08..16f4b355b1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.plans.logical
import scala.collection.mutable.ArrayBuffer
+import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.plans._
@@ -90,12 +91,7 @@ case class Filter(condition: Expression, child: LogicalPlan) extends UnaryNode {
override def output: Seq[Attribute] = child.output
}
-abstract class SetOperation(left: LogicalPlan, right: LogicalPlan) extends BinaryNode {
- final override lazy val resolved: Boolean =
- childrenResolved &&
- left.output.length == right.output.length &&
- left.output.zip(right.output).forall { case (l, r) => l.dataType == r.dataType }
-}
+abstract class SetOperation(left: LogicalPlan, right: LogicalPlan) extends BinaryNode
private[sql] object SetOperation {
def unapply(p: SetOperation): Option[(LogicalPlan, LogicalPlan)] = Some((p.left, p.right))
@@ -103,15 +99,30 @@ private[sql] object SetOperation {
case class Intersect(left: LogicalPlan, right: LogicalPlan) extends SetOperation(left, right) {
+ def duplicateResolved: Boolean = left.outputSet.intersect(right.outputSet).isEmpty
+
override def output: Seq[Attribute] =
left.output.zip(right.output).map { case (leftAttr, rightAttr) =>
leftAttr.withNullability(leftAttr.nullable && rightAttr.nullable)
}
+
+ // Intersect are only resolved if they don't introduce ambiguous expression ids,
+ // since the Optimizer will convert Intersect to Join.
+ override lazy val resolved: Boolean =
+ childrenResolved &&
+ left.output.length == right.output.length &&
+ left.output.zip(right.output).forall { case (l, r) => l.dataType == r.dataType } &&
+ duplicateResolved
}
case class Except(left: LogicalPlan, right: LogicalPlan) extends SetOperation(left, right) {
/** We don't use right.output because those rows get excluded from the set. */
override def output: Seq[Attribute] = left.output
+
+ override lazy val resolved: Boolean =
+ childrenResolved &&
+ left.output.length == right.output.length &&
+ left.output.zip(right.output).forall { case (l, r) => l.dataType == r.dataType }
}
/** Factory for constructing new `Union` nodes. */
@@ -169,13 +180,13 @@ case class Join(
}
}
- def selfJoinResolved: Boolean = left.outputSet.intersect(right.outputSet).isEmpty
+ def duplicateResolved: Boolean = left.outputSet.intersect(right.outputSet).isEmpty
// Joins are only resolved if they don't introduce ambiguous expression ids.
override lazy val resolved: Boolean = {
childrenResolved &&
expressions.forall(_.resolved) &&
- selfJoinResolved &&
+ duplicateResolved &&
condition.forall(_.dataType == BooleanType)
}
}
@@ -249,7 +260,7 @@ case class Range(
end: Long,
step: Long,
numSlices: Int,
- output: Seq[Attribute]) extends LeafNode {
+ output: Seq[Attribute]) extends LeafNode with MultiInstanceRelation {
require(step != 0, "step cannot be 0")
val numElements: BigInt = {
val safeStart = BigInt(start)
@@ -262,6 +273,9 @@ case class Range(
}
}
+ override def newInstance(): Range =
+ Range(start, end, step, numSlices, output.map(_.newInstance()))
+
override def statistics: Statistics = {
val sizeInBytes = LongType.defaultSize * numElements
Statistics( sizeInBytes = sizeInBytes )
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
index ab68028220..1938bce02a 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
@@ -154,6 +154,11 @@ class AnalysisSuite extends AnalysisTest {
checkAnalysis(plan, expected)
}
+ test("self intersect should resolve duplicate expression IDs") {
+ val plan = testRelation.intersect(testRelation)
+ assertAnalysisSuccess(plan)
+ }
+
test("SPARK-8654: invalid CAST in NULL IN(...) expression") {
val plan = Project(Alias(In(Literal(null), Seq(Literal(1), Literal(2))), "a")() :: Nil,
LocalRelation()
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateOptimizeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateOptimizeSuite.scala
index 37148a226f..a4a12c0d62 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateOptimizeSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateOptimizeSuite.scala
@@ -28,21 +28,9 @@ class AggregateOptimizeSuite extends PlanTest {
object Optimize extends RuleExecutor[LogicalPlan] {
val batches = Batch("Aggregate", FixedPoint(100),
- ReplaceDistinctWithAggregate,
RemoveLiteralFromGroupExpressions) :: Nil
}
- test("replace distinct with aggregate") {
- val input = LocalRelation('a.int, 'b.int)
-
- val query = Distinct(input)
- val optimized = Optimize.execute(query.analyze)
-
- val correctAnswer = Aggregate(input.output, input.output, input)
-
- comparePlans(optimized, correctAnswer)
- }
-
test("remove literals in grouping expression") {
val input = LocalRelation('a.int, 'b.int)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala
new file mode 100644
index 0000000000..f8ae5d9be2
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala
@@ -0,0 +1,59 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.optimizer
+
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.dsl.plans._
+import org.apache.spark.sql.catalyst.plans.{LeftSemi, PlanTest}
+import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.catalyst.rules.RuleExecutor
+
+class ReplaceOperatorSuite extends PlanTest {
+
+ object Optimize extends RuleExecutor[LogicalPlan] {
+ val batches =
+ Batch("Replace Operators", FixedPoint(100),
+ ReplaceDistinctWithAggregate,
+ ReplaceIntersectWithSemiJoin) :: Nil
+ }
+
+ test("replace Intersect with Left-semi Join") {
+ val table1 = LocalRelation('a.int, 'b.int)
+ val table2 = LocalRelation('c.int, 'd.int)
+
+ val query = Intersect(table1, table2)
+ val optimized = Optimize.execute(query.analyze)
+
+ val correctAnswer =
+ Aggregate(table1.output, table1.output,
+ Join(table1, table2, LeftSemi, Option('a <=> 'c && 'b <=> 'd))).analyze
+
+ comparePlans(optimized, correctAnswer)
+ }
+
+ test("replace Distinct with Aggregate") {
+ val input = LocalRelation('a.int, 'b.int)
+
+ val query = Distinct(input)
+ val optimized = Optimize.execute(query.analyze)
+
+ val correctAnswer = Aggregate(input.output, input.output, input)
+
+ comparePlans(optimized, correctAnswer)
+ }
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala
index 2283f7c008..b8ea32b4df 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala
@@ -39,7 +39,6 @@ class SetOperationSuite extends PlanTest {
val testRelation2 = LocalRelation('d.int, 'e.int, 'f.int)
val testRelation3 = LocalRelation('g.int, 'h.int, 'i.int)
val testUnion = Union(testRelation :: testRelation2 :: testRelation3 :: Nil)
- val testIntersect = Intersect(testRelation, testRelation2)
val testExcept = Except(testRelation, testRelation2)
test("union: combine unions into one unions") {
@@ -57,19 +56,12 @@ class SetOperationSuite extends PlanTest {
comparePlans(combinedUnionsOptimized, unionOptimized3)
}
- test("intersect/except: filter to each side") {
- val intersectQuery = testIntersect.where('b < 10)
+ test("except: filter to each side") {
val exceptQuery = testExcept.where('c >= 5)
-
- val intersectOptimized = Optimize.execute(intersectQuery.analyze)
val exceptOptimized = Optimize.execute(exceptQuery.analyze)
-
- val intersectCorrectAnswer =
- Intersect(testRelation.where('b < 10), testRelation2.where('e < 10)).analyze
val exceptCorrectAnswer =
Except(testRelation.where('c >= 5), testRelation2.where('f >= 5)).analyze
- comparePlans(intersectOptimized, intersectCorrectAnswer)
comparePlans(exceptOptimized, exceptCorrectAnswer)
}
@@ -95,13 +87,8 @@ class SetOperationSuite extends PlanTest {
}
test("SPARK-10539: Project should not be pushed down through Intersect or Except") {
- val intersectQuery = testIntersect.select('b, 'c)
val exceptQuery = testExcept.select('a, 'b, 'c)
-
- val intersectOptimized = Optimize.execute(intersectQuery.analyze)
val exceptOptimized = Optimize.execute(exceptQuery.analyze)
-
- comparePlans(intersectOptimized, intersectQuery.analyze)
comparePlans(exceptOptimized, exceptQuery.analyze)
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 60fbb595e5..9293e55141 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -298,6 +298,9 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
case logical.Distinct(child) =>
throw new IllegalStateException(
"logical distinct operator should have been replaced by aggregate in the optimizer")
+ case logical.Intersect(left, right) =>
+ throw new IllegalStateException(
+ "logical intersect operator should have been replaced by semi-join in the optimizer")
case logical.MapPartitions(f, in, out, child) =>
execution.MapPartitions(f, in, out, planLater(child)) :: Nil
@@ -340,8 +343,6 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
execution.Union(unionChildren.map(planLater)) :: Nil
case logical.Except(left, right) =>
execution.Except(planLater(left), planLater(right)) :: Nil
- case logical.Intersect(left, right) =>
- execution.Intersect(planLater(left), planLater(right)) :: Nil
case g @ logical.Generate(generator, join, outer, _, _, child) =>
execution.Generate(
generator, join = join, outer = outer, g.output, planLater(child)) :: Nil
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
index e7a73d5fbb..fd81531c93 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
@@ -421,18 +421,6 @@ case class Except(left: SparkPlan, right: SparkPlan) extends BinaryNode {
}
/**
- * Returns the rows in left that also appear in right using the built in spark
- * intersection function.
- */
-case class Intersect(left: SparkPlan, right: SparkPlan) extends BinaryNode {
- override def output: Seq[Attribute] = children.head.output
-
- protected override def doExecute(): RDD[InternalRow] = {
- left.execute().map(_.copy()).intersection(right.execute().map(_.copy()))
- }
-}
-
-/**
* A plan node that does nothing but lie about the output of its child. Used to spice a
* (hopefully structurally equivalent) tree from a different optimization sequence into an already
* resolved tree.
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index 09bbe57a43..4ff99bdf29 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -349,6 +349,27 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
Row(3, "c") ::
Row(4, "d") :: Nil)
checkAnswer(lowerCaseData.intersect(upperCaseData), Nil)
+
+ // check null equality
+ checkAnswer(
+ nullInts.intersect(nullInts),
+ Row(1) ::
+ Row(2) ::
+ Row(3) ::
+ Row(null) :: Nil)
+
+ // check if values are de-duplicated
+ checkAnswer(
+ allNulls.intersect(allNulls),
+ Row(null) :: Nil)
+
+ // check if values are de-duplicated
+ val df = Seq(("id1", 1), ("id1", 1), ("id", 1), ("id1", 2)).toDF("id", "value")
+ checkAnswer(
+ df.intersect(df),
+ Row("id1", 1) ::
+ Row("id", 1) ::
+ Row("id1", 2) :: Nil)
}
test("intersect - nullability") {