aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorWenchen Fan <wenchen@databricks.com>2016-05-17 17:02:52 +0800
committerWenchen Fan <wenchen@databricks.com>2016-05-17 17:02:52 +0800
commitc36ca651f9177f8e7a3f6a0098cba5a810ee9deb (patch)
tree2a0405085ef6df1670715b9864004dc8d6327fe0
parent122302cbf5cbf1133067a5acdffd6ab96765dafe (diff)
downloadspark-c36ca651f9177f8e7a3f6a0098cba5a810ee9deb.tar.gz
spark-c36ca651f9177f8e7a3f6a0098cba5a810ee9deb.tar.bz2
spark-c36ca651f9177f8e7a3f6a0098cba5a810ee9deb.zip
[SPARK-15351][SQL] RowEncoder should support array as the external type for ArrayType
## What changes were proposed in this pull request? This PR improves `RowEncoder` and `MapObjects`, to support array as the external type for `ArrayType`. The idea is straightforward, we use `Object` as the external input type for `ArrayType`, and determine its type at runtime in `MapObjects`. ## How was this patch tested? new test in `RowEncoderSuite` Author: Wenchen Fan <wenchen@databricks.com> Closes #13138 from cloud-fan/map-object.
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala4
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala22
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala99
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala5
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala17
5 files changed, 92 insertions, 55 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala
index 726291b96c..a257b831dd 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala
@@ -151,7 +151,7 @@ trait Row extends Serializable {
* BinaryType -> byte array
* ArrayType -> scala.collection.Seq (use getList for java.util.List)
* MapType -> scala.collection.Map (use getJavaMap for java.util.Map)
- * StructType -> org.apache.spark.sql.Row (or Product)
+ * StructType -> org.apache.spark.sql.Row
* }}}
*/
def apply(i: Int): Any = get(i)
@@ -176,7 +176,7 @@ trait Row extends Serializable {
* BinaryType -> byte array
* ArrayType -> scala.collection.Seq (use getList for java.util.List)
* MapType -> scala.collection.Map (use getJavaMap for java.util.Map)
- * StructType -> org.apache.spark.sql.Row (or Product)
+ * StructType -> org.apache.spark.sql.Row
* }}}
*/
def get(i: Int): Any
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
index ae842a9f87..a5f39aaa23 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
@@ -32,6 +32,26 @@ import org.apache.spark.unsafe.types.UTF8String
/**
* A factory for constructing encoders that convert external row to/from the Spark SQL
* internal binary representation.
+ *
+ * The following is a mapping between Spark SQL types and its allowed external types:
+ * {{{
+ * BooleanType -> java.lang.Boolean
+ * ByteType -> java.lang.Byte
+ * ShortType -> java.lang.Short
+ * IntegerType -> java.lang.Integer
+ * FloatType -> java.lang.Float
+ * DoubleType -> java.lang.Double
+ * StringType -> String
+ * DecimalType -> java.math.BigDecimal or scala.math.BigDecimal or Decimal
+ *
+ * DateType -> java.sql.Date
+ * TimestampType -> java.sql.Timestamp
+ *
+ * BinaryType -> byte array
+ * ArrayType -> scala.collection.Seq or Array
+ * MapType -> scala.collection.Map
+ * StructType -> org.apache.spark.sql.Row or Product
+ * }}}
*/
object RowEncoder {
def apply(schema: StructType): ExpressionEncoder[Row] = {
@@ -166,6 +186,8 @@ object RowEncoder {
// In order to support both Decimal and java/scala BigDecimal in external row, we make this
// as java.lang.Object.
case _: DecimalType => ObjectType(classOf[java.lang.Object])
+ // In order to support both Array and Seq in external row, we make this as java.lang.Object.
+ case _: ArrayType => ObjectType(classOf[java.lang.Object])
case _ => externalDataTypeFor(dt)
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
index e8a6c742bf..7df6e06805 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
@@ -376,45 +376,6 @@ case class MapObjects private(
lambdaFunction: Expression,
inputData: Expression) extends Expression with NonSQLExpression {
- @tailrec
- private def itemAccessorMethod(dataType: DataType): String => String = dataType match {
- case NullType =>
- val nullTypeClassName = NullType.getClass.getName + ".MODULE$"
- (i: String) => s".get($i, $nullTypeClassName)"
- case IntegerType => (i: String) => s".getInt($i)"
- case LongType => (i: String) => s".getLong($i)"
- case FloatType => (i: String) => s".getFloat($i)"
- case DoubleType => (i: String) => s".getDouble($i)"
- case ByteType => (i: String) => s".getByte($i)"
- case ShortType => (i: String) => s".getShort($i)"
- case BooleanType => (i: String) => s".getBoolean($i)"
- case StringType => (i: String) => s".getUTF8String($i)"
- case s: StructType => (i: String) => s".getStruct($i, ${s.size})"
- case a: ArrayType => (i: String) => s".getArray($i)"
- case _: MapType => (i: String) => s".getMap($i)"
- case udt: UserDefinedType[_] => itemAccessorMethod(udt.sqlType)
- case DecimalType.Fixed(p, s) => (i: String) => s".getDecimal($i, $p, $s)"
- case DateType => (i: String) => s".getInt($i)"
- }
-
- private lazy val (lengthFunction, itemAccessor, primitiveElement) = inputData.dataType match {
- case ObjectType(cls) if classOf[Seq[_]].isAssignableFrom(cls) =>
- (".size()", (i: String) => s".apply($i)", false)
- case ObjectType(cls) if cls.isArray =>
- (".length", (i: String) => s"[$i]", false)
- case ObjectType(cls) if classOf[java.util.List[_]].isAssignableFrom(cls) =>
- (".size()", (i: String) => s".get($i)", false)
- case ArrayType(t, _) =>
- val (sqlType, primitiveElement) = t match {
- case m: MapType => (m, false)
- case s: StructType => (s, false)
- case s: StringType => (s, false)
- case udt: UserDefinedType[_] => (udt.sqlType, false)
- case o => (o, true)
- }
- (".numElements()", itemAccessorMethod(sqlType), primitiveElement)
- }
-
override def nullable: Boolean = true
override def children: Seq[Expression] = lambdaFunction :: inputData :: Nil
@@ -425,7 +386,6 @@ case class MapObjects private(
override def dataType: DataType = ArrayType(lambdaFunction.dataType)
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
- val javaType = ctx.javaType(dataType)
val elementJavaType = ctx.javaType(loopVar.dataType)
ctx.addMutableState("boolean", loopVar.isNull, "")
ctx.addMutableState(elementJavaType, loopVar.value, "")
@@ -448,27 +408,61 @@ case class MapObjects private(
s"new $convertedType[$dataLength]"
}
- val loopNullCheck = if (primitiveElement) {
- s"${loopVar.isNull} = ${genInputData.value}.isNullAt($loopIndex);"
- } else {
- s"${loopVar.isNull} = ${genInputData.isNull} || ${loopVar.value} == null;"
+ // In RowEncoder, we use `Object` to represent Array or Seq, so we need to determine the type
+ // of input collection at runtime for this case.
+ val seq = ctx.freshName("seq")
+ val array = ctx.freshName("array")
+ val determineCollectionType = inputData.dataType match {
+ case ObjectType(cls) if cls == classOf[Object] =>
+ val seqClass = classOf[Seq[_]].getName
+ s"""
+ $seqClass $seq = null;
+ $elementJavaType[] $array = null;
+ if (${genInputData.value}.getClass().isArray()) {
+ $array = ($elementJavaType[]) ${genInputData.value};
+ } else {
+ $seq = ($seqClass) ${genInputData.value};
+ }
+ """
+ case _ => ""
+ }
+
+
+ val (getLength, getLoopVar) = inputData.dataType match {
+ case ObjectType(cls) if classOf[Seq[_]].isAssignableFrom(cls) =>
+ s"${genInputData.value}.size()" -> s"${genInputData.value}.apply($loopIndex)"
+ case ObjectType(cls) if cls.isArray =>
+ s"${genInputData.value}.length" -> s"${genInputData.value}[$loopIndex]"
+ case ObjectType(cls) if classOf[java.util.List[_]].isAssignableFrom(cls) =>
+ s"${genInputData.value}.size()" -> s"${genInputData.value}.get($loopIndex)"
+ case ArrayType(et, _) =>
+ s"${genInputData.value}.numElements()" -> ctx.getValue(genInputData.value, et, loopIndex)
+ case ObjectType(cls) if cls == classOf[Object] =>
+ s"$seq == null ? $array.length : $seq.size()" ->
+ s"$seq == null ? $array[$loopIndex] : $seq.apply($loopIndex)"
+ }
+
+ val loopNullCheck = inputData.dataType match {
+ case _: ArrayType => s"${loopVar.isNull} = ${genInputData.value}.isNullAt($loopIndex);"
+ // The element of primitive array will never be null.
+ case ObjectType(cls) if cls.isArray && cls.getComponentType.isPrimitive =>
+ s"${loopVar.isNull} = false"
+ case _ => s"${loopVar.isNull} = ${loopVar.value} == null;"
}
val code = s"""
${genInputData.code}
+ ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
- boolean ${ev.isNull} = ${genInputData.value} == null;
- $javaType ${ev.value} = ${ctx.defaultValue(dataType)};
-
- if (!${ev.isNull}) {
+ if (!${genInputData.isNull}) {
+ $determineCollectionType
$convertedType[] $convertedArray = null;
- int $dataLength = ${genInputData.value}$lengthFunction;
+ int $dataLength = $getLength;
$convertedArray = $arrayConstructor;
int $loopIndex = 0;
while ($loopIndex < $dataLength) {
- ${loopVar.value} =
- ($elementJavaType)${genInputData.value}${itemAccessor(loopIndex)};
+ ${loopVar.value} = ($elementJavaType) ($getLoopVar);
$loopNullCheck
${genFunction.code}
@@ -481,11 +475,10 @@ case class MapObjects private(
$loopIndex += 1;
}
- ${ev.isNull} = false;
${ev.value} = new ${classOf[GenericArrayData].getName}($convertedArray);
}
"""
- ev.copy(code = code)
+ ev.copy(code = code, isNull = genInputData.isNull)
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala
index 2b8cdc1e23..3a665d3708 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala
@@ -37,6 +37,11 @@ class GenericArrayData(val array: Array[Any]) extends ArrayData {
def this(primitiveArray: Array[Byte]) = this(primitiveArray.toSeq)
def this(primitiveArray: Array[Boolean]) = this(primitiveArray.toSeq)
+ def this(seqOrArray: Any) = this(seqOrArray match {
+ case seq: Seq[Any] => seq
+ case array: Array[_] => array.toSeq
+ })
+
override def copy(): ArrayData = new GenericArrayData(array.clone())
override def numElements(): Int = array.length
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
index 4800e2e26e..7bb006c173 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
@@ -185,6 +185,23 @@ class RowEncoderSuite extends SparkFunSuite {
assert(encoder.serializer.head.nullable == false)
}
+ test("RowEncoder should support array as the external type for ArrayType") {
+ val schema = new StructType()
+ .add("array", ArrayType(IntegerType))
+ .add("nestedArray", ArrayType(ArrayType(StringType)))
+ .add("deepNestedArray", ArrayType(ArrayType(ArrayType(LongType))))
+ val encoder = RowEncoder(schema)
+ val input = Row(
+ Array(1, 2, null),
+ Array(Array("abc", null), null),
+ Array(Seq(Array(0L, null), null), null))
+ val row = encoder.toRow(input)
+ val convertedBack = encoder.fromRow(row)
+ assert(convertedBack.getSeq(0) == Seq(1, 2, null))
+ assert(convertedBack.getSeq(1) == Seq(Seq("abc", null), null))
+ assert(convertedBack.getSeq(2) == Seq(Seq(Seq(0L, null), null), null))
+ }
+
private def encodeDecodeTest(schema: StructType): Unit = {
test(s"encode/decode: ${schema.simpleString}") {
val encoder = RowEncoder(schema)