aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2015-08-28 14:38:20 -0700
committerDavies Liu <davies.liu@gmail.com>2015-08-28 14:38:29 -0700
commit02e10d2df40e18a14c4c388c41699b5b258e57ac (patch)
tree314de1d7441e1af14aa32b90ffd5465a6d123264
parent7f014809de25d1d491c46e09fd88ef6c3d5d0e1b (diff)
downloadspark-02e10d2df40e18a14c4c388c41699b5b258e57ac.tar.gz
spark-02e10d2df40e18a14c4c388c41699b5b258e57ac.tar.bz2
spark-02e10d2df40e18a14c4c388c41699b5b258e57ac.zip
[SPARK-10323] [SQL] fix nullability of In/InSet/ArrayContain
After this PR, In/InSet/ArrayContain will return null if value is null, instead of false. They also will return null even if there is a null in the set/array. Author: Davies Liu <davies@databricks.com> Closes #8492 from davies/fix_in. (cherry picked from commit bb7f35239385ec74b5ee69631b5480fbcee253e4) Signed-off-by: Davies Liu <davies.liu@gmail.com>
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala62
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala71
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala6
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala12
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala49
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala21
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala14
7 files changed, 138 insertions, 97 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
index 646afa4047..7b8c5b723d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
@@ -19,9 +19,7 @@ package org.apache.spark.sql.catalyst.expressions
import java.util.Comparator
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
-import org.apache.spark.sql.catalyst.expressions.codegen.{
- CodegenFallback, CodeGenContext, GeneratedExpressionCode}
-import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, CodegenFallback, GeneratedExpressionCode}
import org.apache.spark.sql.types._
/**
@@ -145,46 +143,42 @@ case class ArrayContains(left: Expression, right: Expression)
}
}
- override def nullable: Boolean = false
+ override def nullable: Boolean = {
+ left.nullable || right.nullable || left.dataType.asInstanceOf[ArrayType].containsNull
+ }
- override def eval(input: InternalRow): Boolean = {
- val arr = left.eval(input)
- if (arr == null) {
- false
- } else {
- val value = right.eval(input)
- if (value == null) {
- false
- } else {
- arr.asInstanceOf[ArrayData].foreach(right.dataType, (i, v) =>
- if (v == value) return true
- )
- false
+ override def nullSafeEval(arr: Any, value: Any): Any = {
+ var hasNull = false
+ arr.asInstanceOf[ArrayData].foreach(right.dataType, (i, v) =>
+ if (v == null) {
+ hasNull = true
+ } else if (v == value) {
+ return true
}
+ )
+ if (hasNull) {
+ null
+ } else {
+ false
}
}
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
- val arrGen = left.gen(ctx)
- val elementGen = right.gen(ctx)
- val i = ctx.freshName("i")
- val getValue = ctx.getValue(arrGen.primitive, right.dataType, i)
- s"""
- ${arrGen.code}
- boolean ${ev.isNull} = false;
- boolean ${ev.primitive} = false;
- if (!${arrGen.isNull}) {
- ${elementGen.code}
- if (!${elementGen.isNull}) {
- for (int $i = 0; $i < ${arrGen.primitive}.numElements(); $i ++) {
- if (${ctx.genEqual(right.dataType, elementGen.primitive, getValue)}) {
- ${ev.primitive} = true;
- break;
- }
- }
+ nullSafeCodeGen(ctx, ev, (arr, value) => {
+ val i = ctx.freshName("i")
+ val getValue = ctx.getValue(arr, right.dataType, i)
+ s"""
+ for (int $i = 0; $i < $arr.numElements(); $i ++) {
+ if ($arr.isNullAt($i)) {
+ ${ev.isNull} = true;
+ } else if (${ctx.genEqual(right.dataType, value, getValue)}) {
+ ${ev.isNull} = false;
+ ${ev.primitive} = true;
+ break;
}
}
"""
+ })
}
override def prettyName: String = "array_contains"
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
index fe7dffb815..65706dba7d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
@@ -17,11 +17,9 @@
package org.apache.spark.sql.catalyst.expressions
-import scala.collection.mutable
-
-import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
-import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenFallback, GeneratedExpressionCode, CodeGenContext}
import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
+import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.types._
@@ -103,6 +101,8 @@ case class Not(child: Expression)
case class In(value: Expression, list: Seq[Expression]) extends Predicate
with ImplicitCastInputTypes {
+ require(list != null, "list should not be null")
+
override def inputTypes: Seq[AbstractDataType] = value.dataType +: list.map(_.dataType)
override def checkInputDataTypes(): TypeCheckResult = {
@@ -116,12 +116,31 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate
override def children: Seq[Expression] = value +: list
- override def nullable: Boolean = false // TODO: Figure out correct nullability semantics of IN.
+ override def nullable: Boolean = children.exists(_.nullable)
+ override def foldable: Boolean = children.forall(_.foldable)
+
override def toString: String = s"$value IN ${list.mkString("(", ",", ")")}"
override def eval(input: InternalRow): Any = {
val evaluatedValue = value.eval(input)
- list.exists(e => e.eval(input) == evaluatedValue)
+ if (evaluatedValue == null) {
+ null
+ } else {
+ var hasNull = false
+ list.foreach { e =>
+ val v = e.eval(input)
+ if (v == evaluatedValue) {
+ return true
+ } else if (v == null) {
+ hasNull = true
+ }
+ }
+ if (hasNull) {
+ null
+ } else {
+ false
+ }
+ }
}
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
@@ -131,7 +150,10 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate
s"""
if (!${ev.primitive}) {
${x.code}
- if (${ctx.genEqual(value.dataType, valueGen.primitive, x.primitive)}) {
+ if (${x.isNull}) {
+ ${ev.isNull} = true;
+ } else if (${ctx.genEqual(value.dataType, valueGen.primitive, x.primitive)}) {
+ ${ev.isNull} = false;
${ev.primitive} = true;
}
}
@@ -139,8 +161,10 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate
s"""
${valueGen.code}
boolean ${ev.primitive} = false;
- boolean ${ev.isNull} = false;
- $listCode
+ boolean ${ev.isNull} = ${valueGen.isNull};
+ if (!${ev.isNull}) {
+ $listCode
+ }
"""
}
}
@@ -151,11 +175,22 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate
*/
case class InSet(child: Expression, hset: Set[Any]) extends UnaryExpression with Predicate {
- override def nullable: Boolean = false // TODO: Figure out correct nullability semantics of IN.
+ require(hset != null, "hset could not be null")
+
override def toString: String = s"$child INSET ${hset.mkString("(", ",", ")")}"
- override def eval(input: InternalRow): Any = {
- hset.contains(child.eval(input))
+ @transient private[this] lazy val hasNull: Boolean = hset.contains(null)
+
+ override def nullable: Boolean = child.nullable || hasNull
+
+ protected override def nullSafeEval(value: Any): Any = {
+ if (hset.contains(value)) {
+ true
+ } else if (hasNull) {
+ null
+ } else {
+ false
+ }
}
def getHSet(): Set[Any] = hset
@@ -166,12 +201,20 @@ case class InSet(child: Expression, hset: Set[Any]) extends UnaryExpression with
val childGen = child.gen(ctx)
ctx.references += this
val hsetTerm = ctx.freshName("hset")
+ val hasNullTerm = ctx.freshName("hasNull")
ctx.addMutableState(setName, hsetTerm,
s"$hsetTerm = (($InSetName)expressions[${ctx.references.size - 1}]).getHSet();")
+ ctx.addMutableState("boolean", hasNullTerm, s"$hasNullTerm = $hsetTerm.contains(null);")
s"""
${childGen.code}
- boolean ${ev.isNull} = false;
- boolean ${ev.primitive} = $hsetTerm.contains(${childGen.primitive});
+ boolean ${ev.isNull} = ${childGen.isNull};
+ boolean ${ev.primitive} = false;
+ if (!${ev.isNull}) {
+ ${ev.primitive} = $hsetTerm.contains(${childGen.primitive});
+ if (!${ev.primitive} && $hasNullTerm) {
+ ${ev.isNull} = true;
+ }
+ }
"""
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 854463dd11..a430000bef 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -395,12 +395,6 @@ object ConstantFolding extends Rule[LogicalPlan] {
// Fold expressions that are foldable.
case e if e.foldable => Literal.create(e.eval(EmptyRow), e.dataType)
-
- // Fold "literal in (item1, item2, ..., literal, ...)" into true directly.
- case In(Literal(v, _), list) if list.exists {
- case Literal(candidate, _) if candidate == v => true
- case _ => false
- } => Literal.create(true, BooleanType)
}
}
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala
index 95f0e38212..a3e81888df 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala
@@ -70,14 +70,20 @@ class CollectionFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
val a0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType))
val a1 = Literal.create(Seq[String](null, ""), ArrayType(StringType))
val a2 = Literal.create(Seq(null), ArrayType(LongType))
+ val a3 = Literal.create(null, ArrayType(StringType))
checkEvaluation(ArrayContains(a0, Literal(1)), true)
checkEvaluation(ArrayContains(a0, Literal(0)), false)
- checkEvaluation(ArrayContains(a0, Literal(null)), false)
+ checkEvaluation(ArrayContains(a0, Literal.create(null, IntegerType)), null)
checkEvaluation(ArrayContains(a1, Literal("")), true)
- checkEvaluation(ArrayContains(a1, Literal(null)), false)
+ checkEvaluation(ArrayContains(a1, Literal("a")), null)
+ checkEvaluation(ArrayContains(a1, Literal.create(null, StringType)), null)
- checkEvaluation(ArrayContains(a2, Literal(null)), false)
+ checkEvaluation(ArrayContains(a2, Literal(1L)), null)
+ checkEvaluation(ArrayContains(a2, Literal.create(null, LongType)), null)
+
+ checkEvaluation(ArrayContains(a3, Literal("")), null)
+ checkEvaluation(ArrayContains(a3, Literal.create(null, StringType)), null)
}
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala
index 54c04faddb..03e7611fce 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala
@@ -20,7 +20,6 @@ package org.apache.spark.sql.catalyst.expressions
import scala.collection.immutable.HashSet
import org.apache.spark.SparkFunSuite
-import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.RandomDataGenerator
import org.apache.spark.sql.types._
@@ -119,6 +118,12 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
(null, null, null) :: Nil)
test("IN") {
+ checkEvaluation(In(Literal.create(null, IntegerType), Seq(Literal(1), Literal(2))), null)
+ checkEvaluation(In(Literal.create(null, IntegerType), Seq(Literal.create(null, IntegerType))),
+ null)
+ checkEvaluation(In(Literal(1), Seq(Literal.create(null, IntegerType))), null)
+ checkEvaluation(In(Literal(1), Seq(Literal(1), Literal.create(null, IntegerType))), true)
+ checkEvaluation(In(Literal(2), Seq(Literal(1), Literal.create(null, IntegerType))), null)
checkEvaluation(In(Literal(1), Seq(Literal(1), Literal(2))), true)
checkEvaluation(In(Literal(2), Seq(Literal(1), Literal(2))), true)
checkEvaluation(In(Literal(3), Seq(Literal(1), Literal(2))), false)
@@ -126,14 +131,18 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
And(In(Literal(1), Seq(Literal(1), Literal(2))), In(Literal(2), Seq(Literal(1), Literal(2)))),
true)
- checkEvaluation(In(Literal("^Ba*n"), Seq(Literal("^Ba*n"))), true)
+ val ns = Literal.create(null, StringType)
+ checkEvaluation(In(ns, Seq(Literal("1"), Literal("2"))), null)
+ checkEvaluation(In(ns, Seq(ns)), null)
+ checkEvaluation(In(Literal("a"), Seq(ns)), null)
+ checkEvaluation(In(Literal("^Ba*n"), Seq(Literal("^Ba*n"), ns)), true)
checkEvaluation(In(Literal("^Ba*n"), Seq(Literal("aa"), Literal("^Ba*n"))), true)
checkEvaluation(In(Literal("^Ba*n"), Seq(Literal("aa"), Literal("^n"))), false)
val primitiveTypes = Seq(IntegerType, FloatType, DoubleType, StringType, ByteType, ShortType,
LongType, BinaryType, BooleanType, DecimalType.USER_DEFAULT, TimestampType)
primitiveTypes.map { t =>
- val dataGen = RandomDataGenerator.forType(t, nullable = false).get
+ val dataGen = RandomDataGenerator.forType(t, nullable = true).get
val inputData = Seq.fill(10) {
val value = dataGen.apply()
value match {
@@ -142,9 +151,17 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
case _ => value
}
}
- val input = inputData.map(Literal(_))
- checkEvaluation(In(input(0), input.slice(1, 10)),
- inputData.slice(1, 10).contains(inputData(0)))
+ val input = inputData.map(Literal.create(_, t))
+ val expected = if (inputData(0) == null) {
+ null
+ } else if (inputData.slice(1, 10).contains(inputData(0))) {
+ true
+ } else if (inputData.slice(1, 10).contains(null)) {
+ null
+ } else {
+ false
+ }
+ checkEvaluation(In(input(0), input.slice(1, 10)), expected)
}
}
@@ -158,15 +175,15 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(InSet(one, hS), true)
checkEvaluation(InSet(two, hS), true)
checkEvaluation(InSet(two, nS), true)
- checkEvaluation(InSet(nl, nS), true)
checkEvaluation(InSet(three, hS), false)
- checkEvaluation(InSet(three, nS), false)
- checkEvaluation(And(InSet(one, hS), InSet(two, hS)), true)
+ checkEvaluation(InSet(three, nS), null)
+ checkEvaluation(InSet(nl, hS), null)
+ checkEvaluation(InSet(nl, nS), null)
val primitiveTypes = Seq(IntegerType, FloatType, DoubleType, StringType, ByteType, ShortType,
LongType, BinaryType, BooleanType, DecimalType.USER_DEFAULT, TimestampType)
primitiveTypes.map { t =>
- val dataGen = RandomDataGenerator.forType(t, nullable = false).get
+ val dataGen = RandomDataGenerator.forType(t, nullable = true).get
val inputData = Seq.fill(10) {
val value = dataGen.apply()
value match {
@@ -176,8 +193,16 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
}
}
val input = inputData.map(Literal(_))
- checkEvaluation(InSet(input(0), inputData.slice(1, 10).toSet),
- inputData.slice(1, 10).contains(inputData(0)))
+ val expected = if (inputData(0) == null) {
+ null
+ } else if (inputData.slice(1, 10).contains(inputData(0))) {
+ true
+ } else if (inputData.slice(1, 10).contains(null)) {
+ null
+ } else {
+ false
+ }
+ checkEvaluation(InSet(input(0), inputData.slice(1, 10).toSet), expected)
}
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala
index ec3b2f1edf..e67606288f 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala
@@ -250,29 +250,14 @@ class ConstantFoldingSuite extends PlanTest {
}
test("Constant folding test: Fold In(v, list) into true or false") {
- var originalQuery =
+ val originalQuery =
testRelation
.select('a)
.where(In(Literal(1), Seq(Literal(1), Literal(2))))
- var optimized = Optimize.execute(originalQuery.analyze)
-
- var correctAnswer =
- testRelation
- .select('a)
- .where(Literal(true))
- .analyze
-
- comparePlans(optimized, correctAnswer)
-
- originalQuery =
- testRelation
- .select('a)
- .where(In(Literal(1), Seq(Literal(1), 'a.attr)))
-
- optimized = Optimize.execute(originalQuery.analyze)
+ val optimized = Optimize.execute(originalQuery.analyze)
- correctAnswer =
+ val correctAnswer =
testRelation
.select('a)
.where(Literal(true))
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
index 9d965258e3..3a3f19af14 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
@@ -366,10 +366,6 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
df.selectExpr("array_contains(a, 1)"),
Seq(Row(true), Row(false))
)
- checkAnswer(
- df.select(array_contains(array(lit(2), lit(null)), 1)),
- Seq(Row(false), Row(false))
- )
// In hive, this errors because null has no type information
intercept[AnalysisException] {
@@ -382,15 +378,13 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
df.selectExpr("array_contains(null, 1)")
}
- // In hive, if either argument has a matching type has a null value, return false, even if
- // the first argument array contains a null and the second argument is null
checkAnswer(
- df.selectExpr("array_contains(array(array(1), null)[1], 1)"),
- Seq(Row(false), Row(false))
+ df.selectExpr("array_contains(array(array(1), null)[0], 1)"),
+ Seq(Row(true), Row(true))
)
checkAnswer(
- df.selectExpr("array_contains(array(0, null), array(1, null)[1])"),
- Seq(Row(false), Row(false))
+ df.selectExpr("array_contains(array(1, null), array(1, null)[0])"),
+ Seq(Row(true), Row(true))
)
}
}