From 958a0ec8fa58ff091f595db2b574a7aa3ff41253 Mon Sep 17 00:00:00 2001 From: Jia Li Date: Tue, 27 Oct 2015 10:57:08 +0100 Subject: [SPARK-11277][SQL] sort_array throws exception scala.MatchError I'm new to spark. I was trying out the sort_array function then hit this exception. I looked into the spark source code. I found the root cause is that sort_array does not check for an array of NULLs. It's not meaningful to sort an array of entirely NULLs anyway. I'm adding a check on the input array type to SortArray. If the array consists of NULLs entirely, there is no need to sort such array. I have also added a test case for this. Please help to review my fix. Thanks! Author: Jia Li Closes #9247 from jliwork/SPARK-11277. --- .../spark/sql/catalyst/expressions/collectionOperations.scala | 6 +++++- .../spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala | 7 +++++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 89d87726ac..2cf19b939f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -68,6 +68,7 @@ case class SortArray(base: Expression, ascendingOrder: Expression) private lazy val lt: Comparator[Any] = { val ordering = base.dataType match { case _ @ ArrayType(n: AtomicType, _) => n.ordering.asInstanceOf[Ordering[Any]] + case _ @ ArrayType(s: StructType, _) => s.interpretedOrdering.asInstanceOf[Ordering[Any]] } new Comparator[Any]() { @@ -89,6 +90,7 @@ case class SortArray(base: Expression, ascendingOrder: Expression) private lazy val gt: Comparator[Any] = { val ordering = base.dataType match { case _ @ ArrayType(n: AtomicType, _) => n.ordering.asInstanceOf[Ordering[Any]] + case _ @ ArrayType(s: StructType, _) => s.interpretedOrdering.asInstanceOf[Ordering[Any]] } new Comparator[Any]() { @@ -109,7 +111,9 @@ case class SortArray(base: Expression, ascendingOrder: Expression) override def nullSafeEval(array: Any, ascending: Any): Any = { val elementType = base.dataType.asInstanceOf[ArrayType].elementType val data = array.asInstanceOf[ArrayData].toArray[AnyRef](elementType) - java.util.Arrays.sort(data, if (ascending.asInstanceOf[Boolean]) lt else gt) + if (elementType != NullType) { + java.util.Arrays.sort(data, if (ascending.asInstanceOf[Boolean]) lt else gt) + } new GenericArrayData(data.asInstanceOf[Array[Any]]) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala index a3e81888df..1aae4678d6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala @@ -49,6 +49,7 @@ class CollectionFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { val a1 = Literal.create(Seq[Integer](), ArrayType(IntegerType)) val a2 = Literal.create(Seq("b", "a"), ArrayType(StringType)) val a3 = Literal.create(Seq("b", null, "a"), ArrayType(StringType)) + val a4 = Literal.create(Seq(null, null), ArrayType(NullType)) checkEvaluation(new SortArray(a0), Seq(1, 2, 3)) checkEvaluation(new SortArray(a1), Seq[Integer]()) @@ -64,6 +65,12 @@ class CollectionFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(new SortArray(a3, Literal(false)), Seq("b", "a", null)) checkEvaluation(Literal.create(null, ArrayType(StringType)), null) + checkEvaluation(new SortArray(a4), Seq(null, null)) + + val typeAS = ArrayType(StructType(StructField("a", IntegerType) :: Nil)) + val arrayStruct = Literal.create(Seq(create_row(2), create_row(1)), typeAS) + + checkEvaluation(new SortArray(arrayStruct), Seq(create_row(1), create_row(2))) } test("Array contains") { -- cgit v1.2.3