diff options
author | Wenchen Fan <wenchen@databricks.com> | 2016-05-17 17:02:52 +0800 |
---|---|---|
committer | Wenchen Fan <wenchen@databricks.com> | 2016-05-17 17:02:52 +0800 |
commit | c36ca651f9177f8e7a3f6a0098cba5a810ee9deb (patch) | |
tree | 2a0405085ef6df1670715b9864004dc8d6327fe0 | |
parent | 122302cbf5cbf1133067a5acdffd6ab96765dafe (diff) | |
download | spark-c36ca651f9177f8e7a3f6a0098cba5a810ee9deb.tar.gz spark-c36ca651f9177f8e7a3f6a0098cba5a810ee9deb.tar.bz2 spark-c36ca651f9177f8e7a3f6a0098cba5a810ee9deb.zip |
[SPARK-15351][SQL] RowEncoder should support array as the external type for ArrayType
## What changes were proposed in this pull request?
This PR improves `RowEncoder` and `MapObjects`, to support array as the external type for `ArrayType`. The idea is straightforward, we use `Object` as the external input type for `ArrayType`, and determine its type at runtime in `MapObjects`.
## How was this patch tested?
new test in `RowEncoderSuite`
Author: Wenchen Fan <wenchen@databricks.com>
Closes #13138 from cloud-fan/map-object.
5 files changed, 92 insertions, 55 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala index 726291b96c..a257b831dd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala @@ -151,7 +151,7 @@ trait Row extends Serializable { * BinaryType -> byte array * ArrayType -> scala.collection.Seq (use getList for java.util.List) * MapType -> scala.collection.Map (use getJavaMap for java.util.Map) - * StructType -> org.apache.spark.sql.Row (or Product) + * StructType -> org.apache.spark.sql.Row * }}} */ def apply(i: Int): Any = get(i) @@ -176,7 +176,7 @@ trait Row extends Serializable { * BinaryType -> byte array * ArrayType -> scala.collection.Seq (use getList for java.util.List) * MapType -> scala.collection.Map (use getJavaMap for java.util.Map) - * StructType -> org.apache.spark.sql.Row (or Product) + * StructType -> org.apache.spark.sql.Row * }}} */ def get(i: Int): Any diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala index ae842a9f87..a5f39aaa23 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala @@ -32,6 +32,26 @@ import org.apache.spark.unsafe.types.UTF8String /** * A factory for constructing encoders that convert external row to/from the Spark SQL * internal binary representation. + * + * The following is a mapping between Spark SQL types and its allowed external types: + * {{{ + * BooleanType -> java.lang.Boolean + * ByteType -> java.lang.Byte + * ShortType -> java.lang.Short + * IntegerType -> java.lang.Integer + * FloatType -> java.lang.Float + * DoubleType -> java.lang.Double + * StringType -> String + * DecimalType -> java.math.BigDecimal or scala.math.BigDecimal or Decimal + * + * DateType -> java.sql.Date + * TimestampType -> java.sql.Timestamp + * + * BinaryType -> byte array + * ArrayType -> scala.collection.Seq or Array + * MapType -> scala.collection.Map + * StructType -> org.apache.spark.sql.Row or Product + * }}} */ object RowEncoder { def apply(schema: StructType): ExpressionEncoder[Row] = { @@ -166,6 +186,8 @@ object RowEncoder { // In order to support both Decimal and java/scala BigDecimal in external row, we make this // as java.lang.Object. case _: DecimalType => ObjectType(classOf[java.lang.Object]) + // In order to support both Array and Seq in external row, we make this as java.lang.Object. + case _: ArrayType => ObjectType(classOf[java.lang.Object]) case _ => externalDataTypeFor(dt) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index e8a6c742bf..7df6e06805 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -376,45 +376,6 @@ case class MapObjects private( lambdaFunction: Expression, inputData: Expression) extends Expression with NonSQLExpression { - @tailrec - private def itemAccessorMethod(dataType: DataType): String => String = dataType match { - case NullType => - val nullTypeClassName = NullType.getClass.getName + ".MODULE$" - (i: String) => s".get($i, $nullTypeClassName)" - case IntegerType => (i: String) => s".getInt($i)" - case LongType => (i: String) => s".getLong($i)" - case FloatType => (i: String) => s".getFloat($i)" - case DoubleType => (i: String) => s".getDouble($i)" - case ByteType => (i: String) => s".getByte($i)" - case ShortType => (i: String) => s".getShort($i)" - case BooleanType => (i: String) => s".getBoolean($i)" - case StringType => (i: String) => s".getUTF8String($i)" - case s: StructType => (i: String) => s".getStruct($i, ${s.size})" - case a: ArrayType => (i: String) => s".getArray($i)" - case _: MapType => (i: String) => s".getMap($i)" - case udt: UserDefinedType[_] => itemAccessorMethod(udt.sqlType) - case DecimalType.Fixed(p, s) => (i: String) => s".getDecimal($i, $p, $s)" - case DateType => (i: String) => s".getInt($i)" - } - - private lazy val (lengthFunction, itemAccessor, primitiveElement) = inputData.dataType match { - case ObjectType(cls) if classOf[Seq[_]].isAssignableFrom(cls) => - (".size()", (i: String) => s".apply($i)", false) - case ObjectType(cls) if cls.isArray => - (".length", (i: String) => s"[$i]", false) - case ObjectType(cls) if classOf[java.util.List[_]].isAssignableFrom(cls) => - (".size()", (i: String) => s".get($i)", false) - case ArrayType(t, _) => - val (sqlType, primitiveElement) = t match { - case m: MapType => (m, false) - case s: StructType => (s, false) - case s: StringType => (s, false) - case udt: UserDefinedType[_] => (udt.sqlType, false) - case o => (o, true) - } - (".numElements()", itemAccessorMethod(sqlType), primitiveElement) - } - override def nullable: Boolean = true override def children: Seq[Expression] = lambdaFunction :: inputData :: Nil @@ -425,7 +386,6 @@ case class MapObjects private( override def dataType: DataType = ArrayType(lambdaFunction.dataType) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val javaType = ctx.javaType(dataType) val elementJavaType = ctx.javaType(loopVar.dataType) ctx.addMutableState("boolean", loopVar.isNull, "") ctx.addMutableState(elementJavaType, loopVar.value, "") @@ -448,27 +408,61 @@ case class MapObjects private( s"new $convertedType[$dataLength]" } - val loopNullCheck = if (primitiveElement) { - s"${loopVar.isNull} = ${genInputData.value}.isNullAt($loopIndex);" - } else { - s"${loopVar.isNull} = ${genInputData.isNull} || ${loopVar.value} == null;" + // In RowEncoder, we use `Object` to represent Array or Seq, so we need to determine the type + // of input collection at runtime for this case. + val seq = ctx.freshName("seq") + val array = ctx.freshName("array") + val determineCollectionType = inputData.dataType match { + case ObjectType(cls) if cls == classOf[Object] => + val seqClass = classOf[Seq[_]].getName + s""" + $seqClass $seq = null; + $elementJavaType[] $array = null; + if (${genInputData.value}.getClass().isArray()) { + $array = ($elementJavaType[]) ${genInputData.value}; + } else { + $seq = ($seqClass) ${genInputData.value}; + } + """ + case _ => "" + } + + + val (getLength, getLoopVar) = inputData.dataType match { + case ObjectType(cls) if classOf[Seq[_]].isAssignableFrom(cls) => + s"${genInputData.value}.size()" -> s"${genInputData.value}.apply($loopIndex)" + case ObjectType(cls) if cls.isArray => + s"${genInputData.value}.length" -> s"${genInputData.value}[$loopIndex]" + case ObjectType(cls) if classOf[java.util.List[_]].isAssignableFrom(cls) => + s"${genInputData.value}.size()" -> s"${genInputData.value}.get($loopIndex)" + case ArrayType(et, _) => + s"${genInputData.value}.numElements()" -> ctx.getValue(genInputData.value, et, loopIndex) + case ObjectType(cls) if cls == classOf[Object] => + s"$seq == null ? $array.length : $seq.size()" -> + s"$seq == null ? $array[$loopIndex] : $seq.apply($loopIndex)" + } + + val loopNullCheck = inputData.dataType match { + case _: ArrayType => s"${loopVar.isNull} = ${genInputData.value}.isNullAt($loopIndex);" + // The element of primitive array will never be null. + case ObjectType(cls) if cls.isArray && cls.getComponentType.isPrimitive => + s"${loopVar.isNull} = false" + case _ => s"${loopVar.isNull} = ${loopVar.value} == null;" } val code = s""" ${genInputData.code} + ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; - boolean ${ev.isNull} = ${genInputData.value} == null; - $javaType ${ev.value} = ${ctx.defaultValue(dataType)}; - - if (!${ev.isNull}) { + if (!${genInputData.isNull}) { + $determineCollectionType $convertedType[] $convertedArray = null; - int $dataLength = ${genInputData.value}$lengthFunction; + int $dataLength = $getLength; $convertedArray = $arrayConstructor; int $loopIndex = 0; while ($loopIndex < $dataLength) { - ${loopVar.value} = - ($elementJavaType)${genInputData.value}${itemAccessor(loopIndex)}; + ${loopVar.value} = ($elementJavaType) ($getLoopVar); $loopNullCheck ${genFunction.code} @@ -481,11 +475,10 @@ case class MapObjects private( $loopIndex += 1; } - ${ev.isNull} = false; ${ev.value} = new ${classOf[GenericArrayData].getName}($convertedArray); } """ - ev.copy(code = code) + ev.copy(code = code, isNull = genInputData.isNull) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala index 2b8cdc1e23..3a665d3708 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala @@ -37,6 +37,11 @@ class GenericArrayData(val array: Array[Any]) extends ArrayData { def this(primitiveArray: Array[Byte]) = this(primitiveArray.toSeq) def this(primitiveArray: Array[Boolean]) = this(primitiveArray.toSeq) + def this(seqOrArray: Any) = this(seqOrArray match { + case seq: Seq[Any] => seq + case array: Array[_] => array.toSeq + }) + override def copy(): ArrayData = new GenericArrayData(array.clone()) override def numElements(): Int = array.length diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala index 4800e2e26e..7bb006c173 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala @@ -185,6 +185,23 @@ class RowEncoderSuite extends SparkFunSuite { assert(encoder.serializer.head.nullable == false) } + test("RowEncoder should support array as the external type for ArrayType") { + val schema = new StructType() + .add("array", ArrayType(IntegerType)) + .add("nestedArray", ArrayType(ArrayType(StringType))) + .add("deepNestedArray", ArrayType(ArrayType(ArrayType(LongType)))) + val encoder = RowEncoder(schema) + val input = Row( + Array(1, 2, null), + Array(Array("abc", null), null), + Array(Seq(Array(0L, null), null), null)) + val row = encoder.toRow(input) + val convertedBack = encoder.fromRow(row) + assert(convertedBack.getSeq(0) == Seq(1, 2, null)) + assert(convertedBack.getSeq(1) == Seq(Seq("abc", null), null)) + assert(convertedBack.getSeq(2) == Seq(Seq(Seq(0L, null), null), null)) + } + private def encodeDecodeTest(schema: StructType): Unit = { test(s"encode/decode: ${schema.simpleString}") { val encoder = RowEncoder(schema) |