aboutsummaryrefslogtreecommitdiff
path: root/sql/catalyst
diff options
context:
space:
mode:
authorJoseph K. Bradley <joseph@databricks.com>2014-11-02 17:55:55 -0800
committerMichael Armbrust <michael@databricks.com>2014-11-02 17:56:00 -0800
commitebd6480587f96e9964d37157253523e0a179171a (patch)
tree221ceefd1a5febae327cd75810efdf0197d56005 /sql/catalyst
parent2ebd1df3f17993f3cb472ec44c8832213976d99a (diff)
downloadspark-ebd6480587f96e9964d37157253523e0a179171a.tar.gz
spark-ebd6480587f96e9964d37157253523e0a179171a.tar.bz2
spark-ebd6480587f96e9964d37157253523e0a179171a.zip
[SPARK-3572] [SQL] Internal API for User-Defined Types
This PR adds User-Defined Types (UDTs) to SQL. It is a precursor to using SchemaRDD as a Dataset for the new MLlib API. Currently, the UDT API is private since there is incomplete support (e.g., no Java or Python support yet). Author: Joseph K. Bradley <joseph@databricks.com> Author: Michael Armbrust <michael@databricks.com> Author: Xiangrui Meng <meng@databricks.com> Closes #3063 from marmbrus/udts and squashes the following commits: 7ccfc0d [Michael Armbrust] remove println 46a3aee [Michael Armbrust] Slightly easier to read test output. 6cc434d [Michael Armbrust] Recursively convert rows. e369b91 [Michael Armbrust] Merge remote-tracking branch 'origin/master' into udts 15c10a6 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into sql-udt2 f3c72fe [Joseph K. Bradley] Fixing merge e13cd8a [Joseph K. Bradley] Removed Vector UDTs 5817b2b [Joseph K. Bradley] style edits 30ce5b2 [Joseph K. Bradley] updates based on code review d063380 [Joseph K. Bradley] Cleaned up Java UDT Suite, and added warning about element ordering when creating schema from Java Bean a571bb6 [Joseph K. Bradley] Removed old UDT code (registry and Java UDTs). Cleaned up other code. Extended JavaUserDefinedTypeSuite 6fddc1c [Joseph K. Bradley] Made MyLabeledPoint into a Java Bean 20630bc [Joseph K. Bradley] fixed scalastyle fa86b20 [Joseph K. Bradley] Removed Java UserDefinedType, and made UDTs private[spark] for now 8de957c [Joseph K. Bradley] Modified UserDefinedType to store Java class of user type so that registerUDT takes only the udt argument. 8b242ea [Joseph K. Bradley] Fixed merge error after last merge. Note: Last merge commit also removed SQL UDT examples from mllib. 7f29656 [Joseph K. Bradley] Moved udt case to top of all matches. Small cleanups b028675 [Xiangrui Meng] allow any type in UDT 4500d8a [Xiangrui Meng] update example code 87264a5 [Xiangrui Meng] remove debug code 3143ac3 [Xiangrui Meng] remove unnecessary changes cfbc321 [Xiangrui Meng] support UDT in parquet db16139 [Joseph K. Bradley] Added more doc for UserDefinedType. Removed unused code in Suite 759af7a [Joseph K. Bradley] Added more doc to UserDefineType 63626a4 [Joseph K. Bradley] Updated ScalaReflectionsSuite per @marmbrus suggestions 51e5282 [Joseph K. Bradley] fixed 1 test f025035 [Joseph K. Bradley] Cleanups before PR. Added new tests 85872f6 [Michael Armbrust] Allow schema calculation to be lazy, but ensure its available on executors. dff99d6 [Joseph K. Bradley] Added UDTs for Vectors in MLlib, plus DatasetExample using the UDTs cd60cb4 [Joseph K. Bradley] Trying to get other SQL tests to run 34a5831 [Joseph K. Bradley] Added MLlib dependency on SQL. e1f7b9c [Joseph K. Bradley] blah 2f40c02 [Joseph K. Bradley] renamed UDT types 3579035 [Joseph K. Bradley] udt annotation now working b226b9e [Joseph K. Bradley] Changing UDT to annotation fea04af [Joseph K. Bradley] more cleanups 964b32e [Joseph K. Bradley] some cleanups 893ee4c [Joseph K. Bradley] udt finallly working 50f9726 [Joseph K. Bradley] udts 04303c9 [Joseph K. Bradley] udts 39f8707 [Joseph K. Bradley] removed old udt suite 273ac96 [Joseph K. Bradley] basic UDT is working, but deserialization has yet to be done 8bebf24 [Joseph K. Bradley] commented out convertRowToScala for debugging 53de70f [Joseph K. Bradley] more udts... 982c035 [Joseph K. Bradley] still working on UDTs 19b2f60 [Joseph K. Bradley] still working on UDTs 0eaeb81 [Joseph K. Bradley] Still working on UDTs 105c5a3 [Joseph K. Bradley] Adding UserDefinedType to SQL, not done yet.
Diffstat (limited to 'sql/catalyst')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala155
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/annotation/SQLUserDefinedType.java46
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala6
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala53
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala13
5 files changed, 206 insertions, 67 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index 8fbdf664b7..9cda373623 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -19,6 +19,8 @@ package org.apache.spark.sql.catalyst
import java.sql.{Date, Timestamp}
+import org.apache.spark.util.Utils
+import org.apache.spark.sql.catalyst.annotation.SQLUserDefinedType
import org.apache.spark.sql.catalyst.expressions.{GenericRow, Attribute, AttributeReference, Row}
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
import org.apache.spark.sql.catalyst.types._
@@ -35,25 +37,46 @@ object ScalaReflection {
case class Schema(dataType: DataType, nullable: Boolean)
- /** Converts Scala objects to catalyst rows / types */
- def convertToCatalyst(a: Any): Any = a match {
- case o: Option[_] => o.map(convertToCatalyst).orNull
- case s: Seq[_] => s.map(convertToCatalyst)
- case m: Map[_, _] => m.map { case (k, v) => convertToCatalyst(k) -> convertToCatalyst(v) }
- case p: Product => new GenericRow(p.productIterator.map(convertToCatalyst).toArray)
- case d: BigDecimal => Decimal(d)
- case other => other
+ /**
+ * Converts Scala objects to catalyst rows / types.
+ * Note: This is always called after schemaFor has been called.
+ * This ordering is important for UDT registration.
+ */
+ def convertToCatalyst(a: Any, dataType: DataType): Any = (a, dataType) match {
+ // Check UDT first since UDTs can override other types
+ case (obj, udt: UserDefinedType[_]) => udt.serialize(obj)
+ case (o: Option[_], _) => o.map(convertToCatalyst(_, dataType)).orNull
+ case (s: Seq[_], arrayType: ArrayType) => s.map(convertToCatalyst(_, arrayType.elementType))
+ case (m: Map[_, _], mapType: MapType) => m.map { case (k, v) =>
+ convertToCatalyst(k, mapType.keyType) -> convertToCatalyst(v, mapType.valueType)
+ }
+ case (p: Product, structType: StructType) =>
+ new GenericRow(
+ p.productIterator.toSeq.zip(structType.fields).map { case (elem, field) =>
+ convertToCatalyst(elem, field.dataType)
+ }.toArray)
+ case (d: BigDecimal, _) => Decimal(d)
+ case (other, _) => other
}
/** Converts Catalyst types used internally in rows to standard Scala types */
- def convertToScala(a: Any): Any = a match {
- case s: Seq[_] => s.map(convertToScala)
- case m: Map[_, _] => m.map { case (k, v) => convertToScala(k) -> convertToScala(v) }
- case d: Decimal => d.toBigDecimal
- case other => other
+ def convertToScala(a: Any, dataType: DataType): Any = (a, dataType) match {
+ // Check UDT first since UDTs can override other types
+ case (d, udt: UserDefinedType[_]) => udt.deserialize(d)
+ case (s: Seq[_], arrayType: ArrayType) => s.map(convertToScala(_, arrayType.elementType))
+ case (m: Map[_, _], mapType: MapType) => m.map { case (k, v) =>
+ convertToScala(k, mapType.keyType) -> convertToScala(v, mapType.valueType)
+ }
+ case (r: Row, s: StructType) => convertRowToScala(r, s)
+ case (d: Decimal, _: DecimalType) => d.toBigDecimal
+ case (other, _) => other
}
- def convertRowToScala(r: Row): Row = new GenericRow(r.toArray.map(convertToScala))
+ def convertRowToScala(r: Row, schema: StructType): Row = {
+ new GenericRow(
+ r.zip(schema.fields.map(_.dataType))
+ .map(r_dt => convertToScala(r_dt._1, r_dt._2)).toArray)
+ }
/** Returns a Sequence of attributes for the given case class type. */
def attributesFor[T: TypeTag]: Seq[Attribute] = schemaFor[T] match {
@@ -65,52 +88,64 @@ object ScalaReflection {
def schemaFor[T: TypeTag]: Schema = schemaFor(typeOf[T])
/** Returns a catalyst DataType and its nullability for the given Scala Type using reflection. */
- def schemaFor(tpe: `Type`): Schema = tpe match {
- case t if t <:< typeOf[Option[_]] =>
- val TypeRef(_, _, Seq(optType)) = t
- Schema(schemaFor(optType).dataType, nullable = true)
- case t if t <:< typeOf[Product] =>
- val formalTypeArgs = t.typeSymbol.asClass.typeParams
- val TypeRef(_, _, actualTypeArgs) = t
- val params = t.member(nme.CONSTRUCTOR).asMethod.paramss
- Schema(StructType(
- params.head.map { p =>
- val Schema(dataType, nullable) =
- schemaFor(p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs))
- StructField(p.name.toString, dataType, nullable)
- }), nullable = true)
- // Need to decide if we actually need a special type here.
- case t if t <:< typeOf[Array[Byte]] => Schema(BinaryType, nullable = true)
- case t if t <:< typeOf[Array[_]] =>
- sys.error(s"Only Array[Byte] supported now, use Seq instead of $t")
- case t if t <:< typeOf[Seq[_]] =>
- val TypeRef(_, _, Seq(elementType)) = t
- val Schema(dataType, nullable) = schemaFor(elementType)
- Schema(ArrayType(dataType, containsNull = nullable), nullable = true)
- case t if t <:< typeOf[Map[_,_]] =>
- val TypeRef(_, _, Seq(keyType, valueType)) = t
- val Schema(valueDataType, valueNullable) = schemaFor(valueType)
- Schema(MapType(schemaFor(keyType).dataType,
- valueDataType, valueContainsNull = valueNullable), nullable = true)
- case t if t <:< typeOf[String] => Schema(StringType, nullable = true)
- case t if t <:< typeOf[Timestamp] => Schema(TimestampType, nullable = true)
- case t if t <:< typeOf[Date] => Schema(DateType, nullable = true)
- case t if t <:< typeOf[BigDecimal] => Schema(DecimalType.Unlimited, nullable = true)
- case t if t <:< typeOf[Decimal] => Schema(DecimalType.Unlimited, nullable = true)
- case t if t <:< typeOf[java.lang.Integer] => Schema(IntegerType, nullable = true)
- case t if t <:< typeOf[java.lang.Long] => Schema(LongType, nullable = true)
- case t if t <:< typeOf[java.lang.Double] => Schema(DoubleType, nullable = true)
- case t if t <:< typeOf[java.lang.Float] => Schema(FloatType, nullable = true)
- case t if t <:< typeOf[java.lang.Short] => Schema(ShortType, nullable = true)
- case t if t <:< typeOf[java.lang.Byte] => Schema(ByteType, nullable = true)
- case t if t <:< typeOf[java.lang.Boolean] => Schema(BooleanType, nullable = true)
- case t if t <:< definitions.IntTpe => Schema(IntegerType, nullable = false)
- case t if t <:< definitions.LongTpe => Schema(LongType, nullable = false)
- case t if t <:< definitions.DoubleTpe => Schema(DoubleType, nullable = false)
- case t if t <:< definitions.FloatTpe => Schema(FloatType, nullable = false)
- case t if t <:< definitions.ShortTpe => Schema(ShortType, nullable = false)
- case t if t <:< definitions.ByteTpe => Schema(ByteType, nullable = false)
- case t if t <:< definitions.BooleanTpe => Schema(BooleanType, nullable = false)
+ def schemaFor(tpe: `Type`): Schema = {
+ val className: String = tpe.erasure.typeSymbol.asClass.fullName
+ tpe match {
+ case t if Utils.classIsLoadable(className) &&
+ Utils.classForName(className).isAnnotationPresent(classOf[SQLUserDefinedType]) =>
+ // Note: We check for classIsLoadable above since Utils.classForName uses Java reflection,
+ // whereas className is from Scala reflection. This can make it hard to find classes
+ // in some cases, such as when a class is enclosed in an object (in which case
+ // Java appends a '$' to the object name but Scala does not).
+ val udt = Utils.classForName(className)
+ .getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance()
+ Schema(udt, nullable = true)
+ case t if t <:< typeOf[Option[_]] =>
+ val TypeRef(_, _, Seq(optType)) = t
+ Schema(schemaFor(optType).dataType, nullable = true)
+ case t if t <:< typeOf[Product] =>
+ val formalTypeArgs = t.typeSymbol.asClass.typeParams
+ val TypeRef(_, _, actualTypeArgs) = t
+ val params = t.member(nme.CONSTRUCTOR).asMethod.paramss
+ Schema(StructType(
+ params.head.map { p =>
+ val Schema(dataType, nullable) =
+ schemaFor(p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs))
+ StructField(p.name.toString, dataType, nullable)
+ }), nullable = true)
+ // Need to decide if we actually need a special type here.
+ case t if t <:< typeOf[Array[Byte]] => Schema(BinaryType, nullable = true)
+ case t if t <:< typeOf[Array[_]] =>
+ sys.error(s"Only Array[Byte] supported now, use Seq instead of $t")
+ case t if t <:< typeOf[Seq[_]] =>
+ val TypeRef(_, _, Seq(elementType)) = t
+ val Schema(dataType, nullable) = schemaFor(elementType)
+ Schema(ArrayType(dataType, containsNull = nullable), nullable = true)
+ case t if t <:< typeOf[Map[_, _]] =>
+ val TypeRef(_, _, Seq(keyType, valueType)) = t
+ val Schema(valueDataType, valueNullable) = schemaFor(valueType)
+ Schema(MapType(schemaFor(keyType).dataType,
+ valueDataType, valueContainsNull = valueNullable), nullable = true)
+ case t if t <:< typeOf[String] => Schema(StringType, nullable = true)
+ case t if t <:< typeOf[Timestamp] => Schema(TimestampType, nullable = true)
+ case t if t <:< typeOf[Date] => Schema(DateType, nullable = true)
+ case t if t <:< typeOf[BigDecimal] => Schema(DecimalType.Unlimited, nullable = true)
+ case t if t <:< typeOf[Decimal] => Schema(DecimalType.Unlimited, nullable = true)
+ case t if t <:< typeOf[java.lang.Integer] => Schema(IntegerType, nullable = true)
+ case t if t <:< typeOf[java.lang.Long] => Schema(LongType, nullable = true)
+ case t if t <:< typeOf[java.lang.Double] => Schema(DoubleType, nullable = true)
+ case t if t <:< typeOf[java.lang.Float] => Schema(FloatType, nullable = true)
+ case t if t <:< typeOf[java.lang.Short] => Schema(ShortType, nullable = true)
+ case t if t <:< typeOf[java.lang.Byte] => Schema(ByteType, nullable = true)
+ case t if t <:< typeOf[java.lang.Boolean] => Schema(BooleanType, nullable = true)
+ case t if t <:< definitions.IntTpe => Schema(IntegerType, nullable = false)
+ case t if t <:< definitions.LongTpe => Schema(LongType, nullable = false)
+ case t if t <:< definitions.DoubleTpe => Schema(DoubleType, nullable = false)
+ case t if t <:< definitions.FloatTpe => Schema(FloatType, nullable = false)
+ case t if t <:< definitions.ShortTpe => Schema(ShortType, nullable = false)
+ case t if t <:< definitions.ByteTpe => Schema(ByteType, nullable = false)
+ case t if t <:< definitions.BooleanTpe => Schema(BooleanType, nullable = false)
+ }
}
def typeOfObject: PartialFunction[Any, DataType] = {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/annotation/SQLUserDefinedType.java b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/annotation/SQLUserDefinedType.java
new file mode 100644
index 0000000000..e966aeea1c
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/annotation/SQLUserDefinedType.java
@@ -0,0 +1,46 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.annotation;
+
+import java.lang.annotation.*;
+
+import org.apache.spark.annotation.DeveloperApi;
+import org.apache.spark.sql.catalyst.types.UserDefinedType;
+
+/**
+ * ::DeveloperApi::
+ * A user-defined type which can be automatically recognized by a SQLContext and registered.
+ *
+ * WARNING: This annotation will only work if both Java and Scala reflection return the same class
+ * names (after erasure) for the UDT. This will NOT be the case when, e.g., the UDT class
+ * is enclosed in an object (a singleton).
+ *
+ * WARNING: UDTs are currently only supported from Scala.
+ */
+// TODO: Should I used @Documented ?
+@DeveloperApi
+@Retention(RetentionPolicy.RUNTIME)
+@Target(ElementType.TYPE)
+public @interface SQLUserDefinedType {
+
+ /**
+ * Returns an instance of the UserDefinedType which can serialize and deserialize the user
+ * class to and from Catalyst built-in types.
+ */
+ Class<? extends UserDefinedType<?> > udt();
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala
index 1b687a443e..fa1786e74b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala
@@ -21,6 +21,10 @@ import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.catalyst.types.DataType
import org.apache.spark.util.ClosureCleaner
+/**
+ * User-defined function.
+ * @param dataType Return type of function.
+ */
case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expression])
extends Expression {
@@ -347,6 +351,6 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi
}
// scalastyle:on
- ScalaReflection.convertToCatalyst(result)
+ ScalaReflection.convertToCatalyst(result, dataType)
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala
index d25f3a619d..cc5015ad3c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala
@@ -29,11 +29,12 @@ import org.json4s.JsonAST.JValue
import org.json4s.JsonDSL._
import org.json4s.jackson.JsonMethods._
+import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.sql.catalyst.ScalaReflectionLock
-import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression}
+import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression, Row}
+import org.apache.spark.sql.catalyst.types.decimal._
import org.apache.spark.sql.catalyst.util.Metadata
import org.apache.spark.util.Utils
-import org.apache.spark.sql.catalyst.types.decimal._
object DataType {
def fromJson(json: String): DataType = parseDataType(parse(json))
@@ -67,6 +68,11 @@ object DataType {
("fields", JArray(fields)),
("type", JString("struct"))) =>
StructType(fields.map(parseStructField))
+
+ case JSortedObject(
+ ("class", JString(udtClass)),
+ ("type", JString("udt"))) =>
+ Class.forName(udtClass).newInstance().asInstanceOf[UserDefinedType[_]]
}
private def parseStructField(json: JValue): StructField = json match {
@@ -342,6 +348,7 @@ object FractionalType {
case _ => false
}
}
+
abstract class FractionalType extends NumericType {
private[sql] val fractional: Fractional[JvmType]
private[sql] val asIntegral: Integral[JvmType]
@@ -565,3 +572,45 @@ case class MapType(
("valueType" -> valueType.jsonValue) ~
("valueContainsNull" -> valueContainsNull)
}
+
+/**
+ * ::DeveloperApi::
+ * The data type for User Defined Types (UDTs).
+ *
+ * This interface allows a user to make their own classes more interoperable with SparkSQL;
+ * e.g., by creating a [[UserDefinedType]] for a class X, it becomes possible to create
+ * a SchemaRDD which has class X in the schema.
+ *
+ * For SparkSQL to recognize UDTs, the UDT must be annotated with
+ * [[org.apache.spark.sql.catalyst.annotation.SQLUserDefinedType]].
+ *
+ * The conversion via `serialize` occurs when instantiating a `SchemaRDD` from another RDD.
+ * The conversion via `deserialize` occurs when reading from a `SchemaRDD`.
+ */
+@DeveloperApi
+abstract class UserDefinedType[UserType] extends DataType with Serializable {
+
+ /** Underlying storage type for this UDT */
+ def sqlType: DataType
+
+ /**
+ * Convert the user type to a SQL datum
+ *
+ * TODO: Can we make this take obj: UserType? The issue is in ScalaReflection.convertToCatalyst,
+ * where we need to convert Any to UserType.
+ */
+ def serialize(obj: Any): Any
+
+ /** Convert a SQL datum to the user type */
+ def deserialize(datum: Any): UserType
+
+ override private[sql] def jsonValue: JValue = {
+ ("type" -> "udt") ~
+ ("class" -> this.getClass.getName)
+ }
+
+ /**
+ * Class object for the UserType
+ */
+ def userClass: java.lang.Class[UserType]
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
index 21b2c8e20d..ddc3d44869 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
@@ -22,6 +22,7 @@ import java.sql.{Date, Timestamp}
import org.scalatest.FunSuite
+import org.apache.spark.sql.catalyst.expressions.Row
import org.apache.spark.sql.catalyst.types._
case class PrimitiveData(
@@ -239,13 +240,17 @@ class ScalaReflectionSuite extends FunSuite {
test("convert PrimitiveData to catalyst") {
val data = PrimitiveData(1, 1, 1, 1, 1, 1, true)
val convertedData = Seq(1, 1.toLong, 1.toDouble, 1.toFloat, 1.toShort, 1.toByte, true)
- assert(convertToCatalyst(data) === convertedData)
+ val dataType = schemaFor[PrimitiveData].dataType
+ assert(convertToCatalyst(data, dataType) === convertedData)
}
test("convert Option[Product] to catalyst") {
val primitiveData = PrimitiveData(1, 1, 1, 1, 1, 1, true)
- val data = OptionalData(Some(1), Some(1), Some(1), Some(1), Some(1), Some(1), Some(true), Some(primitiveData))
- val convertedData = Seq(1, 1.toLong, 1.toDouble, 1.toFloat, 1.toShort, 1.toByte, true, convertToCatalyst(primitiveData))
- assert(convertToCatalyst(data) === convertedData)
+ val data = OptionalData(Some(2), Some(2), Some(2), Some(2), Some(2), Some(2), Some(true),
+ Some(primitiveData))
+ val dataType = schemaFor[OptionalData].dataType
+ val convertedData = Row(2, 2.toLong, 2.toDouble, 2.toFloat, 2.toShort, 2.toByte, true,
+ Row(1, 1, 1, 1, 1, 1, true))
+ assert(convertToCatalyst(data, dataType) === convertedData)
}
}