aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorWenchen Fan <wenchen@databricks.com>2016-06-03 00:43:02 -0700
committerCheng Lian <lian@databricks.com>2016-06-03 00:43:02 -0700
commit190ff274fd71662023a804cf98400c71f9f7da4f (patch)
tree9b3f79aebf252d3c27f53d9593000c5fd58e1509
parentb9fcfb3bd14592ac9f1a8e5c2bb31412b9603b60 (diff)
downloadspark-190ff274fd71662023a804cf98400c71f9f7da4f.tar.gz
spark-190ff274fd71662023a804cf98400c71f9f7da4f.tar.bz2
spark-190ff274fd71662023a804cf98400c71f9f7da4f.zip
[SPARK-15494][SQL] encoder code cleanup
## What changes were proposed in this pull request? Our encoder framework has been evolved a lot, this PR tries to clean up the code to make it more readable and emphasise the concept that encoder should be used as a container of serde expressions. 1. move validation logic to analyzer instead of encoder 2. only have a `resolveAndBind` method in encoder instead of `resolve` and `bind`, as we don't have the encoder life cycle concept anymore. 3. `Dataset` don't need to keep a resolved encoder, as there is no such concept anymore. bound encoder is still needed to do serialization outside of query framework. 4. Using `BoundReference` to represent an unresolved field in deserializer expression is kind of weird, this PR adds a `GetColumnByOrdinal` for this purpose. (serializer expression still use `BoundReference`, we can replace it with `GetColumnByOrdinal` in follow-ups) ## How was this patch tested? existing test Author: Wenchen Fan <wenchen@databricks.com> Author: Cheng Lian <lian@databricks.com> Closes #13269 from cloud-fan/clean-encoder.
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/linalg/UDTSerializationBenchmark.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala3
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala6
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala307
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala53
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala7
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala134
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala8
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala19
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala42
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala11
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala14
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala48
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala26
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala6
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/functions.scala2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala8
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala8
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/GroupedIteratorSuite.scala6
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala4
21 files changed, 324 insertions, 392 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/UDTSerializationBenchmark.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/UDTSerializationBenchmark.scala
index be7110ad6b..8b439e6b7a 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/UDTSerializationBenchmark.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/UDTSerializationBenchmark.scala
@@ -29,7 +29,7 @@ object UDTSerializationBenchmark {
val iters = 1e2.toInt
val numRows = 1e3.toInt
- val encoder = ExpressionEncoder[Vector].defaultBinding
+ val encoder = ExpressionEncoder[Vector].resolveAndBind()
val vectors = (1 to numRows).map { i =>
Vectors.dense(Array.fill(1e5.toInt)(1.0 * i))
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala
index fa96f8223d..673c587b18 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala
@@ -23,6 +23,7 @@ import scala.reflect.{classTag, ClassTag}
import scala.reflect.runtime.universe.TypeTag
import org.apache.spark.annotation.Experimental
+import org.apache.spark.sql.catalyst.analysis.GetColumnByOrdinal
import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder}
import org.apache.spark.sql.catalyst.expressions.objects.{DecodeUsingSerializer, EncodeUsingSerializer}
import org.apache.spark.sql.catalyst.expressions.BoundReference
@@ -208,7 +209,7 @@ object Encoders {
BoundReference(0, ObjectType(classOf[AnyRef]), nullable = true), kryo = useKryo)),
deserializer =
DecodeUsingSerializer[T](
- BoundReference(0, BinaryType, nullable = true), classTag[T], kryo = useKryo),
+ GetColumnByOrdinal(0, BinaryType), classTag[T], kryo = useKryo),
clsTag = classTag[T]
)
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
index 1fe143494a..b3a233ae39 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
@@ -25,7 +25,7 @@ import scala.language.existentials
import com.google.common.reflect.TypeToken
-import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedExtractValue}
+import org.apache.spark.sql.catalyst.analysis.{GetColumnByOrdinal, UnresolvedAttribute, UnresolvedExtractValue}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.objects._
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeUtils, GenericArrayData}
@@ -177,8 +177,8 @@ object JavaTypeInference {
.map(p => UnresolvedExtractValue(p, expressions.Literal(part)))
.getOrElse(UnresolvedAttribute(part))
- /** Returns the current path or `BoundReference`. */
- def getPath: Expression = path.getOrElse(BoundReference(0, inferDataType(typeToken)._1, true))
+ /** Returns the current path or `GetColumnByOrdinal`. */
+ def getPath: Expression = path.getOrElse(GetColumnByOrdinal(0, inferDataType(typeToken)._1))
typeToken.getRawType match {
case c if !inferExternalType(c).isInstanceOf[ObjectType] => getPath
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index 4750861817..78c145d4fd 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql.catalyst
-import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedExtractValue}
+import org.apache.spark.sql.catalyst.analysis.{GetColumnByOrdinal, UnresolvedAttribute, UnresolvedExtractValue}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.objects._
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeUtils, GenericArrayData}
@@ -156,17 +156,17 @@ object ScalaReflection extends ScalaReflection {
walkedTypePath: Seq[String]): Expression = {
val newPath = path
.map(p => GetStructField(p, ordinal))
- .getOrElse(BoundReference(ordinal, dataType, false))
+ .getOrElse(GetColumnByOrdinal(ordinal, dataType))
upCastToExpectedType(newPath, dataType, walkedTypePath)
}
- /** Returns the current path or `BoundReference`. */
+ /** Returns the current path or `GetColumnByOrdinal`. */
def getPath: Expression = {
val dataType = schemaFor(tpe).dataType
if (path.isDefined) {
path.get
} else {
- upCastToExpectedType(BoundReference(0, dataType, true), dataType, walkedTypePath)
+ upCastToExpectedType(GetColumnByOrdinal(0, dataType), dataType, walkedTypePath)
}
}
@@ -421,7 +421,7 @@ object ScalaReflection extends ScalaReflection {
def serializerFor[T : TypeTag](inputObject: Expression): CreateNamedStruct = {
val tpe = localTypeOf[T]
val clsName = getClassNameFromType(tpe)
- val walkedTypePath = s"""- root class: "${clsName}"""" :: Nil
+ val walkedTypePath = s"""- root class: "$clsName"""" :: Nil
serializerFor(inputObject, tpe, walkedTypePath) match {
case expressions.If(_, _, s: CreateNamedStruct) if definedByConstructorParams(tpe) => s
case other => CreateNamedStruct(expressions.Literal("value") :: other :: Nil)
@@ -449,157 +449,156 @@ object ScalaReflection extends ScalaReflection {
}
}
- if (!inputObject.dataType.isInstanceOf[ObjectType]) {
- inputObject
- } else {
- tpe match {
- case t if t <:< localTypeOf[Option[_]] =>
- val TypeRef(_, _, Seq(optType)) = t
- val className = getClassNameFromType(optType)
- val newPath = s"""- option value class: "$className"""" +: walkedTypePath
- val unwrapped = UnwrapOption(dataTypeFor(optType), inputObject)
- serializerFor(unwrapped, optType, newPath)
-
- // Since List[_] also belongs to localTypeOf[Product], we put this case before
- // "case t if definedByConstructorParams(t)" to make sure it will match to the
- // case "localTypeOf[Seq[_]]"
- case t if t <:< localTypeOf[Seq[_]] =>
- val TypeRef(_, _, Seq(elementType)) = t
- toCatalystArray(inputObject, elementType)
-
- case t if t <:< localTypeOf[Array[_]] =>
- val TypeRef(_, _, Seq(elementType)) = t
- toCatalystArray(inputObject, elementType)
-
- case t if t <:< localTypeOf[Map[_, _]] =>
- val TypeRef(_, _, Seq(keyType, valueType)) = t
-
- val keys =
- Invoke(
- Invoke(inputObject, "keysIterator",
- ObjectType(classOf[scala.collection.Iterator[_]])),
- "toSeq",
- ObjectType(classOf[scala.collection.Seq[_]]))
- val convertedKeys = toCatalystArray(keys, keyType)
-
- val values =
- Invoke(
- Invoke(inputObject, "valuesIterator",
- ObjectType(classOf[scala.collection.Iterator[_]])),
- "toSeq",
- ObjectType(classOf[scala.collection.Seq[_]]))
- val convertedValues = toCatalystArray(values, valueType)
-
- val Schema(keyDataType, _) = schemaFor(keyType)
- val Schema(valueDataType, valueNullable) = schemaFor(valueType)
- NewInstance(
- classOf[ArrayBasedMapData],
- convertedKeys :: convertedValues :: Nil,
- dataType = MapType(keyDataType, valueDataType, valueNullable))
-
- case t if t <:< localTypeOf[String] =>
- StaticInvoke(
- classOf[UTF8String],
- StringType,
- "fromString",
- inputObject :: Nil)
-
- case t if t <:< localTypeOf[java.sql.Timestamp] =>
- StaticInvoke(
- DateTimeUtils.getClass,
- TimestampType,
- "fromJavaTimestamp",
- inputObject :: Nil)
-
- case t if t <:< localTypeOf[java.sql.Date] =>
- StaticInvoke(
- DateTimeUtils.getClass,
- DateType,
- "fromJavaDate",
- inputObject :: Nil)
-
- case t if t <:< localTypeOf[BigDecimal] =>
- StaticInvoke(
- Decimal.getClass,
- DecimalType.SYSTEM_DEFAULT,
- "apply",
- inputObject :: Nil)
-
- case t if t <:< localTypeOf[java.math.BigDecimal] =>
- StaticInvoke(
- Decimal.getClass,
- DecimalType.SYSTEM_DEFAULT,
- "apply",
- inputObject :: Nil)
-
- case t if t <:< localTypeOf[java.math.BigInteger] =>
- StaticInvoke(
- Decimal.getClass,
- DecimalType.BigIntDecimal,
- "apply",
- inputObject :: Nil)
-
- case t if t <:< localTypeOf[scala.math.BigInt] =>
- StaticInvoke(
- Decimal.getClass,
- DecimalType.BigIntDecimal,
- "apply",
- inputObject :: Nil)
-
- case t if t <:< localTypeOf[java.lang.Integer] =>
- Invoke(inputObject, "intValue", IntegerType)
- case t if t <:< localTypeOf[java.lang.Long] =>
- Invoke(inputObject, "longValue", LongType)
- case t if t <:< localTypeOf[java.lang.Double] =>
- Invoke(inputObject, "doubleValue", DoubleType)
- case t if t <:< localTypeOf[java.lang.Float] =>
- Invoke(inputObject, "floatValue", FloatType)
- case t if t <:< localTypeOf[java.lang.Short] =>
- Invoke(inputObject, "shortValue", ShortType)
- case t if t <:< localTypeOf[java.lang.Byte] =>
- Invoke(inputObject, "byteValue", ByteType)
- case t if t <:< localTypeOf[java.lang.Boolean] =>
- Invoke(inputObject, "booleanValue", BooleanType)
-
- case t if t.typeSymbol.annotations.exists(_.tpe =:= typeOf[SQLUserDefinedType]) =>
- val udt = getClassFromType(t)
- .getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance()
- val obj = NewInstance(
- udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt(),
- Nil,
- dataType = ObjectType(udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt()))
- Invoke(obj, "serialize", udt, inputObject :: Nil)
-
- case t if UDTRegistration.exists(getClassNameFromType(t)) =>
- val udt = UDTRegistration.getUDTFor(getClassNameFromType(t)).get.newInstance()
- .asInstanceOf[UserDefinedType[_]]
- val obj = NewInstance(
- udt.getClass,
- Nil,
- dataType = ObjectType(udt.getClass))
- Invoke(obj, "serialize", udt, inputObject :: Nil)
-
- case t if definedByConstructorParams(t) =>
- val params = getConstructorParameters(t)
- val nonNullOutput = CreateNamedStruct(params.flatMap { case (fieldName, fieldType) =>
- if (javaKeywords.contains(fieldName)) {
- throw new UnsupportedOperationException(s"`$fieldName` is a reserved keyword and " +
- "cannot be used as field name\n" + walkedTypePath.mkString("\n"))
- }
+ tpe match {
+ case _ if !inputObject.dataType.isInstanceOf[ObjectType] => inputObject
- val fieldValue = Invoke(inputObject, fieldName, dataTypeFor(fieldType))
- val clsName = getClassNameFromType(fieldType)
- val newPath = s"""- field (class: "$clsName", name: "$fieldName")""" +: walkedTypePath
- expressions.Literal(fieldName) :: serializerFor(fieldValue, fieldType, newPath) :: Nil
- })
- val nullOutput = expressions.Literal.create(null, nonNullOutput.dataType)
- expressions.If(IsNull(inputObject), nullOutput, nonNullOutput)
-
- case other =>
- throw new UnsupportedOperationException(
- s"No Encoder found for $tpe\n" + walkedTypePath.mkString("\n"))
- }
+ case t if t <:< localTypeOf[Option[_]] =>
+ val TypeRef(_, _, Seq(optType)) = t
+ val className = getClassNameFromType(optType)
+ val newPath = s"""- option value class: "$className"""" +: walkedTypePath
+ val unwrapped = UnwrapOption(dataTypeFor(optType), inputObject)
+ serializerFor(unwrapped, optType, newPath)
+
+ // Since List[_] also belongs to localTypeOf[Product], we put this case before
+ // "case t if definedByConstructorParams(t)" to make sure it will match to the
+ // case "localTypeOf[Seq[_]]"
+ case t if t <:< localTypeOf[Seq[_]] =>
+ val TypeRef(_, _, Seq(elementType)) = t
+ toCatalystArray(inputObject, elementType)
+
+ case t if t <:< localTypeOf[Array[_]] =>
+ val TypeRef(_, _, Seq(elementType)) = t
+ toCatalystArray(inputObject, elementType)
+
+ case t if t <:< localTypeOf[Map[_, _]] =>
+ val TypeRef(_, _, Seq(keyType, valueType)) = t
+
+ val keys =
+ Invoke(
+ Invoke(inputObject, "keysIterator",
+ ObjectType(classOf[scala.collection.Iterator[_]])),
+ "toSeq",
+ ObjectType(classOf[scala.collection.Seq[_]]))
+ val convertedKeys = toCatalystArray(keys, keyType)
+
+ val values =
+ Invoke(
+ Invoke(inputObject, "valuesIterator",
+ ObjectType(classOf[scala.collection.Iterator[_]])),
+ "toSeq",
+ ObjectType(classOf[scala.collection.Seq[_]]))
+ val convertedValues = toCatalystArray(values, valueType)
+
+ val Schema(keyDataType, _) = schemaFor(keyType)
+ val Schema(valueDataType, valueNullable) = schemaFor(valueType)
+ NewInstance(
+ classOf[ArrayBasedMapData],
+ convertedKeys :: convertedValues :: Nil,
+ dataType = MapType(keyDataType, valueDataType, valueNullable))
+
+ case t if t <:< localTypeOf[String] =>
+ StaticInvoke(
+ classOf[UTF8String],
+ StringType,
+ "fromString",
+ inputObject :: Nil)
+
+ case t if t <:< localTypeOf[java.sql.Timestamp] =>
+ StaticInvoke(
+ DateTimeUtils.getClass,
+ TimestampType,
+ "fromJavaTimestamp",
+ inputObject :: Nil)
+
+ case t if t <:< localTypeOf[java.sql.Date] =>
+ StaticInvoke(
+ DateTimeUtils.getClass,
+ DateType,
+ "fromJavaDate",
+ inputObject :: Nil)
+
+ case t if t <:< localTypeOf[BigDecimal] =>
+ StaticInvoke(
+ Decimal.getClass,
+ DecimalType.SYSTEM_DEFAULT,
+ "apply",
+ inputObject :: Nil)
+
+ case t if t <:< localTypeOf[java.math.BigDecimal] =>
+ StaticInvoke(
+ Decimal.getClass,
+ DecimalType.SYSTEM_DEFAULT,
+ "apply",
+ inputObject :: Nil)
+
+ case t if t <:< localTypeOf[java.math.BigInteger] =>
+ StaticInvoke(
+ Decimal.getClass,
+ DecimalType.BigIntDecimal,
+ "apply",
+ inputObject :: Nil)
+
+ case t if t <:< localTypeOf[scala.math.BigInt] =>
+ StaticInvoke(
+ Decimal.getClass,
+ DecimalType.BigIntDecimal,
+ "apply",
+ inputObject :: Nil)
+
+ case t if t <:< localTypeOf[java.lang.Integer] =>
+ Invoke(inputObject, "intValue", IntegerType)
+ case t if t <:< localTypeOf[java.lang.Long] =>
+ Invoke(inputObject, "longValue", LongType)
+ case t if t <:< localTypeOf[java.lang.Double] =>
+ Invoke(inputObject, "doubleValue", DoubleType)
+ case t if t <:< localTypeOf[java.lang.Float] =>
+ Invoke(inputObject, "floatValue", FloatType)
+ case t if t <:< localTypeOf[java.lang.Short] =>
+ Invoke(inputObject, "shortValue", ShortType)
+ case t if t <:< localTypeOf[java.lang.Byte] =>
+ Invoke(inputObject, "byteValue", ByteType)
+ case t if t <:< localTypeOf[java.lang.Boolean] =>
+ Invoke(inputObject, "booleanValue", BooleanType)
+
+ case t if t.typeSymbol.annotations.exists(_.tpe =:= typeOf[SQLUserDefinedType]) =>
+ val udt = getClassFromType(t)
+ .getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance()
+ val obj = NewInstance(
+ udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt(),
+ Nil,
+ dataType = ObjectType(udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt()))
+ Invoke(obj, "serialize", udt, inputObject :: Nil)
+
+ case t if UDTRegistration.exists(getClassNameFromType(t)) =>
+ val udt = UDTRegistration.getUDTFor(getClassNameFromType(t)).get.newInstance()
+ .asInstanceOf[UserDefinedType[_]]
+ val obj = NewInstance(
+ udt.getClass,
+ Nil,
+ dataType = ObjectType(udt.getClass))
+ Invoke(obj, "serialize", udt, inputObject :: Nil)
+
+ case t if definedByConstructorParams(t) =>
+ val params = getConstructorParameters(t)
+ val nonNullOutput = CreateNamedStruct(params.flatMap { case (fieldName, fieldType) =>
+ if (javaKeywords.contains(fieldName)) {
+ throw new UnsupportedOperationException(s"`$fieldName` is a reserved keyword and " +
+ "cannot be used as field name\n" + walkedTypePath.mkString("\n"))
+ }
+
+ val fieldValue = Invoke(inputObject, fieldName, dataTypeFor(fieldType))
+ val clsName = getClassNameFromType(fieldType)
+ val newPath = s"""- field (class: "$clsName", name: "$fieldName")""" +: walkedTypePath
+ expressions.Literal(fieldName) :: serializerFor(fieldValue, fieldType, newPath) :: Nil
+ })
+ val nullOutput = expressions.Literal.create(null, nonNullOutput.dataType)
+ expressions.If(IsNull(inputObject), nullOutput, nonNullOutput)
+
+ case other =>
+ throw new UnsupportedOperationException(
+ s"No Encoder found for $tpe\n" + walkedTypePath.mkString("\n"))
}
+
}
/**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 02966796af..4f6b4830cd 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -722,6 +722,7 @@ class Analyzer(
// Else, throw exception.
try {
expr transformUp {
+ case GetColumnByOrdinal(ordinal, _) => plan.output(ordinal)
case u @ UnresolvedAttribute(nameParts) =>
withPosition(u) { plan.resolve(nameParts, resolver).getOrElse(u) }
case UnresolvedExtractValue(child, fieldName) if child.resolved =>
@@ -1924,10 +1925,54 @@ class Analyzer(
} else {
inputAttributes
}
- val unbound = deserializer transform {
- case b: BoundReference => inputs(b.ordinal)
- }
- resolveExpression(unbound, LocalRelation(inputs), throws = true)
+
+ validateTopLevelTupleFields(deserializer, inputs)
+ val resolved = resolveExpression(
+ deserializer, LocalRelation(inputs), throws = true)
+ validateNestedTupleFields(resolved)
+ resolved
+ }
+ }
+
+ private def fail(schema: StructType, maxOrdinal: Int): Unit = {
+ throw new AnalysisException(s"Try to map ${schema.simpleString} to Tuple${maxOrdinal + 1}, " +
+ "but failed as the number of fields does not line up.")
+ }
+
+ /**
+ * For each top-level Tuple field, we use [[GetColumnByOrdinal]] to get its corresponding column
+ * by position. However, the actual number of columns may be different from the number of Tuple
+ * fields. This method is used to check the number of columns and fields, and throw an
+ * exception if they do not match.
+ */
+ private def validateTopLevelTupleFields(
+ deserializer: Expression, inputs: Seq[Attribute]): Unit = {
+ val ordinals = deserializer.collect {
+ case GetColumnByOrdinal(ordinal, _) => ordinal
+ }.distinct.sorted
+
+ if (ordinals.nonEmpty && ordinals != inputs.indices) {
+ fail(inputs.toStructType, ordinals.last)
+ }
+ }
+
+ /**
+ * For each nested Tuple field, we use [[GetStructField]] to get its corresponding struct field
+ * by position. However, the actual number of struct fields may be different from the number
+ * of nested Tuple fields. This method is used to check the number of struct fields and nested
+ * Tuple fields, and throw an exception if they do not match.
+ */
+ private def validateNestedTupleFields(deserializer: Expression): Unit = {
+ val structChildToOrdinals = deserializer
+ .collect { case g: GetStructField => g }
+ .groupBy(_.child)
+ .mapValues(_.map(_.ordinal).distinct.sorted)
+
+ structChildToOrdinals.foreach { case (expr, ordinals) =>
+ val schema = expr.dataType.asInstanceOf[StructType]
+ if (ordinals != schema.indices) {
+ fail(schema, ordinals.last)
+ }
}
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
index e953eda784..b883546135 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
@@ -366,3 +366,10 @@ case class UnresolvedDeserializer(deserializer: Expression, inputAttributes: Seq
override def nullable: Boolean = throw new UnresolvedException(this, "nullable")
override lazy val resolved = false
}
+
+case class GetColumnByOrdinal(ordinal: Int, dataType: DataType) extends LeafExpression
+ with Unevaluable with NonSQLExpression {
+ override def foldable: Boolean = throw new UnresolvedException(this, "foldable")
+ override def nullable: Boolean = throw new UnresolvedException(this, "nullable")
+ override lazy val resolved = false
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
index 2296946cd7..cc59d06fa3 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
@@ -17,19 +17,17 @@
package org.apache.spark.sql.catalyst.encoders
-import java.util.concurrent.ConcurrentMap
-
import scala.reflect.ClassTag
import scala.reflect.runtime.universe.{typeTag, TypeTag}
-import org.apache.spark.sql.{AnalysisException, Encoder}
+import org.apache.spark.sql.Encoder
import org.apache.spark.sql.catalyst.{InternalRow, JavaTypeInference, ScalaReflection}
-import org.apache.spark.sql.catalyst.analysis.{SimpleAnalyzer, UnresolvedAttribute, UnresolvedDeserializer, UnresolvedExtractValue}
+import org.apache.spark.sql.catalyst.analysis.{Analyzer, GetColumnByOrdinal, SimpleAnalyzer, UnresolvedAttribute, UnresolvedExtractValue}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection, GenerateUnsafeProjection}
import org.apache.spark.sql.catalyst.expressions.objects.{Invoke, NewInstance}
import org.apache.spark.sql.catalyst.optimizer.SimplifyCasts
-import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project}
+import org.apache.spark.sql.catalyst.plans.logical.{CatalystSerde, DeserializeToObject, LocalRelation}
import org.apache.spark.sql.types.{ObjectType, StructField, StructType}
import org.apache.spark.util.Utils
@@ -121,15 +119,15 @@ object ExpressionEncoder {
val childrenDeserializers = encoders.zipWithIndex.map { case (enc, index) =>
if (enc.flat) {
enc.deserializer.transform {
- case b: BoundReference => b.copy(ordinal = index)
+ case g: GetColumnByOrdinal => g.copy(ordinal = index)
}
} else {
- val input = BoundReference(index, enc.schema, nullable = true)
+ val input = GetColumnByOrdinal(index, enc.schema)
val deserialized = enc.deserializer.transformUp {
case UnresolvedAttribute(nameParts) =>
assert(nameParts.length == 1)
UnresolvedExtractValue(input, Literal(nameParts.head))
- case BoundReference(ordinal, dt, _) => GetStructField(input, ordinal)
+ case GetColumnByOrdinal(ordinal, _) => GetStructField(input, ordinal)
}
If(IsNull(input), Literal.create(null, deserialized.dataType), deserialized)
}
@@ -192,6 +190,26 @@ case class ExpressionEncoder[T](
if (flat) require(serializer.size == 1)
+ /**
+ * Returns a new copy of this encoder, where the `deserializer` is resolved and bound to the
+ * given schema.
+ *
+ * Note that, ideally encoder is used as a container of serde expressions, the resolution and
+ * binding stuff should happen inside query framework. However, in some cases we need to
+ * use encoder as a function to do serialization directly(e.g. Dataset.collect), then we can use
+ * this method to do resolution and binding outside of query framework.
+ */
+ def resolveAndBind(
+ attrs: Seq[Attribute] = schema.toAttributes,
+ analyzer: Analyzer = SimpleAnalyzer): ExpressionEncoder[T] = {
+ val dummyPlan = CatalystSerde.deserialize(LocalRelation(attrs))(this)
+ val analyzedPlan = analyzer.execute(dummyPlan)
+ analyzer.checkAnalysis(analyzedPlan)
+ val resolved = SimplifyCasts(analyzedPlan).asInstanceOf[DeserializeToObject].deserializer
+ val bound = BindReferences.bindReference(resolved, attrs)
+ copy(deserializer = bound)
+ }
+
@transient
private lazy val extractProjection = GenerateUnsafeProjection.generate(serializer)
@@ -202,16 +220,6 @@ case class ExpressionEncoder[T](
private lazy val constructProjection = GenerateSafeProjection.generate(deserializer :: Nil)
/**
- * Returns this encoder where it has been bound to its own output (i.e. no remaping of columns
- * is performed).
- */
- def defaultBinding: ExpressionEncoder[T] = {
- val attrs = schema.toAttributes
- resolve(attrs, OuterScopes.outerScopes).bind(attrs)
- }
-
-
- /**
* Returns a new set (with unique ids) of [[NamedExpression]] that represent the serialized form
* of this object.
*/
@@ -236,7 +244,7 @@ case class ExpressionEncoder[T](
/**
* Returns an object of type `T`, extracting the required values from the provided row. Note that
- * you must `resolve` and `bind` an encoder to a specific schema before you can call this
+ * you must `resolveAndBind` an encoder to a specific schema before you can call this
* function.
*/
def fromRow(row: InternalRow): T = try {
@@ -259,94 +267,6 @@ case class ExpressionEncoder[T](
})
}
- /**
- * Validates `deserializer` to make sure it can be resolved by given schema, and produce
- * friendly error messages to explain why it fails to resolve if there is something wrong.
- */
- def validate(schema: Seq[Attribute]): Unit = {
- def fail(st: StructType, maxOrdinal: Int): Unit = {
- throw new AnalysisException(s"Try to map ${st.simpleString} to Tuple${maxOrdinal + 1}, " +
- "but failed as the number of fields does not line up.\n" +
- " - Input schema: " + StructType.fromAttributes(schema).simpleString + "\n" +
- " - Target schema: " + this.schema.simpleString)
- }
-
- // If this is a tuple encoder or tupled encoder, which means its leaf nodes are all
- // `BoundReference`, make sure their ordinals are all valid.
- var maxOrdinal = -1
- deserializer.foreach {
- case b: BoundReference => if (b.ordinal > maxOrdinal) maxOrdinal = b.ordinal
- case _ =>
- }
- if (maxOrdinal >= 0 && maxOrdinal != schema.length - 1) {
- fail(StructType.fromAttributes(schema), maxOrdinal)
- }
-
- // If we have nested tuple, the `fromRowExpression` will contains `GetStructField` instead of
- // `UnresolvedExtractValue`, so we need to check if their ordinals are all valid.
- // Note that, `BoundReference` contains the expected type, but here we need the actual type, so
- // we unbound it by the given `schema` and propagate the actual type to `GetStructField`, after
- // we resolve the `fromRowExpression`.
- val resolved = SimpleAnalyzer.resolveExpression(
- deserializer,
- LocalRelation(schema),
- throws = true)
-
- val unbound = resolved transform {
- case b: BoundReference => schema(b.ordinal)
- }
-
- val exprToMaxOrdinal = scala.collection.mutable.HashMap.empty[Expression, Int]
- unbound.foreach {
- case g: GetStructField =>
- val maxOrdinal = exprToMaxOrdinal.getOrElse(g.child, -1)
- if (maxOrdinal < g.ordinal) {
- exprToMaxOrdinal.update(g.child, g.ordinal)
- }
- case _ =>
- }
- exprToMaxOrdinal.foreach {
- case (expr, maxOrdinal) =>
- val schema = expr.dataType.asInstanceOf[StructType]
- if (maxOrdinal != schema.length - 1) {
- fail(schema, maxOrdinal)
- }
- }
- }
-
- /**
- * Returns a new copy of this encoder, where the `deserializer` is resolved to the given schema.
- */
- def resolve(
- schema: Seq[Attribute],
- outerScopes: ConcurrentMap[String, AnyRef]): ExpressionEncoder[T] = {
- // Make a fake plan to wrap the deserializer, so that we can go though the whole analyzer, check
- // analysis, go through optimizer, etc.
- val plan = Project(
- Alias(UnresolvedDeserializer(deserializer, schema), "")() :: Nil,
- LocalRelation(schema))
- val analyzedPlan = SimpleAnalyzer.execute(plan)
- SimpleAnalyzer.checkAnalysis(analyzedPlan)
- copy(deserializer = SimplifyCasts(analyzedPlan).expressions.head.children.head)
- }
-
- /**
- * Returns a copy of this encoder where the `deserializer` has been bound to the
- * ordinals of the given schema. Note that you need to first call resolve before bind.
- */
- def bind(schema: Seq[Attribute]): ExpressionEncoder[T] = {
- copy(deserializer = BindReferences.bindReference(deserializer, schema))
- }
-
- /**
- * Returns a new encoder with input columns shifted by `delta` ordinals
- */
- def shift(delta: Int): ExpressionEncoder[T] = {
- copy(deserializer = deserializer transform {
- case r: BoundReference => r.copy(ordinal = r.ordinal + delta)
- })
- }
-
protected val attrs = serializer.flatMap(_.collect {
case _: UnresolvedAttribute => ""
case a: Attribute => s"#${a.exprId}"
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
index 0de9166aa2..3c6ae1c5cc 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
@@ -25,6 +25,7 @@ import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeUtils, GenericArrayData}
import org.apache.spark.sql.catalyst.ScalaReflection
+import org.apache.spark.sql.catalyst.analysis.GetColumnByOrdinal
import org.apache.spark.sql.catalyst.expressions.objects._
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
@@ -210,12 +211,7 @@ object RowEncoder {
case p: PythonUserDefinedType => p.sqlType
case other => other
}
- val field = BoundReference(i, dt, f.nullable)
- If(
- IsNull(field),
- Literal.create(null, externalDataTypeFor(dt)),
- deserializerFor(field)
- )
+ deserializerFor(GetColumnByOrdinal(i, dt))
}
CreateExternalRow(fields, schema)
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
index 98ce5dd2ef..55d8adf040 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
@@ -30,26 +30,13 @@ object CatalystSerde {
DeserializeToObject(deserializer, generateObjAttr[T], child)
}
- def deserialize(child: LogicalPlan, encoder: ExpressionEncoder[Row]): DeserializeToObject = {
- val deserializer = UnresolvedDeserializer(encoder.deserializer)
- DeserializeToObject(deserializer, generateObjAttrForRow(encoder), child)
- }
-
def serialize[T : Encoder](child: LogicalPlan): SerializeFromObject = {
SerializeFromObject(encoderFor[T].namedExpressions, child)
}
- def serialize(child: LogicalPlan, encoder: ExpressionEncoder[Row]): SerializeFromObject = {
- SerializeFromObject(encoder.namedExpressions, child)
- }
-
def generateObjAttr[T : Encoder]: Attribute = {
AttributeReference("obj", encoderFor[T].deserializer.dataType, nullable = false)()
}
-
- def generateObjAttrForRow(encoder: ExpressionEncoder[Row]): Attribute = {
- AttributeReference("obj", encoder.deserializer.dataType, nullable = false)()
- }
}
/**
@@ -128,16 +115,16 @@ object MapPartitionsInR {
schema: StructType,
encoder: ExpressionEncoder[Row],
child: LogicalPlan): LogicalPlan = {
- val deserialized = CatalystSerde.deserialize(child, encoder)
+ val deserialized = CatalystSerde.deserialize(child)(encoder)
val mapped = MapPartitionsInR(
func,
packageNames,
broadcastVars,
encoder.schema,
schema,
- CatalystSerde.generateObjAttrForRow(RowEncoder(schema)),
+ CatalystSerde.generateObjAttr(RowEncoder(schema)),
deserialized)
- CatalystSerde.serialize(mapped, RowEncoder(schema))
+ CatalystSerde.serialize(mapped)(RowEncoder(schema))
}
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala
index 3ad0dae767..7251202c7b 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala
@@ -41,17 +41,17 @@ class EncoderResolutionSuite extends PlanTest {
// int type can be up cast to long type
val attrs1 = Seq('a.string, 'b.int)
- encoder.resolve(attrs1, null).bind(attrs1).fromRow(InternalRow(str, 1))
+ encoder.resolveAndBind(attrs1).fromRow(InternalRow(str, 1))
// int type can be up cast to string type
val attrs2 = Seq('a.int, 'b.long)
- encoder.resolve(attrs2, null).bind(attrs2).fromRow(InternalRow(1, 2L))
+ encoder.resolveAndBind(attrs2).fromRow(InternalRow(1, 2L))
}
test("real type doesn't match encoder schema but they are compatible: nested product") {
val encoder = ExpressionEncoder[ComplexClass]
val attrs = Seq('a.int, 'b.struct('a.int, 'b.long))
- encoder.resolve(attrs, null).bind(attrs).fromRow(InternalRow(1, InternalRow(2, 3L)))
+ encoder.resolveAndBind(attrs).fromRow(InternalRow(1, InternalRow(2, 3L)))
}
test("real type doesn't match encoder schema but they are compatible: tupled encoder") {
@@ -59,7 +59,7 @@ class EncoderResolutionSuite extends PlanTest {
ExpressionEncoder[StringLongClass],
ExpressionEncoder[Long])
val attrs = Seq('a.struct('a.string, 'b.byte), 'b.int)
- encoder.resolve(attrs, null).bind(attrs).fromRow(InternalRow(InternalRow(str, 1.toByte), 2))
+ encoder.resolveAndBind(attrs).fromRow(InternalRow(InternalRow(str, 1.toByte), 2))
}
test("nullability of array type element should not fail analysis") {
@@ -67,7 +67,7 @@ class EncoderResolutionSuite extends PlanTest {
val attrs = 'a.array(IntegerType) :: Nil
// It should pass analysis
- val bound = encoder.resolve(attrs, null).bind(attrs)
+ val bound = encoder.resolveAndBind(attrs)
// If no null values appear, it should works fine
bound.fromRow(InternalRow(new GenericArrayData(Array(1, 2))))
@@ -84,20 +84,16 @@ class EncoderResolutionSuite extends PlanTest {
{
val attrs = Seq('a.string, 'b.long, 'c.int)
- assert(intercept[AnalysisException](encoder.validate(attrs)).message ==
+ assert(intercept[AnalysisException](encoder.resolveAndBind(attrs)).message ==
"Try to map struct<a:string,b:bigint,c:int> to Tuple2, " +
- "but failed as the number of fields does not line up.\n" +
- " - Input schema: struct<a:string,b:bigint,c:int>\n" +
- " - Target schema: struct<_1:string,_2:bigint>")
+ "but failed as the number of fields does not line up.")
}
{
val attrs = Seq('a.string)
- assert(intercept[AnalysisException](encoder.validate(attrs)).message ==
+ assert(intercept[AnalysisException](encoder.resolveAndBind(attrs)).message ==
"Try to map struct<a:string> to Tuple2, " +
- "but failed as the number of fields does not line up.\n" +
- " - Input schema: struct<a:string>\n" +
- " - Target schema: struct<_1:string,_2:bigint>")
+ "but failed as the number of fields does not line up.")
}
}
@@ -106,26 +102,22 @@ class EncoderResolutionSuite extends PlanTest {
{
val attrs = Seq('a.string, 'b.struct('x.long, 'y.string, 'z.int))
- assert(intercept[AnalysisException](encoder.validate(attrs)).message ==
+ assert(intercept[AnalysisException](encoder.resolveAndBind(attrs)).message ==
"Try to map struct<x:bigint,y:string,z:int> to Tuple2, " +
- "but failed as the number of fields does not line up.\n" +
- " - Input schema: struct<a:string,b:struct<x:bigint,y:string,z:int>>\n" +
- " - Target schema: struct<_1:string,_2:struct<_1:bigint,_2:string>>")
+ "but failed as the number of fields does not line up.")
}
{
val attrs = Seq('a.string, 'b.struct('x.long))
- assert(intercept[AnalysisException](encoder.validate(attrs)).message ==
+ assert(intercept[AnalysisException](encoder.resolveAndBind(attrs)).message ==
"Try to map struct<x:bigint> to Tuple2, " +
- "but failed as the number of fields does not line up.\n" +
- " - Input schema: struct<a:string,b:struct<x:bigint>>\n" +
- " - Target schema: struct<_1:string,_2:struct<_1:bigint,_2:string>>")
+ "but failed as the number of fields does not line up.")
}
}
test("throw exception if real type is not compatible with encoder schema") {
val msg1 = intercept[AnalysisException] {
- ExpressionEncoder[StringIntClass].resolve(Seq('a.string, 'b.long), null)
+ ExpressionEncoder[StringIntClass].resolveAndBind(Seq('a.string, 'b.long))
}.message
assert(msg1 ==
s"""
@@ -138,7 +130,7 @@ class EncoderResolutionSuite extends PlanTest {
val msg2 = intercept[AnalysisException] {
val structType = new StructType().add("a", StringType).add("b", DecimalType.SYSTEM_DEFAULT)
- ExpressionEncoder[ComplexClass].resolve(Seq('a.long, 'b.struct(structType)), null)
+ ExpressionEncoder[ComplexClass].resolveAndBind(Seq('a.long, 'b.struct(structType)))
}.message
assert(msg2 ==
s"""
@@ -171,7 +163,7 @@ class EncoderResolutionSuite extends PlanTest {
val to = ExpressionEncoder[U]
val catalystType = from.schema.head.dataType.simpleString
test(s"cast from $catalystType to ${implicitly[TypeTag[U]].tpe} should success") {
- to.resolve(from.schema.toAttributes, null)
+ to.resolveAndBind(from.schema.toAttributes)
}
}
@@ -180,7 +172,7 @@ class EncoderResolutionSuite extends PlanTest {
val to = ExpressionEncoder[U]
val catalystType = from.schema.head.dataType.simpleString
test(s"cast from $catalystType to ${implicitly[TypeTag[U]].tpe} should fail") {
- intercept[AnalysisException](to.resolve(from.schema.toAttributes, null))
+ intercept[AnalysisException](to.resolveAndBind(from.schema.toAttributes))
}
}
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
index 232dcc9ee5..a1f9259f13 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
@@ -27,6 +27,7 @@ import scala.reflect.runtime.universe.TypeTag
import org.apache.spark.sql.Encoders
import org.apache.spark.sql.catalyst.{OptionalData, PrimitiveData}
import org.apache.spark.sql.catalyst.analysis.AnalysisTest
+import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference}
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project}
@@ -334,7 +335,7 @@ class ExpressionEncoderSuite extends PlanTest with AnalysisTest {
val encoder = implicitly[ExpressionEncoder[T]]
val row = encoder.toRow(input)
val schema = encoder.schema.toAttributes
- val boundEncoder = encoder.defaultBinding
+ val boundEncoder = encoder.resolveAndBind()
val convertedBack = try boundEncoder.fromRow(row) catch {
case e: Exception =>
fail(
@@ -350,12 +351,8 @@ class ExpressionEncoderSuite extends PlanTest with AnalysisTest {
}
// Test the correct resolution of serialization / deserialization.
- val attr = AttributeReference("obj", ObjectType(encoder.clsTag.runtimeClass))()
- val inputPlan = LocalRelation(attr)
- val plan =
- Project(Alias(encoder.deserializer, "obj")() :: Nil,
- Project(encoder.namedExpressions,
- inputPlan))
+ val attr = AttributeReference("obj", encoder.deserializer.dataType)()
+ val plan = LocalRelation(attr).serialize[T].deserialize[T]
assertAnalysisSuccess(plan)
val isCorrect = (input, convertedBack) match {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
index 39fcc7225b..6f1bc80c1c 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
@@ -135,7 +135,7 @@ class RowEncoderSuite extends SparkFunSuite {
.add("string", StringType)
.add("double", DoubleType))
- val encoder = RowEncoder(schema)
+ val encoder = RowEncoder(schema).resolveAndBind()
val input: Row = Row((100, "test", 0.123))
val row = encoder.toRow(input)
@@ -152,7 +152,7 @@ class RowEncoderSuite extends SparkFunSuite {
.add("scala_decimal", DecimalType.SYSTEM_DEFAULT)
.add("catalyst_decimal", DecimalType.SYSTEM_DEFAULT)
- val encoder = RowEncoder(schema)
+ val encoder = RowEncoder(schema).resolveAndBind()
val javaDecimal = new java.math.BigDecimal("1234.5678")
val scalaDecimal = BigDecimal("1234.5678")
@@ -169,7 +169,7 @@ class RowEncoderSuite extends SparkFunSuite {
test("RowEncoder should preserve decimal precision and scale") {
val schema = new StructType().add("decimal", DecimalType(10, 5), false)
- val encoder = RowEncoder(schema)
+ val encoder = RowEncoder(schema).resolveAndBind()
val decimal = Decimal("67123.45")
val input = Row(decimal)
val row = encoder.toRow(input)
@@ -179,7 +179,7 @@ class RowEncoderSuite extends SparkFunSuite {
test("RowEncoder should preserve schema nullability") {
val schema = new StructType().add("int", IntegerType, nullable = false)
- val encoder = RowEncoder(schema)
+ val encoder = RowEncoder(schema).resolveAndBind()
assert(encoder.serializer.length == 1)
assert(encoder.serializer.head.dataType == IntegerType)
assert(encoder.serializer.head.nullable == false)
@@ -195,7 +195,7 @@ class RowEncoderSuite extends SparkFunSuite {
new StructType().add("int", IntegerType, nullable = false),
nullable = false),
nullable = false)
- val encoder = RowEncoder(schema)
+ val encoder = RowEncoder(schema).resolveAndBind()
assert(encoder.serializer.length == 1)
assert(encoder.serializer.head.dataType ==
new StructType()
@@ -212,7 +212,7 @@ class RowEncoderSuite extends SparkFunSuite {
.add("array", ArrayType(IntegerType))
.add("nestedArray", ArrayType(ArrayType(StringType)))
.add("deepNestedArray", ArrayType(ArrayType(ArrayType(LongType))))
- val encoder = RowEncoder(schema)
+ val encoder = RowEncoder(schema).resolveAndBind()
val input = Row(
Array(1, 2, null),
Array(Array("abc", null), null),
@@ -226,7 +226,7 @@ class RowEncoderSuite extends SparkFunSuite {
private def encodeDecodeTest(schema: StructType): Unit = {
test(s"encode/decode: ${schema.simpleString}") {
- val encoder = RowEncoder(schema)
+ val encoder = RowEncoder(schema).resolveAndBind()
val inputGenerator = RandomDataGenerator.forType(schema, nullable = false).get
var input: Row = null
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index 369b772d32..96c871d034 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -192,24 +192,24 @@ class Dataset[T] private[sql](
}
/**
- * An unresolved version of the internal encoder for the type of this [[Dataset]]. This one is
- * marked implicit so that we can use it when constructing new [[Dataset]] objects that have the
- * same object type (that will be possibly resolved to a different schema).
+ * Currently [[ExpressionEncoder]] is the only implementation of [[Encoder]], here we turn the
+ * passed in encoder to [[ExpressionEncoder]] explicitly, and mark it implicit so that we can use
+ * it when constructing new [[Dataset]] objects that have the same object type (that will be
+ * possibly resolved to a different schema).
*/
- private[sql] implicit val unresolvedTEncoder: ExpressionEncoder[T] = encoderFor(encoder)
- unresolvedTEncoder.validate(logicalPlan.output)
-
- /** The encoder for this [[Dataset]] that has been resolved to its output schema. */
- private[sql] val resolvedTEncoder: ExpressionEncoder[T] =
- unresolvedTEncoder.resolve(logicalPlan.output, OuterScopes.outerScopes)
+ private[sql] implicit val exprEnc: ExpressionEncoder[T] = encoderFor(encoder)
/**
- * The encoder where the expressions used to construct an object from an input row have been
- * bound to the ordinals of this [[Dataset]]'s output schema.
+ * Encoder is used mostly as a container of serde expressions in [[Dataset]]. We build logical
+ * plans by these serde expressions and execute it within the query framework. However, for
+ * performance reasons we may want to use encoder as a function to deserialize internal rows to
+ * custom objects, e.g. collect. Here we resolve and bind the encoder so that we can call its
+ * `fromRow` method later.
*/
- private[sql] val boundTEncoder = resolvedTEncoder.bind(logicalPlan.output)
+ private val boundEnc =
+ exprEnc.resolveAndBind(logicalPlan.output, sparkSession.sessionState.analyzer)
- private implicit def classTag = unresolvedTEncoder.clsTag
+ private implicit def classTag = exprEnc.clsTag
// sqlContext must be val because a stable identifier is expected when you import implicits
@transient lazy val sqlContext: SQLContext = sparkSession.sqlContext
@@ -761,7 +761,7 @@ class Dataset[T] private[sql](
// Note that we do this before joining them, to enable the join operator to return null for one
// side, in cases like outer-join.
val left = {
- val combined = if (this.unresolvedTEncoder.flat) {
+ val combined = if (this.exprEnc.flat) {
assert(joined.left.output.length == 1)
Alias(joined.left.output.head, "_1")()
} else {
@@ -771,7 +771,7 @@ class Dataset[T] private[sql](
}
val right = {
- val combined = if (other.unresolvedTEncoder.flat) {
+ val combined = if (other.exprEnc.flat) {
assert(joined.right.output.length == 1)
Alias(joined.right.output.head, "_2")()
} else {
@@ -784,14 +784,14 @@ class Dataset[T] private[sql](
// combine the outputs of each join side.
val conditionExpr = joined.condition.get transformUp {
case a: Attribute if joined.left.outputSet.contains(a) =>
- if (this.unresolvedTEncoder.flat) {
+ if (this.exprEnc.flat) {
left.output.head
} else {
val index = joined.left.output.indexWhere(_.exprId == a.exprId)
GetStructField(left.output.head, index)
}
case a: Attribute if joined.right.outputSet.contains(a) =>
- if (other.unresolvedTEncoder.flat) {
+ if (other.exprEnc.flat) {
right.output.head
} else {
val index = joined.right.output.indexWhere(_.exprId == a.exprId)
@@ -800,7 +800,7 @@ class Dataset[T] private[sql](
}
implicit val tuple2Encoder: Encoder[(T, U)] =
- ExpressionEncoder.tuple(this.unresolvedTEncoder, other.unresolvedTEncoder)
+ ExpressionEncoder.tuple(this.exprEnc, other.exprEnc)
withTypedPlan(Join(left, right, joined.joinType, Some(conditionExpr)))
}
@@ -1024,7 +1024,7 @@ class Dataset[T] private[sql](
sparkSession,
Project(
c1.withInputType(
- unresolvedTEncoder.deserializer,
+ exprEnc.deserializer,
logicalPlan.output).named :: Nil,
logicalPlan),
implicitly[Encoder[U1]])
@@ -1038,7 +1038,7 @@ class Dataset[T] private[sql](
protected def selectUntyped(columns: TypedColumn[_, _]*): Dataset[_] = {
val encoders = columns.map(_.encoder)
val namedColumns =
- columns.map(_.withInputType(unresolvedTEncoder.deserializer, logicalPlan.output).named)
+ columns.map(_.withInputType(exprEnc.deserializer, logicalPlan.output).named)
val execution = new QueryExecution(sparkSession, Project(namedColumns, logicalPlan))
new Dataset(sparkSession, execution, ExpressionEncoder.tuple(encoders))
}
@@ -2153,14 +2153,14 @@ class Dataset[T] private[sql](
*/
def collectAsList(): java.util.List[T] = withCallback("collectAsList", toDF()) { _ =>
withNewExecutionId {
- val values = queryExecution.executedPlan.executeCollect().map(boundTEncoder.fromRow)
+ val values = queryExecution.executedPlan.executeCollect().map(boundEnc.fromRow)
java.util.Arrays.asList(values : _*)
}
}
private def collect(needCallback: Boolean): Array[T] = {
def execute(): Array[T] = withNewExecutionId {
- queryExecution.executedPlan.executeCollect().map(boundTEncoder.fromRow)
+ queryExecution.executedPlan.executeCollect().map(boundEnc.fromRow)
}
if (needCallback) {
@@ -2184,7 +2184,7 @@ class Dataset[T] private[sql](
*/
def toLocalIterator(): java.util.Iterator[T] = withCallback("toLocalIterator", toDF()) { _ =>
withNewExecutionId {
- queryExecution.executedPlan.executeToIterator().map(boundTEncoder.fromRow).asJava
+ queryExecution.executedPlan.executeToIterator().map(boundEnc.fromRow).asJava
}
}
@@ -2322,7 +2322,7 @@ class Dataset[T] private[sql](
* @since 1.6.0
*/
lazy val rdd: RDD[T] = {
- val objectType = unresolvedTEncoder.deserializer.dataType
+ val objectType = exprEnc.deserializer.dataType
val deserialized = CatalystSerde.deserialize[T](logicalPlan)
sparkSession.sessionState.executePlan(deserialized).toRdd.mapPartitions { rows =>
rows.map(_.get(0, objectType).asInstanceOf[T])
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
index 53f4ea647c..a6867a67ee 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
@@ -42,17 +42,9 @@ class KeyValueGroupedDataset[K, V] private[sql](
private val dataAttributes: Seq[Attribute],
private val groupingAttributes: Seq[Attribute]) extends Serializable {
- // Similar to [[Dataset]], we use unresolved encoders for later composition and resolved encoders
- // when constructing new logical plans that will operate on the output of the current
- // queryexecution.
-
- private implicit val unresolvedKEncoder = encoderFor(kEncoder)
- private implicit val unresolvedVEncoder = encoderFor(vEncoder)
-
- private val resolvedKEncoder =
- unresolvedKEncoder.resolve(groupingAttributes, OuterScopes.outerScopes)
- private val resolvedVEncoder =
- unresolvedVEncoder.resolve(dataAttributes, OuterScopes.outerScopes)
+ // Similar to [[Dataset]], we turn the passed in encoder to `ExpressionEncoder` explicitly.
+ private implicit val kExprEnc = encoderFor(kEncoder)
+ private implicit val vExprEnc = encoderFor(vEncoder)
private def logicalPlan = queryExecution.analyzed
private def sparkSession = queryExecution.sparkSession
@@ -67,7 +59,7 @@ class KeyValueGroupedDataset[K, V] private[sql](
def keyAs[L : Encoder]: KeyValueGroupedDataset[L, V] =
new KeyValueGroupedDataset(
encoderFor[L],
- unresolvedVEncoder,
+ vExprEnc,
queryExecution,
dataAttributes,
groupingAttributes)
@@ -187,7 +179,7 @@ class KeyValueGroupedDataset[K, V] private[sql](
def reduceGroups(f: (V, V) => V): Dataset[(K, V)] = {
val func = (key: K, it: Iterator[V]) => Iterator((key, it.reduce(f)))
- implicit val resultEncoder = ExpressionEncoder.tuple(unresolvedKEncoder, unresolvedVEncoder)
+ implicit val resultEncoder = ExpressionEncoder.tuple(kExprEnc, vExprEnc)
flatMapGroups(func)
}
@@ -209,8 +201,8 @@ class KeyValueGroupedDataset[K, V] private[sql](
protected def aggUntyped(columns: TypedColumn[_, _]*): Dataset[_] = {
val encoders = columns.map(_.encoder)
val namedColumns =
- columns.map(_.withInputType(unresolvedVEncoder.deserializer, dataAttributes).named)
- val keyColumn = if (resolvedKEncoder.flat) {
+ columns.map(_.withInputType(vExprEnc.deserializer, dataAttributes).named)
+ val keyColumn = if (kExprEnc.flat) {
assert(groupingAttributes.length == 1)
groupingAttributes.head
} else {
@@ -222,7 +214,7 @@ class KeyValueGroupedDataset[K, V] private[sql](
new Dataset(
sparkSession,
execution,
- ExpressionEncoder.tuple(unresolvedKEncoder +: encoders))
+ ExpressionEncoder.tuple(kExprEnc +: encoders))
}
/**
@@ -287,7 +279,7 @@ class KeyValueGroupedDataset[K, V] private[sql](
def cogroup[U, R : Encoder](
other: KeyValueGroupedDataset[K, U])(
f: (K, Iterator[V], Iterator[U]) => TraversableOnce[R]): Dataset[R] = {
- implicit val uEncoder = other.unresolvedVEncoder
+ implicit val uEncoder = other.vExprEnc
Dataset[R](
sparkSession,
CoGroup(
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
index 58850a7d4b..49b6eab8db 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
@@ -215,7 +215,7 @@ class RelationalGroupedDataset protected[sql](
def agg(expr: Column, exprs: Column*): DataFrame = {
toDF((expr +: exprs).map {
case typed: TypedColumn[_, _] =>
- typed.withInputType(df.unresolvedTEncoder.deserializer, df.logicalPlan.output).expr
+ typed.withInputType(df.exprEnc.deserializer, df.logicalPlan.output).expr
case c => c.expr
})
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala
index 8f94184764..ecb56e2a28 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala
@@ -33,9 +33,9 @@ object TypedAggregateExpression {
aggregator: Aggregator[_, BUF, OUT]): TypedAggregateExpression = {
val bufferEncoder = encoderFor[BUF]
val bufferSerializer = bufferEncoder.namedExpressions
- val bufferDeserializer = bufferEncoder.deserializer.transform {
- case b: BoundReference => bufferSerializer(b.ordinal).toAttribute
- }
+ val bufferDeserializer = UnresolvedDeserializer(
+ bufferEncoder.deserializer,
+ bufferSerializer.map(_.toAttribute))
val outputEncoder = encoderFor[OUT]
val outputType = if (outputEncoder.flat) {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index d89e98645b..4dbd1665e4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -924,7 +924,7 @@ object functions {
* @since 1.5.0
*/
def broadcast[T](df: Dataset[T]): Dataset[T] = {
- Dataset[T](df.sparkSession, BroadcastHint(df.logicalPlan))(df.unresolvedTEncoder)
+ Dataset[T](df.sparkSession, BroadcastHint(df.logicalPlan))(df.exprEnc)
}
/**
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index df8f4b0610..d1c232974e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -566,18 +566,14 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
}.message
assert(message ==
"Try to map struct<a:string,b:int> to Tuple3, " +
- "but failed as the number of fields does not line up.\n" +
- " - Input schema: struct<a:string,b:int>\n" +
- " - Target schema: struct<_1:string,_2:int,_3:bigint>")
+ "but failed as the number of fields does not line up.")
val message2 = intercept[AnalysisException] {
ds.as[Tuple1[String]]
}.message
assert(message2 ==
"Try to map struct<a:string,b:int> to Tuple1, " +
- "but failed as the number of fields does not line up.\n" +
- " - Input schema: struct<a:string,b:int>\n" +
- " - Target schema: struct<_1:string>")
+ "but failed as the number of fields does not line up.")
}
test("SPARK-13440: Resolving option fields") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
index a1a9b66c1f..9c044f4e8f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
@@ -81,7 +81,7 @@ abstract class QueryTest extends PlanTest {
expectedAnswer: T*): Unit = {
checkAnswer(
ds.toDF(),
- spark.createDataset(expectedAnswer)(ds.unresolvedTEncoder).toDF().collect().toSeq)
+ spark.createDataset(expectedAnswer)(ds.exprEnc).toDF().collect().toSeq)
checkDecoding(ds, expectedAnswer: _*)
}
@@ -94,8 +94,8 @@ abstract class QueryTest extends PlanTest {
fail(
s"""
|Exception collecting dataset as objects
- |${ds.resolvedTEncoder}
- |${ds.resolvedTEncoder.deserializer.treeString}
+ |${ds.exprEnc}
+ |${ds.exprEnc.deserializer.treeString}
|${ds.queryExecution}
""".stripMargin, e)
}
@@ -114,7 +114,7 @@ abstract class QueryTest extends PlanTest {
fail(
s"""Decoded objects do not match expected objects:
|$comparison
- |${ds.resolvedTEncoder.deserializer.treeString}
+ |${ds.exprEnc.deserializer.treeString}
""".stripMargin)
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/GroupedIteratorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/GroupedIteratorSuite.scala
index 6f10e4b805..80340b5552 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/GroupedIteratorSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/GroupedIteratorSuite.scala
@@ -27,7 +27,7 @@ class GroupedIteratorSuite extends SparkFunSuite {
test("basic") {
val schema = new StructType().add("i", IntegerType).add("s", StringType)
- val encoder = RowEncoder(schema)
+ val encoder = RowEncoder(schema).resolveAndBind()
val input = Seq(Row(1, "a"), Row(1, "b"), Row(2, "c"))
val grouped = GroupedIterator(input.iterator.map(encoder.toRow),
Seq('i.int.at(0)), schema.toAttributes)
@@ -45,7 +45,7 @@ class GroupedIteratorSuite extends SparkFunSuite {
test("group by 2 columns") {
val schema = new StructType().add("i", IntegerType).add("l", LongType).add("s", StringType)
- val encoder = RowEncoder(schema)
+ val encoder = RowEncoder(schema).resolveAndBind()
val input = Seq(
Row(1, 2L, "a"),
@@ -72,7 +72,7 @@ class GroupedIteratorSuite extends SparkFunSuite {
test("do nothing to the value iterator") {
val schema = new StructType().add("i", IntegerType).add("s", StringType)
- val encoder = RowEncoder(schema)
+ val encoder = RowEncoder(schema).resolveAndBind()
val input = Seq(Row(1, "a"), Row(1, "b"), Row(2, "c"))
val grouped = GroupedIterator(input.iterator.map(encoder.toRow),
Seq('i.int.at(0)), schema.toAttributes)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala
index dd8672aa64..194c3e7307 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala
@@ -110,7 +110,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with Timeouts {
object CheckAnswer {
def apply[A : Encoder](data: A*): CheckAnswerRows = {
val encoder = encoderFor[A]
- val toExternalRow = RowEncoder(encoder.schema)
+ val toExternalRow = RowEncoder(encoder.schema).resolveAndBind()
CheckAnswerRows(data.map(d => toExternalRow.fromRow(encoder.toRow(d))), false)
}
@@ -124,7 +124,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with Timeouts {
object CheckLastBatch {
def apply[A : Encoder](data: A*): CheckAnswerRows = {
val encoder = encoderFor[A]
- val toExternalRow = RowEncoder(encoder.schema)
+ val toExternalRow = RowEncoder(encoder.schema).resolveAndBind()
CheckAnswerRows(data.map(d => toExternalRow.fromRow(encoder.toRow(d))), true)
}