aboutsummaryrefslogtreecommitdiff
path: root/sql/core/src
diff options
context:
space:
mode:
authorgatorsmile <gatorsmile@gmail.com>2016-04-29 15:30:36 +0800
committerWenchen Fan <wenchen@databricks.com>2016-04-29 15:30:36 +0800
commit222dcf79377df33007d7a9780dafa2c740dbe6a3 (patch)
treee251b64b68f42d99d2de4ed96b95ca0b0ff1419c /sql/core/src
parente249e6f8b551614c82cd62e827c3647166e918e3 (diff)
downloadspark-222dcf79377df33007d7a9780dafa2c740dbe6a3.tar.gz
spark-222dcf79377df33007d7a9780dafa2c740dbe6a3.tar.bz2
spark-222dcf79377df33007d7a9780dafa2c740dbe6a3.zip
[SPARK-12660][SPARK-14967][SQL] Implement Except Distinct by Left Anti Join
#### What changes were proposed in this pull request? Replaces a logical `Except` operator with a `Left-anti Join` operator. This way, we can take advantage of all the benefits of join implementations (e.g. managed memory, code generation, broadcast joins). ```SQL SELECT a1, a2 FROM Tab1 EXCEPT SELECT b1, b2 FROM Tab2 ==> SELECT DISTINCT a1, a2 FROM Tab1 LEFT ANTI JOIN Tab2 ON a1<=>b1 AND a2<=>b2 ``` Note: 1. This rule is only applicable to EXCEPT DISTINCT. Do not use it for EXCEPT ALL. 2. This rule has to be done after de-duplicating the attributes; otherwise, the enerated join conditions will be incorrect. This PR also corrects the existing behavior in Spark. Before this PR, the behavior is like ```SQL test("except") { val df_left = Seq(1, 2, 2, 3, 3, 4).toDF("id") val df_right = Seq(1, 3).toDF("id") checkAnswer( df_left.except(df_right), Row(2) :: Row(2) :: Row(4) :: Nil ) } ``` After this PR, the result is corrected. We strictly follow the SQL compliance of `Except Distinct`. #### How was this patch tested? Modified and added a few test cases to verify the optimization rule and the results of operators. Author: gatorsmile <gatorsmile@gmail.com> Closes #12736 from gatorsmile/exceptByAntiJoin.
Diffstat (limited to 'sql/core/src')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala5
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala12
-rw-r--r--sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala70
4 files changed, 69 insertions, 20 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 3955c5dc92..1eb1f8ef11 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -297,6 +297,9 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
case logical.Intersect(left, right) =>
throw new IllegalStateException(
"logical intersect operator should have been replaced by semi-join in the optimizer")
+ case logical.Except(left, right) =>
+ throw new IllegalStateException(
+ "logical except operator should have been replaced by anti-join in the optimizer")
case logical.DeserializeToObject(deserializer, objAttr, child) =>
execution.DeserializeToObject(deserializer, objAttr, planLater(child)) :: Nil
@@ -347,8 +350,6 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
execution.GlobalLimitExec(limit, planLater(child)) :: Nil
case logical.Union(unionChildren) =>
execution.UnionExec(unionChildren.map(planLater)) :: Nil
- case logical.Except(left, right) =>
- execution.ExceptExec(planLater(left), planLater(right)) :: Nil
case g @ logical.Generate(generator, join, outer, _, _, child) =>
execution.GenerateExec(
generator, join = join, outer = outer, g.output, planLater(child)) :: Nil
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
index 77be613b83..d492fa7c41 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
@@ -492,18 +492,6 @@ case class CoalesceExec(numPartitions: Int, child: SparkPlan) extends UnaryExecN
}
/**
- * Physical plan for returning a table with the elements from left that are not in right using
- * the built-in spark subtract function.
- */
-case class ExceptExec(left: SparkPlan, right: SparkPlan) extends BinaryExecNode {
- override def output: Seq[Attribute] = left.output
-
- protected override def doExecute(): RDD[InternalRow] = {
- left.execute().map(_.copy()).subtract(right.execute().map(_.copy()))
- }
-}
-
-/**
* A plan node that does nothing but lie about the output of its child. Used to spice a
* (hopefully structurally equivalent) tree from a different optimization sequence into an already
* resolved tree.
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
index 5abd62cbc2..f1b1c22e4a 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
@@ -291,7 +291,7 @@ public class JavaDatasetSuite implements Serializable {
unioned.collectAsList());
Dataset<String> subtracted = ds.except(ds2);
- Assert.assertEquals(Arrays.asList("abc", "abc"), subtracted.collectAsList());
+ Assert.assertEquals(Arrays.asList("abc"), subtracted.collectAsList());
}
private static <T> Set<T> toSet(List<T> records) {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index 681476b6e2..f10d8372ed 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -398,6 +398,66 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
Row(4, "d") :: Nil)
checkAnswer(lowerCaseData.except(lowerCaseData), Nil)
checkAnswer(upperCaseData.except(upperCaseData), Nil)
+
+ // check null equality
+ checkAnswer(
+ nullInts.except(nullInts.filter("0 = 1")),
+ nullInts)
+ checkAnswer(
+ nullInts.except(nullInts),
+ Nil)
+
+ // check if values are de-duplicated
+ checkAnswer(
+ allNulls.except(allNulls.filter("0 = 1")),
+ Row(null) :: Nil)
+ checkAnswer(
+ allNulls.except(allNulls),
+ Nil)
+
+ // check if values are de-duplicated
+ val df = Seq(("id1", 1), ("id1", 1), ("id", 1), ("id1", 2)).toDF("id", "value")
+ checkAnswer(
+ df.except(df.filter("0 = 1")),
+ Row("id1", 1) ::
+ Row("id", 1) ::
+ Row("id1", 2) :: Nil)
+
+ // check if the empty set on the left side works
+ checkAnswer(
+ allNulls.filter("0 = 1").except(allNulls),
+ Nil)
+ }
+
+ test("except distinct - SQL compliance") {
+ val df_left = Seq(1, 2, 2, 3, 3, 4).toDF("id")
+ val df_right = Seq(1, 3).toDF("id")
+
+ checkAnswer(
+ df_left.except(df_right),
+ Row(2) :: Row(4) :: Nil
+ )
+ }
+
+ test("except - nullability") {
+ val nonNullableInts = Seq(Tuple1(11), Tuple1(3)).toDF()
+ assert(nonNullableInts.schema.forall(!_.nullable))
+
+ val df1 = nonNullableInts.except(nullInts)
+ checkAnswer(df1, Row(11) :: Nil)
+ assert(df1.schema.forall(!_.nullable))
+
+ val df2 = nullInts.except(nonNullableInts)
+ checkAnswer(df2, Row(1) :: Row(2) :: Row(null) :: Nil)
+ assert(df2.schema.forall(_.nullable))
+
+ val df3 = nullInts.except(nullInts)
+ checkAnswer(df3, Nil)
+ assert(df3.schema.forall(_.nullable))
+
+ val df4 = nonNullableInts.except(nonNullableInts)
+ checkAnswer(df4, Nil)
+ assert(df4.schema.forall(!_.nullable))
}
test("intersect") {
@@ -433,23 +493,23 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
test("intersect - nullability") {
val nonNullableInts = Seq(Tuple1(1), Tuple1(3)).toDF()
- assert(nonNullableInts.schema.forall(_.nullable == false))
+ assert(nonNullableInts.schema.forall(!_.nullable))
val df1 = nonNullableInts.intersect(nullInts)
checkAnswer(df1, Row(1) :: Row(3) :: Nil)
- assert(df1.schema.forall(_.nullable == false))
+ assert(df1.schema.forall(!_.nullable))
val df2 = nullInts.intersect(nonNullableInts)
checkAnswer(df2, Row(1) :: Row(3) :: Nil)
- assert(df2.schema.forall(_.nullable == false))
+ assert(df2.schema.forall(!_.nullable))
val df3 = nullInts.intersect(nullInts)
checkAnswer(df3, Row(1) :: Row(2) :: Row(3) :: Row(null) :: Nil)
- assert(df3.schema.forall(_.nullable == true))
+ assert(df3.schema.forall(_.nullable))
val df4 = nonNullableInts.intersect(nonNullableInts)
checkAnswer(df4, Row(1) :: Row(3) :: Nil)
- assert(df4.schema.forall(_.nullable == false))
+ assert(df4.schema.forall(!_.nullable))
}
test("udf") {