From 5f686cc8b74ea9e36f56c31f14df90d134fd9343 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Fri, 29 Jan 2016 11:22:12 -0800 Subject: [SPARK-12656] [SQL] Implement Intersect with Left-semi Join Our current Intersect physical operator simply delegates to RDD.intersect. We should remove the Intersect physical operator and simply transform a logical intersect into a semi-join with distinct. This way, we can take advantage of all the benefits of join implementations (e.g. managed memory, code generation, broadcast joins). After a search, I found one of the mainstream RDBMS did the same. In their query explain, Intersect is replaced by Left-semi Join. Left-semi Join could help outer-join elimination in Optimizer, as shown in the PR: https://github.com/apache/spark/pull/10566 Author: gatorsmile Author: xiaoli Author: Xiao Li Closes #10630 from gatorsmile/IntersectBySemiJoin. --- .../spark/sql/catalyst/analysis/Analyzer.scala | 113 +++++++++++---------- .../sql/catalyst/analysis/CheckAnalysis.scala | 14 ++- .../spark/sql/catalyst/optimizer/Optimizer.scala | 45 ++++---- .../catalyst/plans/logical/basicOperators.scala | 32 ++++-- .../sql/catalyst/analysis/AnalysisSuite.scala | 5 + .../optimizer/AggregateOptimizeSuite.scala | 12 --- .../catalyst/optimizer/ReplaceOperatorSuite.scala | 59 +++++++++++ .../sql/catalyst/optimizer/SetOperationSuite.scala | 15 +-- .../spark/sql/execution/SparkStrategies.scala | 5 +- .../spark/sql/execution/basicOperators.scala | 12 --- .../org/apache/spark/sql/DataFrameSuite.scala | 21 ++++ 11 files changed, 211 insertions(+), 122 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 33d76eeb21..5fe700ee00 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -344,6 +344,63 @@ class Analyzer( } } + /** + * Generate a new logical plan for the right child with different expression IDs + * for all conflicting attributes. + */ + private def dedupRight (left: LogicalPlan, right: LogicalPlan): LogicalPlan = { + val conflictingAttributes = left.outputSet.intersect(right.outputSet) + logDebug(s"Conflicting attributes ${conflictingAttributes.mkString(",")} " + + s"between $left and $right") + + right.collect { + // Handle base relations that might appear more than once. + case oldVersion: MultiInstanceRelation + if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty => + val newVersion = oldVersion.newInstance() + (oldVersion, newVersion) + + // Handle projects that create conflicting aliases. + case oldVersion @ Project(projectList, _) + if findAliases(projectList).intersect(conflictingAttributes).nonEmpty => + (oldVersion, oldVersion.copy(projectList = newAliases(projectList))) + + case oldVersion @ Aggregate(_, aggregateExpressions, _) + if findAliases(aggregateExpressions).intersect(conflictingAttributes).nonEmpty => + (oldVersion, oldVersion.copy(aggregateExpressions = newAliases(aggregateExpressions))) + + case oldVersion: Generate + if oldVersion.generatedSet.intersect(conflictingAttributes).nonEmpty => + val newOutput = oldVersion.generatorOutput.map(_.newInstance()) + (oldVersion, oldVersion.copy(generatorOutput = newOutput)) + + case oldVersion @ Window(_, windowExpressions, _, _, child) + if AttributeSet(windowExpressions.map(_.toAttribute)).intersect(conflictingAttributes) + .nonEmpty => + (oldVersion, oldVersion.copy(windowExpressions = newAliases(windowExpressions))) + } + // Only handle first case, others will be fixed on the next pass. + .headOption match { + case None => + /* + * No result implies that there is a logical plan node that produces new references + * that this rule cannot handle. When that is the case, there must be another rule + * that resolves these conflicts. Otherwise, the analysis will fail. + */ + right + case Some((oldRelation, newRelation)) => + val attributeRewrites = AttributeMap(oldRelation.output.zip(newRelation.output)) + val newRight = right transformUp { + case r if r == oldRelation => newRelation + } transformUp { + case other => other transformExpressions { + case a: Attribute => attributeRewrites.get(a).getOrElse(a) + } + } + newRight + } + } + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case p: LogicalPlan if !p.childrenResolved => p @@ -388,57 +445,11 @@ class Analyzer( .map(_.asInstanceOf[NamedExpression]) a.copy(aggregateExpressions = expanded) - // Special handling for cases when self-join introduce duplicate expression ids. - case j @ Join(left, right, _, _) if !j.selfJoinResolved => - val conflictingAttributes = left.outputSet.intersect(right.outputSet) - logDebug(s"Conflicting attributes ${conflictingAttributes.mkString(",")} in $j") - - right.collect { - // Handle base relations that might appear more than once. - case oldVersion: MultiInstanceRelation - if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty => - val newVersion = oldVersion.newInstance() - (oldVersion, newVersion) - - // Handle projects that create conflicting aliases. - case oldVersion @ Project(projectList, _) - if findAliases(projectList).intersect(conflictingAttributes).nonEmpty => - (oldVersion, oldVersion.copy(projectList = newAliases(projectList))) - - case oldVersion @ Aggregate(_, aggregateExpressions, _) - if findAliases(aggregateExpressions).intersect(conflictingAttributes).nonEmpty => - (oldVersion, oldVersion.copy(aggregateExpressions = newAliases(aggregateExpressions))) - - case oldVersion: Generate - if oldVersion.generatedSet.intersect(conflictingAttributes).nonEmpty => - val newOutput = oldVersion.generatorOutput.map(_.newInstance()) - (oldVersion, oldVersion.copy(generatorOutput = newOutput)) - - case oldVersion @ Window(_, windowExpressions, _, _, child) - if AttributeSet(windowExpressions.map(_.toAttribute)).intersect(conflictingAttributes) - .nonEmpty => - (oldVersion, oldVersion.copy(windowExpressions = newAliases(windowExpressions))) - } - // Only handle first case, others will be fixed on the next pass. - .headOption match { - case None => - /* - * No result implies that there is a logical plan node that produces new references - * that this rule cannot handle. When that is the case, there must be another rule - * that resolves these conflicts. Otherwise, the analysis will fail. - */ - j - case Some((oldRelation, newRelation)) => - val attributeRewrites = AttributeMap(oldRelation.output.zip(newRelation.output)) - val newRight = right transformUp { - case r if r == oldRelation => newRelation - } transformUp { - case other => other transformExpressions { - case a: Attribute => attributeRewrites.get(a).getOrElse(a) - } - } - j.copy(right = newRight) - } + // To resolve duplicate expression IDs for Join and Intersect + case j @ Join(left, right, _, _) if !j.duplicateResolved => + j.copy(right = dedupRight(left, right)) + case i @ Intersect(left, right) if !i.duplicateResolved => + i.copy(right = dedupRight(left, right)) // When resolve `SortOrder`s in Sort based on child, don't report errors as // we still have chance to resolve it based on grandchild diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index f2e78d9744..4a2f2b8bc6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -214,9 +214,8 @@ trait CheckAnalysis { s"""Only a single table generating function is allowed in a SELECT clause, found: | ${exprs.map(_.prettyString).mkString(",")}""".stripMargin) - // Special handling for cases when self-join introduce duplicate expression ids. - case j @ Join(left, right, _, _) if left.outputSet.intersect(right.outputSet).nonEmpty => - val conflictingAttributes = left.outputSet.intersect(right.outputSet) + case j: Join if !j.duplicateResolved => + val conflictingAttributes = j.left.outputSet.intersect(j.right.outputSet) failAnalysis( s""" |Failure when resolving conflicting references in Join: @@ -224,6 +223,15 @@ trait CheckAnalysis { |Conflicting attributes: ${conflictingAttributes.mkString(",")} |""".stripMargin) + case i: Intersect if !i.duplicateResolved => + val conflictingAttributes = i.left.outputSet.intersect(i.right.outputSet) + failAnalysis( + s""" + |Failure when resolving conflicting references in Intersect: + |$plan + |Conflicting attributes: ${conflictingAttributes.mkString(",")} + |""".stripMargin) + case o if !o.resolved => failAnalysis( s"unresolved operator ${operator.simpleString}") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 6addc20806..f156b5d10a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -52,8 +52,10 @@ abstract class Optimizer extends RuleExecutor[LogicalPlan] { // since the other rules might make two separate Unions operators adjacent. Batch("Union", Once, CombineUnions) :: + Batch("Replace Operators", FixedPoint(100), + ReplaceIntersectWithSemiJoin, + ReplaceDistinctWithAggregate) :: Batch("Aggregate", FixedPoint(100), - ReplaceDistinctWithAggregate, RemoveLiteralFromGroupExpressions) :: Batch("Operator Optimizations", FixedPoint(100), // Operator push down @@ -124,18 +126,13 @@ object EliminateSerialization extends Rule[LogicalPlan] { } /** - * Pushes certain operations to both sides of a Union, Intersect or Except operator. + * Pushes certain operations to both sides of a Union or Except operator. * Operations that are safe to pushdown are listed as follows. * Union: * Right now, Union means UNION ALL, which does not de-duplicate rows. So, it is * safe to pushdown Filters and Projections through it. Once we add UNION DISTINCT, * we will not be able to pushdown Projections. * - * Intersect: - * It is not safe to pushdown Projections through it because we need to get the - * intersect of rows by comparing the entire rows. It is fine to pushdown Filters - * with deterministic condition. - * * Except: * It is not safe to pushdown Projections through it because we need to get the * intersect of rows by comparing the entire rows. It is fine to pushdown Filters @@ -153,7 +150,7 @@ object SetOperationPushDown extends Rule[LogicalPlan] with PredicateHelper { /** * Rewrites an expression so that it can be pushed to the right side of a - * Union, Intersect or Except operator. This method relies on the fact that the output attributes + * Union or Except operator. This method relies on the fact that the output attributes * of a union/intersect/except are always equal to the left child's output. */ private def pushToRight[A <: Expression](e: A, rewrites: AttributeMap[Attribute]) = { @@ -210,17 +207,6 @@ object SetOperationPushDown extends Rule[LogicalPlan] with PredicateHelper { } Filter(nondeterministic, Union(newFirstChild +: newOtherChildren)) - // Push down filter through INTERSECT - case Filter(condition, Intersect(left, right)) => - val (deterministic, nondeterministic) = partitionByDeterministic(condition) - val rewrites = buildRewrites(left, right) - Filter(nondeterministic, - Intersect( - Filter(deterministic, left), - Filter(pushToRight(deterministic, rewrites), right) - ) - ) - // Push down filter through EXCEPT case Filter(condition, Except(left, right)) => val (deterministic, nondeterministic) = partitionByDeterministic(condition) @@ -1054,6 +1040,27 @@ object ReplaceDistinctWithAggregate extends Rule[LogicalPlan] { } } +/** + * Replaces logical [[Intersect]] operator with a left-semi [[Join]] operator. + * {{{ + * SELECT a1, a2 FROM Tab1 INTERSECT SELECT b1, b2 FROM Tab2 + * ==> SELECT DISTINCT a1, a2 FROM Tab1 LEFT SEMI JOIN Tab2 ON a1<=>b1 AND a2<=>b2 + * }}} + * + * Note: + * 1. This rule is only applicable to INTERSECT DISTINCT. Do not use it for INTERSECT ALL. + * 2. This rule has to be done after de-duplicating the attributes; otherwise, the generated + * join conditions will be incorrect. + */ +object ReplaceIntersectWithSemiJoin extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case Intersect(left, right) => + assert(left.output.size == right.output.size) + val joinCond = left.output.zip(right.output).map { case (l, r) => EqualNullSafe(l, r) } + Distinct(Join(left, right, LeftSemi, joinCond.reduceLeftOption(And))) + } +} + /** * Removes literals from group expressions in [[Aggregate]], as they have no effect to the result * but only makes the grouping key bigger. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index e9c970cd08..16f4b355b1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.plans.logical import scala.collection.mutable.ArrayBuffer +import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans._ @@ -90,12 +91,7 @@ case class Filter(condition: Expression, child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output } -abstract class SetOperation(left: LogicalPlan, right: LogicalPlan) extends BinaryNode { - final override lazy val resolved: Boolean = - childrenResolved && - left.output.length == right.output.length && - left.output.zip(right.output).forall { case (l, r) => l.dataType == r.dataType } -} +abstract class SetOperation(left: LogicalPlan, right: LogicalPlan) extends BinaryNode private[sql] object SetOperation { def unapply(p: SetOperation): Option[(LogicalPlan, LogicalPlan)] = Some((p.left, p.right)) @@ -103,15 +99,30 @@ private[sql] object SetOperation { case class Intersect(left: LogicalPlan, right: LogicalPlan) extends SetOperation(left, right) { + def duplicateResolved: Boolean = left.outputSet.intersect(right.outputSet).isEmpty + override def output: Seq[Attribute] = left.output.zip(right.output).map { case (leftAttr, rightAttr) => leftAttr.withNullability(leftAttr.nullable && rightAttr.nullable) } + + // Intersect are only resolved if they don't introduce ambiguous expression ids, + // since the Optimizer will convert Intersect to Join. + override lazy val resolved: Boolean = + childrenResolved && + left.output.length == right.output.length && + left.output.zip(right.output).forall { case (l, r) => l.dataType == r.dataType } && + duplicateResolved } case class Except(left: LogicalPlan, right: LogicalPlan) extends SetOperation(left, right) { /** We don't use right.output because those rows get excluded from the set. */ override def output: Seq[Attribute] = left.output + + override lazy val resolved: Boolean = + childrenResolved && + left.output.length == right.output.length && + left.output.zip(right.output).forall { case (l, r) => l.dataType == r.dataType } } /** Factory for constructing new `Union` nodes. */ @@ -169,13 +180,13 @@ case class Join( } } - def selfJoinResolved: Boolean = left.outputSet.intersect(right.outputSet).isEmpty + def duplicateResolved: Boolean = left.outputSet.intersect(right.outputSet).isEmpty // Joins are only resolved if they don't introduce ambiguous expression ids. override lazy val resolved: Boolean = { childrenResolved && expressions.forall(_.resolved) && - selfJoinResolved && + duplicateResolved && condition.forall(_.dataType == BooleanType) } } @@ -249,7 +260,7 @@ case class Range( end: Long, step: Long, numSlices: Int, - output: Seq[Attribute]) extends LeafNode { + output: Seq[Attribute]) extends LeafNode with MultiInstanceRelation { require(step != 0, "step cannot be 0") val numElements: BigInt = { val safeStart = BigInt(start) @@ -262,6 +273,9 @@ case class Range( } } + override def newInstance(): Range = + Range(start, end, step, numSlices, output.map(_.newInstance())) + override def statistics: Statistics = { val sizeInBytes = LongType.defaultSize * numElements Statistics( sizeInBytes = sizeInBytes ) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index ab68028220..1938bce02a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -154,6 +154,11 @@ class AnalysisSuite extends AnalysisTest { checkAnalysis(plan, expected) } + test("self intersect should resolve duplicate expression IDs") { + val plan = testRelation.intersect(testRelation) + assertAnalysisSuccess(plan) + } + test("SPARK-8654: invalid CAST in NULL IN(...) expression") { val plan = Project(Alias(In(Literal(null), Seq(Literal(1), Literal(2))), "a")() :: Nil, LocalRelation() diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateOptimizeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateOptimizeSuite.scala index 37148a226f..a4a12c0d62 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateOptimizeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateOptimizeSuite.scala @@ -28,21 +28,9 @@ class AggregateOptimizeSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { val batches = Batch("Aggregate", FixedPoint(100), - ReplaceDistinctWithAggregate, RemoveLiteralFromGroupExpressions) :: Nil } - test("replace distinct with aggregate") { - val input = LocalRelation('a.int, 'b.int) - - val query = Distinct(input) - val optimized = Optimize.execute(query.analyze) - - val correctAnswer = Aggregate(input.output, input.output, input) - - comparePlans(optimized, correctAnswer) - } - test("remove literals in grouping expression") { val input = LocalRelation('a.int, 'b.int) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala new file mode 100644 index 0000000000..f8ae5d9be2 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.plans.{LeftSemi, PlanTest} +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules.RuleExecutor + +class ReplaceOperatorSuite extends PlanTest { + + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = + Batch("Replace Operators", FixedPoint(100), + ReplaceDistinctWithAggregate, + ReplaceIntersectWithSemiJoin) :: Nil + } + + test("replace Intersect with Left-semi Join") { + val table1 = LocalRelation('a.int, 'b.int) + val table2 = LocalRelation('c.int, 'd.int) + + val query = Intersect(table1, table2) + val optimized = Optimize.execute(query.analyze) + + val correctAnswer = + Aggregate(table1.output, table1.output, + Join(table1, table2, LeftSemi, Option('a <=> 'c && 'b <=> 'd))).analyze + + comparePlans(optimized, correctAnswer) + } + + test("replace Distinct with Aggregate") { + val input = LocalRelation('a.int, 'b.int) + + val query = Distinct(input) + val optimized = Optimize.execute(query.analyze) + + val correctAnswer = Aggregate(input.output, input.output, input) + + comparePlans(optimized, correctAnswer) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala index 2283f7c008..b8ea32b4df 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala @@ -39,7 +39,6 @@ class SetOperationSuite extends PlanTest { val testRelation2 = LocalRelation('d.int, 'e.int, 'f.int) val testRelation3 = LocalRelation('g.int, 'h.int, 'i.int) val testUnion = Union(testRelation :: testRelation2 :: testRelation3 :: Nil) - val testIntersect = Intersect(testRelation, testRelation2) val testExcept = Except(testRelation, testRelation2) test("union: combine unions into one unions") { @@ -57,19 +56,12 @@ class SetOperationSuite extends PlanTest { comparePlans(combinedUnionsOptimized, unionOptimized3) } - test("intersect/except: filter to each side") { - val intersectQuery = testIntersect.where('b < 10) + test("except: filter to each side") { val exceptQuery = testExcept.where('c >= 5) - - val intersectOptimized = Optimize.execute(intersectQuery.analyze) val exceptOptimized = Optimize.execute(exceptQuery.analyze) - - val intersectCorrectAnswer = - Intersect(testRelation.where('b < 10), testRelation2.where('e < 10)).analyze val exceptCorrectAnswer = Except(testRelation.where('c >= 5), testRelation2.where('f >= 5)).analyze - comparePlans(intersectOptimized, intersectCorrectAnswer) comparePlans(exceptOptimized, exceptCorrectAnswer) } @@ -95,13 +87,8 @@ class SetOperationSuite extends PlanTest { } test("SPARK-10539: Project should not be pushed down through Intersect or Except") { - val intersectQuery = testIntersect.select('b, 'c) val exceptQuery = testExcept.select('a, 'b, 'c) - - val intersectOptimized = Optimize.execute(intersectQuery.analyze) val exceptOptimized = Optimize.execute(exceptQuery.analyze) - - comparePlans(intersectOptimized, intersectQuery.analyze) comparePlans(exceptOptimized, exceptQuery.analyze) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 60fbb595e5..9293e55141 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -298,6 +298,9 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case logical.Distinct(child) => throw new IllegalStateException( "logical distinct operator should have been replaced by aggregate in the optimizer") + case logical.Intersect(left, right) => + throw new IllegalStateException( + "logical intersect operator should have been replaced by semi-join in the optimizer") case logical.MapPartitions(f, in, out, child) => execution.MapPartitions(f, in, out, planLater(child)) :: Nil @@ -340,8 +343,6 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { execution.Union(unionChildren.map(planLater)) :: Nil case logical.Except(left, right) => execution.Except(planLater(left), planLater(right)) :: Nil - case logical.Intersect(left, right) => - execution.Intersect(planLater(left), planLater(right)) :: Nil case g @ logical.Generate(generator, join, outer, _, _, child) => execution.Generate( generator, join = join, outer = outer, g.output, planLater(child)) :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index e7a73d5fbb..fd81531c93 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -420,18 +420,6 @@ case class Except(left: SparkPlan, right: SparkPlan) extends BinaryNode { } } -/** - * Returns the rows in left that also appear in right using the built in spark - * intersection function. - */ -case class Intersect(left: SparkPlan, right: SparkPlan) extends BinaryNode { - override def output: Seq[Attribute] = children.head.output - - protected override def doExecute(): RDD[InternalRow] = { - left.execute().map(_.copy()).intersection(right.execute().map(_.copy())) - } -} - /** * A plan node that does nothing but lie about the output of its child. Used to spice a * (hopefully structurally equivalent) tree from a different optimization sequence into an already diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 09bbe57a43..4ff99bdf29 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -349,6 +349,27 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { Row(3, "c") :: Row(4, "d") :: Nil) checkAnswer(lowerCaseData.intersect(upperCaseData), Nil) + + // check null equality + checkAnswer( + nullInts.intersect(nullInts), + Row(1) :: + Row(2) :: + Row(3) :: + Row(null) :: Nil) + + // check if values are de-duplicated + checkAnswer( + allNulls.intersect(allNulls), + Row(null) :: Nil) + + // check if values are de-duplicated + val df = Seq(("id1", 1), ("id1", 1), ("id", 1), ("id1", 2)).toDF("id", "value") + checkAnswer( + df.intersect(df), + Row("id1", 1) :: + Row("id", 1) :: + Row("id1", 2) :: Nil) } test("intersect - nullability") { -- cgit v1.2.3