diff options
3 files changed, 28 insertions, 14 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 2e8ea1107c..c020029937 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -18,13 +18,14 @@ package org.apache.spark.sql.catalyst.expressions import java.util.Comparator +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodegenFallback, ExprCode} import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData} import org.apache.spark.sql.types._ /** - * Given an array or map, returns its size. + * Given an array or map, returns its size. Returns -1 if null. */ @ExpressionDescription( usage = "_FUNC_(expr) - Returns the size of an array or a map.", @@ -32,14 +33,25 @@ import org.apache.spark.sql.types._ case class Size(child: Expression) extends UnaryExpression with ExpectsInputTypes { override def dataType: DataType = IntegerType override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(ArrayType, MapType)) - - override def nullSafeEval(value: Any): Int = child.dataType match { - case _: ArrayType => value.asInstanceOf[ArrayData].numElements() - case _: MapType => value.asInstanceOf[MapData].numElements() + override def nullable: Boolean = false + + override def eval(input: InternalRow): Any = { + val value = child.eval(input) + if (value == null) { + -1 + } else child.dataType match { + case _: ArrayType => value.asInstanceOf[ArrayData].numElements() + case _: MapType => value.asInstanceOf[MapData].numElements() + } } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - nullSafeCodeGen(ctx, ev, c => s"${ev.value} = ($c).numElements();") + val childGen = child.genCode(ctx) + ev.copy(code = s""" + boolean ${ev.isNull} = false; + ${childGen.code} + ${ctx.javaType(dataType)} ${ev.value} = ${childGen.isNull} ? -1 : + (${childGen.value}).numElements();""", isNull = "false") } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala index a5f784fdcc..c76dad208e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala @@ -40,8 +40,8 @@ class CollectionFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Size(m1), 0) checkEvaluation(Size(m2), 1) - checkEvaluation(Literal.create(null, MapType(StringType, StringType)), null) - checkEvaluation(Literal.create(null, ArrayType(StringType)), null) + checkEvaluation(Size(Literal.create(null, MapType(StringType, StringType))), -1) + checkEvaluation(Size(Literal.create(null, ArrayType(StringType))), -1) } test("MapKeys/MapValues") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 0f6c49e759..45db61515e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -324,15 +324,16 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { val df = Seq( (Seq[Int](1, 2), "x"), (Seq[Int](), "y"), - (Seq[Int](1, 2, 3), "z") + (Seq[Int](1, 2, 3), "z"), + (null, "empty") ).toDF("a", "b") checkAnswer( df.select(size($"a")), - Seq(Row(2), Row(0), Row(3)) + Seq(Row(2), Row(0), Row(3), Row(-1)) ) checkAnswer( df.selectExpr("size(a)"), - Seq(Row(2), Row(0), Row(3)) + Seq(Row(2), Row(0), Row(3), Row(-1)) ) } @@ -340,15 +341,16 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { val df = Seq( (Map[Int, Int](1 -> 1, 2 -> 2), "x"), (Map[Int, Int](), "y"), - (Map[Int, Int](1 -> 1, 2 -> 2, 3 -> 3), "z") + (Map[Int, Int](1 -> 1, 2 -> 2, 3 -> 3), "z"), + (null, "empty") ).toDF("a", "b") checkAnswer( df.select(size($"a")), - Seq(Row(2), Row(0), Row(3)) + Seq(Row(2), Row(0), Row(3), Row(-1)) ) checkAnswer( df.selectExpr("size(a)"), - Seq(Row(2), Row(0), Row(3)) + Seq(Row(2), Row(0), Row(3), Row(-1)) ) } |