From ebd6480587f96e9964d37157253523e0a179171a Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Sun, 2 Nov 2014 17:55:55 -0800 Subject: [SPARK-3572] [SQL] Internal API for User-Defined Types This PR adds User-Defined Types (UDTs) to SQL. It is a precursor to using SchemaRDD as a Dataset for the new MLlib API. Currently, the UDT API is private since there is incomplete support (e.g., no Java or Python support yet). Author: Joseph K. Bradley Author: Michael Armbrust Author: Xiangrui Meng Closes #3063 from marmbrus/udts and squashes the following commits: 7ccfc0d [Michael Armbrust] remove println 46a3aee [Michael Armbrust] Slightly easier to read test output. 6cc434d [Michael Armbrust] Recursively convert rows. e369b91 [Michael Armbrust] Merge remote-tracking branch 'origin/master' into udts 15c10a6 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into sql-udt2 f3c72fe [Joseph K. Bradley] Fixing merge e13cd8a [Joseph K. Bradley] Removed Vector UDTs 5817b2b [Joseph K. Bradley] style edits 30ce5b2 [Joseph K. Bradley] updates based on code review d063380 [Joseph K. Bradley] Cleaned up Java UDT Suite, and added warning about element ordering when creating schema from Java Bean a571bb6 [Joseph K. Bradley] Removed old UDT code (registry and Java UDTs). Cleaned up other code. Extended JavaUserDefinedTypeSuite 6fddc1c [Joseph K. Bradley] Made MyLabeledPoint into a Java Bean 20630bc [Joseph K. Bradley] fixed scalastyle fa86b20 [Joseph K. Bradley] Removed Java UserDefinedType, and made UDTs private[spark] for now 8de957c [Joseph K. Bradley] Modified UserDefinedType to store Java class of user type so that registerUDT takes only the udt argument. 8b242ea [Joseph K. Bradley] Fixed merge error after last merge. Note: Last merge commit also removed SQL UDT examples from mllib. 7f29656 [Joseph K. Bradley] Moved udt case to top of all matches. Small cleanups b028675 [Xiangrui Meng] allow any type in UDT 4500d8a [Xiangrui Meng] update example code 87264a5 [Xiangrui Meng] remove debug code 3143ac3 [Xiangrui Meng] remove unnecessary changes cfbc321 [Xiangrui Meng] support UDT in parquet db16139 [Joseph K. Bradley] Added more doc for UserDefinedType. Removed unused code in Suite 759af7a [Joseph K. Bradley] Added more doc to UserDefineType 63626a4 [Joseph K. Bradley] Updated ScalaReflectionsSuite per @marmbrus suggestions 51e5282 [Joseph K. Bradley] fixed 1 test f025035 [Joseph K. Bradley] Cleanups before PR. Added new tests 85872f6 [Michael Armbrust] Allow schema calculation to be lazy, but ensure its available on executors. dff99d6 [Joseph K. Bradley] Added UDTs for Vectors in MLlib, plus DatasetExample using the UDTs cd60cb4 [Joseph K. Bradley] Trying to get other SQL tests to run 34a5831 [Joseph K. Bradley] Added MLlib dependency on SQL. e1f7b9c [Joseph K. Bradley] blah 2f40c02 [Joseph K. Bradley] renamed UDT types 3579035 [Joseph K. Bradley] udt annotation now working b226b9e [Joseph K. Bradley] Changing UDT to annotation fea04af [Joseph K. Bradley] more cleanups 964b32e [Joseph K. Bradley] some cleanups 893ee4c [Joseph K. Bradley] udt finallly working 50f9726 [Joseph K. Bradley] udts 04303c9 [Joseph K. Bradley] udts 39f8707 [Joseph K. Bradley] removed old udt suite 273ac96 [Joseph K. Bradley] basic UDT is working, but deserialization has yet to be done 8bebf24 [Joseph K. Bradley] commented out convertRowToScala for debugging 53de70f [Joseph K. Bradley] more udts... 982c035 [Joseph K. Bradley] still working on UDTs 19b2f60 [Joseph K. Bradley] still working on UDTs 0eaeb81 [Joseph K. Bradley] Still working on UDTs 105c5a3 [Joseph K. Bradley] Adding UserDefinedType to SQL, not done yet. --- .../spark/sql/catalyst/ScalaReflection.scala | 155 +++++++++++++-------- .../catalyst/annotation/SQLUserDefinedType.java | 46 ++++++ .../spark/sql/catalyst/expressions/ScalaUdf.scala | 6 +- .../spark/sql/catalyst/types/dataTypes.scala | 53 ++++++- .../spark/sql/catalyst/ScalaReflectionSuite.scala | 13 +- .../apache/spark/sql/api/java/UserDefinedType.java | 53 +++++++ .../scala/org/apache/spark/sql/SQLContext.scala | 6 +- .../scala/org/apache/spark/sql/SchemaRDD.scala | 30 ++-- .../scala/org/apache/spark/sql/SchemaRDDLike.scala | 2 +- .../org/apache/spark/sql/UdfRegistration.scala | 46 +++--- .../apache/spark/sql/api/java/JavaSQLContext.scala | 29 ++-- .../apache/spark/sql/api/java/UDTWrappers.scala | 75 ++++++++++ .../apache/spark/sql/execution/ExistingRDD.scala | 11 +- .../org/apache/spark/sql/execution/SparkPlan.scala | 5 +- .../spark/sql/execution/SparkStrategies.scala | 3 +- .../spark/sql/execution/basicOperators.scala | 6 +- .../spark/sql/parquet/ParquetConverter.scala | 13 +- .../spark/sql/parquet/ParquetTableSupport.scala | 3 +- .../apache/spark/sql/parquet/ParquetTypes.scala | 3 + .../spark/sql/types/util/DataTypeConversions.scala | 22 ++- .../sql/api/java/JavaUserDefinedTypeSuite.java | 88 ++++++++++++ .../apache/spark/sql/UserDefinedTypeSuite.scala | 83 +++++++++++ .../org/apache/spark/sql/json/JsonSuite.scala | 11 +- .../org/apache/spark/sql/hive/HiveContext.scala | 4 +- 24 files changed, 620 insertions(+), 146 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/annotation/SQLUserDefinedType.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/api/java/UserDefinedType.java create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/api/java/UDTWrappers.scala create mode 100644 sql/core/src/test/java/org/apache/spark/sql/api/java/JavaUserDefinedTypeSuite.java create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala (limited to 'sql') diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 8fbdf664b7..9cda373623 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.catalyst import java.sql.{Date, Timestamp} +import org.apache.spark.util.Utils +import org.apache.spark.sql.catalyst.annotation.SQLUserDefinedType import org.apache.spark.sql.catalyst.expressions.{GenericRow, Attribute, AttributeReference, Row} import org.apache.spark.sql.catalyst.plans.logical.LocalRelation import org.apache.spark.sql.catalyst.types._ @@ -35,25 +37,46 @@ object ScalaReflection { case class Schema(dataType: DataType, nullable: Boolean) - /** Converts Scala objects to catalyst rows / types */ - def convertToCatalyst(a: Any): Any = a match { - case o: Option[_] => o.map(convertToCatalyst).orNull - case s: Seq[_] => s.map(convertToCatalyst) - case m: Map[_, _] => m.map { case (k, v) => convertToCatalyst(k) -> convertToCatalyst(v) } - case p: Product => new GenericRow(p.productIterator.map(convertToCatalyst).toArray) - case d: BigDecimal => Decimal(d) - case other => other + /** + * Converts Scala objects to catalyst rows / types. + * Note: This is always called after schemaFor has been called. + * This ordering is important for UDT registration. + */ + def convertToCatalyst(a: Any, dataType: DataType): Any = (a, dataType) match { + // Check UDT first since UDTs can override other types + case (obj, udt: UserDefinedType[_]) => udt.serialize(obj) + case (o: Option[_], _) => o.map(convertToCatalyst(_, dataType)).orNull + case (s: Seq[_], arrayType: ArrayType) => s.map(convertToCatalyst(_, arrayType.elementType)) + case (m: Map[_, _], mapType: MapType) => m.map { case (k, v) => + convertToCatalyst(k, mapType.keyType) -> convertToCatalyst(v, mapType.valueType) + } + case (p: Product, structType: StructType) => + new GenericRow( + p.productIterator.toSeq.zip(structType.fields).map { case (elem, field) => + convertToCatalyst(elem, field.dataType) + }.toArray) + case (d: BigDecimal, _) => Decimal(d) + case (other, _) => other } /** Converts Catalyst types used internally in rows to standard Scala types */ - def convertToScala(a: Any): Any = a match { - case s: Seq[_] => s.map(convertToScala) - case m: Map[_, _] => m.map { case (k, v) => convertToScala(k) -> convertToScala(v) } - case d: Decimal => d.toBigDecimal - case other => other + def convertToScala(a: Any, dataType: DataType): Any = (a, dataType) match { + // Check UDT first since UDTs can override other types + case (d, udt: UserDefinedType[_]) => udt.deserialize(d) + case (s: Seq[_], arrayType: ArrayType) => s.map(convertToScala(_, arrayType.elementType)) + case (m: Map[_, _], mapType: MapType) => m.map { case (k, v) => + convertToScala(k, mapType.keyType) -> convertToScala(v, mapType.valueType) + } + case (r: Row, s: StructType) => convertRowToScala(r, s) + case (d: Decimal, _: DecimalType) => d.toBigDecimal + case (other, _) => other } - def convertRowToScala(r: Row): Row = new GenericRow(r.toArray.map(convertToScala)) + def convertRowToScala(r: Row, schema: StructType): Row = { + new GenericRow( + r.zip(schema.fields.map(_.dataType)) + .map(r_dt => convertToScala(r_dt._1, r_dt._2)).toArray) + } /** Returns a Sequence of attributes for the given case class type. */ def attributesFor[T: TypeTag]: Seq[Attribute] = schemaFor[T] match { @@ -65,52 +88,64 @@ object ScalaReflection { def schemaFor[T: TypeTag]: Schema = schemaFor(typeOf[T]) /** Returns a catalyst DataType and its nullability for the given Scala Type using reflection. */ - def schemaFor(tpe: `Type`): Schema = tpe match { - case t if t <:< typeOf[Option[_]] => - val TypeRef(_, _, Seq(optType)) = t - Schema(schemaFor(optType).dataType, nullable = true) - case t if t <:< typeOf[Product] => - val formalTypeArgs = t.typeSymbol.asClass.typeParams - val TypeRef(_, _, actualTypeArgs) = t - val params = t.member(nme.CONSTRUCTOR).asMethod.paramss - Schema(StructType( - params.head.map { p => - val Schema(dataType, nullable) = - schemaFor(p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs)) - StructField(p.name.toString, dataType, nullable) - }), nullable = true) - // Need to decide if we actually need a special type here. - case t if t <:< typeOf[Array[Byte]] => Schema(BinaryType, nullable = true) - case t if t <:< typeOf[Array[_]] => - sys.error(s"Only Array[Byte] supported now, use Seq instead of $t") - case t if t <:< typeOf[Seq[_]] => - val TypeRef(_, _, Seq(elementType)) = t - val Schema(dataType, nullable) = schemaFor(elementType) - Schema(ArrayType(dataType, containsNull = nullable), nullable = true) - case t if t <:< typeOf[Map[_,_]] => - val TypeRef(_, _, Seq(keyType, valueType)) = t - val Schema(valueDataType, valueNullable) = schemaFor(valueType) - Schema(MapType(schemaFor(keyType).dataType, - valueDataType, valueContainsNull = valueNullable), nullable = true) - case t if t <:< typeOf[String] => Schema(StringType, nullable = true) - case t if t <:< typeOf[Timestamp] => Schema(TimestampType, nullable = true) - case t if t <:< typeOf[Date] => Schema(DateType, nullable = true) - case t if t <:< typeOf[BigDecimal] => Schema(DecimalType.Unlimited, nullable = true) - case t if t <:< typeOf[Decimal] => Schema(DecimalType.Unlimited, nullable = true) - case t if t <:< typeOf[java.lang.Integer] => Schema(IntegerType, nullable = true) - case t if t <:< typeOf[java.lang.Long] => Schema(LongType, nullable = true) - case t if t <:< typeOf[java.lang.Double] => Schema(DoubleType, nullable = true) - case t if t <:< typeOf[java.lang.Float] => Schema(FloatType, nullable = true) - case t if t <:< typeOf[java.lang.Short] => Schema(ShortType, nullable = true) - case t if t <:< typeOf[java.lang.Byte] => Schema(ByteType, nullable = true) - case t if t <:< typeOf[java.lang.Boolean] => Schema(BooleanType, nullable = true) - case t if t <:< definitions.IntTpe => Schema(IntegerType, nullable = false) - case t if t <:< definitions.LongTpe => Schema(LongType, nullable = false) - case t if t <:< definitions.DoubleTpe => Schema(DoubleType, nullable = false) - case t if t <:< definitions.FloatTpe => Schema(FloatType, nullable = false) - case t if t <:< definitions.ShortTpe => Schema(ShortType, nullable = false) - case t if t <:< definitions.ByteTpe => Schema(ByteType, nullable = false) - case t if t <:< definitions.BooleanTpe => Schema(BooleanType, nullable = false) + def schemaFor(tpe: `Type`): Schema = { + val className: String = tpe.erasure.typeSymbol.asClass.fullName + tpe match { + case t if Utils.classIsLoadable(className) && + Utils.classForName(className).isAnnotationPresent(classOf[SQLUserDefinedType]) => + // Note: We check for classIsLoadable above since Utils.classForName uses Java reflection, + // whereas className is from Scala reflection. This can make it hard to find classes + // in some cases, such as when a class is enclosed in an object (in which case + // Java appends a '$' to the object name but Scala does not). + val udt = Utils.classForName(className) + .getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance() + Schema(udt, nullable = true) + case t if t <:< typeOf[Option[_]] => + val TypeRef(_, _, Seq(optType)) = t + Schema(schemaFor(optType).dataType, nullable = true) + case t if t <:< typeOf[Product] => + val formalTypeArgs = t.typeSymbol.asClass.typeParams + val TypeRef(_, _, actualTypeArgs) = t + val params = t.member(nme.CONSTRUCTOR).asMethod.paramss + Schema(StructType( + params.head.map { p => + val Schema(dataType, nullable) = + schemaFor(p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs)) + StructField(p.name.toString, dataType, nullable) + }), nullable = true) + // Need to decide if we actually need a special type here. + case t if t <:< typeOf[Array[Byte]] => Schema(BinaryType, nullable = true) + case t if t <:< typeOf[Array[_]] => + sys.error(s"Only Array[Byte] supported now, use Seq instead of $t") + case t if t <:< typeOf[Seq[_]] => + val TypeRef(_, _, Seq(elementType)) = t + val Schema(dataType, nullable) = schemaFor(elementType) + Schema(ArrayType(dataType, containsNull = nullable), nullable = true) + case t if t <:< typeOf[Map[_, _]] => + val TypeRef(_, _, Seq(keyType, valueType)) = t + val Schema(valueDataType, valueNullable) = schemaFor(valueType) + Schema(MapType(schemaFor(keyType).dataType, + valueDataType, valueContainsNull = valueNullable), nullable = true) + case t if t <:< typeOf[String] => Schema(StringType, nullable = true) + case t if t <:< typeOf[Timestamp] => Schema(TimestampType, nullable = true) + case t if t <:< typeOf[Date] => Schema(DateType, nullable = true) + case t if t <:< typeOf[BigDecimal] => Schema(DecimalType.Unlimited, nullable = true) + case t if t <:< typeOf[Decimal] => Schema(DecimalType.Unlimited, nullable = true) + case t if t <:< typeOf[java.lang.Integer] => Schema(IntegerType, nullable = true) + case t if t <:< typeOf[java.lang.Long] => Schema(LongType, nullable = true) + case t if t <:< typeOf[java.lang.Double] => Schema(DoubleType, nullable = true) + case t if t <:< typeOf[java.lang.Float] => Schema(FloatType, nullable = true) + case t if t <:< typeOf[java.lang.Short] => Schema(ShortType, nullable = true) + case t if t <:< typeOf[java.lang.Byte] => Schema(ByteType, nullable = true) + case t if t <:< typeOf[java.lang.Boolean] => Schema(BooleanType, nullable = true) + case t if t <:< definitions.IntTpe => Schema(IntegerType, nullable = false) + case t if t <:< definitions.LongTpe => Schema(LongType, nullable = false) + case t if t <:< definitions.DoubleTpe => Schema(DoubleType, nullable = false) + case t if t <:< definitions.FloatTpe => Schema(FloatType, nullable = false) + case t if t <:< definitions.ShortTpe => Schema(ShortType, nullable = false) + case t if t <:< definitions.ByteTpe => Schema(ByteType, nullable = false) + case t if t <:< definitions.BooleanTpe => Schema(BooleanType, nullable = false) + } } def typeOfObject: PartialFunction[Any, DataType] = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/annotation/SQLUserDefinedType.java b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/annotation/SQLUserDefinedType.java new file mode 100644 index 0000000000..e966aeea1c --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/annotation/SQLUserDefinedType.java @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.annotation; + +import java.lang.annotation.*; + +import org.apache.spark.annotation.DeveloperApi; +import org.apache.spark.sql.catalyst.types.UserDefinedType; + +/** + * ::DeveloperApi:: + * A user-defined type which can be automatically recognized by a SQLContext and registered. + * + * WARNING: This annotation will only work if both Java and Scala reflection return the same class + * names (after erasure) for the UDT. This will NOT be the case when, e.g., the UDT class + * is enclosed in an object (a singleton). + * + * WARNING: UDTs are currently only supported from Scala. + */ +// TODO: Should I used @Documented ? +@DeveloperApi +@Retention(RetentionPolicy.RUNTIME) +@Target(ElementType.TYPE) +public @interface SQLUserDefinedType { + + /** + * Returns an instance of the UserDefinedType which can serialize and deserialize the user + * class to and from Catalyst built-in types. + */ + Class > udt(); +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala index 1b687a443e..fa1786e74b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala @@ -21,6 +21,10 @@ import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.types.DataType import org.apache.spark.util.ClosureCleaner +/** + * User-defined function. + * @param dataType Return type of function. + */ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expression]) extends Expression { @@ -347,6 +351,6 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi } // scalastyle:on - ScalaReflection.convertToCatalyst(result) + ScalaReflection.convertToCatalyst(result, dataType) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala index d25f3a619d..cc5015ad3c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala @@ -29,11 +29,12 @@ import org.json4s.JsonAST.JValue import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ +import org.apache.spark.annotation.DeveloperApi import org.apache.spark.sql.catalyst.ScalaReflectionLock -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression, Row} +import org.apache.spark.sql.catalyst.types.decimal._ import org.apache.spark.sql.catalyst.util.Metadata import org.apache.spark.util.Utils -import org.apache.spark.sql.catalyst.types.decimal._ object DataType { def fromJson(json: String): DataType = parseDataType(parse(json)) @@ -67,6 +68,11 @@ object DataType { ("fields", JArray(fields)), ("type", JString("struct"))) => StructType(fields.map(parseStructField)) + + case JSortedObject( + ("class", JString(udtClass)), + ("type", JString("udt"))) => + Class.forName(udtClass).newInstance().asInstanceOf[UserDefinedType[_]] } private def parseStructField(json: JValue): StructField = json match { @@ -342,6 +348,7 @@ object FractionalType { case _ => false } } + abstract class FractionalType extends NumericType { private[sql] val fractional: Fractional[JvmType] private[sql] val asIntegral: Integral[JvmType] @@ -565,3 +572,45 @@ case class MapType( ("valueType" -> valueType.jsonValue) ~ ("valueContainsNull" -> valueContainsNull) } + +/** + * ::DeveloperApi:: + * The data type for User Defined Types (UDTs). + * + * This interface allows a user to make their own classes more interoperable with SparkSQL; + * e.g., by creating a [[UserDefinedType]] for a class X, it becomes possible to create + * a SchemaRDD which has class X in the schema. + * + * For SparkSQL to recognize UDTs, the UDT must be annotated with + * [[org.apache.spark.sql.catalyst.annotation.SQLUserDefinedType]]. + * + * The conversion via `serialize` occurs when instantiating a `SchemaRDD` from another RDD. + * The conversion via `deserialize` occurs when reading from a `SchemaRDD`. + */ +@DeveloperApi +abstract class UserDefinedType[UserType] extends DataType with Serializable { + + /** Underlying storage type for this UDT */ + def sqlType: DataType + + /** + * Convert the user type to a SQL datum + * + * TODO: Can we make this take obj: UserType? The issue is in ScalaReflection.convertToCatalyst, + * where we need to convert Any to UserType. + */ + def serialize(obj: Any): Any + + /** Convert a SQL datum to the user type */ + def deserialize(datum: Any): UserType + + override private[sql] def jsonValue: JValue = { + ("type" -> "udt") ~ + ("class" -> this.getClass.getName) + } + + /** + * Class object for the UserType + */ + def userClass: java.lang.Class[UserType] +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala index 21b2c8e20d..ddc3d44869 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala @@ -22,6 +22,7 @@ import java.sql.{Date, Timestamp} import org.scalatest.FunSuite +import org.apache.spark.sql.catalyst.expressions.Row import org.apache.spark.sql.catalyst.types._ case class PrimitiveData( @@ -239,13 +240,17 @@ class ScalaReflectionSuite extends FunSuite { test("convert PrimitiveData to catalyst") { val data = PrimitiveData(1, 1, 1, 1, 1, 1, true) val convertedData = Seq(1, 1.toLong, 1.toDouble, 1.toFloat, 1.toShort, 1.toByte, true) - assert(convertToCatalyst(data) === convertedData) + val dataType = schemaFor[PrimitiveData].dataType + assert(convertToCatalyst(data, dataType) === convertedData) } test("convert Option[Product] to catalyst") { val primitiveData = PrimitiveData(1, 1, 1, 1, 1, 1, true) - val data = OptionalData(Some(1), Some(1), Some(1), Some(1), Some(1), Some(1), Some(true), Some(primitiveData)) - val convertedData = Seq(1, 1.toLong, 1.toDouble, 1.toFloat, 1.toShort, 1.toByte, true, convertToCatalyst(primitiveData)) - assert(convertToCatalyst(data) === convertedData) + val data = OptionalData(Some(2), Some(2), Some(2), Some(2), Some(2), Some(2), Some(true), + Some(primitiveData)) + val dataType = schemaFor[OptionalData].dataType + val convertedData = Row(2, 2.toLong, 2.toDouble, 2.toFloat, 2.toShort, 2.toByte, true, + Row(1, 1, 1, 1, 1, 1, true)) + assert(convertToCatalyst(data, dataType) === convertedData) } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UserDefinedType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UserDefinedType.java new file mode 100644 index 0000000000..b751847b46 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UserDefinedType.java @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java; + +import java.io.Serializable; + +import org.apache.spark.annotation.DeveloperApi; + +/** + * ::DeveloperApi:: + * The data type representing User-Defined Types (UDTs). + * UDTs may use any other DataType for an underlying representation. + */ +@DeveloperApi +public abstract class UserDefinedType extends DataType implements Serializable { + + protected UserDefinedType() { } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + UserDefinedType that = (UserDefinedType) o; + return this.sqlType().equals(that.sqlType()); + } + + /** Underlying storage type for this UDT */ + public abstract DataType sqlType(); + + /** Convert the user type to a SQL datum */ + public abstract Object serialize(Object obj); + + /** Convert a SQL datum to the user type */ + public abstract UserType deserialize(Object datum); + + /** Class object for the UserType */ + public abstract Class userClass(); +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 3cf6af5f7a..9e61d18f7e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -107,8 +107,10 @@ class SQLContext(@transient val sparkContext: SparkContext) */ implicit def createSchemaRDD[A <: Product: TypeTag](rdd: RDD[A]) = { SparkPlan.currentContext.set(self) - new SchemaRDD(this, - LogicalRDD(ScalaReflection.attributesFor[A], RDDConversions.productToRowRdd(rdd))(self)) + val attributeSeq = ScalaReflection.attributesFor[A] + val schema = StructType.fromAttributes(attributeSeq) + val rowRDD = RDDConversions.productToRowRdd(rdd, schema) + new SchemaRDD(this, LogicalRDD(attributeSeq, rowRDD)(self)) } implicit def baseRelationToSchemaRDD(baseRelation: BaseRelation): SchemaRDD = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala index 018a18c4ac..3ee2ea05cf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala @@ -17,26 +17,24 @@ package org.apache.spark.sql -import java.util.{Map => JMap, List => JList} - -import org.apache.spark.sql.catalyst.ScalaReflection -import org.apache.spark.storage.StorageLevel +import java.util.{List => JList} import scala.collection.JavaConversions._ -import scala.collection.JavaConverters._ import net.razorvine.pickle.Pickler import org.apache.spark.{Dependency, OneToOneDependency, Partition, Partitioner, TaskContext} import org.apache.spark.annotation.{AlphaComponent, Experimental} +import org.apache.spark.api.java.JavaRDD import org.apache.spark.rdd.RDD import org.apache.spark.sql.api.java.JavaSchemaRDD +import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.{Inner, JoinType} +import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.{LogicalRDD, EvaluatePython} -import org.apache.spark.api.java.JavaRDD +import org.apache.spark.storage.StorageLevel /** * :: AlphaComponent :: @@ -114,18 +112,22 @@ class SchemaRDD( // ========================================================================================= override def compute(split: Partition, context: TaskContext): Iterator[Row] = - firstParent[Row].compute(split, context).map(ScalaReflection.convertRowToScala) + firstParent[Row].compute(split, context).map(ScalaReflection.convertRowToScala(_, this.schema)) override def getPartitions: Array[Partition] = firstParent[Row].partitions - override protected def getDependencies: Seq[Dependency[_]] = + override protected def getDependencies: Seq[Dependency[_]] = { + schema // Force reification of the schema so it is available on executors. + List(new OneToOneDependency(queryExecution.toRdd)) + } - /** Returns the schema of this SchemaRDD (represented by a [[StructType]]). - * - * @group schema - */ - def schema: StructType = queryExecution.analyzed.schema + /** + * Returns the schema of this SchemaRDD (represented by a [[StructType]]). + * + * @group schema + */ + lazy val schema: StructType = queryExecution.analyzed.schema // ======================================================================= // Query DSL diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala index 15516afb95..fd5f4abcbc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.execution.LogicalRDD * Contains functions that are shared between all SchemaRDD types (i.e., Scala, Java) */ private[sql] trait SchemaRDDLike { - @transient val sqlContext: SQLContext + @transient def sqlContext: SQLContext @transient val baseLogicalPlan: LogicalPlan private[sql] def baseSchemaRDD: SchemaRDD diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UdfRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UdfRegistration.scala index 595b4aa36e..6d4c0d82ac 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UdfRegistration.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/UdfRegistration.scala @@ -78,7 +78,7 @@ private[sql] trait UDFRegistration { s""" def registerFunction[T: TypeTag](name: String, func: Function$x[$types, T]): Unit = { def builder(e: Seq[Expression]) = - ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e) functionRegistry.registerFunction(name, builder) } """ @@ -87,112 +87,112 @@ private[sql] trait UDFRegistration { // scalastyle:off def registerFunction[T: TypeTag](name: String, func: Function1[_, T]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e) functionRegistry.registerFunction(name, builder) } def registerFunction[T: TypeTag](name: String, func: Function2[_, _, T]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e) functionRegistry.registerFunction(name, builder) } def registerFunction[T: TypeTag](name: String, func: Function3[_, _, _, T]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e) functionRegistry.registerFunction(name, builder) } def registerFunction[T: TypeTag](name: String, func: Function4[_, _, _, _, T]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e) functionRegistry.registerFunction(name, builder) } def registerFunction[T: TypeTag](name: String, func: Function5[_, _, _, _, _, T]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e) functionRegistry.registerFunction(name, builder) } def registerFunction[T: TypeTag](name: String, func: Function6[_, _, _, _, _, _, T]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e) functionRegistry.registerFunction(name, builder) } def registerFunction[T: TypeTag](name: String, func: Function7[_, _, _, _, _, _, _, T]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e) functionRegistry.registerFunction(name, builder) } def registerFunction[T: TypeTag](name: String, func: Function8[_, _, _, _, _, _, _, _, T]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e) functionRegistry.registerFunction(name, builder) } def registerFunction[T: TypeTag](name: String, func: Function9[_, _, _, _, _, _, _, _, _, T]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e) functionRegistry.registerFunction(name, builder) } def registerFunction[T: TypeTag](name: String, func: Function10[_, _, _, _, _, _, _, _, _, _, T]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e) functionRegistry.registerFunction(name, builder) } def registerFunction[T: TypeTag](name: String, func: Function11[_, _, _, _, _, _, _, _, _, _, _, T]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e) functionRegistry.registerFunction(name, builder) } def registerFunction[T: TypeTag](name: String, func: Function12[_, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e) functionRegistry.registerFunction(name, builder) } def registerFunction[T: TypeTag](name: String, func: Function13[_, _, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e) functionRegistry.registerFunction(name, builder) } def registerFunction[T: TypeTag](name: String, func: Function14[_, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e) functionRegistry.registerFunction(name, builder) } def registerFunction[T: TypeTag](name: String, func: Function15[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e) functionRegistry.registerFunction(name, builder) } def registerFunction[T: TypeTag](name: String, func: Function16[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e) functionRegistry.registerFunction(name, builder) } def registerFunction[T: TypeTag](name: String, func: Function17[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e) functionRegistry.registerFunction(name, builder) } def registerFunction[T: TypeTag](name: String, func: Function18[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e) functionRegistry.registerFunction(name, builder) } def registerFunction[T: TypeTag](name: String, func: Function19[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e) functionRegistry.registerFunction(name, builder) } def registerFunction[T: TypeTag](name: String, func: Function20[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e) functionRegistry.registerFunction(name, builder) } def registerFunction[T: TypeTag](name: String, func: Function21[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e) functionRegistry.registerFunction(name, builder) } def registerFunction[T: TypeTag](name: String, func: Function22[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = { - def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e) functionRegistry.registerFunction(name, builder) } // scalastyle:on diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala index 60065509bf..4c0869e05b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala @@ -23,13 +23,14 @@ import org.apache.hadoop.conf.Configuration import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} -import org.apache.spark.sql.json.JsonRDD -import org.apache.spark.sql.sources.{LogicalRelation, BaseRelation} -import org.apache.spark.sql.types.util.DataTypeConversions import org.apache.spark.sql.{SQLContext, StructType => SStructType} +import org.apache.spark.sql.catalyst.annotation.SQLUserDefinedType import org.apache.spark.sql.catalyst.expressions.{AttributeReference, GenericRow, Row => ScalaRow} -import org.apache.spark.sql.parquet.ParquetRelation import org.apache.spark.sql.execution.LogicalRDD +import org.apache.spark.sql.json.JsonRDD +import org.apache.spark.sql.parquet.ParquetRelation +import org.apache.spark.sql.sources.{LogicalRelation, BaseRelation} +import org.apache.spark.sql.types.util.DataTypeConversions import org.apache.spark.sql.types.util.DataTypeConversions.asScalaDataType import org.apache.spark.util.Utils @@ -91,9 +92,12 @@ class JavaSQLContext(val sqlContext: SQLContext) extends UDFRegistration { /** * Applies a schema to an RDD of Java Beans. + * + * WARNING: Since there is no guaranteed ordering for fields in a Java Bean, + * SELECT * queries will return the columns in an undefined order. */ def applySchema(rdd: JavaRDD[_], beanClass: Class[_]): JavaSchemaRDD = { - val schema = getSchema(beanClass) + val attributeSeq = getSchema(beanClass) val className = beanClass.getName val rowRdd = rdd.rdd.mapPartitions { iter => // BeanInfo is not serializable so we must rediscover it remotely for each partition. @@ -104,11 +108,13 @@ class JavaSQLContext(val sqlContext: SQLContext) extends UDFRegistration { iter.map { row => new GenericRow( - extractors.map(e => DataTypeConversions.convertJavaToCatalyst(e.invoke(row))).toArray[Any] + extractors.zip(attributeSeq).map { case (e, attr) => + DataTypeConversions.convertJavaToCatalyst(e.invoke(row), attr.dataType) + }.toArray[Any] ): ScalaRow } } - new JavaSchemaRDD(sqlContext, LogicalRDD(schema, rowRdd)(sqlContext)) + new JavaSchemaRDD(sqlContext, LogicalRDD(attributeSeq, rowRdd)(sqlContext)) } /** @@ -195,14 +201,21 @@ class JavaSQLContext(val sqlContext: SQLContext) extends UDFRegistration { sqlContext.registerRDDAsTable(rdd.baseSchemaRDD, tableName) } - /** Returns a Catalyst Schema for the given java bean class. */ + /** + * Returns a Catalyst Schema for the given java bean class. + */ protected def getSchema(beanClass: Class[_]): Seq[AttributeReference] = { // TODO: All of this could probably be moved to Catalyst as it is mostly not Spark specific. val beanInfo = Introspector.getBeanInfo(beanClass) + // Note: The ordering of elements may differ from when the schema is inferred in Scala. + // This is because beanInfo.getPropertyDescriptors gives no guarantees about + // element ordering. val fields = beanInfo.getPropertyDescriptors.filterNot(_.getName == "class") fields.map { property => val (dataType, nullable) = property.getPropertyType match { + case c: Class[_] if c.isAnnotationPresent(classOf[SQLUserDefinedType]) => + (c.getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance(), true) case c: Class[_] if c == classOf[java.lang.String] => (org.apache.spark.sql.StringType, true) case c: Class[_] if c == java.lang.Short.TYPE => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/UDTWrappers.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/java/UDTWrappers.scala new file mode 100644 index 0000000000..a7d0f4f127 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/UDTWrappers.scala @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java + +import org.apache.spark.sql.catalyst.types.{UserDefinedType => ScalaUserDefinedType} +import org.apache.spark.sql.{DataType => ScalaDataType} +import org.apache.spark.sql.types.util.DataTypeConversions + +/** + * Scala wrapper for a Java UserDefinedType + */ +private[sql] class JavaToScalaUDTWrapper[UserType](val javaUDT: UserDefinedType[UserType]) + extends ScalaUserDefinedType[UserType] with Serializable { + + /** Underlying storage type for this UDT */ + val sqlType: ScalaDataType = DataTypeConversions.asScalaDataType(javaUDT.sqlType()) + + /** Convert the user type to a SQL datum */ + def serialize(obj: Any): Any = javaUDT.serialize(obj) + + /** Convert a SQL datum to the user type */ + def deserialize(datum: Any): UserType = javaUDT.deserialize(datum) + + val userClass: java.lang.Class[UserType] = javaUDT.userClass() +} + +/** + * Java wrapper for a Scala UserDefinedType + */ +private[sql] class ScalaToJavaUDTWrapper[UserType](val scalaUDT: ScalaUserDefinedType[UserType]) + extends UserDefinedType[UserType] with Serializable { + + /** Underlying storage type for this UDT */ + val sqlType: DataType = DataTypeConversions.asJavaDataType(scalaUDT.sqlType) + + /** Convert the user type to a SQL datum */ + def serialize(obj: Any): java.lang.Object = scalaUDT.serialize(obj).asInstanceOf[java.lang.Object] + + /** Convert a SQL datum to the user type */ + def deserialize(datum: Any): UserType = scalaUDT.deserialize(datum) + + val userClass: java.lang.Class[UserType] = scalaUDT.userClass +} + +private[sql] object UDTWrappers { + + def wrapAsScala(udtType: UserDefinedType[_]): ScalaUserDefinedType[_] = { + udtType match { + case t: ScalaToJavaUDTWrapper[_] => t.scalaUDT + case _ => new JavaToScalaUDTWrapper(udtType) + } + } + + def wrapAsJava(udtType: ScalaUserDefinedType[_]): UserDefinedType[_] = { + udtType match { + case t: JavaToScalaUDTWrapper[_] => t.javaUDT + case _ => new ScalaToJavaUDTWrapper(udtType) + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala index d64c5af89e..ed6b95dc6d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala @@ -19,29 +19,32 @@ package org.apache.spark.sql.execution import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{DataType, StructType, Row, SQLContext} import org.apache.spark.sql.catalyst.ScalaReflection +import org.apache.spark.sql.catalyst.ScalaReflection.Schema import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.expressions.{Attribute, GenericMutableRow} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Statistics} -import org.apache.spark.sql.{Row, SQLContext} +import org.apache.spark.sql.catalyst.types.UserDefinedType /** * :: DeveloperApi :: */ @DeveloperApi object RDDConversions { - def productToRowRdd[A <: Product](data: RDD[A]): RDD[Row] = { + def productToRowRdd[A <: Product](data: RDD[A], schema: StructType): RDD[Row] = { data.mapPartitions { iterator => if (iterator.isEmpty) { Iterator.empty } else { val bufferedIterator = iterator.buffered val mutableRow = new GenericMutableRow(bufferedIterator.head.productArity) - + val schemaFields = schema.fields.toArray bufferedIterator.map { r => var i = 0 while (i < mutableRow.length) { - mutableRow(i) = ScalaReflection.convertToCatalyst(r.productElement(i)) + mutableRow(i) = + ScalaReflection.convertToCatalyst(r.productElement(i), schemaFields(i).dataType) i += 1 } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index aafcce0572..81c60e0050 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -20,8 +20,6 @@ package org.apache.spark.sql.execution import org.apache.spark.annotation.DeveloperApi import org.apache.spark.Logging import org.apache.spark.rdd.RDD - - import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.{ScalaReflection, trees} import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation @@ -82,7 +80,8 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ /** * Runs this query returning the result as an array. */ - def executeCollect(): Array[Row] = execute().map(ScalaReflection.convertRowToScala).collect() + def executeCollect(): Array[Row] = + execute().map(ScalaReflection.convertRowToScala(_, schema)).collect() protected def newProjection( expressions: Seq[Expression], inputSchema: Seq[Attribute]): Projection = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 2cd3063bc3..cc7e0c05ff 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -280,7 +280,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { val nPartitions = if (data.isEmpty) 1 else numPartitions PhysicalRDD( output, - RDDConversions.productToRowRdd(sparkContext.parallelize(data, nPartitions))) :: Nil + RDDConversions.productToRowRdd(sparkContext.parallelize(data, nPartitions), + StructType.fromAttributes(output))) :: Nil case logical.Limit(IntegerLiteral(limit), child) => execution.Limit(limit, planLater(child)) :: Nil case Unions(unionChildren) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index e6cd1a9d04..1b8ba3ace2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -143,7 +143,7 @@ case class Limit(limit: Int, child: SparkPlan) partsScanned += numPartsToTry } - buf.toArray.map(ScalaReflection.convertRowToScala) + buf.toArray.map(ScalaReflection.convertRowToScala(_, this.schema)) } override def execute() = { @@ -179,8 +179,8 @@ case class TakeOrdered(limit: Int, sortOrder: Seq[SortOrder], child: SparkPlan) val ord = new RowOrdering(sortOrder, child.output) // TODO: Is this copying for no reason? - override def executeCollect() = - child.execute().map(_.copy()).takeOrdered(limit)(ord).map(ScalaReflection.convertRowToScala) + override def executeCollect() = child.execute().map(_.copy()).takeOrdered(limit)(ord) + .map(ScalaReflection.convertRowToScala(_, this.schema)) // TODO: Terminal split should be implemented differently from non-terminal split. // TODO: Pick num splits based on |limit|. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala index 08feced61a..1bbb66aaa1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala @@ -77,6 +77,9 @@ private[sql] object CatalystConverter { parent: CatalystConverter): Converter = { val fieldType: DataType = field.dataType fieldType match { + case udt: UserDefinedType[_] => { + createConverter(field.copy(dataType = udt.sqlType), fieldIndex, parent) + } // For native JVM types we use a converter with native arrays case ArrayType(elementType: NativeType, false) => { new CatalystNativeArrayConverter(elementType, fieldIndex, parent) @@ -255,8 +258,8 @@ private[parquet] class CatalystGroupConverter( schema, index, parent, - current=null, - buffer=new ArrayBuffer[Row]( + current = null, + buffer = new ArrayBuffer[Row]( CatalystArrayConverter.INITIAL_ARRAY_SIZE)) /** @@ -301,7 +304,7 @@ private[parquet] class CatalystGroupConverter( override def end(): Unit = { if (!isRootConverter) { - assert(current!=null) // there should be no empty groups + assert(current != null) // there should be no empty groups buffer.append(new GenericRow(current.toArray)) parent.updateField(index, new GenericRow(buffer.toArray.asInstanceOf[Array[Any]])) } @@ -358,7 +361,7 @@ private[parquet] class CatalystPrimitiveRowConverter( override def end(): Unit = {} - // Overriden here to avoid auto-boxing for primitive types + // Overridden here to avoid auto-boxing for primitive types override protected[parquet] def updateBoolean(fieldIndex: Int, value: Boolean): Unit = current.setBoolean(fieldIndex, value) @@ -533,7 +536,7 @@ private[parquet] class CatalystNativeArrayConverter( override protected[parquet] def updateField(fieldIndex: Int, value: Any): Unit = throw new UnsupportedOperationException - // Overriden here to avoid auto-boxing for primitive types + // Overridden here to avoid auto-boxing for primitive types override protected[parquet] def updateBoolean(fieldIndex: Int, value: Boolean): Unit = { checkGrowBuffer() buffer(elements) = value.asInstanceOf[NativeType] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala index 2a5f23b24e..7bc2496600 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala @@ -20,7 +20,6 @@ package org.apache.spark.sql.parquet import java.util.{HashMap => JHashMap} import org.apache.hadoop.conf.Configuration -import org.apache.spark.sql.catalyst.types.decimal.Decimal import parquet.column.ParquetProperties import parquet.hadoop.ParquetOutputFormat import parquet.hadoop.api.ReadSupport.ReadContext @@ -31,6 +30,7 @@ import parquet.schema.MessageType import org.apache.spark.Logging import org.apache.spark.sql.catalyst.expressions.{Attribute, Row} import org.apache.spark.sql.catalyst.types._ +import org.apache.spark.sql.catalyst.types.decimal.Decimal /** * A `parquet.io.api.RecordMaterializer` for Rows. @@ -174,6 +174,7 @@ private[parquet] class RowWriteSupport extends WriteSupport[Row] with Logging { private[parquet] def writeValue(schema: DataType, value: Any): Unit = { if (value != null) { schema match { + case t: UserDefinedType[_] => writeValue(t.sqlType, value) case t @ ArrayType(_, _) => writeArray( t, value.asInstanceOf[CatalystConverter.ArrayScalaType[_]]) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala index e5077de8dd..fa37d1f2ae 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala @@ -290,6 +290,9 @@ private[parquet] object ParquetTypesConverter extends Logging { builder.named(name) }.getOrElse { ctype match { + case udt: UserDefinedType[_] => { + fromDataType(udt.sqlType, name, nullable, inArray) + } case ArrayType(elementType, false) => { val parquetElementType = fromDataType( elementType, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala b/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala index 7564bf3923..1bc15146f0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala @@ -17,12 +17,16 @@ package org.apache.spark.sql.types.util +import scala.collection.JavaConverters._ + import org.apache.spark.sql._ -import org.apache.spark.sql.api.java.{DataType => JDataType, StructField => JStructField, MetadataBuilder => JMetaDataBuilder} +import org.apache.spark.sql.api.java.{DataType => JDataType, StructField => JStructField, + MetadataBuilder => JMetaDataBuilder, UDTWrappers, JavaToScalaUDTWrapper} import org.apache.spark.sql.api.java.{DecimalType => JDecimalType} import org.apache.spark.sql.catalyst.types.decimal.Decimal +import org.apache.spark.sql.catalyst.ScalaReflection +import org.apache.spark.sql.catalyst.types.UserDefinedType -import scala.collection.JavaConverters._ protected[sql] object DataTypeConversions { @@ -41,6 +45,9 @@ protected[sql] object DataTypeConversions { * Returns the equivalent DataType in Java for the given DataType in Scala. */ def asJavaDataType(scalaDataType: DataType): JDataType = scalaDataType match { + case udtType: UserDefinedType[_] => + UDTWrappers.wrapAsJava(udtType) + case StringType => JDataType.StringType case BinaryType => JDataType.BinaryType case BooleanType => JDataType.BooleanType @@ -80,6 +87,9 @@ protected[sql] object DataTypeConversions { * Returns the equivalent DataType in Scala for the given DataType in Java. */ def asScalaDataType(javaDataType: JDataType): DataType = javaDataType match { + case udtType: org.apache.spark.sql.api.java.UserDefinedType[_] => + UDTWrappers.wrapAsScala(udtType) + case stringType: org.apache.spark.sql.api.java.StringType => StringType case binaryType: org.apache.spark.sql.api.java.BinaryType => @@ -121,9 +131,11 @@ protected[sql] object DataTypeConversions { } /** Converts Java objects to catalyst rows / types */ - def convertJavaToCatalyst(a: Any): Any = a match { - case d: java.math.BigDecimal => Decimal(BigDecimal(d)) - case other => other + def convertJavaToCatalyst(a: Any, dataType: DataType): Any = (a, dataType) match { + case (obj, udt: UserDefinedType[_]) => ScalaReflection.convertToCatalyst(obj, udt) // Scala type + case (d: java.math.BigDecimal, _) => Decimal(BigDecimal(d)) + case (d: java.math.BigDecimal, _) => BigDecimal(d) + case (other, _) => other } /** Converts Java objects to catalyst rows / types */ diff --git a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaUserDefinedTypeSuite.java b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaUserDefinedTypeSuite.java new file mode 100644 index 0000000000..0caa8219a6 --- /dev/null +++ b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaUserDefinedTypeSuite.java @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java; + +import java.io.Serializable; +import java.util.*; + +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.MyDenseVector; +import org.apache.spark.sql.MyLabeledPoint; + +public class JavaUserDefinedTypeSuite implements Serializable { + private transient JavaSparkContext javaCtx; + private transient JavaSQLContext javaSqlCtx; + + @Before + public void setUp() { + javaCtx = new JavaSparkContext("local", "JavaUserDefinedTypeSuite"); + javaSqlCtx = new JavaSQLContext(javaCtx); + } + + @After + public void tearDown() { + javaCtx.stop(); + javaCtx = null; + javaSqlCtx = null; + } + + @Test + public void useScalaUDT() { + List points = Arrays.asList( + new MyLabeledPoint(1.0, new MyDenseVector(new double[]{0.1, 1.0})), + new MyLabeledPoint(0.0, new MyDenseVector(new double[]{0.2, 2.0}))); + JavaRDD pointsRDD = javaCtx.parallelize(points); + + JavaSchemaRDD schemaRDD = javaSqlCtx.applySchema(pointsRDD, MyLabeledPoint.class); + schemaRDD.registerTempTable("points"); + + List actualLabelRows = javaSqlCtx.sql("SELECT label FROM points").collect(); + List actualLabels = new LinkedList(); + for (Row r : actualLabelRows) { + actualLabels.add(r.getDouble(0)); + } + for (MyLabeledPoint lp : points) { + Assert.assertTrue(actualLabels.contains(lp.label())); + } + + List actualFeatureRows = javaSqlCtx.sql("SELECT features FROM points").collect(); + List actualFeatures = new LinkedList(); + for (Row r : actualFeatureRows) { + actualFeatures.add((MyDenseVector)r.get(0)); + } + for (MyLabeledPoint lp : points) { + Assert.assertTrue(actualFeatures.contains(lp.features())); + } + + List actual = javaSqlCtx.sql("SELECT label, features FROM points").collect(); + List actualPoints = + new LinkedList(); + for (Row r : actual) { + actualPoints.add(new MyLabeledPoint(r.getDouble(0), (MyDenseVector)r.get(1))); + } + for (MyLabeledPoint lp : points) { + Assert.assertTrue(actualPoints.contains(lp)); + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala new file mode 100644 index 0000000000..666235e57f --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import scala.beans.{BeanInfo, BeanProperty} + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.annotation.SQLUserDefinedType +import org.apache.spark.sql.catalyst.types.UserDefinedType +import org.apache.spark.sql.test.TestSQLContext._ + +@SQLUserDefinedType(udt = classOf[MyDenseVectorUDT]) +private[sql] class MyDenseVector(val data: Array[Double]) extends Serializable { + override def equals(other: Any): Boolean = other match { + case v: MyDenseVector => + java.util.Arrays.equals(this.data, v.data) + case _ => false + } +} + +@BeanInfo +private[sql] case class MyLabeledPoint( + @BeanProperty label: Double, + @BeanProperty features: MyDenseVector) + +private[sql] class MyDenseVectorUDT extends UserDefinedType[MyDenseVector] { + + override def sqlType: DataType = ArrayType(DoubleType, containsNull = false) + + override def serialize(obj: Any): Seq[Double] = { + obj match { + case features: MyDenseVector => + features.data.toSeq + } + } + + override def deserialize(datum: Any): MyDenseVector = { + datum match { + case data: Seq[_] => + new MyDenseVector(data.asInstanceOf[Seq[Double]].toArray) + } + } + + override def userClass = classOf[MyDenseVector] +} + +class UserDefinedTypeSuite extends QueryTest { + + test("register user type: MyDenseVector for MyLabeledPoint") { + val points = Seq( + MyLabeledPoint(1.0, new MyDenseVector(Array(0.1, 1.0))), + MyLabeledPoint(0.0, new MyDenseVector(Array(0.2, 2.0)))) + val pointsRDD: RDD[MyLabeledPoint] = sparkContext.parallelize(points) + + val labels: RDD[Double] = pointsRDD.select('label).map { case Row(v: Double) => v } + val labelsArrays: Array[Double] = labels.collect() + assert(labelsArrays.size === 2) + assert(labelsArrays.contains(1.0)) + assert(labelsArrays.contains(0.0)) + + val features: RDD[MyDenseVector] = + pointsRDD.select('features).map { case Row(v: MyDenseVector) => v } + val featuresArrays: Array[MyDenseVector] = features.collect() + assert(featuresArrays.size === 2) + assert(featuresArrays.contains(new MyDenseVector(Array(0.1, 1.0)))) + assert(featuresArrays.contains(new MyDenseVector(Array(0.2, 2.0)))) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala index 4b851d1b96..cade244f7a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala @@ -21,8 +21,7 @@ import org.apache.spark.sql.catalyst.types._ import org.apache.spark.sql.catalyst.types.decimal.Decimal import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.json.JsonRDD.{enforceCorrectType, compatibleType} -import org.apache.spark.sql.QueryTest -import org.apache.spark.sql.SQLConf +import org.apache.spark.sql.{Row, SQLConf, QueryTest} import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.test.TestSQLContext._ @@ -233,8 +232,8 @@ class JsonSuite extends QueryTest { StructField("field2", StringType, true) :: StructField("field3", StringType, true) :: Nil), false), true) :: StructField("struct", StructType( - StructField("field1", BooleanType, true) :: - StructField("field2", DecimalType.Unlimited, true) :: Nil), true) :: + StructField("field1", BooleanType, true) :: + StructField("field2", DecimalType.Unlimited, true) :: Nil), true) :: StructField("structWithArrayFields", StructType( StructField("field1", ArrayType(IntegerType, false), true) :: StructField("field2", ArrayType(StringType, false), true) :: Nil), true) :: Nil) @@ -292,8 +291,8 @@ class JsonSuite extends QueryTest { // Access a struct and fields inside of it. checkAnswer( sql("select struct, struct.field1, struct.field2 from jsonTable"), - ( - Seq(true, BigDecimal("92233720368547758070")), + Row( + Row(true, BigDecimal("92233720368547758070")), true, BigDecimal("92233720368547758070")) :: Nil ) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 0fe59f42f2..f025169ad5 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -374,8 +374,6 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { /** Extends QueryExecution with hive specific features. */ protected[sql] abstract class QueryExecution extends super.QueryExecution { - override lazy val toRdd: RDD[Row] = executedPlan.execute().map(_.copy()) - protected val primitiveTypes = Seq(StringType, IntegerType, LongType, DoubleType, FloatType, BooleanType, ByteType, ShortType, DateType, TimestampType, BinaryType) @@ -433,7 +431,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { command.executeCollect().map(_.head.toString) case other => - val result: Seq[Seq[Any]] = toRdd.collect().toSeq + val result: Seq[Seq[Any]] = toRdd.map(_.copy()).collect().toSeq // We need the types so we can output struct field names val types = analyzed.output.map(_.dataType) // Reformat to match hive tab delimited output. -- cgit v1.2.3