aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala24
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala4
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala14
3 files changed, 28 insertions, 14 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
index 2e8ea1107c..c020029937 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
@@ -18,13 +18,14 @@ package org.apache.spark.sql.catalyst.expressions
import java.util.Comparator
+import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodegenFallback, ExprCode}
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData}
import org.apache.spark.sql.types._
/**
- * Given an array or map, returns its size.
+ * Given an array or map, returns its size. Returns -1 if null.
*/
@ExpressionDescription(
usage = "_FUNC_(expr) - Returns the size of an array or a map.",
@@ -32,14 +33,25 @@ import org.apache.spark.sql.types._
case class Size(child: Expression) extends UnaryExpression with ExpectsInputTypes {
override def dataType: DataType = IntegerType
override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(ArrayType, MapType))
-
- override def nullSafeEval(value: Any): Int = child.dataType match {
- case _: ArrayType => value.asInstanceOf[ArrayData].numElements()
- case _: MapType => value.asInstanceOf[MapData].numElements()
+ override def nullable: Boolean = false
+
+ override def eval(input: InternalRow): Any = {
+ val value = child.eval(input)
+ if (value == null) {
+ -1
+ } else child.dataType match {
+ case _: ArrayType => value.asInstanceOf[ArrayData].numElements()
+ case _: MapType => value.asInstanceOf[MapData].numElements()
+ }
}
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
- nullSafeCodeGen(ctx, ev, c => s"${ev.value} = ($c).numElements();")
+ val childGen = child.genCode(ctx)
+ ev.copy(code = s"""
+ boolean ${ev.isNull} = false;
+ ${childGen.code}
+ ${ctx.javaType(dataType)} ${ev.value} = ${childGen.isNull} ? -1 :
+ (${childGen.value}).numElements();""", isNull = "false")
}
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala
index a5f784fdcc..c76dad208e 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala
@@ -40,8 +40,8 @@ class CollectionFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(Size(m1), 0)
checkEvaluation(Size(m2), 1)
- checkEvaluation(Literal.create(null, MapType(StringType, StringType)), null)
- checkEvaluation(Literal.create(null, ArrayType(StringType)), null)
+ checkEvaluation(Size(Literal.create(null, MapType(StringType, StringType))), -1)
+ checkEvaluation(Size(Literal.create(null, ArrayType(StringType))), -1)
}
test("MapKeys/MapValues") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
index 0f6c49e759..45db61515e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
@@ -324,15 +324,16 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
val df = Seq(
(Seq[Int](1, 2), "x"),
(Seq[Int](), "y"),
- (Seq[Int](1, 2, 3), "z")
+ (Seq[Int](1, 2, 3), "z"),
+ (null, "empty")
).toDF("a", "b")
checkAnswer(
df.select(size($"a")),
- Seq(Row(2), Row(0), Row(3))
+ Seq(Row(2), Row(0), Row(3), Row(-1))
)
checkAnswer(
df.selectExpr("size(a)"),
- Seq(Row(2), Row(0), Row(3))
+ Seq(Row(2), Row(0), Row(3), Row(-1))
)
}
@@ -340,15 +341,16 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
val df = Seq(
(Map[Int, Int](1 -> 1, 2 -> 2), "x"),
(Map[Int, Int](), "y"),
- (Map[Int, Int](1 -> 1, 2 -> 2, 3 -> 3), "z")
+ (Map[Int, Int](1 -> 1, 2 -> 2, 3 -> 3), "z"),
+ (null, "empty")
).toDF("a", "b")
checkAnswer(
df.select(size($"a")),
- Seq(Row(2), Row(0), Row(3))
+ Seq(Row(2), Row(0), Row(3), Row(-1))
)
checkAnswer(
df.selectExpr("size(a)"),
- Seq(Row(2), Row(0), Row(3))
+ Seq(Row(2), Row(0), Row(3), Row(-1))
)
}