diff options
Diffstat (limited to 'sql/core/src')
4 files changed, 69 insertions, 20 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 3955c5dc92..1eb1f8ef11 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -297,6 +297,9 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case logical.Intersect(left, right) => throw new IllegalStateException( "logical intersect operator should have been replaced by semi-join in the optimizer") + case logical.Except(left, right) => + throw new IllegalStateException( + "logical except operator should have been replaced by anti-join in the optimizer") case logical.DeserializeToObject(deserializer, objAttr, child) => execution.DeserializeToObject(deserializer, objAttr, planLater(child)) :: Nil @@ -347,8 +350,6 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { execution.GlobalLimitExec(limit, planLater(child)) :: Nil case logical.Union(unionChildren) => execution.UnionExec(unionChildren.map(planLater)) :: Nil - case logical.Except(left, right) => - execution.ExceptExec(planLater(left), planLater(right)) :: Nil case g @ logical.Generate(generator, join, outer, _, _, child) => execution.GenerateExec( generator, join = join, outer = outer, g.output, planLater(child)) :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index 77be613b83..d492fa7c41 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -492,18 +492,6 @@ case class CoalesceExec(numPartitions: Int, child: SparkPlan) extends UnaryExecN } /** - * Physical plan for returning a table with the elements from left that are not in right using - * the built-in spark subtract function. - */ -case class ExceptExec(left: SparkPlan, right: SparkPlan) extends BinaryExecNode { - override def output: Seq[Attribute] = left.output - - protected override def doExecute(): RDD[InternalRow] = { - left.execute().map(_.copy()).subtract(right.execute().map(_.copy())) - } -} - -/** * A plan node that does nothing but lie about the output of its child. Used to spice a * (hopefully structurally equivalent) tree from a different optimization sequence into an already * resolved tree. diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index 5abd62cbc2..f1b1c22e4a 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -291,7 +291,7 @@ public class JavaDatasetSuite implements Serializable { unioned.collectAsList()); Dataset<String> subtracted = ds.except(ds2); - Assert.assertEquals(Arrays.asList("abc", "abc"), subtracted.collectAsList()); + Assert.assertEquals(Arrays.asList("abc"), subtracted.collectAsList()); } private static <T> Set<T> toSet(List<T> records) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 681476b6e2..f10d8372ed 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -398,6 +398,66 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { Row(4, "d") :: Nil) checkAnswer(lowerCaseData.except(lowerCaseData), Nil) checkAnswer(upperCaseData.except(upperCaseData), Nil) + + // check null equality + checkAnswer( + nullInts.except(nullInts.filter("0 = 1")), + nullInts) + checkAnswer( + nullInts.except(nullInts), + Nil) + + // check if values are de-duplicated + checkAnswer( + allNulls.except(allNulls.filter("0 = 1")), + Row(null) :: Nil) + checkAnswer( + allNulls.except(allNulls), + Nil) + + // check if values are de-duplicated + val df = Seq(("id1", 1), ("id1", 1), ("id", 1), ("id1", 2)).toDF("id", "value") + checkAnswer( + df.except(df.filter("0 = 1")), + Row("id1", 1) :: + Row("id", 1) :: + Row("id1", 2) :: Nil) + + // check if the empty set on the left side works + checkAnswer( + allNulls.filter("0 = 1").except(allNulls), + Nil) + } + + test("except distinct - SQL compliance") { + val df_left = Seq(1, 2, 2, 3, 3, 4).toDF("id") + val df_right = Seq(1, 3).toDF("id") + + checkAnswer( + df_left.except(df_right), + Row(2) :: Row(4) :: Nil + ) + } + + test("except - nullability") { + val nonNullableInts = Seq(Tuple1(11), Tuple1(3)).toDF() + assert(nonNullableInts.schema.forall(!_.nullable)) + + val df1 = nonNullableInts.except(nullInts) + checkAnswer(df1, Row(11) :: Nil) + assert(df1.schema.forall(!_.nullable)) + + val df2 = nullInts.except(nonNullableInts) + checkAnswer(df2, Row(1) :: Row(2) :: Row(null) :: Nil) + assert(df2.schema.forall(_.nullable)) + + val df3 = nullInts.except(nullInts) + checkAnswer(df3, Nil) + assert(df3.schema.forall(_.nullable)) + + val df4 = nonNullableInts.except(nonNullableInts) + checkAnswer(df4, Nil) + assert(df4.schema.forall(!_.nullable)) } test("intersect") { @@ -433,23 +493,23 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { test("intersect - nullability") { val nonNullableInts = Seq(Tuple1(1), Tuple1(3)).toDF() - assert(nonNullableInts.schema.forall(_.nullable == false)) + assert(nonNullableInts.schema.forall(!_.nullable)) val df1 = nonNullableInts.intersect(nullInts) checkAnswer(df1, Row(1) :: Row(3) :: Nil) - assert(df1.schema.forall(_.nullable == false)) + assert(df1.schema.forall(!_.nullable)) val df2 = nullInts.intersect(nonNullableInts) checkAnswer(df2, Row(1) :: Row(3) :: Nil) - assert(df2.schema.forall(_.nullable == false)) + assert(df2.schema.forall(!_.nullable)) val df3 = nullInts.intersect(nullInts) checkAnswer(df3, Row(1) :: Row(2) :: Row(3) :: Row(null) :: Nil) - assert(df3.schema.forall(_.nullable == true)) + assert(df3.schema.forall(_.nullable)) val df4 = nonNullableInts.intersect(nonNullableInts) checkAnswer(df4, Row(1) :: Row(3) :: Nil) - assert(df4.schema.forall(_.nullable == false)) + assert(df4.schema.forall(!_.nullable)) } test("udf") { |