aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorLiang-Chi Hsieh <viirya@appier.com>2015-11-16 09:03:42 -0800
committerDavies Liu <davies.liu@gmail.com>2015-11-16 09:03:42 -0800
commitb0c3fd34e4cfa3f0472d83e71ffe774430cfdc87 (patch)
tree27f0788370b639ad6a94440d3f1009410dae16dd /sql
parent06f1fdba6d1425afddfc1d45a20dbe9bede15e7a (diff)
downloadspark-b0c3fd34e4cfa3f0472d83e71ffe774430cfdc87.tar.gz
spark-b0c3fd34e4cfa3f0472d83e71ffe774430cfdc87.tar.bz2
spark-b0c3fd34e4cfa3f0472d83e71ffe774430cfdc87.zip
[SPARK-11743] [SQL] Add UserDefinedType support to RowEncoder
JIRA: https://issues.apache.org/jira/browse/SPARK-11743 RowEncoder doesn't support UserDefinedType now. We should add the support for it. Author: Liang-Chi Hsieh <viirya@appier.com> Closes #9712 from viirya/rowencoder-udt.
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala14
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala24
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala48
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala82
4 files changed, 139 insertions, 29 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala
index ed2fdf9f2f..0f0f200122 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala
@@ -152,7 +152,7 @@ trait Row extends Serializable {
* BinaryType -> byte array
* ArrayType -> scala.collection.Seq (use getList for java.util.List)
* MapType -> scala.collection.Map (use getJavaMap for java.util.Map)
- * StructType -> org.apache.spark.sql.Row
+ * StructType -> org.apache.spark.sql.Row (or Product)
* }}}
*/
def apply(i: Int): Any = get(i)
@@ -177,7 +177,7 @@ trait Row extends Serializable {
* BinaryType -> byte array
* ArrayType -> scala.collection.Seq (use getList for java.util.List)
* MapType -> scala.collection.Map (use getJavaMap for java.util.Map)
- * StructType -> org.apache.spark.sql.Row
+ * StructType -> org.apache.spark.sql.Row (or Product)
* }}}
*/
def get(i: Int): Any
@@ -306,7 +306,15 @@ trait Row extends Serializable {
*
* @throws ClassCastException when data type does not match.
*/
- def getStruct(i: Int): Row = getAs[Row](i)
+ def getStruct(i: Int): Row = {
+ // Product and Row both are recoginized as StructType in a Row
+ val t = get(i)
+ if (t.isInstanceOf[Product]) {
+ Row.fromTuple(t.asInstanceOf[Product])
+ } else {
+ t.asInstanceOf[Row]
+ }
+ }
/**
* Returns the value at position i.
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
index e0be896bb3..9bb1602494 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
@@ -50,6 +50,14 @@ object RowEncoder {
case BooleanType | ByteType | ShortType | IntegerType | LongType |
FloatType | DoubleType | BinaryType => inputObject
+ case udt: UserDefinedType[_] =>
+ val obj = NewInstance(
+ udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt(),
+ Nil,
+ false,
+ dataType = ObjectType(udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt()))
+ Invoke(obj, "serialize", udt.sqlType, inputObject :: Nil)
+
case TimestampType =>
StaticInvoke(
DateTimeUtils,
@@ -109,11 +117,16 @@ object RowEncoder {
case StructType(fields) =>
val convertedFields = fields.zipWithIndex.map { case (f, i) =>
+ val method = if (f.dataType.isInstanceOf[StructType]) {
+ "getStruct"
+ } else {
+ "get"
+ }
If(
Invoke(inputObject, "isNullAt", BooleanType, Literal(i) :: Nil),
Literal.create(null, f.dataType),
extractorsFor(
- Invoke(inputObject, "get", externalDataTypeFor(f.dataType), Literal(i) :: Nil),
+ Invoke(inputObject, method, externalDataTypeFor(f.dataType), Literal(i) :: Nil),
f.dataType))
}
CreateStruct(convertedFields)
@@ -137,6 +150,7 @@ object RowEncoder {
case _: ArrayType => ObjectType(classOf[scala.collection.Seq[_]])
case _: MapType => ObjectType(classOf[scala.collection.Map[_, _]])
case _: StructType => ObjectType(classOf[Row])
+ case udt: UserDefinedType[_] => ObjectType(udt.userClass)
}
private def constructorFor(schema: StructType): Expression = {
@@ -155,6 +169,14 @@ object RowEncoder {
case BooleanType | ByteType | ShortType | IntegerType | LongType |
FloatType | DoubleType | BinaryType => input
+ case udt: UserDefinedType[_] =>
+ val obj = NewInstance(
+ udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt(),
+ Nil,
+ false,
+ dataType = ObjectType(udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt()))
+ Invoke(obj, "deserialize", ObjectType(udt.userClass), input :: Nil)
+
case TimestampType =>
StaticInvoke(
DateTimeUtils,
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
index 4f58464221..5cd19de683 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
@@ -113,7 +113,7 @@ case class Invoke(
arguments: Seq[Expression] = Nil) extends Expression {
override def nullable: Boolean = true
- override def children: Seq[Expression] = targetObject :: Nil
+ override def children: Seq[Expression] = arguments.+:(targetObject)
override def eval(input: InternalRow): Any =
throw new UnsupportedOperationException("Only code-generated evaluation is supported.")
@@ -343,33 +343,35 @@ case class MapObjects(
private lazy val loopAttribute = AttributeReference("loopVar", elementType)()
private lazy val completeFunction = function(loopAttribute)
+ private def itemAccessorMethod(dataType: DataType): String => String = dataType match {
+ case IntegerType => (i: String) => s".getInt($i)"
+ case LongType => (i: String) => s".getLong($i)"
+ case FloatType => (i: String) => s".getFloat($i)"
+ case DoubleType => (i: String) => s".getDouble($i)"
+ case ByteType => (i: String) => s".getByte($i)"
+ case ShortType => (i: String) => s".getShort($i)"
+ case BooleanType => (i: String) => s".getBoolean($i)"
+ case StringType => (i: String) => s".getUTF8String($i)"
+ case s: StructType => (i: String) => s".getStruct($i, ${s.size})"
+ case a: ArrayType => (i: String) => s".getArray($i)"
+ case _: MapType => (i: String) => s".getMap($i)"
+ case udt: UserDefinedType[_] => itemAccessorMethod(udt.sqlType)
+ }
+
private lazy val (lengthFunction, itemAccessor, primitiveElement) = inputData.dataType match {
case ObjectType(cls) if classOf[Seq[_]].isAssignableFrom(cls) =>
(".size()", (i: String) => s".apply($i)", false)
case ObjectType(cls) if cls.isArray =>
(".length", (i: String) => s"[$i]", false)
- case ArrayType(s: StructType, _) =>
- (".numElements()", (i: String) => s".getStruct($i, ${s.size})", false)
- case ArrayType(a: ArrayType, _) =>
- (".numElements()", (i: String) => s".getArray($i)", true)
- case ArrayType(IntegerType, _) =>
- (".numElements()", (i: String) => s".getInt($i)", true)
- case ArrayType(LongType, _) =>
- (".numElements()", (i: String) => s".getLong($i)", true)
- case ArrayType(FloatType, _) =>
- (".numElements()", (i: String) => s".getFloat($i)", true)
- case ArrayType(DoubleType, _) =>
- (".numElements()", (i: String) => s".getDouble($i)", true)
- case ArrayType(ByteType, _) =>
- (".numElements()", (i: String) => s".getByte($i)", true)
- case ArrayType(ShortType, _) =>
- (".numElements()", (i: String) => s".getShort($i)", true)
- case ArrayType(BooleanType, _) =>
- (".numElements()", (i: String) => s".getBoolean($i)", true)
- case ArrayType(StringType, _) =>
- (".numElements()", (i: String) => s".getUTF8String($i)", false)
- case ArrayType(_: MapType, _) =>
- (".numElements()", (i: String) => s".getMap($i)", false)
+ case ArrayType(t, _) =>
+ val (sqlType, primitiveElement) = t match {
+ case m: MapType => (m, false)
+ case s: StructType => (s, false)
+ case s: StringType => (s, false)
+ case udt: UserDefinedType[_] => (udt.sqlType, false)
+ case o => (o, true)
+ }
+ (".numElements()", itemAccessorMethod(sqlType), primitiveElement)
}
override def nullable: Boolean = true
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
index e8301e8e06..c868ddec1b 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
@@ -19,14 +19,62 @@ package org.apache.spark.sql.catalyst.encoders
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.{RandomDataGenerator, Row}
+import org.apache.spark.sql.catalyst.util.{GenericArrayData, ArrayData}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
+@SQLUserDefinedType(udt = classOf[ExamplePointUDT])
+class ExamplePoint(val x: Double, val y: Double) extends Serializable {
+ override def hashCode: Int = 41 * (41 + x.toInt) + y.toInt
+ override def equals(that: Any): Boolean = {
+ if (that.isInstanceOf[ExamplePoint]) {
+ val e = that.asInstanceOf[ExamplePoint]
+ (this.x == e.x || (this.x.isNaN && e.x.isNaN) || (this.x.isInfinity && e.x.isInfinity)) &&
+ (this.y == e.y || (this.y.isNaN && e.y.isNaN) || (this.y.isInfinity && e.y.isInfinity))
+ } else {
+ false
+ }
+ }
+}
+
+/**
+ * User-defined type for [[ExamplePoint]].
+ */
+class ExamplePointUDT extends UserDefinedType[ExamplePoint] {
+
+ override def sqlType: DataType = ArrayType(DoubleType, false)
+
+ override def pyUDT: String = "pyspark.sql.tests.ExamplePointUDT"
+
+ override def serialize(obj: Any): GenericArrayData = {
+ obj match {
+ case p: ExamplePoint =>
+ val output = new Array[Any](2)
+ output(0) = p.x
+ output(1) = p.y
+ new GenericArrayData(output)
+ }
+ }
+
+ override def deserialize(datum: Any): ExamplePoint = {
+ datum match {
+ case values: ArrayData =>
+ new ExamplePoint(values.getDouble(0), values.getDouble(1))
+ }
+ }
+
+ override def userClass: Class[ExamplePoint] = classOf[ExamplePoint]
+
+ private[spark] override def asNullable: ExamplePointUDT = this
+}
+
class RowEncoderSuite extends SparkFunSuite {
private val structOfString = new StructType().add("str", StringType)
+ private val structOfUDT = new StructType().add("udt", new ExamplePointUDT, false)
private val arrayOfString = ArrayType(StringType)
private val mapOfString = MapType(StringType, StringType)
+ private val arrayOfUDT = ArrayType(new ExamplePointUDT, false)
encodeDecodeTest(
new StructType()
@@ -41,7 +89,8 @@ class RowEncoderSuite extends SparkFunSuite {
.add("string", StringType)
.add("binary", BinaryType)
.add("date", DateType)
- .add("timestamp", TimestampType))
+ .add("timestamp", TimestampType)
+ .add("udt", new ExamplePointUDT, false))
encodeDecodeTest(
new StructType()
@@ -68,7 +117,36 @@ class RowEncoderSuite extends SparkFunSuite {
.add("structOfArray", new StructType().add("array", arrayOfString))
.add("structOfMap", new StructType().add("map", mapOfString))
.add("structOfArrayAndMap",
- new StructType().add("array", arrayOfString).add("map", mapOfString)))
+ new StructType().add("array", arrayOfString).add("map", mapOfString))
+ .add("structOfUDT", structOfUDT))
+
+ test(s"encode/decode: arrayOfUDT") {
+ val schema = new StructType()
+ .add("arrayOfUDT", arrayOfUDT)
+
+ val encoder = RowEncoder(schema)
+
+ val input: Row = Row(Seq(new ExamplePoint(0.1, 0.2), new ExamplePoint(0.3, 0.4)))
+ val row = encoder.toRow(input)
+ val convertedBack = encoder.fromRow(row)
+ assert(input.getSeq[ExamplePoint](0) == convertedBack.getSeq[ExamplePoint](0))
+ }
+
+ test(s"encode/decode: Product") {
+ val schema = new StructType()
+ .add("structAsProduct",
+ new StructType()
+ .add("int", IntegerType)
+ .add("string", StringType)
+ .add("double", DoubleType))
+
+ val encoder = RowEncoder(schema)
+
+ val input: Row = Row((100, "test", 0.123))
+ val row = encoder.toRow(input)
+ val convertedBack = encoder.fromRow(row)
+ assert(input.getStruct(0) == convertedBack.getStruct(0))
+ }
private def encodeDecodeTest(schema: StructType): Unit = {
test(s"encode/decode: ${schema.simpleString}") {