aboutsummaryrefslogtreecommitdiff
path: root/sql/catalyst
diff options
context:
space:
mode:
authorMichael Armbrust <michael@databricks.com>2015-10-13 17:09:17 -0700
committerMichael Armbrust <michael@databricks.com>2015-10-13 17:09:17 -0700
commit328d1b3e4bc39cce653342e04f9e08af12dd7ed8 (patch)
tree66f9910a45fc4ed6339578f4345563aeb9258476 /sql/catalyst
parent3889b1c7a96da1111946fa63ad69489b83468646 (diff)
downloadspark-328d1b3e4bc39cce653342e04f9e08af12dd7ed8.tar.gz
spark-328d1b3e4bc39cce653342e04f9e08af12dd7ed8.tar.bz2
spark-328d1b3e4bc39cce653342e04f9e08af12dd7ed8.zip
[SPARK-11090] [SQL] Constructor for Product types from InternalRow
This is a first draft of the ability to construct expressions that will take a catalyst internal row and construct a Product (case class or tuple) that has fields with the correct names. Support include: - Nested classes - Maps - Efficiently handling of arrays of primitive types Not yet supported: - Case classes that require custom collection types (i.e. List instead of Seq). Author: Michael Armbrust <michael@databricks.com> Closes #9100 from marmbrus/productContructor.
Diffstat (limited to 'sql/catalyst')
-rw-r--r--sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java4
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala302
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala14
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala26
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala154
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayBasedMapData.scala4
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayData.scala5
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/types/GenericArrayData.scala4
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala369
9 files changed, 723 insertions, 159 deletions
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java
index 796f8abec9..4c63abb071 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java
@@ -74,6 +74,10 @@ public class UnsafeArrayData extends ArrayData {
assert ordinal < numElements : "ordinal (" + ordinal + ") should < " + numElements;
}
+ public Object[] array() {
+ throw new UnsupportedOperationException("Only supported on GenericArrayData.");
+ }
+
/**
* Construct a new UnsafeArrayData. The resulting UnsafeArrayData won't be usable until
* `pointTo()` has been called, since the value returned by this constructor is equivalent
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index 8b733f2a0b..8edd6498e5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -17,6 +17,7 @@
package org.apache.spark.sql.catalyst
+import org.apache.spark.sql.catalyst.analysis.{UnresolvedExtractValue, UnresolvedAttribute}
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.Utils
@@ -80,6 +81,9 @@ trait ScalaReflection {
* Returns the Spark SQL DataType for a given scala type. Where this is not an exact mapping
* to a native type, an ObjectType is returned. Special handling is also used for Arrays including
* those that hold primitive types.
+ *
+ * Unlike `schemaFor`, this function doesn't do any massaging of types into the Spark SQL type
+ * system. As a result, ObjectType will be returned for things like boxed Integers
*/
def dataTypeFor(tpe: `Type`): DataType = tpe match {
case t if t <:< definitions.IntTpe => IntegerType
@@ -114,6 +118,298 @@ trait ScalaReflection {
}
}
+ /**
+ * Given a type `T` this function constructs and ObjectType that holds a class of type
+ * Array[T]. Special handling is performed for primitive types to map them back to their raw
+ * JVM form instead of the Scala Array that handles auto boxing.
+ */
+ def arrayClassFor(tpe: `Type`): DataType = {
+ val cls = tpe match {
+ case t if t <:< definitions.IntTpe => classOf[Array[Int]]
+ case t if t <:< definitions.LongTpe => classOf[Array[Long]]
+ case t if t <:< definitions.DoubleTpe => classOf[Array[Double]]
+ case t if t <:< definitions.FloatTpe => classOf[Array[Float]]
+ case t if t <:< definitions.ShortTpe => classOf[Array[Short]]
+ case t if t <:< definitions.ByteTpe => classOf[Array[Byte]]
+ case t if t <:< definitions.BooleanTpe => classOf[Array[Boolean]]
+ case other =>
+ // There is probably a better way to do this, but I couldn't find it...
+ val elementType = dataTypeFor(other).asInstanceOf[ObjectType].cls
+ java.lang.reflect.Array.newInstance(elementType, 1).getClass
+
+ }
+ ObjectType(cls)
+ }
+
+ /**
+ * Returns an expression that can be used to construct an object of type `T` given a an input
+ * row with a compatible schema. Fields of the row will be extracted using UnresolvedAttributes
+ * of the same name as the constructor arguments. Nested classes will have their fields accessed
+ * using UnresolvedExtractValue.
+ */
+ def constructorFor[T : TypeTag]: Expression = constructorFor(typeOf[T], None)
+
+ protected def constructorFor(
+ tpe: `Type`,
+ path: Option[Expression]): Expression = ScalaReflectionLock.synchronized {
+
+ /** Returns the current path with a sub-field extracted. */
+ def addToPath(part: String) =
+ path
+ .map(p => UnresolvedExtractValue(p, expressions.Literal(part)))
+ .getOrElse(UnresolvedAttribute(part))
+
+ /** Returns the current path or throws an error. */
+ def getPath = path.getOrElse(sys.error("Constructors must start at a class type"))
+
+ tpe match {
+ case t if !dataTypeFor(t).isInstanceOf[ObjectType] =>
+ getPath
+
+ case t if t <:< localTypeOf[Option[_]] =>
+ val TypeRef(_, _, Seq(optType)) = t
+ val boxedType = optType match {
+ // For primitive types we must manually box the primitive value.
+ case t if t <:< definitions.IntTpe => Some(classOf[java.lang.Integer])
+ case t if t <:< definitions.LongTpe => Some(classOf[java.lang.Long])
+ case t if t <:< definitions.DoubleTpe => Some(classOf[java.lang.Double])
+ case t if t <:< definitions.FloatTpe => Some(classOf[java.lang.Float])
+ case t if t <:< definitions.ShortTpe => Some(classOf[java.lang.Short])
+ case t if t <:< definitions.ByteTpe => Some(classOf[java.lang.Byte])
+ case t if t <:< definitions.BooleanTpe => Some(classOf[java.lang.Boolean])
+ case _ => None
+ }
+
+ boxedType.map { boxedType =>
+ val objectType = ObjectType(boxedType)
+ WrapOption(
+ objectType,
+ NewInstance(
+ boxedType,
+ getPath :: Nil,
+ propagateNull = true,
+ objectType))
+ }.getOrElse {
+ val className: String = optType.erasure.typeSymbol.asClass.fullName
+ val cls = Utils.classForName(className)
+ val objectType = ObjectType(cls)
+
+ WrapOption(objectType, constructorFor(optType, path))
+ }
+
+ case t if t <:< localTypeOf[java.lang.Integer] =>
+ val boxedType = classOf[java.lang.Integer]
+ val objectType = ObjectType(boxedType)
+ NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType)
+
+ case t if t <:< localTypeOf[java.lang.Long] =>
+ val boxedType = classOf[java.lang.Long]
+ val objectType = ObjectType(boxedType)
+ NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType)
+
+ case t if t <:< localTypeOf[java.lang.Double] =>
+ val boxedType = classOf[java.lang.Double]
+ val objectType = ObjectType(boxedType)
+ NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType)
+
+ case t if t <:< localTypeOf[java.lang.Float] =>
+ val boxedType = classOf[java.lang.Float]
+ val objectType = ObjectType(boxedType)
+ NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType)
+
+ case t if t <:< localTypeOf[java.lang.Short] =>
+ val boxedType = classOf[java.lang.Short]
+ val objectType = ObjectType(boxedType)
+ NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType)
+
+ case t if t <:< localTypeOf[java.lang.Byte] =>
+ val boxedType = classOf[java.lang.Byte]
+ val objectType = ObjectType(boxedType)
+ NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType)
+
+ case t if t <:< localTypeOf[java.lang.Boolean] =>
+ val boxedType = classOf[java.lang.Boolean]
+ val objectType = ObjectType(boxedType)
+ NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType)
+
+ case t if t <:< localTypeOf[java.sql.Date] =>
+ StaticInvoke(
+ DateTimeUtils,
+ ObjectType(classOf[java.sql.Date]),
+ "toJavaDate",
+ getPath :: Nil,
+ propagateNull = true)
+
+ case t if t <:< localTypeOf[java.sql.Timestamp] =>
+ StaticInvoke(
+ DateTimeUtils,
+ ObjectType(classOf[java.sql.Timestamp]),
+ "toJavaTimestamp",
+ getPath :: Nil,
+ propagateNull = true)
+
+ case t if t <:< localTypeOf[java.lang.String] =>
+ Invoke(getPath, "toString", ObjectType(classOf[String]))
+
+ case t if t <:< localTypeOf[java.math.BigDecimal] =>
+ Invoke(getPath, "toJavaBigDecimal", ObjectType(classOf[java.math.BigDecimal]))
+
+ case t if t <:< localTypeOf[Array[_]] =>
+ val TypeRef(_, _, Seq(elementType)) = t
+ val elementDataType = dataTypeFor(elementType)
+ val Schema(dataType, nullable) = schemaFor(elementType)
+
+ val primitiveMethod = elementType match {
+ case t if t <:< definitions.IntTpe => Some("toIntArray")
+ case t if t <:< definitions.LongTpe => Some("toLongArray")
+ case t if t <:< definitions.DoubleTpe => Some("toDoubleArray")
+ case t if t <:< definitions.FloatTpe => Some("toFloatArray")
+ case t if t <:< definitions.ShortTpe => Some("toShortArray")
+ case t if t <:< definitions.ByteTpe => Some("toByteArray")
+ case t if t <:< definitions.BooleanTpe => Some("toBooleanArray")
+ case _ => None
+ }
+
+ primitiveMethod.map { method =>
+ Invoke(getPath, method, dataTypeFor(t))
+ }.getOrElse {
+ val returnType = dataTypeFor(t)
+ Invoke(
+ MapObjects(p => constructorFor(elementType, Some(p)), getPath, dataType),
+ "array",
+ returnType)
+ }
+
+ case t if t <:< localTypeOf[Map[_, _]] =>
+ val TypeRef(_, _, Seq(keyType, valueType)) = t
+ val Schema(keyDataType, _) = schemaFor(keyType)
+ val Schema(valueDataType, valueNullable) = schemaFor(valueType)
+
+ val primitiveMethodKey = keyType match {
+ case t if t <:< definitions.IntTpe => Some("toIntArray")
+ case t if t <:< definitions.LongTpe => Some("toLongArray")
+ case t if t <:< definitions.DoubleTpe => Some("toDoubleArray")
+ case t if t <:< definitions.FloatTpe => Some("toFloatArray")
+ case t if t <:< definitions.ShortTpe => Some("toShortArray")
+ case t if t <:< definitions.ByteTpe => Some("toByteArray")
+ case t if t <:< definitions.BooleanTpe => Some("toBooleanArray")
+ case _ => None
+ }
+
+ val keyData =
+ Invoke(
+ MapObjects(
+ p => constructorFor(keyType, Some(p)),
+ Invoke(getPath, "keyArray", ArrayType(keyDataType)),
+ keyDataType),
+ "array",
+ ObjectType(classOf[Array[Any]]))
+
+ val primitiveMethodValue = valueType match {
+ case t if t <:< definitions.IntTpe => Some("toIntArray")
+ case t if t <:< definitions.LongTpe => Some("toLongArray")
+ case t if t <:< definitions.DoubleTpe => Some("toDoubleArray")
+ case t if t <:< definitions.FloatTpe => Some("toFloatArray")
+ case t if t <:< definitions.ShortTpe => Some("toShortArray")
+ case t if t <:< definitions.ByteTpe => Some("toByteArray")
+ case t if t <:< definitions.BooleanTpe => Some("toBooleanArray")
+ case _ => None
+ }
+
+ val valueData =
+ Invoke(
+ MapObjects(
+ p => constructorFor(valueType, Some(p)),
+ Invoke(getPath, "valueArray", ArrayType(valueDataType)),
+ valueDataType),
+ "array",
+ ObjectType(classOf[Array[Any]]))
+
+ StaticInvoke(
+ ArrayBasedMapData,
+ ObjectType(classOf[Map[_, _]]),
+ "toScalaMap",
+ keyData :: valueData :: Nil)
+
+ case t if t <:< localTypeOf[Seq[_]] =>
+ val TypeRef(_, _, Seq(elementType)) = t
+ val elementDataType = dataTypeFor(elementType)
+ val Schema(dataType, nullable) = schemaFor(elementType)
+
+ // Avoid boxing when possible by just wrapping a primitive array.
+ val primitiveMethod = elementType match {
+ case _ if nullable => None
+ case t if t <:< definitions.IntTpe => Some("toIntArray")
+ case t if t <:< definitions.LongTpe => Some("toLongArray")
+ case t if t <:< definitions.DoubleTpe => Some("toDoubleArray")
+ case t if t <:< definitions.FloatTpe => Some("toFloatArray")
+ case t if t <:< definitions.ShortTpe => Some("toShortArray")
+ case t if t <:< definitions.ByteTpe => Some("toByteArray")
+ case t if t <:< definitions.BooleanTpe => Some("toBooleanArray")
+ case _ => None
+ }
+
+ val arrayData = primitiveMethod.map { method =>
+ Invoke(getPath, method, arrayClassFor(elementType))
+ }.getOrElse {
+ Invoke(
+ MapObjects(p => constructorFor(elementType, Some(p)), getPath, dataType),
+ "array",
+ arrayClassFor(elementType))
+ }
+
+ StaticInvoke(
+ scala.collection.mutable.WrappedArray,
+ ObjectType(classOf[Seq[_]]),
+ "make",
+ arrayData :: Nil)
+
+
+ case t if t <:< localTypeOf[Product] =>
+ val formalTypeArgs = t.typeSymbol.asClass.typeParams
+ val TypeRef(_, _, actualTypeArgs) = t
+ val constructorSymbol = t.member(nme.CONSTRUCTOR)
+ val params = if (constructorSymbol.isMethod) {
+ constructorSymbol.asMethod.paramss
+ } else {
+ // Find the primary constructor, and use its parameter ordering.
+ val primaryConstructorSymbol: Option[Symbol] =
+ constructorSymbol.asTerm.alternatives.find(s =>
+ s.isMethod && s.asMethod.isPrimaryConstructor)
+
+ if (primaryConstructorSymbol.isEmpty) {
+ sys.error("Internal SQL error: Product object did not have a primary constructor.")
+ } else {
+ primaryConstructorSymbol.get.asMethod.paramss
+ }
+ }
+
+ val className: String = t.erasure.typeSymbol.asClass.fullName
+ val cls = Utils.classForName(className)
+
+ val arguments = params.head.map { p =>
+ val fieldName = p.name.toString
+ val fieldType = p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs)
+ val dataType = dataTypeFor(fieldType)
+
+ constructorFor(fieldType, Some(addToPath(fieldName)))
+ }
+
+ val newInstance = NewInstance(cls, arguments, propagateNull = false, ObjectType(cls))
+
+ if (path.nonEmpty) {
+ expressions.If(
+ IsNull(getPath),
+ expressions.Literal.create(null, ObjectType(cls)),
+ newInstance
+ )
+ } else {
+ newInstance
+ }
+
+ }
+ }
+
/** Returns expressions for extracting all the fields from the given type. */
def extractorsFor[T : TypeTag](inputObject: Expression): Seq[Expression] = {
ScalaReflectionLock.synchronized {
@@ -227,13 +523,13 @@ trait ScalaReflection {
val elementDataType = dataTypeFor(elementType)
val Schema(dataType, nullable) = schemaFor(elementType)
- if (!elementDataType.isInstanceOf[AtomicType]) {
- MapObjects(extractorFor(_, elementType), inputObject, elementDataType)
- } else {
+ if (dataType.isInstanceOf[AtomicType]) {
NewInstance(
classOf[GenericArrayData],
inputObject :: Nil,
dataType = ArrayType(dataType, nullable))
+ } else {
+ MapObjects(extractorFor(_, elementType), inputObject, elementDataType)
}
case t if t <:< localTypeOf[Map[_, _]] =>
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala
index 8dacfa9477..3618247d5d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala
@@ -17,8 +17,10 @@
package org.apache.spark.sql.catalyst.encoders
+
import scala.reflect.ClassTag
+import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.types.StructType
@@ -41,4 +43,16 @@ trait Encoder[T] {
* copy the result before making another call if required.
*/
def toRow(t: T): InternalRow
+
+ /**
+ * Returns an object of type `T`, extracting the required values from the provided row. Note that
+ * you must bind` and encoder to a specific schema before you can call this function.
+ */
+ def fromRow(row: InternalRow): T
+
+ /**
+ * Returns a new copy of this encoder, where the expressions used by `fromRow` are bound to the
+ * given schema
+ */
+ def bind(schema: Seq[Attribute]): Encoder[T]
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala
index a23613673e..b0381880c3 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala
@@ -17,8 +17,10 @@
package org.apache.spark.sql.catalyst.encoders
+import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection
+import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection, GenerateUnsafeProjection}
+import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project}
import scala.reflect.ClassTag
import scala.reflect.runtime.universe.{typeTag, TypeTag}
@@ -31,7 +33,7 @@ import org.apache.spark.sql.types.{ObjectType, StructType}
* internal binary representation.
*/
object ProductEncoder {
- def apply[T <: Product : TypeTag]: Encoder[T] = {
+ def apply[T <: Product : TypeTag]: ClassEncoder[T] = {
// We convert the not-serializable TypeTag into StructType and ClassTag.
val schema = ScalaReflection.schemaFor[T].dataType.asInstanceOf[StructType]
val mirror = typeTag[T].mirror
@@ -39,7 +41,8 @@ object ProductEncoder {
val inputObject = BoundReference(0, ObjectType(cls), nullable = true)
val extractExpressions = ScalaReflection.extractorsFor[T](inputObject)
- new ClassEncoder[T](schema, extractExpressions, ClassTag[T](cls))
+ val constructExpression = ScalaReflection.constructorFor[T]
+ new ClassEncoder[T](schema, extractExpressions, constructExpression, ClassTag[T](cls))
}
}
@@ -54,14 +57,31 @@ object ProductEncoder {
case class ClassEncoder[T](
schema: StructType,
extractExpressions: Seq[Expression],
+ constructExpression: Expression,
clsTag: ClassTag[T])
extends Encoder[T] {
private val extractProjection = GenerateUnsafeProjection.generate(extractExpressions)
private val inputRow = new GenericMutableRow(1)
+ private lazy val constructProjection = GenerateSafeProjection.generate(constructExpression :: Nil)
+ private val dataType = ObjectType(clsTag.runtimeClass)
+
override def toRow(t: T): InternalRow = {
inputRow(0) = t
extractProjection(inputRow)
}
+
+ override def fromRow(row: InternalRow): T = {
+ constructProjection(row).get(0, dataType).asInstanceOf[T]
+ }
+
+ override def bind(schema: Seq[Attribute]): ClassEncoder[T] = {
+ val plan = Project(Alias(constructExpression, "object")() :: Nil, LocalRelation(schema))
+ val analyzedPlan = SimpleAnalyzer.execute(plan)
+ val resolvedExpression = analyzedPlan.expressions.head.children.head
+ val boundExpression = BindReferences.bindReference(resolvedExpression, schema)
+
+ copy(constructExpression = boundExpression)
+ }
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
index e1f960a6e6..e8c1c93cf5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
@@ -17,9 +17,12 @@
package org.apache.spark.sql.catalyst.expressions
+import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer
+import org.apache.spark.sql.catalyst.plans.logical.{Project, LocalRelation}
+
import scala.language.existentials
-import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.{ScalaReflection, InternalRow}
import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext}
import org.apache.spark.sql.types._
@@ -48,7 +51,7 @@ case class StaticInvoke(
case other => other.getClass.getName.stripSuffix("$")
}
override def nullable: Boolean = true
- override def children: Seq[Expression] = Nil
+ override def children: Seq[Expression] = arguments
override def eval(input: InternalRow): Any =
throw new UnsupportedOperationException("Only code-generated evaluation is supported.")
@@ -69,7 +72,7 @@ case class StaticInvoke(
s"""
${argGen.map(_.code).mkString("\n")}
- boolean ${ev.isNull} = true;
+ boolean ${ev.isNull} = !$argsNonNull;
$javaType ${ev.value} = ${ctx.defaultValue(dataType)};
if ($argsNonNull) {
@@ -81,8 +84,8 @@ case class StaticInvoke(
s"""
${argGen.map(_.code).mkString("\n")}
- final boolean ${ev.isNull} = ${ev.value} == null;
$javaType ${ev.value} = $objectName.$functionName($argString);
+ final boolean ${ev.isNull} = ${ev.value} == null;
"""
}
}
@@ -92,6 +95,10 @@ case class StaticInvoke(
* Calls the specified function on an object, optionally passing arguments. If the `targetObject`
* expression evaluates to null then null will be returned.
*
+ * In some cases, due to erasure, the schema may expect a primitive type when in fact the method
+ * is returning java.lang.Object. In this case, we will generate code that attempts to unbox the
+ * value automatically.
+ *
* @param targetObject An expression that will return the object to call the method on.
* @param functionName The name of the method to call.
* @param dataType The expected return type of the function.
@@ -109,6 +116,35 @@ case class Invoke(
override def eval(input: InternalRow): Any =
throw new UnsupportedOperationException("Only code-generated evaluation is supported.")
+ lazy val method = targetObject.dataType match {
+ case ObjectType(cls) =>
+ cls
+ .getMethods
+ .find(_.getName == functionName)
+ .getOrElse(sys.error(s"Couldn't find $functionName on $cls"))
+ .getReturnType
+ .getName
+ case _ => ""
+ }
+
+ lazy val unboxer = (dataType, method) match {
+ case (IntegerType, "java.lang.Object") => (s: String) =>
+ s"((java.lang.Integer)$s).intValue()"
+ case (LongType, "java.lang.Object") => (s: String) =>
+ s"((java.lang.Long)$s).longValue()"
+ case (FloatType, "java.lang.Object") => (s: String) =>
+ s"((java.lang.Float)$s).floatValue()"
+ case (ShortType, "java.lang.Object") => (s: String) =>
+ s"((java.lang.Short)$s).shortValue()"
+ case (ByteType, "java.lang.Object") => (s: String) =>
+ s"((java.lang.Byte)$s).byteValue()"
+ case (DoubleType, "java.lang.Object") => (s: String) =>
+ s"((java.lang.Double)$s).doubleValue()"
+ case (BooleanType, "java.lang.Object") => (s: String) =>
+ s"((java.lang.Boolean)$s).booleanValue()"
+ case _ => identity[String] _
+ }
+
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
val javaType = ctx.javaType(dataType)
val obj = targetObject.gen(ctx)
@@ -123,6 +159,8 @@ case class Invoke(
""
}
+ val value = unboxer(s"${obj.value}.$functionName($argString)")
+
s"""
${obj.code}
${argGen.map(_.code).mkString("\n")}
@@ -130,7 +168,7 @@ case class Invoke(
boolean ${ev.isNull} = ${obj.value} == null;
$javaType ${ev.value} =
${ev.isNull} ?
- ${ctx.defaultValue(dataType)} : ($javaType) ${obj.value}.$functionName($argString);
+ ${ctx.defaultValue(dataType)} : ($javaType) $value;
$objNullCheck
"""
}
@@ -190,8 +228,8 @@ case class NewInstance(
s"""
${argGen.map(_.code).mkString("\n")}
- final boolean ${ev.isNull} = ${ev.value} == null;
$javaType ${ev.value} = new $className($argString);
+ final boolean ${ev.isNull} = ${ev.value} == null;
"""
}
}
@@ -210,8 +248,6 @@ case class UnwrapOption(
override def nullable: Boolean = true
- override def children: Seq[Expression] = Nil
-
override def inputTypes: Seq[AbstractDataType] = ObjectType :: Nil
override def eval(input: InternalRow): Any =
@@ -231,6 +267,43 @@ case class UnwrapOption(
}
}
+/**
+ * Converts the result of evaluating `child` into an option, checking both the isNull bit and
+ * (in the case of reference types) equality with null.
+ * @param optionType The datatype to be held inside of the Option.
+ * @param child The expression to evaluate and wrap.
+ */
+case class WrapOption(optionType: DataType, child: Expression)
+ extends UnaryExpression with ExpectsInputTypes {
+
+ override def dataType: DataType = ObjectType(classOf[Option[_]])
+
+ override def nullable: Boolean = true
+
+ override def inputTypes: Seq[AbstractDataType] = ObjectType :: Nil
+
+ override def eval(input: InternalRow): Any =
+ throw new UnsupportedOperationException("Only code-generated evaluation is supported")
+
+ override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
+ val javaType = ctx.javaType(optionType)
+ val inputObject = child.gen(ctx)
+
+ s"""
+ ${inputObject.code}
+
+ boolean ${ev.isNull} = false;
+ scala.Option<$javaType> ${ev.value} =
+ ${inputObject.isNull} ?
+ scala.Option$$.MODULE$$.apply(null) : new scala.Some(${inputObject.value});
+ """
+ }
+}
+
+/**
+ * A place holder for the loop variable used in [[MapObjects]]. This should never be constructed
+ * manually, but will instead be passed into the provided lambda function.
+ */
case class LambdaVariable(value: String, isNull: String, dataType: DataType) extends Expression {
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String =
@@ -251,7 +324,7 @@ case class LambdaVariable(value: String, isNull: String, dataType: DataType) ext
* as an ArrayType. This is similar to a typical map operation, but where the lambda function
* is expressed using catalyst expressions.
*
- * The following collection ObjectTypes are currently supported: Seq, Array
+ * The following collection ObjectTypes are currently supported: Seq, Array, ArrayData
*
* @param function A function that returns an expression, given an attribute that can be used
* to access the current value. This is does as a lambda function so that
@@ -265,14 +338,32 @@ case class MapObjects(
inputData: Expression,
elementType: DataType) extends Expression {
- private val loopAttribute = AttributeReference("loopVar", elementType)()
- private val completeFunction = function(loopAttribute)
+ private lazy val loopAttribute = AttributeReference("loopVar", elementType)()
+ private lazy val completeFunction = function(loopAttribute)
- private val (lengthFunction, itemAccessor) = inputData.dataType match {
- case ObjectType(cls) if cls.isAssignableFrom(classOf[Seq[_]]) =>
- (".size()", (i: String) => s".apply($i)")
+ private lazy val (lengthFunction, itemAccessor, primitiveElement) = inputData.dataType match {
+ case ObjectType(cls) if classOf[Seq[_]].isAssignableFrom(cls) =>
+ (".size()", (i: String) => s".apply($i)", false)
case ObjectType(cls) if cls.isArray =>
- (".length", (i: String) => s"[$i]")
+ (".length", (i: String) => s"[$i]", false)
+ case ArrayType(s: StructType, _) =>
+ (".numElements()", (i: String) => s".getStruct($i, ${s.size})", false)
+ case ArrayType(a: ArrayType, _) =>
+ (".numElements()", (i: String) => s".getArray($i)", true)
+ case ArrayType(IntegerType, _) =>
+ (".numElements()", (i: String) => s".getInt($i)", true)
+ case ArrayType(LongType, _) =>
+ (".numElements()", (i: String) => s".getLong($i)", true)
+ case ArrayType(FloatType, _) =>
+ (".numElements()", (i: String) => s".getFloat($i)", true)
+ case ArrayType(DoubleType, _) =>
+ (".numElements()", (i: String) => s".getDouble($i)", true)
+ case ArrayType(ByteType, _) =>
+ (".numElements()", (i: String) => s".getByte($i)", true)
+ case ArrayType(ShortType, _) =>
+ (".numElements()", (i: String) => s".getShort($i)", true)
+ case ArrayType(BooleanType, _) =>
+ (".numElements()", (i: String) => s".getBoolean($i)", true)
}
override def nullable: Boolean = true
@@ -294,15 +385,38 @@ case class MapObjects(
val loopIsNull = ctx.freshName("loopIsNull")
val loopVariable = LambdaVariable(loopValue, loopIsNull, elementType)
- val boundFunction = completeFunction transform {
+ val substitutedFunction = completeFunction transform {
case a: AttributeReference if a == loopAttribute => loopVariable
}
+ // A hack to run this through the analyzer (to bind extractions).
+ val boundFunction =
+ SimpleAnalyzer.execute(Project(Alias(substitutedFunction, "")() :: Nil, LocalRelation(Nil)))
+ .expressions.head.children.head
val genFunction = boundFunction.gen(ctx)
val dataLength = ctx.freshName("dataLength")
val convertedArray = ctx.freshName("convertedArray")
val loopIndex = ctx.freshName("loopIndex")
+ val convertedType = ctx.javaType(boundFunction.dataType)
+
+ // Because of the way Java defines nested arrays, we have to handle the syntax specially.
+ // Specifically, we have to insert the [$dataLength] in between the type and any extra nested
+ // array declarations (i.e. new String[1][]).
+ val arrayConstructor = if (convertedType contains "[]") {
+ val rawType = convertedType.takeWhile(_ != '[')
+ val arrayPart = convertedType.reverse.takeWhile(c => c == '[' || c == ']').reverse
+ s"new $rawType[$dataLength]$arrayPart"
+ } else {
+ s"new $convertedType[$dataLength]"
+ }
+
+ val loopNullCheck = if (primitiveElement) {
+ s"boolean $loopIsNull = ${genInputData.value}.isNullAt($loopIndex);"
+ } else {
+ s"boolean $loopIsNull = ${genInputData.isNull} || $loopValue == null;"
+ }
+
s"""
${genInputData.code}
@@ -310,19 +424,19 @@ case class MapObjects(
$javaType ${ev.value} = ${ctx.defaultValue(dataType)};
if (!${ev.isNull}) {
- Object[] $convertedArray = null;
+ $convertedType[] $convertedArray = null;
int $dataLength = ${genInputData.value}$lengthFunction;
- $convertedArray = new Object[$dataLength];
+ $convertedArray = $arrayConstructor;
int $loopIndex = 0;
while ($loopIndex < $dataLength) {
$elementJavaType $loopValue =
($elementJavaType)${genInputData.value}${itemAccessor(loopIndex)};
- boolean $loopIsNull = $loopValue == null;
+ $loopNullCheck
${genFunction.code}
- $convertedArray[$loopIndex] = ${genFunction.value};
+ $convertedArray[$loopIndex] = ($convertedType)${genFunction.value};
$loopIndex += 1;
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayBasedMapData.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayBasedMapData.scala
index 52069598ee..5f22e59d5f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayBasedMapData.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayBasedMapData.scala
@@ -62,4 +62,8 @@ object ArrayBasedMapData {
val values = map.valueArray.asInstanceOf[GenericArrayData].array
keys.zip(values).toMap
}
+
+ def toScalaMap(keys: Array[Any], values: Array[Any]): Map[Any, Any] = {
+ keys.zip(values).toMap
+ }
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayData.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayData.scala
index 642c56f12d..b4ea300f5f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayData.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayData.scala
@@ -26,6 +26,8 @@ abstract class ArrayData extends SpecializedGetters with Serializable {
def copy(): ArrayData
+ def array: Array[Any]
+
def toBooleanArray(): Array[Boolean] = {
val size = numElements()
val values = new Array[Boolean](size)
@@ -103,6 +105,9 @@ abstract class ArrayData extends SpecializedGetters with Serializable {
values
}
+ def toObjectArray(elementType: DataType): Array[AnyRef] =
+ toArray[AnyRef](elementType: DataType)
+
def toArray[T: ClassTag](elementType: DataType): Array[T] = {
val size = numElements()
val values = new Array[T](size)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/GenericArrayData.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/GenericArrayData.scala
index c381603327..9448d88d6c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/GenericArrayData.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/GenericArrayData.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.types
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
-class GenericArrayData(private[sql] val array: Array[Any]) extends ArrayData {
+class GenericArrayData(val array: Array[Any]) extends ArrayData {
def this(seq: scala.collection.GenIterable[Any]) = this(seq.toArray)
@@ -29,6 +29,8 @@ class GenericArrayData(private[sql] val array: Array[Any]) extends ArrayData {
def this(primitiveArray: Array[Long]) = this(primitiveArray.toSeq)
def this(primitiveArray: Array[Float]) = this(primitiveArray.toSeq)
def this(primitiveArray: Array[Double]) = this(primitiveArray.toSeq)
+ def this(primitiveArray: Array[Short]) = this(primitiveArray.toSeq)
+ def this(primitiveArray: Array[Byte]) = this(primitiveArray.toSeq)
def this(primitiveArray: Array[Boolean]) = this(primitiveArray.toSeq)
override def copy(): ArrayData = new GenericArrayData(array.clone())
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala
index 99c993d3fe..02e43ddb35 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala
@@ -17,158 +17,263 @@
package org.apache.spark.sql.catalyst.encoders
-import java.sql.{Date, Timestamp}
+import java.util
+
+import org.apache.spark.sql.types.{StructField, ArrayType, ArrayData}
+
+import scala.collection.mutable.ArrayBuffer
+import scala.reflect.runtime.universe._
import org.apache.spark.SparkFunSuite
-import org.apache.spark.sql.catalyst.ScalaReflection._
-import org.apache.spark.sql.catalyst.expressions.UnsafeProjection
import org.apache.spark.sql.catalyst._
-
case class RepeatedStruct(s: Seq[PrimitiveData])
case class NestedArray(a: Array[Array[Int]])
-class ProductEncoderSuite extends SparkFunSuite {
+case class BoxedData(
+ intField: java.lang.Integer,
+ longField: java.lang.Long,
+ doubleField: java.lang.Double,
+ floatField: java.lang.Float,
+ shortField: java.lang.Short,
+ byteField: java.lang.Byte,
+ booleanField: java.lang.Boolean)
- test("convert PrimitiveData to InternalRow") {
- val inputData = PrimitiveData(1, 1, 1, 1, 1, 1, true)
- val encoder = ProductEncoder[PrimitiveData]
- val convertedData = encoder.toRow(inputData)
-
- assert(convertedData.getInt(0) == 1)
- assert(convertedData.getLong(1) == 1.toLong)
- assert(convertedData.getDouble(2) == 1.toDouble)
- assert(convertedData.getFloat(3) == 1.toFloat)
- assert(convertedData.getShort(4) == 1.toShort)
- assert(convertedData.getByte(5) == 1.toByte)
- assert(convertedData.getBoolean(6) == true)
- }
+case class RepeatedData(
+ arrayField: Seq[Int],
+ arrayFieldContainsNull: Seq[java.lang.Integer],
+ mapField: scala.collection.Map[Int, Long],
+ mapFieldNull: scala.collection.Map[Int, java.lang.Long],
+ structField: PrimitiveData)
- test("convert Some[_] to InternalRow") {
- val primitiveData = PrimitiveData(1, 1, 1, 1, 1, 1, true)
- val inputData = OptionalData(Some(2), Some(2), Some(2), Some(2), Some(2), Some(2), Some(true),
- Some(primitiveData))
-
- val encoder = ProductEncoder[OptionalData]
- val convertedData = encoder.toRow(inputData)
-
- assert(convertedData.getInt(0) == 2)
- assert(convertedData.getLong(1) == 2.toLong)
- assert(convertedData.getDouble(2) == 2.toDouble)
- assert(convertedData.getFloat(3) == 2.toFloat)
- assert(convertedData.getShort(4) == 2.toShort)
- assert(convertedData.getByte(5) == 2.toByte)
- assert(convertedData.getBoolean(6) == true)
-
- val nestedRow = convertedData.getStruct(7, 7)
- assert(nestedRow.getInt(0) == 1)
- assert(nestedRow.getLong(1) == 1.toLong)
- assert(nestedRow.getDouble(2) == 1.toDouble)
- assert(nestedRow.getFloat(3) == 1.toFloat)
- assert(nestedRow.getShort(4) == 1.toShort)
- assert(nestedRow.getByte(5) == 1.toByte)
- assert(nestedRow.getBoolean(6) == true)
- }
+case class SpecificCollection(l: List[Int])
- test("convert None to InternalRow") {
- val inputData = OptionalData(None, None, None, None, None, None, None, None)
- val encoder = ProductEncoder[OptionalData]
- val convertedData = encoder.toRow(inputData)
-
- assert(convertedData.isNullAt(0))
- assert(convertedData.isNullAt(1))
- assert(convertedData.isNullAt(2))
- assert(convertedData.isNullAt(3))
- assert(convertedData.isNullAt(4))
- assert(convertedData.isNullAt(5))
- assert(convertedData.isNullAt(6))
- assert(convertedData.isNullAt(7))
- }
+class ProductEncoderSuite extends SparkFunSuite {
- test("convert nullable but present data to InternalRow") {
- val inputData = NullableData(
- 1, 1L, 1.0, 1.0f, 1.toShort, 1.toByte, true, "test", new java.math.BigDecimal(1), new Date(0),
- new Timestamp(0), Array[Byte](1, 2, 3))
-
- val encoder = ProductEncoder[NullableData]
- val convertedData = encoder.toRow(inputData)
-
- assert(convertedData.getInt(0) == 1)
- assert(convertedData.getLong(1) == 1.toLong)
- assert(convertedData.getDouble(2) == 1.toDouble)
- assert(convertedData.getFloat(3) == 1.toFloat)
- assert(convertedData.getShort(4) == 1.toShort)
- assert(convertedData.getByte(5) == 1.toByte)
- assert(convertedData.getBoolean(6) == true)
- }
+ encodeDecodeTest(PrimitiveData(1, 1, 1, 1, 1, 1, true))
- test("convert nullable data to InternalRow") {
- val inputData =
- NullableData(null, null, null, null, null, null, null, null, null, null, null, null)
-
- val encoder = ProductEncoder[NullableData]
- val convertedData = encoder.toRow(inputData)
-
- assert(convertedData.isNullAt(0))
- assert(convertedData.isNullAt(1))
- assert(convertedData.isNullAt(2))
- assert(convertedData.isNullAt(3))
- assert(convertedData.isNullAt(4))
- assert(convertedData.isNullAt(5))
- assert(convertedData.isNullAt(6))
- assert(convertedData.isNullAt(7))
- assert(convertedData.isNullAt(8))
- assert(convertedData.isNullAt(9))
- assert(convertedData.isNullAt(10))
- assert(convertedData.isNullAt(11))
- }
+ // TODO: Support creating specific subclasses of Seq.
+ ignore("Specific collection types") { encodeDecodeTest(SpecificCollection(1 :: Nil)) }
- test("convert repeated struct") {
- val inputData = RepeatedStruct(PrimitiveData(1, 1, 1, 1, 1, 1, true) :: Nil)
- val encoder = ProductEncoder[RepeatedStruct]
-
- val converted = encoder.toRow(inputData)
- val convertedStruct = converted.getArray(0).getStruct(0, 7)
- assert(convertedStruct.getInt(0) == 1)
- assert(convertedStruct.getLong(1) == 1.toLong)
- assert(convertedStruct.getDouble(2) == 1.toDouble)
- assert(convertedStruct.getFloat(3) == 1.toFloat)
- assert(convertedStruct.getShort(4) == 1.toShort)
- assert(convertedStruct.getByte(5) == 1.toByte)
- assert(convertedStruct.getBoolean(6) == true)
- }
+ encodeDecodeTest(
+ OptionalData(
+ Some(2), Some(2), Some(2), Some(2), Some(2), Some(2), Some(true),
+ Some(PrimitiveData(1, 1, 1, 1, 1, 1, true))))
- test("convert nested seq") {
- val convertedData = ProductEncoder[Tuple1[Seq[Seq[Int]]]].toRow(Tuple1(Seq(Seq(1))))
- assert(convertedData.getArray(0).getArray(0).getInt(0) == 1)
+ encodeDecodeTest(OptionalData(None, None, None, None, None, None, None, None))
- val convertedData2 = ProductEncoder[Tuple1[Seq[Seq[Seq[Int]]]]].toRow(Tuple1(Seq(Seq(Seq(1)))))
- assert(convertedData2.getArray(0).getArray(0).getArray(0).getInt(0) == 1)
- }
+ encodeDecodeTest(
+ BoxedData(1, 1L, 1.0, 1.0f, 1.toShort, 1.toByte, true))
- test("convert nested array") {
- val convertedData = ProductEncoder[Tuple1[Array[Array[Int]]]].toRow(Tuple1(Array(Array(1))))
- }
+ encodeDecodeTest(
+ BoxedData(null, null, null, null, null, null, null))
+
+ encodeDecodeTest(
+ RepeatedStruct(PrimitiveData(1, 1, 1, 1, 1, 1, true) :: Nil))
- test("convert complex") {
- val inputData = ComplexData(
+ encodeDecodeTest(
+ RepeatedData(
Seq(1, 2),
- Array(1, 2),
- 1 :: 2 :: Nil,
Seq(new Integer(1), null, new Integer(2)),
Map(1 -> 2L),
- Map(1 -> new java.lang.Long(2)),
- PrimitiveData(1, 1, 1, 1, 1, 1, true),
- Array(Array(1)))
-
- val encoder = ProductEncoder[ComplexData]
- val convertedData = encoder.toRow(inputData)
-
- assert(!convertedData.isNullAt(0))
- val seq = convertedData.getArray(0)
- assert(seq.numElements() == 2)
- assert(seq.getInt(0) == 1)
- assert(seq.getInt(1) == 2)
+ Map(1 -> null),
+ PrimitiveData(1, 1, 1, 1, 1, 1, true)))
+
+ encodeDecodeTest(("nullable Seq[Integer]", Seq[Integer](1, null)))
+
+ encodeDecodeTest(("Seq[(String, String)]",
+ Seq(("a", "b"))))
+ encodeDecodeTest(("Seq[(Int, Int)]",
+ Seq((1, 2))))
+ encodeDecodeTest(("Seq[(Long, Long)]",
+ Seq((1L, 2L))))
+ encodeDecodeTest(("Seq[(Float, Float)]",
+ Seq((1.toFloat, 2.toFloat))))
+ encodeDecodeTest(("Seq[(Double, Double)]",
+ Seq((1.toDouble, 2.toDouble))))
+ encodeDecodeTest(("Seq[(Short, Short)]",
+ Seq((1.toShort, 2.toShort))))
+ encodeDecodeTest(("Seq[(Byte, Byte)]",
+ Seq((1.toByte, 2.toByte))))
+ encodeDecodeTest(("Seq[(Boolean, Boolean)]",
+ Seq((true, false))))
+
+ // TODO: Decoding/encoding of complex maps.
+ ignore("complex maps") {
+ encodeDecodeTest(("Map[Int, (String, String)]",
+ Map(1 ->("a", "b"))))
+ }
+
+ encodeDecodeTest(("ArrayBuffer[(String, String)]",
+ ArrayBuffer(("a", "b"))))
+ encodeDecodeTest(("ArrayBuffer[(Int, Int)]",
+ ArrayBuffer((1, 2))))
+ encodeDecodeTest(("ArrayBuffer[(Long, Long)]",
+ ArrayBuffer((1L, 2L))))
+ encodeDecodeTest(("ArrayBuffer[(Float, Float)]",
+ ArrayBuffer((1.toFloat, 2.toFloat))))
+ encodeDecodeTest(("ArrayBuffer[(Double, Double)]",
+ ArrayBuffer((1.toDouble, 2.toDouble))))
+ encodeDecodeTest(("ArrayBuffer[(Short, Short)]",
+ ArrayBuffer((1.toShort, 2.toShort))))
+ encodeDecodeTest(("ArrayBuffer[(Byte, Byte)]",
+ ArrayBuffer((1.toByte, 2.toByte))))
+ encodeDecodeTest(("ArrayBuffer[(Boolean, Boolean)]",
+ ArrayBuffer((true, false))))
+
+ encodeDecodeTest(("Seq[Seq[(Int, Int)]]",
+ Seq(Seq((1, 2)))))
+
+ encodeDecodeTestCustom(("Array[Array[(Int, Int)]]",
+ Array(Array((1, 2)))))
+ { (l, r) => l._2(0)(0) == r._2(0)(0) }
+
+ encodeDecodeTestCustom(("Array[Array[(Int, Int)]]",
+ Array(Array(Array((1, 2))))))
+ { (l, r) => l._2(0)(0)(0) == r._2(0)(0)(0) }
+
+ encodeDecodeTestCustom(("Array[Array[Array[(Int, Int)]]]",
+ Array(Array(Array(Array((1, 2)))))))
+ { (l, r) => l._2(0)(0)(0)(0) == r._2(0)(0)(0)(0) }
+
+ encodeDecodeTestCustom(("Array[Array[Array[Array[(Int, Int)]]]]",
+ Array(Array(Array(Array(Array((1, 2))))))))
+ { (l, r) => l._2(0)(0)(0)(0)(0) == r._2(0)(0)(0)(0)(0) }
+
+
+ encodeDecodeTestCustom(("Array[Array[Integer]]",
+ Array(Array[Integer](1))))
+ { (l, r) => l._2(0)(0) == r._2(0)(0) }
+
+ encodeDecodeTestCustom(("Array[Array[Int]]",
+ Array(Array(1))))
+ { (l, r) => l._2(0)(0) == r._2(0)(0) }
+
+ encodeDecodeTestCustom(("Array[Array[Int]]",
+ Array(Array(Array(1)))))
+ { (l, r) => l._2(0)(0)(0) == r._2(0)(0)(0) }
+
+ encodeDecodeTestCustom(("Array[Array[Array[Int]]]",
+ Array(Array(Array(Array(1))))))
+ { (l, r) => l._2(0)(0)(0)(0) == r._2(0)(0)(0)(0) }
+
+ encodeDecodeTestCustom(("Array[Array[Array[Array[Int]]]]",
+ Array(Array(Array(Array(Array(1)))))))
+ { (l, r) => l._2(0)(0)(0)(0)(0) == r._2(0)(0)(0)(0)(0) }
+
+ encodeDecodeTest(("Array[Byte] null",
+ null: Array[Byte]))
+ encodeDecodeTestCustom(("Array[Byte]",
+ Array[Byte](1, 2, 3)))
+ { (l, r) => util.Arrays.equals(l._2, r._2) }
+
+ encodeDecodeTest(("Array[Int] null",
+ null: Array[Int]))
+ encodeDecodeTestCustom(("Array[Int]",
+ Array[Int](1, 2, 3)))
+ { (l, r) => util.Arrays.equals(l._2, r._2) }
+
+ encodeDecodeTest(("Array[Long] null",
+ null: Array[Long]))
+ encodeDecodeTestCustom(("Array[Long]",
+ Array[Long](1, 2, 3)))
+ { (l, r) => util.Arrays.equals(l._2, r._2) }
+
+ encodeDecodeTest(("Array[Double] null",
+ null: Array[Double]))
+ encodeDecodeTestCustom(("Array[Double]",
+ Array[Double](1, 2, 3)))
+ { (l, r) => util.Arrays.equals(l._2, r._2) }
+
+ encodeDecodeTest(("Array[Float] null",
+ null: Array[Float]))
+ encodeDecodeTestCustom(("Array[Float]",
+ Array[Float](1, 2, 3)))
+ { (l, r) => util.Arrays.equals(l._2, r._2) }
+
+ encodeDecodeTest(("Array[Boolean] null",
+ null: Array[Boolean]))
+ encodeDecodeTestCustom(("Array[Boolean]",
+ Array[Boolean](true, false)))
+ { (l, r) => util.Arrays.equals(l._2, r._2) }
+
+ encodeDecodeTest(("Array[Short] null",
+ null: Array[Short]))
+ encodeDecodeTestCustom(("Array[Short]",
+ Array[Short](1, 2, 3)))
+ { (l, r) => util.Arrays.equals(l._2, r._2) }
+
+ encodeDecodeTestCustom(("java.sql.Timestamp",
+ new java.sql.Timestamp(1)))
+ { (l, r) => l._2.toString == r._2.toString }
+
+ encodeDecodeTestCustom(("java.sql.Date", new java.sql.Date(1)))
+ { (l, r) => l._2.toString == r._2.toString }
+
+ /** Simplified encodeDecodeTestCustom, where the comparison function can be `Object.equals`. */
+ protected def encodeDecodeTest[T <: Product : TypeTag](inputData: T) =
+ encodeDecodeTestCustom[T](inputData)((l, r) => l == r)
+
+ /**
+ * Constructs a test that round-trips `t` through an encoder, checking the results to ensure it
+ * matches the original.
+ */
+ protected def encodeDecodeTestCustom[T <: Product : TypeTag](
+ inputData: T)(
+ c: (T, T) => Boolean) = {
+ test(s"encode/decode: $inputData") {
+ val encoder = try ProductEncoder[T] catch {
+ case e: Exception =>
+ fail(s"Exception thrown generating encoder", e)
+ }
+ val convertedData = encoder.toRow(inputData)
+ val schema = encoder.schema.toAttributes
+ val boundEncoder = encoder.bind(schema)
+ val convertedBack = try boundEncoder.fromRow(convertedData) catch {
+ case e: Exception =>
+ fail(
+ s"""Exception thrown while decoding
+ |Converted: $convertedData
+ |Schema: ${schema.mkString(",")}
+ |${encoder.schema.treeString}
+ |
+ |Construct Expressions:
+ |${boundEncoder.constructExpression.treeString}
+ |
+ """.stripMargin, e)
+ }
+
+ if (!c(inputData, convertedBack)) {
+ val types =
+ convertedBack.productIterator.filter(_ != null).map(_.getClass.getName).mkString(",")
+
+ val encodedData = convertedData.toSeq(encoder.schema).zip(encoder.schema).map {
+ case (a: ArrayData, StructField(_, at: ArrayType, _, _)) =>
+ a.toArray[Any](at.elementType).toSeq
+ case (other, _) =>
+ other
+ }.mkString("[", ",", "]")
+
+ fail(
+ s"""Encoded/Decoded data does not match input data
+ |
+ |in: $inputData
+ |out: $convertedBack
+ |types: $types
+ |
+ |Encoded Data: $encodedData
+ |Schema: ${schema.mkString(",")}
+ |${encoder.schema.treeString}
+ |
+ |Extract Expressions:
+ |${boundEncoder.extractExpressions.map(_.treeString).mkString("\n")}
+ |
+ |Construct Expressions:
+ |${boundEncoder.constructExpression.treeString}
+ |
+ """.stripMargin)
+ }
+ }
}
}