aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorWenchen Fan <wenchen@databricks.com>2015-11-19 12:54:25 -0800
committerMichael Armbrust <michael@databricks.com>2015-11-19 12:54:25 -0800
commit47d1c2325caaf9ffe31695b6fff529314b8582f7 (patch)
tree794c849f2833c33e1882ff040a7bfe051f8821df
parent7d4aba18722727c85893ad8d8f07d4494665dcfc (diff)
downloadspark-47d1c2325caaf9ffe31695b6fff529314b8582f7.tar.gz
spark-47d1c2325caaf9ffe31695b6fff529314b8582f7.tar.bz2
spark-47d1c2325caaf9ffe31695b6fff529314b8582f7.zip
[SPARK-11750][SQL] revert SPARK-11727 and code clean up
After some experiment, I found it's not convenient to have separate encoder builders: `FlatEncoder` and `ProductEncoder`. For example, when create encoders for `ScalaUDF`, we have no idea if the type `T` is flat or not. So I revert the splitting change in https://github.com/apache/spark/pull/9693, while still keeping the bug fixes and tests. Author: Wenchen Fan <wenchen@databricks.com> Closes #9726 from cloud-fan/follow.
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala16
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala354
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala19
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/FlatEncoder.scala50
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala452
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala12
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala7
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala68
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala218
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/FlatEncoderSuite.scala99
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala156
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala23
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/functions.scala4
14 files changed, 364 insertions, 1118 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala
index d54f2854fb..86bb536459 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala
@@ -45,14 +45,14 @@ trait Encoder[T] extends Serializable {
*/
object Encoders {
- def BOOLEAN: Encoder[java.lang.Boolean] = ExpressionEncoder(flat = true)
- def BYTE: Encoder[java.lang.Byte] = ExpressionEncoder(flat = true)
- def SHORT: Encoder[java.lang.Short] = ExpressionEncoder(flat = true)
- def INT: Encoder[java.lang.Integer] = ExpressionEncoder(flat = true)
- def LONG: Encoder[java.lang.Long] = ExpressionEncoder(flat = true)
- def FLOAT: Encoder[java.lang.Float] = ExpressionEncoder(flat = true)
- def DOUBLE: Encoder[java.lang.Double] = ExpressionEncoder(flat = true)
- def STRING: Encoder[java.lang.String] = ExpressionEncoder(flat = true)
+ def BOOLEAN: Encoder[java.lang.Boolean] = ExpressionEncoder()
+ def BYTE: Encoder[java.lang.Byte] = ExpressionEncoder()
+ def SHORT: Encoder[java.lang.Short] = ExpressionEncoder()
+ def INT: Encoder[java.lang.Integer] = ExpressionEncoder()
+ def LONG: Encoder[java.lang.Long] = ExpressionEncoder()
+ def FLOAT: Encoder[java.lang.Float] = ExpressionEncoder()
+ def DOUBLE: Encoder[java.lang.Double] = ExpressionEncoder()
+ def STRING: Encoder[java.lang.String] = ExpressionEncoder()
/**
* (Scala-specific) Creates an encoder that serializes objects of type T using Kryo.
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index 59ccf356f2..33ae700706 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -50,39 +50,29 @@ object ScalaReflection extends ScalaReflection {
* Unlike `schemaFor`, this function doesn't do any massaging of types into the Spark SQL type
* system. As a result, ObjectType will be returned for things like boxed Integers
*/
- def dataTypeFor(tpe: `Type`): DataType = tpe match {
- case t if t <:< definitions.IntTpe => IntegerType
- case t if t <:< definitions.LongTpe => LongType
- case t if t <:< definitions.DoubleTpe => DoubleType
- case t if t <:< definitions.FloatTpe => FloatType
- case t if t <:< definitions.ShortTpe => ShortType
- case t if t <:< definitions.ByteTpe => ByteType
- case t if t <:< definitions.BooleanTpe => BooleanType
- case t if t <:< localTypeOf[Array[Byte]] => BinaryType
- case _ =>
- val className: String = tpe.erasure.typeSymbol.asClass.fullName
- className match {
- case "scala.Array" =>
- val TypeRef(_, _, Seq(arrayType)) = tpe
- val cls = arrayType match {
- case t if t <:< definitions.IntTpe => classOf[Array[Int]]
- case t if t <:< definitions.LongTpe => classOf[Array[Long]]
- case t if t <:< definitions.DoubleTpe => classOf[Array[Double]]
- case t if t <:< definitions.FloatTpe => classOf[Array[Float]]
- case t if t <:< definitions.ShortTpe => classOf[Array[Short]]
- case t if t <:< definitions.ByteTpe => classOf[Array[Byte]]
- case t if t <:< definitions.BooleanTpe => classOf[Array[Boolean]]
- case other =>
- // There is probably a better way to do this, but I couldn't find it...
- val elementType = dataTypeFor(other).asInstanceOf[ObjectType].cls
- java.lang.reflect.Array.newInstance(elementType, 1).getClass
+ def dataTypeFor[T : TypeTag]: DataType = dataTypeFor(localTypeOf[T])
- }
- ObjectType(cls)
- case other =>
- val clazz = mirror.runtimeClass(tpe.erasure.typeSymbol.asClass)
- ObjectType(clazz)
- }
+ private def dataTypeFor(tpe: `Type`): DataType = ScalaReflectionLock.synchronized {
+ tpe match {
+ case t if t <:< definitions.IntTpe => IntegerType
+ case t if t <:< definitions.LongTpe => LongType
+ case t if t <:< definitions.DoubleTpe => DoubleType
+ case t if t <:< definitions.FloatTpe => FloatType
+ case t if t <:< definitions.ShortTpe => ShortType
+ case t if t <:< definitions.ByteTpe => ByteType
+ case t if t <:< definitions.BooleanTpe => BooleanType
+ case t if t <:< localTypeOf[Array[Byte]] => BinaryType
+ case _ =>
+ val className: String = tpe.erasure.typeSymbol.asClass.fullName
+ className match {
+ case "scala.Array" =>
+ val TypeRef(_, _, Seq(elementType)) = tpe
+ arrayClassFor(elementType)
+ case other =>
+ val clazz = mirror.runtimeClass(tpe.erasure.typeSymbol.asClass)
+ ObjectType(clazz)
+ }
+ }
}
/**
@@ -90,7 +80,7 @@ object ScalaReflection extends ScalaReflection {
* Array[T]. Special handling is performed for primitive types to map them back to their raw
* JVM form instead of the Scala Array that handles auto boxing.
*/
- def arrayClassFor(tpe: `Type`): DataType = {
+ private def arrayClassFor(tpe: `Type`): DataType = ScalaReflectionLock.synchronized {
val cls = tpe match {
case t if t <:< definitions.IntTpe => classOf[Array[Int]]
case t if t <:< definitions.LongTpe => classOf[Array[Long]]
@@ -109,6 +99,15 @@ object ScalaReflection extends ScalaReflection {
}
/**
+ * Returns true if the value of this data type is same between internal and external.
+ */
+ def isNativeType(dt: DataType): Boolean = dt match {
+ case BooleanType | ByteType | ShortType | IntegerType | LongType |
+ FloatType | DoubleType | BinaryType => true
+ case _ => false
+ }
+
+ /**
* Returns an expression that can be used to construct an object of type `T` given an input
* row with a compatible schema. Fields of the row will be extracted using UnresolvedAttributes
* of the same name as the constructor arguments. Nested classes will have their fields accessed
@@ -116,63 +115,33 @@ object ScalaReflection extends ScalaReflection {
*
* When used on a primitive type, the constructor will instead default to extracting the value
* from ordinal 0 (since there are no names to map to). The actual location can be moved by
- * calling unbind/bind with a new schema.
+ * calling resolve/bind with a new schema.
*/
- def constructorFor[T : TypeTag]: Expression = constructorFor(typeOf[T], None)
+ def constructorFor[T : TypeTag]: Expression = constructorFor(localTypeOf[T], None)
private def constructorFor(
tpe: `Type`,
path: Option[Expression]): Expression = ScalaReflectionLock.synchronized {
/** Returns the current path with a sub-field extracted. */
- def addToPath(part: String): Expression =
- path
- .map(p => UnresolvedExtractValue(p, expressions.Literal(part)))
- .getOrElse(UnresolvedAttribute(part))
+ def addToPath(part: String): Expression = path
+ .map(p => UnresolvedExtractValue(p, expressions.Literal(part)))
+ .getOrElse(UnresolvedAttribute(part))
/** Returns the current path with a field at ordinal extracted. */
- def addToPathOrdinal(ordinal: Int, dataType: DataType): Expression =
- path
- .map(p => GetStructField(p, StructField(s"_$ordinal", dataType), ordinal))
- .getOrElse(BoundReference(ordinal, dataType, false))
+ def addToPathOrdinal(ordinal: Int, dataType: DataType): Expression = path
+ .map(p => GetInternalRowField(p, ordinal, dataType))
+ .getOrElse(BoundReference(ordinal, dataType, false))
- /** Returns the current path or throws an error. */
- def getPath = path.getOrElse(BoundReference(0, schemaFor(tpe).dataType, true))
+ /** Returns the current path or `BoundReference`. */
+ def getPath: Expression = path.getOrElse(BoundReference(0, schemaFor(tpe).dataType, true))
tpe match {
- case t if !dataTypeFor(t).isInstanceOf[ObjectType] =>
- getPath
+ case t if !dataTypeFor(t).isInstanceOf[ObjectType] => getPath
case t if t <:< localTypeOf[Option[_]] =>
val TypeRef(_, _, Seq(optType)) = t
- val boxedType = optType match {
- // For primitive types we must manually box the primitive value.
- case t if t <:< definitions.IntTpe => Some(classOf[java.lang.Integer])
- case t if t <:< definitions.LongTpe => Some(classOf[java.lang.Long])
- case t if t <:< definitions.DoubleTpe => Some(classOf[java.lang.Double])
- case t if t <:< definitions.FloatTpe => Some(classOf[java.lang.Float])
- case t if t <:< definitions.ShortTpe => Some(classOf[java.lang.Short])
- case t if t <:< definitions.ByteTpe => Some(classOf[java.lang.Byte])
- case t if t <:< definitions.BooleanTpe => Some(classOf[java.lang.Boolean])
- case _ => None
- }
-
- boxedType.map { boxedType =>
- val objectType = ObjectType(boxedType)
- WrapOption(
- objectType,
- NewInstance(
- boxedType,
- getPath :: Nil,
- propagateNull = true,
- objectType))
- }.getOrElse {
- val className: String = optType.erasure.typeSymbol.asClass.fullName
- val cls = Utils.classForName(className)
- val objectType = ObjectType(cls)
-
- WrapOption(objectType, constructorFor(optType, path))
- }
+ WrapOption(constructorFor(optType, path))
case t if t <:< localTypeOf[java.lang.Integer] =>
val boxedType = classOf[java.lang.Integer]
@@ -231,11 +200,11 @@ object ScalaReflection extends ScalaReflection {
case t if t <:< localTypeOf[java.math.BigDecimal] =>
Invoke(getPath, "toJavaBigDecimal", ObjectType(classOf[java.math.BigDecimal]))
+ case t if t <:< localTypeOf[BigDecimal] =>
+ Invoke(getPath, "toBigDecimal", ObjectType(classOf[BigDecimal]))
+
case t if t <:< localTypeOf[Array[_]] =>
val TypeRef(_, _, Seq(elementType)) = t
- val elementDataType = dataTypeFor(elementType)
- val Schema(dataType, nullable) = schemaFor(elementType)
-
val primitiveMethod = elementType match {
case t if t <:< definitions.IntTpe => Some("toIntArray")
case t if t <:< definitions.LongTpe => Some("toLongArray")
@@ -248,57 +217,52 @@ object ScalaReflection extends ScalaReflection {
}
primitiveMethod.map { method =>
- Invoke(getPath, method, dataTypeFor(t))
+ Invoke(getPath, method, arrayClassFor(elementType))
}.getOrElse {
- val returnType = dataTypeFor(t)
Invoke(
- MapObjects(p => constructorFor(elementType, Some(p)), getPath, dataType),
+ MapObjects(
+ p => constructorFor(elementType, Some(p)),
+ getPath,
+ schemaFor(elementType).dataType),
"array",
- returnType)
+ arrayClassFor(elementType))
}
+ case t if t <:< localTypeOf[Seq[_]] =>
+ val TypeRef(_, _, Seq(elementType)) = t
+ val arrayData =
+ Invoke(
+ MapObjects(
+ p => constructorFor(elementType, Some(p)),
+ getPath,
+ schemaFor(elementType).dataType),
+ "array",
+ ObjectType(classOf[Array[Any]]))
+
+ StaticInvoke(
+ scala.collection.mutable.WrappedArray,
+ ObjectType(classOf[Seq[_]]),
+ "make",
+ arrayData :: Nil)
+
case t if t <:< localTypeOf[Map[_, _]] =>
val TypeRef(_, _, Seq(keyType, valueType)) = t
- val Schema(keyDataType, _) = schemaFor(keyType)
- val Schema(valueDataType, valueNullable) = schemaFor(valueType)
-
- val primitiveMethodKey = keyType match {
- case t if t <:< definitions.IntTpe => Some("toIntArray")
- case t if t <:< definitions.LongTpe => Some("toLongArray")
- case t if t <:< definitions.DoubleTpe => Some("toDoubleArray")
- case t if t <:< definitions.FloatTpe => Some("toFloatArray")
- case t if t <:< definitions.ShortTpe => Some("toShortArray")
- case t if t <:< definitions.ByteTpe => Some("toByteArray")
- case t if t <:< definitions.BooleanTpe => Some("toBooleanArray")
- case _ => None
- }
val keyData =
Invoke(
MapObjects(
p => constructorFor(keyType, Some(p)),
- Invoke(getPath, "keyArray", ArrayType(keyDataType)),
- keyDataType),
+ Invoke(getPath, "keyArray", ArrayType(schemaFor(keyType).dataType)),
+ schemaFor(keyType).dataType),
"array",
ObjectType(classOf[Array[Any]]))
- val primitiveMethodValue = valueType match {
- case t if t <:< definitions.IntTpe => Some("toIntArray")
- case t if t <:< definitions.LongTpe => Some("toLongArray")
- case t if t <:< definitions.DoubleTpe => Some("toDoubleArray")
- case t if t <:< definitions.FloatTpe => Some("toFloatArray")
- case t if t <:< definitions.ShortTpe => Some("toShortArray")
- case t if t <:< definitions.ByteTpe => Some("toByteArray")
- case t if t <:< definitions.BooleanTpe => Some("toBooleanArray")
- case _ => None
- }
-
val valueData =
Invoke(
MapObjects(
p => constructorFor(valueType, Some(p)),
- Invoke(getPath, "valueArray", ArrayType(valueDataType)),
- valueDataType),
+ Invoke(getPath, "valueArray", ArrayType(schemaFor(valueType).dataType)),
+ schemaFor(valueType).dataType),
"array",
ObjectType(classOf[Array[Any]]))
@@ -308,40 +272,6 @@ object ScalaReflection extends ScalaReflection {
"toScalaMap",
keyData :: valueData :: Nil)
- case t if t <:< localTypeOf[Seq[_]] =>
- val TypeRef(_, _, Seq(elementType)) = t
- val elementDataType = dataTypeFor(elementType)
- val Schema(dataType, nullable) = schemaFor(elementType)
-
- // Avoid boxing when possible by just wrapping a primitive array.
- val primitiveMethod = elementType match {
- case _ if nullable => None
- case t if t <:< definitions.IntTpe => Some("toIntArray")
- case t if t <:< definitions.LongTpe => Some("toLongArray")
- case t if t <:< definitions.DoubleTpe => Some("toDoubleArray")
- case t if t <:< definitions.FloatTpe => Some("toFloatArray")
- case t if t <:< definitions.ShortTpe => Some("toShortArray")
- case t if t <:< definitions.ByteTpe => Some("toByteArray")
- case t if t <:< definitions.BooleanTpe => Some("toBooleanArray")
- case _ => None
- }
-
- val arrayData = primitiveMethod.map { method =>
- Invoke(getPath, method, arrayClassFor(elementType))
- }.getOrElse {
- Invoke(
- MapObjects(p => constructorFor(elementType, Some(p)), getPath, dataType),
- "array",
- arrayClassFor(elementType))
- }
-
- StaticInvoke(
- scala.collection.mutable.WrappedArray,
- ObjectType(classOf[Seq[_]]),
- "make",
- arrayData :: Nil)
-
-
case t if t <:< localTypeOf[Product] =>
val formalTypeArgs = t.typeSymbol.asClass.typeParams
val TypeRef(_, _, actualTypeArgs) = t
@@ -361,8 +291,7 @@ object ScalaReflection extends ScalaReflection {
}
}
- val className: String = t.erasure.typeSymbol.asClass.fullName
- val cls = Utils.classForName(className)
+ val cls = mirror.runtimeClass(tpe.erasure.typeSymbol.asClass)
val arguments = params.head.zipWithIndex.map { case (p, i) =>
val fieldName = p.name.toString
@@ -370,7 +299,7 @@ object ScalaReflection extends ScalaReflection {
val dataType = schemaFor(fieldType).dataType
// For tuples, we based grab the inner fields by ordinal instead of name.
- if (className startsWith "scala.Tuple") {
+ if (cls.getName startsWith "scala.Tuple") {
constructorFor(fieldType, Some(addToPathOrdinal(i, dataType)))
} else {
constructorFor(fieldType, Some(addToPath(fieldName)))
@@ -388,22 +317,19 @@ object ScalaReflection extends ScalaReflection {
} else {
newInstance
}
-
}
}
/** Returns expressions for extracting all the fields from the given type. */
def extractorsFor[T : TypeTag](inputObject: Expression): CreateNamedStruct = {
- ScalaReflectionLock.synchronized {
- extractorFor(inputObject, typeTag[T].tpe) match {
- case s: CreateNamedStruct => s
- case o => CreateNamedStruct(expressions.Literal("value") :: o :: Nil)
- }
+ extractorFor(inputObject, localTypeOf[T]) match {
+ case s: CreateNamedStruct => s
+ case other => CreateNamedStruct(expressions.Literal("value") :: other :: Nil)
}
}
/** Helper for extracting internal fields from a case class. */
- protected def extractorFor(
+ private def extractorFor(
inputObject: Expression,
tpe: `Type`): Expression = ScalaReflectionLock.synchronized {
if (!inputObject.dataType.isInstanceOf[ObjectType]) {
@@ -491,51 +417,36 @@ object ScalaReflection extends ScalaReflection {
case t if t <:< localTypeOf[Array[_]] =>
val TypeRef(_, _, Seq(elementType)) = t
- val elementDataType = dataTypeFor(elementType)
- val Schema(dataType, nullable) = schemaFor(elementType)
-
- if (!elementDataType.isInstanceOf[AtomicType]) {
- MapObjects(extractorFor(_, elementType), inputObject, elementDataType)
- } else {
- NewInstance(
- classOf[GenericArrayData],
- inputObject :: Nil,
- dataType = ArrayType(dataType, nullable))
- }
+ toCatalystArray(inputObject, elementType)
case t if t <:< localTypeOf[Seq[_]] =>
val TypeRef(_, _, Seq(elementType)) = t
- val elementDataType = dataTypeFor(elementType)
- val Schema(dataType, nullable) = schemaFor(elementType)
-
- if (dataType.isInstanceOf[AtomicType]) {
- NewInstance(
- classOf[GenericArrayData],
- inputObject :: Nil,
- dataType = ArrayType(dataType, nullable))
- } else {
- MapObjects(extractorFor(_, elementType), inputObject, elementDataType)
- }
+ toCatalystArray(inputObject, elementType)
case t if t <:< localTypeOf[Map[_, _]] =>
val TypeRef(_, _, Seq(keyType, valueType)) = t
- val Schema(keyDataType, _) = schemaFor(keyType)
- val Schema(valueDataType, valueNullable) = schemaFor(valueType)
- val rawMap = inputObject
val keys =
- NewInstance(
- classOf[GenericArrayData],
- Invoke(rawMap, "keys", ObjectType(classOf[scala.collection.GenIterable[_]])) :: Nil,
- dataType = ObjectType(classOf[ArrayData]))
+ Invoke(
+ Invoke(inputObject, "keysIterator",
+ ObjectType(classOf[scala.collection.Iterator[_]])),
+ "toSeq",
+ ObjectType(classOf[scala.collection.Seq[_]]))
+ val convertedKeys = toCatalystArray(keys, keyType)
+
val values =
- NewInstance(
- classOf[GenericArrayData],
- Invoke(rawMap, "values", ObjectType(classOf[scala.collection.GenIterable[_]])) :: Nil,
- dataType = ObjectType(classOf[ArrayData]))
+ Invoke(
+ Invoke(inputObject, "valuesIterator",
+ ObjectType(classOf[scala.collection.Iterator[_]])),
+ "toSeq",
+ ObjectType(classOf[scala.collection.Seq[_]]))
+ val convertedValues = toCatalystArray(values, valueType)
+
+ val Schema(keyDataType, _) = schemaFor(keyType)
+ val Schema(valueDataType, valueNullable) = schemaFor(valueType)
NewInstance(
classOf[ArrayBasedMapData],
- keys :: values :: Nil,
+ convertedKeys :: convertedValues :: Nil,
dataType = MapType(keyDataType, valueDataType, valueNullable))
case t if t <:< localTypeOf[String] =>
@@ -558,6 +469,7 @@ object ScalaReflection extends ScalaReflection {
DateType,
"fromJavaDate",
inputObject :: Nil)
+
case t if t <:< localTypeOf[BigDecimal] =>
StaticInvoke(
Decimal,
@@ -587,26 +499,24 @@ object ScalaReflection extends ScalaReflection {
case t if t <:< localTypeOf[java.lang.Boolean] =>
Invoke(inputObject, "booleanValue", BooleanType)
- case t if t <:< definitions.IntTpe =>
- BoundReference(0, IntegerType, false)
- case t if t <:< definitions.LongTpe =>
- BoundReference(0, LongType, false)
- case t if t <:< definitions.DoubleTpe =>
- BoundReference(0, DoubleType, false)
- case t if t <:< definitions.FloatTpe =>
- BoundReference(0, FloatType, false)
- case t if t <:< definitions.ShortTpe =>
- BoundReference(0, ShortType, false)
- case t if t <:< definitions.ByteTpe =>
- BoundReference(0, ByteType, false)
- case t if t <:< definitions.BooleanTpe =>
- BoundReference(0, BooleanType, false)
-
case other =>
throw new UnsupportedOperationException(s"Extractor for type $other is not supported")
}
}
}
+
+ private def toCatalystArray(input: Expression, elementType: `Type`): Expression = {
+ val externalDataType = dataTypeFor(elementType)
+ val Schema(catalystType, nullable) = schemaFor(elementType)
+ if (isNativeType(catalystType)) {
+ NewInstance(
+ classOf[GenericArrayData],
+ input :: Nil,
+ dataType = ArrayType(catalystType, nullable))
+ } else {
+ MapObjects(extractorFor(_, elementType), input, externalDataType)
+ }
+ }
}
/**
@@ -635,8 +545,7 @@ trait ScalaReflection {
}
/** Returns a catalyst DataType and its nullability for the given Scala Type using reflection. */
- def schemaFor[T: TypeTag]: Schema =
- ScalaReflectionLock.synchronized { schemaFor(localTypeOf[T]) }
+ def schemaFor[T: TypeTag]: Schema = schemaFor(localTypeOf[T])
/**
* Return the Scala Type for `T` in the current classloader mirror.
@@ -736,39 +645,4 @@ trait ScalaReflection {
assert(methods.length == 1)
methods.head.getParameterTypes
}
-
- def typeOfObject: PartialFunction[Any, DataType] = {
- // The data type can be determined without ambiguity.
- case obj: Boolean => BooleanType
- case obj: Array[Byte] => BinaryType
- case obj: String => StringType
- case obj: UTF8String => StringType
- case obj: Byte => ByteType
- case obj: Short => ShortType
- case obj: Int => IntegerType
- case obj: Long => LongType
- case obj: Float => FloatType
- case obj: Double => DoubleType
- case obj: java.sql.Date => DateType
- case obj: java.math.BigDecimal => DecimalType.SYSTEM_DEFAULT
- case obj: Decimal => DecimalType.SYSTEM_DEFAULT
- case obj: java.sql.Timestamp => TimestampType
- case null => NullType
- // For other cases, there is no obvious mapping from the type of the given object to a
- // Catalyst data type. A user should provide his/her specific rules
- // (in a user-defined PartialFunction) to infer the Catalyst data type for other types of
- // objects and then compose the user-defined PartialFunction with this one.
- }
-
- implicit class CaseClassRelation[A <: Product : TypeTag](data: Seq[A]) {
-
- /**
- * Implicitly added to Sequences of case class objects. Returns a catalyst logical relation
- * for the the data in the sequence.
- */
- def asRelation: LocalRelation = {
- val output = attributesFor[A]
- LocalRelation.fromProduct(output, data)
- }
- }
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
index 456b595008..6eeba1442c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
@@ -30,10 +30,10 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection, GenerateUnsafeProjection}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.ScalaReflection
-import org.apache.spark.sql.types.{NullType, StructField, ObjectType, StructType}
+import org.apache.spark.sql.types.{StructField, ObjectType, StructType}
/**
- * A factory for constructing encoders that convert objects and primitves to and from the
+ * A factory for constructing encoders that convert objects and primitives to and from the
* internal row format using catalyst expressions and code generation. By default, the
* expressions used to retrieve values from an input row when producing an object will be created as
* follows:
@@ -44,20 +44,21 @@ import org.apache.spark.sql.types.{NullType, StructField, ObjectType, StructType
* to the name `value`.
*/
object ExpressionEncoder {
- def apply[T : TypeTag](flat: Boolean = false): ExpressionEncoder[T] = {
+ def apply[T : TypeTag](): ExpressionEncoder[T] = {
// We convert the not-serializable TypeTag into StructType and ClassTag.
val mirror = typeTag[T].mirror
val cls = mirror.runtimeClass(typeTag[T].tpe)
+ val flat = !classOf[Product].isAssignableFrom(cls)
- val inputObject = BoundReference(0, ObjectType(cls), nullable = true)
- val extractExpression = ScalaReflection.extractorsFor[T](inputObject)
- val constructExpression = ScalaReflection.constructorFor[T]
+ val inputObject = BoundReference(0, ScalaReflection.dataTypeFor[T], nullable = true)
+ val toRowExpression = ScalaReflection.extractorsFor[T](inputObject)
+ val fromRowExpression = ScalaReflection.constructorFor[T]
new ExpressionEncoder[T](
- extractExpression.dataType,
+ toRowExpression.dataType,
flat,
- extractExpression.flatten,
- constructExpression,
+ toRowExpression.flatten,
+ fromRowExpression,
ClassTag[T](cls))
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/FlatEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/FlatEncoder.scala
deleted file mode 100644
index 6d307ab13a..0000000000
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/FlatEncoder.scala
+++ /dev/null
@@ -1,50 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.catalyst.encoders
-
-import scala.reflect.ClassTag
-import scala.reflect.runtime.universe.{typeTag, TypeTag}
-
-import org.apache.spark.sql.types.StructType
-import org.apache.spark.sql.catalyst.expressions.{Literal, CreateNamedStruct, BoundReference}
-import org.apache.spark.sql.catalyst.ScalaReflection
-
-object FlatEncoder {
- import ScalaReflection.schemaFor
- import ScalaReflection.dataTypeFor
-
- def apply[T : TypeTag]: ExpressionEncoder[T] = {
- // We convert the not-serializable TypeTag into StructType and ClassTag.
- val tpe = typeTag[T].tpe
- val mirror = typeTag[T].mirror
- val cls = mirror.runtimeClass(tpe)
- assert(!schemaFor(tpe).dataType.isInstanceOf[StructType])
-
- val input = BoundReference(0, dataTypeFor(tpe), nullable = true)
- val toRowExpression = CreateNamedStruct(
- Literal("value") :: ProductEncoder.extractorFor(input, tpe) :: Nil)
- val fromRowExpression = ProductEncoder.constructorFor(tpe)
-
- new ExpressionEncoder[T](
- toRowExpression.dataType,
- flat = true,
- toRowExpression.flatten,
- fromRowExpression,
- ClassTag[T](cls))
- }
-}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala
deleted file mode 100644
index 2914c6ee79..0000000000
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala
+++ /dev/null
@@ -1,452 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.catalyst.encoders
-
-import org.apache.spark.util.Utils
-import org.apache.spark.unsafe.types.UTF8String
-import org.apache.spark.sql.types._
-import org.apache.spark.sql.catalyst.ScalaReflectionLock
-import org.apache.spark.sql.catalyst.ScalaReflection
-import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedExtractValue}
-import org.apache.spark.sql.catalyst.expressions
-import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.util.{DateTimeUtils, ArrayBasedMapData, GenericArrayData}
-
-import scala.reflect.ClassTag
-
-object ProductEncoder {
- import ScalaReflection.universe._
- import ScalaReflection.mirror
- import ScalaReflection.localTypeOf
- import ScalaReflection.dataTypeFor
- import ScalaReflection.Schema
- import ScalaReflection.schemaFor
- import ScalaReflection.arrayClassFor
-
- def apply[T <: Product : TypeTag]: ExpressionEncoder[T] = {
- // We convert the not-serializable TypeTag into StructType and ClassTag.
- val tpe = typeTag[T].tpe
- val mirror = typeTag[T].mirror
- val cls = mirror.runtimeClass(tpe)
-
- val inputObject = BoundReference(0, ObjectType(cls), nullable = true)
- val toRowExpression = extractorFor(inputObject, tpe).asInstanceOf[CreateNamedStruct]
- val fromRowExpression = constructorFor(tpe)
-
- new ExpressionEncoder[T](
- toRowExpression.dataType,
- flat = false,
- toRowExpression.flatten,
- fromRowExpression,
- ClassTag[T](cls))
- }
-
- // The Predef.Map is scala.collection.immutable.Map.
- // Since the map values can be mutable, we explicitly import scala.collection.Map at here.
- import scala.collection.Map
-
- def extractorFor(
- inputObject: Expression,
- tpe: `Type`): Expression = ScalaReflectionLock.synchronized {
- if (!inputObject.dataType.isInstanceOf[ObjectType]) {
- inputObject
- } else {
- tpe match {
- case t if t <:< localTypeOf[Option[_]] =>
- val TypeRef(_, _, Seq(optType)) = t
- optType match {
- // For primitive types we must manually unbox the value of the object.
- case t if t <:< definitions.IntTpe =>
- Invoke(
- UnwrapOption(ObjectType(classOf[java.lang.Integer]), inputObject),
- "intValue",
- IntegerType)
- case t if t <:< definitions.LongTpe =>
- Invoke(
- UnwrapOption(ObjectType(classOf[java.lang.Long]), inputObject),
- "longValue",
- LongType)
- case t if t <:< definitions.DoubleTpe =>
- Invoke(
- UnwrapOption(ObjectType(classOf[java.lang.Double]), inputObject),
- "doubleValue",
- DoubleType)
- case t if t <:< definitions.FloatTpe =>
- Invoke(
- UnwrapOption(ObjectType(classOf[java.lang.Float]), inputObject),
- "floatValue",
- FloatType)
- case t if t <:< definitions.ShortTpe =>
- Invoke(
- UnwrapOption(ObjectType(classOf[java.lang.Short]), inputObject),
- "shortValue",
- ShortType)
- case t if t <:< definitions.ByteTpe =>
- Invoke(
- UnwrapOption(ObjectType(classOf[java.lang.Byte]), inputObject),
- "byteValue",
- ByteType)
- case t if t <:< definitions.BooleanTpe =>
- Invoke(
- UnwrapOption(ObjectType(classOf[java.lang.Boolean]), inputObject),
- "booleanValue",
- BooleanType)
-
- // For non-primitives, we can just extract the object from the Option and then recurse.
- case other =>
- val className: String = optType.erasure.typeSymbol.asClass.fullName
- val classObj = Utils.classForName(className)
- val optionObjectType = ObjectType(classObj)
-
- val unwrapped = UnwrapOption(optionObjectType, inputObject)
- expressions.If(
- IsNull(unwrapped),
- expressions.Literal.create(null, schemaFor(optType).dataType),
- extractorFor(unwrapped, optType))
- }
-
- case t if t <:< localTypeOf[Product] =>
- val formalTypeArgs = t.typeSymbol.asClass.typeParams
- val TypeRef(_, _, actualTypeArgs) = t
- val constructorSymbol = t.member(nme.CONSTRUCTOR)
- val params = if (constructorSymbol.isMethod) {
- constructorSymbol.asMethod.paramss
- } else {
- // Find the primary constructor, and use its parameter ordering.
- val primaryConstructorSymbol: Option[Symbol] =
- constructorSymbol.asTerm.alternatives.find(s =>
- s.isMethod && s.asMethod.isPrimaryConstructor)
-
- if (primaryConstructorSymbol.isEmpty) {
- sys.error("Internal SQL error: Product object did not have a primary constructor.")
- } else {
- primaryConstructorSymbol.get.asMethod.paramss
- }
- }
-
- CreateNamedStruct(params.head.flatMap { p =>
- val fieldName = p.name.toString
- val fieldType = p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs)
- val fieldValue = Invoke(inputObject, fieldName, dataTypeFor(fieldType))
- expressions.Literal(fieldName) :: extractorFor(fieldValue, fieldType) :: Nil
- })
-
- case t if t <:< localTypeOf[Array[_]] =>
- val TypeRef(_, _, Seq(elementType)) = t
- toCatalystArray(inputObject, elementType)
-
- case t if t <:< localTypeOf[Seq[_]] =>
- val TypeRef(_, _, Seq(elementType)) = t
- toCatalystArray(inputObject, elementType)
-
- case t if t <:< localTypeOf[Map[_, _]] =>
- val TypeRef(_, _, Seq(keyType, valueType)) = t
-
- val keys =
- Invoke(
- Invoke(inputObject, "keysIterator",
- ObjectType(classOf[scala.collection.Iterator[_]])),
- "toSeq",
- ObjectType(classOf[scala.collection.Seq[_]]))
- val convertedKeys = toCatalystArray(keys, keyType)
-
- val values =
- Invoke(
- Invoke(inputObject, "valuesIterator",
- ObjectType(classOf[scala.collection.Iterator[_]])),
- "toSeq",
- ObjectType(classOf[scala.collection.Seq[_]]))
- val convertedValues = toCatalystArray(values, valueType)
-
- val Schema(keyDataType, _) = schemaFor(keyType)
- val Schema(valueDataType, valueNullable) = schemaFor(valueType)
- NewInstance(
- classOf[ArrayBasedMapData],
- convertedKeys :: convertedValues :: Nil,
- dataType = MapType(keyDataType, valueDataType, valueNullable))
-
- case t if t <:< localTypeOf[String] =>
- StaticInvoke(
- classOf[UTF8String],
- StringType,
- "fromString",
- inputObject :: Nil)
-
- case t if t <:< localTypeOf[java.sql.Timestamp] =>
- StaticInvoke(
- DateTimeUtils,
- TimestampType,
- "fromJavaTimestamp",
- inputObject :: Nil)
-
- case t if t <:< localTypeOf[java.sql.Date] =>
- StaticInvoke(
- DateTimeUtils,
- DateType,
- "fromJavaDate",
- inputObject :: Nil)
-
- case t if t <:< localTypeOf[BigDecimal] =>
- StaticInvoke(
- Decimal,
- DecimalType.SYSTEM_DEFAULT,
- "apply",
- inputObject :: Nil)
-
- case t if t <:< localTypeOf[java.math.BigDecimal] =>
- StaticInvoke(
- Decimal,
- DecimalType.SYSTEM_DEFAULT,
- "apply",
- inputObject :: Nil)
-
- case t if t <:< localTypeOf[java.lang.Integer] =>
- Invoke(inputObject, "intValue", IntegerType)
- case t if t <:< localTypeOf[java.lang.Long] =>
- Invoke(inputObject, "longValue", LongType)
- case t if t <:< localTypeOf[java.lang.Double] =>
- Invoke(inputObject, "doubleValue", DoubleType)
- case t if t <:< localTypeOf[java.lang.Float] =>
- Invoke(inputObject, "floatValue", FloatType)
- case t if t <:< localTypeOf[java.lang.Short] =>
- Invoke(inputObject, "shortValue", ShortType)
- case t if t <:< localTypeOf[java.lang.Byte] =>
- Invoke(inputObject, "byteValue", ByteType)
- case t if t <:< localTypeOf[java.lang.Boolean] =>
- Invoke(inputObject, "booleanValue", BooleanType)
-
- case other =>
- throw new UnsupportedOperationException(s"Encoder for type $other is not supported")
- }
- }
- }
-
- private def toCatalystArray(input: Expression, elementType: `Type`): Expression = {
- val externalDataType = dataTypeFor(elementType)
- val Schema(catalystType, nullable) = schemaFor(elementType)
- if (RowEncoder.isNativeType(catalystType)) {
- NewInstance(
- classOf[GenericArrayData],
- input :: Nil,
- dataType = ArrayType(catalystType, nullable))
- } else {
- MapObjects(extractorFor(_, elementType), input, externalDataType)
- }
- }
-
- def constructorFor(
- tpe: `Type`,
- path: Option[Expression] = None): Expression = ScalaReflectionLock.synchronized {
-
- /** Returns the current path with a sub-field extracted. */
- def addToPath(part: String): Expression = path
- .map(p => UnresolvedExtractValue(p, expressions.Literal(part)))
- .getOrElse(UnresolvedAttribute(part))
-
- /** Returns the current path with a field at ordinal extracted. */
- def addToPathOrdinal(ordinal: Int, dataType: DataType): Expression = path
- .map(p => GetInternalRowField(p, ordinal, dataType))
- .getOrElse(BoundReference(ordinal, dataType, false))
-
- /** Returns the current path or `BoundReference`. */
- def getPath: Expression = path.getOrElse(BoundReference(0, schemaFor(tpe).dataType, true))
-
- tpe match {
- case t if !dataTypeFor(t).isInstanceOf[ObjectType] => getPath
-
- case t if t <:< localTypeOf[Option[_]] =>
- val TypeRef(_, _, Seq(optType)) = t
- WrapOption(null, constructorFor(optType, path))
-
- case t if t <:< localTypeOf[java.lang.Integer] =>
- val boxedType = classOf[java.lang.Integer]
- val objectType = ObjectType(boxedType)
- NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType)
-
- case t if t <:< localTypeOf[java.lang.Long] =>
- val boxedType = classOf[java.lang.Long]
- val objectType = ObjectType(boxedType)
- NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType)
-
- case t if t <:< localTypeOf[java.lang.Double] =>
- val boxedType = classOf[java.lang.Double]
- val objectType = ObjectType(boxedType)
- NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType)
-
- case t if t <:< localTypeOf[java.lang.Float] =>
- val boxedType = classOf[java.lang.Float]
- val objectType = ObjectType(boxedType)
- NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType)
-
- case t if t <:< localTypeOf[java.lang.Short] =>
- val boxedType = classOf[java.lang.Short]
- val objectType = ObjectType(boxedType)
- NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType)
-
- case t if t <:< localTypeOf[java.lang.Byte] =>
- val boxedType = classOf[java.lang.Byte]
- val objectType = ObjectType(boxedType)
- NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType)
-
- case t if t <:< localTypeOf[java.lang.Boolean] =>
- val boxedType = classOf[java.lang.Boolean]
- val objectType = ObjectType(boxedType)
- NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType)
-
- case t if t <:< localTypeOf[java.sql.Date] =>
- StaticInvoke(
- DateTimeUtils,
- ObjectType(classOf[java.sql.Date]),
- "toJavaDate",
- getPath :: Nil,
- propagateNull = true)
-
- case t if t <:< localTypeOf[java.sql.Timestamp] =>
- StaticInvoke(
- DateTimeUtils,
- ObjectType(classOf[java.sql.Timestamp]),
- "toJavaTimestamp",
- getPath :: Nil,
- propagateNull = true)
-
- case t if t <:< localTypeOf[java.lang.String] =>
- Invoke(getPath, "toString", ObjectType(classOf[String]))
-
- case t if t <:< localTypeOf[java.math.BigDecimal] =>
- Invoke(getPath, "toJavaBigDecimal", ObjectType(classOf[java.math.BigDecimal]))
-
- case t if t <:< localTypeOf[BigDecimal] =>
- Invoke(getPath, "toBigDecimal", ObjectType(classOf[BigDecimal]))
-
- case t if t <:< localTypeOf[Array[_]] =>
- val TypeRef(_, _, Seq(elementType)) = t
- val primitiveMethod = elementType match {
- case t if t <:< definitions.IntTpe => Some("toIntArray")
- case t if t <:< definitions.LongTpe => Some("toLongArray")
- case t if t <:< definitions.DoubleTpe => Some("toDoubleArray")
- case t if t <:< definitions.FloatTpe => Some("toFloatArray")
- case t if t <:< definitions.ShortTpe => Some("toShortArray")
- case t if t <:< definitions.ByteTpe => Some("toByteArray")
- case t if t <:< definitions.BooleanTpe => Some("toBooleanArray")
- case _ => None
- }
-
- primitiveMethod.map { method =>
- Invoke(getPath, method, arrayClassFor(elementType))
- }.getOrElse {
- Invoke(
- MapObjects(
- p => constructorFor(elementType, Some(p)),
- getPath,
- schemaFor(elementType).dataType),
- "array",
- arrayClassFor(elementType))
- }
-
- case t if t <:< localTypeOf[Seq[_]] =>
- val TypeRef(_, _, Seq(elementType)) = t
- val arrayData =
- Invoke(
- MapObjects(
- p => constructorFor(elementType, Some(p)),
- getPath,
- schemaFor(elementType).dataType),
- "array",
- ObjectType(classOf[Array[Any]]))
-
- StaticInvoke(
- scala.collection.mutable.WrappedArray,
- ObjectType(classOf[Seq[_]]),
- "make",
- arrayData :: Nil)
-
- case t if t <:< localTypeOf[Map[_, _]] =>
- val TypeRef(_, _, Seq(keyType, valueType)) = t
-
- val keyData =
- Invoke(
- MapObjects(
- p => constructorFor(keyType, Some(p)),
- Invoke(getPath, "keyArray", ArrayType(schemaFor(keyType).dataType)),
- schemaFor(keyType).dataType),
- "array",
- ObjectType(classOf[Array[Any]]))
-
- val valueData =
- Invoke(
- MapObjects(
- p => constructorFor(valueType, Some(p)),
- Invoke(getPath, "valueArray", ArrayType(schemaFor(valueType).dataType)),
- schemaFor(valueType).dataType),
- "array",
- ObjectType(classOf[Array[Any]]))
-
- StaticInvoke(
- ArrayBasedMapData,
- ObjectType(classOf[Map[_, _]]),
- "toScalaMap",
- keyData :: valueData :: Nil)
-
- case t if t <:< localTypeOf[Product] =>
- val formalTypeArgs = t.typeSymbol.asClass.typeParams
- val TypeRef(_, _, actualTypeArgs) = t
- val constructorSymbol = t.member(nme.CONSTRUCTOR)
- val params = if (constructorSymbol.isMethod) {
- constructorSymbol.asMethod.paramss
- } else {
- // Find the primary constructor, and use its parameter ordering.
- val primaryConstructorSymbol: Option[Symbol] =
- constructorSymbol.asTerm.alternatives.find(s =>
- s.isMethod && s.asMethod.isPrimaryConstructor)
-
- if (primaryConstructorSymbol.isEmpty) {
- sys.error("Internal SQL error: Product object did not have a primary constructor.")
- } else {
- primaryConstructorSymbol.get.asMethod.paramss
- }
- }
-
- val cls = mirror.runtimeClass(tpe.erasure.typeSymbol.asClass)
-
- val arguments = params.head.zipWithIndex.map { case (p, i) =>
- val fieldName = p.name.toString
- val fieldType = p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs)
- val dataType = schemaFor(fieldType).dataType
-
- // For tuples, we based grab the inner fields by ordinal instead of name.
- if (cls.getName startsWith "scala.Tuple") {
- constructorFor(fieldType, Some(addToPathOrdinal(i, dataType)))
- } else {
- constructorFor(fieldType, Some(addToPath(fieldName)))
- }
- }
-
- val newInstance = NewInstance(cls, arguments, propagateNull = false, ObjectType(cls))
-
- if (path.nonEmpty) {
- expressions.If(
- IsNull(getPath),
- expressions.Literal.create(null, ObjectType(cls)),
- newInstance
- )
- } else {
- newInstance
- }
- }
- }
-}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
index 9bb1602494..4cda4824ac 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
@@ -23,6 +23,7 @@ import scala.reflect.ClassTag
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.util.{GenericArrayData, ArrayBasedMapData, DateTimeUtils}
+import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
@@ -132,17 +133,8 @@ object RowEncoder {
CreateStruct(convertedFields)
}
- /**
- * Returns true if the value of this data type is same between internal and external.
- */
- def isNativeType(dt: DataType): Boolean = dt match {
- case BooleanType | ByteType | ShortType | IntegerType | LongType |
- FloatType | DoubleType | BinaryType => true
- case _ => false
- }
-
private def externalDataTypeFor(dt: DataType): DataType = dt match {
- case _ if isNativeType(dt) => dt
+ case _ if ScalaReflection.isNativeType(dt) => dt
case TimestampType => ObjectType(classOf[java.sql.Timestamp])
case DateType => ObjectType(classOf[java.sql.Date])
case _: DecimalType => ObjectType(classOf[java.math.BigDecimal])
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
index f865a9408e..ef7399e019 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
@@ -24,7 +24,6 @@ import org.apache.spark.SparkConf
import org.apache.spark.serializer._
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer
-import org.apache.spark.sql.catalyst.encoders.ProductEncoder
import org.apache.spark.sql.catalyst.plans.logical.{Project, LocalRelation}
import org.apache.spark.sql.catalyst.util.GenericArrayData
import org.apache.spark.sql.catalyst.InternalRow
@@ -300,10 +299,9 @@ case class UnwrapOption(
/**
* Converts the result of evaluating `child` into an option, checking both the isNull bit and
* (in the case of reference types) equality with null.
- * @param optionType The datatype to be held inside of the Option.
* @param child The expression to evaluate and wrap.
*/
-case class WrapOption(optionType: DataType, child: Expression)
+case class WrapOption(child: Expression)
extends UnaryExpression with ExpectsInputTypes {
override def dataType: DataType = ObjectType(classOf[Option[_]])
@@ -316,14 +314,13 @@ case class WrapOption(optionType: DataType, child: Expression)
throw new UnsupportedOperationException("Only code-generated evaluation is supported")
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
- val javaType = ctx.javaType(optionType)
val inputObject = child.gen(ctx)
s"""
${inputObject.code}
boolean ${ev.isNull} = false;
- scala.Option<$javaType> ${ev.value} =
+ scala.Option ${ev.value} =
${inputObject.isNull} ?
scala.Option$$.MODULE$$.apply(null) : new scala.Some(${inputObject.value});
"""
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
index 4ea410d492..c2aace1ef2 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
@@ -186,74 +186,6 @@ class ScalaReflectionSuite extends SparkFunSuite {
nullable = true))
}
- test("get data type of a value") {
- // BooleanType
- assert(BooleanType === typeOfObject(true))
- assert(BooleanType === typeOfObject(false))
-
- // BinaryType
- assert(BinaryType === typeOfObject("string".getBytes))
-
- // StringType
- assert(StringType === typeOfObject("string"))
-
- // ByteType
- assert(ByteType === typeOfObject(127.toByte))
-
- // ShortType
- assert(ShortType === typeOfObject(32767.toShort))
-
- // IntegerType
- assert(IntegerType === typeOfObject(2147483647))
-
- // LongType
- assert(LongType === typeOfObject(9223372036854775807L))
-
- // FloatType
- assert(FloatType === typeOfObject(3.4028235E38.toFloat))
-
- // DoubleType
- assert(DoubleType === typeOfObject(1.7976931348623157E308))
-
- // DecimalType
- assert(DecimalType.SYSTEM_DEFAULT ===
- typeOfObject(new java.math.BigDecimal("1.7976931348623157E318")))
-
- // DateType
- assert(DateType === typeOfObject(Date.valueOf("2014-07-25")))
-
- // TimestampType
- assert(TimestampType === typeOfObject(Timestamp.valueOf("2014-07-25 10:26:00")))
-
- // NullType
- assert(NullType === typeOfObject(null))
-
- def typeOfObject1: PartialFunction[Any, DataType] = typeOfObject orElse {
- case value: java.math.BigInteger => DecimalType.SYSTEM_DEFAULT
- case value: java.math.BigDecimal => DecimalType.SYSTEM_DEFAULT
- case _ => StringType
- }
-
- assert(DecimalType.SYSTEM_DEFAULT === typeOfObject1(
- new BigInteger("92233720368547758070")))
- assert(DecimalType.SYSTEM_DEFAULT === typeOfObject1(
- new java.math.BigDecimal("1.7976931348623157E318")))
- assert(StringType === typeOfObject1(BigInt("92233720368547758070")))
-
- def typeOfObject2: PartialFunction[Any, DataType] = typeOfObject orElse {
- case value: java.math.BigInteger => DecimalType.SYSTEM_DEFAULT
- }
-
- intercept[MatchError](typeOfObject2(BigInt("92233720368547758070")))
-
- def typeOfObject3: PartialFunction[Any, DataType] = typeOfObject orElse {
- case c: Seq[_] => ArrayType(typeOfObject3(c.head))
- }
-
- assert(ArrayType(IntegerType) === typeOfObject3(Seq(1, 2, 3)))
- assert(ArrayType(ArrayType(IntegerType)) === typeOfObject3(Seq(Seq(1, 2, 3))))
- }
-
test("convert PrimitiveData to catalyst") {
val data = PrimitiveData(1, 1, 1, 1, 1, 1, true)
val convertedData = InternalRow(1, 1.toLong, 1.toDouble, 1.toFloat, 1.toShort, 1.toByte, true)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
index cde0364f3d..76459b34a4 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
@@ -17,24 +17,234 @@
package org.apache.spark.sql.catalyst.encoders
+import java.sql.{Timestamp, Date}
import java.util.Arrays
import java.util.concurrent.ConcurrentMap
+import scala.collection.mutable.ArrayBuffer
+import scala.reflect.runtime.universe.TypeTag
import com.google.common.collect.MapMaker
import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.Encoders
import org.apache.spark.sql.catalyst.expressions.AttributeReference
import org.apache.spark.sql.catalyst.util.ArrayData
+import org.apache.spark.sql.catalyst.{OptionalData, PrimitiveData}
import org.apache.spark.sql.types.ArrayType
-abstract class ExpressionEncoderSuite extends SparkFunSuite {
- val outers: ConcurrentMap[String, AnyRef] = new MapMaker().weakValues().makeMap()
+case class RepeatedStruct(s: Seq[PrimitiveData])
- protected def encodeDecodeTest[T](
+case class NestedArray(a: Array[Array[Int]]) {
+ override def equals(other: Any): Boolean = other match {
+ case NestedArray(otherArray) =>
+ java.util.Arrays.deepEquals(
+ a.asInstanceOf[Array[AnyRef]],
+ otherArray.asInstanceOf[Array[AnyRef]])
+ case _ => false
+ }
+}
+
+case class BoxedData(
+ intField: java.lang.Integer,
+ longField: java.lang.Long,
+ doubleField: java.lang.Double,
+ floatField: java.lang.Float,
+ shortField: java.lang.Short,
+ byteField: java.lang.Byte,
+ booleanField: java.lang.Boolean)
+
+case class RepeatedData(
+ arrayField: Seq[Int],
+ arrayFieldContainsNull: Seq[java.lang.Integer],
+ mapField: scala.collection.Map[Int, Long],
+ mapFieldNull: scala.collection.Map[Int, java.lang.Long],
+ structField: PrimitiveData)
+
+case class SpecificCollection(l: List[Int])
+
+/** For testing Kryo serialization based encoder. */
+class KryoSerializable(val value: Int) {
+ override def equals(other: Any): Boolean = {
+ this.value == other.asInstanceOf[KryoSerializable].value
+ }
+}
+
+/** For testing Java serialization based encoder. */
+class JavaSerializable(val value: Int) extends Serializable {
+ override def equals(other: Any): Boolean = {
+ this.value == other.asInstanceOf[JavaSerializable].value
+ }
+}
+
+class ExpressionEncoderSuite extends SparkFunSuite {
+ implicit def encoder[T : TypeTag]: ExpressionEncoder[T] = ExpressionEncoder()
+
+ // test flat encoders
+ encodeDecodeTest(false, "primitive boolean")
+ encodeDecodeTest(-3.toByte, "primitive byte")
+ encodeDecodeTest(-3.toShort, "primitive short")
+ encodeDecodeTest(-3, "primitive int")
+ encodeDecodeTest(-3L, "primitive long")
+ encodeDecodeTest(-3.7f, "primitive float")
+ encodeDecodeTest(-3.7, "primitive double")
+
+ encodeDecodeTest(new java.lang.Boolean(false), "boxed boolean")
+ encodeDecodeTest(new java.lang.Byte(-3.toByte), "boxed byte")
+ encodeDecodeTest(new java.lang.Short(-3.toShort), "boxed short")
+ encodeDecodeTest(new java.lang.Integer(-3), "boxed int")
+ encodeDecodeTest(new java.lang.Long(-3L), "boxed long")
+ encodeDecodeTest(new java.lang.Float(-3.7f), "boxed float")
+ encodeDecodeTest(new java.lang.Double(-3.7), "boxed double")
+
+ encodeDecodeTest(BigDecimal("32131413.211321313"), "scala decimal")
+ // encodeDecodeTest(new java.math.BigDecimal("231341.23123"), "java decimal")
+
+ encodeDecodeTest("hello", "string")
+ encodeDecodeTest(Date.valueOf("2012-12-23"), "date")
+ encodeDecodeTest(Timestamp.valueOf("2016-01-29 10:00:00"), "timestamp")
+ encodeDecodeTest(Array[Byte](13, 21, -23), "binary")
+
+ encodeDecodeTest(Seq(31, -123, 4), "seq of int")
+ encodeDecodeTest(Seq("abc", "xyz"), "seq of string")
+ encodeDecodeTest(Seq("abc", null, "xyz"), "seq of string with null")
+ encodeDecodeTest(Seq.empty[Int], "empty seq of int")
+ encodeDecodeTest(Seq.empty[String], "empty seq of string")
+
+ encodeDecodeTest(Seq(Seq(31, -123), null, Seq(4, 67)), "seq of seq of int")
+ encodeDecodeTest(Seq(Seq("abc", "xyz"), Seq[String](null), null, Seq("1", null, "2")),
+ "seq of seq of string")
+
+ encodeDecodeTest(Array(31, -123, 4), "array of int")
+ encodeDecodeTest(Array("abc", "xyz"), "array of string")
+ encodeDecodeTest(Array("a", null, "x"), "array of string with null")
+ encodeDecodeTest(Array.empty[Int], "empty array of int")
+ encodeDecodeTest(Array.empty[String], "empty array of string")
+
+ encodeDecodeTest(Array(Array(31, -123), null, Array(4, 67)), "array of array of int")
+ encodeDecodeTest(Array(Array("abc", "xyz"), Array[String](null), null, Array("1", null, "2")),
+ "array of array of string")
+
+ encodeDecodeTest(Map(1 -> "a", 2 -> "b"), "map")
+ encodeDecodeTest(Map(1 -> "a", 2 -> null), "map with null")
+ encodeDecodeTest(Map(1 -> Map("a" -> 1), 2 -> Map("b" -> 2)), "map of map")
+
+ // Kryo encoders
+ encodeDecodeTest("hello", "kryo string")(encoderFor(Encoders.kryo[String]))
+ encodeDecodeTest(new KryoSerializable(15), "kryo object")(
+ encoderFor(Encoders.kryo[KryoSerializable]))
+
+ // Java encoders
+ encodeDecodeTest("hello", "java string")(encoderFor(Encoders.javaSerialization[String]))
+ encodeDecodeTest(new JavaSerializable(15), "java object")(
+ encoderFor(Encoders.javaSerialization[JavaSerializable]))
+
+ // test product encoders
+ private def productTest[T <: Product : ExpressionEncoder](input: T): Unit = {
+ encodeDecodeTest(input, input.getClass.getSimpleName)
+ }
+
+ case class InnerClass(i: Int)
+ productTest(InnerClass(1))
+
+ productTest(PrimitiveData(1, 1, 1, 1, 1, 1, true))
+
+ productTest(
+ OptionalData(Some(2), Some(2), Some(2), Some(2), Some(2), Some(2), Some(true),
+ Some(PrimitiveData(1, 1, 1, 1, 1, 1, true))))
+
+ productTest(OptionalData(None, None, None, None, None, None, None, None))
+
+ productTest(BoxedData(1, 1L, 1.0, 1.0f, 1.toShort, 1.toByte, true))
+
+ productTest(BoxedData(null, null, null, null, null, null, null))
+
+ productTest(RepeatedStruct(PrimitiveData(1, 1, 1, 1, 1, 1, true) :: Nil))
+
+ productTest((1, "test", PrimitiveData(1, 1, 1, 1, 1, 1, true)))
+
+ productTest(
+ RepeatedData(
+ Seq(1, 2),
+ Seq(new Integer(1), null, new Integer(2)),
+ Map(1 -> 2L),
+ Map(1 -> null),
+ PrimitiveData(1, 1, 1, 1, 1, 1, true)))
+
+ productTest(NestedArray(Array(Array(1, -2, 3), null, Array(4, 5, -6))))
+
+ productTest(("Seq[(String, String)]",
+ Seq(("a", "b"))))
+ productTest(("Seq[(Int, Int)]",
+ Seq((1, 2))))
+ productTest(("Seq[(Long, Long)]",
+ Seq((1L, 2L))))
+ productTest(("Seq[(Float, Float)]",
+ Seq((1.toFloat, 2.toFloat))))
+ productTest(("Seq[(Double, Double)]",
+ Seq((1.toDouble, 2.toDouble))))
+ productTest(("Seq[(Short, Short)]",
+ Seq((1.toShort, 2.toShort))))
+ productTest(("Seq[(Byte, Byte)]",
+ Seq((1.toByte, 2.toByte))))
+ productTest(("Seq[(Boolean, Boolean)]",
+ Seq((true, false))))
+
+ productTest(("ArrayBuffer[(String, String)]",
+ ArrayBuffer(("a", "b"))))
+ productTest(("ArrayBuffer[(Int, Int)]",
+ ArrayBuffer((1, 2))))
+ productTest(("ArrayBuffer[(Long, Long)]",
+ ArrayBuffer((1L, 2L))))
+ productTest(("ArrayBuffer[(Float, Float)]",
+ ArrayBuffer((1.toFloat, 2.toFloat))))
+ productTest(("ArrayBuffer[(Double, Double)]",
+ ArrayBuffer((1.toDouble, 2.toDouble))))
+ productTest(("ArrayBuffer[(Short, Short)]",
+ ArrayBuffer((1.toShort, 2.toShort))))
+ productTest(("ArrayBuffer[(Byte, Byte)]",
+ ArrayBuffer((1.toByte, 2.toByte))))
+ productTest(("ArrayBuffer[(Boolean, Boolean)]",
+ ArrayBuffer((true, false))))
+
+ productTest(("Seq[Seq[(Int, Int)]]",
+ Seq(Seq((1, 2)))))
+
+ // test for ExpressionEncoder.tuple
+ encodeDecodeTest(
+ 1 -> 10L,
+ "tuple with 2 flat encoders")(
+ ExpressionEncoder.tuple(ExpressionEncoder[Int], ExpressionEncoder[Long]))
+
+ encodeDecodeTest(
+ (PrimitiveData(1, 1, 1, 1, 1, 1, true), (3, 30L)),
+ "tuple with 2 product encoders")(
+ ExpressionEncoder.tuple(ExpressionEncoder[PrimitiveData], ExpressionEncoder[(Int, Long)]))
+
+ encodeDecodeTest(
+ (PrimitiveData(1, 1, 1, 1, 1, 1, true), 3),
+ "tuple with flat encoder and product encoder")(
+ ExpressionEncoder.tuple(ExpressionEncoder[PrimitiveData], ExpressionEncoder[Int]))
+
+ encodeDecodeTest(
+ (3, PrimitiveData(1, 1, 1, 1, 1, 1, true)),
+ "tuple with product encoder and flat encoder")(
+ ExpressionEncoder.tuple(ExpressionEncoder[Int], ExpressionEncoder[PrimitiveData]))
+
+ encodeDecodeTest(
+ (1, (10, 100L)),
+ "nested tuple encoder") {
+ val intEnc = ExpressionEncoder[Int]
+ val longEnc = ExpressionEncoder[Long]
+ ExpressionEncoder.tuple(intEnc, ExpressionEncoder.tuple(intEnc, longEnc))
+ }
+
+ private val outers: ConcurrentMap[String, AnyRef] = new MapMaker().weakValues().makeMap()
+ outers.put(getClass.getName, this)
+ private def encodeDecodeTest[T : ExpressionEncoder](
input: T,
- encoder: ExpressionEncoder[T],
testName: String): Unit = {
test(s"encode/decode for $testName: $input") {
+ val encoder = implicitly[ExpressionEncoder[T]]
val row = encoder.toRow(input)
val schema = encoder.schema.toAttributes
val boundEncoder = encoder.resolve(schema, outers).bind(schema)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/FlatEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/FlatEncoderSuite.scala
deleted file mode 100644
index 07523d49f4..0000000000
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/FlatEncoderSuite.scala
+++ /dev/null
@@ -1,99 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.catalyst.encoders
-
-import java.sql.{Date, Timestamp}
-import org.apache.spark.sql.Encoders
-
-class FlatEncoderSuite extends ExpressionEncoderSuite {
- encodeDecodeTest(false, FlatEncoder[Boolean], "primitive boolean")
- encodeDecodeTest(-3.toByte, FlatEncoder[Byte], "primitive byte")
- encodeDecodeTest(-3.toShort, FlatEncoder[Short], "primitive short")
- encodeDecodeTest(-3, FlatEncoder[Int], "primitive int")
- encodeDecodeTest(-3L, FlatEncoder[Long], "primitive long")
- encodeDecodeTest(-3.7f, FlatEncoder[Float], "primitive float")
- encodeDecodeTest(-3.7, FlatEncoder[Double], "primitive double")
-
- encodeDecodeTest(new java.lang.Boolean(false), FlatEncoder[java.lang.Boolean], "boxed boolean")
- encodeDecodeTest(new java.lang.Byte(-3.toByte), FlatEncoder[java.lang.Byte], "boxed byte")
- encodeDecodeTest(new java.lang.Short(-3.toShort), FlatEncoder[java.lang.Short], "boxed short")
- encodeDecodeTest(new java.lang.Integer(-3), FlatEncoder[java.lang.Integer], "boxed int")
- encodeDecodeTest(new java.lang.Long(-3L), FlatEncoder[java.lang.Long], "boxed long")
- encodeDecodeTest(new java.lang.Float(-3.7f), FlatEncoder[java.lang.Float], "boxed float")
- encodeDecodeTest(new java.lang.Double(-3.7), FlatEncoder[java.lang.Double], "boxed double")
-
- encodeDecodeTest(BigDecimal("32131413.211321313"), FlatEncoder[BigDecimal], "scala decimal")
- type JDecimal = java.math.BigDecimal
- // encodeDecodeTest(new JDecimal("231341.23123"), FlatEncoder[JDecimal], "java decimal")
-
- encodeDecodeTest("hello", FlatEncoder[String], "string")
- encodeDecodeTest(Date.valueOf("2012-12-23"), FlatEncoder[Date], "date")
- encodeDecodeTest(Timestamp.valueOf("2016-01-29 10:00:00"), FlatEncoder[Timestamp], "timestamp")
- encodeDecodeTest(Array[Byte](13, 21, -23), FlatEncoder[Array[Byte]], "binary")
-
- encodeDecodeTest(Seq(31, -123, 4), FlatEncoder[Seq[Int]], "seq of int")
- encodeDecodeTest(Seq("abc", "xyz"), FlatEncoder[Seq[String]], "seq of string")
- encodeDecodeTest(Seq("abc", null, "xyz"), FlatEncoder[Seq[String]], "seq of string with null")
- encodeDecodeTest(Seq.empty[Int], FlatEncoder[Seq[Int]], "empty seq of int")
- encodeDecodeTest(Seq.empty[String], FlatEncoder[Seq[String]], "empty seq of string")
-
- encodeDecodeTest(Seq(Seq(31, -123), null, Seq(4, 67)),
- FlatEncoder[Seq[Seq[Int]]], "seq of seq of int")
- encodeDecodeTest(Seq(Seq("abc", "xyz"), Seq[String](null), null, Seq("1", null, "2")),
- FlatEncoder[Seq[Seq[String]]], "seq of seq of string")
-
- encodeDecodeTest(Array(31, -123, 4), FlatEncoder[Array[Int]], "array of int")
- encodeDecodeTest(Array("abc", "xyz"), FlatEncoder[Array[String]], "array of string")
- encodeDecodeTest(Array("a", null, "x"), FlatEncoder[Array[String]], "array of string with null")
- encodeDecodeTest(Array.empty[Int], FlatEncoder[Array[Int]], "empty array of int")
- encodeDecodeTest(Array.empty[String], FlatEncoder[Array[String]], "empty array of string")
-
- encodeDecodeTest(Array(Array(31, -123), null, Array(4, 67)),
- FlatEncoder[Array[Array[Int]]], "array of array of int")
- encodeDecodeTest(Array(Array("abc", "xyz"), Array[String](null), null, Array("1", null, "2")),
- FlatEncoder[Array[Array[String]]], "array of array of string")
-
- encodeDecodeTest(Map(1 -> "a", 2 -> "b"), FlatEncoder[Map[Int, String]], "map")
- encodeDecodeTest(Map(1 -> "a", 2 -> null), FlatEncoder[Map[Int, String]], "map with null")
- encodeDecodeTest(Map(1 -> Map("a" -> 1), 2 -> Map("b" -> 2)),
- FlatEncoder[Map[Int, Map[String, Int]]], "map of map")
-
- // Kryo encoders
- encodeDecodeTest("hello", encoderFor(Encoders.kryo[String]), "kryo string")
- encodeDecodeTest(new KryoSerializable(15),
- encoderFor(Encoders.kryo[KryoSerializable]), "kryo object")
-
- // Java encoders
- encodeDecodeTest("hello", encoderFor(Encoders.javaSerialization[String]), "java string")
- encodeDecodeTest(new JavaSerializable(15),
- encoderFor(Encoders.javaSerialization[JavaSerializable]), "java object")
-}
-
-/** For testing Kryo serialization based encoder. */
-class KryoSerializable(val value: Int) {
- override def equals(other: Any): Boolean = {
- this.value == other.asInstanceOf[KryoSerializable].value
- }
-}
-
-/** For testing Java serialization based encoder. */
-class JavaSerializable(val value: Int) extends Serializable {
- override def equals(other: Any): Boolean = {
- this.value == other.asInstanceOf[JavaSerializable].value
- }
-}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala
deleted file mode 100644
index 1798514c5c..0000000000
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala
+++ /dev/null
@@ -1,156 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.catalyst.encoders
-
-import scala.collection.mutable.ArrayBuffer
-import scala.reflect.runtime.universe.TypeTag
-
-import org.apache.spark.sql.catalyst.{OptionalData, PrimitiveData}
-
-case class RepeatedStruct(s: Seq[PrimitiveData])
-
-case class NestedArray(a: Array[Array[Int]]) {
- override def equals(other: Any): Boolean = other match {
- case NestedArray(otherArray) =>
- java.util.Arrays.deepEquals(
- a.asInstanceOf[Array[AnyRef]],
- otherArray.asInstanceOf[Array[AnyRef]])
- case _ => false
- }
-}
-
-case class BoxedData(
- intField: java.lang.Integer,
- longField: java.lang.Long,
- doubleField: java.lang.Double,
- floatField: java.lang.Float,
- shortField: java.lang.Short,
- byteField: java.lang.Byte,
- booleanField: java.lang.Boolean)
-
-case class RepeatedData(
- arrayField: Seq[Int],
- arrayFieldContainsNull: Seq[java.lang.Integer],
- mapField: scala.collection.Map[Int, Long],
- mapFieldNull: scala.collection.Map[Int, java.lang.Long],
- structField: PrimitiveData)
-
-case class SpecificCollection(l: List[Int])
-
-class ProductEncoderSuite extends ExpressionEncoderSuite {
- outers.put(getClass.getName, this)
-
- case class InnerClass(i: Int)
- productTest(InnerClass(1))
-
- productTest(PrimitiveData(1, 1, 1, 1, 1, 1, true))
-
- productTest(
- OptionalData(Some(2), Some(2), Some(2), Some(2), Some(2), Some(2), Some(true),
- Some(PrimitiveData(1, 1, 1, 1, 1, 1, true))))
-
- productTest(OptionalData(None, None, None, None, None, None, None, None))
-
- productTest(BoxedData(1, 1L, 1.0, 1.0f, 1.toShort, 1.toByte, true))
-
- productTest(BoxedData(null, null, null, null, null, null, null))
-
- productTest(RepeatedStruct(PrimitiveData(1, 1, 1, 1, 1, 1, true) :: Nil))
-
- productTest((1, "test", PrimitiveData(1, 1, 1, 1, 1, 1, true)))
-
- productTest(
- RepeatedData(
- Seq(1, 2),
- Seq(new Integer(1), null, new Integer(2)),
- Map(1 -> 2L),
- Map(1 -> null),
- PrimitiveData(1, 1, 1, 1, 1, 1, true)))
-
- productTest(NestedArray(Array(Array(1, -2, 3), null, Array(4, 5, -6))))
-
- productTest(("Seq[(String, String)]",
- Seq(("a", "b"))))
- productTest(("Seq[(Int, Int)]",
- Seq((1, 2))))
- productTest(("Seq[(Long, Long)]",
- Seq((1L, 2L))))
- productTest(("Seq[(Float, Float)]",
- Seq((1.toFloat, 2.toFloat))))
- productTest(("Seq[(Double, Double)]",
- Seq((1.toDouble, 2.toDouble))))
- productTest(("Seq[(Short, Short)]",
- Seq((1.toShort, 2.toShort))))
- productTest(("Seq[(Byte, Byte)]",
- Seq((1.toByte, 2.toByte))))
- productTest(("Seq[(Boolean, Boolean)]",
- Seq((true, false))))
-
- productTest(("ArrayBuffer[(String, String)]",
- ArrayBuffer(("a", "b"))))
- productTest(("ArrayBuffer[(Int, Int)]",
- ArrayBuffer((1, 2))))
- productTest(("ArrayBuffer[(Long, Long)]",
- ArrayBuffer((1L, 2L))))
- productTest(("ArrayBuffer[(Float, Float)]",
- ArrayBuffer((1.toFloat, 2.toFloat))))
- productTest(("ArrayBuffer[(Double, Double)]",
- ArrayBuffer((1.toDouble, 2.toDouble))))
- productTest(("ArrayBuffer[(Short, Short)]",
- ArrayBuffer((1.toShort, 2.toShort))))
- productTest(("ArrayBuffer[(Byte, Byte)]",
- ArrayBuffer((1.toByte, 2.toByte))))
- productTest(("ArrayBuffer[(Boolean, Boolean)]",
- ArrayBuffer((true, false))))
-
- productTest(("Seq[Seq[(Int, Int)]]",
- Seq(Seq((1, 2)))))
-
- encodeDecodeTest(
- 1 -> 10L,
- ExpressionEncoder.tuple(FlatEncoder[Int], FlatEncoder[Long]),
- "tuple with 2 flat encoders")
-
- encodeDecodeTest(
- (PrimitiveData(1, 1, 1, 1, 1, 1, true), (3, 30L)),
- ExpressionEncoder.tuple(ProductEncoder[PrimitiveData], ProductEncoder[(Int, Long)]),
- "tuple with 2 product encoders")
-
- encodeDecodeTest(
- (PrimitiveData(1, 1, 1, 1, 1, 1, true), 3),
- ExpressionEncoder.tuple(ProductEncoder[PrimitiveData], FlatEncoder[Int]),
- "tuple with flat encoder and product encoder")
-
- encodeDecodeTest(
- (3, PrimitiveData(1, 1, 1, 1, 1, 1, true)),
- ExpressionEncoder.tuple(FlatEncoder[Int], ProductEncoder[PrimitiveData]),
- "tuple with product encoder and flat encoder")
-
- encodeDecodeTest(
- (1, (10, 100L)),
- {
- val intEnc = FlatEncoder[Int]
- val longEnc = FlatEncoder[Long]
- ExpressionEncoder.tuple(intEnc, ExpressionEncoder.tuple(intEnc, longEnc))
- },
- "nested tuple encoder")
-
- private def productTest[T <: Product : TypeTag](input: T): Unit = {
- encodeDecodeTest(input, ProductEncoder[T], input.getClass.getSimpleName)
- }
-}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
index 7e5acbe851..6de3dd6265 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
@@ -21,7 +21,7 @@ import scala.collection.JavaConverters._
import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.function._
-import org.apache.spark.sql.catalyst.encoders.{FlatEncoder, ExpressionEncoder, encoderFor, OuterScopes}
+import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, encoderFor, OuterScopes}
import org.apache.spark.sql.catalyst.expressions.{Alias, CreateStruct, Attribute}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.execution.QueryExecution
@@ -242,7 +242,7 @@ class GroupedDataset[K, T] private[sql](
* Returns a [[Dataset]] that contains a tuple with each key and the number of items present
* for that key.
*/
- def count(): Dataset[(K, Long)] = agg(functions.count("*").as(FlatEncoder[Long]))
+ def count(): Dataset[(K, Long)] = agg(functions.count("*").as(ExpressionEncoder[Long]))
/**
* Applies the given function to each cogrouped data. For each unique group, the function will
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala
index 8471eea1b7..25ffdcde17 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala
@@ -17,10 +17,6 @@
package org.apache.spark.sql
-import org.apache.spark.sql.catalyst.encoders._
-import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
-import org.apache.spark.sql.execution.datasources.LogicalRelation
-
import scala.language.implicitConversions
import scala.reflect.runtime.universe.TypeTag
@@ -28,6 +24,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.types._
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.SpecificMutableRow
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.types.StructField
import org.apache.spark.unsafe.types.UTF8String
@@ -37,16 +34,16 @@ import org.apache.spark.unsafe.types.UTF8String
abstract class SQLImplicits {
protected def _sqlContext: SQLContext
- implicit def newProductEncoder[T <: Product : TypeTag]: Encoder[T] = ProductEncoder[T]
+ implicit def newProductEncoder[T <: Product : TypeTag]: Encoder[T] = ExpressionEncoder()
- implicit def newIntEncoder: Encoder[Int] = FlatEncoder[Int]
- implicit def newLongEncoder: Encoder[Long] = FlatEncoder[Long]
- implicit def newDoubleEncoder: Encoder[Double] = FlatEncoder[Double]
- implicit def newFloatEncoder: Encoder[Float] = FlatEncoder[Float]
- implicit def newByteEncoder: Encoder[Byte] = FlatEncoder[Byte]
- implicit def newShortEncoder: Encoder[Short] = FlatEncoder[Short]
- implicit def newBooleanEncoder: Encoder[Boolean] = FlatEncoder[Boolean]
- implicit def newStringEncoder: Encoder[String] = FlatEncoder[String]
+ implicit def newIntEncoder: Encoder[Int] = ExpressionEncoder()
+ implicit def newLongEncoder: Encoder[Long] = ExpressionEncoder()
+ implicit def newDoubleEncoder: Encoder[Double] = ExpressionEncoder()
+ implicit def newFloatEncoder: Encoder[Float] = ExpressionEncoder()
+ implicit def newByteEncoder: Encoder[Byte] = ExpressionEncoder()
+ implicit def newShortEncoder: Encoder[Short] = ExpressionEncoder()
+ implicit def newBooleanEncoder: Encoder[Boolean] = ExpressionEncoder()
+ implicit def newStringEncoder: Encoder[String] = ExpressionEncoder()
/**
* Creates a [[Dataset]] from an RDD.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index 95158de710..b27b1340cc 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -26,7 +26,7 @@ import scala.util.Try
import org.apache.spark.annotation.Experimental
import org.apache.spark.sql.catalyst.{SqlParser, ScalaReflection}
import org.apache.spark.sql.catalyst.analysis.{UnresolvedFunction, Star}
-import org.apache.spark.sql.catalyst.encoders.FlatEncoder
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.plans.logical.BroadcastHint
@@ -267,7 +267,7 @@ object functions extends LegacyFunctions {
* @since 1.3.0
*/
def count(columnName: String): TypedColumn[Any, Long] =
- count(Column(columnName)).as(FlatEncoder[Long])
+ count(Column(columnName)).as(ExpressionEncoder[Long])
/**
* Aggregate function: returns the number of distinct items in a group.