aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorWenchen Fan <wenchen@databricks.com>2015-11-13 11:25:33 -0800
committerMichael Armbrust <michael@databricks.com>2015-11-13 11:25:33 -0800
commitd7b2b97ad67f9700fb8c13422c399f2edb72f770 (patch)
tree8f23b645a0e50ec19feeee1bf0bbfb56c83fc12d
parent23b8188f75d945ef70fbb1c4dc9720c2c5f8cbc3 (diff)
downloadspark-d7b2b97ad67f9700fb8c13422c399f2edb72f770.tar.gz
spark-d7b2b97ad67f9700fb8c13422c399f2edb72f770.tar.bz2
spark-d7b2b97ad67f9700fb8c13422c399f2edb72f770.zip
[SPARK-11727][SQL] Split ExpressionEncoder into FlatEncoder and ProductEncoder
also add more tests for encoders, and fix bugs that I found: * when convert array to catalyst array, we can only skip element conversion for native types(e.g. int, long, boolean), not `AtomicType`(String is AtomicType but we need to convert it) * we should also handle scala `BigDecimal` when convert from catalyst `Decimal`. * complex map type should be supported other issues that still in investigation: * encode java `BigDecimal` and decode it back, seems we will loss precision info. * when encode case class that defined inside a object, `ClassNotFound` exception will be thrown. I'll remove unused code in a follow-up PR. Author: Wenchen Fan <wenchen@databricks.com> Closes #9693 from cloud-fan/split.
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/FlatEncoder.scala50
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala452
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala58
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala2
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala259
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/FlatEncoderSuite.scala74
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala123
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala7
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala22
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/functions.scala4
12 files changed, 766 insertions, 289 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index 6d822261b0..0b3dd351e3 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -75,7 +75,7 @@ trait ScalaReflection {
*
* @see SPARK-5281
*/
- private def localTypeOf[T: TypeTag]: `Type` = typeTag[T].in(mirror).tpe
+ def localTypeOf[T: TypeTag]: `Type` = typeTag[T].in(mirror).tpe
/**
* Returns the Spark SQL DataType for a given scala type. Where this is not an exact mapping
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/FlatEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/FlatEncoder.scala
new file mode 100644
index 0000000000..6d307ab13a
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/FlatEncoder.scala
@@ -0,0 +1,50 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.encoders
+
+import scala.reflect.ClassTag
+import scala.reflect.runtime.universe.{typeTag, TypeTag}
+
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.catalyst.expressions.{Literal, CreateNamedStruct, BoundReference}
+import org.apache.spark.sql.catalyst.ScalaReflection
+
+object FlatEncoder {
+ import ScalaReflection.schemaFor
+ import ScalaReflection.dataTypeFor
+
+ def apply[T : TypeTag]: ExpressionEncoder[T] = {
+ // We convert the not-serializable TypeTag into StructType and ClassTag.
+ val tpe = typeTag[T].tpe
+ val mirror = typeTag[T].mirror
+ val cls = mirror.runtimeClass(tpe)
+ assert(!schemaFor(tpe).dataType.isInstanceOf[StructType])
+
+ val input = BoundReference(0, dataTypeFor(tpe), nullable = true)
+ val toRowExpression = CreateNamedStruct(
+ Literal("value") :: ProductEncoder.extractorFor(input, tpe) :: Nil)
+ val fromRowExpression = ProductEncoder.constructorFor(tpe)
+
+ new ExpressionEncoder[T](
+ toRowExpression.dataType,
+ flat = true,
+ toRowExpression.flatten,
+ fromRowExpression,
+ ClassTag[T](cls))
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala
new file mode 100644
index 0000000000..414adb2116
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala
@@ -0,0 +1,452 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.encoders
+
+import org.apache.spark.util.Utils
+import org.apache.spark.unsafe.types.UTF8String
+import org.apache.spark.sql.types._
+import org.apache.spark.sql.catalyst.ScalaReflectionLock
+import org.apache.spark.sql.catalyst.ScalaReflection
+import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedExtractValue}
+import org.apache.spark.sql.catalyst.expressions
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.util.{DateTimeUtils, ArrayBasedMapData, GenericArrayData}
+
+import scala.reflect.ClassTag
+
+object ProductEncoder {
+ import ScalaReflection.universe._
+ import ScalaReflection.localTypeOf
+ import ScalaReflection.dataTypeFor
+ import ScalaReflection.Schema
+ import ScalaReflection.schemaFor
+ import ScalaReflection.arrayClassFor
+
+ def apply[T <: Product : TypeTag]: ExpressionEncoder[T] = {
+ // We convert the not-serializable TypeTag into StructType and ClassTag.
+ val tpe = typeTag[T].tpe
+ val mirror = typeTag[T].mirror
+ val cls = mirror.runtimeClass(tpe)
+
+ val inputObject = BoundReference(0, ObjectType(cls), nullable = true)
+ val toRowExpression = extractorFor(inputObject, tpe).asInstanceOf[CreateNamedStruct]
+ val fromRowExpression = constructorFor(tpe)
+
+ new ExpressionEncoder[T](
+ toRowExpression.dataType,
+ flat = false,
+ toRowExpression.flatten,
+ fromRowExpression,
+ ClassTag[T](cls))
+ }
+
+ // The Predef.Map is scala.collection.immutable.Map.
+ // Since the map values can be mutable, we explicitly import scala.collection.Map at here.
+ import scala.collection.Map
+
+ def extractorFor(
+ inputObject: Expression,
+ tpe: `Type`): Expression = ScalaReflectionLock.synchronized {
+ if (!inputObject.dataType.isInstanceOf[ObjectType]) {
+ inputObject
+ } else {
+ tpe match {
+ case t if t <:< localTypeOf[Option[_]] =>
+ val TypeRef(_, _, Seq(optType)) = t
+ optType match {
+ // For primitive types we must manually unbox the value of the object.
+ case t if t <:< definitions.IntTpe =>
+ Invoke(
+ UnwrapOption(ObjectType(classOf[java.lang.Integer]), inputObject),
+ "intValue",
+ IntegerType)
+ case t if t <:< definitions.LongTpe =>
+ Invoke(
+ UnwrapOption(ObjectType(classOf[java.lang.Long]), inputObject),
+ "longValue",
+ LongType)
+ case t if t <:< definitions.DoubleTpe =>
+ Invoke(
+ UnwrapOption(ObjectType(classOf[java.lang.Double]), inputObject),
+ "doubleValue",
+ DoubleType)
+ case t if t <:< definitions.FloatTpe =>
+ Invoke(
+ UnwrapOption(ObjectType(classOf[java.lang.Float]), inputObject),
+ "floatValue",
+ FloatType)
+ case t if t <:< definitions.ShortTpe =>
+ Invoke(
+ UnwrapOption(ObjectType(classOf[java.lang.Short]), inputObject),
+ "shortValue",
+ ShortType)
+ case t if t <:< definitions.ByteTpe =>
+ Invoke(
+ UnwrapOption(ObjectType(classOf[java.lang.Byte]), inputObject),
+ "byteValue",
+ ByteType)
+ case t if t <:< definitions.BooleanTpe =>
+ Invoke(
+ UnwrapOption(ObjectType(classOf[java.lang.Boolean]), inputObject),
+ "booleanValue",
+ BooleanType)
+
+ // For non-primitives, we can just extract the object from the Option and then recurse.
+ case other =>
+ val className: String = optType.erasure.typeSymbol.asClass.fullName
+ val classObj = Utils.classForName(className)
+ val optionObjectType = ObjectType(classObj)
+
+ val unwrapped = UnwrapOption(optionObjectType, inputObject)
+ expressions.If(
+ IsNull(unwrapped),
+ expressions.Literal.create(null, schemaFor(optType).dataType),
+ extractorFor(unwrapped, optType))
+ }
+
+ case t if t <:< localTypeOf[Product] =>
+ val formalTypeArgs = t.typeSymbol.asClass.typeParams
+ val TypeRef(_, _, actualTypeArgs) = t
+ val constructorSymbol = t.member(nme.CONSTRUCTOR)
+ val params = if (constructorSymbol.isMethod) {
+ constructorSymbol.asMethod.paramss
+ } else {
+ // Find the primary constructor, and use its parameter ordering.
+ val primaryConstructorSymbol: Option[Symbol] =
+ constructorSymbol.asTerm.alternatives.find(s =>
+ s.isMethod && s.asMethod.isPrimaryConstructor)
+
+ if (primaryConstructorSymbol.isEmpty) {
+ sys.error("Internal SQL error: Product object did not have a primary constructor.")
+ } else {
+ primaryConstructorSymbol.get.asMethod.paramss
+ }
+ }
+
+ CreateNamedStruct(params.head.flatMap { p =>
+ val fieldName = p.name.toString
+ val fieldType = p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs)
+ val fieldValue = Invoke(inputObject, fieldName, dataTypeFor(fieldType))
+ expressions.Literal(fieldName) :: extractorFor(fieldValue, fieldType) :: Nil
+ })
+
+ case t if t <:< localTypeOf[Array[_]] =>
+ val TypeRef(_, _, Seq(elementType)) = t
+ toCatalystArray(inputObject, elementType)
+
+ case t if t <:< localTypeOf[Seq[_]] =>
+ val TypeRef(_, _, Seq(elementType)) = t
+ toCatalystArray(inputObject, elementType)
+
+ case t if t <:< localTypeOf[Map[_, _]] =>
+ val TypeRef(_, _, Seq(keyType, valueType)) = t
+
+ val keys =
+ Invoke(
+ Invoke(inputObject, "keysIterator",
+ ObjectType(classOf[scala.collection.Iterator[_]])),
+ "toSeq",
+ ObjectType(classOf[scala.collection.Seq[_]]))
+ val convertedKeys = toCatalystArray(keys, keyType)
+
+ val values =
+ Invoke(
+ Invoke(inputObject, "valuesIterator",
+ ObjectType(classOf[scala.collection.Iterator[_]])),
+ "toSeq",
+ ObjectType(classOf[scala.collection.Seq[_]]))
+ val convertedValues = toCatalystArray(values, valueType)
+
+ val Schema(keyDataType, _) = schemaFor(keyType)
+ val Schema(valueDataType, valueNullable) = schemaFor(valueType)
+ NewInstance(
+ classOf[ArrayBasedMapData],
+ convertedKeys :: convertedValues :: Nil,
+ dataType = MapType(keyDataType, valueDataType, valueNullable))
+
+ case t if t <:< localTypeOf[String] =>
+ StaticInvoke(
+ classOf[UTF8String],
+ StringType,
+ "fromString",
+ inputObject :: Nil)
+
+ case t if t <:< localTypeOf[java.sql.Timestamp] =>
+ StaticInvoke(
+ DateTimeUtils,
+ TimestampType,
+ "fromJavaTimestamp",
+ inputObject :: Nil)
+
+ case t if t <:< localTypeOf[java.sql.Date] =>
+ StaticInvoke(
+ DateTimeUtils,
+ DateType,
+ "fromJavaDate",
+ inputObject :: Nil)
+
+ case t if t <:< localTypeOf[BigDecimal] =>
+ StaticInvoke(
+ Decimal,
+ DecimalType.SYSTEM_DEFAULT,
+ "apply",
+ inputObject :: Nil)
+
+ case t if t <:< localTypeOf[java.math.BigDecimal] =>
+ StaticInvoke(
+ Decimal,
+ DecimalType.SYSTEM_DEFAULT,
+ "apply",
+ inputObject :: Nil)
+
+ case t if t <:< localTypeOf[java.lang.Integer] =>
+ Invoke(inputObject, "intValue", IntegerType)
+ case t if t <:< localTypeOf[java.lang.Long] =>
+ Invoke(inputObject, "longValue", LongType)
+ case t if t <:< localTypeOf[java.lang.Double] =>
+ Invoke(inputObject, "doubleValue", DoubleType)
+ case t if t <:< localTypeOf[java.lang.Float] =>
+ Invoke(inputObject, "floatValue", FloatType)
+ case t if t <:< localTypeOf[java.lang.Short] =>
+ Invoke(inputObject, "shortValue", ShortType)
+ case t if t <:< localTypeOf[java.lang.Byte] =>
+ Invoke(inputObject, "byteValue", ByteType)
+ case t if t <:< localTypeOf[java.lang.Boolean] =>
+ Invoke(inputObject, "booleanValue", BooleanType)
+
+ case other =>
+ throw new UnsupportedOperationException(s"Extractor for type $other is not supported")
+ }
+ }
+ }
+
+ private def toCatalystArray(input: Expression, elementType: `Type`): Expression = {
+ val externalDataType = dataTypeFor(elementType)
+ val Schema(catalystType, nullable) = schemaFor(elementType)
+ if (RowEncoder.isNativeType(catalystType)) {
+ NewInstance(
+ classOf[GenericArrayData],
+ input :: Nil,
+ dataType = ArrayType(catalystType, nullable))
+ } else {
+ MapObjects(extractorFor(_, elementType), input, externalDataType)
+ }
+ }
+
+ def constructorFor(
+ tpe: `Type`,
+ path: Option[Expression] = None): Expression = ScalaReflectionLock.synchronized {
+
+ /** Returns the current path with a sub-field extracted. */
+ def addToPath(part: String): Expression = path
+ .map(p => UnresolvedExtractValue(p, expressions.Literal(part)))
+ .getOrElse(UnresolvedAttribute(part))
+
+ /** Returns the current path with a field at ordinal extracted. */
+ def addToPathOrdinal(ordinal: Int, dataType: DataType): Expression = path
+ .map(p => GetInternalRowField(p, ordinal, dataType))
+ .getOrElse(BoundReference(ordinal, dataType, false))
+
+ /** Returns the current path or `BoundReference`. */
+ def getPath: Expression = path.getOrElse(BoundReference(0, schemaFor(tpe).dataType, true))
+
+ tpe match {
+ case t if !dataTypeFor(t).isInstanceOf[ObjectType] => getPath
+
+ case t if t <:< localTypeOf[Option[_]] =>
+ val TypeRef(_, _, Seq(optType)) = t
+ WrapOption(null, constructorFor(optType, path))
+
+ case t if t <:< localTypeOf[java.lang.Integer] =>
+ val boxedType = classOf[java.lang.Integer]
+ val objectType = ObjectType(boxedType)
+ NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType)
+
+ case t if t <:< localTypeOf[java.lang.Long] =>
+ val boxedType = classOf[java.lang.Long]
+ val objectType = ObjectType(boxedType)
+ NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType)
+
+ case t if t <:< localTypeOf[java.lang.Double] =>
+ val boxedType = classOf[java.lang.Double]
+ val objectType = ObjectType(boxedType)
+ NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType)
+
+ case t if t <:< localTypeOf[java.lang.Float] =>
+ val boxedType = classOf[java.lang.Float]
+ val objectType = ObjectType(boxedType)
+ NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType)
+
+ case t if t <:< localTypeOf[java.lang.Short] =>
+ val boxedType = classOf[java.lang.Short]
+ val objectType = ObjectType(boxedType)
+ NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType)
+
+ case t if t <:< localTypeOf[java.lang.Byte] =>
+ val boxedType = classOf[java.lang.Byte]
+ val objectType = ObjectType(boxedType)
+ NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType)
+
+ case t if t <:< localTypeOf[java.lang.Boolean] =>
+ val boxedType = classOf[java.lang.Boolean]
+ val objectType = ObjectType(boxedType)
+ NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType)
+
+ case t if t <:< localTypeOf[java.sql.Date] =>
+ StaticInvoke(
+ DateTimeUtils,
+ ObjectType(classOf[java.sql.Date]),
+ "toJavaDate",
+ getPath :: Nil,
+ propagateNull = true)
+
+ case t if t <:< localTypeOf[java.sql.Timestamp] =>
+ StaticInvoke(
+ DateTimeUtils,
+ ObjectType(classOf[java.sql.Timestamp]),
+ "toJavaTimestamp",
+ getPath :: Nil,
+ propagateNull = true)
+
+ case t if t <:< localTypeOf[java.lang.String] =>
+ Invoke(getPath, "toString", ObjectType(classOf[String]))
+
+ case t if t <:< localTypeOf[java.math.BigDecimal] =>
+ Invoke(getPath, "toJavaBigDecimal", ObjectType(classOf[java.math.BigDecimal]))
+
+ case t if t <:< localTypeOf[BigDecimal] =>
+ Invoke(getPath, "toBigDecimal", ObjectType(classOf[BigDecimal]))
+
+ case t if t <:< localTypeOf[Array[_]] =>
+ val TypeRef(_, _, Seq(elementType)) = t
+ val primitiveMethod = elementType match {
+ case t if t <:< definitions.IntTpe => Some("toIntArray")
+ case t if t <:< definitions.LongTpe => Some("toLongArray")
+ case t if t <:< definitions.DoubleTpe => Some("toDoubleArray")
+ case t if t <:< definitions.FloatTpe => Some("toFloatArray")
+ case t if t <:< definitions.ShortTpe => Some("toShortArray")
+ case t if t <:< definitions.ByteTpe => Some("toByteArray")
+ case t if t <:< definitions.BooleanTpe => Some("toBooleanArray")
+ case _ => None
+ }
+
+ primitiveMethod.map { method =>
+ Invoke(getPath, method, arrayClassFor(elementType))
+ }.getOrElse {
+ Invoke(
+ MapObjects(
+ p => constructorFor(elementType, Some(p)),
+ getPath,
+ schemaFor(elementType).dataType),
+ "array",
+ arrayClassFor(elementType))
+ }
+
+ case t if t <:< localTypeOf[Seq[_]] =>
+ val TypeRef(_, _, Seq(elementType)) = t
+ val arrayData =
+ Invoke(
+ MapObjects(
+ p => constructorFor(elementType, Some(p)),
+ getPath,
+ schemaFor(elementType).dataType),
+ "array",
+ ObjectType(classOf[Array[Any]]))
+
+ StaticInvoke(
+ scala.collection.mutable.WrappedArray,
+ ObjectType(classOf[Seq[_]]),
+ "make",
+ arrayData :: Nil)
+
+ case t if t <:< localTypeOf[Map[_, _]] =>
+ val TypeRef(_, _, Seq(keyType, valueType)) = t
+
+ val keyData =
+ Invoke(
+ MapObjects(
+ p => constructorFor(keyType, Some(p)),
+ Invoke(getPath, "keyArray", ArrayType(schemaFor(keyType).dataType)),
+ schemaFor(keyType).dataType),
+ "array",
+ ObjectType(classOf[Array[Any]]))
+
+ val valueData =
+ Invoke(
+ MapObjects(
+ p => constructorFor(valueType, Some(p)),
+ Invoke(getPath, "valueArray", ArrayType(schemaFor(valueType).dataType)),
+ schemaFor(valueType).dataType),
+ "array",
+ ObjectType(classOf[Array[Any]]))
+
+ StaticInvoke(
+ ArrayBasedMapData,
+ ObjectType(classOf[Map[_, _]]),
+ "toScalaMap",
+ keyData :: valueData :: Nil)
+
+ case t if t <:< localTypeOf[Product] =>
+ val formalTypeArgs = t.typeSymbol.asClass.typeParams
+ val TypeRef(_, _, actualTypeArgs) = t
+ val constructorSymbol = t.member(nme.CONSTRUCTOR)
+ val params = if (constructorSymbol.isMethod) {
+ constructorSymbol.asMethod.paramss
+ } else {
+ // Find the primary constructor, and use its parameter ordering.
+ val primaryConstructorSymbol: Option[Symbol] =
+ constructorSymbol.asTerm.alternatives.find(s =>
+ s.isMethod && s.asMethod.isPrimaryConstructor)
+
+ if (primaryConstructorSymbol.isEmpty) {
+ sys.error("Internal SQL error: Product object did not have a primary constructor.")
+ } else {
+ primaryConstructorSymbol.get.asMethod.paramss
+ }
+ }
+
+ val className: String = t.erasure.typeSymbol.asClass.fullName
+ val cls = Utils.classForName(className)
+
+ val arguments = params.head.zipWithIndex.map { case (p, i) =>
+ val fieldName = p.name.toString
+ val fieldType = p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs)
+ val dataType = schemaFor(fieldType).dataType
+
+ // For tuples, we based grab the inner fields by ordinal instead of name.
+ if (className startsWith "scala.Tuple") {
+ constructorFor(fieldType, Some(addToPathOrdinal(i, dataType)))
+ } else {
+ constructorFor(fieldType, Some(addToPath(fieldName)))
+ }
+ }
+
+ val newInstance = NewInstance(cls, arguments, propagateNull = false, ObjectType(cls))
+
+ if (path.nonEmpty) {
+ expressions.If(
+ IsNull(getPath),
+ expressions.Literal.create(null, ObjectType(cls)),
+ newInstance
+ )
+ } else {
+ newInstance
+ }
+ }
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
index 0b42130a01..e0be896bb3 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
@@ -119,9 +119,17 @@ object RowEncoder {
CreateStruct(convertedFields)
}
- private def externalDataTypeFor(dt: DataType): DataType = dt match {
+ /**
+ * Returns true if the value of this data type is same between internal and external.
+ */
+ def isNativeType(dt: DataType): Boolean = dt match {
case BooleanType | ByteType | ShortType | IntegerType | LongType |
- FloatType | DoubleType | BinaryType => dt
+ FloatType | DoubleType | BinaryType => true
+ case _ => false
+ }
+
+ private def externalDataTypeFor(dt: DataType): DataType = dt match {
+ case _ if isNativeType(dt) => dt
case TimestampType => ObjectType(classOf[java.sql.Timestamp])
case DateType => ObjectType(classOf[java.sql.Date])
case _: DecimalType => ObjectType(classOf[java.math.BigDecimal])
@@ -137,13 +145,13 @@ object RowEncoder {
If(
IsNull(field),
Literal.create(null, externalDataTypeFor(f.dataType)),
- constructorFor(BoundReference(i, f.dataType, f.nullable), f.dataType)
+ constructorFor(BoundReference(i, f.dataType, f.nullable))
)
}
CreateExternalRow(fields)
}
- private def constructorFor(input: Expression, dataType: DataType): Expression = dataType match {
+ private def constructorFor(input: Expression): Expression = input.dataType match {
case BooleanType | ByteType | ShortType | IntegerType | LongType |
FloatType | DoubleType | BinaryType => input
@@ -170,7 +178,7 @@ object RowEncoder {
case ArrayType(et, nullable) =>
val arrayData =
Invoke(
- MapObjects(constructorFor(_, et), input, et),
+ MapObjects(constructorFor, input, et),
"array",
ObjectType(classOf[Array[_]]))
StaticInvoke(
@@ -181,10 +189,10 @@ object RowEncoder {
case MapType(kt, vt, valueNullable) =>
val keyArrayType = ArrayType(kt, false)
- val keyData = constructorFor(Invoke(input, "keyArray", keyArrayType), keyArrayType)
+ val keyData = constructorFor(Invoke(input, "keyArray", keyArrayType))
val valueArrayType = ArrayType(vt, valueNullable)
- val valueData = constructorFor(Invoke(input, "valueArray", valueArrayType), valueArrayType)
+ val valueData = constructorFor(Invoke(input, "valueArray", valueArrayType))
StaticInvoke(
ArrayBasedMapData,
@@ -197,42 +205,8 @@ object RowEncoder {
If(
Invoke(input, "isNullAt", BooleanType, Literal(i) :: Nil),
Literal.create(null, externalDataTypeFor(f.dataType)),
- constructorFor(getField(input, i, f.dataType), f.dataType))
+ constructorFor(GetInternalRowField(input, i, f.dataType)))
}
CreateExternalRow(convertedFields)
}
-
- private def getField(
- row: Expression,
- ordinal: Int,
- dataType: DataType): Expression = dataType match {
- case BooleanType =>
- Invoke(row, "getBoolean", dataType, Literal(ordinal) :: Nil)
- case ByteType =>
- Invoke(row, "getByte", dataType, Literal(ordinal) :: Nil)
- case ShortType =>
- Invoke(row, "getShort", dataType, Literal(ordinal) :: Nil)
- case IntegerType | DateType =>
- Invoke(row, "getInt", dataType, Literal(ordinal) :: Nil)
- case LongType | TimestampType =>
- Invoke(row, "getLong", dataType, Literal(ordinal) :: Nil)
- case FloatType =>
- Invoke(row, "getFloat", dataType, Literal(ordinal) :: Nil)
- case DoubleType =>
- Invoke(row, "getDouble", dataType, Literal(ordinal) :: Nil)
- case t: DecimalType =>
- Invoke(row, "getDecimal", dataType, Seq(ordinal, t.precision, t.scale).map(Literal(_)))
- case StringType =>
- Invoke(row, "getUTF8String", dataType, Literal(ordinal) :: Nil)
- case BinaryType =>
- Invoke(row, "getBinary", dataType, Literal(ordinal) :: Nil)
- case CalendarIntervalType =>
- Invoke(row, "getInterval", dataType, Literal(ordinal) :: Nil)
- case t: StructType =>
- Invoke(row, "getStruct", dataType, Literal(ordinal) :: Literal(t.size) :: Nil)
- case _: ArrayType =>
- Invoke(row, "getArray", dataType, Literal(ordinal) :: Nil)
- case _: MapType =>
- Invoke(row, "getMap", dataType, Literal(ordinal) :: Nil)
- }
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala
index f5fff90e5a..deff8a5378 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala
@@ -110,7 +110,7 @@ object DateTimeUtils {
}
def stringToTime(s: String): java.util.Date = {
- var indexOfGMT = s.indexOf("GMT");
+ val indexOfGMT = s.indexOf("GMT")
if (indexOfGMT != -1) {
// ISO8601 with a weird time zone specifier (2000-01-01T00:00GMT+01:00)
val s0 = s.substring(0, indexOfGMT)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala
index e9bf7b33e3..96588bb5dc 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala
@@ -23,7 +23,7 @@ import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
class GenericArrayData(val array: Array[Any]) extends ArrayData {
- def this(seq: scala.collection.GenIterable[Any]) = this(seq.toArray)
+ def this(seq: Seq[Any]) = this(seq.toArray)
// TODO: This is boxing. We should specialize.
def this(primitiveArray: Array[Int]) = this(primitiveArray.toSeq)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
index b0dacf7f55..9fe64b4cf1 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
@@ -17,232 +17,27 @@
package org.apache.spark.sql.catalyst.encoders
-import scala.collection.mutable.ArrayBuffer
-import scala.reflect.runtime.universe._
+import java.util.Arrays
import org.apache.spark.SparkFunSuite
-import org.apache.spark.sql.catalyst._
+import org.apache.spark.sql.catalyst.expressions.AttributeReference
import org.apache.spark.sql.catalyst.util.ArrayData
-import org.apache.spark.sql.types.{StructField, ArrayType}
-
-case class RepeatedStruct(s: Seq[PrimitiveData])
-
-case class NestedArray(a: Array[Array[Int]])
-
-case class BoxedData(
- intField: java.lang.Integer,
- longField: java.lang.Long,
- doubleField: java.lang.Double,
- floatField: java.lang.Float,
- shortField: java.lang.Short,
- byteField: java.lang.Byte,
- booleanField: java.lang.Boolean)
-
-case class RepeatedData(
- arrayField: Seq[Int],
- arrayFieldContainsNull: Seq[java.lang.Integer],
- mapField: scala.collection.Map[Int, Long],
- mapFieldNull: scala.collection.Map[Int, java.lang.Long],
- structField: PrimitiveData)
-
-case class SpecificCollection(l: List[Int])
-
-class ExpressionEncoderSuite extends SparkFunSuite {
-
- encodeDecodeTest(1)
- encodeDecodeTest(1L)
- encodeDecodeTest(1.toDouble)
- encodeDecodeTest(1.toFloat)
- encodeDecodeTest(true)
- encodeDecodeTest(false)
- encodeDecodeTest(1.toShort)
- encodeDecodeTest(1.toByte)
- encodeDecodeTest("hello")
-
- encodeDecodeTest(PrimitiveData(1, 1, 1, 1, 1, 1, true))
-
- // TODO: Support creating specific subclasses of Seq.
- ignore("Specific collection types") { encodeDecodeTest(SpecificCollection(1 :: Nil)) }
-
- encodeDecodeTest(
- OptionalData(
- Some(2), Some(2), Some(2), Some(2), Some(2), Some(2), Some(true),
- Some(PrimitiveData(1, 1, 1, 1, 1, 1, true))))
-
- encodeDecodeTest(OptionalData(None, None, None, None, None, None, None, None))
-
- encodeDecodeTest(
- BoxedData(1, 1L, 1.0, 1.0f, 1.toShort, 1.toByte, true))
-
- encodeDecodeTest(
- BoxedData(null, null, null, null, null, null, null))
-
- encodeDecodeTest(
- RepeatedStruct(PrimitiveData(1, 1, 1, 1, 1, 1, true) :: Nil))
-
- encodeDecodeTest(
- RepeatedData(
- Seq(1, 2),
- Seq(new Integer(1), null, new Integer(2)),
- Map(1 -> 2L),
- Map(1 -> null),
- PrimitiveData(1, 1, 1, 1, 1, 1, true)))
-
- encodeDecodeTest(("nullable Seq[Integer]", Seq[Integer](1, null)))
-
- encodeDecodeTest(("Seq[(String, String)]",
- Seq(("a", "b"))))
- encodeDecodeTest(("Seq[(Int, Int)]",
- Seq((1, 2))))
- encodeDecodeTest(("Seq[(Long, Long)]",
- Seq((1L, 2L))))
- encodeDecodeTest(("Seq[(Float, Float)]",
- Seq((1.toFloat, 2.toFloat))))
- encodeDecodeTest(("Seq[(Double, Double)]",
- Seq((1.toDouble, 2.toDouble))))
- encodeDecodeTest(("Seq[(Short, Short)]",
- Seq((1.toShort, 2.toShort))))
- encodeDecodeTest(("Seq[(Byte, Byte)]",
- Seq((1.toByte, 2.toByte))))
- encodeDecodeTest(("Seq[(Boolean, Boolean)]",
- Seq((true, false))))
-
- // TODO: Decoding/encoding of complex maps.
- ignore("complex maps") {
- encodeDecodeTest(("Map[Int, (String, String)]",
- Map(1 ->("a", "b"))))
- }
-
- encodeDecodeTest(("ArrayBuffer[(String, String)]",
- ArrayBuffer(("a", "b"))))
- encodeDecodeTest(("ArrayBuffer[(Int, Int)]",
- ArrayBuffer((1, 2))))
- encodeDecodeTest(("ArrayBuffer[(Long, Long)]",
- ArrayBuffer((1L, 2L))))
- encodeDecodeTest(("ArrayBuffer[(Float, Float)]",
- ArrayBuffer((1.toFloat, 2.toFloat))))
- encodeDecodeTest(("ArrayBuffer[(Double, Double)]",
- ArrayBuffer((1.toDouble, 2.toDouble))))
- encodeDecodeTest(("ArrayBuffer[(Short, Short)]",
- ArrayBuffer((1.toShort, 2.toShort))))
- encodeDecodeTest(("ArrayBuffer[(Byte, Byte)]",
- ArrayBuffer((1.toByte, 2.toByte))))
- encodeDecodeTest(("ArrayBuffer[(Boolean, Boolean)]",
- ArrayBuffer((true, false))))
-
- encodeDecodeTest(("Seq[Seq[(Int, Int)]]",
- Seq(Seq((1, 2)))))
-
- encodeDecodeTestCustom(("Array[Array[(Int, Int)]]",
- Array(Array((1, 2)))))
- { (l, r) => l._2(0)(0) == r._2(0)(0) }
-
- encodeDecodeTestCustom(("Array[Array[(Int, Int)]]",
- Array(Array(Array((1, 2))))))
- { (l, r) => l._2(0)(0)(0) == r._2(0)(0)(0) }
-
- encodeDecodeTestCustom(("Array[Array[Array[(Int, Int)]]]",
- Array(Array(Array(Array((1, 2)))))))
- { (l, r) => l._2(0)(0)(0)(0) == r._2(0)(0)(0)(0) }
-
- encodeDecodeTestCustom(("Array[Array[Array[Array[(Int, Int)]]]]",
- Array(Array(Array(Array(Array((1, 2))))))))
- { (l, r) => l._2(0)(0)(0)(0)(0) == r._2(0)(0)(0)(0)(0) }
-
-
- encodeDecodeTestCustom(("Array[Array[Integer]]",
- Array(Array[Integer](1))))
- { (l, r) => l._2(0)(0) == r._2(0)(0) }
-
- encodeDecodeTestCustom(("Array[Array[Int]]",
- Array(Array(1))))
- { (l, r) => l._2(0)(0) == r._2(0)(0) }
-
- encodeDecodeTestCustom(("Array[Array[Int]]",
- Array(Array(Array(1)))))
- { (l, r) => l._2(0)(0)(0) == r._2(0)(0)(0) }
-
- encodeDecodeTestCustom(("Array[Array[Array[Int]]]",
- Array(Array(Array(Array(1))))))
- { (l, r) => l._2(0)(0)(0)(0) == r._2(0)(0)(0)(0) }
-
- encodeDecodeTestCustom(("Array[Array[Array[Array[Int]]]]",
- Array(Array(Array(Array(Array(1)))))))
- { (l, r) => l._2(0)(0)(0)(0)(0) == r._2(0)(0)(0)(0)(0) }
-
- encodeDecodeTest(("Array[Byte] null",
- null: Array[Byte]))
- encodeDecodeTestCustom(("Array[Byte]",
- Array[Byte](1, 2, 3)))
- { (l, r) => java.util.Arrays.equals(l._2, r._2) }
-
- encodeDecodeTest(("Array[Int] null",
- null: Array[Int]))
- encodeDecodeTestCustom(("Array[Int]",
- Array[Int](1, 2, 3)))
- { (l, r) => java.util.Arrays.equals(l._2, r._2) }
-
- encodeDecodeTest(("Array[Long] null",
- null: Array[Long]))
- encodeDecodeTestCustom(("Array[Long]",
- Array[Long](1, 2, 3)))
- { (l, r) => java.util.Arrays.equals(l._2, r._2) }
-
- encodeDecodeTest(("Array[Double] null",
- null: Array[Double]))
- encodeDecodeTestCustom(("Array[Double]",
- Array[Double](1, 2, 3)))
- { (l, r) => java.util.Arrays.equals(l._2, r._2) }
-
- encodeDecodeTest(("Array[Float] null",
- null: Array[Float]))
- encodeDecodeTestCustom(("Array[Float]",
- Array[Float](1, 2, 3)))
- { (l, r) => java.util.Arrays.equals(l._2, r._2) }
-
- encodeDecodeTest(("Array[Boolean] null",
- null: Array[Boolean]))
- encodeDecodeTestCustom(("Array[Boolean]",
- Array[Boolean](true, false)))
- { (l, r) => java.util.Arrays.equals(l._2, r._2) }
-
- encodeDecodeTest(("Array[Short] null",
- null: Array[Short]))
- encodeDecodeTestCustom(("Array[Short]",
- Array[Short](1, 2, 3)))
- { (l, r) => java.util.Arrays.equals(l._2, r._2) }
-
- encodeDecodeTestCustom(("java.sql.Timestamp",
- new java.sql.Timestamp(1)))
- { (l, r) => l._2.toString == r._2.toString }
-
- encodeDecodeTestCustom(("java.sql.Date", new java.sql.Date(1)))
- { (l, r) => l._2.toString == r._2.toString }
-
- /** Simplified encodeDecodeTestCustom, where the comparison function can be `Object.equals`. */
- protected def encodeDecodeTest[T : TypeTag](inputData: T) =
- encodeDecodeTestCustom[T](inputData)((l, r) => l == r)
-
- /**
- * Constructs a test that round-trips `t` through an encoder, checking the results to ensure it
- * matches the original.
- */
- protected def encodeDecodeTestCustom[T : TypeTag](
- inputData: T)(
- c: (T, T) => Boolean) = {
- test(s"encode/decode: $inputData - ${inputData.getClass.getName}") {
- val encoder = try ExpressionEncoder[T]() catch {
- case e: Exception =>
- fail(s"Exception thrown generating encoder", e)
- }
- val convertedData = encoder.toRow(inputData)
+import org.apache.spark.sql.types.ArrayType
+
+abstract class ExpressionEncoderSuite extends SparkFunSuite {
+ protected def encodeDecodeTest[T](
+ input: T,
+ encoder: ExpressionEncoder[T],
+ testName: String): Unit = {
+ test(s"encode/decode for $testName: $input") {
+ val row = encoder.toRow(input)
val schema = encoder.schema.toAttributes
val boundEncoder = encoder.resolve(schema).bind(schema)
- val convertedBack = try boundEncoder.fromRow(convertedData) catch {
+ val convertedBack = try boundEncoder.fromRow(row) catch {
case e: Exception =>
fail(
s"""Exception thrown while decoding
- |Converted: $convertedData
+ |Converted: $row
|Schema: ${schema.mkString(",")}
|${encoder.schema.treeString}
|
@@ -252,18 +47,27 @@ class ExpressionEncoderSuite extends SparkFunSuite {
""".stripMargin, e)
}
- if (!c(inputData, convertedBack)) {
+ val isCorrect = (input, convertedBack) match {
+ case (b1: Array[Byte], b2: Array[Byte]) => Arrays.equals(b1, b2)
+ case (b1: Array[Int], b2: Array[Int]) => Arrays.equals(b1, b2)
+ case (b1: Array[Array[_]], b2: Array[Array[_]]) =>
+ Arrays.deepEquals(b1.asInstanceOf[Array[AnyRef]], b2.asInstanceOf[Array[AnyRef]])
+ case (b1: Array[_], b2: Array[_]) =>
+ Arrays.equals(b1.asInstanceOf[Array[AnyRef]], b2.asInstanceOf[Array[AnyRef]])
+ case _ => input == convertedBack
+ }
+
+ if (!isCorrect) {
val types = convertedBack match {
case c: Product =>
c.productIterator.filter(_ != null).map(_.getClass.getName).mkString(",")
case other => other.getClass.getName
}
-
val encodedData = try {
- convertedData.toSeq(encoder.schema).zip(encoder.schema).map {
- case (a: ArrayData, StructField(_, at: ArrayType, _, _)) =>
- a.toArray[Any](at.elementType).toSeq
+ row.toSeq(encoder.schema).zip(schema).map {
+ case (a: ArrayData, AttributeReference(_, ArrayType(et, _), _, _)) =>
+ a.toArray[Any](et).toSeq
case (other, _) =>
other
}.mkString("[", ",", "]")
@@ -274,7 +78,7 @@ class ExpressionEncoderSuite extends SparkFunSuite {
fail(
s"""Encoded/Decoded data does not match input data
|
- |in: $inputData
+ |in: $input
|out: $convertedBack
|types: $types
|
@@ -282,11 +86,10 @@ class ExpressionEncoderSuite extends SparkFunSuite {
|Schema: ${schema.mkString(",")}
|${encoder.schema.treeString}
|
- |Extract Expressions:
- |$boundEncoder
+ |fromRow Expressions:
+ |${boundEncoder.fromRowExpression.treeString}
""".stripMargin)
- }
}
-
+ }
}
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/FlatEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/FlatEncoderSuite.scala
new file mode 100644
index 0000000000..55821c4370
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/FlatEncoderSuite.scala
@@ -0,0 +1,74 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.encoders
+
+import java.sql.{Date, Timestamp}
+
+class FlatEncoderSuite extends ExpressionEncoderSuite {
+ encodeDecodeTest(false, FlatEncoder[Boolean], "primitive boolean")
+ encodeDecodeTest(-3.toByte, FlatEncoder[Byte], "primitive byte")
+ encodeDecodeTest(-3.toShort, FlatEncoder[Short], "primitive short")
+ encodeDecodeTest(-3, FlatEncoder[Int], "primitive int")
+ encodeDecodeTest(-3L, FlatEncoder[Long], "primitive long")
+ encodeDecodeTest(-3.7f, FlatEncoder[Float], "primitive float")
+ encodeDecodeTest(-3.7, FlatEncoder[Double], "primitive double")
+
+ encodeDecodeTest(new java.lang.Boolean(false), FlatEncoder[java.lang.Boolean], "boxed boolean")
+ encodeDecodeTest(new java.lang.Byte(-3.toByte), FlatEncoder[java.lang.Byte], "boxed byte")
+ encodeDecodeTest(new java.lang.Short(-3.toShort), FlatEncoder[java.lang.Short], "boxed short")
+ encodeDecodeTest(new java.lang.Integer(-3), FlatEncoder[java.lang.Integer], "boxed int")
+ encodeDecodeTest(new java.lang.Long(-3L), FlatEncoder[java.lang.Long], "boxed long")
+ encodeDecodeTest(new java.lang.Float(-3.7f), FlatEncoder[java.lang.Float], "boxed float")
+ encodeDecodeTest(new java.lang.Double(-3.7), FlatEncoder[java.lang.Double], "boxed double")
+
+ encodeDecodeTest(BigDecimal("32131413.211321313"), FlatEncoder[BigDecimal], "scala decimal")
+ type JDecimal = java.math.BigDecimal
+ // encodeDecodeTest(new JDecimal("231341.23123"), FlatEncoder[JDecimal], "java decimal")
+
+ encodeDecodeTest("hello", FlatEncoder[String], "string")
+ encodeDecodeTest(Date.valueOf("2012-12-23"), FlatEncoder[Date], "date")
+ encodeDecodeTest(Timestamp.valueOf("2016-01-29 10:00:00"), FlatEncoder[Timestamp], "timestamp")
+ encodeDecodeTest(Array[Byte](13, 21, -23), FlatEncoder[Array[Byte]], "binary")
+
+ encodeDecodeTest(Seq(31, -123, 4), FlatEncoder[Seq[Int]], "seq of int")
+ encodeDecodeTest(Seq("abc", "xyz"), FlatEncoder[Seq[String]], "seq of string")
+ encodeDecodeTest(Seq("abc", null, "xyz"), FlatEncoder[Seq[String]], "seq of string with null")
+ encodeDecodeTest(Seq.empty[Int], FlatEncoder[Seq[Int]], "empty seq of int")
+ encodeDecodeTest(Seq.empty[String], FlatEncoder[Seq[String]], "empty seq of string")
+
+ encodeDecodeTest(Seq(Seq(31, -123), null, Seq(4, 67)),
+ FlatEncoder[Seq[Seq[Int]]], "seq of seq of int")
+ encodeDecodeTest(Seq(Seq("abc", "xyz"), Seq[String](null), null, Seq("1", null, "2")),
+ FlatEncoder[Seq[Seq[String]]], "seq of seq of string")
+
+ encodeDecodeTest(Array(31, -123, 4), FlatEncoder[Array[Int]], "array of int")
+ encodeDecodeTest(Array("abc", "xyz"), FlatEncoder[Array[String]], "array of string")
+ encodeDecodeTest(Array("a", null, "x"), FlatEncoder[Array[String]], "array of string with null")
+ encodeDecodeTest(Array.empty[Int], FlatEncoder[Array[Int]], "empty array of int")
+ encodeDecodeTest(Array.empty[String], FlatEncoder[Array[String]], "empty array of string")
+
+ encodeDecodeTest(Array(Array(31, -123), null, Array(4, 67)),
+ FlatEncoder[Array[Array[Int]]], "array of array of int")
+ encodeDecodeTest(Array(Array("abc", "xyz"), Array[String](null), null, Array("1", null, "2")),
+ FlatEncoder[Array[Array[String]]], "array of array of string")
+
+ encodeDecodeTest(Map(1 -> "a", 2 -> "b"), FlatEncoder[Map[Int, String]], "map")
+ encodeDecodeTest(Map(1 -> "a", 2 -> null), FlatEncoder[Map[Int, String]], "map with null")
+ encodeDecodeTest(Map(1 -> Map("a" -> 1), 2 -> Map("b" -> 2)),
+ FlatEncoder[Map[Int, Map[String, Int]]], "map of map")
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala
new file mode 100644
index 0000000000..fda978e705
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala
@@ -0,0 +1,123 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.encoders
+
+import scala.collection.mutable.ArrayBuffer
+import scala.reflect.runtime.universe.TypeTag
+
+import org.apache.spark.sql.catalyst.{OptionalData, PrimitiveData}
+
+case class RepeatedStruct(s: Seq[PrimitiveData])
+
+case class NestedArray(a: Array[Array[Int]]) {
+ override def equals(other: Any): Boolean = other match {
+ case NestedArray(otherArray) =>
+ java.util.Arrays.deepEquals(
+ a.asInstanceOf[Array[AnyRef]],
+ otherArray.asInstanceOf[Array[AnyRef]])
+ case _ => false
+ }
+}
+
+case class BoxedData(
+ intField: java.lang.Integer,
+ longField: java.lang.Long,
+ doubleField: java.lang.Double,
+ floatField: java.lang.Float,
+ shortField: java.lang.Short,
+ byteField: java.lang.Byte,
+ booleanField: java.lang.Boolean)
+
+case class RepeatedData(
+ arrayField: Seq[Int],
+ arrayFieldContainsNull: Seq[java.lang.Integer],
+ mapField: scala.collection.Map[Int, Long],
+ mapFieldNull: scala.collection.Map[Int, java.lang.Long],
+ structField: PrimitiveData)
+
+case class SpecificCollection(l: List[Int])
+
+class ProductEncoderSuite extends ExpressionEncoderSuite {
+
+ productTest(PrimitiveData(1, 1, 1, 1, 1, 1, true))
+
+ productTest(
+ OptionalData(Some(2), Some(2), Some(2), Some(2), Some(2), Some(2), Some(true),
+ Some(PrimitiveData(1, 1, 1, 1, 1, 1, true))))
+
+ productTest(OptionalData(None, None, None, None, None, None, None, None))
+
+ productTest(BoxedData(1, 1L, 1.0, 1.0f, 1.toShort, 1.toByte, true))
+
+ productTest(BoxedData(null, null, null, null, null, null, null))
+
+ productTest(RepeatedStruct(PrimitiveData(1, 1, 1, 1, 1, 1, true) :: Nil))
+
+ productTest((1, "test", PrimitiveData(1, 1, 1, 1, 1, 1, true)))
+
+ productTest(
+ RepeatedData(
+ Seq(1, 2),
+ Seq(new Integer(1), null, new Integer(2)),
+ Map(1 -> 2L),
+ Map(1 -> null),
+ PrimitiveData(1, 1, 1, 1, 1, 1, true)))
+
+ productTest(NestedArray(Array(Array(1, -2, 3), null, Array(4, 5, -6))))
+
+ productTest(("Seq[(String, String)]",
+ Seq(("a", "b"))))
+ productTest(("Seq[(Int, Int)]",
+ Seq((1, 2))))
+ productTest(("Seq[(Long, Long)]",
+ Seq((1L, 2L))))
+ productTest(("Seq[(Float, Float)]",
+ Seq((1.toFloat, 2.toFloat))))
+ productTest(("Seq[(Double, Double)]",
+ Seq((1.toDouble, 2.toDouble))))
+ productTest(("Seq[(Short, Short)]",
+ Seq((1.toShort, 2.toShort))))
+ productTest(("Seq[(Byte, Byte)]",
+ Seq((1.toByte, 2.toByte))))
+ productTest(("Seq[(Boolean, Boolean)]",
+ Seq((true, false))))
+
+ productTest(("ArrayBuffer[(String, String)]",
+ ArrayBuffer(("a", "b"))))
+ productTest(("ArrayBuffer[(Int, Int)]",
+ ArrayBuffer((1, 2))))
+ productTest(("ArrayBuffer[(Long, Long)]",
+ ArrayBuffer((1L, 2L))))
+ productTest(("ArrayBuffer[(Float, Float)]",
+ ArrayBuffer((1.toFloat, 2.toFloat))))
+ productTest(("ArrayBuffer[(Double, Double)]",
+ ArrayBuffer((1.toDouble, 2.toDouble))))
+ productTest(("ArrayBuffer[(Short, Short)]",
+ ArrayBuffer((1.toShort, 2.toShort))))
+ productTest(("ArrayBuffer[(Byte, Byte)]",
+ ArrayBuffer((1.toByte, 2.toByte))))
+ productTest(("ArrayBuffer[(Boolean, Boolean)]",
+ ArrayBuffer((true, false))))
+
+ productTest(("Seq[Seq[(Int, Int)]]",
+ Seq(Seq((1, 2)))))
+
+ private def productTest[T <: Product : TypeTag](input: T): Unit = {
+ encodeDecodeTest(input, ProductEncoder[T], input.getClass.getSimpleName)
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
index 9c16940707..ebcf4c8bfe 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
@@ -22,7 +22,7 @@ import scala.collection.JavaConverters._
import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.function._
-import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, encoderFor}
+import org.apache.spark.sql.catalyst.encoders.{FlatEncoder, ExpressionEncoder, encoderFor}
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.execution.QueryExecution
@@ -56,9 +56,6 @@ class GroupedDataset[K, T] private[sql](
private val resolvedKEncoder = unresolvedKEncoder.resolve(groupingAttributes)
private val resolvedTEncoder = unresolvedTEncoder.resolve(dataAttributes)
- /** Encoders for built in aggregations. */
- private implicit def newLongEncoder: Encoder[Long] = ExpressionEncoder[Long](flat = true)
-
private def logicalPlan = queryExecution.analyzed
private def sqlContext = queryExecution.sqlContext
@@ -211,7 +208,7 @@ class GroupedDataset[K, T] private[sql](
* Returns a [[Dataset]] that contains a tuple with each key and the number of items present
* for that key.
*/
- def count(): Dataset[(K, Long)] = agg(functions.count("*").as[Long])
+ def count(): Dataset[(K, Long)] = agg(functions.count("*").as(FlatEncoder[Long]))
/**
* Applies the given function to each cogrouped data. For each unique group, the function will
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala
index 6da46a5f7e..8471eea1b7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala
@@ -37,17 +37,21 @@ import org.apache.spark.unsafe.types.UTF8String
abstract class SQLImplicits {
protected def _sqlContext: SQLContext
- implicit def newProductEncoder[T <: Product : TypeTag]: Encoder[T] = ExpressionEncoder[T]()
+ implicit def newProductEncoder[T <: Product : TypeTag]: Encoder[T] = ProductEncoder[T]
- implicit def newIntEncoder: Encoder[Int] = ExpressionEncoder[Int](flat = true)
- implicit def newLongEncoder: Encoder[Long] = ExpressionEncoder[Long](flat = true)
- implicit def newDoubleEncoder: Encoder[Double] = ExpressionEncoder[Double](flat = true)
- implicit def newFloatEncoder: Encoder[Float] = ExpressionEncoder[Float](flat = true)
- implicit def newByteEncoder: Encoder[Byte] = ExpressionEncoder[Byte](flat = true)
- implicit def newShortEncoder: Encoder[Short] = ExpressionEncoder[Short](flat = true)
- implicit def newBooleanEncoder: Encoder[Boolean] = ExpressionEncoder[Boolean](flat = true)
- implicit def newStringEncoder: Encoder[String] = ExpressionEncoder[String](flat = true)
+ implicit def newIntEncoder: Encoder[Int] = FlatEncoder[Int]
+ implicit def newLongEncoder: Encoder[Long] = FlatEncoder[Long]
+ implicit def newDoubleEncoder: Encoder[Double] = FlatEncoder[Double]
+ implicit def newFloatEncoder: Encoder[Float] = FlatEncoder[Float]
+ implicit def newByteEncoder: Encoder[Byte] = FlatEncoder[Byte]
+ implicit def newShortEncoder: Encoder[Short] = FlatEncoder[Short]
+ implicit def newBooleanEncoder: Encoder[Boolean] = FlatEncoder[Boolean]
+ implicit def newStringEncoder: Encoder[String] = FlatEncoder[String]
+ /**
+ * Creates a [[Dataset]] from an RDD.
+ * @since 1.6.0
+ */
implicit def rddToDatasetHolder[T : Encoder](rdd: RDD[T]): DatasetHolder[T] = {
DatasetHolder(_sqlContext.createDataset(rdd))
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index 53cc6e0cda..95158de710 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -26,7 +26,7 @@ import scala.util.Try
import org.apache.spark.annotation.Experimental
import org.apache.spark.sql.catalyst.{SqlParser, ScalaReflection}
import org.apache.spark.sql.catalyst.analysis.{UnresolvedFunction, Star}
-import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
+import org.apache.spark.sql.catalyst.encoders.FlatEncoder
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.plans.logical.BroadcastHint
@@ -267,7 +267,7 @@ object functions extends LegacyFunctions {
* @since 1.3.0
*/
def count(columnName: String): TypedColumn[Any, Long] =
- count(Column(columnName)).as(ExpressionEncoder[Long](flat = true))
+ count(Column(columnName)).as(FlatEncoder[Long])
/**
* Aggregate function: returns the number of distinct items in a group.