aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorWenchen Fan <wenchen@databricks.com>2016-03-30 11:03:15 -0700
committerMichael Armbrust <michael@databricks.com>2016-03-30 11:03:15 -0700
commitd46c71b39da92f5cabf6d9057c953c52f7f3f965 (patch)
treec5ed22d51789f1e1d347050a2bf356198f53f6bd
parent816f359cf043ef719a0bc7df0506a3a830fff70d (diff)
downloadspark-d46c71b39da92f5cabf6d9057c953c52f7f3f965.tar.gz
spark-d46c71b39da92f5cabf6d9057c953c52f7f3f965.tar.bz2
spark-d46c71b39da92f5cabf6d9057c953c52f7f3f965.zip
[SPARK-14268][SQL] rename toRowExpressions and fromRowExpression to serializer and deserializer in ExpressionEncoder
## What changes were proposed in this pull request? In `ExpressionEncoder`, we use `constructorFor` to build `fromRowExpression` as the `deserializer` in `ObjectOperator`. It's kind of confusing, we should make the name consistent. ## How was this patch tested? existing tests. Author: Wenchen Fan <wenchen@databricks.com> Closes #12058 from cloud-fan/rename.
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala4
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala32
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala40
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala87
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala34
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala14
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala4
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala4
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala4
9 files changed, 110 insertions, 113 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala
index b19538a23f..1f20e26354 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala
@@ -245,10 +245,10 @@ object Encoders {
ExpressionEncoder[T](
schema = new StructType().add("value", BinaryType),
flat = true,
- toRowExpressions = Seq(
+ serializer = Seq(
EncodeUsingSerializer(
BoundReference(0, ObjectType(classOf[AnyRef]), nullable = true), kryo = useKryo)),
- fromRowExpression =
+ deserializer =
DecodeUsingSerializer[T](
BoundReference(0, BinaryType, nullable = true), classTag[T], kryo = useKryo),
clsTag = classTag[T]
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
index 59ee41d02f..6f9fbbbead 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
@@ -155,16 +155,16 @@ object JavaTypeInference {
}
/**
- * Returns an expression that can be used to construct an object of java bean `T` given an input
- * row with a compatible schema. Fields of the row will be extracted using UnresolvedAttributes
+ * Returns an expression that can be used to deserialize an internal row to an object of java bean
+ * `T` with a compatible schema. Fields of the row will be extracted using UnresolvedAttributes
* of the same name as the constructor arguments. Nested classes will have their fields accessed
* using UnresolvedExtractValue.
*/
- def constructorFor(beanClass: Class[_]): Expression = {
- constructorFor(TypeToken.of(beanClass), None)
+ def deserializerFor(beanClass: Class[_]): Expression = {
+ deserializerFor(TypeToken.of(beanClass), None)
}
- private def constructorFor(typeToken: TypeToken[_], path: Option[Expression]): Expression = {
+ private def deserializerFor(typeToken: TypeToken[_], path: Option[Expression]): Expression = {
/** Returns the current path with a sub-field extracted. */
def addToPath(part: String): Expression = path
.map(p => UnresolvedExtractValue(p, expressions.Literal(part)))
@@ -231,7 +231,7 @@ object JavaTypeInference {
}.getOrElse {
Invoke(
MapObjects(
- p => constructorFor(typeToken.getComponentType, Some(p)),
+ p => deserializerFor(typeToken.getComponentType, Some(p)),
getPath,
inferDataType(elementType)._1),
"array",
@@ -243,7 +243,7 @@ object JavaTypeInference {
val array =
Invoke(
MapObjects(
- p => constructorFor(et, Some(p)),
+ p => deserializerFor(et, Some(p)),
getPath,
inferDataType(et)._1),
"array",
@@ -259,7 +259,7 @@ object JavaTypeInference {
val keyData =
Invoke(
MapObjects(
- p => constructorFor(keyType, Some(p)),
+ p => deserializerFor(keyType, Some(p)),
Invoke(getPath, "keyArray", ArrayType(keyDataType)),
keyDataType),
"array",
@@ -268,7 +268,7 @@ object JavaTypeInference {
val valueData =
Invoke(
MapObjects(
- p => constructorFor(valueType, Some(p)),
+ p => deserializerFor(valueType, Some(p)),
Invoke(getPath, "valueArray", ArrayType(valueDataType)),
valueDataType),
"array",
@@ -288,7 +288,7 @@ object JavaTypeInference {
val fieldName = p.getName
val fieldType = typeToken.method(p.getReadMethod).getReturnType
val (_, nullable) = inferDataType(fieldType)
- val constructor = constructorFor(fieldType, Some(addToPath(fieldName)))
+ val constructor = deserializerFor(fieldType, Some(addToPath(fieldName)))
val setter = if (nullable) {
constructor
} else {
@@ -313,14 +313,14 @@ object JavaTypeInference {
}
/**
- * Returns expressions for extracting all the fields from the given type.
+ * Returns an expression for serializing an object of the given type to an internal row.
*/
- def extractorsFor(beanClass: Class[_]): CreateNamedStruct = {
+ def serializerFor(beanClass: Class[_]): CreateNamedStruct = {
val inputObject = BoundReference(0, ObjectType(beanClass), nullable = true)
- extractorFor(inputObject, TypeToken.of(beanClass)).asInstanceOf[CreateNamedStruct]
+ serializerFor(inputObject, TypeToken.of(beanClass)).asInstanceOf[CreateNamedStruct]
}
- private def extractorFor(inputObject: Expression, typeToken: TypeToken[_]): Expression = {
+ private def serializerFor(inputObject: Expression, typeToken: TypeToken[_]): Expression = {
def toCatalystArray(input: Expression, elementType: TypeToken[_]): Expression = {
val (dataType, nullable) = inferDataType(elementType)
@@ -330,7 +330,7 @@ object JavaTypeInference {
input :: Nil,
dataType = ArrayType(dataType, nullable))
} else {
- MapObjects(extractorFor(_, elementType), input, ObjectType(elementType.getRawType))
+ MapObjects(serializerFor(_, elementType), input, ObjectType(elementType.getRawType))
}
}
@@ -403,7 +403,7 @@ object JavaTypeInference {
inputObject,
p.getReadMethod.getName,
inferExternalType(fieldType.getRawType))
- expressions.Literal(fieldName) :: extractorFor(fieldValue, fieldType) :: Nil
+ expressions.Literal(fieldName) :: serializerFor(fieldValue, fieldType) :: Nil
})
} else {
throw new UnsupportedOperationException(
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index f208401160..d241b8a79b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -110,8 +110,8 @@ object ScalaReflection extends ScalaReflection {
}
/**
- * Returns an expression that can be used to construct an object of type `T` given an input
- * row with a compatible schema. Fields of the row will be extracted using UnresolvedAttributes
+ * Returns an expression that can be used to deserialize an input row to an object of type `T`
+ * with a compatible schema. Fields of the row will be extracted using UnresolvedAttributes
* of the same name as the constructor arguments. Nested classes will have their fields accessed
* using UnresolvedExtractValue.
*
@@ -119,14 +119,14 @@ object ScalaReflection extends ScalaReflection {
* from ordinal 0 (since there are no names to map to). The actual location can be moved by
* calling resolve/bind with a new schema.
*/
- def constructorFor[T : TypeTag]: Expression = {
+ def deserializerFor[T : TypeTag]: Expression = {
val tpe = localTypeOf[T]
val clsName = getClassNameFromType(tpe)
val walkedTypePath = s"""- root class: "${clsName}"""" :: Nil
- constructorFor(tpe, None, walkedTypePath)
+ deserializerFor(tpe, None, walkedTypePath)
}
- private def constructorFor(
+ private def deserializerFor(
tpe: `Type`,
path: Option[Expression],
walkedTypePath: Seq[String]): Expression = ScalaReflectionLock.synchronized {
@@ -161,7 +161,7 @@ object ScalaReflection extends ScalaReflection {
}
/**
- * When we build the `fromRowExpression` for an encoder, we set up a lot of "unresolved" stuff
+ * When we build the `deserializer` for an encoder, we set up a lot of "unresolved" stuff
* and lost the required data type, which may lead to runtime error if the real type doesn't
* match the encoder's schema.
* For example, we build an encoder for `case class Data(a: Int, b: String)` and the real type
@@ -188,7 +188,7 @@ object ScalaReflection extends ScalaReflection {
val TypeRef(_, _, Seq(optType)) = t
val className = getClassNameFromType(optType)
val newTypePath = s"""- option value class: "$className"""" +: walkedTypePath
- WrapOption(constructorFor(optType, path, newTypePath), dataTypeFor(optType))
+ WrapOption(deserializerFor(optType, path, newTypePath), dataTypeFor(optType))
case t if t <:< localTypeOf[java.lang.Integer] =>
val boxedType = classOf[java.lang.Integer]
@@ -272,7 +272,7 @@ object ScalaReflection extends ScalaReflection {
val newTypePath = s"""- array element class: "$className"""" +: walkedTypePath
Invoke(
MapObjects(
- p => constructorFor(elementType, Some(p), newTypePath),
+ p => deserializerFor(elementType, Some(p), newTypePath),
getPath,
schemaFor(elementType).dataType),
"array",
@@ -286,7 +286,7 @@ object ScalaReflection extends ScalaReflection {
val newTypePath = s"""- array element class: "$className"""" +: walkedTypePath
val mapFunction: Expression => Expression = p => {
- val converter = constructorFor(elementType, Some(p), newTypePath)
+ val converter = deserializerFor(elementType, Some(p), newTypePath)
if (nullable) {
converter
} else {
@@ -312,7 +312,7 @@ object ScalaReflection extends ScalaReflection {
val keyData =
Invoke(
MapObjects(
- p => constructorFor(keyType, Some(p), walkedTypePath),
+ p => deserializerFor(keyType, Some(p), walkedTypePath),
Invoke(getPath, "keyArray", ArrayType(schemaFor(keyType).dataType)),
schemaFor(keyType).dataType),
"array",
@@ -321,7 +321,7 @@ object ScalaReflection extends ScalaReflection {
val valueData =
Invoke(
MapObjects(
- p => constructorFor(valueType, Some(p), walkedTypePath),
+ p => deserializerFor(valueType, Some(p), walkedTypePath),
Invoke(getPath, "valueArray", ArrayType(schemaFor(valueType).dataType)),
schemaFor(valueType).dataType),
"array",
@@ -344,12 +344,12 @@ object ScalaReflection extends ScalaReflection {
val newTypePath = s"""- field (class: "$clsName", name: "$fieldName")""" +: walkedTypePath
// For tuples, we based grab the inner fields by ordinal instead of name.
if (cls.getName startsWith "scala.Tuple") {
- constructorFor(
+ deserializerFor(
fieldType,
Some(addToPathOrdinal(i, dataType, newTypePath)),
newTypePath)
} else {
- val constructor = constructorFor(
+ val constructor = deserializerFor(
fieldType,
Some(addToPath(fieldName, dataType, newTypePath)),
newTypePath)
@@ -387,7 +387,7 @@ object ScalaReflection extends ScalaReflection {
}
/**
- * Returns expressions for extracting all the fields from the given type.
+ * Returns an expression for serializing an object of type T to an internal row.
*
* If the given type is not supported, i.e. there is no encoder can be built for this type,
* an [[UnsupportedOperationException]] will be thrown with detailed error message to explain
@@ -398,18 +398,18 @@ object ScalaReflection extends ScalaReflection {
* * the element type of [[Array]] or [[Seq]]: `array element class: "abc.xyz.MyClass"`
* * the field of [[Product]]: `field (class: "abc.xyz.MyClass", name: "myField")`
*/
- def extractorsFor[T : TypeTag](inputObject: Expression): CreateNamedStruct = {
+ def serializerFor[T : TypeTag](inputObject: Expression): CreateNamedStruct = {
val tpe = localTypeOf[T]
val clsName = getClassNameFromType(tpe)
val walkedTypePath = s"""- root class: "${clsName}"""" :: Nil
- extractorFor(inputObject, tpe, walkedTypePath) match {
+ serializerFor(inputObject, tpe, walkedTypePath) match {
case expressions.If(_, _, s: CreateNamedStruct) if tpe <:< localTypeOf[Product] => s
case other => CreateNamedStruct(expressions.Literal("value") :: other :: Nil)
}
}
/** Helper for extracting internal fields from a case class. */
- private def extractorFor(
+ private def serializerFor(
inputObject: Expression,
tpe: `Type`,
walkedTypePath: Seq[String]): Expression = ScalaReflectionLock.synchronized {
@@ -425,7 +425,7 @@ object ScalaReflection extends ScalaReflection {
} else {
val clsName = getClassNameFromType(elementType)
val newPath = s"""- array element class: "$clsName"""" +: walkedTypePath
- MapObjects(extractorFor(_, elementType, newPath), input, externalDataType)
+ MapObjects(serializerFor(_, elementType, newPath), input, externalDataType)
}
}
@@ -491,7 +491,7 @@ object ScalaReflection extends ScalaReflection {
expressions.If(
IsNull(unwrapped),
expressions.Literal.create(null, silentSchemaFor(optType).dataType),
- extractorFor(unwrapped, optType, newPath))
+ serializerFor(unwrapped, optType, newPath))
}
case t if t <:< localTypeOf[Product] =>
@@ -500,7 +500,7 @@ object ScalaReflection extends ScalaReflection {
val fieldValue = Invoke(inputObject, fieldName, dataTypeFor(fieldType))
val clsName = getClassNameFromType(fieldType)
val newPath = s"""- field (class: "$clsName", name: "$fieldName")""" +: walkedTypePath
- expressions.Literal(fieldName) :: extractorFor(fieldValue, fieldType, newPath) :: Nil
+ expressions.Literal(fieldName) :: serializerFor(fieldValue, fieldType, newPath) :: Nil
})
val nullOutput = expressions.Literal.create(null, nonNullOutput.dataType)
expressions.If(IsNull(inputObject), nullOutput, nonNullOutput)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
index 918233ddcd..1c712fde26 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
@@ -51,8 +51,8 @@ object ExpressionEncoder {
val flat = !classOf[Product].isAssignableFrom(cls)
val inputObject = BoundReference(0, ScalaReflection.dataTypeFor[T], nullable = false)
- val toRowExpression = ScalaReflection.extractorsFor[T](inputObject)
- val fromRowExpression = ScalaReflection.constructorFor[T]
+ val serializer = ScalaReflection.serializerFor[T](inputObject)
+ val deserializer = ScalaReflection.deserializerFor[T]
val schema = ScalaReflection.schemaFor[T] match {
case ScalaReflection.Schema(s: StructType, _) => s
@@ -62,8 +62,8 @@ object ExpressionEncoder {
new ExpressionEncoder[T](
schema,
flat,
- toRowExpression.flatten,
- fromRowExpression,
+ serializer.flatten,
+ deserializer,
ClassTag[T](cls))
}
@@ -72,14 +72,14 @@ object ExpressionEncoder {
val schema = JavaTypeInference.inferDataType(beanClass)._1
assert(schema.isInstanceOf[StructType])
- val toRowExpression = JavaTypeInference.extractorsFor(beanClass)
- val fromRowExpression = JavaTypeInference.constructorFor(beanClass)
+ val serializer = JavaTypeInference.serializerFor(beanClass)
+ val deserializer = JavaTypeInference.deserializerFor(beanClass)
new ExpressionEncoder[T](
schema.asInstanceOf[StructType],
flat = false,
- toRowExpression.flatten,
- fromRowExpression,
+ serializer.flatten,
+ deserializer,
ClassTag[T](beanClass))
}
@@ -103,9 +103,9 @@ object ExpressionEncoder {
val cls = Utils.getContextOrSparkClassLoader.loadClass(s"scala.Tuple${encoders.size}")
- val toRowExpressions = encoders.map {
- case e if e.flat => e.toRowExpressions.head
- case other => CreateStruct(other.toRowExpressions)
+ val serializer = encoders.map {
+ case e if e.flat => e.serializer.head
+ case other => CreateStruct(other.serializer)
}.zipWithIndex.map { case (expr, index) =>
expr.transformUp {
case BoundReference(0, t, _) =>
@@ -116,14 +116,14 @@ object ExpressionEncoder {
}
}
- val fromRowExpressions = encoders.zipWithIndex.map { case (enc, index) =>
+ val childrenDeserializers = encoders.zipWithIndex.map { case (enc, index) =>
if (enc.flat) {
- enc.fromRowExpression.transform {
+ enc.deserializer.transform {
case b: BoundReference => b.copy(ordinal = index)
}
} else {
val input = BoundReference(index, enc.schema, nullable = true)
- enc.fromRowExpression.transformUp {
+ enc.deserializer.transformUp {
case UnresolvedAttribute(nameParts) =>
assert(nameParts.length == 1)
UnresolvedExtractValue(input, Literal(nameParts.head))
@@ -132,14 +132,14 @@ object ExpressionEncoder {
}
}
- val fromRowExpression =
- NewInstance(cls, fromRowExpressions, ObjectType(cls), propagateNull = false)
+ val deserializer =
+ NewInstance(cls, childrenDeserializers, ObjectType(cls), propagateNull = false)
new ExpressionEncoder[Any](
schema,
flat = false,
- toRowExpressions,
- fromRowExpression,
+ serializer,
+ deserializer,
ClassTag(cls))
}
@@ -174,29 +174,29 @@ object ExpressionEncoder {
* A generic encoder for JVM objects.
*
* @param schema The schema after converting `T` to a Spark SQL row.
- * @param toRowExpressions A set of expressions, one for each top-level field that can be used to
- * extract the values from a raw object into an [[InternalRow]].
- * @param fromRowExpression An expression that will construct an object given an [[InternalRow]].
+ * @param serializer A set of expressions, one for each top-level field that can be used to
+ * extract the values from a raw object into an [[InternalRow]].
+ * @param deserializer An expression that will construct an object given an [[InternalRow]].
* @param clsTag A classtag for `T`.
*/
case class ExpressionEncoder[T](
schema: StructType,
flat: Boolean,
- toRowExpressions: Seq[Expression],
- fromRowExpression: Expression,
+ serializer: Seq[Expression],
+ deserializer: Expression,
clsTag: ClassTag[T])
extends Encoder[T] {
- if (flat) require(toRowExpressions.size == 1)
+ if (flat) require(serializer.size == 1)
@transient
- private lazy val extractProjection = GenerateUnsafeProjection.generate(toRowExpressions)
+ private lazy val extractProjection = GenerateUnsafeProjection.generate(serializer)
@transient
private lazy val inputRow = new GenericMutableRow(1)
@transient
- private lazy val constructProjection = GenerateSafeProjection.generate(fromRowExpression :: Nil)
+ private lazy val constructProjection = GenerateSafeProjection.generate(deserializer :: Nil)
/**
* Returns this encoder where it has been bound to its own output (i.e. no remaping of columns
@@ -212,7 +212,7 @@ case class ExpressionEncoder[T](
* Returns a new set (with unique ids) of [[NamedExpression]] that represent the serialized form
* of this object.
*/
- def namedExpressions: Seq[NamedExpression] = schema.map(_.name).zip(toRowExpressions).map {
+ def namedExpressions: Seq[NamedExpression] = schema.map(_.name).zip(serializer).map {
case (_, ne: NamedExpression) => ne.newInstance()
case (name, e) => Alias(e, name)()
}
@@ -228,7 +228,7 @@ case class ExpressionEncoder[T](
} catch {
case e: Exception =>
throw new RuntimeException(
- s"Error while encoding: $e\n${toRowExpressions.map(_.treeString).mkString("\n")}", e)
+ s"Error while encoding: $e\n${serializer.map(_.treeString).mkString("\n")}", e)
}
/**
@@ -240,7 +240,7 @@ case class ExpressionEncoder[T](
constructProjection(row).get(0, ObjectType(clsTag.runtimeClass)).asInstanceOf[T]
} catch {
case e: Exception =>
- throw new RuntimeException(s"Error while decoding: $e\n${fromRowExpression.treeString}", e)
+ throw new RuntimeException(s"Error while decoding: $e\n${deserializer.treeString}", e)
}
/**
@@ -249,7 +249,7 @@ case class ExpressionEncoder[T](
* has not been done already in places where we plan to do later composition of encoders.
*/
def assertUnresolved(): Unit = {
- (fromRowExpression +: toRowExpressions).foreach(_.foreach {
+ (deserializer +: serializer).foreach(_.foreach {
case a: AttributeReference if a.name != "loopVar" =>
sys.error(s"Unresolved encoder expected, but $a was found.")
case _ =>
@@ -257,7 +257,7 @@ case class ExpressionEncoder[T](
}
/**
- * Validates `fromRowExpression` to make sure it can be resolved by given schema, and produce
+ * Validates `deserializer` to make sure it can be resolved by given schema, and produce
* friendly error messages to explain why it fails to resolve if there is something wrong.
*/
def validate(schema: Seq[Attribute]): Unit = {
@@ -271,7 +271,7 @@ case class ExpressionEncoder[T](
// If this is a tuple encoder or tupled encoder, which means its leaf nodes are all
// `BoundReference`, make sure their ordinals are all valid.
var maxOrdinal = -1
- fromRowExpression.foreach {
+ deserializer.foreach {
case b: BoundReference => if (b.ordinal > maxOrdinal) maxOrdinal = b.ordinal
case _ =>
}
@@ -285,7 +285,7 @@ case class ExpressionEncoder[T](
// we unbound it by the given `schema` and propagate the actual type to `GetStructField`, after
// we resolve the `fromRowExpression`.
val resolved = SimpleAnalyzer.resolveExpression(
- fromRowExpression,
+ deserializer,
LocalRelation(schema),
throws = true)
@@ -312,42 +312,39 @@ case class ExpressionEncoder[T](
}
/**
- * Returns a new copy of this encoder, where the expressions used by `fromRow` are resolved to the
- * given schema.
+ * Returns a new copy of this encoder, where the `deserializer` is resolved to the given schema.
*/
def resolve(
schema: Seq[Attribute],
outerScopes: ConcurrentMap[String, AnyRef]): ExpressionEncoder[T] = {
- val deserializer = SimpleAnalyzer.ResolveReferences.resolveDeserializer(
- fromRowExpression, schema)
+ val resolved = SimpleAnalyzer.ResolveReferences.resolveDeserializer(deserializer, schema)
// Make a fake plan to wrap the deserializer, so that we can go though the whole analyzer, check
// analysis, go through optimizer, etc.
- val plan = Project(Alias(deserializer, "")() :: Nil, LocalRelation(schema))
+ val plan = Project(Alias(resolved, "")() :: Nil, LocalRelation(schema))
val analyzedPlan = SimpleAnalyzer.execute(plan)
SimpleAnalyzer.checkAnalysis(analyzedPlan)
- copy(fromRowExpression = SimplifyCasts(analyzedPlan).expressions.head.children.head)
+ copy(deserializer = SimplifyCasts(analyzedPlan).expressions.head.children.head)
}
/**
- * Returns a copy of this encoder where the expressions used to construct an object from an input
- * row have been bound to the ordinals of the given schema. Note that you need to first call
- * resolve before bind.
+ * Returns a copy of this encoder where the `deserializer` has been bound to the
+ * ordinals of the given schema. Note that you need to first call resolve before bind.
*/
def bind(schema: Seq[Attribute]): ExpressionEncoder[T] = {
- copy(fromRowExpression = BindReferences.bindReference(fromRowExpression, schema))
+ copy(deserializer = BindReferences.bindReference(deserializer, schema))
}
/**
* Returns a new encoder with input columns shifted by `delta` ordinals
*/
def shift(delta: Int): ExpressionEncoder[T] = {
- copy(fromRowExpression = fromRowExpression transform {
+ copy(deserializer = deserializer transform {
case r: BoundReference => r.copy(ordinal = r.ordinal + delta)
})
}
- protected val attrs = toRowExpressions.flatMap(_.collect {
+ protected val attrs = serializer.flatMap(_.collect {
case _: UnresolvedAttribute => ""
case a: Attribute => s"#${a.exprId}"
case b: BoundReference => s"[${b.ordinal}]"
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
index 30f56d8c2f..a8397aa5e5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
@@ -36,23 +36,23 @@ object RowEncoder {
val cls = classOf[Row]
val inputObject = BoundReference(0, ObjectType(cls), nullable = true)
// We use an If expression to wrap extractorsFor result of StructType
- val extractExpressions = extractorsFor(inputObject, schema).asInstanceOf[If].falseValue
- val constructExpression = constructorFor(schema)
+ val serializer = serializerFor(inputObject, schema).asInstanceOf[If].falseValue
+ val deserializer = deserializerFor(schema)
new ExpressionEncoder[Row](
schema,
flat = false,
- extractExpressions.asInstanceOf[CreateStruct].children,
- constructExpression,
+ serializer.asInstanceOf[CreateStruct].children,
+ deserializer,
ClassTag(cls))
}
- private def extractorsFor(
+ private def serializerFor(
inputObject: Expression,
inputType: DataType): Expression = inputType match {
case NullType | BooleanType | ByteType | ShortType | IntegerType | LongType |
FloatType | DoubleType | BinaryType | CalendarIntervalType => inputObject
- case p: PythonUserDefinedType => extractorsFor(inputObject, p.sqlType)
+ case p: PythonUserDefinedType => serializerFor(inputObject, p.sqlType)
case udt: UserDefinedType[_] =>
val obj = NewInstance(
@@ -95,7 +95,7 @@ object RowEncoder {
classOf[GenericArrayData],
inputObject :: Nil,
dataType = t)
- case _ => MapObjects(extractorsFor(_, et), inputObject, externalDataTypeForInput(et))
+ case _ => MapObjects(serializerFor(_, et), inputObject, externalDataTypeForInput(et))
}
case t @ MapType(kt, vt, valueNullable) =>
@@ -104,14 +104,14 @@ object RowEncoder {
Invoke(inputObject, "keysIterator", ObjectType(classOf[scala.collection.Iterator[_]])),
"toSeq",
ObjectType(classOf[scala.collection.Seq[_]]))
- val convertedKeys = extractorsFor(keys, ArrayType(kt, false))
+ val convertedKeys = serializerFor(keys, ArrayType(kt, false))
val values =
Invoke(
Invoke(inputObject, "valuesIterator", ObjectType(classOf[scala.collection.Iterator[_]])),
"toSeq",
ObjectType(classOf[scala.collection.Seq[_]]))
- val convertedValues = extractorsFor(values, ArrayType(vt, valueNullable))
+ val convertedValues = serializerFor(values, ArrayType(vt, valueNullable))
NewInstance(
classOf[ArrayBasedMapData],
@@ -128,7 +128,7 @@ object RowEncoder {
If(
Invoke(inputObject, "isNullAt", BooleanType, Literal(i) :: Nil),
Literal.create(null, f.dataType),
- extractorsFor(
+ serializerFor(
Invoke(inputObject, method, externalDataTypeForInput(f.dataType), Literal(i) :: Nil),
f.dataType))
}
@@ -166,7 +166,7 @@ object RowEncoder {
case _: NullType => ObjectType(classOf[java.lang.Object])
}
- private def constructorFor(schema: StructType): Expression = {
+ private def deserializerFor(schema: StructType): Expression = {
val fields = schema.zipWithIndex.map { case (f, i) =>
val dt = f.dataType match {
case p: PythonUserDefinedType => p.sqlType
@@ -176,13 +176,13 @@ object RowEncoder {
If(
IsNull(field),
Literal.create(null, externalDataTypeFor(dt)),
- constructorFor(field)
+ deserializerFor(field)
)
}
CreateExternalRow(fields, schema)
}
- private def constructorFor(input: Expression): Expression = input.dataType match {
+ private def deserializerFor(input: Expression): Expression = input.dataType match {
case NullType | BooleanType | ByteType | ShortType | IntegerType | LongType |
FloatType | DoubleType | BinaryType | CalendarIntervalType => input
@@ -216,7 +216,7 @@ object RowEncoder {
case ArrayType(et, nullable) =>
val arrayData =
Invoke(
- MapObjects(constructorFor(_), input, et),
+ MapObjects(deserializerFor(_), input, et),
"array",
ObjectType(classOf[Array[_]]))
StaticInvoke(
@@ -227,10 +227,10 @@ object RowEncoder {
case MapType(kt, vt, valueNullable) =>
val keyArrayType = ArrayType(kt, false)
- val keyData = constructorFor(Invoke(input, "keyArray", keyArrayType))
+ val keyData = deserializerFor(Invoke(input, "keyArray", keyArrayType))
val valueArrayType = ArrayType(vt, valueNullable)
- val valueData = constructorFor(Invoke(input, "valueArray", valueArrayType))
+ val valueData = deserializerFor(Invoke(input, "valueArray", valueArrayType))
StaticInvoke(
ArrayBasedMapData.getClass,
@@ -243,7 +243,7 @@ object RowEncoder {
If(
Invoke(input, "isNullAt", BooleanType, Literal(i) :: Nil),
Literal.create(null, externalDataTypeFor(f.dataType)),
- constructorFor(GetStructField(input, i)))
+ deserializerFor(GetStructField(input, i)))
}
If(IsNull(input),
Literal.create(null, externalDataTypeFor(input.dataType)),
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
index da7f81c785..058fb6bff1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
@@ -71,7 +71,7 @@ object MapPartitions {
child: LogicalPlan): MapPartitions = {
MapPartitions(
func.asInstanceOf[Iterator[Any] => Iterator[Any]],
- encoderFor[T].fromRowExpression,
+ encoderFor[T].deserializer,
encoderFor[U].namedExpressions,
child)
}
@@ -98,7 +98,7 @@ object AppendColumns {
child: LogicalPlan): AppendColumns = {
new AppendColumns(
func.asInstanceOf[Any => Any],
- encoderFor[T].fromRowExpression,
+ encoderFor[T].deserializer,
encoderFor[U].namedExpressions,
child)
}
@@ -133,8 +133,8 @@ object MapGroups {
child: LogicalPlan): MapGroups = {
new MapGroups(
func.asInstanceOf[(Any, Iterator[Any]) => TraversableOnce[Any]],
- encoderFor[K].fromRowExpression,
- encoderFor[T].fromRowExpression,
+ encoderFor[K].deserializer,
+ encoderFor[T].deserializer,
encoderFor[U].namedExpressions,
groupingAttributes,
dataAttributes,
@@ -178,9 +178,9 @@ object CoGroup {
CoGroup(
func.asInstanceOf[(Any, Iterator[Any], Iterator[Any]) => TraversableOnce[Any]],
- encoderFor[Key].fromRowExpression,
- encoderFor[Left].fromRowExpression,
- encoderFor[Right].fromRowExpression,
+ encoderFor[Key].deserializer,
+ encoderFor[Left].deserializer,
+ encoderFor[Right].deserializer,
encoderFor[Result].namedExpressions,
leftGroup,
rightGroup,
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
index dd31050bb5..5ca5a72512 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
@@ -248,10 +248,10 @@ class ScalaReflectionSuite extends SparkFunSuite {
Seq(
("mirror", () => mirror),
("dataTypeFor", () => dataTypeFor[ComplexData]),
- ("constructorFor", () => constructorFor[ComplexData]),
+ ("constructorFor", () => deserializerFor[ComplexData]),
("extractorsFor", {
val inputObject = BoundReference(0, dataTypeForComplexData, nullable = false)
- () => extractorsFor[ComplexData](inputObject)
+ () => serializerFor[ComplexData](inputObject)
}),
("getConstructorParameters(cls)", () => getConstructorParameters(classOf[ComplexData])),
("getConstructorParameterNames", () => getConstructorParameterNames(classOf[ComplexData])),
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
index f6583bfe42..18752014ea 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
@@ -315,7 +315,7 @@ class ExpressionEncoderSuite extends PlanTest with AnalysisTest {
val attr = AttributeReference("obj", ObjectType(encoder.clsTag.runtimeClass))()
val inputPlan = LocalRelation(attr)
val plan =
- Project(Alias(encoder.fromRowExpression, "obj")() :: Nil,
+ Project(Alias(encoder.deserializer, "obj")() :: Nil,
Project(encoder.namedExpressions,
inputPlan))
assertAnalysisSuccess(plan)
@@ -360,7 +360,7 @@ class ExpressionEncoderSuite extends PlanTest with AnalysisTest {
|${encoder.schema.treeString}
|
|fromRow Expressions:
- |${boundEncoder.fromRowExpression.treeString}
+ |${boundEncoder.deserializer.treeString}
""".stripMargin)
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
index 7ff4ffcaec..854a662cc4 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
@@ -90,7 +90,7 @@ abstract class QueryTest extends PlanTest {
s"""
|Exception collecting dataset as objects
|${ds.resolvedTEncoder}
- |${ds.resolvedTEncoder.fromRowExpression.treeString}
+ |${ds.resolvedTEncoder.deserializer.treeString}
|${ds.queryExecution}
""".stripMargin, e)
}
@@ -109,7 +109,7 @@ abstract class QueryTest extends PlanTest {
fail(
s"""Decoded objects do not match expected objects:
|$comparision
- |${ds.resolvedTEncoder.fromRowExpression.treeString}
+ |${ds.resolvedTEncoder.deserializer.treeString}
""".stripMargin)
}
}