aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorJoseph K. Bradley <joseph@databricks.com>2014-11-02 17:55:55 -0800
committerMichael Armbrust <michael@databricks.com>2014-11-02 17:56:00 -0800
commitebd6480587f96e9964d37157253523e0a179171a (patch)
tree221ceefd1a5febae327cd75810efdf0197d56005 /sql
parent2ebd1df3f17993f3cb472ec44c8832213976d99a (diff)
downloadspark-ebd6480587f96e9964d37157253523e0a179171a.tar.gz
spark-ebd6480587f96e9964d37157253523e0a179171a.tar.bz2
spark-ebd6480587f96e9964d37157253523e0a179171a.zip
[SPARK-3572] [SQL] Internal API for User-Defined Types
This PR adds User-Defined Types (UDTs) to SQL. It is a precursor to using SchemaRDD as a Dataset for the new MLlib API. Currently, the UDT API is private since there is incomplete support (e.g., no Java or Python support yet). Author: Joseph K. Bradley <joseph@databricks.com> Author: Michael Armbrust <michael@databricks.com> Author: Xiangrui Meng <meng@databricks.com> Closes #3063 from marmbrus/udts and squashes the following commits: 7ccfc0d [Michael Armbrust] remove println 46a3aee [Michael Armbrust] Slightly easier to read test output. 6cc434d [Michael Armbrust] Recursively convert rows. e369b91 [Michael Armbrust] Merge remote-tracking branch 'origin/master' into udts 15c10a6 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into sql-udt2 f3c72fe [Joseph K. Bradley] Fixing merge e13cd8a [Joseph K. Bradley] Removed Vector UDTs 5817b2b [Joseph K. Bradley] style edits 30ce5b2 [Joseph K. Bradley] updates based on code review d063380 [Joseph K. Bradley] Cleaned up Java UDT Suite, and added warning about element ordering when creating schema from Java Bean a571bb6 [Joseph K. Bradley] Removed old UDT code (registry and Java UDTs). Cleaned up other code. Extended JavaUserDefinedTypeSuite 6fddc1c [Joseph K. Bradley] Made MyLabeledPoint into a Java Bean 20630bc [Joseph K. Bradley] fixed scalastyle fa86b20 [Joseph K. Bradley] Removed Java UserDefinedType, and made UDTs private[spark] for now 8de957c [Joseph K. Bradley] Modified UserDefinedType to store Java class of user type so that registerUDT takes only the udt argument. 8b242ea [Joseph K. Bradley] Fixed merge error after last merge. Note: Last merge commit also removed SQL UDT examples from mllib. 7f29656 [Joseph K. Bradley] Moved udt case to top of all matches. Small cleanups b028675 [Xiangrui Meng] allow any type in UDT 4500d8a [Xiangrui Meng] update example code 87264a5 [Xiangrui Meng] remove debug code 3143ac3 [Xiangrui Meng] remove unnecessary changes cfbc321 [Xiangrui Meng] support UDT in parquet db16139 [Joseph K. Bradley] Added more doc for UserDefinedType. Removed unused code in Suite 759af7a [Joseph K. Bradley] Added more doc to UserDefineType 63626a4 [Joseph K. Bradley] Updated ScalaReflectionsSuite per @marmbrus suggestions 51e5282 [Joseph K. Bradley] fixed 1 test f025035 [Joseph K. Bradley] Cleanups before PR. Added new tests 85872f6 [Michael Armbrust] Allow schema calculation to be lazy, but ensure its available on executors. dff99d6 [Joseph K. Bradley] Added UDTs for Vectors in MLlib, plus DatasetExample using the UDTs cd60cb4 [Joseph K. Bradley] Trying to get other SQL tests to run 34a5831 [Joseph K. Bradley] Added MLlib dependency on SQL. e1f7b9c [Joseph K. Bradley] blah 2f40c02 [Joseph K. Bradley] renamed UDT types 3579035 [Joseph K. Bradley] udt annotation now working b226b9e [Joseph K. Bradley] Changing UDT to annotation fea04af [Joseph K. Bradley] more cleanups 964b32e [Joseph K. Bradley] some cleanups 893ee4c [Joseph K. Bradley] udt finallly working 50f9726 [Joseph K. Bradley] udts 04303c9 [Joseph K. Bradley] udts 39f8707 [Joseph K. Bradley] removed old udt suite 273ac96 [Joseph K. Bradley] basic UDT is working, but deserialization has yet to be done 8bebf24 [Joseph K. Bradley] commented out convertRowToScala for debugging 53de70f [Joseph K. Bradley] more udts... 982c035 [Joseph K. Bradley] still working on UDTs 19b2f60 [Joseph K. Bradley] still working on UDTs 0eaeb81 [Joseph K. Bradley] Still working on UDTs 105c5a3 [Joseph K. Bradley] Adding UserDefinedType to SQL, not done yet.
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala155
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/annotation/SQLUserDefinedType.java46
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala6
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala53
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala13
-rw-r--r--sql/core/src/main/java/org/apache/spark/sql/api/java/UserDefinedType.java53
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala6
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala30
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/UdfRegistration.scala46
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala29
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/api/java/UDTWrappers.scala75
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala11
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala5
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala3
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala6
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala13
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala3
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala3
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala22
-rw-r--r--sql/core/src/test/java/org/apache/spark/sql/api/java/JavaUserDefinedTypeSuite.java88
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala83
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala11
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala4
24 files changed, 620 insertions, 146 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index 8fbdf664b7..9cda373623 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -19,6 +19,8 @@ package org.apache.spark.sql.catalyst
import java.sql.{Date, Timestamp}
+import org.apache.spark.util.Utils
+import org.apache.spark.sql.catalyst.annotation.SQLUserDefinedType
import org.apache.spark.sql.catalyst.expressions.{GenericRow, Attribute, AttributeReference, Row}
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
import org.apache.spark.sql.catalyst.types._
@@ -35,25 +37,46 @@ object ScalaReflection {
case class Schema(dataType: DataType, nullable: Boolean)
- /** Converts Scala objects to catalyst rows / types */
- def convertToCatalyst(a: Any): Any = a match {
- case o: Option[_] => o.map(convertToCatalyst).orNull
- case s: Seq[_] => s.map(convertToCatalyst)
- case m: Map[_, _] => m.map { case (k, v) => convertToCatalyst(k) -> convertToCatalyst(v) }
- case p: Product => new GenericRow(p.productIterator.map(convertToCatalyst).toArray)
- case d: BigDecimal => Decimal(d)
- case other => other
+ /**
+ * Converts Scala objects to catalyst rows / types.
+ * Note: This is always called after schemaFor has been called.
+ * This ordering is important for UDT registration.
+ */
+ def convertToCatalyst(a: Any, dataType: DataType): Any = (a, dataType) match {
+ // Check UDT first since UDTs can override other types
+ case (obj, udt: UserDefinedType[_]) => udt.serialize(obj)
+ case (o: Option[_], _) => o.map(convertToCatalyst(_, dataType)).orNull
+ case (s: Seq[_], arrayType: ArrayType) => s.map(convertToCatalyst(_, arrayType.elementType))
+ case (m: Map[_, _], mapType: MapType) => m.map { case (k, v) =>
+ convertToCatalyst(k, mapType.keyType) -> convertToCatalyst(v, mapType.valueType)
+ }
+ case (p: Product, structType: StructType) =>
+ new GenericRow(
+ p.productIterator.toSeq.zip(structType.fields).map { case (elem, field) =>
+ convertToCatalyst(elem, field.dataType)
+ }.toArray)
+ case (d: BigDecimal, _) => Decimal(d)
+ case (other, _) => other
}
/** Converts Catalyst types used internally in rows to standard Scala types */
- def convertToScala(a: Any): Any = a match {
- case s: Seq[_] => s.map(convertToScala)
- case m: Map[_, _] => m.map { case (k, v) => convertToScala(k) -> convertToScala(v) }
- case d: Decimal => d.toBigDecimal
- case other => other
+ def convertToScala(a: Any, dataType: DataType): Any = (a, dataType) match {
+ // Check UDT first since UDTs can override other types
+ case (d, udt: UserDefinedType[_]) => udt.deserialize(d)
+ case (s: Seq[_], arrayType: ArrayType) => s.map(convertToScala(_, arrayType.elementType))
+ case (m: Map[_, _], mapType: MapType) => m.map { case (k, v) =>
+ convertToScala(k, mapType.keyType) -> convertToScala(v, mapType.valueType)
+ }
+ case (r: Row, s: StructType) => convertRowToScala(r, s)
+ case (d: Decimal, _: DecimalType) => d.toBigDecimal
+ case (other, _) => other
}
- def convertRowToScala(r: Row): Row = new GenericRow(r.toArray.map(convertToScala))
+ def convertRowToScala(r: Row, schema: StructType): Row = {
+ new GenericRow(
+ r.zip(schema.fields.map(_.dataType))
+ .map(r_dt => convertToScala(r_dt._1, r_dt._2)).toArray)
+ }
/** Returns a Sequence of attributes for the given case class type. */
def attributesFor[T: TypeTag]: Seq[Attribute] = schemaFor[T] match {
@@ -65,52 +88,64 @@ object ScalaReflection {
def schemaFor[T: TypeTag]: Schema = schemaFor(typeOf[T])
/** Returns a catalyst DataType and its nullability for the given Scala Type using reflection. */
- def schemaFor(tpe: `Type`): Schema = tpe match {
- case t if t <:< typeOf[Option[_]] =>
- val TypeRef(_, _, Seq(optType)) = t
- Schema(schemaFor(optType).dataType, nullable = true)
- case t if t <:< typeOf[Product] =>
- val formalTypeArgs = t.typeSymbol.asClass.typeParams
- val TypeRef(_, _, actualTypeArgs) = t
- val params = t.member(nme.CONSTRUCTOR).asMethod.paramss
- Schema(StructType(
- params.head.map { p =>
- val Schema(dataType, nullable) =
- schemaFor(p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs))
- StructField(p.name.toString, dataType, nullable)
- }), nullable = true)
- // Need to decide if we actually need a special type here.
- case t if t <:< typeOf[Array[Byte]] => Schema(BinaryType, nullable = true)
- case t if t <:< typeOf[Array[_]] =>
- sys.error(s"Only Array[Byte] supported now, use Seq instead of $t")
- case t if t <:< typeOf[Seq[_]] =>
- val TypeRef(_, _, Seq(elementType)) = t
- val Schema(dataType, nullable) = schemaFor(elementType)
- Schema(ArrayType(dataType, containsNull = nullable), nullable = true)
- case t if t <:< typeOf[Map[_,_]] =>
- val TypeRef(_, _, Seq(keyType, valueType)) = t
- val Schema(valueDataType, valueNullable) = schemaFor(valueType)
- Schema(MapType(schemaFor(keyType).dataType,
- valueDataType, valueContainsNull = valueNullable), nullable = true)
- case t if t <:< typeOf[String] => Schema(StringType, nullable = true)
- case t if t <:< typeOf[Timestamp] => Schema(TimestampType, nullable = true)
- case t if t <:< typeOf[Date] => Schema(DateType, nullable = true)
- case t if t <:< typeOf[BigDecimal] => Schema(DecimalType.Unlimited, nullable = true)
- case t if t <:< typeOf[Decimal] => Schema(DecimalType.Unlimited, nullable = true)
- case t if t <:< typeOf[java.lang.Integer] => Schema(IntegerType, nullable = true)
- case t if t <:< typeOf[java.lang.Long] => Schema(LongType, nullable = true)
- case t if t <:< typeOf[java.lang.Double] => Schema(DoubleType, nullable = true)
- case t if t <:< typeOf[java.lang.Float] => Schema(FloatType, nullable = true)
- case t if t <:< typeOf[java.lang.Short] => Schema(ShortType, nullable = true)
- case t if t <:< typeOf[java.lang.Byte] => Schema(ByteType, nullable = true)
- case t if t <:< typeOf[java.lang.Boolean] => Schema(BooleanType, nullable = true)
- case t if t <:< definitions.IntTpe => Schema(IntegerType, nullable = false)
- case t if t <:< definitions.LongTpe => Schema(LongType, nullable = false)
- case t if t <:< definitions.DoubleTpe => Schema(DoubleType, nullable = false)
- case t if t <:< definitions.FloatTpe => Schema(FloatType, nullable = false)
- case t if t <:< definitions.ShortTpe => Schema(ShortType, nullable = false)
- case t if t <:< definitions.ByteTpe => Schema(ByteType, nullable = false)
- case t if t <:< definitions.BooleanTpe => Schema(BooleanType, nullable = false)
+ def schemaFor(tpe: `Type`): Schema = {
+ val className: String = tpe.erasure.typeSymbol.asClass.fullName
+ tpe match {
+ case t if Utils.classIsLoadable(className) &&
+ Utils.classForName(className).isAnnotationPresent(classOf[SQLUserDefinedType]) =>
+ // Note: We check for classIsLoadable above since Utils.classForName uses Java reflection,
+ // whereas className is from Scala reflection. This can make it hard to find classes
+ // in some cases, such as when a class is enclosed in an object (in which case
+ // Java appends a '$' to the object name but Scala does not).
+ val udt = Utils.classForName(className)
+ .getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance()
+ Schema(udt, nullable = true)
+ case t if t <:< typeOf[Option[_]] =>
+ val TypeRef(_, _, Seq(optType)) = t
+ Schema(schemaFor(optType).dataType, nullable = true)
+ case t if t <:< typeOf[Product] =>
+ val formalTypeArgs = t.typeSymbol.asClass.typeParams
+ val TypeRef(_, _, actualTypeArgs) = t
+ val params = t.member(nme.CONSTRUCTOR).asMethod.paramss
+ Schema(StructType(
+ params.head.map { p =>
+ val Schema(dataType, nullable) =
+ schemaFor(p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs))
+ StructField(p.name.toString, dataType, nullable)
+ }), nullable = true)
+ // Need to decide if we actually need a special type here.
+ case t if t <:< typeOf[Array[Byte]] => Schema(BinaryType, nullable = true)
+ case t if t <:< typeOf[Array[_]] =>
+ sys.error(s"Only Array[Byte] supported now, use Seq instead of $t")
+ case t if t <:< typeOf[Seq[_]] =>
+ val TypeRef(_, _, Seq(elementType)) = t
+ val Schema(dataType, nullable) = schemaFor(elementType)
+ Schema(ArrayType(dataType, containsNull = nullable), nullable = true)
+ case t if t <:< typeOf[Map[_, _]] =>
+ val TypeRef(_, _, Seq(keyType, valueType)) = t
+ val Schema(valueDataType, valueNullable) = schemaFor(valueType)
+ Schema(MapType(schemaFor(keyType).dataType,
+ valueDataType, valueContainsNull = valueNullable), nullable = true)
+ case t if t <:< typeOf[String] => Schema(StringType, nullable = true)
+ case t if t <:< typeOf[Timestamp] => Schema(TimestampType, nullable = true)
+ case t if t <:< typeOf[Date] => Schema(DateType, nullable = true)
+ case t if t <:< typeOf[BigDecimal] => Schema(DecimalType.Unlimited, nullable = true)
+ case t if t <:< typeOf[Decimal] => Schema(DecimalType.Unlimited, nullable = true)
+ case t if t <:< typeOf[java.lang.Integer] => Schema(IntegerType, nullable = true)
+ case t if t <:< typeOf[java.lang.Long] => Schema(LongType, nullable = true)
+ case t if t <:< typeOf[java.lang.Double] => Schema(DoubleType, nullable = true)
+ case t if t <:< typeOf[java.lang.Float] => Schema(FloatType, nullable = true)
+ case t if t <:< typeOf[java.lang.Short] => Schema(ShortType, nullable = true)
+ case t if t <:< typeOf[java.lang.Byte] => Schema(ByteType, nullable = true)
+ case t if t <:< typeOf[java.lang.Boolean] => Schema(BooleanType, nullable = true)
+ case t if t <:< definitions.IntTpe => Schema(IntegerType, nullable = false)
+ case t if t <:< definitions.LongTpe => Schema(LongType, nullable = false)
+ case t if t <:< definitions.DoubleTpe => Schema(DoubleType, nullable = false)
+ case t if t <:< definitions.FloatTpe => Schema(FloatType, nullable = false)
+ case t if t <:< definitions.ShortTpe => Schema(ShortType, nullable = false)
+ case t if t <:< definitions.ByteTpe => Schema(ByteType, nullable = false)
+ case t if t <:< definitions.BooleanTpe => Schema(BooleanType, nullable = false)
+ }
}
def typeOfObject: PartialFunction[Any, DataType] = {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/annotation/SQLUserDefinedType.java b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/annotation/SQLUserDefinedType.java
new file mode 100644
index 0000000000..e966aeea1c
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/annotation/SQLUserDefinedType.java
@@ -0,0 +1,46 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.annotation;
+
+import java.lang.annotation.*;
+
+import org.apache.spark.annotation.DeveloperApi;
+import org.apache.spark.sql.catalyst.types.UserDefinedType;
+
+/**
+ * ::DeveloperApi::
+ * A user-defined type which can be automatically recognized by a SQLContext and registered.
+ *
+ * WARNING: This annotation will only work if both Java and Scala reflection return the same class
+ * names (after erasure) for the UDT. This will NOT be the case when, e.g., the UDT class
+ * is enclosed in an object (a singleton).
+ *
+ * WARNING: UDTs are currently only supported from Scala.
+ */
+// TODO: Should I used @Documented ?
+@DeveloperApi
+@Retention(RetentionPolicy.RUNTIME)
+@Target(ElementType.TYPE)
+public @interface SQLUserDefinedType {
+
+ /**
+ * Returns an instance of the UserDefinedType which can serialize and deserialize the user
+ * class to and from Catalyst built-in types.
+ */
+ Class<? extends UserDefinedType<?> > udt();
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala
index 1b687a443e..fa1786e74b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala
@@ -21,6 +21,10 @@ import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.catalyst.types.DataType
import org.apache.spark.util.ClosureCleaner
+/**
+ * User-defined function.
+ * @param dataType Return type of function.
+ */
case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expression])
extends Expression {
@@ -347,6 +351,6 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi
}
// scalastyle:on
- ScalaReflection.convertToCatalyst(result)
+ ScalaReflection.convertToCatalyst(result, dataType)
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala
index d25f3a619d..cc5015ad3c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala
@@ -29,11 +29,12 @@ import org.json4s.JsonAST.JValue
import org.json4s.JsonDSL._
import org.json4s.jackson.JsonMethods._
+import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.sql.catalyst.ScalaReflectionLock
-import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression}
+import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression, Row}
+import org.apache.spark.sql.catalyst.types.decimal._
import org.apache.spark.sql.catalyst.util.Metadata
import org.apache.spark.util.Utils
-import org.apache.spark.sql.catalyst.types.decimal._
object DataType {
def fromJson(json: String): DataType = parseDataType(parse(json))
@@ -67,6 +68,11 @@ object DataType {
("fields", JArray(fields)),
("type", JString("struct"))) =>
StructType(fields.map(parseStructField))
+
+ case JSortedObject(
+ ("class", JString(udtClass)),
+ ("type", JString("udt"))) =>
+ Class.forName(udtClass).newInstance().asInstanceOf[UserDefinedType[_]]
}
private def parseStructField(json: JValue): StructField = json match {
@@ -342,6 +348,7 @@ object FractionalType {
case _ => false
}
}
+
abstract class FractionalType extends NumericType {
private[sql] val fractional: Fractional[JvmType]
private[sql] val asIntegral: Integral[JvmType]
@@ -565,3 +572,45 @@ case class MapType(
("valueType" -> valueType.jsonValue) ~
("valueContainsNull" -> valueContainsNull)
}
+
+/**
+ * ::DeveloperApi::
+ * The data type for User Defined Types (UDTs).
+ *
+ * This interface allows a user to make their own classes more interoperable with SparkSQL;
+ * e.g., by creating a [[UserDefinedType]] for a class X, it becomes possible to create
+ * a SchemaRDD which has class X in the schema.
+ *
+ * For SparkSQL to recognize UDTs, the UDT must be annotated with
+ * [[org.apache.spark.sql.catalyst.annotation.SQLUserDefinedType]].
+ *
+ * The conversion via `serialize` occurs when instantiating a `SchemaRDD` from another RDD.
+ * The conversion via `deserialize` occurs when reading from a `SchemaRDD`.
+ */
+@DeveloperApi
+abstract class UserDefinedType[UserType] extends DataType with Serializable {
+
+ /** Underlying storage type for this UDT */
+ def sqlType: DataType
+
+ /**
+ * Convert the user type to a SQL datum
+ *
+ * TODO: Can we make this take obj: UserType? The issue is in ScalaReflection.convertToCatalyst,
+ * where we need to convert Any to UserType.
+ */
+ def serialize(obj: Any): Any
+
+ /** Convert a SQL datum to the user type */
+ def deserialize(datum: Any): UserType
+
+ override private[sql] def jsonValue: JValue = {
+ ("type" -> "udt") ~
+ ("class" -> this.getClass.getName)
+ }
+
+ /**
+ * Class object for the UserType
+ */
+ def userClass: java.lang.Class[UserType]
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
index 21b2c8e20d..ddc3d44869 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
@@ -22,6 +22,7 @@ import java.sql.{Date, Timestamp}
import org.scalatest.FunSuite
+import org.apache.spark.sql.catalyst.expressions.Row
import org.apache.spark.sql.catalyst.types._
case class PrimitiveData(
@@ -239,13 +240,17 @@ class ScalaReflectionSuite extends FunSuite {
test("convert PrimitiveData to catalyst") {
val data = PrimitiveData(1, 1, 1, 1, 1, 1, true)
val convertedData = Seq(1, 1.toLong, 1.toDouble, 1.toFloat, 1.toShort, 1.toByte, true)
- assert(convertToCatalyst(data) === convertedData)
+ val dataType = schemaFor[PrimitiveData].dataType
+ assert(convertToCatalyst(data, dataType) === convertedData)
}
test("convert Option[Product] to catalyst") {
val primitiveData = PrimitiveData(1, 1, 1, 1, 1, 1, true)
- val data = OptionalData(Some(1), Some(1), Some(1), Some(1), Some(1), Some(1), Some(true), Some(primitiveData))
- val convertedData = Seq(1, 1.toLong, 1.toDouble, 1.toFloat, 1.toShort, 1.toByte, true, convertToCatalyst(primitiveData))
- assert(convertToCatalyst(data) === convertedData)
+ val data = OptionalData(Some(2), Some(2), Some(2), Some(2), Some(2), Some(2), Some(true),
+ Some(primitiveData))
+ val dataType = schemaFor[OptionalData].dataType
+ val convertedData = Row(2, 2.toLong, 2.toDouble, 2.toFloat, 2.toShort, 2.toByte, true,
+ Row(1, 1, 1, 1, 1, 1, true))
+ assert(convertToCatalyst(data, dataType) === convertedData)
}
}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UserDefinedType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UserDefinedType.java
new file mode 100644
index 0000000000..b751847b46
--- /dev/null
+++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UserDefinedType.java
@@ -0,0 +1,53 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.api.java;
+
+import java.io.Serializable;
+
+import org.apache.spark.annotation.DeveloperApi;
+
+/**
+ * ::DeveloperApi::
+ * The data type representing User-Defined Types (UDTs).
+ * UDTs may use any other DataType for an underlying representation.
+ */
+@DeveloperApi
+public abstract class UserDefinedType<UserType> extends DataType implements Serializable {
+
+ protected UserDefinedType() { }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) return true;
+ if (o == null || getClass() != o.getClass()) return false;
+ UserDefinedType<UserType> that = (UserDefinedType<UserType>) o;
+ return this.sqlType().equals(that.sqlType());
+ }
+
+ /** Underlying storage type for this UDT */
+ public abstract DataType sqlType();
+
+ /** Convert the user type to a SQL datum */
+ public abstract Object serialize(Object obj);
+
+ /** Convert a SQL datum to the user type */
+ public abstract UserType deserialize(Object datum);
+
+ /** Class object for the UserType */
+ public abstract Class<UserType> userClass();
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index 3cf6af5f7a..9e61d18f7e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -107,8 +107,10 @@ class SQLContext(@transient val sparkContext: SparkContext)
*/
implicit def createSchemaRDD[A <: Product: TypeTag](rdd: RDD[A]) = {
SparkPlan.currentContext.set(self)
- new SchemaRDD(this,
- LogicalRDD(ScalaReflection.attributesFor[A], RDDConversions.productToRowRdd(rdd))(self))
+ val attributeSeq = ScalaReflection.attributesFor[A]
+ val schema = StructType.fromAttributes(attributeSeq)
+ val rowRDD = RDDConversions.productToRowRdd(rdd, schema)
+ new SchemaRDD(this, LogicalRDD(attributeSeq, rowRDD)(self))
}
implicit def baseRelationToSchemaRDD(baseRelation: BaseRelation): SchemaRDD = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
index 018a18c4ac..3ee2ea05cf 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
@@ -17,26 +17,24 @@
package org.apache.spark.sql
-import java.util.{Map => JMap, List => JList}
-
-import org.apache.spark.sql.catalyst.ScalaReflection
-import org.apache.spark.storage.StorageLevel
+import java.util.{List => JList}
import scala.collection.JavaConversions._
-import scala.collection.JavaConverters._
import net.razorvine.pickle.Pickler
import org.apache.spark.{Dependency, OneToOneDependency, Partition, Partitioner, TaskContext}
import org.apache.spark.annotation.{AlphaComponent, Experimental}
+import org.apache.spark.api.java.JavaRDD
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.api.java.JavaSchemaRDD
+import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.plans.{Inner, JoinType}
+import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.execution.{LogicalRDD, EvaluatePython}
-import org.apache.spark.api.java.JavaRDD
+import org.apache.spark.storage.StorageLevel
/**
* :: AlphaComponent ::
@@ -114,18 +112,22 @@ class SchemaRDD(
// =========================================================================================
override def compute(split: Partition, context: TaskContext): Iterator[Row] =
- firstParent[Row].compute(split, context).map(ScalaReflection.convertRowToScala)
+ firstParent[Row].compute(split, context).map(ScalaReflection.convertRowToScala(_, this.schema))
override def getPartitions: Array[Partition] = firstParent[Row].partitions
- override protected def getDependencies: Seq[Dependency[_]] =
+ override protected def getDependencies: Seq[Dependency[_]] = {
+ schema // Force reification of the schema so it is available on executors.
+
List(new OneToOneDependency(queryExecution.toRdd))
+ }
- /** Returns the schema of this SchemaRDD (represented by a [[StructType]]).
- *
- * @group schema
- */
- def schema: StructType = queryExecution.analyzed.schema
+ /**
+ * Returns the schema of this SchemaRDD (represented by a [[StructType]]).
+ *
+ * @group schema
+ */
+ lazy val schema: StructType = queryExecution.analyzed.schema
// =======================================================================
// Query DSL
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala
index 15516afb95..fd5f4abcbc 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala
@@ -26,7 +26,7 @@ import org.apache.spark.sql.execution.LogicalRDD
* Contains functions that are shared between all SchemaRDD types (i.e., Scala, Java)
*/
private[sql] trait SchemaRDDLike {
- @transient val sqlContext: SQLContext
+ @transient def sqlContext: SQLContext
@transient val baseLogicalPlan: LogicalPlan
private[sql] def baseSchemaRDD: SchemaRDD
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UdfRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UdfRegistration.scala
index 595b4aa36e..6d4c0d82ac 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/UdfRegistration.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/UdfRegistration.scala
@@ -78,7 +78,7 @@ private[sql] trait UDFRegistration {
s"""
def registerFunction[T: TypeTag](name: String, func: Function$x[$types, T]): Unit = {
def builder(e: Seq[Expression]) =
- ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e)
+ ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e)
functionRegistry.registerFunction(name, builder)
}
"""
@@ -87,112 +87,112 @@ private[sql] trait UDFRegistration {
// scalastyle:off
def registerFunction[T: TypeTag](name: String, func: Function1[_, T]): Unit = {
- def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e)
+ def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e)
functionRegistry.registerFunction(name, builder)
}
def registerFunction[T: TypeTag](name: String, func: Function2[_, _, T]): Unit = {
- def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e)
+ def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e)
functionRegistry.registerFunction(name, builder)
}
def registerFunction[T: TypeTag](name: String, func: Function3[_, _, _, T]): Unit = {
- def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e)
+ def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e)
functionRegistry.registerFunction(name, builder)
}
def registerFunction[T: TypeTag](name: String, func: Function4[_, _, _, _, T]): Unit = {
- def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e)
+ def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e)
functionRegistry.registerFunction(name, builder)
}
def registerFunction[T: TypeTag](name: String, func: Function5[_, _, _, _, _, T]): Unit = {
- def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e)
+ def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e)
functionRegistry.registerFunction(name, builder)
}
def registerFunction[T: TypeTag](name: String, func: Function6[_, _, _, _, _, _, T]): Unit = {
- def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e)
+ def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e)
functionRegistry.registerFunction(name, builder)
}
def registerFunction[T: TypeTag](name: String, func: Function7[_, _, _, _, _, _, _, T]): Unit = {
- def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e)
+ def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e)
functionRegistry.registerFunction(name, builder)
}
def registerFunction[T: TypeTag](name: String, func: Function8[_, _, _, _, _, _, _, _, T]): Unit = {
- def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e)
+ def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e)
functionRegistry.registerFunction(name, builder)
}
def registerFunction[T: TypeTag](name: String, func: Function9[_, _, _, _, _, _, _, _, _, T]): Unit = {
- def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e)
+ def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e)
functionRegistry.registerFunction(name, builder)
}
def registerFunction[T: TypeTag](name: String, func: Function10[_, _, _, _, _, _, _, _, _, _, T]): Unit = {
- def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e)
+ def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e)
functionRegistry.registerFunction(name, builder)
}
def registerFunction[T: TypeTag](name: String, func: Function11[_, _, _, _, _, _, _, _, _, _, _, T]): Unit = {
- def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e)
+ def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e)
functionRegistry.registerFunction(name, builder)
}
def registerFunction[T: TypeTag](name: String, func: Function12[_, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = {
- def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e)
+ def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e)
functionRegistry.registerFunction(name, builder)
}
def registerFunction[T: TypeTag](name: String, func: Function13[_, _, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = {
- def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e)
+ def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e)
functionRegistry.registerFunction(name, builder)
}
def registerFunction[T: TypeTag](name: String, func: Function14[_, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = {
- def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e)
+ def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e)
functionRegistry.registerFunction(name, builder)
}
def registerFunction[T: TypeTag](name: String, func: Function15[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = {
- def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e)
+ def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e)
functionRegistry.registerFunction(name, builder)
}
def registerFunction[T: TypeTag](name: String, func: Function16[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = {
- def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e)
+ def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e)
functionRegistry.registerFunction(name, builder)
}
def registerFunction[T: TypeTag](name: String, func: Function17[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = {
- def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e)
+ def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e)
functionRegistry.registerFunction(name, builder)
}
def registerFunction[T: TypeTag](name: String, func: Function18[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = {
- def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e)
+ def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e)
functionRegistry.registerFunction(name, builder)
}
def registerFunction[T: TypeTag](name: String, func: Function19[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = {
- def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e)
+ def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e)
functionRegistry.registerFunction(name, builder)
}
def registerFunction[T: TypeTag](name: String, func: Function20[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = {
- def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e)
+ def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e)
functionRegistry.registerFunction(name, builder)
}
def registerFunction[T: TypeTag](name: String, func: Function21[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = {
- def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e)
+ def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e)
functionRegistry.registerFunction(name, builder)
}
def registerFunction[T: TypeTag](name: String, func: Function22[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = {
- def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e)
+ def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e)
functionRegistry.registerFunction(name, builder)
}
// scalastyle:on
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala
index 60065509bf..4c0869e05b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala
@@ -23,13 +23,14 @@ import org.apache.hadoop.conf.Configuration
import org.apache.spark.annotation.{DeveloperApi, Experimental}
import org.apache.spark.api.java.{JavaRDD, JavaSparkContext}
-import org.apache.spark.sql.json.JsonRDD
-import org.apache.spark.sql.sources.{LogicalRelation, BaseRelation}
-import org.apache.spark.sql.types.util.DataTypeConversions
import org.apache.spark.sql.{SQLContext, StructType => SStructType}
+import org.apache.spark.sql.catalyst.annotation.SQLUserDefinedType
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, GenericRow, Row => ScalaRow}
-import org.apache.spark.sql.parquet.ParquetRelation
import org.apache.spark.sql.execution.LogicalRDD
+import org.apache.spark.sql.json.JsonRDD
+import org.apache.spark.sql.parquet.ParquetRelation
+import org.apache.spark.sql.sources.{LogicalRelation, BaseRelation}
+import org.apache.spark.sql.types.util.DataTypeConversions
import org.apache.spark.sql.types.util.DataTypeConversions.asScalaDataType
import org.apache.spark.util.Utils
@@ -91,9 +92,12 @@ class JavaSQLContext(val sqlContext: SQLContext) extends UDFRegistration {
/**
* Applies a schema to an RDD of Java Beans.
+ *
+ * WARNING: Since there is no guaranteed ordering for fields in a Java Bean,
+ * SELECT * queries will return the columns in an undefined order.
*/
def applySchema(rdd: JavaRDD[_], beanClass: Class[_]): JavaSchemaRDD = {
- val schema = getSchema(beanClass)
+ val attributeSeq = getSchema(beanClass)
val className = beanClass.getName
val rowRdd = rdd.rdd.mapPartitions { iter =>
// BeanInfo is not serializable so we must rediscover it remotely for each partition.
@@ -104,11 +108,13 @@ class JavaSQLContext(val sqlContext: SQLContext) extends UDFRegistration {
iter.map { row =>
new GenericRow(
- extractors.map(e => DataTypeConversions.convertJavaToCatalyst(e.invoke(row))).toArray[Any]
+ extractors.zip(attributeSeq).map { case (e, attr) =>
+ DataTypeConversions.convertJavaToCatalyst(e.invoke(row), attr.dataType)
+ }.toArray[Any]
): ScalaRow
}
}
- new JavaSchemaRDD(sqlContext, LogicalRDD(schema, rowRdd)(sqlContext))
+ new JavaSchemaRDD(sqlContext, LogicalRDD(attributeSeq, rowRdd)(sqlContext))
}
/**
@@ -195,14 +201,21 @@ class JavaSQLContext(val sqlContext: SQLContext) extends UDFRegistration {
sqlContext.registerRDDAsTable(rdd.baseSchemaRDD, tableName)
}
- /** Returns a Catalyst Schema for the given java bean class. */
+ /**
+ * Returns a Catalyst Schema for the given java bean class.
+ */
protected def getSchema(beanClass: Class[_]): Seq[AttributeReference] = {
// TODO: All of this could probably be moved to Catalyst as it is mostly not Spark specific.
val beanInfo = Introspector.getBeanInfo(beanClass)
+ // Note: The ordering of elements may differ from when the schema is inferred in Scala.
+ // This is because beanInfo.getPropertyDescriptors gives no guarantees about
+ // element ordering.
val fields = beanInfo.getPropertyDescriptors.filterNot(_.getName == "class")
fields.map { property =>
val (dataType, nullable) = property.getPropertyType match {
+ case c: Class[_] if c.isAnnotationPresent(classOf[SQLUserDefinedType]) =>
+ (c.getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance(), true)
case c: Class[_] if c == classOf[java.lang.String] =>
(org.apache.spark.sql.StringType, true)
case c: Class[_] if c == java.lang.Short.TYPE =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/UDTWrappers.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/java/UDTWrappers.scala
new file mode 100644
index 0000000000..a7d0f4f127
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/UDTWrappers.scala
@@ -0,0 +1,75 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.api.java
+
+import org.apache.spark.sql.catalyst.types.{UserDefinedType => ScalaUserDefinedType}
+import org.apache.spark.sql.{DataType => ScalaDataType}
+import org.apache.spark.sql.types.util.DataTypeConversions
+
+/**
+ * Scala wrapper for a Java UserDefinedType
+ */
+private[sql] class JavaToScalaUDTWrapper[UserType](val javaUDT: UserDefinedType[UserType])
+ extends ScalaUserDefinedType[UserType] with Serializable {
+
+ /** Underlying storage type for this UDT */
+ val sqlType: ScalaDataType = DataTypeConversions.asScalaDataType(javaUDT.sqlType())
+
+ /** Convert the user type to a SQL datum */
+ def serialize(obj: Any): Any = javaUDT.serialize(obj)
+
+ /** Convert a SQL datum to the user type */
+ def deserialize(datum: Any): UserType = javaUDT.deserialize(datum)
+
+ val userClass: java.lang.Class[UserType] = javaUDT.userClass()
+}
+
+/**
+ * Java wrapper for a Scala UserDefinedType
+ */
+private[sql] class ScalaToJavaUDTWrapper[UserType](val scalaUDT: ScalaUserDefinedType[UserType])
+ extends UserDefinedType[UserType] with Serializable {
+
+ /** Underlying storage type for this UDT */
+ val sqlType: DataType = DataTypeConversions.asJavaDataType(scalaUDT.sqlType)
+
+ /** Convert the user type to a SQL datum */
+ def serialize(obj: Any): java.lang.Object = scalaUDT.serialize(obj).asInstanceOf[java.lang.Object]
+
+ /** Convert a SQL datum to the user type */
+ def deserialize(datum: Any): UserType = scalaUDT.deserialize(datum)
+
+ val userClass: java.lang.Class[UserType] = scalaUDT.userClass
+}
+
+private[sql] object UDTWrappers {
+
+ def wrapAsScala(udtType: UserDefinedType[_]): ScalaUserDefinedType[_] = {
+ udtType match {
+ case t: ScalaToJavaUDTWrapper[_] => t.scalaUDT
+ case _ => new JavaToScalaUDTWrapper(udtType)
+ }
+ }
+
+ def wrapAsJava(udtType: ScalaUserDefinedType[_]): UserDefinedType[_] = {
+ udtType match {
+ case t: JavaToScalaUDTWrapper[_] => t.javaUDT
+ case _ => new ScalaToJavaUDTWrapper(udtType)
+ }
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
index d64c5af89e..ed6b95dc6d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
@@ -19,29 +19,32 @@ package org.apache.spark.sql.execution
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.{DataType, StructType, Row, SQLContext}
import org.apache.spark.sql.catalyst.ScalaReflection
+import org.apache.spark.sql.catalyst.ScalaReflection.Schema
import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
import org.apache.spark.sql.catalyst.expressions.{Attribute, GenericMutableRow}
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Statistics}
-import org.apache.spark.sql.{Row, SQLContext}
+import org.apache.spark.sql.catalyst.types.UserDefinedType
/**
* :: DeveloperApi ::
*/
@DeveloperApi
object RDDConversions {
- def productToRowRdd[A <: Product](data: RDD[A]): RDD[Row] = {
+ def productToRowRdd[A <: Product](data: RDD[A], schema: StructType): RDD[Row] = {
data.mapPartitions { iterator =>
if (iterator.isEmpty) {
Iterator.empty
} else {
val bufferedIterator = iterator.buffered
val mutableRow = new GenericMutableRow(bufferedIterator.head.productArity)
-
+ val schemaFields = schema.fields.toArray
bufferedIterator.map { r =>
var i = 0
while (i < mutableRow.length) {
- mutableRow(i) = ScalaReflection.convertToCatalyst(r.productElement(i))
+ mutableRow(i) =
+ ScalaReflection.convertToCatalyst(r.productElement(i), schemaFields(i).dataType)
i += 1
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
index aafcce0572..81c60e0050 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
@@ -20,8 +20,6 @@ package org.apache.spark.sql.execution
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.Logging
import org.apache.spark.rdd.RDD
-
-
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.catalyst.{ScalaReflection, trees}
import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
@@ -82,7 +80,8 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
/**
* Runs this query returning the result as an array.
*/
- def executeCollect(): Array[Row] = execute().map(ScalaReflection.convertRowToScala).collect()
+ def executeCollect(): Array[Row] =
+ execute().map(ScalaReflection.convertRowToScala(_, schema)).collect()
protected def newProjection(
expressions: Seq[Expression], inputSchema: Seq[Attribute]): Projection = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 2cd3063bc3..cc7e0c05ff 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -280,7 +280,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
val nPartitions = if (data.isEmpty) 1 else numPartitions
PhysicalRDD(
output,
- RDDConversions.productToRowRdd(sparkContext.parallelize(data, nPartitions))) :: Nil
+ RDDConversions.productToRowRdd(sparkContext.parallelize(data, nPartitions),
+ StructType.fromAttributes(output))) :: Nil
case logical.Limit(IntegerLiteral(limit), child) =>
execution.Limit(limit, planLater(child)) :: Nil
case Unions(unionChildren) =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
index e6cd1a9d04..1b8ba3ace2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
@@ -143,7 +143,7 @@ case class Limit(limit: Int, child: SparkPlan)
partsScanned += numPartsToTry
}
- buf.toArray.map(ScalaReflection.convertRowToScala)
+ buf.toArray.map(ScalaReflection.convertRowToScala(_, this.schema))
}
override def execute() = {
@@ -179,8 +179,8 @@ case class TakeOrdered(limit: Int, sortOrder: Seq[SortOrder], child: SparkPlan)
val ord = new RowOrdering(sortOrder, child.output)
// TODO: Is this copying for no reason?
- override def executeCollect() =
- child.execute().map(_.copy()).takeOrdered(limit)(ord).map(ScalaReflection.convertRowToScala)
+ override def executeCollect() = child.execute().map(_.copy()).takeOrdered(limit)(ord)
+ .map(ScalaReflection.convertRowToScala(_, this.schema))
// TODO: Terminal split should be implemented differently from non-terminal split.
// TODO: Pick num splits based on |limit|.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala
index 08feced61a..1bbb66aaa1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala
@@ -77,6 +77,9 @@ private[sql] object CatalystConverter {
parent: CatalystConverter): Converter = {
val fieldType: DataType = field.dataType
fieldType match {
+ case udt: UserDefinedType[_] => {
+ createConverter(field.copy(dataType = udt.sqlType), fieldIndex, parent)
+ }
// For native JVM types we use a converter with native arrays
case ArrayType(elementType: NativeType, false) => {
new CatalystNativeArrayConverter(elementType, fieldIndex, parent)
@@ -255,8 +258,8 @@ private[parquet] class CatalystGroupConverter(
schema,
index,
parent,
- current=null,
- buffer=new ArrayBuffer[Row](
+ current = null,
+ buffer = new ArrayBuffer[Row](
CatalystArrayConverter.INITIAL_ARRAY_SIZE))
/**
@@ -301,7 +304,7 @@ private[parquet] class CatalystGroupConverter(
override def end(): Unit = {
if (!isRootConverter) {
- assert(current!=null) // there should be no empty groups
+ assert(current != null) // there should be no empty groups
buffer.append(new GenericRow(current.toArray))
parent.updateField(index, new GenericRow(buffer.toArray.asInstanceOf[Array[Any]]))
}
@@ -358,7 +361,7 @@ private[parquet] class CatalystPrimitiveRowConverter(
override def end(): Unit = {}
- // Overriden here to avoid auto-boxing for primitive types
+ // Overridden here to avoid auto-boxing for primitive types
override protected[parquet] def updateBoolean(fieldIndex: Int, value: Boolean): Unit =
current.setBoolean(fieldIndex, value)
@@ -533,7 +536,7 @@ private[parquet] class CatalystNativeArrayConverter(
override protected[parquet] def updateField(fieldIndex: Int, value: Any): Unit =
throw new UnsupportedOperationException
- // Overriden here to avoid auto-boxing for primitive types
+ // Overridden here to avoid auto-boxing for primitive types
override protected[parquet] def updateBoolean(fieldIndex: Int, value: Boolean): Unit = {
checkGrowBuffer()
buffer(elements) = value.asInstanceOf[NativeType]
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala
index 2a5f23b24e..7bc2496600 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala
@@ -20,7 +20,6 @@ package org.apache.spark.sql.parquet
import java.util.{HashMap => JHashMap}
import org.apache.hadoop.conf.Configuration
-import org.apache.spark.sql.catalyst.types.decimal.Decimal
import parquet.column.ParquetProperties
import parquet.hadoop.ParquetOutputFormat
import parquet.hadoop.api.ReadSupport.ReadContext
@@ -31,6 +30,7 @@ import parquet.schema.MessageType
import org.apache.spark.Logging
import org.apache.spark.sql.catalyst.expressions.{Attribute, Row}
import org.apache.spark.sql.catalyst.types._
+import org.apache.spark.sql.catalyst.types.decimal.Decimal
/**
* A `parquet.io.api.RecordMaterializer` for Rows.
@@ -174,6 +174,7 @@ private[parquet] class RowWriteSupport extends WriteSupport[Row] with Logging {
private[parquet] def writeValue(schema: DataType, value: Any): Unit = {
if (value != null) {
schema match {
+ case t: UserDefinedType[_] => writeValue(t.sqlType, value)
case t @ ArrayType(_, _) => writeArray(
t,
value.asInstanceOf[CatalystConverter.ArrayScalaType[_]])
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala
index e5077de8dd..fa37d1f2ae 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala
@@ -290,6 +290,9 @@ private[parquet] object ParquetTypesConverter extends Logging {
builder.named(name)
}.getOrElse {
ctype match {
+ case udt: UserDefinedType[_] => {
+ fromDataType(udt.sqlType, name, nullable, inArray)
+ }
case ArrayType(elementType, false) => {
val parquetElementType = fromDataType(
elementType,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala b/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala
index 7564bf3923..1bc15146f0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala
@@ -17,12 +17,16 @@
package org.apache.spark.sql.types.util
+import scala.collection.JavaConverters._
+
import org.apache.spark.sql._
-import org.apache.spark.sql.api.java.{DataType => JDataType, StructField => JStructField, MetadataBuilder => JMetaDataBuilder}
+import org.apache.spark.sql.api.java.{DataType => JDataType, StructField => JStructField,
+ MetadataBuilder => JMetaDataBuilder, UDTWrappers, JavaToScalaUDTWrapper}
import org.apache.spark.sql.api.java.{DecimalType => JDecimalType}
import org.apache.spark.sql.catalyst.types.decimal.Decimal
+import org.apache.spark.sql.catalyst.ScalaReflection
+import org.apache.spark.sql.catalyst.types.UserDefinedType
-import scala.collection.JavaConverters._
protected[sql] object DataTypeConversions {
@@ -41,6 +45,9 @@ protected[sql] object DataTypeConversions {
* Returns the equivalent DataType in Java for the given DataType in Scala.
*/
def asJavaDataType(scalaDataType: DataType): JDataType = scalaDataType match {
+ case udtType: UserDefinedType[_] =>
+ UDTWrappers.wrapAsJava(udtType)
+
case StringType => JDataType.StringType
case BinaryType => JDataType.BinaryType
case BooleanType => JDataType.BooleanType
@@ -80,6 +87,9 @@ protected[sql] object DataTypeConversions {
* Returns the equivalent DataType in Scala for the given DataType in Java.
*/
def asScalaDataType(javaDataType: JDataType): DataType = javaDataType match {
+ case udtType: org.apache.spark.sql.api.java.UserDefinedType[_] =>
+ UDTWrappers.wrapAsScala(udtType)
+
case stringType: org.apache.spark.sql.api.java.StringType =>
StringType
case binaryType: org.apache.spark.sql.api.java.BinaryType =>
@@ -121,9 +131,11 @@ protected[sql] object DataTypeConversions {
}
/** Converts Java objects to catalyst rows / types */
- def convertJavaToCatalyst(a: Any): Any = a match {
- case d: java.math.BigDecimal => Decimal(BigDecimal(d))
- case other => other
+ def convertJavaToCatalyst(a: Any, dataType: DataType): Any = (a, dataType) match {
+ case (obj, udt: UserDefinedType[_]) => ScalaReflection.convertToCatalyst(obj, udt) // Scala type
+ case (d: java.math.BigDecimal, _) => Decimal(BigDecimal(d))
+ case (d: java.math.BigDecimal, _) => BigDecimal(d)
+ case (other, _) => other
}
/** Converts Java objects to catalyst rows / types */
diff --git a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaUserDefinedTypeSuite.java b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaUserDefinedTypeSuite.java
new file mode 100644
index 0000000000..0caa8219a6
--- /dev/null
+++ b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaUserDefinedTypeSuite.java
@@ -0,0 +1,88 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.api.java;
+
+import java.io.Serializable;
+import java.util.*;
+
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.sql.MyDenseVector;
+import org.apache.spark.sql.MyLabeledPoint;
+
+public class JavaUserDefinedTypeSuite implements Serializable {
+ private transient JavaSparkContext javaCtx;
+ private transient JavaSQLContext javaSqlCtx;
+
+ @Before
+ public void setUp() {
+ javaCtx = new JavaSparkContext("local", "JavaUserDefinedTypeSuite");
+ javaSqlCtx = new JavaSQLContext(javaCtx);
+ }
+
+ @After
+ public void tearDown() {
+ javaCtx.stop();
+ javaCtx = null;
+ javaSqlCtx = null;
+ }
+
+ @Test
+ public void useScalaUDT() {
+ List<MyLabeledPoint> points = Arrays.asList(
+ new MyLabeledPoint(1.0, new MyDenseVector(new double[]{0.1, 1.0})),
+ new MyLabeledPoint(0.0, new MyDenseVector(new double[]{0.2, 2.0})));
+ JavaRDD<MyLabeledPoint> pointsRDD = javaCtx.parallelize(points);
+
+ JavaSchemaRDD schemaRDD = javaSqlCtx.applySchema(pointsRDD, MyLabeledPoint.class);
+ schemaRDD.registerTempTable("points");
+
+ List<Row> actualLabelRows = javaSqlCtx.sql("SELECT label FROM points").collect();
+ List<Double> actualLabels = new LinkedList<Double>();
+ for (Row r : actualLabelRows) {
+ actualLabels.add(r.getDouble(0));
+ }
+ for (MyLabeledPoint lp : points) {
+ Assert.assertTrue(actualLabels.contains(lp.label()));
+ }
+
+ List<Row> actualFeatureRows = javaSqlCtx.sql("SELECT features FROM points").collect();
+ List<MyDenseVector> actualFeatures = new LinkedList<MyDenseVector>();
+ for (Row r : actualFeatureRows) {
+ actualFeatures.add((MyDenseVector)r.get(0));
+ }
+ for (MyLabeledPoint lp : points) {
+ Assert.assertTrue(actualFeatures.contains(lp.features()));
+ }
+
+ List<Row> actual = javaSqlCtx.sql("SELECT label, features FROM points").collect();
+ List<MyLabeledPoint> actualPoints =
+ new LinkedList<MyLabeledPoint>();
+ for (Row r : actual) {
+ actualPoints.add(new MyLabeledPoint(r.getDouble(0), (MyDenseVector)r.get(1)));
+ }
+ for (MyLabeledPoint lp : points) {
+ Assert.assertTrue(actualPoints.contains(lp));
+ }
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
new file mode 100644
index 0000000000..666235e57f
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
@@ -0,0 +1,83 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import scala.beans.{BeanInfo, BeanProperty}
+
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.annotation.SQLUserDefinedType
+import org.apache.spark.sql.catalyst.types.UserDefinedType
+import org.apache.spark.sql.test.TestSQLContext._
+
+@SQLUserDefinedType(udt = classOf[MyDenseVectorUDT])
+private[sql] class MyDenseVector(val data: Array[Double]) extends Serializable {
+ override def equals(other: Any): Boolean = other match {
+ case v: MyDenseVector =>
+ java.util.Arrays.equals(this.data, v.data)
+ case _ => false
+ }
+}
+
+@BeanInfo
+private[sql] case class MyLabeledPoint(
+ @BeanProperty label: Double,
+ @BeanProperty features: MyDenseVector)
+
+private[sql] class MyDenseVectorUDT extends UserDefinedType[MyDenseVector] {
+
+ override def sqlType: DataType = ArrayType(DoubleType, containsNull = false)
+
+ override def serialize(obj: Any): Seq[Double] = {
+ obj match {
+ case features: MyDenseVector =>
+ features.data.toSeq
+ }
+ }
+
+ override def deserialize(datum: Any): MyDenseVector = {
+ datum match {
+ case data: Seq[_] =>
+ new MyDenseVector(data.asInstanceOf[Seq[Double]].toArray)
+ }
+ }
+
+ override def userClass = classOf[MyDenseVector]
+}
+
+class UserDefinedTypeSuite extends QueryTest {
+
+ test("register user type: MyDenseVector for MyLabeledPoint") {
+ val points = Seq(
+ MyLabeledPoint(1.0, new MyDenseVector(Array(0.1, 1.0))),
+ MyLabeledPoint(0.0, new MyDenseVector(Array(0.2, 2.0))))
+ val pointsRDD: RDD[MyLabeledPoint] = sparkContext.parallelize(points)
+
+ val labels: RDD[Double] = pointsRDD.select('label).map { case Row(v: Double) => v }
+ val labelsArrays: Array[Double] = labels.collect()
+ assert(labelsArrays.size === 2)
+ assert(labelsArrays.contains(1.0))
+ assert(labelsArrays.contains(0.0))
+
+ val features: RDD[MyDenseVector] =
+ pointsRDD.select('features).map { case Row(v: MyDenseVector) => v }
+ val featuresArrays: Array[MyDenseVector] = features.collect()
+ assert(featuresArrays.size === 2)
+ assert(featuresArrays.contains(new MyDenseVector(Array(0.1, 1.0))))
+ assert(featuresArrays.contains(new MyDenseVector(Array(0.2, 2.0))))
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala
index 4b851d1b96..cade244f7a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala
@@ -21,8 +21,7 @@ import org.apache.spark.sql.catalyst.types._
import org.apache.spark.sql.catalyst.types.decimal.Decimal
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.json.JsonRDD.{enforceCorrectType, compatibleType}
-import org.apache.spark.sql.QueryTest
-import org.apache.spark.sql.SQLConf
+import org.apache.spark.sql.{Row, SQLConf, QueryTest}
import org.apache.spark.sql.test.TestSQLContext
import org.apache.spark.sql.test.TestSQLContext._
@@ -233,8 +232,8 @@ class JsonSuite extends QueryTest {
StructField("field2", StringType, true) ::
StructField("field3", StringType, true) :: Nil), false), true) ::
StructField("struct", StructType(
- StructField("field1", BooleanType, true) ::
- StructField("field2", DecimalType.Unlimited, true) :: Nil), true) ::
+ StructField("field1", BooleanType, true) ::
+ StructField("field2", DecimalType.Unlimited, true) :: Nil), true) ::
StructField("structWithArrayFields", StructType(
StructField("field1", ArrayType(IntegerType, false), true) ::
StructField("field2", ArrayType(StringType, false), true) :: Nil), true) :: Nil)
@@ -292,8 +291,8 @@ class JsonSuite extends QueryTest {
// Access a struct and fields inside of it.
checkAnswer(
sql("select struct, struct.field1, struct.field2 from jsonTable"),
- (
- Seq(true, BigDecimal("92233720368547758070")),
+ Row(
+ Row(true, BigDecimal("92233720368547758070")),
true,
BigDecimal("92233720368547758070")) :: Nil
)
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
index 0fe59f42f2..f025169ad5 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
@@ -374,8 +374,6 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
/** Extends QueryExecution with hive specific features. */
protected[sql] abstract class QueryExecution extends super.QueryExecution {
- override lazy val toRdd: RDD[Row] = executedPlan.execute().map(_.copy())
-
protected val primitiveTypes =
Seq(StringType, IntegerType, LongType, DoubleType, FloatType, BooleanType, ByteType,
ShortType, DateType, TimestampType, BinaryType)
@@ -433,7 +431,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
command.executeCollect().map(_.head.toString)
case other =>
- val result: Seq[Seq[Any]] = toRdd.collect().toSeq
+ val result: Seq[Seq[Any]] = toRdd.map(_.copy()).collect().toSeq
// We need the types so we can output struct field names
val types = analyzed.output.map(_.dataType)
// Reformat to match hive tab delimited output.