From 328d1b3e4bc39cce653342e04f9e08af12dd7ed8 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Tue, 13 Oct 2015 17:09:17 -0700 Subject: [SPARK-11090] [SQL] Constructor for Product types from InternalRow This is a first draft of the ability to construct expressions that will take a catalyst internal row and construct a Product (case class or tuple) that has fields with the correct names. Support include: - Nested classes - Maps - Efficiently handling of arrays of primitive types Not yet supported: - Case classes that require custom collection types (i.e. List instead of Seq). Author: Michael Armbrust Closes #9100 from marmbrus/productContructor. --- .../sql/catalyst/expressions/UnsafeArrayData.java | 4 + .../spark/sql/catalyst/ScalaReflection.scala | 302 ++++++++++++++++- .../spark/sql/catalyst/encoders/Encoder.scala | 14 + .../sql/catalyst/encoders/ProductEncoder.scala | 26 +- .../spark/sql/catalyst/expressions/objects.scala | 154 +++++++-- .../apache/spark/sql/types/ArrayBasedMapData.scala | 4 + .../org/apache/spark/sql/types/ArrayData.scala | 5 + .../apache/spark/sql/types/GenericArrayData.scala | 4 +- .../catalyst/encoders/ProductEncoderSuite.scala | 369 +++++++++++++-------- 9 files changed, 723 insertions(+), 159 deletions(-) (limited to 'sql/catalyst') diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java index 796f8abec9..4c63abb071 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java @@ -74,6 +74,10 @@ public class UnsafeArrayData extends ArrayData { assert ordinal < numElements : "ordinal (" + ordinal + ") should < " + numElements; } + public Object[] array() { + throw new UnsupportedOperationException("Only supported on GenericArrayData."); + } + /** * Construct a new UnsafeArrayData. The resulting UnsafeArrayData won't be usable until * `pointTo()` has been called, since the value returned by this constructor is equivalent diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 8b733f2a0b..8edd6498e5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst +import org.apache.spark.sql.catalyst.analysis.{UnresolvedExtractValue, UnresolvedAttribute} import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils @@ -80,6 +81,9 @@ trait ScalaReflection { * Returns the Spark SQL DataType for a given scala type. Where this is not an exact mapping * to a native type, an ObjectType is returned. Special handling is also used for Arrays including * those that hold primitive types. + * + * Unlike `schemaFor`, this function doesn't do any massaging of types into the Spark SQL type + * system. As a result, ObjectType will be returned for things like boxed Integers */ def dataTypeFor(tpe: `Type`): DataType = tpe match { case t if t <:< definitions.IntTpe => IntegerType @@ -114,6 +118,298 @@ trait ScalaReflection { } } + /** + * Given a type `T` this function constructs and ObjectType that holds a class of type + * Array[T]. Special handling is performed for primitive types to map them back to their raw + * JVM form instead of the Scala Array that handles auto boxing. + */ + def arrayClassFor(tpe: `Type`): DataType = { + val cls = tpe match { + case t if t <:< definitions.IntTpe => classOf[Array[Int]] + case t if t <:< definitions.LongTpe => classOf[Array[Long]] + case t if t <:< definitions.DoubleTpe => classOf[Array[Double]] + case t if t <:< definitions.FloatTpe => classOf[Array[Float]] + case t if t <:< definitions.ShortTpe => classOf[Array[Short]] + case t if t <:< definitions.ByteTpe => classOf[Array[Byte]] + case t if t <:< definitions.BooleanTpe => classOf[Array[Boolean]] + case other => + // There is probably a better way to do this, but I couldn't find it... + val elementType = dataTypeFor(other).asInstanceOf[ObjectType].cls + java.lang.reflect.Array.newInstance(elementType, 1).getClass + + } + ObjectType(cls) + } + + /** + * Returns an expression that can be used to construct an object of type `T` given a an input + * row with a compatible schema. Fields of the row will be extracted using UnresolvedAttributes + * of the same name as the constructor arguments. Nested classes will have their fields accessed + * using UnresolvedExtractValue. + */ + def constructorFor[T : TypeTag]: Expression = constructorFor(typeOf[T], None) + + protected def constructorFor( + tpe: `Type`, + path: Option[Expression]): Expression = ScalaReflectionLock.synchronized { + + /** Returns the current path with a sub-field extracted. */ + def addToPath(part: String) = + path + .map(p => UnresolvedExtractValue(p, expressions.Literal(part))) + .getOrElse(UnresolvedAttribute(part)) + + /** Returns the current path or throws an error. */ + def getPath = path.getOrElse(sys.error("Constructors must start at a class type")) + + tpe match { + case t if !dataTypeFor(t).isInstanceOf[ObjectType] => + getPath + + case t if t <:< localTypeOf[Option[_]] => + val TypeRef(_, _, Seq(optType)) = t + val boxedType = optType match { + // For primitive types we must manually box the primitive value. + case t if t <:< definitions.IntTpe => Some(classOf[java.lang.Integer]) + case t if t <:< definitions.LongTpe => Some(classOf[java.lang.Long]) + case t if t <:< definitions.DoubleTpe => Some(classOf[java.lang.Double]) + case t if t <:< definitions.FloatTpe => Some(classOf[java.lang.Float]) + case t if t <:< definitions.ShortTpe => Some(classOf[java.lang.Short]) + case t if t <:< definitions.ByteTpe => Some(classOf[java.lang.Byte]) + case t if t <:< definitions.BooleanTpe => Some(classOf[java.lang.Boolean]) + case _ => None + } + + boxedType.map { boxedType => + val objectType = ObjectType(boxedType) + WrapOption( + objectType, + NewInstance( + boxedType, + getPath :: Nil, + propagateNull = true, + objectType)) + }.getOrElse { + val className: String = optType.erasure.typeSymbol.asClass.fullName + val cls = Utils.classForName(className) + val objectType = ObjectType(cls) + + WrapOption(objectType, constructorFor(optType, path)) + } + + case t if t <:< localTypeOf[java.lang.Integer] => + val boxedType = classOf[java.lang.Integer] + val objectType = ObjectType(boxedType) + NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType) + + case t if t <:< localTypeOf[java.lang.Long] => + val boxedType = classOf[java.lang.Long] + val objectType = ObjectType(boxedType) + NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType) + + case t if t <:< localTypeOf[java.lang.Double] => + val boxedType = classOf[java.lang.Double] + val objectType = ObjectType(boxedType) + NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType) + + case t if t <:< localTypeOf[java.lang.Float] => + val boxedType = classOf[java.lang.Float] + val objectType = ObjectType(boxedType) + NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType) + + case t if t <:< localTypeOf[java.lang.Short] => + val boxedType = classOf[java.lang.Short] + val objectType = ObjectType(boxedType) + NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType) + + case t if t <:< localTypeOf[java.lang.Byte] => + val boxedType = classOf[java.lang.Byte] + val objectType = ObjectType(boxedType) + NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType) + + case t if t <:< localTypeOf[java.lang.Boolean] => + val boxedType = classOf[java.lang.Boolean] + val objectType = ObjectType(boxedType) + NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType) + + case t if t <:< localTypeOf[java.sql.Date] => + StaticInvoke( + DateTimeUtils, + ObjectType(classOf[java.sql.Date]), + "toJavaDate", + getPath :: Nil, + propagateNull = true) + + case t if t <:< localTypeOf[java.sql.Timestamp] => + StaticInvoke( + DateTimeUtils, + ObjectType(classOf[java.sql.Timestamp]), + "toJavaTimestamp", + getPath :: Nil, + propagateNull = true) + + case t if t <:< localTypeOf[java.lang.String] => + Invoke(getPath, "toString", ObjectType(classOf[String])) + + case t if t <:< localTypeOf[java.math.BigDecimal] => + Invoke(getPath, "toJavaBigDecimal", ObjectType(classOf[java.math.BigDecimal])) + + case t if t <:< localTypeOf[Array[_]] => + val TypeRef(_, _, Seq(elementType)) = t + val elementDataType = dataTypeFor(elementType) + val Schema(dataType, nullable) = schemaFor(elementType) + + val primitiveMethod = elementType match { + case t if t <:< definitions.IntTpe => Some("toIntArray") + case t if t <:< definitions.LongTpe => Some("toLongArray") + case t if t <:< definitions.DoubleTpe => Some("toDoubleArray") + case t if t <:< definitions.FloatTpe => Some("toFloatArray") + case t if t <:< definitions.ShortTpe => Some("toShortArray") + case t if t <:< definitions.ByteTpe => Some("toByteArray") + case t if t <:< definitions.BooleanTpe => Some("toBooleanArray") + case _ => None + } + + primitiveMethod.map { method => + Invoke(getPath, method, dataTypeFor(t)) + }.getOrElse { + val returnType = dataTypeFor(t) + Invoke( + MapObjects(p => constructorFor(elementType, Some(p)), getPath, dataType), + "array", + returnType) + } + + case t if t <:< localTypeOf[Map[_, _]] => + val TypeRef(_, _, Seq(keyType, valueType)) = t + val Schema(keyDataType, _) = schemaFor(keyType) + val Schema(valueDataType, valueNullable) = schemaFor(valueType) + + val primitiveMethodKey = keyType match { + case t if t <:< definitions.IntTpe => Some("toIntArray") + case t if t <:< definitions.LongTpe => Some("toLongArray") + case t if t <:< definitions.DoubleTpe => Some("toDoubleArray") + case t if t <:< definitions.FloatTpe => Some("toFloatArray") + case t if t <:< definitions.ShortTpe => Some("toShortArray") + case t if t <:< definitions.ByteTpe => Some("toByteArray") + case t if t <:< definitions.BooleanTpe => Some("toBooleanArray") + case _ => None + } + + val keyData = + Invoke( + MapObjects( + p => constructorFor(keyType, Some(p)), + Invoke(getPath, "keyArray", ArrayType(keyDataType)), + keyDataType), + "array", + ObjectType(classOf[Array[Any]])) + + val primitiveMethodValue = valueType match { + case t if t <:< definitions.IntTpe => Some("toIntArray") + case t if t <:< definitions.LongTpe => Some("toLongArray") + case t if t <:< definitions.DoubleTpe => Some("toDoubleArray") + case t if t <:< definitions.FloatTpe => Some("toFloatArray") + case t if t <:< definitions.ShortTpe => Some("toShortArray") + case t if t <:< definitions.ByteTpe => Some("toByteArray") + case t if t <:< definitions.BooleanTpe => Some("toBooleanArray") + case _ => None + } + + val valueData = + Invoke( + MapObjects( + p => constructorFor(valueType, Some(p)), + Invoke(getPath, "valueArray", ArrayType(valueDataType)), + valueDataType), + "array", + ObjectType(classOf[Array[Any]])) + + StaticInvoke( + ArrayBasedMapData, + ObjectType(classOf[Map[_, _]]), + "toScalaMap", + keyData :: valueData :: Nil) + + case t if t <:< localTypeOf[Seq[_]] => + val TypeRef(_, _, Seq(elementType)) = t + val elementDataType = dataTypeFor(elementType) + val Schema(dataType, nullable) = schemaFor(elementType) + + // Avoid boxing when possible by just wrapping a primitive array. + val primitiveMethod = elementType match { + case _ if nullable => None + case t if t <:< definitions.IntTpe => Some("toIntArray") + case t if t <:< definitions.LongTpe => Some("toLongArray") + case t if t <:< definitions.DoubleTpe => Some("toDoubleArray") + case t if t <:< definitions.FloatTpe => Some("toFloatArray") + case t if t <:< definitions.ShortTpe => Some("toShortArray") + case t if t <:< definitions.ByteTpe => Some("toByteArray") + case t if t <:< definitions.BooleanTpe => Some("toBooleanArray") + case _ => None + } + + val arrayData = primitiveMethod.map { method => + Invoke(getPath, method, arrayClassFor(elementType)) + }.getOrElse { + Invoke( + MapObjects(p => constructorFor(elementType, Some(p)), getPath, dataType), + "array", + arrayClassFor(elementType)) + } + + StaticInvoke( + scala.collection.mutable.WrappedArray, + ObjectType(classOf[Seq[_]]), + "make", + arrayData :: Nil) + + + case t if t <:< localTypeOf[Product] => + val formalTypeArgs = t.typeSymbol.asClass.typeParams + val TypeRef(_, _, actualTypeArgs) = t + val constructorSymbol = t.member(nme.CONSTRUCTOR) + val params = if (constructorSymbol.isMethod) { + constructorSymbol.asMethod.paramss + } else { + // Find the primary constructor, and use its parameter ordering. + val primaryConstructorSymbol: Option[Symbol] = + constructorSymbol.asTerm.alternatives.find(s => + s.isMethod && s.asMethod.isPrimaryConstructor) + + if (primaryConstructorSymbol.isEmpty) { + sys.error("Internal SQL error: Product object did not have a primary constructor.") + } else { + primaryConstructorSymbol.get.asMethod.paramss + } + } + + val className: String = t.erasure.typeSymbol.asClass.fullName + val cls = Utils.classForName(className) + + val arguments = params.head.map { p => + val fieldName = p.name.toString + val fieldType = p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs) + val dataType = dataTypeFor(fieldType) + + constructorFor(fieldType, Some(addToPath(fieldName))) + } + + val newInstance = NewInstance(cls, arguments, propagateNull = false, ObjectType(cls)) + + if (path.nonEmpty) { + expressions.If( + IsNull(getPath), + expressions.Literal.create(null, ObjectType(cls)), + newInstance + ) + } else { + newInstance + } + + } + } + /** Returns expressions for extracting all the fields from the given type. */ def extractorsFor[T : TypeTag](inputObject: Expression): Seq[Expression] = { ScalaReflectionLock.synchronized { @@ -227,13 +523,13 @@ trait ScalaReflection { val elementDataType = dataTypeFor(elementType) val Schema(dataType, nullable) = schemaFor(elementType) - if (!elementDataType.isInstanceOf[AtomicType]) { - MapObjects(extractorFor(_, elementType), inputObject, elementDataType) - } else { + if (dataType.isInstanceOf[AtomicType]) { NewInstance( classOf[GenericArrayData], inputObject :: Nil, dataType = ArrayType(dataType, nullable)) + } else { + MapObjects(extractorFor(_, elementType), inputObject, elementDataType) } case t if t <:< localTypeOf[Map[_, _]] => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala index 8dacfa9477..3618247d5d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala @@ -17,8 +17,10 @@ package org.apache.spark.sql.catalyst.encoders + import scala.reflect.ClassTag +import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.types.StructType @@ -41,4 +43,16 @@ trait Encoder[T] { * copy the result before making another call if required. */ def toRow(t: T): InternalRow + + /** + * Returns an object of type `T`, extracting the required values from the provided row. Note that + * you must bind` and encoder to a specific schema before you can call this function. + */ + def fromRow(row: InternalRow): T + + /** + * Returns a new copy of this encoder, where the expressions used by `fromRow` are bound to the + * given schema + */ + def bind(schema: Seq[Attribute]): Encoder[T] } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala index a23613673e..b0381880c3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala @@ -17,8 +17,10 @@ package org.apache.spark.sql.catalyst.encoders +import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection +import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection, GenerateUnsafeProjection} +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project} import scala.reflect.ClassTag import scala.reflect.runtime.universe.{typeTag, TypeTag} @@ -31,7 +33,7 @@ import org.apache.spark.sql.types.{ObjectType, StructType} * internal binary representation. */ object ProductEncoder { - def apply[T <: Product : TypeTag]: Encoder[T] = { + def apply[T <: Product : TypeTag]: ClassEncoder[T] = { // We convert the not-serializable TypeTag into StructType and ClassTag. val schema = ScalaReflection.schemaFor[T].dataType.asInstanceOf[StructType] val mirror = typeTag[T].mirror @@ -39,7 +41,8 @@ object ProductEncoder { val inputObject = BoundReference(0, ObjectType(cls), nullable = true) val extractExpressions = ScalaReflection.extractorsFor[T](inputObject) - new ClassEncoder[T](schema, extractExpressions, ClassTag[T](cls)) + val constructExpression = ScalaReflection.constructorFor[T] + new ClassEncoder[T](schema, extractExpressions, constructExpression, ClassTag[T](cls)) } } @@ -54,14 +57,31 @@ object ProductEncoder { case class ClassEncoder[T]( schema: StructType, extractExpressions: Seq[Expression], + constructExpression: Expression, clsTag: ClassTag[T]) extends Encoder[T] { private val extractProjection = GenerateUnsafeProjection.generate(extractExpressions) private val inputRow = new GenericMutableRow(1) + private lazy val constructProjection = GenerateSafeProjection.generate(constructExpression :: Nil) + private val dataType = ObjectType(clsTag.runtimeClass) + override def toRow(t: T): InternalRow = { inputRow(0) = t extractProjection(inputRow) } + + override def fromRow(row: InternalRow): T = { + constructProjection(row).get(0, dataType).asInstanceOf[T] + } + + override def bind(schema: Seq[Attribute]): ClassEncoder[T] = { + val plan = Project(Alias(constructExpression, "object")() :: Nil, LocalRelation(schema)) + val analyzedPlan = SimpleAnalyzer.execute(plan) + val resolvedExpression = analyzedPlan.expressions.head.children.head + val boundExpression = BindReferences.bindReference(resolvedExpression, schema) + + copy(constructExpression = boundExpression) + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala index e1f960a6e6..e8c1c93cf5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala @@ -17,9 +17,12 @@ package org.apache.spark.sql.catalyst.expressions +import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer +import org.apache.spark.sql.catalyst.plans.logical.{Project, LocalRelation} + import scala.language.existentials -import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.{ScalaReflection, InternalRow} import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext} import org.apache.spark.sql.types._ @@ -48,7 +51,7 @@ case class StaticInvoke( case other => other.getClass.getName.stripSuffix("$") } override def nullable: Boolean = true - override def children: Seq[Expression] = Nil + override def children: Seq[Expression] = arguments override def eval(input: InternalRow): Any = throw new UnsupportedOperationException("Only code-generated evaluation is supported.") @@ -69,7 +72,7 @@ case class StaticInvoke( s""" ${argGen.map(_.code).mkString("\n")} - boolean ${ev.isNull} = true; + boolean ${ev.isNull} = !$argsNonNull; $javaType ${ev.value} = ${ctx.defaultValue(dataType)}; if ($argsNonNull) { @@ -81,8 +84,8 @@ case class StaticInvoke( s""" ${argGen.map(_.code).mkString("\n")} - final boolean ${ev.isNull} = ${ev.value} == null; $javaType ${ev.value} = $objectName.$functionName($argString); + final boolean ${ev.isNull} = ${ev.value} == null; """ } } @@ -92,6 +95,10 @@ case class StaticInvoke( * Calls the specified function on an object, optionally passing arguments. If the `targetObject` * expression evaluates to null then null will be returned. * + * In some cases, due to erasure, the schema may expect a primitive type when in fact the method + * is returning java.lang.Object. In this case, we will generate code that attempts to unbox the + * value automatically. + * * @param targetObject An expression that will return the object to call the method on. * @param functionName The name of the method to call. * @param dataType The expected return type of the function. @@ -109,6 +116,35 @@ case class Invoke( override def eval(input: InternalRow): Any = throw new UnsupportedOperationException("Only code-generated evaluation is supported.") + lazy val method = targetObject.dataType match { + case ObjectType(cls) => + cls + .getMethods + .find(_.getName == functionName) + .getOrElse(sys.error(s"Couldn't find $functionName on $cls")) + .getReturnType + .getName + case _ => "" + } + + lazy val unboxer = (dataType, method) match { + case (IntegerType, "java.lang.Object") => (s: String) => + s"((java.lang.Integer)$s).intValue()" + case (LongType, "java.lang.Object") => (s: String) => + s"((java.lang.Long)$s).longValue()" + case (FloatType, "java.lang.Object") => (s: String) => + s"((java.lang.Float)$s).floatValue()" + case (ShortType, "java.lang.Object") => (s: String) => + s"((java.lang.Short)$s).shortValue()" + case (ByteType, "java.lang.Object") => (s: String) => + s"((java.lang.Byte)$s).byteValue()" + case (DoubleType, "java.lang.Object") => (s: String) => + s"((java.lang.Double)$s).doubleValue()" + case (BooleanType, "java.lang.Object") => (s: String) => + s"((java.lang.Boolean)$s).booleanValue()" + case _ => identity[String] _ + } + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { val javaType = ctx.javaType(dataType) val obj = targetObject.gen(ctx) @@ -123,6 +159,8 @@ case class Invoke( "" } + val value = unboxer(s"${obj.value}.$functionName($argString)") + s""" ${obj.code} ${argGen.map(_.code).mkString("\n")} @@ -130,7 +168,7 @@ case class Invoke( boolean ${ev.isNull} = ${obj.value} == null; $javaType ${ev.value} = ${ev.isNull} ? - ${ctx.defaultValue(dataType)} : ($javaType) ${obj.value}.$functionName($argString); + ${ctx.defaultValue(dataType)} : ($javaType) $value; $objNullCheck """ } @@ -190,8 +228,8 @@ case class NewInstance( s""" ${argGen.map(_.code).mkString("\n")} - final boolean ${ev.isNull} = ${ev.value} == null; $javaType ${ev.value} = new $className($argString); + final boolean ${ev.isNull} = ${ev.value} == null; """ } } @@ -210,8 +248,6 @@ case class UnwrapOption( override def nullable: Boolean = true - override def children: Seq[Expression] = Nil - override def inputTypes: Seq[AbstractDataType] = ObjectType :: Nil override def eval(input: InternalRow): Any = @@ -231,6 +267,43 @@ case class UnwrapOption( } } +/** + * Converts the result of evaluating `child` into an option, checking both the isNull bit and + * (in the case of reference types) equality with null. + * @param optionType The datatype to be held inside of the Option. + * @param child The expression to evaluate and wrap. + */ +case class WrapOption(optionType: DataType, child: Expression) + extends UnaryExpression with ExpectsInputTypes { + + override def dataType: DataType = ObjectType(classOf[Option[_]]) + + override def nullable: Boolean = true + + override def inputTypes: Seq[AbstractDataType] = ObjectType :: Nil + + override def eval(input: InternalRow): Any = + throw new UnsupportedOperationException("Only code-generated evaluation is supported") + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val javaType = ctx.javaType(optionType) + val inputObject = child.gen(ctx) + + s""" + ${inputObject.code} + + boolean ${ev.isNull} = false; + scala.Option<$javaType> ${ev.value} = + ${inputObject.isNull} ? + scala.Option$$.MODULE$$.apply(null) : new scala.Some(${inputObject.value}); + """ + } +} + +/** + * A place holder for the loop variable used in [[MapObjects]]. This should never be constructed + * manually, but will instead be passed into the provided lambda function. + */ case class LambdaVariable(value: String, isNull: String, dataType: DataType) extends Expression { override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = @@ -251,7 +324,7 @@ case class LambdaVariable(value: String, isNull: String, dataType: DataType) ext * as an ArrayType. This is similar to a typical map operation, but where the lambda function * is expressed using catalyst expressions. * - * The following collection ObjectTypes are currently supported: Seq, Array + * The following collection ObjectTypes are currently supported: Seq, Array, ArrayData * * @param function A function that returns an expression, given an attribute that can be used * to access the current value. This is does as a lambda function so that @@ -265,14 +338,32 @@ case class MapObjects( inputData: Expression, elementType: DataType) extends Expression { - private val loopAttribute = AttributeReference("loopVar", elementType)() - private val completeFunction = function(loopAttribute) + private lazy val loopAttribute = AttributeReference("loopVar", elementType)() + private lazy val completeFunction = function(loopAttribute) - private val (lengthFunction, itemAccessor) = inputData.dataType match { - case ObjectType(cls) if cls.isAssignableFrom(classOf[Seq[_]]) => - (".size()", (i: String) => s".apply($i)") + private lazy val (lengthFunction, itemAccessor, primitiveElement) = inputData.dataType match { + case ObjectType(cls) if classOf[Seq[_]].isAssignableFrom(cls) => + (".size()", (i: String) => s".apply($i)", false) case ObjectType(cls) if cls.isArray => - (".length", (i: String) => s"[$i]") + (".length", (i: String) => s"[$i]", false) + case ArrayType(s: StructType, _) => + (".numElements()", (i: String) => s".getStruct($i, ${s.size})", false) + case ArrayType(a: ArrayType, _) => + (".numElements()", (i: String) => s".getArray($i)", true) + case ArrayType(IntegerType, _) => + (".numElements()", (i: String) => s".getInt($i)", true) + case ArrayType(LongType, _) => + (".numElements()", (i: String) => s".getLong($i)", true) + case ArrayType(FloatType, _) => + (".numElements()", (i: String) => s".getFloat($i)", true) + case ArrayType(DoubleType, _) => + (".numElements()", (i: String) => s".getDouble($i)", true) + case ArrayType(ByteType, _) => + (".numElements()", (i: String) => s".getByte($i)", true) + case ArrayType(ShortType, _) => + (".numElements()", (i: String) => s".getShort($i)", true) + case ArrayType(BooleanType, _) => + (".numElements()", (i: String) => s".getBoolean($i)", true) } override def nullable: Boolean = true @@ -294,15 +385,38 @@ case class MapObjects( val loopIsNull = ctx.freshName("loopIsNull") val loopVariable = LambdaVariable(loopValue, loopIsNull, elementType) - val boundFunction = completeFunction transform { + val substitutedFunction = completeFunction transform { case a: AttributeReference if a == loopAttribute => loopVariable } + // A hack to run this through the analyzer (to bind extractions). + val boundFunction = + SimpleAnalyzer.execute(Project(Alias(substitutedFunction, "")() :: Nil, LocalRelation(Nil))) + .expressions.head.children.head val genFunction = boundFunction.gen(ctx) val dataLength = ctx.freshName("dataLength") val convertedArray = ctx.freshName("convertedArray") val loopIndex = ctx.freshName("loopIndex") + val convertedType = ctx.javaType(boundFunction.dataType) + + // Because of the way Java defines nested arrays, we have to handle the syntax specially. + // Specifically, we have to insert the [$dataLength] in between the type and any extra nested + // array declarations (i.e. new String[1][]). + val arrayConstructor = if (convertedType contains "[]") { + val rawType = convertedType.takeWhile(_ != '[') + val arrayPart = convertedType.reverse.takeWhile(c => c == '[' || c == ']').reverse + s"new $rawType[$dataLength]$arrayPart" + } else { + s"new $convertedType[$dataLength]" + } + + val loopNullCheck = if (primitiveElement) { + s"boolean $loopIsNull = ${genInputData.value}.isNullAt($loopIndex);" + } else { + s"boolean $loopIsNull = ${genInputData.isNull} || $loopValue == null;" + } + s""" ${genInputData.code} @@ -310,19 +424,19 @@ case class MapObjects( $javaType ${ev.value} = ${ctx.defaultValue(dataType)}; if (!${ev.isNull}) { - Object[] $convertedArray = null; + $convertedType[] $convertedArray = null; int $dataLength = ${genInputData.value}$lengthFunction; - $convertedArray = new Object[$dataLength]; + $convertedArray = $arrayConstructor; int $loopIndex = 0; while ($loopIndex < $dataLength) { $elementJavaType $loopValue = ($elementJavaType)${genInputData.value}${itemAccessor(loopIndex)}; - boolean $loopIsNull = $loopValue == null; + $loopNullCheck ${genFunction.code} - $convertedArray[$loopIndex] = ${genFunction.value}; + $convertedArray[$loopIndex] = ($convertedType)${genFunction.value}; $loopIndex += 1; } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayBasedMapData.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayBasedMapData.scala index 52069598ee..5f22e59d5f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayBasedMapData.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayBasedMapData.scala @@ -62,4 +62,8 @@ object ArrayBasedMapData { val values = map.valueArray.asInstanceOf[GenericArrayData].array keys.zip(values).toMap } + + def toScalaMap(keys: Array[Any], values: Array[Any]): Map[Any, Any] = { + keys.zip(values).toMap + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayData.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayData.scala index 642c56f12d..b4ea300f5f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayData.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayData.scala @@ -26,6 +26,8 @@ abstract class ArrayData extends SpecializedGetters with Serializable { def copy(): ArrayData + def array: Array[Any] + def toBooleanArray(): Array[Boolean] = { val size = numElements() val values = new Array[Boolean](size) @@ -103,6 +105,9 @@ abstract class ArrayData extends SpecializedGetters with Serializable { values } + def toObjectArray(elementType: DataType): Array[AnyRef] = + toArray[AnyRef](elementType: DataType) + def toArray[T: ClassTag](elementType: DataType): Array[T] = { val size = numElements() val values = new Array[T](size) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/GenericArrayData.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/GenericArrayData.scala index c381603327..9448d88d6c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/GenericArrayData.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/GenericArrayData.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.types import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} -class GenericArrayData(private[sql] val array: Array[Any]) extends ArrayData { +class GenericArrayData(val array: Array[Any]) extends ArrayData { def this(seq: scala.collection.GenIterable[Any]) = this(seq.toArray) @@ -29,6 +29,8 @@ class GenericArrayData(private[sql] val array: Array[Any]) extends ArrayData { def this(primitiveArray: Array[Long]) = this(primitiveArray.toSeq) def this(primitiveArray: Array[Float]) = this(primitiveArray.toSeq) def this(primitiveArray: Array[Double]) = this(primitiveArray.toSeq) + def this(primitiveArray: Array[Short]) = this(primitiveArray.toSeq) + def this(primitiveArray: Array[Byte]) = this(primitiveArray.toSeq) def this(primitiveArray: Array[Boolean]) = this(primitiveArray.toSeq) override def copy(): ArrayData = new GenericArrayData(array.clone()) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala index 99c993d3fe..02e43ddb35 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala @@ -17,158 +17,263 @@ package org.apache.spark.sql.catalyst.encoders -import java.sql.{Date, Timestamp} +import java.util + +import org.apache.spark.sql.types.{StructField, ArrayType, ArrayData} + +import scala.collection.mutable.ArrayBuffer +import scala.reflect.runtime.universe._ import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.ScalaReflection._ -import org.apache.spark.sql.catalyst.expressions.UnsafeProjection import org.apache.spark.sql.catalyst._ - case class RepeatedStruct(s: Seq[PrimitiveData]) case class NestedArray(a: Array[Array[Int]]) -class ProductEncoderSuite extends SparkFunSuite { +case class BoxedData( + intField: java.lang.Integer, + longField: java.lang.Long, + doubleField: java.lang.Double, + floatField: java.lang.Float, + shortField: java.lang.Short, + byteField: java.lang.Byte, + booleanField: java.lang.Boolean) - test("convert PrimitiveData to InternalRow") { - val inputData = PrimitiveData(1, 1, 1, 1, 1, 1, true) - val encoder = ProductEncoder[PrimitiveData] - val convertedData = encoder.toRow(inputData) - - assert(convertedData.getInt(0) == 1) - assert(convertedData.getLong(1) == 1.toLong) - assert(convertedData.getDouble(2) == 1.toDouble) - assert(convertedData.getFloat(3) == 1.toFloat) - assert(convertedData.getShort(4) == 1.toShort) - assert(convertedData.getByte(5) == 1.toByte) - assert(convertedData.getBoolean(6) == true) - } +case class RepeatedData( + arrayField: Seq[Int], + arrayFieldContainsNull: Seq[java.lang.Integer], + mapField: scala.collection.Map[Int, Long], + mapFieldNull: scala.collection.Map[Int, java.lang.Long], + structField: PrimitiveData) - test("convert Some[_] to InternalRow") { - val primitiveData = PrimitiveData(1, 1, 1, 1, 1, 1, true) - val inputData = OptionalData(Some(2), Some(2), Some(2), Some(2), Some(2), Some(2), Some(true), - Some(primitiveData)) - - val encoder = ProductEncoder[OptionalData] - val convertedData = encoder.toRow(inputData) - - assert(convertedData.getInt(0) == 2) - assert(convertedData.getLong(1) == 2.toLong) - assert(convertedData.getDouble(2) == 2.toDouble) - assert(convertedData.getFloat(3) == 2.toFloat) - assert(convertedData.getShort(4) == 2.toShort) - assert(convertedData.getByte(5) == 2.toByte) - assert(convertedData.getBoolean(6) == true) - - val nestedRow = convertedData.getStruct(7, 7) - assert(nestedRow.getInt(0) == 1) - assert(nestedRow.getLong(1) == 1.toLong) - assert(nestedRow.getDouble(2) == 1.toDouble) - assert(nestedRow.getFloat(3) == 1.toFloat) - assert(nestedRow.getShort(4) == 1.toShort) - assert(nestedRow.getByte(5) == 1.toByte) - assert(nestedRow.getBoolean(6) == true) - } +case class SpecificCollection(l: List[Int]) - test("convert None to InternalRow") { - val inputData = OptionalData(None, None, None, None, None, None, None, None) - val encoder = ProductEncoder[OptionalData] - val convertedData = encoder.toRow(inputData) - - assert(convertedData.isNullAt(0)) - assert(convertedData.isNullAt(1)) - assert(convertedData.isNullAt(2)) - assert(convertedData.isNullAt(3)) - assert(convertedData.isNullAt(4)) - assert(convertedData.isNullAt(5)) - assert(convertedData.isNullAt(6)) - assert(convertedData.isNullAt(7)) - } +class ProductEncoderSuite extends SparkFunSuite { - test("convert nullable but present data to InternalRow") { - val inputData = NullableData( - 1, 1L, 1.0, 1.0f, 1.toShort, 1.toByte, true, "test", new java.math.BigDecimal(1), new Date(0), - new Timestamp(0), Array[Byte](1, 2, 3)) - - val encoder = ProductEncoder[NullableData] - val convertedData = encoder.toRow(inputData) - - assert(convertedData.getInt(0) == 1) - assert(convertedData.getLong(1) == 1.toLong) - assert(convertedData.getDouble(2) == 1.toDouble) - assert(convertedData.getFloat(3) == 1.toFloat) - assert(convertedData.getShort(4) == 1.toShort) - assert(convertedData.getByte(5) == 1.toByte) - assert(convertedData.getBoolean(6) == true) - } + encodeDecodeTest(PrimitiveData(1, 1, 1, 1, 1, 1, true)) - test("convert nullable data to InternalRow") { - val inputData = - NullableData(null, null, null, null, null, null, null, null, null, null, null, null) - - val encoder = ProductEncoder[NullableData] - val convertedData = encoder.toRow(inputData) - - assert(convertedData.isNullAt(0)) - assert(convertedData.isNullAt(1)) - assert(convertedData.isNullAt(2)) - assert(convertedData.isNullAt(3)) - assert(convertedData.isNullAt(4)) - assert(convertedData.isNullAt(5)) - assert(convertedData.isNullAt(6)) - assert(convertedData.isNullAt(7)) - assert(convertedData.isNullAt(8)) - assert(convertedData.isNullAt(9)) - assert(convertedData.isNullAt(10)) - assert(convertedData.isNullAt(11)) - } + // TODO: Support creating specific subclasses of Seq. + ignore("Specific collection types") { encodeDecodeTest(SpecificCollection(1 :: Nil)) } - test("convert repeated struct") { - val inputData = RepeatedStruct(PrimitiveData(1, 1, 1, 1, 1, 1, true) :: Nil) - val encoder = ProductEncoder[RepeatedStruct] - - val converted = encoder.toRow(inputData) - val convertedStruct = converted.getArray(0).getStruct(0, 7) - assert(convertedStruct.getInt(0) == 1) - assert(convertedStruct.getLong(1) == 1.toLong) - assert(convertedStruct.getDouble(2) == 1.toDouble) - assert(convertedStruct.getFloat(3) == 1.toFloat) - assert(convertedStruct.getShort(4) == 1.toShort) - assert(convertedStruct.getByte(5) == 1.toByte) - assert(convertedStruct.getBoolean(6) == true) - } + encodeDecodeTest( + OptionalData( + Some(2), Some(2), Some(2), Some(2), Some(2), Some(2), Some(true), + Some(PrimitiveData(1, 1, 1, 1, 1, 1, true)))) - test("convert nested seq") { - val convertedData = ProductEncoder[Tuple1[Seq[Seq[Int]]]].toRow(Tuple1(Seq(Seq(1)))) - assert(convertedData.getArray(0).getArray(0).getInt(0) == 1) + encodeDecodeTest(OptionalData(None, None, None, None, None, None, None, None)) - val convertedData2 = ProductEncoder[Tuple1[Seq[Seq[Seq[Int]]]]].toRow(Tuple1(Seq(Seq(Seq(1))))) - assert(convertedData2.getArray(0).getArray(0).getArray(0).getInt(0) == 1) - } + encodeDecodeTest( + BoxedData(1, 1L, 1.0, 1.0f, 1.toShort, 1.toByte, true)) - test("convert nested array") { - val convertedData = ProductEncoder[Tuple1[Array[Array[Int]]]].toRow(Tuple1(Array(Array(1)))) - } + encodeDecodeTest( + BoxedData(null, null, null, null, null, null, null)) + + encodeDecodeTest( + RepeatedStruct(PrimitiveData(1, 1, 1, 1, 1, 1, true) :: Nil)) - test("convert complex") { - val inputData = ComplexData( + encodeDecodeTest( + RepeatedData( Seq(1, 2), - Array(1, 2), - 1 :: 2 :: Nil, Seq(new Integer(1), null, new Integer(2)), Map(1 -> 2L), - Map(1 -> new java.lang.Long(2)), - PrimitiveData(1, 1, 1, 1, 1, 1, true), - Array(Array(1))) - - val encoder = ProductEncoder[ComplexData] - val convertedData = encoder.toRow(inputData) - - assert(!convertedData.isNullAt(0)) - val seq = convertedData.getArray(0) - assert(seq.numElements() == 2) - assert(seq.getInt(0) == 1) - assert(seq.getInt(1) == 2) + Map(1 -> null), + PrimitiveData(1, 1, 1, 1, 1, 1, true))) + + encodeDecodeTest(("nullable Seq[Integer]", Seq[Integer](1, null))) + + encodeDecodeTest(("Seq[(String, String)]", + Seq(("a", "b")))) + encodeDecodeTest(("Seq[(Int, Int)]", + Seq((1, 2)))) + encodeDecodeTest(("Seq[(Long, Long)]", + Seq((1L, 2L)))) + encodeDecodeTest(("Seq[(Float, Float)]", + Seq((1.toFloat, 2.toFloat)))) + encodeDecodeTest(("Seq[(Double, Double)]", + Seq((1.toDouble, 2.toDouble)))) + encodeDecodeTest(("Seq[(Short, Short)]", + Seq((1.toShort, 2.toShort)))) + encodeDecodeTest(("Seq[(Byte, Byte)]", + Seq((1.toByte, 2.toByte)))) + encodeDecodeTest(("Seq[(Boolean, Boolean)]", + Seq((true, false)))) + + // TODO: Decoding/encoding of complex maps. + ignore("complex maps") { + encodeDecodeTest(("Map[Int, (String, String)]", + Map(1 ->("a", "b")))) + } + + encodeDecodeTest(("ArrayBuffer[(String, String)]", + ArrayBuffer(("a", "b")))) + encodeDecodeTest(("ArrayBuffer[(Int, Int)]", + ArrayBuffer((1, 2)))) + encodeDecodeTest(("ArrayBuffer[(Long, Long)]", + ArrayBuffer((1L, 2L)))) + encodeDecodeTest(("ArrayBuffer[(Float, Float)]", + ArrayBuffer((1.toFloat, 2.toFloat)))) + encodeDecodeTest(("ArrayBuffer[(Double, Double)]", + ArrayBuffer((1.toDouble, 2.toDouble)))) + encodeDecodeTest(("ArrayBuffer[(Short, Short)]", + ArrayBuffer((1.toShort, 2.toShort)))) + encodeDecodeTest(("ArrayBuffer[(Byte, Byte)]", + ArrayBuffer((1.toByte, 2.toByte)))) + encodeDecodeTest(("ArrayBuffer[(Boolean, Boolean)]", + ArrayBuffer((true, false)))) + + encodeDecodeTest(("Seq[Seq[(Int, Int)]]", + Seq(Seq((1, 2))))) + + encodeDecodeTestCustom(("Array[Array[(Int, Int)]]", + Array(Array((1, 2))))) + { (l, r) => l._2(0)(0) == r._2(0)(0) } + + encodeDecodeTestCustom(("Array[Array[(Int, Int)]]", + Array(Array(Array((1, 2)))))) + { (l, r) => l._2(0)(0)(0) == r._2(0)(0)(0) } + + encodeDecodeTestCustom(("Array[Array[Array[(Int, Int)]]]", + Array(Array(Array(Array((1, 2))))))) + { (l, r) => l._2(0)(0)(0)(0) == r._2(0)(0)(0)(0) } + + encodeDecodeTestCustom(("Array[Array[Array[Array[(Int, Int)]]]]", + Array(Array(Array(Array(Array((1, 2)))))))) + { (l, r) => l._2(0)(0)(0)(0)(0) == r._2(0)(0)(0)(0)(0) } + + + encodeDecodeTestCustom(("Array[Array[Integer]]", + Array(Array[Integer](1)))) + { (l, r) => l._2(0)(0) == r._2(0)(0) } + + encodeDecodeTestCustom(("Array[Array[Int]]", + Array(Array(1)))) + { (l, r) => l._2(0)(0) == r._2(0)(0) } + + encodeDecodeTestCustom(("Array[Array[Int]]", + Array(Array(Array(1))))) + { (l, r) => l._2(0)(0)(0) == r._2(0)(0)(0) } + + encodeDecodeTestCustom(("Array[Array[Array[Int]]]", + Array(Array(Array(Array(1)))))) + { (l, r) => l._2(0)(0)(0)(0) == r._2(0)(0)(0)(0) } + + encodeDecodeTestCustom(("Array[Array[Array[Array[Int]]]]", + Array(Array(Array(Array(Array(1))))))) + { (l, r) => l._2(0)(0)(0)(0)(0) == r._2(0)(0)(0)(0)(0) } + + encodeDecodeTest(("Array[Byte] null", + null: Array[Byte])) + encodeDecodeTestCustom(("Array[Byte]", + Array[Byte](1, 2, 3))) + { (l, r) => util.Arrays.equals(l._2, r._2) } + + encodeDecodeTest(("Array[Int] null", + null: Array[Int])) + encodeDecodeTestCustom(("Array[Int]", + Array[Int](1, 2, 3))) + { (l, r) => util.Arrays.equals(l._2, r._2) } + + encodeDecodeTest(("Array[Long] null", + null: Array[Long])) + encodeDecodeTestCustom(("Array[Long]", + Array[Long](1, 2, 3))) + { (l, r) => util.Arrays.equals(l._2, r._2) } + + encodeDecodeTest(("Array[Double] null", + null: Array[Double])) + encodeDecodeTestCustom(("Array[Double]", + Array[Double](1, 2, 3))) + { (l, r) => util.Arrays.equals(l._2, r._2) } + + encodeDecodeTest(("Array[Float] null", + null: Array[Float])) + encodeDecodeTestCustom(("Array[Float]", + Array[Float](1, 2, 3))) + { (l, r) => util.Arrays.equals(l._2, r._2) } + + encodeDecodeTest(("Array[Boolean] null", + null: Array[Boolean])) + encodeDecodeTestCustom(("Array[Boolean]", + Array[Boolean](true, false))) + { (l, r) => util.Arrays.equals(l._2, r._2) } + + encodeDecodeTest(("Array[Short] null", + null: Array[Short])) + encodeDecodeTestCustom(("Array[Short]", + Array[Short](1, 2, 3))) + { (l, r) => util.Arrays.equals(l._2, r._2) } + + encodeDecodeTestCustom(("java.sql.Timestamp", + new java.sql.Timestamp(1))) + { (l, r) => l._2.toString == r._2.toString } + + encodeDecodeTestCustom(("java.sql.Date", new java.sql.Date(1))) + { (l, r) => l._2.toString == r._2.toString } + + /** Simplified encodeDecodeTestCustom, where the comparison function can be `Object.equals`. */ + protected def encodeDecodeTest[T <: Product : TypeTag](inputData: T) = + encodeDecodeTestCustom[T](inputData)((l, r) => l == r) + + /** + * Constructs a test that round-trips `t` through an encoder, checking the results to ensure it + * matches the original. + */ + protected def encodeDecodeTestCustom[T <: Product : TypeTag]( + inputData: T)( + c: (T, T) => Boolean) = { + test(s"encode/decode: $inputData") { + val encoder = try ProductEncoder[T] catch { + case e: Exception => + fail(s"Exception thrown generating encoder", e) + } + val convertedData = encoder.toRow(inputData) + val schema = encoder.schema.toAttributes + val boundEncoder = encoder.bind(schema) + val convertedBack = try boundEncoder.fromRow(convertedData) catch { + case e: Exception => + fail( + s"""Exception thrown while decoding + |Converted: $convertedData + |Schema: ${schema.mkString(",")} + |${encoder.schema.treeString} + | + |Construct Expressions: + |${boundEncoder.constructExpression.treeString} + | + """.stripMargin, e) + } + + if (!c(inputData, convertedBack)) { + val types = + convertedBack.productIterator.filter(_ != null).map(_.getClass.getName).mkString(",") + + val encodedData = convertedData.toSeq(encoder.schema).zip(encoder.schema).map { + case (a: ArrayData, StructField(_, at: ArrayType, _, _)) => + a.toArray[Any](at.elementType).toSeq + case (other, _) => + other + }.mkString("[", ",", "]") + + fail( + s"""Encoded/Decoded data does not match input data + | + |in: $inputData + |out: $convertedBack + |types: $types + | + |Encoded Data: $encodedData + |Schema: ${schema.mkString(",")} + |${encoder.schema.treeString} + | + |Extract Expressions: + |${boundEncoder.extractExpressions.map(_.treeString).mkString("\n")} + | + |Construct Expressions: + |${boundEncoder.constructExpression.treeString} + | + """.stripMargin) + } + } } } -- cgit v1.2.3