diff options
author | Liang-Chi Hsieh <viirya@appier.com> | 2015-11-16 09:03:42 -0800 |
---|---|---|
committer | Davies Liu <davies.liu@gmail.com> | 2015-11-16 09:03:42 -0800 |
commit | b0c3fd34e4cfa3f0472d83e71ffe774430cfdc87 (patch) | |
tree | 27f0788370b639ad6a94440d3f1009410dae16dd /sql | |
parent | 06f1fdba6d1425afddfc1d45a20dbe9bede15e7a (diff) | |
download | spark-b0c3fd34e4cfa3f0472d83e71ffe774430cfdc87.tar.gz spark-b0c3fd34e4cfa3f0472d83e71ffe774430cfdc87.tar.bz2 spark-b0c3fd34e4cfa3f0472d83e71ffe774430cfdc87.zip |
[SPARK-11743] [SQL] Add UserDefinedType support to RowEncoder
JIRA: https://issues.apache.org/jira/browse/SPARK-11743
RowEncoder doesn't support UserDefinedType now. We should add the support for it.
Author: Liang-Chi Hsieh <viirya@appier.com>
Closes #9712 from viirya/rowencoder-udt.
Diffstat (limited to 'sql')
4 files changed, 139 insertions, 29 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala index ed2fdf9f2f..0f0f200122 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala @@ -152,7 +152,7 @@ trait Row extends Serializable { * BinaryType -> byte array * ArrayType -> scala.collection.Seq (use getList for java.util.List) * MapType -> scala.collection.Map (use getJavaMap for java.util.Map) - * StructType -> org.apache.spark.sql.Row + * StructType -> org.apache.spark.sql.Row (or Product) * }}} */ def apply(i: Int): Any = get(i) @@ -177,7 +177,7 @@ trait Row extends Serializable { * BinaryType -> byte array * ArrayType -> scala.collection.Seq (use getList for java.util.List) * MapType -> scala.collection.Map (use getJavaMap for java.util.Map) - * StructType -> org.apache.spark.sql.Row + * StructType -> org.apache.spark.sql.Row (or Product) * }}} */ def get(i: Int): Any @@ -306,7 +306,15 @@ trait Row extends Serializable { * * @throws ClassCastException when data type does not match. */ - def getStruct(i: Int): Row = getAs[Row](i) + def getStruct(i: Int): Row = { + // Product and Row both are recoginized as StructType in a Row + val t = get(i) + if (t.isInstanceOf[Product]) { + Row.fromTuple(t.asInstanceOf[Product]) + } else { + t.asInstanceOf[Row] + } + } /** * Returns the value at position i. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala index e0be896bb3..9bb1602494 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala @@ -50,6 +50,14 @@ object RowEncoder { case BooleanType | ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType | BinaryType => inputObject + case udt: UserDefinedType[_] => + val obj = NewInstance( + udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt(), + Nil, + false, + dataType = ObjectType(udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt())) + Invoke(obj, "serialize", udt.sqlType, inputObject :: Nil) + case TimestampType => StaticInvoke( DateTimeUtils, @@ -109,11 +117,16 @@ object RowEncoder { case StructType(fields) => val convertedFields = fields.zipWithIndex.map { case (f, i) => + val method = if (f.dataType.isInstanceOf[StructType]) { + "getStruct" + } else { + "get" + } If( Invoke(inputObject, "isNullAt", BooleanType, Literal(i) :: Nil), Literal.create(null, f.dataType), extractorsFor( - Invoke(inputObject, "get", externalDataTypeFor(f.dataType), Literal(i) :: Nil), + Invoke(inputObject, method, externalDataTypeFor(f.dataType), Literal(i) :: Nil), f.dataType)) } CreateStruct(convertedFields) @@ -137,6 +150,7 @@ object RowEncoder { case _: ArrayType => ObjectType(classOf[scala.collection.Seq[_]]) case _: MapType => ObjectType(classOf[scala.collection.Map[_, _]]) case _: StructType => ObjectType(classOf[Row]) + case udt: UserDefinedType[_] => ObjectType(udt.userClass) } private def constructorFor(schema: StructType): Expression = { @@ -155,6 +169,14 @@ object RowEncoder { case BooleanType | ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType | BinaryType => input + case udt: UserDefinedType[_] => + val obj = NewInstance( + udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt(), + Nil, + false, + dataType = ObjectType(udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt())) + Invoke(obj, "deserialize", ObjectType(udt.userClass), input :: Nil) + case TimestampType => StaticInvoke( DateTimeUtils, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala index 4f58464221..5cd19de683 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala @@ -113,7 +113,7 @@ case class Invoke( arguments: Seq[Expression] = Nil) extends Expression { override def nullable: Boolean = true - override def children: Seq[Expression] = targetObject :: Nil + override def children: Seq[Expression] = arguments.+:(targetObject) override def eval(input: InternalRow): Any = throw new UnsupportedOperationException("Only code-generated evaluation is supported.") @@ -343,33 +343,35 @@ case class MapObjects( private lazy val loopAttribute = AttributeReference("loopVar", elementType)() private lazy val completeFunction = function(loopAttribute) + private def itemAccessorMethod(dataType: DataType): String => String = dataType match { + case IntegerType => (i: String) => s".getInt($i)" + case LongType => (i: String) => s".getLong($i)" + case FloatType => (i: String) => s".getFloat($i)" + case DoubleType => (i: String) => s".getDouble($i)" + case ByteType => (i: String) => s".getByte($i)" + case ShortType => (i: String) => s".getShort($i)" + case BooleanType => (i: String) => s".getBoolean($i)" + case StringType => (i: String) => s".getUTF8String($i)" + case s: StructType => (i: String) => s".getStruct($i, ${s.size})" + case a: ArrayType => (i: String) => s".getArray($i)" + case _: MapType => (i: String) => s".getMap($i)" + case udt: UserDefinedType[_] => itemAccessorMethod(udt.sqlType) + } + private lazy val (lengthFunction, itemAccessor, primitiveElement) = inputData.dataType match { case ObjectType(cls) if classOf[Seq[_]].isAssignableFrom(cls) => (".size()", (i: String) => s".apply($i)", false) case ObjectType(cls) if cls.isArray => (".length", (i: String) => s"[$i]", false) - case ArrayType(s: StructType, _) => - (".numElements()", (i: String) => s".getStruct($i, ${s.size})", false) - case ArrayType(a: ArrayType, _) => - (".numElements()", (i: String) => s".getArray($i)", true) - case ArrayType(IntegerType, _) => - (".numElements()", (i: String) => s".getInt($i)", true) - case ArrayType(LongType, _) => - (".numElements()", (i: String) => s".getLong($i)", true) - case ArrayType(FloatType, _) => - (".numElements()", (i: String) => s".getFloat($i)", true) - case ArrayType(DoubleType, _) => - (".numElements()", (i: String) => s".getDouble($i)", true) - case ArrayType(ByteType, _) => - (".numElements()", (i: String) => s".getByte($i)", true) - case ArrayType(ShortType, _) => - (".numElements()", (i: String) => s".getShort($i)", true) - case ArrayType(BooleanType, _) => - (".numElements()", (i: String) => s".getBoolean($i)", true) - case ArrayType(StringType, _) => - (".numElements()", (i: String) => s".getUTF8String($i)", false) - case ArrayType(_: MapType, _) => - (".numElements()", (i: String) => s".getMap($i)", false) + case ArrayType(t, _) => + val (sqlType, primitiveElement) = t match { + case m: MapType => (m, false) + case s: StructType => (s, false) + case s: StringType => (s, false) + case udt: UserDefinedType[_] => (udt.sqlType, false) + case o => (o, true) + } + (".numElements()", itemAccessorMethod(sqlType), primitiveElement) } override def nullable: Boolean = true diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala index e8301e8e06..c868ddec1b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala @@ -19,14 +19,62 @@ package org.apache.spark.sql.catalyst.encoders import org.apache.spark.SparkFunSuite import org.apache.spark.sql.{RandomDataGenerator, Row} +import org.apache.spark.sql.catalyst.util.{GenericArrayData, ArrayData} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String +@SQLUserDefinedType(udt = classOf[ExamplePointUDT]) +class ExamplePoint(val x: Double, val y: Double) extends Serializable { + override def hashCode: Int = 41 * (41 + x.toInt) + y.toInt + override def equals(that: Any): Boolean = { + if (that.isInstanceOf[ExamplePoint]) { + val e = that.asInstanceOf[ExamplePoint] + (this.x == e.x || (this.x.isNaN && e.x.isNaN) || (this.x.isInfinity && e.x.isInfinity)) && + (this.y == e.y || (this.y.isNaN && e.y.isNaN) || (this.y.isInfinity && e.y.isInfinity)) + } else { + false + } + } +} + +/** + * User-defined type for [[ExamplePoint]]. + */ +class ExamplePointUDT extends UserDefinedType[ExamplePoint] { + + override def sqlType: DataType = ArrayType(DoubleType, false) + + override def pyUDT: String = "pyspark.sql.tests.ExamplePointUDT" + + override def serialize(obj: Any): GenericArrayData = { + obj match { + case p: ExamplePoint => + val output = new Array[Any](2) + output(0) = p.x + output(1) = p.y + new GenericArrayData(output) + } + } + + override def deserialize(datum: Any): ExamplePoint = { + datum match { + case values: ArrayData => + new ExamplePoint(values.getDouble(0), values.getDouble(1)) + } + } + + override def userClass: Class[ExamplePoint] = classOf[ExamplePoint] + + private[spark] override def asNullable: ExamplePointUDT = this +} + class RowEncoderSuite extends SparkFunSuite { private val structOfString = new StructType().add("str", StringType) + private val structOfUDT = new StructType().add("udt", new ExamplePointUDT, false) private val arrayOfString = ArrayType(StringType) private val mapOfString = MapType(StringType, StringType) + private val arrayOfUDT = ArrayType(new ExamplePointUDT, false) encodeDecodeTest( new StructType() @@ -41,7 +89,8 @@ class RowEncoderSuite extends SparkFunSuite { .add("string", StringType) .add("binary", BinaryType) .add("date", DateType) - .add("timestamp", TimestampType)) + .add("timestamp", TimestampType) + .add("udt", new ExamplePointUDT, false)) encodeDecodeTest( new StructType() @@ -68,7 +117,36 @@ class RowEncoderSuite extends SparkFunSuite { .add("structOfArray", new StructType().add("array", arrayOfString)) .add("structOfMap", new StructType().add("map", mapOfString)) .add("structOfArrayAndMap", - new StructType().add("array", arrayOfString).add("map", mapOfString))) + new StructType().add("array", arrayOfString).add("map", mapOfString)) + .add("structOfUDT", structOfUDT)) + + test(s"encode/decode: arrayOfUDT") { + val schema = new StructType() + .add("arrayOfUDT", arrayOfUDT) + + val encoder = RowEncoder(schema) + + val input: Row = Row(Seq(new ExamplePoint(0.1, 0.2), new ExamplePoint(0.3, 0.4))) + val row = encoder.toRow(input) + val convertedBack = encoder.fromRow(row) + assert(input.getSeq[ExamplePoint](0) == convertedBack.getSeq[ExamplePoint](0)) + } + + test(s"encode/decode: Product") { + val schema = new StructType() + .add("structAsProduct", + new StructType() + .add("int", IntegerType) + .add("string", StringType) + .add("double", DoubleType)) + + val encoder = RowEncoder(schema) + + val input: Row = Row((100, "test", 0.123)) + val row = encoder.toRow(input) + val convertedBack = encoder.fromRow(row) + assert(input.getStruct(0) == convertedBack.getStruct(0)) + } private def encodeDecodeTest(schema: StructType): Unit = { test(s"encode/decode: ${schema.simpleString}") { |