diff options
Diffstat (limited to 'sql/catalyst')
5 files changed, 103 insertions, 14 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 31c6e5def1..7bcaea7ea2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -441,6 +441,22 @@ object ScalaReflection extends ScalaReflection { val newPath = s"""- array element class: "$clsName"""" +: walkedTypePath MapObjects(serializerFor(_, elementType, newPath), input, dt) + case dt @ (BooleanType | ByteType | ShortType | IntegerType | LongType | + FloatType | DoubleType) => + val cls = input.dataType.asInstanceOf[ObjectType].cls + if (cls.isArray && cls.getComponentType.isPrimitive) { + StaticInvoke( + classOf[UnsafeArrayData], + ArrayType(dt, false), + "fromPrimitiveArray", + input :: Nil) + } else { + NewInstance( + classOf[GenericArrayData], + input :: Nil, + dataType = ArrayType(dt, schemaFor(elementType).nullable)) + } + case dt => NewInstance( classOf[GenericArrayData], diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala index 2a6fcd03a2..e95e97b9dc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala @@ -23,7 +23,7 @@ import scala.reflect.ClassTag import org.apache.spark.SparkException import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeUtils, GenericArrayData} +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, DateTimeUtils, GenericArrayData} import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.analysis.GetColumnByOrdinal import org.apache.spark.sql.catalyst.expressions.objects._ @@ -119,18 +119,19 @@ object RowEncoder { "fromString", inputObject :: Nil) - case t @ ArrayType(et, _) => et match { - case BooleanType | ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType => - // TODO: validate input type for primitive array. - NewInstance( - classOf[GenericArrayData], - inputObject :: Nil, - dataType = t) - case _ => MapObjects( - element => serializerFor(ValidateExternalType(element, et), et), - inputObject, - ObjectType(classOf[Object])) - } + case t @ ArrayType(et, cn) => + et match { + case BooleanType | ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType => + StaticInvoke( + classOf[ArrayData], + t, + "toArrayData", + inputObject :: Nil) + case _ => MapObjects( + element => serializerFor(ValidateExternalType(element, et), et), + inputObject, + ObjectType(classOf[Object])) + } case t @ MapType(kt, vt, valueNullable) => val keys = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayData.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayData.scala index cad4a08b0d..140e86d670 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayData.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayData.scala @@ -19,9 +19,22 @@ package org.apache.spark.sql.catalyst.util import scala.reflect.ClassTag -import org.apache.spark.sql.catalyst.expressions.SpecializedGetters +import org.apache.spark.sql.catalyst.expressions.{SpecializedGetters, UnsafeArrayData} import org.apache.spark.sql.types.DataType +object ArrayData { + def toArrayData(input: Any): ArrayData = input match { + case a: Array[Boolean] => UnsafeArrayData.fromPrimitiveArray(a) + case a: Array[Byte] => UnsafeArrayData.fromPrimitiveArray(a) + case a: Array[Short] => UnsafeArrayData.fromPrimitiveArray(a) + case a: Array[Int] => UnsafeArrayData.fromPrimitiveArray(a) + case a: Array[Long] => UnsafeArrayData.fromPrimitiveArray(a) + case a: Array[Float] => UnsafeArrayData.fromPrimitiveArray(a) + case a: Array[Double] => UnsafeArrayData.fromPrimitiveArray(a) + case other => new GenericArrayData(other) + } +} + abstract class ArrayData extends SpecializedGetters with Serializable { def numElements(): Int diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala index 03bb102c67..f3702ec92b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystTypeConvertersSuite.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.catalyst import org.apache.spark.SparkFunSuite import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.expressions.UnsafeArrayData +import org.apache.spark.sql.catalyst.util.GenericArrayData import org.apache.spark.sql.types._ class CatalystTypeConvertersSuite extends SparkFunSuite { @@ -61,4 +63,35 @@ class CatalystTypeConvertersSuite extends SparkFunSuite { test("option handling in createToCatalystConverter") { assert(CatalystTypeConverters.createToCatalystConverter(IntegerType)(Some(123)) === 123) } + + test("primitive array handling") { + val intArray = Array(1, 100, 10000) + val intUnsafeArray = UnsafeArrayData.fromPrimitiveArray(intArray) + val intArrayType = ArrayType(IntegerType, false) + assert(CatalystTypeConverters.createToScalaConverter(intArrayType)(intUnsafeArray) === intArray) + + val doubleArray = Array(1.1, 111.1, 11111.1) + val doubleUnsafeArray = UnsafeArrayData.fromPrimitiveArray(doubleArray) + val doubleArrayType = ArrayType(DoubleType, false) + assert(CatalystTypeConverters.createToScalaConverter(doubleArrayType)(doubleUnsafeArray) + === doubleArray) + } + + test("An array with null handling") { + val intArray = Array(1, null, 100, null, 10000) + val intGenericArray = new GenericArrayData(intArray) + val intArrayType = ArrayType(IntegerType, true) + assert(CatalystTypeConverters.createToScalaConverter(intArrayType)(intGenericArray) + === intArray) + assert(CatalystTypeConverters.createToCatalystConverter(intArrayType)(intArray) + == intGenericArray) + + val doubleArray = Array(1.1, null, 111.1, null, 11111.1) + val doubleGenericArray = new GenericArrayData(doubleArray) + val doubleArrayType = ArrayType(DoubleType, true) + assert(CatalystTypeConverters.createToScalaConverter(doubleArrayType)(doubleGenericArray) + === doubleArray) + assert(CatalystTypeConverters.createToCatalystConverter(doubleArrayType)(doubleArray) + == doubleGenericArray) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala index 2e513ea22c..1a5569a77d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala @@ -191,6 +191,32 @@ class RowEncoderSuite extends SparkFunSuite { assert(encoder.serializer.head.nullable == false) } + test("RowEncoder should support primitive arrays") { + val schema = new StructType() + .add("booleanPrimitiveArray", ArrayType(BooleanType, false)) + .add("bytePrimitiveArray", ArrayType(ByteType, false)) + .add("shortPrimitiveArray", ArrayType(ShortType, false)) + .add("intPrimitiveArray", ArrayType(IntegerType, false)) + .add("longPrimitiveArray", ArrayType(LongType, false)) + .add("floatPrimitiveArray", ArrayType(FloatType, false)) + .add("doublePrimitiveArray", ArrayType(DoubleType, false)) + val encoder = RowEncoder(schema).resolveAndBind() + val input = Seq( + Array(true, false), + Array(1.toByte, 64.toByte, Byte.MaxValue), + Array(1.toShort, 255.toShort, Short.MaxValue), + Array(1, 10000, Int.MaxValue), + Array(1.toLong, 1000000.toLong, Long.MaxValue), + Array(1.1.toFloat, 123.456.toFloat, Float.MaxValue), + Array(11.1111, 123456.7890123, Double.MaxValue) + ) + val row = encoder.toRow(Row.fromSeq(input)) + val convertedBack = encoder.fromRow(row) + input.zipWithIndex.map { case (array, index) => + assert(convertedBack.getSeq(index) === array) + } + } + test("RowEncoder should support array as the external type for ArrayType") { val schema = new StructType() .add("array", ArrayType(IntegerType)) |