aboutsummaryrefslogtreecommitdiff
path: root/sql/catalyst
diff options
context:
space:
mode:
Diffstat (limited to 'sql/catalyst')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala10
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala124
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/package.scala11
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala7
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala15
5 files changed, 103 insertions, 64 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala
index 6134f9e036..5f619d6c33 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala
@@ -84,7 +84,7 @@ object Encoders {
private def tuple(encoders: Seq[ExpressionEncoder[_]]): ExpressionEncoder[_] = {
assert(encoders.length > 1)
// make sure all encoders are resolved, i.e. `Attribute` has been resolved to `BoundReference`.
- assert(encoders.forall(_.constructExpression.find(_.isInstanceOf[Attribute]).isEmpty))
+ assert(encoders.forall(_.fromRowExpression.find(_.isInstanceOf[Attribute]).isEmpty))
val schema = StructType(encoders.zipWithIndex.map {
case (e, i) => StructField(s"_${i + 1}", if (e.flat) e.schema.head.dataType else e.schema)
@@ -93,8 +93,8 @@ object Encoders {
val cls = Utils.getContextOrSparkClassLoader.loadClass(s"scala.Tuple${encoders.size}")
val extractExpressions = encoders.map {
- case e if e.flat => e.extractExpressions.head
- case other => CreateStruct(other.extractExpressions)
+ case e if e.flat => e.toRowExpressions.head
+ case other => CreateStruct(other.toRowExpressions)
}.zipWithIndex.map { case (expr, index) =>
expr.transformUp {
case BoundReference(0, t: ObjectType, _) =>
@@ -107,11 +107,11 @@ object Encoders {
val constructExpressions = encoders.zipWithIndex.map { case (enc, index) =>
if (enc.flat) {
- enc.constructExpression.transform {
+ enc.fromRowExpression.transform {
case b: BoundReference => b.copy(ordinal = index)
}
} else {
- enc.constructExpression.transformUp {
+ enc.fromRowExpression.transformUp {
case BoundReference(ordinal, dt, _) =>
GetInternalRowField(BoundReference(index, enc.schema, nullable = true), ordinal, dt)
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
index 294afde534..0d3e4aafb0 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
@@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection, GenerateUnsafeProjection}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.ScalaReflection
-import org.apache.spark.sql.types.{StructField, ObjectType, StructType}
+import org.apache.spark.sql.types.{NullType, StructField, ObjectType, StructType}
/**
* A factory for constructing encoders that convert objects and primitves to and from the
@@ -61,20 +61,39 @@ object ExpressionEncoder {
/**
* Given a set of N encoders, constructs a new encoder that produce objects as items in an
- * N-tuple. Note that these encoders should first be bound correctly to the combined input
- * schema.
+ * N-tuple. Note that these encoders should be unresolved so that information about
+ * name/positional binding is preserved.
*/
def tuple(encoders: Seq[ExpressionEncoder[_]]): ExpressionEncoder[_] = {
+ encoders.foreach(_.assertUnresolved())
+
val schema =
StructType(
- encoders.zipWithIndex.map { case (e, i) => StructField(s"_${i + 1}", e.schema)})
+ encoders.zipWithIndex.map {
+ case (e, i) => StructField(s"_${i + 1}", if (e.flat) e.schema.head.dataType else e.schema)
+ })
val cls = Utils.getContextOrSparkClassLoader.loadClass(s"scala.Tuple${encoders.size}")
- val extractExpressions = encoders.map {
- case e if e.flat => e.extractExpressions.head
- case other => CreateStruct(other.extractExpressions)
+
+ // Rebind the encoders to the nested schema.
+ val newConstructExpressions = encoders.zipWithIndex.map {
+ case (e, i) if !e.flat => e.nested(i).fromRowExpression
+ case (e, i) => e.shift(i).fromRowExpression
}
+
val constructExpression =
- NewInstance(cls, encoders.map(_.constructExpression), false, ObjectType(cls))
+ NewInstance(cls, newConstructExpressions, false, ObjectType(cls))
+
+ val input = BoundReference(0, ObjectType(cls), false)
+ val extractExpressions = encoders.zipWithIndex.map {
+ case (e, i) if !e.flat => CreateStruct(e.toRowExpressions.map(_ transformUp {
+ case b: BoundReference =>
+ Invoke(input, s"_${i + 1}", b.dataType, Nil)
+ }))
+ case (e, i) => e.toRowExpressions.head transformUp {
+ case b: BoundReference =>
+ Invoke(input, s"_${i + 1}", b.dataType, Nil)
+ }
+ }
new ExpressionEncoder[Any](
schema,
@@ -95,35 +114,40 @@ object ExpressionEncoder {
* A generic encoder for JVM objects.
*
* @param schema The schema after converting `T` to a Spark SQL row.
- * @param extractExpressions A set of expressions, one for each top-level field that can be used to
- * extract the values from a raw object.
+ * @param toRowExpressions A set of expressions, one for each top-level field that can be used to
+ * extract the values from a raw object into an [[InternalRow]].
+ * @param fromRowExpression An expression that will construct an object given an [[InternalRow]].
* @param clsTag A classtag for `T`.
*/
case class ExpressionEncoder[T](
schema: StructType,
flat: Boolean,
- extractExpressions: Seq[Expression],
- constructExpression: Expression,
+ toRowExpressions: Seq[Expression],
+ fromRowExpression: Expression,
clsTag: ClassTag[T])
extends Encoder[T] {
- if (flat) require(extractExpressions.size == 1)
+ if (flat) require(toRowExpressions.size == 1)
@transient
- private lazy val extractProjection = GenerateUnsafeProjection.generate(extractExpressions)
+ private lazy val extractProjection = GenerateUnsafeProjection.generate(toRowExpressions)
private val inputRow = new GenericMutableRow(1)
@transient
- private lazy val constructProjection = GenerateSafeProjection.generate(constructExpression :: Nil)
+ private lazy val constructProjection = GenerateSafeProjection.generate(fromRowExpression :: Nil)
/**
* Returns an encoded version of `t` as a Spark SQL row. Note that multiple calls to
* toRow are allowed to return the same actual [[InternalRow]] object. Thus, the caller should
* copy the result before making another call if required.
*/
- def toRow(t: T): InternalRow = {
+ def toRow(t: T): InternalRow = try {
inputRow(0) = t
extractProjection(inputRow)
+ } catch {
+ case e: Exception =>
+ throw new RuntimeException(
+ s"Error while encoding: $e\n${toRowExpressions.map(_.treeString).mkString("\n")}", e)
}
/**
@@ -135,7 +159,20 @@ case class ExpressionEncoder[T](
constructProjection(row).get(0, ObjectType(clsTag.runtimeClass)).asInstanceOf[T]
} catch {
case e: Exception =>
- throw new RuntimeException(s"Error while decoding: $e\n${constructExpression.treeString}", e)
+ throw new RuntimeException(s"Error while decoding: $e\n${fromRowExpression.treeString}", e)
+ }
+
+ /**
+ * The process of resolution to a given schema throws away information about where a given field
+ * is being bound by ordinal instead of by name. This method checks to make sure this process
+ * has not been done already in places where we plan to do later composition of encoders.
+ */
+ def assertUnresolved(): Unit = {
+ (fromRowExpression +: toRowExpressions).foreach(_.foreach {
+ case a: AttributeReference =>
+ sys.error(s"Unresolved encoder expected, but $a was found.")
+ case _ =>
+ })
}
/**
@@ -143,9 +180,14 @@ case class ExpressionEncoder[T](
* given schema.
*/
def resolve(schema: Seq[Attribute]): ExpressionEncoder[T] = {
- val plan = Project(Alias(constructExpression, "")() :: Nil, LocalRelation(schema))
+ val positionToAttribute = AttributeMap.toIndex(schema)
+ val unbound = fromRowExpression transform {
+ case b: BoundReference => positionToAttribute(b.ordinal)
+ }
+
+ val plan = Project(Alias(unbound, "")() :: Nil, LocalRelation(schema))
val analyzedPlan = SimpleAnalyzer.execute(plan)
- copy(constructExpression = analyzedPlan.expressions.head.children.head)
+ copy(fromRowExpression = analyzedPlan.expressions.head.children.head)
}
/**
@@ -154,39 +196,14 @@ case class ExpressionEncoder[T](
* resolve before bind.
*/
def bind(schema: Seq[Attribute]): ExpressionEncoder[T] = {
- copy(constructExpression = BindReferences.bindReference(constructExpression, schema))
- }
-
- /**
- * Replaces any bound references in the schema with the attributes at the corresponding ordinal
- * in the provided schema. This can be used to "relocate" a given encoder to pull values from
- * a different schema than it was initially bound to. It can also be used to assign attributes
- * to ordinal based extraction (i.e. because the input data was a tuple).
- */
- def unbind(schema: Seq[Attribute]): ExpressionEncoder[T] = {
- val positionToAttribute = AttributeMap.toIndex(schema)
- copy(constructExpression = constructExpression transform {
- case b: BoundReference => positionToAttribute(b.ordinal)
- })
+ copy(fromRowExpression = BindReferences.bindReference(fromRowExpression, schema))
}
/**
- * Given an encoder that has already been bound to a given schema, returns a new encoder
- * where the positions are mapped from `oldSchema` to `newSchema`. This can be used, for example,
- * when you are trying to use an encoder on grouping keys that were originally part of a larger
- * row, but now you have projected out only the key expressions.
+ * Returns a new encoder with input columns shifted by `delta` ordinals
*/
- def rebind(oldSchema: Seq[Attribute], newSchema: Seq[Attribute]): ExpressionEncoder[T] = {
- val positionToAttribute = AttributeMap.toIndex(oldSchema)
- val attributeToNewPosition = AttributeMap.byIndex(newSchema)
- copy(constructExpression = constructExpression transform {
- case r: BoundReference =>
- r.copy(ordinal = attributeToNewPosition(positionToAttribute(r.ordinal)))
- })
- }
-
def shift(delta: Int): ExpressionEncoder[T] = {
- copy(constructExpression = constructExpression transform {
+ copy(fromRowExpression = fromRowExpression transform {
case r: BoundReference => r.copy(ordinal = r.ordinal + delta)
})
}
@@ -196,11 +213,14 @@ case class ExpressionEncoder[T](
* input row have been modified to pull the object out from a nested struct, instead of the
* top level fields.
*/
- def nested(input: Expression = BoundReference(0, schema, true)): ExpressionEncoder[T] = {
- copy(constructExpression = constructExpression transform {
- case u: Attribute if u != input =>
+ private def nested(i: Int): ExpressionEncoder[T] = {
+ // We don't always know our input type at this point since it might be unresolved.
+ // We fill in null and it will get unbound to the actual attribute at this position.
+ val input = BoundReference(i, NullType, nullable = true)
+ copy(fromRowExpression = fromRowExpression transformUp {
+ case u: Attribute =>
UnresolvedExtractValue(input, Literal(u.name))
- case b: BoundReference if b != input =>
+ case b: BoundReference =>
GetStructField(
input,
StructField(s"i[${b.ordinal}]", b.dataType),
@@ -208,7 +228,7 @@ case class ExpressionEncoder[T](
})
}
- protected val attrs = extractExpressions.flatMap(_.collect {
+ protected val attrs = toRowExpressions.flatMap(_.collect {
case _: UnresolvedAttribute => ""
case a: Attribute => s"#${a.exprId}"
case b: BoundReference => s"[${b.ordinal}]"
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/package.scala
index 2c35adca9c..9e283f5eb6 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/package.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/package.scala
@@ -18,10 +18,19 @@
package org.apache.spark.sql.catalyst
import org.apache.spark.sql.Encoder
+import org.apache.spark.sql.catalyst.expressions.AttributeReference
package object encoders {
+ /**
+ * Returns an internal encoder object that can be used to serialize / deserialize JVM objects
+ * into Spark SQL rows. The implicit encoder should always be unresolved (i.e. have no attribute
+ * references from a specific schema.) This requirement allows us to preserve whether a given
+ * object type is being bound by name or by ordinal when doing resolution.
+ */
private[sql] def encoderFor[A : Encoder]: ExpressionEncoder[A] = implicitly[Encoder[A]] match {
- case e: ExpressionEncoder[A] => e
+ case e: ExpressionEncoder[A] =>
+ e.assertUnresolved()
+ e
case _ => sys.error(s"Only expression encoders are supported today")
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
index 41cd0a104a..f871b737ff 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
@@ -97,11 +97,16 @@ object ExtractValue {
* Returns the value of fields in the Struct `child`.
*
* No need to do type checking since it is handled by [[ExtractValue]].
+ * TODO: Unify with [[GetInternalRowField]], remove the need to specify a [[StructField]].
*/
case class GetStructField(child: Expression, field: StructField, ordinal: Int)
extends UnaryExpression {
- override def dataType: DataType = field.dataType
+ override def dataType: DataType = child.dataType match {
+ case s: StructType => s(ordinal).dataType
+ // This is a hack to avoid breaking existing code until we remove the need for the struct field
+ case _ => field.dataType
+ }
override def nullable: Boolean = child.nullable || field.nullable
override def toString: String = s"$child.${field.name}"
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
index 32b09b59af..d9f046efce 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
@@ -483,9 +483,12 @@ case class MapPartitions[T, U](
/** Factory for constructing new `AppendColumn` nodes. */
object AppendColumn {
- def apply[T : Encoder, U : Encoder](func: T => U, child: LogicalPlan): AppendColumn[T, U] = {
+ def apply[T, U : Encoder](
+ func: T => U,
+ tEncoder: ExpressionEncoder[T],
+ child: LogicalPlan): AppendColumn[T, U] = {
val attrs = encoderFor[U].schema.toAttributes
- new AppendColumn[T, U](func, encoderFor[T], encoderFor[U], attrs, child)
+ new AppendColumn[T, U](func, tEncoder, encoderFor[U], attrs, child)
}
}
@@ -506,14 +509,16 @@ case class AppendColumn[T, U](
/** Factory for constructing new `MapGroups` nodes. */
object MapGroups {
- def apply[K : Encoder, T : Encoder, U : Encoder](
+ def apply[K, T, U : Encoder](
func: (K, Iterator[T]) => TraversableOnce[U],
+ kEncoder: ExpressionEncoder[K],
+ tEncoder: ExpressionEncoder[T],
groupingAttributes: Seq[Attribute],
child: LogicalPlan): MapGroups[K, T, U] = {
new MapGroups(
func,
- encoderFor[K],
- encoderFor[T],
+ kEncoder,
+ tEncoder,
encoderFor[U],
groupingAttributes,
encoderFor[U].schema.toAttributes,