aboutsummaryrefslogtreecommitdiff
path: root/sql/catalyst
diff options
context:
space:
mode:
Diffstat (limited to 'sql/catalyst')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala24
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala4
2 files changed, 20 insertions, 8 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
index 2e8ea1107c..c020029937 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
@@ -18,13 +18,14 @@ package org.apache.spark.sql.catalyst.expressions
import java.util.Comparator
+import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodegenFallback, ExprCode}
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData}
import org.apache.spark.sql.types._
/**
- * Given an array or map, returns its size.
+ * Given an array or map, returns its size. Returns -1 if null.
*/
@ExpressionDescription(
usage = "_FUNC_(expr) - Returns the size of an array or a map.",
@@ -32,14 +33,25 @@ import org.apache.spark.sql.types._
case class Size(child: Expression) extends UnaryExpression with ExpectsInputTypes {
override def dataType: DataType = IntegerType
override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(ArrayType, MapType))
-
- override def nullSafeEval(value: Any): Int = child.dataType match {
- case _: ArrayType => value.asInstanceOf[ArrayData].numElements()
- case _: MapType => value.asInstanceOf[MapData].numElements()
+ override def nullable: Boolean = false
+
+ override def eval(input: InternalRow): Any = {
+ val value = child.eval(input)
+ if (value == null) {
+ -1
+ } else child.dataType match {
+ case _: ArrayType => value.asInstanceOf[ArrayData].numElements()
+ case _: MapType => value.asInstanceOf[MapData].numElements()
+ }
}
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
- nullSafeCodeGen(ctx, ev, c => s"${ev.value} = ($c).numElements();")
+ val childGen = child.genCode(ctx)
+ ev.copy(code = s"""
+ boolean ${ev.isNull} = false;
+ ${childGen.code}
+ ${ctx.javaType(dataType)} ${ev.value} = ${childGen.isNull} ? -1 :
+ (${childGen.value}).numElements();""", isNull = "false")
}
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala
index a5f784fdcc..c76dad208e 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala
@@ -40,8 +40,8 @@ class CollectionFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(Size(m1), 0)
checkEvaluation(Size(m2), 1)
- checkEvaluation(Literal.create(null, MapType(StringType, StringType)), null)
- checkEvaluation(Literal.create(null, ArrayType(StringType)), null)
+ checkEvaluation(Size(Literal.create(null, MapType(StringType, StringType))), -1)
+ checkEvaluation(Size(Literal.create(null, ArrayType(StringType))), -1)
}
test("MapKeys/MapValues") {