aboutsummaryrefslogtreecommitdiff
path: root/sql/catalyst
diff options
context:
space:
mode:
authorMichael Armbrust <michael@databricks.com>2015-11-12 17:20:30 -0800
committerMichael Armbrust <michael@databricks.com>2015-11-12 17:20:30 -0800
commit41bbd2300472501d69ed46f0407d5ed7cbede4a8 (patch)
treed158a4f04eda93959aeb8ebb1b69b186fe1a8ee7 /sql/catalyst
parentdcb896fd8cec83483f700ee985c352be61cdf233 (diff)
downloadspark-41bbd2300472501d69ed46f0407d5ed7cbede4a8.tar.gz
spark-41bbd2300472501d69ed46f0407d5ed7cbede4a8.tar.bz2
spark-41bbd2300472501d69ed46f0407d5ed7cbede4a8.zip
[SPARK-11654][SQL] add reduce to GroupedDataset
This PR adds a new method, `reduce`, to `GroupedDataset`, which allows similar operations to `reduceByKey` on a traditional `PairRDD`. ```scala val ds = Seq("abc", "xyz", "hello").toDS() ds.groupBy(_.length).reduce(_ + _).collect() // not actually commutative :P res0: Array(3 -> "abcxyz", 5 -> "hello") ``` While implementing this method and its test cases several more deficiencies were found in our encoder handling. Specifically, in order to support positional resolution, named resolution and tuple composition, it is important to keep the unresolved encoder around and to use it when constructing new `Datasets` with the same object type but different output attributes. We now divide the encoder lifecycle into three phases (that mirror the lifecycle of standard expressions) and have checks at various boundaries: - Unresoved Encoders: all users facing encoders (those constructed by implicits, static methods, or tuple composition) are unresolved, meaning they have only `UnresolvedAttributes` for named fields and `BoundReferences` for fields accessed by ordinal. - Resolved Encoders: internal to a `[Grouped]Dataset` the encoder is resolved, meaning all input has been resolved to a specific `AttributeReference`. Any encoders that are placed into a logical plan for use in object construction should be resolved. - BoundEncoder: Are constructed by physical plans, right before actual conversion from row -> object is performed. It is left to future work to add explicit checks for resolution and provide good error messages when it fails. We might also consider enforcing the above constraints in the type system (i.e. `fromRow` only exists on a `ResolvedEncoder`), but we should probably wait before spending too much time on this. Author: Michael Armbrust <michael@databricks.com> Author: Wenchen Fan <wenchen@databricks.com> Closes #9673 from marmbrus/pr/9628.
Diffstat (limited to 'sql/catalyst')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala10
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala124
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/package.scala11
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala7
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala15
5 files changed, 103 insertions, 64 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala
index 6134f9e036..5f619d6c33 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala
@@ -84,7 +84,7 @@ object Encoders {
private def tuple(encoders: Seq[ExpressionEncoder[_]]): ExpressionEncoder[_] = {
assert(encoders.length > 1)
// make sure all encoders are resolved, i.e. `Attribute` has been resolved to `BoundReference`.
- assert(encoders.forall(_.constructExpression.find(_.isInstanceOf[Attribute]).isEmpty))
+ assert(encoders.forall(_.fromRowExpression.find(_.isInstanceOf[Attribute]).isEmpty))
val schema = StructType(encoders.zipWithIndex.map {
case (e, i) => StructField(s"_${i + 1}", if (e.flat) e.schema.head.dataType else e.schema)
@@ -93,8 +93,8 @@ object Encoders {
val cls = Utils.getContextOrSparkClassLoader.loadClass(s"scala.Tuple${encoders.size}")
val extractExpressions = encoders.map {
- case e if e.flat => e.extractExpressions.head
- case other => CreateStruct(other.extractExpressions)
+ case e if e.flat => e.toRowExpressions.head
+ case other => CreateStruct(other.toRowExpressions)
}.zipWithIndex.map { case (expr, index) =>
expr.transformUp {
case BoundReference(0, t: ObjectType, _) =>
@@ -107,11 +107,11 @@ object Encoders {
val constructExpressions = encoders.zipWithIndex.map { case (enc, index) =>
if (enc.flat) {
- enc.constructExpression.transform {
+ enc.fromRowExpression.transform {
case b: BoundReference => b.copy(ordinal = index)
}
} else {
- enc.constructExpression.transformUp {
+ enc.fromRowExpression.transformUp {
case BoundReference(ordinal, dt, _) =>
GetInternalRowField(BoundReference(index, enc.schema, nullable = true), ordinal, dt)
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
index 294afde534..0d3e4aafb0 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
@@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection, GenerateUnsafeProjection}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.ScalaReflection
-import org.apache.spark.sql.types.{StructField, ObjectType, StructType}
+import org.apache.spark.sql.types.{NullType, StructField, ObjectType, StructType}
/**
* A factory for constructing encoders that convert objects and primitves to and from the
@@ -61,20 +61,39 @@ object ExpressionEncoder {
/**
* Given a set of N encoders, constructs a new encoder that produce objects as items in an
- * N-tuple. Note that these encoders should first be bound correctly to the combined input
- * schema.
+ * N-tuple. Note that these encoders should be unresolved so that information about
+ * name/positional binding is preserved.
*/
def tuple(encoders: Seq[ExpressionEncoder[_]]): ExpressionEncoder[_] = {
+ encoders.foreach(_.assertUnresolved())
+
val schema =
StructType(
- encoders.zipWithIndex.map { case (e, i) => StructField(s"_${i + 1}", e.schema)})
+ encoders.zipWithIndex.map {
+ case (e, i) => StructField(s"_${i + 1}", if (e.flat) e.schema.head.dataType else e.schema)
+ })
val cls = Utils.getContextOrSparkClassLoader.loadClass(s"scala.Tuple${encoders.size}")
- val extractExpressions = encoders.map {
- case e if e.flat => e.extractExpressions.head
- case other => CreateStruct(other.extractExpressions)
+
+ // Rebind the encoders to the nested schema.
+ val newConstructExpressions = encoders.zipWithIndex.map {
+ case (e, i) if !e.flat => e.nested(i).fromRowExpression
+ case (e, i) => e.shift(i).fromRowExpression
}
+
val constructExpression =
- NewInstance(cls, encoders.map(_.constructExpression), false, ObjectType(cls))
+ NewInstance(cls, newConstructExpressions, false, ObjectType(cls))
+
+ val input = BoundReference(0, ObjectType(cls), false)
+ val extractExpressions = encoders.zipWithIndex.map {
+ case (e, i) if !e.flat => CreateStruct(e.toRowExpressions.map(_ transformUp {
+ case b: BoundReference =>
+ Invoke(input, s"_${i + 1}", b.dataType, Nil)
+ }))
+ case (e, i) => e.toRowExpressions.head transformUp {
+ case b: BoundReference =>
+ Invoke(input, s"_${i + 1}", b.dataType, Nil)
+ }
+ }
new ExpressionEncoder[Any](
schema,
@@ -95,35 +114,40 @@ object ExpressionEncoder {
* A generic encoder for JVM objects.
*
* @param schema The schema after converting `T` to a Spark SQL row.
- * @param extractExpressions A set of expressions, one for each top-level field that can be used to
- * extract the values from a raw object.
+ * @param toRowExpressions A set of expressions, one for each top-level field that can be used to
+ * extract the values from a raw object into an [[InternalRow]].
+ * @param fromRowExpression An expression that will construct an object given an [[InternalRow]].
* @param clsTag A classtag for `T`.
*/
case class ExpressionEncoder[T](
schema: StructType,
flat: Boolean,
- extractExpressions: Seq[Expression],
- constructExpression: Expression,
+ toRowExpressions: Seq[Expression],
+ fromRowExpression: Expression,
clsTag: ClassTag[T])
extends Encoder[T] {
- if (flat) require(extractExpressions.size == 1)
+ if (flat) require(toRowExpressions.size == 1)
@transient
- private lazy val extractProjection = GenerateUnsafeProjection.generate(extractExpressions)
+ private lazy val extractProjection = GenerateUnsafeProjection.generate(toRowExpressions)
private val inputRow = new GenericMutableRow(1)
@transient
- private lazy val constructProjection = GenerateSafeProjection.generate(constructExpression :: Nil)
+ private lazy val constructProjection = GenerateSafeProjection.generate(fromRowExpression :: Nil)
/**
* Returns an encoded version of `t` as a Spark SQL row. Note that multiple calls to
* toRow are allowed to return the same actual [[InternalRow]] object. Thus, the caller should
* copy the result before making another call if required.
*/
- def toRow(t: T): InternalRow = {
+ def toRow(t: T): InternalRow = try {
inputRow(0) = t
extractProjection(inputRow)
+ } catch {
+ case e: Exception =>
+ throw new RuntimeException(
+ s"Error while encoding: $e\n${toRowExpressions.map(_.treeString).mkString("\n")}", e)
}
/**
@@ -135,7 +159,20 @@ case class ExpressionEncoder[T](
constructProjection(row).get(0, ObjectType(clsTag.runtimeClass)).asInstanceOf[T]
} catch {
case e: Exception =>
- throw new RuntimeException(s"Error while decoding: $e\n${constructExpression.treeString}", e)
+ throw new RuntimeException(s"Error while decoding: $e\n${fromRowExpression.treeString}", e)
+ }
+
+ /**
+ * The process of resolution to a given schema throws away information about where a given field
+ * is being bound by ordinal instead of by name. This method checks to make sure this process
+ * has not been done already in places where we plan to do later composition of encoders.
+ */
+ def assertUnresolved(): Unit = {
+ (fromRowExpression +: toRowExpressions).foreach(_.foreach {
+ case a: AttributeReference =>
+ sys.error(s"Unresolved encoder expected, but $a was found.")
+ case _ =>
+ })
}
/**
@@ -143,9 +180,14 @@ case class ExpressionEncoder[T](
* given schema.
*/
def resolve(schema: Seq[Attribute]): ExpressionEncoder[T] = {
- val plan = Project(Alias(constructExpression, "")() :: Nil, LocalRelation(schema))
+ val positionToAttribute = AttributeMap.toIndex(schema)
+ val unbound = fromRowExpression transform {
+ case b: BoundReference => positionToAttribute(b.ordinal)
+ }
+
+ val plan = Project(Alias(unbound, "")() :: Nil, LocalRelation(schema))
val analyzedPlan = SimpleAnalyzer.execute(plan)
- copy(constructExpression = analyzedPlan.expressions.head.children.head)
+ copy(fromRowExpression = analyzedPlan.expressions.head.children.head)
}
/**
@@ -154,39 +196,14 @@ case class ExpressionEncoder[T](
* resolve before bind.
*/
def bind(schema: Seq[Attribute]): ExpressionEncoder[T] = {
- copy(constructExpression = BindReferences.bindReference(constructExpression, schema))
- }
-
- /**
- * Replaces any bound references in the schema with the attributes at the corresponding ordinal
- * in the provided schema. This can be used to "relocate" a given encoder to pull values from
- * a different schema than it was initially bound to. It can also be used to assign attributes
- * to ordinal based extraction (i.e. because the input data was a tuple).
- */
- def unbind(schema: Seq[Attribute]): ExpressionEncoder[T] = {
- val positionToAttribute = AttributeMap.toIndex(schema)
- copy(constructExpression = constructExpression transform {
- case b: BoundReference => positionToAttribute(b.ordinal)
- })
+ copy(fromRowExpression = BindReferences.bindReference(fromRowExpression, schema))
}
/**
- * Given an encoder that has already been bound to a given schema, returns a new encoder
- * where the positions are mapped from `oldSchema` to `newSchema`. This can be used, for example,
- * when you are trying to use an encoder on grouping keys that were originally part of a larger
- * row, but now you have projected out only the key expressions.
+ * Returns a new encoder with input columns shifted by `delta` ordinals
*/
- def rebind(oldSchema: Seq[Attribute], newSchema: Seq[Attribute]): ExpressionEncoder[T] = {
- val positionToAttribute = AttributeMap.toIndex(oldSchema)
- val attributeToNewPosition = AttributeMap.byIndex(newSchema)
- copy(constructExpression = constructExpression transform {
- case r: BoundReference =>
- r.copy(ordinal = attributeToNewPosition(positionToAttribute(r.ordinal)))
- })
- }
-
def shift(delta: Int): ExpressionEncoder[T] = {
- copy(constructExpression = constructExpression transform {
+ copy(fromRowExpression = fromRowExpression transform {
case r: BoundReference => r.copy(ordinal = r.ordinal + delta)
})
}
@@ -196,11 +213,14 @@ case class ExpressionEncoder[T](
* input row have been modified to pull the object out from a nested struct, instead of the
* top level fields.
*/
- def nested(input: Expression = BoundReference(0, schema, true)): ExpressionEncoder[T] = {
- copy(constructExpression = constructExpression transform {
- case u: Attribute if u != input =>
+ private def nested(i: Int): ExpressionEncoder[T] = {
+ // We don't always know our input type at this point since it might be unresolved.
+ // We fill in null and it will get unbound to the actual attribute at this position.
+ val input = BoundReference(i, NullType, nullable = true)
+ copy(fromRowExpression = fromRowExpression transformUp {
+ case u: Attribute =>
UnresolvedExtractValue(input, Literal(u.name))
- case b: BoundReference if b != input =>
+ case b: BoundReference =>
GetStructField(
input,
StructField(s"i[${b.ordinal}]", b.dataType),
@@ -208,7 +228,7 @@ case class ExpressionEncoder[T](
})
}
- protected val attrs = extractExpressions.flatMap(_.collect {
+ protected val attrs = toRowExpressions.flatMap(_.collect {
case _: UnresolvedAttribute => ""
case a: Attribute => s"#${a.exprId}"
case b: BoundReference => s"[${b.ordinal}]"
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/package.scala
index 2c35adca9c..9e283f5eb6 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/package.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/package.scala
@@ -18,10 +18,19 @@
package org.apache.spark.sql.catalyst
import org.apache.spark.sql.Encoder
+import org.apache.spark.sql.catalyst.expressions.AttributeReference
package object encoders {
+ /**
+ * Returns an internal encoder object that can be used to serialize / deserialize JVM objects
+ * into Spark SQL rows. The implicit encoder should always be unresolved (i.e. have no attribute
+ * references from a specific schema.) This requirement allows us to preserve whether a given
+ * object type is being bound by name or by ordinal when doing resolution.
+ */
private[sql] def encoderFor[A : Encoder]: ExpressionEncoder[A] = implicitly[Encoder[A]] match {
- case e: ExpressionEncoder[A] => e
+ case e: ExpressionEncoder[A] =>
+ e.assertUnresolved()
+ e
case _ => sys.error(s"Only expression encoders are supported today")
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
index 41cd0a104a..f871b737ff 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
@@ -97,11 +97,16 @@ object ExtractValue {
* Returns the value of fields in the Struct `child`.
*
* No need to do type checking since it is handled by [[ExtractValue]].
+ * TODO: Unify with [[GetInternalRowField]], remove the need to specify a [[StructField]].
*/
case class GetStructField(child: Expression, field: StructField, ordinal: Int)
extends UnaryExpression {
- override def dataType: DataType = field.dataType
+ override def dataType: DataType = child.dataType match {
+ case s: StructType => s(ordinal).dataType
+ // This is a hack to avoid breaking existing code until we remove the need for the struct field
+ case _ => field.dataType
+ }
override def nullable: Boolean = child.nullable || field.nullable
override def toString: String = s"$child.${field.name}"
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
index 32b09b59af..d9f046efce 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
@@ -483,9 +483,12 @@ case class MapPartitions[T, U](
/** Factory for constructing new `AppendColumn` nodes. */
object AppendColumn {
- def apply[T : Encoder, U : Encoder](func: T => U, child: LogicalPlan): AppendColumn[T, U] = {
+ def apply[T, U : Encoder](
+ func: T => U,
+ tEncoder: ExpressionEncoder[T],
+ child: LogicalPlan): AppendColumn[T, U] = {
val attrs = encoderFor[U].schema.toAttributes
- new AppendColumn[T, U](func, encoderFor[T], encoderFor[U], attrs, child)
+ new AppendColumn[T, U](func, tEncoder, encoderFor[U], attrs, child)
}
}
@@ -506,14 +509,16 @@ case class AppendColumn[T, U](
/** Factory for constructing new `MapGroups` nodes. */
object MapGroups {
- def apply[K : Encoder, T : Encoder, U : Encoder](
+ def apply[K, T, U : Encoder](
func: (K, Iterator[T]) => TraversableOnce[U],
+ kEncoder: ExpressionEncoder[K],
+ tEncoder: ExpressionEncoder[T],
groupingAttributes: Seq[Attribute],
child: LogicalPlan): MapGroups[K, T, U] = {
new MapGroups(
func,
- encoderFor[K],
- encoderFor[T],
+ kEncoder,
+ tEncoder,
encoderFor[U],
groupingAttributes,
encoderFor[U].schema.toAttributes,