From bb7f35239385ec74b5ee69631b5480fbcee253e4 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Fri, 28 Aug 2015 14:38:20 -0700 Subject: [SPARK-10323] [SQL] fix nullability of In/InSet/ArrayContain After this PR, In/InSet/ArrayContain will return null if value is null, instead of false. They also will return null even if there is a null in the set/array. Author: Davies Liu Closes #8492 from davies/fix_in. --- .../expressions/collectionOperations.scala | 62 +++++++++---------- .../sql/catalyst/expressions/predicates.scala | 71 +++++++++++++++++----- .../spark/sql/catalyst/optimizer/Optimizer.scala | 6 -- .../expressions/CollectionFunctionsSuite.scala | 12 +++- .../sql/catalyst/expressions/PredicateSuite.scala | 49 +++++++++++---- .../catalyst/optimizer/ConstantFoldingSuite.scala | 21 +------ .../apache/spark/sql/DataFrameFunctionsSuite.scala | 14 ++--- 7 files changed, 138 insertions(+), 97 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 646afa4047..7b8c5b723d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -19,9 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import java.util.Comparator import org.apache.spark.sql.catalyst.analysis.TypeCheckResult -import org.apache.spark.sql.catalyst.expressions.codegen.{ - CodegenFallback, CodeGenContext, GeneratedExpressionCode} -import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, CodegenFallback, GeneratedExpressionCode} import org.apache.spark.sql.types._ /** @@ -145,46 +143,42 @@ case class ArrayContains(left: Expression, right: Expression) } } - override def nullable: Boolean = false + override def nullable: Boolean = { + left.nullable || right.nullable || left.dataType.asInstanceOf[ArrayType].containsNull + } - override def eval(input: InternalRow): Boolean = { - val arr = left.eval(input) - if (arr == null) { - false - } else { - val value = right.eval(input) - if (value == null) { - false - } else { - arr.asInstanceOf[ArrayData].foreach(right.dataType, (i, v) => - if (v == value) return true - ) - false + override def nullSafeEval(arr: Any, value: Any): Any = { + var hasNull = false + arr.asInstanceOf[ArrayData].foreach(right.dataType, (i, v) => + if (v == null) { + hasNull = true + } else if (v == value) { + return true } + ) + if (hasNull) { + null + } else { + false } } override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - val arrGen = left.gen(ctx) - val elementGen = right.gen(ctx) - val i = ctx.freshName("i") - val getValue = ctx.getValue(arrGen.primitive, right.dataType, i) - s""" - ${arrGen.code} - boolean ${ev.isNull} = false; - boolean ${ev.primitive} = false; - if (!${arrGen.isNull}) { - ${elementGen.code} - if (!${elementGen.isNull}) { - for (int $i = 0; $i < ${arrGen.primitive}.numElements(); $i ++) { - if (${ctx.genEqual(right.dataType, elementGen.primitive, getValue)}) { - ${ev.primitive} = true; - break; - } - } + nullSafeCodeGen(ctx, ev, (arr, value) => { + val i = ctx.freshName("i") + val getValue = ctx.getValue(arr, right.dataType, i) + s""" + for (int $i = 0; $i < $arr.numElements(); $i ++) { + if ($arr.isNullAt($i)) { + ${ev.isNull} = true; + } else if (${ctx.genEqual(right.dataType, value, getValue)}) { + ${ev.isNull} = false; + ${ev.primitive} = true; + break; } } """ + }) } override def prettyName: String = "array_contains" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index fe7dffb815..65706dba7d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -17,11 +17,9 @@ package org.apache.spark.sql.catalyst.expressions -import scala.collection.mutable - -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenFallback, GeneratedExpressionCode, CodeGenContext} import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ @@ -103,6 +101,8 @@ case class Not(child: Expression) case class In(value: Expression, list: Seq[Expression]) extends Predicate with ImplicitCastInputTypes { + require(list != null, "list should not be null") + override def inputTypes: Seq[AbstractDataType] = value.dataType +: list.map(_.dataType) override def checkInputDataTypes(): TypeCheckResult = { @@ -116,12 +116,31 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate override def children: Seq[Expression] = value +: list - override def nullable: Boolean = false // TODO: Figure out correct nullability semantics of IN. + override def nullable: Boolean = children.exists(_.nullable) + override def foldable: Boolean = children.forall(_.foldable) + override def toString: String = s"$value IN ${list.mkString("(", ",", ")")}" override def eval(input: InternalRow): Any = { val evaluatedValue = value.eval(input) - list.exists(e => e.eval(input) == evaluatedValue) + if (evaluatedValue == null) { + null + } else { + var hasNull = false + list.foreach { e => + val v = e.eval(input) + if (v == evaluatedValue) { + return true + } else if (v == null) { + hasNull = true + } + } + if (hasNull) { + null + } else { + false + } + } } override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { @@ -131,7 +150,10 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate s""" if (!${ev.primitive}) { ${x.code} - if (${ctx.genEqual(value.dataType, valueGen.primitive, x.primitive)}) { + if (${x.isNull}) { + ${ev.isNull} = true; + } else if (${ctx.genEqual(value.dataType, valueGen.primitive, x.primitive)}) { + ${ev.isNull} = false; ${ev.primitive} = true; } } @@ -139,8 +161,10 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate s""" ${valueGen.code} boolean ${ev.primitive} = false; - boolean ${ev.isNull} = false; - $listCode + boolean ${ev.isNull} = ${valueGen.isNull}; + if (!${ev.isNull}) { + $listCode + } """ } } @@ -151,11 +175,22 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate */ case class InSet(child: Expression, hset: Set[Any]) extends UnaryExpression with Predicate { - override def nullable: Boolean = false // TODO: Figure out correct nullability semantics of IN. + require(hset != null, "hset could not be null") + override def toString: String = s"$child INSET ${hset.mkString("(", ",", ")")}" - override def eval(input: InternalRow): Any = { - hset.contains(child.eval(input)) + @transient private[this] lazy val hasNull: Boolean = hset.contains(null) + + override def nullable: Boolean = child.nullable || hasNull + + protected override def nullSafeEval(value: Any): Any = { + if (hset.contains(value)) { + true + } else if (hasNull) { + null + } else { + false + } } def getHSet(): Set[Any] = hset @@ -166,12 +201,20 @@ case class InSet(child: Expression, hset: Set[Any]) extends UnaryExpression with val childGen = child.gen(ctx) ctx.references += this val hsetTerm = ctx.freshName("hset") + val hasNullTerm = ctx.freshName("hasNull") ctx.addMutableState(setName, hsetTerm, s"$hsetTerm = (($InSetName)expressions[${ctx.references.size - 1}]).getHSet();") + ctx.addMutableState("boolean", hasNullTerm, s"$hasNullTerm = $hsetTerm.contains(null);") s""" ${childGen.code} - boolean ${ev.isNull} = false; - boolean ${ev.primitive} = $hsetTerm.contains(${childGen.primitive}); + boolean ${ev.isNull} = ${childGen.isNull}; + boolean ${ev.primitive} = false; + if (!${ev.isNull}) { + ${ev.primitive} = $hsetTerm.contains(${childGen.primitive}); + if (!${ev.primitive} && $hasNullTerm) { + ${ev.isNull} = true; + } + } """ } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 854463dd11..a430000bef 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -395,12 +395,6 @@ object ConstantFolding extends Rule[LogicalPlan] { // Fold expressions that are foldable. case e if e.foldable => Literal.create(e.eval(EmptyRow), e.dataType) - - // Fold "literal in (item1, item2, ..., literal, ...)" into true directly. - case In(Literal(v, _), list) if list.exists { - case Literal(candidate, _) if candidate == v => true - case _ => false - } => Literal.create(true, BooleanType) } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala index 95f0e38212..a3e81888df 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala @@ -70,14 +70,20 @@ class CollectionFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { val a0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType)) val a1 = Literal.create(Seq[String](null, ""), ArrayType(StringType)) val a2 = Literal.create(Seq(null), ArrayType(LongType)) + val a3 = Literal.create(null, ArrayType(StringType)) checkEvaluation(ArrayContains(a0, Literal(1)), true) checkEvaluation(ArrayContains(a0, Literal(0)), false) - checkEvaluation(ArrayContains(a0, Literal(null)), false) + checkEvaluation(ArrayContains(a0, Literal.create(null, IntegerType)), null) checkEvaluation(ArrayContains(a1, Literal("")), true) - checkEvaluation(ArrayContains(a1, Literal(null)), false) + checkEvaluation(ArrayContains(a1, Literal("a")), null) + checkEvaluation(ArrayContains(a1, Literal.create(null, StringType)), null) - checkEvaluation(ArrayContains(a2, Literal(null)), false) + checkEvaluation(ArrayContains(a2, Literal(1L)), null) + checkEvaluation(ArrayContains(a2, Literal.create(null, LongType)), null) + + checkEvaluation(ArrayContains(a3, Literal("")), null) + checkEvaluation(ArrayContains(a3, Literal.create(null, StringType)), null) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala index 54c04faddb..03e7611fce 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala @@ -20,7 +20,6 @@ package org.apache.spark.sql.catalyst.expressions import scala.collection.immutable.HashSet import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.RandomDataGenerator import org.apache.spark.sql.types._ @@ -119,6 +118,12 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { (null, null, null) :: Nil) test("IN") { + checkEvaluation(In(Literal.create(null, IntegerType), Seq(Literal(1), Literal(2))), null) + checkEvaluation(In(Literal.create(null, IntegerType), Seq(Literal.create(null, IntegerType))), + null) + checkEvaluation(In(Literal(1), Seq(Literal.create(null, IntegerType))), null) + checkEvaluation(In(Literal(1), Seq(Literal(1), Literal.create(null, IntegerType))), true) + checkEvaluation(In(Literal(2), Seq(Literal(1), Literal.create(null, IntegerType))), null) checkEvaluation(In(Literal(1), Seq(Literal(1), Literal(2))), true) checkEvaluation(In(Literal(2), Seq(Literal(1), Literal(2))), true) checkEvaluation(In(Literal(3), Seq(Literal(1), Literal(2))), false) @@ -126,14 +131,18 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { And(In(Literal(1), Seq(Literal(1), Literal(2))), In(Literal(2), Seq(Literal(1), Literal(2)))), true) - checkEvaluation(In(Literal("^Ba*n"), Seq(Literal("^Ba*n"))), true) + val ns = Literal.create(null, StringType) + checkEvaluation(In(ns, Seq(Literal("1"), Literal("2"))), null) + checkEvaluation(In(ns, Seq(ns)), null) + checkEvaluation(In(Literal("a"), Seq(ns)), null) + checkEvaluation(In(Literal("^Ba*n"), Seq(Literal("^Ba*n"), ns)), true) checkEvaluation(In(Literal("^Ba*n"), Seq(Literal("aa"), Literal("^Ba*n"))), true) checkEvaluation(In(Literal("^Ba*n"), Seq(Literal("aa"), Literal("^n"))), false) val primitiveTypes = Seq(IntegerType, FloatType, DoubleType, StringType, ByteType, ShortType, LongType, BinaryType, BooleanType, DecimalType.USER_DEFAULT, TimestampType) primitiveTypes.map { t => - val dataGen = RandomDataGenerator.forType(t, nullable = false).get + val dataGen = RandomDataGenerator.forType(t, nullable = true).get val inputData = Seq.fill(10) { val value = dataGen.apply() value match { @@ -142,9 +151,17 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { case _ => value } } - val input = inputData.map(Literal(_)) - checkEvaluation(In(input(0), input.slice(1, 10)), - inputData.slice(1, 10).contains(inputData(0))) + val input = inputData.map(Literal.create(_, t)) + val expected = if (inputData(0) == null) { + null + } else if (inputData.slice(1, 10).contains(inputData(0))) { + true + } else if (inputData.slice(1, 10).contains(null)) { + null + } else { + false + } + checkEvaluation(In(input(0), input.slice(1, 10)), expected) } } @@ -158,15 +175,15 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(InSet(one, hS), true) checkEvaluation(InSet(two, hS), true) checkEvaluation(InSet(two, nS), true) - checkEvaluation(InSet(nl, nS), true) checkEvaluation(InSet(three, hS), false) - checkEvaluation(InSet(three, nS), false) - checkEvaluation(And(InSet(one, hS), InSet(two, hS)), true) + checkEvaluation(InSet(three, nS), null) + checkEvaluation(InSet(nl, hS), null) + checkEvaluation(InSet(nl, nS), null) val primitiveTypes = Seq(IntegerType, FloatType, DoubleType, StringType, ByteType, ShortType, LongType, BinaryType, BooleanType, DecimalType.USER_DEFAULT, TimestampType) primitiveTypes.map { t => - val dataGen = RandomDataGenerator.forType(t, nullable = false).get + val dataGen = RandomDataGenerator.forType(t, nullable = true).get val inputData = Seq.fill(10) { val value = dataGen.apply() value match { @@ -176,8 +193,16 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { } } val input = inputData.map(Literal(_)) - checkEvaluation(InSet(input(0), inputData.slice(1, 10).toSet), - inputData.slice(1, 10).contains(inputData(0))) + val expected = if (inputData(0) == null) { + null + } else if (inputData.slice(1, 10).contains(inputData(0))) { + true + } else if (inputData.slice(1, 10).contains(null)) { + null + } else { + false + } + checkEvaluation(InSet(input(0), inputData.slice(1, 10).toSet), expected) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala index ec3b2f1edf..e67606288f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala @@ -250,29 +250,14 @@ class ConstantFoldingSuite extends PlanTest { } test("Constant folding test: Fold In(v, list) into true or false") { - var originalQuery = + val originalQuery = testRelation .select('a) .where(In(Literal(1), Seq(Literal(1), Literal(2)))) - var optimized = Optimize.execute(originalQuery.analyze) - - var correctAnswer = - testRelation - .select('a) - .where(Literal(true)) - .analyze - - comparePlans(optimized, correctAnswer) - - originalQuery = - testRelation - .select('a) - .where(In(Literal(1), Seq(Literal(1), 'a.attr))) - - optimized = Optimize.execute(originalQuery.analyze) + val optimized = Optimize.execute(originalQuery.analyze) - correctAnswer = + val correctAnswer = testRelation .select('a) .where(Literal(true)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 9d965258e3..3a3f19af14 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -366,10 +366,6 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { df.selectExpr("array_contains(a, 1)"), Seq(Row(true), Row(false)) ) - checkAnswer( - df.select(array_contains(array(lit(2), lit(null)), 1)), - Seq(Row(false), Row(false)) - ) // In hive, this errors because null has no type information intercept[AnalysisException] { @@ -382,15 +378,13 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { df.selectExpr("array_contains(null, 1)") } - // In hive, if either argument has a matching type has a null value, return false, even if - // the first argument array contains a null and the second argument is null checkAnswer( - df.selectExpr("array_contains(array(array(1), null)[1], 1)"), - Seq(Row(false), Row(false)) + df.selectExpr("array_contains(array(array(1), null)[0], 1)"), + Seq(Row(true), Row(true)) ) checkAnswer( - df.selectExpr("array_contains(array(0, null), array(1, null)[1])"), - Seq(Row(false), Row(false)) + df.selectExpr("array_contains(array(1, null), array(1, null)[0])"), + Seq(Row(true), Row(true)) ) } } -- cgit v1.2.3