aboutsummaryrefslogtreecommitdiff
path: root/sql/catalyst
diff options
context:
space:
mode:
Diffstat (limited to 'sql/catalyst')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala6
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala8
2 files changed, 11 insertions, 3 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index 4def65b01f..90646fd25b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -57,7 +57,11 @@ trait ScalaReflection {
case (obj, udt: UserDefinedType[_]) => udt.serialize(obj)
case (o: Option[_], _) => o.map(convertToCatalyst(_, dataType)).orNull
case (s: Seq[_], arrayType: ArrayType) => s.map(convertToCatalyst(_, arrayType.elementType))
- case (s: Array[_], arrayType: ArrayType) => s.toSeq
+ case (s: Array[_], arrayType: ArrayType) => if (arrayType.elementType.isPrimitive) {
+ s.toSeq
+ } else {
+ s.toSeq.map(convertToCatalyst(_, arrayType.elementType))
+ }
case (m: Map[_, _], mapType: MapType) => m.map { case (k, v) =>
convertToCatalyst(k, mapType.keyType) -> convertToCatalyst(v, mapType.valueType)
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
index 4a66716e0a..d0f547d187 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
@@ -64,7 +64,8 @@ case class ComplexData(
arrayFieldContainsNull: Seq[java.lang.Integer],
mapField: Map[Int, Long],
mapFieldValueContainsNull: Map[Int, java.lang.Long],
- structField: PrimitiveData)
+ structField: PrimitiveData,
+ nestedArrayField: Array[Array[Int]])
case class GenericData[A](
genericField: A)
@@ -158,7 +159,10 @@ class ScalaReflectionSuite extends FunSuite {
StructField("shortField", ShortType, nullable = false),
StructField("byteField", ByteType, nullable = false),
StructField("booleanField", BooleanType, nullable = false))),
- nullable = true))),
+ nullable = true),
+ StructField(
+ "nestedArrayField",
+ ArrayType(ArrayType(IntegerType, containsNull = false), containsNull = true)))),
nullable = true))
}