diff options
author | Wenchen Fan <wenchen@databricks.com> | 2016-06-03 00:43:02 -0700 |
---|---|---|
committer | Cheng Lian <lian@databricks.com> | 2016-06-03 00:43:02 -0700 |
commit | 190ff274fd71662023a804cf98400c71f9f7da4f (patch) | |
tree | 9b3f79aebf252d3c27f53d9593000c5fd58e1509 /sql/catalyst | |
parent | b9fcfb3bd14592ac9f1a8e5c2bb31412b9603b60 (diff) | |
download | spark-190ff274fd71662023a804cf98400c71f9f7da4f.tar.gz spark-190ff274fd71662023a804cf98400c71f9f7da4f.tar.bz2 spark-190ff274fd71662023a804cf98400c71f9f7da4f.zip |
[SPARK-15494][SQL] encoder code cleanup
## What changes were proposed in this pull request?
Our encoder framework has been evolved a lot, this PR tries to clean up the code to make it more readable and emphasise the concept that encoder should be used as a container of serde expressions.
1. move validation logic to analyzer instead of encoder
2. only have a `resolveAndBind` method in encoder instead of `resolve` and `bind`, as we don't have the encoder life cycle concept anymore.
3. `Dataset` don't need to keep a resolved encoder, as there is no such concept anymore. bound encoder is still needed to do serialization outside of query framework.
4. Using `BoundReference` to represent an unresolved field in deserializer expression is kind of weird, this PR adds a `GetColumnByOrdinal` for this purpose. (serializer expression still use `BoundReference`, we can replace it with `GetColumnByOrdinal` in follow-ups)
## How was this patch tested?
existing test
Author: Wenchen Fan <wenchen@databricks.com>
Author: Cheng Lian <lian@databricks.com>
Closes #13269 from cloud-fan/clean-encoder.
Diffstat (limited to 'sql/catalyst')
11 files changed, 274 insertions, 330 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala index fa96f8223d..673c587b18 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala @@ -23,6 +23,7 @@ import scala.reflect.{classTag, ClassTag} import scala.reflect.runtime.universe.TypeTag import org.apache.spark.annotation.Experimental +import org.apache.spark.sql.catalyst.analysis.GetColumnByOrdinal import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder} import org.apache.spark.sql.catalyst.expressions.objects.{DecodeUsingSerializer, EncodeUsingSerializer} import org.apache.spark.sql.catalyst.expressions.BoundReference @@ -208,7 +209,7 @@ object Encoders { BoundReference(0, ObjectType(classOf[AnyRef]), nullable = true), kryo = useKryo)), deserializer = DecodeUsingSerializer[T]( - BoundReference(0, BinaryType, nullable = true), classTag[T], kryo = useKryo), + GetColumnByOrdinal(0, BinaryType), classTag[T], kryo = useKryo), clsTag = classTag[T] ) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala index 1fe143494a..b3a233ae39 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala @@ -25,7 +25,7 @@ import scala.language.existentials import com.google.common.reflect.TypeToken -import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedExtractValue} +import org.apache.spark.sql.catalyst.analysis.{GetColumnByOrdinal, UnresolvedAttribute, UnresolvedExtractValue} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.objects._ import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeUtils, GenericArrayData} @@ -177,8 +177,8 @@ object JavaTypeInference { .map(p => UnresolvedExtractValue(p, expressions.Literal(part))) .getOrElse(UnresolvedAttribute(part)) - /** Returns the current path or `BoundReference`. */ - def getPath: Expression = path.getOrElse(BoundReference(0, inferDataType(typeToken)._1, true)) + /** Returns the current path or `GetColumnByOrdinal`. */ + def getPath: Expression = path.getOrElse(GetColumnByOrdinal(0, inferDataType(typeToken)._1)) typeToken.getRawType match { case c if !inferExternalType(c).isInstanceOf[ObjectType] => getPath diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 4750861817..78c145d4fd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst -import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedExtractValue} +import org.apache.spark.sql.catalyst.analysis.{GetColumnByOrdinal, UnresolvedAttribute, UnresolvedExtractValue} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.objects._ import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeUtils, GenericArrayData} @@ -156,17 +156,17 @@ object ScalaReflection extends ScalaReflection { walkedTypePath: Seq[String]): Expression = { val newPath = path .map(p => GetStructField(p, ordinal)) - .getOrElse(BoundReference(ordinal, dataType, false)) + .getOrElse(GetColumnByOrdinal(ordinal, dataType)) upCastToExpectedType(newPath, dataType, walkedTypePath) } - /** Returns the current path or `BoundReference`. */ + /** Returns the current path or `GetColumnByOrdinal`. */ def getPath: Expression = { val dataType = schemaFor(tpe).dataType if (path.isDefined) { path.get } else { - upCastToExpectedType(BoundReference(0, dataType, true), dataType, walkedTypePath) + upCastToExpectedType(GetColumnByOrdinal(0, dataType), dataType, walkedTypePath) } } @@ -421,7 +421,7 @@ object ScalaReflection extends ScalaReflection { def serializerFor[T : TypeTag](inputObject: Expression): CreateNamedStruct = { val tpe = localTypeOf[T] val clsName = getClassNameFromType(tpe) - val walkedTypePath = s"""- root class: "${clsName}"""" :: Nil + val walkedTypePath = s"""- root class: "$clsName"""" :: Nil serializerFor(inputObject, tpe, walkedTypePath) match { case expressions.If(_, _, s: CreateNamedStruct) if definedByConstructorParams(tpe) => s case other => CreateNamedStruct(expressions.Literal("value") :: other :: Nil) @@ -449,157 +449,156 @@ object ScalaReflection extends ScalaReflection { } } - if (!inputObject.dataType.isInstanceOf[ObjectType]) { - inputObject - } else { - tpe match { - case t if t <:< localTypeOf[Option[_]] => - val TypeRef(_, _, Seq(optType)) = t - val className = getClassNameFromType(optType) - val newPath = s"""- option value class: "$className"""" +: walkedTypePath - val unwrapped = UnwrapOption(dataTypeFor(optType), inputObject) - serializerFor(unwrapped, optType, newPath) - - // Since List[_] also belongs to localTypeOf[Product], we put this case before - // "case t if definedByConstructorParams(t)" to make sure it will match to the - // case "localTypeOf[Seq[_]]" - case t if t <:< localTypeOf[Seq[_]] => - val TypeRef(_, _, Seq(elementType)) = t - toCatalystArray(inputObject, elementType) - - case t if t <:< localTypeOf[Array[_]] => - val TypeRef(_, _, Seq(elementType)) = t - toCatalystArray(inputObject, elementType) - - case t if t <:< localTypeOf[Map[_, _]] => - val TypeRef(_, _, Seq(keyType, valueType)) = t - - val keys = - Invoke( - Invoke(inputObject, "keysIterator", - ObjectType(classOf[scala.collection.Iterator[_]])), - "toSeq", - ObjectType(classOf[scala.collection.Seq[_]])) - val convertedKeys = toCatalystArray(keys, keyType) - - val values = - Invoke( - Invoke(inputObject, "valuesIterator", - ObjectType(classOf[scala.collection.Iterator[_]])), - "toSeq", - ObjectType(classOf[scala.collection.Seq[_]])) - val convertedValues = toCatalystArray(values, valueType) - - val Schema(keyDataType, _) = schemaFor(keyType) - val Schema(valueDataType, valueNullable) = schemaFor(valueType) - NewInstance( - classOf[ArrayBasedMapData], - convertedKeys :: convertedValues :: Nil, - dataType = MapType(keyDataType, valueDataType, valueNullable)) - - case t if t <:< localTypeOf[String] => - StaticInvoke( - classOf[UTF8String], - StringType, - "fromString", - inputObject :: Nil) - - case t if t <:< localTypeOf[java.sql.Timestamp] => - StaticInvoke( - DateTimeUtils.getClass, - TimestampType, - "fromJavaTimestamp", - inputObject :: Nil) - - case t if t <:< localTypeOf[java.sql.Date] => - StaticInvoke( - DateTimeUtils.getClass, - DateType, - "fromJavaDate", - inputObject :: Nil) - - case t if t <:< localTypeOf[BigDecimal] => - StaticInvoke( - Decimal.getClass, - DecimalType.SYSTEM_DEFAULT, - "apply", - inputObject :: Nil) - - case t if t <:< localTypeOf[java.math.BigDecimal] => - StaticInvoke( - Decimal.getClass, - DecimalType.SYSTEM_DEFAULT, - "apply", - inputObject :: Nil) - - case t if t <:< localTypeOf[java.math.BigInteger] => - StaticInvoke( - Decimal.getClass, - DecimalType.BigIntDecimal, - "apply", - inputObject :: Nil) - - case t if t <:< localTypeOf[scala.math.BigInt] => - StaticInvoke( - Decimal.getClass, - DecimalType.BigIntDecimal, - "apply", - inputObject :: Nil) - - case t if t <:< localTypeOf[java.lang.Integer] => - Invoke(inputObject, "intValue", IntegerType) - case t if t <:< localTypeOf[java.lang.Long] => - Invoke(inputObject, "longValue", LongType) - case t if t <:< localTypeOf[java.lang.Double] => - Invoke(inputObject, "doubleValue", DoubleType) - case t if t <:< localTypeOf[java.lang.Float] => - Invoke(inputObject, "floatValue", FloatType) - case t if t <:< localTypeOf[java.lang.Short] => - Invoke(inputObject, "shortValue", ShortType) - case t if t <:< localTypeOf[java.lang.Byte] => - Invoke(inputObject, "byteValue", ByteType) - case t if t <:< localTypeOf[java.lang.Boolean] => - Invoke(inputObject, "booleanValue", BooleanType) - - case t if t.typeSymbol.annotations.exists(_.tpe =:= typeOf[SQLUserDefinedType]) => - val udt = getClassFromType(t) - .getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance() - val obj = NewInstance( - udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt(), - Nil, - dataType = ObjectType(udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt())) - Invoke(obj, "serialize", udt, inputObject :: Nil) - - case t if UDTRegistration.exists(getClassNameFromType(t)) => - val udt = UDTRegistration.getUDTFor(getClassNameFromType(t)).get.newInstance() - .asInstanceOf[UserDefinedType[_]] - val obj = NewInstance( - udt.getClass, - Nil, - dataType = ObjectType(udt.getClass)) - Invoke(obj, "serialize", udt, inputObject :: Nil) - - case t if definedByConstructorParams(t) => - val params = getConstructorParameters(t) - val nonNullOutput = CreateNamedStruct(params.flatMap { case (fieldName, fieldType) => - if (javaKeywords.contains(fieldName)) { - throw new UnsupportedOperationException(s"`$fieldName` is a reserved keyword and " + - "cannot be used as field name\n" + walkedTypePath.mkString("\n")) - } + tpe match { + case _ if !inputObject.dataType.isInstanceOf[ObjectType] => inputObject - val fieldValue = Invoke(inputObject, fieldName, dataTypeFor(fieldType)) - val clsName = getClassNameFromType(fieldType) - val newPath = s"""- field (class: "$clsName", name: "$fieldName")""" +: walkedTypePath - expressions.Literal(fieldName) :: serializerFor(fieldValue, fieldType, newPath) :: Nil - }) - val nullOutput = expressions.Literal.create(null, nonNullOutput.dataType) - expressions.If(IsNull(inputObject), nullOutput, nonNullOutput) - - case other => - throw new UnsupportedOperationException( - s"No Encoder found for $tpe\n" + walkedTypePath.mkString("\n")) - } + case t if t <:< localTypeOf[Option[_]] => + val TypeRef(_, _, Seq(optType)) = t + val className = getClassNameFromType(optType) + val newPath = s"""- option value class: "$className"""" +: walkedTypePath + val unwrapped = UnwrapOption(dataTypeFor(optType), inputObject) + serializerFor(unwrapped, optType, newPath) + + // Since List[_] also belongs to localTypeOf[Product], we put this case before + // "case t if definedByConstructorParams(t)" to make sure it will match to the + // case "localTypeOf[Seq[_]]" + case t if t <:< localTypeOf[Seq[_]] => + val TypeRef(_, _, Seq(elementType)) = t + toCatalystArray(inputObject, elementType) + + case t if t <:< localTypeOf[Array[_]] => + val TypeRef(_, _, Seq(elementType)) = t + toCatalystArray(inputObject, elementType) + + case t if t <:< localTypeOf[Map[_, _]] => + val TypeRef(_, _, Seq(keyType, valueType)) = t + + val keys = + Invoke( + Invoke(inputObject, "keysIterator", + ObjectType(classOf[scala.collection.Iterator[_]])), + "toSeq", + ObjectType(classOf[scala.collection.Seq[_]])) + val convertedKeys = toCatalystArray(keys, keyType) + + val values = + Invoke( + Invoke(inputObject, "valuesIterator", + ObjectType(classOf[scala.collection.Iterator[_]])), + "toSeq", + ObjectType(classOf[scala.collection.Seq[_]])) + val convertedValues = toCatalystArray(values, valueType) + + val Schema(keyDataType, _) = schemaFor(keyType) + val Schema(valueDataType, valueNullable) = schemaFor(valueType) + NewInstance( + classOf[ArrayBasedMapData], + convertedKeys :: convertedValues :: Nil, + dataType = MapType(keyDataType, valueDataType, valueNullable)) + + case t if t <:< localTypeOf[String] => + StaticInvoke( + classOf[UTF8String], + StringType, + "fromString", + inputObject :: Nil) + + case t if t <:< localTypeOf[java.sql.Timestamp] => + StaticInvoke( + DateTimeUtils.getClass, + TimestampType, + "fromJavaTimestamp", + inputObject :: Nil) + + case t if t <:< localTypeOf[java.sql.Date] => + StaticInvoke( + DateTimeUtils.getClass, + DateType, + "fromJavaDate", + inputObject :: Nil) + + case t if t <:< localTypeOf[BigDecimal] => + StaticInvoke( + Decimal.getClass, + DecimalType.SYSTEM_DEFAULT, + "apply", + inputObject :: Nil) + + case t if t <:< localTypeOf[java.math.BigDecimal] => + StaticInvoke( + Decimal.getClass, + DecimalType.SYSTEM_DEFAULT, + "apply", + inputObject :: Nil) + + case t if t <:< localTypeOf[java.math.BigInteger] => + StaticInvoke( + Decimal.getClass, + DecimalType.BigIntDecimal, + "apply", + inputObject :: Nil) + + case t if t <:< localTypeOf[scala.math.BigInt] => + StaticInvoke( + Decimal.getClass, + DecimalType.BigIntDecimal, + "apply", + inputObject :: Nil) + + case t if t <:< localTypeOf[java.lang.Integer] => + Invoke(inputObject, "intValue", IntegerType) + case t if t <:< localTypeOf[java.lang.Long] => + Invoke(inputObject, "longValue", LongType) + case t if t <:< localTypeOf[java.lang.Double] => + Invoke(inputObject, "doubleValue", DoubleType) + case t if t <:< localTypeOf[java.lang.Float] => + Invoke(inputObject, "floatValue", FloatType) + case t if t <:< localTypeOf[java.lang.Short] => + Invoke(inputObject, "shortValue", ShortType) + case t if t <:< localTypeOf[java.lang.Byte] => + Invoke(inputObject, "byteValue", ByteType) + case t if t <:< localTypeOf[java.lang.Boolean] => + Invoke(inputObject, "booleanValue", BooleanType) + + case t if t.typeSymbol.annotations.exists(_.tpe =:= typeOf[SQLUserDefinedType]) => + val udt = getClassFromType(t) + .getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance() + val obj = NewInstance( + udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt(), + Nil, + dataType = ObjectType(udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt())) + Invoke(obj, "serialize", udt, inputObject :: Nil) + + case t if UDTRegistration.exists(getClassNameFromType(t)) => + val udt = UDTRegistration.getUDTFor(getClassNameFromType(t)).get.newInstance() + .asInstanceOf[UserDefinedType[_]] + val obj = NewInstance( + udt.getClass, + Nil, + dataType = ObjectType(udt.getClass)) + Invoke(obj, "serialize", udt, inputObject :: Nil) + + case t if definedByConstructorParams(t) => + val params = getConstructorParameters(t) + val nonNullOutput = CreateNamedStruct(params.flatMap { case (fieldName, fieldType) => + if (javaKeywords.contains(fieldName)) { + throw new UnsupportedOperationException(s"`$fieldName` is a reserved keyword and " + + "cannot be used as field name\n" + walkedTypePath.mkString("\n")) + } + + val fieldValue = Invoke(inputObject, fieldName, dataTypeFor(fieldType)) + val clsName = getClassNameFromType(fieldType) + val newPath = s"""- field (class: "$clsName", name: "$fieldName")""" +: walkedTypePath + expressions.Literal(fieldName) :: serializerFor(fieldValue, fieldType, newPath) :: Nil + }) + val nullOutput = expressions.Literal.create(null, nonNullOutput.dataType) + expressions.If(IsNull(inputObject), nullOutput, nonNullOutput) + + case other => + throw new UnsupportedOperationException( + s"No Encoder found for $tpe\n" + walkedTypePath.mkString("\n")) } + } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 02966796af..4f6b4830cd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -722,6 +722,7 @@ class Analyzer( // Else, throw exception. try { expr transformUp { + case GetColumnByOrdinal(ordinal, _) => plan.output(ordinal) case u @ UnresolvedAttribute(nameParts) => withPosition(u) { plan.resolve(nameParts, resolver).getOrElse(u) } case UnresolvedExtractValue(child, fieldName) if child.resolved => @@ -1924,10 +1925,54 @@ class Analyzer( } else { inputAttributes } - val unbound = deserializer transform { - case b: BoundReference => inputs(b.ordinal) - } - resolveExpression(unbound, LocalRelation(inputs), throws = true) + + validateTopLevelTupleFields(deserializer, inputs) + val resolved = resolveExpression( + deserializer, LocalRelation(inputs), throws = true) + validateNestedTupleFields(resolved) + resolved + } + } + + private def fail(schema: StructType, maxOrdinal: Int): Unit = { + throw new AnalysisException(s"Try to map ${schema.simpleString} to Tuple${maxOrdinal + 1}, " + + "but failed as the number of fields does not line up.") + } + + /** + * For each top-level Tuple field, we use [[GetColumnByOrdinal]] to get its corresponding column + * by position. However, the actual number of columns may be different from the number of Tuple + * fields. This method is used to check the number of columns and fields, and throw an + * exception if they do not match. + */ + private def validateTopLevelTupleFields( + deserializer: Expression, inputs: Seq[Attribute]): Unit = { + val ordinals = deserializer.collect { + case GetColumnByOrdinal(ordinal, _) => ordinal + }.distinct.sorted + + if (ordinals.nonEmpty && ordinals != inputs.indices) { + fail(inputs.toStructType, ordinals.last) + } + } + + /** + * For each nested Tuple field, we use [[GetStructField]] to get its corresponding struct field + * by position. However, the actual number of struct fields may be different from the number + * of nested Tuple fields. This method is used to check the number of struct fields and nested + * Tuple fields, and throw an exception if they do not match. + */ + private def validateNestedTupleFields(deserializer: Expression): Unit = { + val structChildToOrdinals = deserializer + .collect { case g: GetStructField => g } + .groupBy(_.child) + .mapValues(_.map(_.ordinal).distinct.sorted) + + structChildToOrdinals.foreach { case (expr, ordinals) => + val schema = expr.dataType.asInstanceOf[StructType] + if (ordinals != schema.indices) { + fail(schema, ordinals.last) + } } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index e953eda784..b883546135 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -366,3 +366,10 @@ case class UnresolvedDeserializer(deserializer: Expression, inputAttributes: Seq override def nullable: Boolean = throw new UnresolvedException(this, "nullable") override lazy val resolved = false } + +case class GetColumnByOrdinal(ordinal: Int, dataType: DataType) extends LeafExpression + with Unevaluable with NonSQLExpression { + override def foldable: Boolean = throw new UnresolvedException(this, "foldable") + override def nullable: Boolean = throw new UnresolvedException(this, "nullable") + override lazy val resolved = false +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index 2296946cd7..cc59d06fa3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -17,19 +17,17 @@ package org.apache.spark.sql.catalyst.encoders -import java.util.concurrent.ConcurrentMap - import scala.reflect.ClassTag import scala.reflect.runtime.universe.{typeTag, TypeTag} -import org.apache.spark.sql.{AnalysisException, Encoder} +import org.apache.spark.sql.Encoder import org.apache.spark.sql.catalyst.{InternalRow, JavaTypeInference, ScalaReflection} -import org.apache.spark.sql.catalyst.analysis.{SimpleAnalyzer, UnresolvedAttribute, UnresolvedDeserializer, UnresolvedExtractValue} +import org.apache.spark.sql.catalyst.analysis.{Analyzer, GetColumnByOrdinal, SimpleAnalyzer, UnresolvedAttribute, UnresolvedExtractValue} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection, GenerateUnsafeProjection} import org.apache.spark.sql.catalyst.expressions.objects.{Invoke, NewInstance} import org.apache.spark.sql.catalyst.optimizer.SimplifyCasts -import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project} +import org.apache.spark.sql.catalyst.plans.logical.{CatalystSerde, DeserializeToObject, LocalRelation} import org.apache.spark.sql.types.{ObjectType, StructField, StructType} import org.apache.spark.util.Utils @@ -121,15 +119,15 @@ object ExpressionEncoder { val childrenDeserializers = encoders.zipWithIndex.map { case (enc, index) => if (enc.flat) { enc.deserializer.transform { - case b: BoundReference => b.copy(ordinal = index) + case g: GetColumnByOrdinal => g.copy(ordinal = index) } } else { - val input = BoundReference(index, enc.schema, nullable = true) + val input = GetColumnByOrdinal(index, enc.schema) val deserialized = enc.deserializer.transformUp { case UnresolvedAttribute(nameParts) => assert(nameParts.length == 1) UnresolvedExtractValue(input, Literal(nameParts.head)) - case BoundReference(ordinal, dt, _) => GetStructField(input, ordinal) + case GetColumnByOrdinal(ordinal, _) => GetStructField(input, ordinal) } If(IsNull(input), Literal.create(null, deserialized.dataType), deserialized) } @@ -192,6 +190,26 @@ case class ExpressionEncoder[T]( if (flat) require(serializer.size == 1) + /** + * Returns a new copy of this encoder, where the `deserializer` is resolved and bound to the + * given schema. + * + * Note that, ideally encoder is used as a container of serde expressions, the resolution and + * binding stuff should happen inside query framework. However, in some cases we need to + * use encoder as a function to do serialization directly(e.g. Dataset.collect), then we can use + * this method to do resolution and binding outside of query framework. + */ + def resolveAndBind( + attrs: Seq[Attribute] = schema.toAttributes, + analyzer: Analyzer = SimpleAnalyzer): ExpressionEncoder[T] = { + val dummyPlan = CatalystSerde.deserialize(LocalRelation(attrs))(this) + val analyzedPlan = analyzer.execute(dummyPlan) + analyzer.checkAnalysis(analyzedPlan) + val resolved = SimplifyCasts(analyzedPlan).asInstanceOf[DeserializeToObject].deserializer + val bound = BindReferences.bindReference(resolved, attrs) + copy(deserializer = bound) + } + @transient private lazy val extractProjection = GenerateUnsafeProjection.generate(serializer) @@ -202,16 +220,6 @@ case class ExpressionEncoder[T]( private lazy val constructProjection = GenerateSafeProjection.generate(deserializer :: Nil) /** - * Returns this encoder where it has been bound to its own output (i.e. no remaping of columns - * is performed). - */ - def defaultBinding: ExpressionEncoder[T] = { - val attrs = schema.toAttributes - resolve(attrs, OuterScopes.outerScopes).bind(attrs) - } - - - /** * Returns a new set (with unique ids) of [[NamedExpression]] that represent the serialized form * of this object. */ @@ -236,7 +244,7 @@ case class ExpressionEncoder[T]( /** * Returns an object of type `T`, extracting the required values from the provided row. Note that - * you must `resolve` and `bind` an encoder to a specific schema before you can call this + * you must `resolveAndBind` an encoder to a specific schema before you can call this * function. */ def fromRow(row: InternalRow): T = try { @@ -259,94 +267,6 @@ case class ExpressionEncoder[T]( }) } - /** - * Validates `deserializer` to make sure it can be resolved by given schema, and produce - * friendly error messages to explain why it fails to resolve if there is something wrong. - */ - def validate(schema: Seq[Attribute]): Unit = { - def fail(st: StructType, maxOrdinal: Int): Unit = { - throw new AnalysisException(s"Try to map ${st.simpleString} to Tuple${maxOrdinal + 1}, " + - "but failed as the number of fields does not line up.\n" + - " - Input schema: " + StructType.fromAttributes(schema).simpleString + "\n" + - " - Target schema: " + this.schema.simpleString) - } - - // If this is a tuple encoder or tupled encoder, which means its leaf nodes are all - // `BoundReference`, make sure their ordinals are all valid. - var maxOrdinal = -1 - deserializer.foreach { - case b: BoundReference => if (b.ordinal > maxOrdinal) maxOrdinal = b.ordinal - case _ => - } - if (maxOrdinal >= 0 && maxOrdinal != schema.length - 1) { - fail(StructType.fromAttributes(schema), maxOrdinal) - } - - // If we have nested tuple, the `fromRowExpression` will contains `GetStructField` instead of - // `UnresolvedExtractValue`, so we need to check if their ordinals are all valid. - // Note that, `BoundReference` contains the expected type, but here we need the actual type, so - // we unbound it by the given `schema` and propagate the actual type to `GetStructField`, after - // we resolve the `fromRowExpression`. - val resolved = SimpleAnalyzer.resolveExpression( - deserializer, - LocalRelation(schema), - throws = true) - - val unbound = resolved transform { - case b: BoundReference => schema(b.ordinal) - } - - val exprToMaxOrdinal = scala.collection.mutable.HashMap.empty[Expression, Int] - unbound.foreach { - case g: GetStructField => - val maxOrdinal = exprToMaxOrdinal.getOrElse(g.child, -1) - if (maxOrdinal < g.ordinal) { - exprToMaxOrdinal.update(g.child, g.ordinal) - } - case _ => - } - exprToMaxOrdinal.foreach { - case (expr, maxOrdinal) => - val schema = expr.dataType.asInstanceOf[StructType] - if (maxOrdinal != schema.length - 1) { - fail(schema, maxOrdinal) - } - } - } - - /** - * Returns a new copy of this encoder, where the `deserializer` is resolved to the given schema. - */ - def resolve( - schema: Seq[Attribute], - outerScopes: ConcurrentMap[String, AnyRef]): ExpressionEncoder[T] = { - // Make a fake plan to wrap the deserializer, so that we can go though the whole analyzer, check - // analysis, go through optimizer, etc. - val plan = Project( - Alias(UnresolvedDeserializer(deserializer, schema), "")() :: Nil, - LocalRelation(schema)) - val analyzedPlan = SimpleAnalyzer.execute(plan) - SimpleAnalyzer.checkAnalysis(analyzedPlan) - copy(deserializer = SimplifyCasts(analyzedPlan).expressions.head.children.head) - } - - /** - * Returns a copy of this encoder where the `deserializer` has been bound to the - * ordinals of the given schema. Note that you need to first call resolve before bind. - */ - def bind(schema: Seq[Attribute]): ExpressionEncoder[T] = { - copy(deserializer = BindReferences.bindReference(deserializer, schema)) - } - - /** - * Returns a new encoder with input columns shifted by `delta` ordinals - */ - def shift(delta: Int): ExpressionEncoder[T] = { - copy(deserializer = deserializer transform { - case r: BoundReference => r.copy(ordinal = r.ordinal + delta) - }) - } - protected val attrs = serializer.flatMap(_.collect { case _: UnresolvedAttribute => "" case a: Attribute => s"#${a.exprId}" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala index 0de9166aa2..3c6ae1c5cc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeUtils, GenericArrayData} import org.apache.spark.sql.catalyst.ScalaReflection +import org.apache.spark.sql.catalyst.analysis.GetColumnByOrdinal import org.apache.spark.sql.catalyst.expressions.objects._ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -210,12 +211,7 @@ object RowEncoder { case p: PythonUserDefinedType => p.sqlType case other => other } - val field = BoundReference(i, dt, f.nullable) - If( - IsNull(field), - Literal.create(null, externalDataTypeFor(dt)), - deserializerFor(field) - ) + deserializerFor(GetColumnByOrdinal(i, dt)) } CreateExternalRow(fields, schema) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala index 98ce5dd2ef..55d8adf040 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala @@ -30,26 +30,13 @@ object CatalystSerde { DeserializeToObject(deserializer, generateObjAttr[T], child) } - def deserialize(child: LogicalPlan, encoder: ExpressionEncoder[Row]): DeserializeToObject = { - val deserializer = UnresolvedDeserializer(encoder.deserializer) - DeserializeToObject(deserializer, generateObjAttrForRow(encoder), child) - } - def serialize[T : Encoder](child: LogicalPlan): SerializeFromObject = { SerializeFromObject(encoderFor[T].namedExpressions, child) } - def serialize(child: LogicalPlan, encoder: ExpressionEncoder[Row]): SerializeFromObject = { - SerializeFromObject(encoder.namedExpressions, child) - } - def generateObjAttr[T : Encoder]: Attribute = { AttributeReference("obj", encoderFor[T].deserializer.dataType, nullable = false)() } - - def generateObjAttrForRow(encoder: ExpressionEncoder[Row]): Attribute = { - AttributeReference("obj", encoder.deserializer.dataType, nullable = false)() - } } /** @@ -128,16 +115,16 @@ object MapPartitionsInR { schema: StructType, encoder: ExpressionEncoder[Row], child: LogicalPlan): LogicalPlan = { - val deserialized = CatalystSerde.deserialize(child, encoder) + val deserialized = CatalystSerde.deserialize(child)(encoder) val mapped = MapPartitionsInR( func, packageNames, broadcastVars, encoder.schema, schema, - CatalystSerde.generateObjAttrForRow(RowEncoder(schema)), + CatalystSerde.generateObjAttr(RowEncoder(schema)), deserialized) - CatalystSerde.serialize(mapped, RowEncoder(schema)) + CatalystSerde.serialize(mapped)(RowEncoder(schema)) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala index 3ad0dae767..7251202c7b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala @@ -41,17 +41,17 @@ class EncoderResolutionSuite extends PlanTest { // int type can be up cast to long type val attrs1 = Seq('a.string, 'b.int) - encoder.resolve(attrs1, null).bind(attrs1).fromRow(InternalRow(str, 1)) + encoder.resolveAndBind(attrs1).fromRow(InternalRow(str, 1)) // int type can be up cast to string type val attrs2 = Seq('a.int, 'b.long) - encoder.resolve(attrs2, null).bind(attrs2).fromRow(InternalRow(1, 2L)) + encoder.resolveAndBind(attrs2).fromRow(InternalRow(1, 2L)) } test("real type doesn't match encoder schema but they are compatible: nested product") { val encoder = ExpressionEncoder[ComplexClass] val attrs = Seq('a.int, 'b.struct('a.int, 'b.long)) - encoder.resolve(attrs, null).bind(attrs).fromRow(InternalRow(1, InternalRow(2, 3L))) + encoder.resolveAndBind(attrs).fromRow(InternalRow(1, InternalRow(2, 3L))) } test("real type doesn't match encoder schema but they are compatible: tupled encoder") { @@ -59,7 +59,7 @@ class EncoderResolutionSuite extends PlanTest { ExpressionEncoder[StringLongClass], ExpressionEncoder[Long]) val attrs = Seq('a.struct('a.string, 'b.byte), 'b.int) - encoder.resolve(attrs, null).bind(attrs).fromRow(InternalRow(InternalRow(str, 1.toByte), 2)) + encoder.resolveAndBind(attrs).fromRow(InternalRow(InternalRow(str, 1.toByte), 2)) } test("nullability of array type element should not fail analysis") { @@ -67,7 +67,7 @@ class EncoderResolutionSuite extends PlanTest { val attrs = 'a.array(IntegerType) :: Nil // It should pass analysis - val bound = encoder.resolve(attrs, null).bind(attrs) + val bound = encoder.resolveAndBind(attrs) // If no null values appear, it should works fine bound.fromRow(InternalRow(new GenericArrayData(Array(1, 2)))) @@ -84,20 +84,16 @@ class EncoderResolutionSuite extends PlanTest { { val attrs = Seq('a.string, 'b.long, 'c.int) - assert(intercept[AnalysisException](encoder.validate(attrs)).message == + assert(intercept[AnalysisException](encoder.resolveAndBind(attrs)).message == "Try to map struct<a:string,b:bigint,c:int> to Tuple2, " + - "but failed as the number of fields does not line up.\n" + - " - Input schema: struct<a:string,b:bigint,c:int>\n" + - " - Target schema: struct<_1:string,_2:bigint>") + "but failed as the number of fields does not line up.") } { val attrs = Seq('a.string) - assert(intercept[AnalysisException](encoder.validate(attrs)).message == + assert(intercept[AnalysisException](encoder.resolveAndBind(attrs)).message == "Try to map struct<a:string> to Tuple2, " + - "but failed as the number of fields does not line up.\n" + - " - Input schema: struct<a:string>\n" + - " - Target schema: struct<_1:string,_2:bigint>") + "but failed as the number of fields does not line up.") } } @@ -106,26 +102,22 @@ class EncoderResolutionSuite extends PlanTest { { val attrs = Seq('a.string, 'b.struct('x.long, 'y.string, 'z.int)) - assert(intercept[AnalysisException](encoder.validate(attrs)).message == + assert(intercept[AnalysisException](encoder.resolveAndBind(attrs)).message == "Try to map struct<x:bigint,y:string,z:int> to Tuple2, " + - "but failed as the number of fields does not line up.\n" + - " - Input schema: struct<a:string,b:struct<x:bigint,y:string,z:int>>\n" + - " - Target schema: struct<_1:string,_2:struct<_1:bigint,_2:string>>") + "but failed as the number of fields does not line up.") } { val attrs = Seq('a.string, 'b.struct('x.long)) - assert(intercept[AnalysisException](encoder.validate(attrs)).message == + assert(intercept[AnalysisException](encoder.resolveAndBind(attrs)).message == "Try to map struct<x:bigint> to Tuple2, " + - "but failed as the number of fields does not line up.\n" + - " - Input schema: struct<a:string,b:struct<x:bigint>>\n" + - " - Target schema: struct<_1:string,_2:struct<_1:bigint,_2:string>>") + "but failed as the number of fields does not line up.") } } test("throw exception if real type is not compatible with encoder schema") { val msg1 = intercept[AnalysisException] { - ExpressionEncoder[StringIntClass].resolve(Seq('a.string, 'b.long), null) + ExpressionEncoder[StringIntClass].resolveAndBind(Seq('a.string, 'b.long)) }.message assert(msg1 == s""" @@ -138,7 +130,7 @@ class EncoderResolutionSuite extends PlanTest { val msg2 = intercept[AnalysisException] { val structType = new StructType().add("a", StringType).add("b", DecimalType.SYSTEM_DEFAULT) - ExpressionEncoder[ComplexClass].resolve(Seq('a.long, 'b.struct(structType)), null) + ExpressionEncoder[ComplexClass].resolveAndBind(Seq('a.long, 'b.struct(structType))) }.message assert(msg2 == s""" @@ -171,7 +163,7 @@ class EncoderResolutionSuite extends PlanTest { val to = ExpressionEncoder[U] val catalystType = from.schema.head.dataType.simpleString test(s"cast from $catalystType to ${implicitly[TypeTag[U]].tpe} should success") { - to.resolve(from.schema.toAttributes, null) + to.resolveAndBind(from.schema.toAttributes) } } @@ -180,7 +172,7 @@ class EncoderResolutionSuite extends PlanTest { val to = ExpressionEncoder[U] val catalystType = from.schema.head.dataType.simpleString test(s"cast from $catalystType to ${implicitly[TypeTag[U]].tpe} should fail") { - intercept[AnalysisException](to.resolve(from.schema.toAttributes, null)) + intercept[AnalysisException](to.resolveAndBind(from.schema.toAttributes)) } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala index 232dcc9ee5..a1f9259f13 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala @@ -27,6 +27,7 @@ import scala.reflect.runtime.universe.TypeTag import org.apache.spark.sql.Encoders import org.apache.spark.sql.catalyst.{OptionalData, PrimitiveData} import org.apache.spark.sql.catalyst.analysis.AnalysisTest +import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference} import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project} @@ -334,7 +335,7 @@ class ExpressionEncoderSuite extends PlanTest with AnalysisTest { val encoder = implicitly[ExpressionEncoder[T]] val row = encoder.toRow(input) val schema = encoder.schema.toAttributes - val boundEncoder = encoder.defaultBinding + val boundEncoder = encoder.resolveAndBind() val convertedBack = try boundEncoder.fromRow(row) catch { case e: Exception => fail( @@ -350,12 +351,8 @@ class ExpressionEncoderSuite extends PlanTest with AnalysisTest { } // Test the correct resolution of serialization / deserialization. - val attr = AttributeReference("obj", ObjectType(encoder.clsTag.runtimeClass))() - val inputPlan = LocalRelation(attr) - val plan = - Project(Alias(encoder.deserializer, "obj")() :: Nil, - Project(encoder.namedExpressions, - inputPlan)) + val attr = AttributeReference("obj", encoder.deserializer.dataType)() + val plan = LocalRelation(attr).serialize[T].deserialize[T] assertAnalysisSuccess(plan) val isCorrect = (input, convertedBack) match { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala index 39fcc7225b..6f1bc80c1c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala @@ -135,7 +135,7 @@ class RowEncoderSuite extends SparkFunSuite { .add("string", StringType) .add("double", DoubleType)) - val encoder = RowEncoder(schema) + val encoder = RowEncoder(schema).resolveAndBind() val input: Row = Row((100, "test", 0.123)) val row = encoder.toRow(input) @@ -152,7 +152,7 @@ class RowEncoderSuite extends SparkFunSuite { .add("scala_decimal", DecimalType.SYSTEM_DEFAULT) .add("catalyst_decimal", DecimalType.SYSTEM_DEFAULT) - val encoder = RowEncoder(schema) + val encoder = RowEncoder(schema).resolveAndBind() val javaDecimal = new java.math.BigDecimal("1234.5678") val scalaDecimal = BigDecimal("1234.5678") @@ -169,7 +169,7 @@ class RowEncoderSuite extends SparkFunSuite { test("RowEncoder should preserve decimal precision and scale") { val schema = new StructType().add("decimal", DecimalType(10, 5), false) - val encoder = RowEncoder(schema) + val encoder = RowEncoder(schema).resolveAndBind() val decimal = Decimal("67123.45") val input = Row(decimal) val row = encoder.toRow(input) @@ -179,7 +179,7 @@ class RowEncoderSuite extends SparkFunSuite { test("RowEncoder should preserve schema nullability") { val schema = new StructType().add("int", IntegerType, nullable = false) - val encoder = RowEncoder(schema) + val encoder = RowEncoder(schema).resolveAndBind() assert(encoder.serializer.length == 1) assert(encoder.serializer.head.dataType == IntegerType) assert(encoder.serializer.head.nullable == false) @@ -195,7 +195,7 @@ class RowEncoderSuite extends SparkFunSuite { new StructType().add("int", IntegerType, nullable = false), nullable = false), nullable = false) - val encoder = RowEncoder(schema) + val encoder = RowEncoder(schema).resolveAndBind() assert(encoder.serializer.length == 1) assert(encoder.serializer.head.dataType == new StructType() @@ -212,7 +212,7 @@ class RowEncoderSuite extends SparkFunSuite { .add("array", ArrayType(IntegerType)) .add("nestedArray", ArrayType(ArrayType(StringType))) .add("deepNestedArray", ArrayType(ArrayType(ArrayType(LongType)))) - val encoder = RowEncoder(schema) + val encoder = RowEncoder(schema).resolveAndBind() val input = Row( Array(1, 2, null), Array(Array("abc", null), null), @@ -226,7 +226,7 @@ class RowEncoderSuite extends SparkFunSuite { private def encodeDecodeTest(schema: StructType): Unit = { test(s"encode/decode: ${schema.simpleString}") { - val encoder = RowEncoder(schema) + val encoder = RowEncoder(schema).resolveAndBind() val inputGenerator = RandomDataGenerator.forType(schema, nullable = false).get var input: Row = null |