aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorgatorsmile <gatorsmile@gmail.com>2016-01-19 11:35:58 -0800
committerCheng Lian <lian@databricks.com>2016-01-19 11:35:58 -0800
commitb72e01e82148a908eb19bb3f526f9777bfe27dde (patch)
treebf93b6df8b5b569c238b75ba9474cc76407f7fb7
parent2388de51912efccaceeb663ac56fc500a79d2ceb (diff)
downloadspark-b72e01e82148a908eb19bb3f526f9777bfe27dde.tar.gz
spark-b72e01e82148a908eb19bb3f526f9777bfe27dde.tar.bz2
spark-b72e01e82148a908eb19bb3f526f9777bfe27dde.zip
[SPARK-12867][SQL] Nullability of Intersect can be stricter
JIRA: https://issues.apache.org/jira/browse/SPARK-12867 When intersecting one nullable column with one non-nullable column, the result will not contain any null. Thus, we can make nullability of `intersect` stricter. liancheng Could you please check if the code changes are appropriate? Also added test cases to verify the results. Thanks! Author: gatorsmile <gatorsmile@gmail.com> Closes #10812 from gatorsmile/nullabilityIntersect.
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala18
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala21
2 files changed, 33 insertions, 6 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
index 2a1b1b131d..f4a3d85d2a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
@@ -91,11 +91,6 @@ case class Filter(condition: Expression, child: LogicalPlan) extends UnaryNode {
}
abstract class SetOperation(left: LogicalPlan, right: LogicalPlan) extends BinaryNode {
- override def output: Seq[Attribute] =
- left.output.zip(right.output).map { case (leftAttr, rightAttr) =>
- leftAttr.withNullability(leftAttr.nullable || rightAttr.nullable)
- }
-
final override lazy val resolved: Boolean =
childrenResolved &&
left.output.length == right.output.length &&
@@ -108,13 +103,24 @@ private[sql] object SetOperation {
case class Union(left: LogicalPlan, right: LogicalPlan) extends SetOperation(left, right) {
+ override def output: Seq[Attribute] =
+ left.output.zip(right.output).map { case (leftAttr, rightAttr) =>
+ leftAttr.withNullability(leftAttr.nullable || rightAttr.nullable)
+ }
+
override def statistics: Statistics = {
val sizeInBytes = left.statistics.sizeInBytes + right.statistics.sizeInBytes
Statistics(sizeInBytes = sizeInBytes)
}
}
-case class Intersect(left: LogicalPlan, right: LogicalPlan) extends SetOperation(left, right)
+case class Intersect(left: LogicalPlan, right: LogicalPlan) extends SetOperation(left, right) {
+
+ override def output: Seq[Attribute] =
+ left.output.zip(right.output).map { case (leftAttr, rightAttr) =>
+ leftAttr.withNullability(leftAttr.nullable && rightAttr.nullable)
+ }
+}
case class Except(left: LogicalPlan, right: LogicalPlan) extends SetOperation(left, right) {
/** We don't use right.output because those rows get excluded from the set. */
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index afc8df07fd..bd11a387a1 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -337,6 +337,27 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
checkAnswer(lowerCaseData.intersect(upperCaseData), Nil)
}
+ test("intersect - nullability") {
+ val nonNullableInts = Seq(Tuple1(1), Tuple1(3)).toDF()
+ assert(nonNullableInts.schema.forall(_.nullable == false))
+
+ val df1 = nonNullableInts.intersect(nullInts)
+ checkAnswer(df1, Row(1) :: Row(3) :: Nil)
+ assert(df1.schema.forall(_.nullable == false))
+
+ val df2 = nullInts.intersect(nonNullableInts)
+ checkAnswer(df2, Row(1) :: Row(3) :: Nil)
+ assert(df2.schema.forall(_.nullable == false))
+
+ val df3 = nullInts.intersect(nullInts)
+ checkAnswer(df3, Row(1) :: Row(2) :: Row(3) :: Row(null) :: Nil)
+ assert(df3.schema.forall(_.nullable == true))
+
+ val df4 = nonNullableInts.intersect(nonNullableInts)
+ checkAnswer(df4, Row(1) :: Row(3) :: Nil)
+ assert(df4.schema.forall(_.nullable == false))
+ }
+
test("udf") {
val foo = udf((a: Int, b: String) => a.toString + b)