aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorgatorsmile <gatorsmile@gmail.com>2016-04-29 15:30:36 +0800
committerWenchen Fan <wenchen@databricks.com>2016-04-29 15:30:36 +0800
commit222dcf79377df33007d7a9780dafa2c740dbe6a3 (patch)
treee251b64b68f42d99d2de4ed96b95ca0b0ff1419c
parente249e6f8b551614c82cd62e827c3647166e918e3 (diff)
downloadspark-222dcf79377df33007d7a9780dafa2c740dbe6a3.tar.gz
spark-222dcf79377df33007d7a9780dafa2c740dbe6a3.tar.bz2
spark-222dcf79377df33007d7a9780dafa2c740dbe6a3.zip
[SPARK-12660][SPARK-14967][SQL] Implement Except Distinct by Left Anti Join
#### What changes were proposed in this pull request? Replaces a logical `Except` operator with a `Left-anti Join` operator. This way, we can take advantage of all the benefits of join implementations (e.g. managed memory, code generation, broadcast joins). ```SQL SELECT a1, a2 FROM Tab1 EXCEPT SELECT b1, b2 FROM Tab2 ==> SELECT DISTINCT a1, a2 FROM Tab1 LEFT ANTI JOIN Tab2 ON a1<=>b1 AND a2<=>b2 ``` Note: 1. This rule is only applicable to EXCEPT DISTINCT. Do not use it for EXCEPT ALL. 2. This rule has to be done after de-duplicating the attributes; otherwise, the enerated join conditions will be incorrect. This PR also corrects the existing behavior in Spark. Before this PR, the behavior is like ```SQL test("except") { val df_left = Seq(1, 2, 2, 3, 3, 4).toDF("id") val df_right = Seq(1, 3).toDF("id") checkAnswer( df_left.except(df_right), Row(2) :: Row(2) :: Row(4) :: Nil ) } ``` After this PR, the result is corrected. We strictly follow the SQL compliance of `Except Distinct`. #### How was this patch tested? Modified and added a few test cases to verify the optimization rule and the results of operators. Author: gatorsmile <gatorsmile@gmail.com> Closes #12736 from gatorsmile/exceptByAntiJoin.
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala11
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala60
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala6
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala8
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala34
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala17
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala16
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala5
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala12
-rw-r--r--sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala70
12 files changed, 132 insertions, 111 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index e37d9760cc..f6a65f7e6c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -530,6 +530,8 @@ class Analyzer(
j.copy(right = dedupRight(left, right))
case i @ Intersect(left, right) if !i.duplicateResolved =>
i.copy(right = dedupRight(left, right))
+ case i @ Except(left, right) if !i.duplicateResolved =>
+ i.copy(right = dedupRight(left, right))
// When resolve `SortOrder`s in Sort based on child, don't report errors as
// we still have chance to resolve it based on its descendants
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
index 6b737d6b78..74f434e063 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
@@ -283,7 +283,16 @@ trait CheckAnalysis extends PredicateHelper {
|Failure when resolving conflicting references in Intersect:
|$plan
|Conflicting attributes: ${conflictingAttributes.mkString(",")}
- |""".stripMargin)
+ """.stripMargin)
+
+ case e: Except if !e.duplicateResolved =>
+ val conflictingAttributes = e.left.outputSet.intersect(e.right.outputSet)
+ failAnalysis(
+ s"""
+ |Failure when resolving conflicting references in Except:
+ |$plan
+ |Conflicting attributes: ${conflictingAttributes.mkString(",")}
+ """.stripMargin)
case o if !o.resolved =>
failAnalysis(
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 54bf4a5293..434c033c49 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -65,6 +65,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: CatalystConf)
CombineUnions) ::
Batch("Replace Operators", fixedPoint,
ReplaceIntersectWithSemiJoin,
+ ReplaceExceptWithAntiJoin,
ReplaceDistinctWithAggregate) ::
Batch("Aggregate", fixedPoint,
RemoveLiteralFromGroupExpressions) ::
@@ -232,17 +233,12 @@ object LimitPushDown extends Rule[LogicalPlan] {
}
/**
- * Pushes certain operations to both sides of a Union or Except operator.
+ * Pushes certain operations to both sides of a Union operator.
* Operations that are safe to pushdown are listed as follows.
* Union:
* Right now, Union means UNION ALL, which does not de-duplicate rows. So, it is
* safe to pushdown Filters and Projections through it. Once we add UNION DISTINCT,
* we will not be able to pushdown Projections.
- *
- * Except:
- * It is not safe to pushdown Projections through it because we need to get the
- * intersect of rows by comparing the entire rows. It is fine to pushdown Filters
- * with deterministic condition.
*/
object SetOperationPushDown extends Rule[LogicalPlan] with PredicateHelper {
@@ -310,17 +306,6 @@ object SetOperationPushDown extends Rule[LogicalPlan] with PredicateHelper {
Filter(pushToRight(deterministic, rewrites), child)
}
Filter(nondeterministic, Union(newFirstChild +: newOtherChildren))
-
- // Push down filter through EXCEPT
- case Filter(condition, Except(left, right)) =>
- val (deterministic, nondeterministic) = partitionByDeterministic(condition)
- val rewrites = buildRewrites(left, right)
- Filter(nondeterministic,
- Except(
- Filter(deterministic, left),
- Filter(pushToRight(deterministic, rewrites), right)
- )
- )
}
}
@@ -1007,16 +992,15 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper {
filter
}
- case filter @ Filter(condition, child)
- if child.isInstanceOf[Union] || child.isInstanceOf[Intersect] =>
- // Union/Intersect could change the rows, so non-deterministic predicate can't be pushed down
+ case filter @ Filter(condition, union: Union) =>
+ // Union could change the rows, so non-deterministic predicate can't be pushed down
val (pushDown, stayUp) = splitConjunctivePredicates(condition).partition { cond =>
cond.deterministic
}
if (pushDown.nonEmpty) {
val pushDownCond = pushDown.reduceLeft(And)
- val output = child.output
- val newGrandChildren = child.children.map { grandchild =>
+ val output = union.output
+ val newGrandChildren = union.children.map { grandchild =>
val newCond = pushDownCond transform {
case e if output.exists(_.semanticEquals(e)) =>
grandchild.output(output.indexWhere(_.semanticEquals(e)))
@@ -1024,21 +1008,16 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper {
assert(newCond.references.subsetOf(grandchild.outputSet))
Filter(newCond, grandchild)
}
- val newChild = child.withNewChildren(newGrandChildren)
+ val newUnion = union.withNewChildren(newGrandChildren)
if (stayUp.nonEmpty) {
- Filter(stayUp.reduceLeft(And), newChild)
+ Filter(stayUp.reduceLeft(And), newUnion)
} else {
- newChild
+ newUnion
}
} else {
filter
}
- case filter @ Filter(condition, e @ Except(left, _)) =>
- pushDownPredicate(filter, e.left) { predicate =>
- e.copy(left = Filter(predicate, left))
- }
-
// two filters should be combine together by other rules
case filter @ Filter(_, f: Filter) => filter
// should not push predicates through sample, or will generate different results.
@@ -1423,6 +1402,27 @@ object ReplaceIntersectWithSemiJoin extends Rule[LogicalPlan] {
}
/**
+ * Replaces logical [[Except]] operator with a left-anti [[Join]] operator.
+ * {{{
+ * SELECT a1, a2 FROM Tab1 EXCEPT SELECT b1, b2 FROM Tab2
+ * ==> SELECT DISTINCT a1, a2 FROM Tab1 LEFT ANTI JOIN Tab2 ON a1<=>b1 AND a2<=>b2
+ * }}}
+ *
+ * Note:
+ * 1. This rule is only applicable to EXCEPT DISTINCT. Do not use it for EXCEPT ALL.
+ * 2. This rule has to be done after de-duplicating the attributes; otherwise, the generated
+ * join conditions will be incorrect.
+ */
+object ReplaceExceptWithAntiJoin extends Rule[LogicalPlan] {
+ def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+ case Except(left, right) =>
+ assert(left.output.size == right.output.size)
+ val joinCond = left.output.zip(right.output).map { case (l, r) => EqualNullSafe(l, r) }
+ Distinct(Join(left, right, LeftAnti, joinCond.reduceLeftOption(And)))
+ }
+}
+
+/**
* Removes literals from group expressions in [[Aggregate]], as they have no effect to the result
* but only makes the grouping key bigger.
*/
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
index a445ce6947..b358e210da 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
@@ -165,6 +165,9 @@ case class Intersect(left: LogicalPlan, right: LogicalPlan) extends SetOperation
}
case class Except(left: LogicalPlan, right: LogicalPlan) extends SetOperation(left, right) {
+
+ def duplicateResolved: Boolean = left.outputSet.intersect(right.outputSet).isEmpty
+
/** We don't use right.output because those rows get excluded from the set. */
override def output: Seq[Attribute] = left.output
@@ -173,7 +176,8 @@ case class Except(left: LogicalPlan, right: LogicalPlan) extends SetOperation(le
override lazy val resolved: Boolean =
childrenResolved &&
left.output.length == right.output.length &&
- left.output.zip(right.output).forall { case (l, r) => l.dataType == r.dataType }
+ left.output.zip(right.output).forall { case (l, r) => l.dataType == r.dataType } &&
+ duplicateResolved
override def statistics: Statistics = {
Statistics(sizeInBytes = left.statistics.sizeInBytes)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
index 18de8b152b..b591861ac0 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
@@ -488,14 +488,6 @@ class HiveTypeCoercionSuite extends PlanTest {
assert(r1.right.isInstanceOf[Project])
assert(r2.left.isInstanceOf[Project])
assert(r2.right.isInstanceOf[Project])
-
- val r3 = wt(Except(firstTable, firstTable)).asInstanceOf[Except]
- checkOutput(r3.left, Seq(IntegerType, DecimalType.SYSTEM_DEFAULT, ByteType, DoubleType))
- checkOutput(r3.right, Seq(IntegerType, DecimalType.SYSTEM_DEFAULT, ByteType, DoubleType))
-
- // Check if no Project is added
- assert(r3.left.isInstanceOf[LocalRelation])
- assert(r3.right.isInstanceOf[LocalRelation])
}
test("WidenSetOperationTypes for union") {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
index e2cc80c564..e9b4bb002b 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
@@ -710,40 +710,6 @@ class FilterPushdownSuite extends PlanTest {
comparePlans(optimized, correctAnswer)
}
- test("intersect") {
- val testRelation2 = LocalRelation('d.int, 'e.int, 'f.int)
-
- val originalQuery = Intersect(testRelation, testRelation2)
- .where('a === 2L && 'b + Rand(10).as("rnd") === 3)
-
- val optimized = Optimize.execute(originalQuery.analyze)
-
- val correctAnswer = Intersect(
- testRelation.where('a === 2L),
- testRelation2.where('d === 2L))
- .where('b + Rand(10).as("rnd") === 3)
- .analyze
-
- comparePlans(optimized, correctAnswer)
- }
-
- test("except") {
- val testRelation2 = LocalRelation('d.int, 'e.int, 'f.int)
-
- val originalQuery = Except(testRelation, testRelation2)
- .where('a === 2L && 'b + Rand(10).as("rnd") === 3)
-
- val optimized = Optimize.execute(originalQuery.analyze)
-
- val correctAnswer = Except(
- testRelation.where('a === 2L),
- testRelation2)
- .where('b + Rand(10).as("rnd") === 3)
- .analyze
-
- comparePlans(optimized, correctAnswer)
- }
-
test("expand") {
val agg = testRelation
.groupBy(Cube(Seq('a, 'b)))('a, 'b, sum('c))
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala
index f8ae5d9be2..f23e262f28 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.optimizer
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
-import org.apache.spark.sql.catalyst.plans.{LeftSemi, PlanTest}
+import org.apache.spark.sql.catalyst.plans.{LeftAnti, LeftSemi, PlanTest}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.RuleExecutor
@@ -29,6 +29,7 @@ class ReplaceOperatorSuite extends PlanTest {
val batches =
Batch("Replace Operators", FixedPoint(100),
ReplaceDistinctWithAggregate,
+ ReplaceExceptWithAntiJoin,
ReplaceIntersectWithSemiJoin) :: Nil
}
@@ -46,6 +47,20 @@ class ReplaceOperatorSuite extends PlanTest {
comparePlans(optimized, correctAnswer)
}
+ test("replace Except with Left-anti Join") {
+ val table1 = LocalRelation('a.int, 'b.int)
+ val table2 = LocalRelation('c.int, 'd.int)
+
+ val query = Except(table1, table2)
+ val optimized = Optimize.execute(query.analyze)
+
+ val correctAnswer =
+ Aggregate(table1.output, table1.output,
+ Join(table1, table2, LeftAnti, Option('a <=> 'c && 'b <=> 'd))).analyze
+
+ comparePlans(optimized, correctAnswer)
+ }
+
test("replace Distinct with Aggregate") {
val input = LocalRelation('a.int, 'b.int)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala
index b08cdc8a36..83ca9d5ec9 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala
@@ -39,7 +39,6 @@ class SetOperationSuite extends PlanTest {
val testRelation2 = LocalRelation('d.int, 'e.int, 'f.int)
val testRelation3 = LocalRelation('g.int, 'h.int, 'i.int)
val testUnion = Union(testRelation :: testRelation2 :: testRelation3 :: Nil)
- val testExcept = Except(testRelation, testRelation2)
test("union: combine unions into one unions") {
val unionQuery1 = Union(Union(testRelation, testRelation2), testRelation)
@@ -56,15 +55,6 @@ class SetOperationSuite extends PlanTest {
comparePlans(combinedUnionsOptimized, unionOptimized3)
}
- test("except: filter to each side") {
- val exceptQuery = testExcept.where('c >= 5)
- val exceptOptimized = Optimize.execute(exceptQuery.analyze)
- val exceptCorrectAnswer =
- Except(testRelation.where('c >= 5), testRelation2.where('f >= 5)).analyze
-
- comparePlans(exceptOptimized, exceptCorrectAnswer)
- }
-
test("union: filter to each side") {
val unionQuery = testUnion.where('a === 1)
val unionOptimized = Optimize.execute(unionQuery.analyze)
@@ -85,10 +75,4 @@ class SetOperationSuite extends PlanTest {
testRelation3.select('g) :: Nil).analyze
comparePlans(unionOptimized, unionCorrectAnswer)
}
-
- test("SPARK-10539: Project should not be pushed down through Intersect or Except") {
- val exceptQuery = testExcept.select('a, 'b, 'c)
- val exceptOptimized = Optimize.execute(exceptQuery.analyze)
- comparePlans(exceptOptimized, exceptQuery.analyze)
- }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 3955c5dc92..1eb1f8ef11 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -297,6 +297,9 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
case logical.Intersect(left, right) =>
throw new IllegalStateException(
"logical intersect operator should have been replaced by semi-join in the optimizer")
+ case logical.Except(left, right) =>
+ throw new IllegalStateException(
+ "logical except operator should have been replaced by anti-join in the optimizer")
case logical.DeserializeToObject(deserializer, objAttr, child) =>
execution.DeserializeToObject(deserializer, objAttr, planLater(child)) :: Nil
@@ -347,8 +350,6 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
execution.GlobalLimitExec(limit, planLater(child)) :: Nil
case logical.Union(unionChildren) =>
execution.UnionExec(unionChildren.map(planLater)) :: Nil
- case logical.Except(left, right) =>
- execution.ExceptExec(planLater(left), planLater(right)) :: Nil
case g @ logical.Generate(generator, join, outer, _, _, child) =>
execution.GenerateExec(
generator, join = join, outer = outer, g.output, planLater(child)) :: Nil
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
index 77be613b83..d492fa7c41 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
@@ -492,18 +492,6 @@ case class CoalesceExec(numPartitions: Int, child: SparkPlan) extends UnaryExecN
}
/**
- * Physical plan for returning a table with the elements from left that are not in right using
- * the built-in spark subtract function.
- */
-case class ExceptExec(left: SparkPlan, right: SparkPlan) extends BinaryExecNode {
- override def output: Seq[Attribute] = left.output
-
- protected override def doExecute(): RDD[InternalRow] = {
- left.execute().map(_.copy()).subtract(right.execute().map(_.copy()))
- }
-}
-
-/**
* A plan node that does nothing but lie about the output of its child. Used to spice a
* (hopefully structurally equivalent) tree from a different optimization sequence into an already
* resolved tree.
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
index 5abd62cbc2..f1b1c22e4a 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
@@ -291,7 +291,7 @@ public class JavaDatasetSuite implements Serializable {
unioned.collectAsList());
Dataset<String> subtracted = ds.except(ds2);
- Assert.assertEquals(Arrays.asList("abc", "abc"), subtracted.collectAsList());
+ Assert.assertEquals(Arrays.asList("abc"), subtracted.collectAsList());
}
private static <T> Set<T> toSet(List<T> records) {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index 681476b6e2..f10d8372ed 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -398,6 +398,66 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
Row(4, "d") :: Nil)
checkAnswer(lowerCaseData.except(lowerCaseData), Nil)
checkAnswer(upperCaseData.except(upperCaseData), Nil)
+
+ // check null equality
+ checkAnswer(
+ nullInts.except(nullInts.filter("0 = 1")),
+ nullInts)
+ checkAnswer(
+ nullInts.except(nullInts),
+ Nil)
+
+ // check if values are de-duplicated
+ checkAnswer(
+ allNulls.except(allNulls.filter("0 = 1")),
+ Row(null) :: Nil)
+ checkAnswer(
+ allNulls.except(allNulls),
+ Nil)
+
+ // check if values are de-duplicated
+ val df = Seq(("id1", 1), ("id1", 1), ("id", 1), ("id1", 2)).toDF("id", "value")
+ checkAnswer(
+ df.except(df.filter("0 = 1")),
+ Row("id1", 1) ::
+ Row("id", 1) ::
+ Row("id1", 2) :: Nil)
+
+ // check if the empty set on the left side works
+ checkAnswer(
+ allNulls.filter("0 = 1").except(allNulls),
+ Nil)
+ }
+
+ test("except distinct - SQL compliance") {
+ val df_left = Seq(1, 2, 2, 3, 3, 4).toDF("id")
+ val df_right = Seq(1, 3).toDF("id")
+
+ checkAnswer(
+ df_left.except(df_right),
+ Row(2) :: Row(4) :: Nil
+ )
+ }
+
+ test("except - nullability") {
+ val nonNullableInts = Seq(Tuple1(11), Tuple1(3)).toDF()
+ assert(nonNullableInts.schema.forall(!_.nullable))
+
+ val df1 = nonNullableInts.except(nullInts)
+ checkAnswer(df1, Row(11) :: Nil)
+ assert(df1.schema.forall(!_.nullable))
+
+ val df2 = nullInts.except(nonNullableInts)
+ checkAnswer(df2, Row(1) :: Row(2) :: Row(null) :: Nil)
+ assert(df2.schema.forall(_.nullable))
+
+ val df3 = nullInts.except(nullInts)
+ checkAnswer(df3, Nil)
+ assert(df3.schema.forall(_.nullable))
+
+ val df4 = nonNullableInts.except(nonNullableInts)
+ checkAnswer(df4, Nil)
+ assert(df4.schema.forall(!_.nullable))
}
test("intersect") {
@@ -433,23 +493,23 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
test("intersect - nullability") {
val nonNullableInts = Seq(Tuple1(1), Tuple1(3)).toDF()
- assert(nonNullableInts.schema.forall(_.nullable == false))
+ assert(nonNullableInts.schema.forall(!_.nullable))
val df1 = nonNullableInts.intersect(nullInts)
checkAnswer(df1, Row(1) :: Row(3) :: Nil)
- assert(df1.schema.forall(_.nullable == false))
+ assert(df1.schema.forall(!_.nullable))
val df2 = nullInts.intersect(nonNullableInts)
checkAnswer(df2, Row(1) :: Row(3) :: Nil)
- assert(df2.schema.forall(_.nullable == false))
+ assert(df2.schema.forall(!_.nullable))
val df3 = nullInts.intersect(nullInts)
checkAnswer(df3, Row(1) :: Row(2) :: Row(3) :: Row(null) :: Nil)
- assert(df3.schema.forall(_.nullable == true))
+ assert(df3.schema.forall(_.nullable))
val df4 = nonNullableInts.intersect(nonNullableInts)
checkAnswer(df4, Row(1) :: Row(3) :: Nil)
- assert(df4.schema.forall(_.nullable == false))
+ assert(df4.schema.forall(!_.nullable))
}
test("udf") {