aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMichael Armbrust <michael@databricks.com>2015-11-12 17:20:30 -0800
committerMichael Armbrust <michael@databricks.com>2015-11-12 17:20:30 -0800
commit41bbd2300472501d69ed46f0407d5ed7cbede4a8 (patch)
treed158a4f04eda93959aeb8ebb1b69b186fe1a8ee7
parentdcb896fd8cec83483f700ee985c352be61cdf233 (diff)
downloadspark-41bbd2300472501d69ed46f0407d5ed7cbede4a8.tar.gz
spark-41bbd2300472501d69ed46f0407d5ed7cbede4a8.tar.bz2
spark-41bbd2300472501d69ed46f0407d5ed7cbede4a8.zip
[SPARK-11654][SQL] add reduce to GroupedDataset
This PR adds a new method, `reduce`, to `GroupedDataset`, which allows similar operations to `reduceByKey` on a traditional `PairRDD`. ```scala val ds = Seq("abc", "xyz", "hello").toDS() ds.groupBy(_.length).reduce(_ + _).collect() // not actually commutative :P res0: Array(3 -> "abcxyz", 5 -> "hello") ``` While implementing this method and its test cases several more deficiencies were found in our encoder handling. Specifically, in order to support positional resolution, named resolution and tuple composition, it is important to keep the unresolved encoder around and to use it when constructing new `Datasets` with the same object type but different output attributes. We now divide the encoder lifecycle into three phases (that mirror the lifecycle of standard expressions) and have checks at various boundaries: - Unresoved Encoders: all users facing encoders (those constructed by implicits, static methods, or tuple composition) are unresolved, meaning they have only `UnresolvedAttributes` for named fields and `BoundReferences` for fields accessed by ordinal. - Resolved Encoders: internal to a `[Grouped]Dataset` the encoder is resolved, meaning all input has been resolved to a specific `AttributeReference`. Any encoders that are placed into a logical plan for use in object construction should be resolved. - BoundEncoder: Are constructed by physical plans, right before actual conversion from row -> object is performed. It is left to future work to add explicit checks for resolution and provide good error messages when it fails. We might also consider enforcing the above constraints in the type system (i.e. `fromRow` only exists on a `ResolvedEncoder`), but we should probably wait before spending too much time on this. Author: Michael Armbrust <michael@databricks.com> Author: Wenchen Fan <wenchen@databricks.com> Closes #9673 from marmbrus/pr/9628.
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala10
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala124
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/package.scala11
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala7
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala15
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/Column.scala43
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala17
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala85
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala98
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala13
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala7
-rw-r--r--sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java12
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala42
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala9
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala13
15 files changed, 309 insertions, 197 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala
index 6134f9e036..5f619d6c33 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala
@@ -84,7 +84,7 @@ object Encoders {
private def tuple(encoders: Seq[ExpressionEncoder[_]]): ExpressionEncoder[_] = {
assert(encoders.length > 1)
// make sure all encoders are resolved, i.e. `Attribute` has been resolved to `BoundReference`.
- assert(encoders.forall(_.constructExpression.find(_.isInstanceOf[Attribute]).isEmpty))
+ assert(encoders.forall(_.fromRowExpression.find(_.isInstanceOf[Attribute]).isEmpty))
val schema = StructType(encoders.zipWithIndex.map {
case (e, i) => StructField(s"_${i + 1}", if (e.flat) e.schema.head.dataType else e.schema)
@@ -93,8 +93,8 @@ object Encoders {
val cls = Utils.getContextOrSparkClassLoader.loadClass(s"scala.Tuple${encoders.size}")
val extractExpressions = encoders.map {
- case e if e.flat => e.extractExpressions.head
- case other => CreateStruct(other.extractExpressions)
+ case e if e.flat => e.toRowExpressions.head
+ case other => CreateStruct(other.toRowExpressions)
}.zipWithIndex.map { case (expr, index) =>
expr.transformUp {
case BoundReference(0, t: ObjectType, _) =>
@@ -107,11 +107,11 @@ object Encoders {
val constructExpressions = encoders.zipWithIndex.map { case (enc, index) =>
if (enc.flat) {
- enc.constructExpression.transform {
+ enc.fromRowExpression.transform {
case b: BoundReference => b.copy(ordinal = index)
}
} else {
- enc.constructExpression.transformUp {
+ enc.fromRowExpression.transformUp {
case BoundReference(ordinal, dt, _) =>
GetInternalRowField(BoundReference(index, enc.schema, nullable = true), ordinal, dt)
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
index 294afde534..0d3e4aafb0 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
@@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection, GenerateUnsafeProjection}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.ScalaReflection
-import org.apache.spark.sql.types.{StructField, ObjectType, StructType}
+import org.apache.spark.sql.types.{NullType, StructField, ObjectType, StructType}
/**
* A factory for constructing encoders that convert objects and primitves to and from the
@@ -61,20 +61,39 @@ object ExpressionEncoder {
/**
* Given a set of N encoders, constructs a new encoder that produce objects as items in an
- * N-tuple. Note that these encoders should first be bound correctly to the combined input
- * schema.
+ * N-tuple. Note that these encoders should be unresolved so that information about
+ * name/positional binding is preserved.
*/
def tuple(encoders: Seq[ExpressionEncoder[_]]): ExpressionEncoder[_] = {
+ encoders.foreach(_.assertUnresolved())
+
val schema =
StructType(
- encoders.zipWithIndex.map { case (e, i) => StructField(s"_${i + 1}", e.schema)})
+ encoders.zipWithIndex.map {
+ case (e, i) => StructField(s"_${i + 1}", if (e.flat) e.schema.head.dataType else e.schema)
+ })
val cls = Utils.getContextOrSparkClassLoader.loadClass(s"scala.Tuple${encoders.size}")
- val extractExpressions = encoders.map {
- case e if e.flat => e.extractExpressions.head
- case other => CreateStruct(other.extractExpressions)
+
+ // Rebind the encoders to the nested schema.
+ val newConstructExpressions = encoders.zipWithIndex.map {
+ case (e, i) if !e.flat => e.nested(i).fromRowExpression
+ case (e, i) => e.shift(i).fromRowExpression
}
+
val constructExpression =
- NewInstance(cls, encoders.map(_.constructExpression), false, ObjectType(cls))
+ NewInstance(cls, newConstructExpressions, false, ObjectType(cls))
+
+ val input = BoundReference(0, ObjectType(cls), false)
+ val extractExpressions = encoders.zipWithIndex.map {
+ case (e, i) if !e.flat => CreateStruct(e.toRowExpressions.map(_ transformUp {
+ case b: BoundReference =>
+ Invoke(input, s"_${i + 1}", b.dataType, Nil)
+ }))
+ case (e, i) => e.toRowExpressions.head transformUp {
+ case b: BoundReference =>
+ Invoke(input, s"_${i + 1}", b.dataType, Nil)
+ }
+ }
new ExpressionEncoder[Any](
schema,
@@ -95,35 +114,40 @@ object ExpressionEncoder {
* A generic encoder for JVM objects.
*
* @param schema The schema after converting `T` to a Spark SQL row.
- * @param extractExpressions A set of expressions, one for each top-level field that can be used to
- * extract the values from a raw object.
+ * @param toRowExpressions A set of expressions, one for each top-level field that can be used to
+ * extract the values from a raw object into an [[InternalRow]].
+ * @param fromRowExpression An expression that will construct an object given an [[InternalRow]].
* @param clsTag A classtag for `T`.
*/
case class ExpressionEncoder[T](
schema: StructType,
flat: Boolean,
- extractExpressions: Seq[Expression],
- constructExpression: Expression,
+ toRowExpressions: Seq[Expression],
+ fromRowExpression: Expression,
clsTag: ClassTag[T])
extends Encoder[T] {
- if (flat) require(extractExpressions.size == 1)
+ if (flat) require(toRowExpressions.size == 1)
@transient
- private lazy val extractProjection = GenerateUnsafeProjection.generate(extractExpressions)
+ private lazy val extractProjection = GenerateUnsafeProjection.generate(toRowExpressions)
private val inputRow = new GenericMutableRow(1)
@transient
- private lazy val constructProjection = GenerateSafeProjection.generate(constructExpression :: Nil)
+ private lazy val constructProjection = GenerateSafeProjection.generate(fromRowExpression :: Nil)
/**
* Returns an encoded version of `t` as a Spark SQL row. Note that multiple calls to
* toRow are allowed to return the same actual [[InternalRow]] object. Thus, the caller should
* copy the result before making another call if required.
*/
- def toRow(t: T): InternalRow = {
+ def toRow(t: T): InternalRow = try {
inputRow(0) = t
extractProjection(inputRow)
+ } catch {
+ case e: Exception =>
+ throw new RuntimeException(
+ s"Error while encoding: $e\n${toRowExpressions.map(_.treeString).mkString("\n")}", e)
}
/**
@@ -135,7 +159,20 @@ case class ExpressionEncoder[T](
constructProjection(row).get(0, ObjectType(clsTag.runtimeClass)).asInstanceOf[T]
} catch {
case e: Exception =>
- throw new RuntimeException(s"Error while decoding: $e\n${constructExpression.treeString}", e)
+ throw new RuntimeException(s"Error while decoding: $e\n${fromRowExpression.treeString}", e)
+ }
+
+ /**
+ * The process of resolution to a given schema throws away information about where a given field
+ * is being bound by ordinal instead of by name. This method checks to make sure this process
+ * has not been done already in places where we plan to do later composition of encoders.
+ */
+ def assertUnresolved(): Unit = {
+ (fromRowExpression +: toRowExpressions).foreach(_.foreach {
+ case a: AttributeReference =>
+ sys.error(s"Unresolved encoder expected, but $a was found.")
+ case _ =>
+ })
}
/**
@@ -143,9 +180,14 @@ case class ExpressionEncoder[T](
* given schema.
*/
def resolve(schema: Seq[Attribute]): ExpressionEncoder[T] = {
- val plan = Project(Alias(constructExpression, "")() :: Nil, LocalRelation(schema))
+ val positionToAttribute = AttributeMap.toIndex(schema)
+ val unbound = fromRowExpression transform {
+ case b: BoundReference => positionToAttribute(b.ordinal)
+ }
+
+ val plan = Project(Alias(unbound, "")() :: Nil, LocalRelation(schema))
val analyzedPlan = SimpleAnalyzer.execute(plan)
- copy(constructExpression = analyzedPlan.expressions.head.children.head)
+ copy(fromRowExpression = analyzedPlan.expressions.head.children.head)
}
/**
@@ -154,39 +196,14 @@ case class ExpressionEncoder[T](
* resolve before bind.
*/
def bind(schema: Seq[Attribute]): ExpressionEncoder[T] = {
- copy(constructExpression = BindReferences.bindReference(constructExpression, schema))
- }
-
- /**
- * Replaces any bound references in the schema with the attributes at the corresponding ordinal
- * in the provided schema. This can be used to "relocate" a given encoder to pull values from
- * a different schema than it was initially bound to. It can also be used to assign attributes
- * to ordinal based extraction (i.e. because the input data was a tuple).
- */
- def unbind(schema: Seq[Attribute]): ExpressionEncoder[T] = {
- val positionToAttribute = AttributeMap.toIndex(schema)
- copy(constructExpression = constructExpression transform {
- case b: BoundReference => positionToAttribute(b.ordinal)
- })
+ copy(fromRowExpression = BindReferences.bindReference(fromRowExpression, schema))
}
/**
- * Given an encoder that has already been bound to a given schema, returns a new encoder
- * where the positions are mapped from `oldSchema` to `newSchema`. This can be used, for example,
- * when you are trying to use an encoder on grouping keys that were originally part of a larger
- * row, but now you have projected out only the key expressions.
+ * Returns a new encoder with input columns shifted by `delta` ordinals
*/
- def rebind(oldSchema: Seq[Attribute], newSchema: Seq[Attribute]): ExpressionEncoder[T] = {
- val positionToAttribute = AttributeMap.toIndex(oldSchema)
- val attributeToNewPosition = AttributeMap.byIndex(newSchema)
- copy(constructExpression = constructExpression transform {
- case r: BoundReference =>
- r.copy(ordinal = attributeToNewPosition(positionToAttribute(r.ordinal)))
- })
- }
-
def shift(delta: Int): ExpressionEncoder[T] = {
- copy(constructExpression = constructExpression transform {
+ copy(fromRowExpression = fromRowExpression transform {
case r: BoundReference => r.copy(ordinal = r.ordinal + delta)
})
}
@@ -196,11 +213,14 @@ case class ExpressionEncoder[T](
* input row have been modified to pull the object out from a nested struct, instead of the
* top level fields.
*/
- def nested(input: Expression = BoundReference(0, schema, true)): ExpressionEncoder[T] = {
- copy(constructExpression = constructExpression transform {
- case u: Attribute if u != input =>
+ private def nested(i: Int): ExpressionEncoder[T] = {
+ // We don't always know our input type at this point since it might be unresolved.
+ // We fill in null and it will get unbound to the actual attribute at this position.
+ val input = BoundReference(i, NullType, nullable = true)
+ copy(fromRowExpression = fromRowExpression transformUp {
+ case u: Attribute =>
UnresolvedExtractValue(input, Literal(u.name))
- case b: BoundReference if b != input =>
+ case b: BoundReference =>
GetStructField(
input,
StructField(s"i[${b.ordinal}]", b.dataType),
@@ -208,7 +228,7 @@ case class ExpressionEncoder[T](
})
}
- protected val attrs = extractExpressions.flatMap(_.collect {
+ protected val attrs = toRowExpressions.flatMap(_.collect {
case _: UnresolvedAttribute => ""
case a: Attribute => s"#${a.exprId}"
case b: BoundReference => s"[${b.ordinal}]"
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/package.scala
index 2c35adca9c..9e283f5eb6 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/package.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/package.scala
@@ -18,10 +18,19 @@
package org.apache.spark.sql.catalyst
import org.apache.spark.sql.Encoder
+import org.apache.spark.sql.catalyst.expressions.AttributeReference
package object encoders {
+ /**
+ * Returns an internal encoder object that can be used to serialize / deserialize JVM objects
+ * into Spark SQL rows. The implicit encoder should always be unresolved (i.e. have no attribute
+ * references from a specific schema.) This requirement allows us to preserve whether a given
+ * object type is being bound by name or by ordinal when doing resolution.
+ */
private[sql] def encoderFor[A : Encoder]: ExpressionEncoder[A] = implicitly[Encoder[A]] match {
- case e: ExpressionEncoder[A] => e
+ case e: ExpressionEncoder[A] =>
+ e.assertUnresolved()
+ e
case _ => sys.error(s"Only expression encoders are supported today")
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
index 41cd0a104a..f871b737ff 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
@@ -97,11 +97,16 @@ object ExtractValue {
* Returns the value of fields in the Struct `child`.
*
* No need to do type checking since it is handled by [[ExtractValue]].
+ * TODO: Unify with [[GetInternalRowField]], remove the need to specify a [[StructField]].
*/
case class GetStructField(child: Expression, field: StructField, ordinal: Int)
extends UnaryExpression {
- override def dataType: DataType = field.dataType
+ override def dataType: DataType = child.dataType match {
+ case s: StructType => s(ordinal).dataType
+ // This is a hack to avoid breaking existing code until we remove the need for the struct field
+ case _ => field.dataType
+ }
override def nullable: Boolean = child.nullable || field.nullable
override def toString: String = s"$child.${field.name}"
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
index 32b09b59af..d9f046efce 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
@@ -483,9 +483,12 @@ case class MapPartitions[T, U](
/** Factory for constructing new `AppendColumn` nodes. */
object AppendColumn {
- def apply[T : Encoder, U : Encoder](func: T => U, child: LogicalPlan): AppendColumn[T, U] = {
+ def apply[T, U : Encoder](
+ func: T => U,
+ tEncoder: ExpressionEncoder[T],
+ child: LogicalPlan): AppendColumn[T, U] = {
val attrs = encoderFor[U].schema.toAttributes
- new AppendColumn[T, U](func, encoderFor[T], encoderFor[U], attrs, child)
+ new AppendColumn[T, U](func, tEncoder, encoderFor[U], attrs, child)
}
}
@@ -506,14 +509,16 @@ case class AppendColumn[T, U](
/** Factory for constructing new `MapGroups` nodes. */
object MapGroups {
- def apply[K : Encoder, T : Encoder, U : Encoder](
+ def apply[K, T, U : Encoder](
func: (K, Iterator[T]) => TraversableOnce[U],
+ kEncoder: ExpressionEncoder[K],
+ tEncoder: ExpressionEncoder[T],
groupingAttributes: Seq[Attribute],
child: LogicalPlan): MapGroups[K, T, U] = {
new MapGroups(
func,
- encoderFor[K],
- encoderFor[T],
+ kEncoder,
+ tEncoder,
encoderFor[U],
groupingAttributes,
encoderFor[U].schema.toAttributes,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
index f0f275e91f..929224460d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
@@ -17,13 +17,15 @@
package org.apache.spark.sql
+import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression
+
import scala.language.implicitConversions
import org.apache.spark.annotation.Experimental
import org.apache.spark.Logging
import org.apache.spark.sql.functions.lit
import org.apache.spark.sql.catalyst.analysis._
-import org.apache.spark.sql.catalyst.encoders.encoderFor
+import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, encoderFor}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.util.DataTypeParser
import org.apache.spark.sql.types._
@@ -45,7 +47,25 @@ private[sql] object Column {
* checked by the analyzer instead of the compiler (i.e. `expr("sum(...)")`).
* @tparam U The output type of this column.
*/
-class TypedColumn[-T, U](expr: Expression, val encoder: Encoder[U]) extends Column(expr)
+class TypedColumn[-T, U](
+ expr: Expression,
+ private[sql] val encoder: ExpressionEncoder[U]) extends Column(expr) {
+
+ /**
+ * Inserts the specific input type and schema into any expressions that are expected to operate
+ * on a decoded object.
+ */
+ private[sql] def withInputType(
+ inputEncoder: ExpressionEncoder[_],
+ schema: Seq[Attribute]): TypedColumn[T, U] = {
+ new TypedColumn[T, U] (expr transform {
+ case ta: TypedAggregateExpression if ta.aEncoder.isEmpty =>
+ ta.copy(
+ aEncoder = Some(inputEncoder.asInstanceOf[ExpressionEncoder[Any]]),
+ children = schema)
+ }, encoder)
+ }
+}
/**
* :: Experimental ::
@@ -73,6 +93,25 @@ class Column(protected[sql] val expr: Expression) extends Logging {
/** Creates a column based on the given expression. */
private def withExpr(newExpr: Expression): Column = new Column(newExpr)
+ /**
+ * Returns the expression for this column either with an existing or auto assigned name.
+ */
+ private[sql] def named: NamedExpression = expr match {
+ // Wrap UnresolvedAttribute with UnresolvedAlias, as when we resolve UnresolvedAttribute, we
+ // will remove intermediate Alias for ExtractValue chain, and we need to alias it again to
+ // make it a NamedExpression.
+ case u: UnresolvedAttribute => UnresolvedAlias(u)
+
+ case expr: NamedExpression => expr
+
+ // Leave an unaliased generator with an empty list of names since the analyzer will generate
+ // the correct defaults after the nested expression's type has been resolved.
+ case explode: Explode => MultiAlias(explode, Nil)
+ case jt: JsonTuple => MultiAlias(jt, Nil)
+
+ case expr: Expression => Alias(expr, expr.prettyString)()
+ }
+
override def toString: String = expr.prettyString
override def equals(that: Any): Boolean = that match {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
index a492099b93..3ba4ba18d2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
@@ -735,22 +735,7 @@ class DataFrame private[sql](
*/
@scala.annotation.varargs
def select(cols: Column*): DataFrame = withPlan {
- val namedExpressions = cols.map {
- // Wrap UnresolvedAttribute with UnresolvedAlias, as when we resolve UnresolvedAttribute, we
- // will remove intermediate Alias for ExtractValue chain, and we need to alias it again to
- // make it a NamedExpression.
- case Column(u: UnresolvedAttribute) => UnresolvedAlias(u)
-
- case Column(expr: NamedExpression) => expr
-
- // Leave an unaliased generator with an empty list of names since the analyzer will generate
- // the correct defaults after the nested expression's type has been resolved.
- case Column(explode: Explode) => MultiAlias(explode, Nil)
- case Column(jt: JsonTuple) => MultiAlias(jt, Nil)
-
- case Column(expr: Expression) => Alias(expr, expr.prettyString)()
- }
- Project(namedExpressions.toSeq, logicalPlan)
+ Project(cols.map(_.named), logicalPlan)
}
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index 87dae6b331..b930e4661c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -29,7 +29,6 @@ import org.apache.spark.sql.catalyst.analysis.UnresolvedAlias
import org.apache.spark.sql.catalyst.plans.Inner
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.execution.{Queryable, QueryExecution}
-import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression
import org.apache.spark.sql.types.StructType
/**
@@ -63,15 +62,20 @@ import org.apache.spark.sql.types.StructType
class Dataset[T] private[sql](
@transient val sqlContext: SQLContext,
@transient val queryExecution: QueryExecution,
- unresolvedEncoder: Encoder[T]) extends Queryable with Serializable {
+ tEncoder: Encoder[T]) extends Queryable with Serializable {
+
+ /**
+ * An unresolved version of the internal encoder for the type of this dataset. This one is marked
+ * implicit so that we can use it when constructing new [[Dataset]] objects that have the same
+ * object type (that will be possibly resolved to a different schema).
+ */
+ private implicit val unresolvedTEncoder: ExpressionEncoder[T] = encoderFor(tEncoder)
/** The encoder for this [[Dataset]] that has been resolved to its output schema. */
- private[sql] implicit val encoder: ExpressionEncoder[T] = unresolvedEncoder match {
- case e: ExpressionEncoder[T] => e.resolve(queryExecution.analyzed.output)
- case _ => throw new IllegalArgumentException("Only expression encoders are currently supported")
- }
+ private[sql] val resolvedTEncoder: ExpressionEncoder[T] =
+ unresolvedTEncoder.resolve(queryExecution.analyzed.output)
- private implicit def classTag = encoder.clsTag
+ private implicit def classTag = resolvedTEncoder.clsTag
private[sql] def this(sqlContext: SQLContext, plan: LogicalPlan)(implicit encoder: Encoder[T]) =
this(sqlContext, new QueryExecution(sqlContext, plan), encoder)
@@ -81,7 +85,7 @@ class Dataset[T] private[sql](
*
* @since 1.6.0
*/
- def schema: StructType = encoder.schema
+ def schema: StructType = resolvedTEncoder.schema
/* ************* *
* Conversions *
@@ -134,7 +138,7 @@ class Dataset[T] private[sql](
* @since 1.6.0
*/
def rdd: RDD[T] = {
- val tEnc = encoderFor[T]
+ val tEnc = resolvedTEncoder
val input = queryExecution.analyzed.output
queryExecution.toRdd.mapPartitions { iter =>
val bound = tEnc.bind(input)
@@ -195,7 +199,7 @@ class Dataset[T] private[sql](
* @since 1.6.0
*/
def mapPartitions[U : Encoder](func: Iterator[T] => Iterator[U]): Dataset[U] = {
- new Dataset(
+ new Dataset[U](
sqlContext,
MapPartitions[T, U](
func,
@@ -295,12 +299,12 @@ class Dataset[T] private[sql](
*/
def groupBy[K : Encoder](func: T => K): GroupedDataset[K, T] = {
val inputPlan = queryExecution.analyzed
- val withGroupingKey = AppendColumn(func, inputPlan)
+ val withGroupingKey = AppendColumn(func, resolvedTEncoder, inputPlan)
val executed = sqlContext.executePlan(withGroupingKey)
new GroupedDataset(
- encoderFor[K].resolve(withGroupingKey.newColumns),
- encoderFor[T].bind(inputPlan.output),
+ encoderFor[K],
+ encoderFor[T],
executed,
inputPlan.output,
withGroupingKey.newColumns)
@@ -360,7 +364,15 @@ class Dataset[T] private[sql](
* @since 1.6.0
*/
def select[U1: Encoder](c1: TypedColumn[T, U1]): Dataset[U1] = {
- new Dataset[U1](sqlContext, Project(Alias(withEncoder(c1).expr, "_1")() :: Nil, logicalPlan))
+ // We use an unbound encoder since the expression will make up its own schema.
+ // TODO: This probably doesn't work if we are relying on reordering of the input class fields.
+ new Dataset[U1](
+ sqlContext,
+ Project(
+ c1.withInputType(
+ resolvedTEncoder.bind(queryExecution.analyzed.output),
+ queryExecution.analyzed.output).named :: Nil,
+ logicalPlan))
}
/**
@@ -369,28 +381,14 @@ class Dataset[T] private[sql](
* that cast appropriately for the user facing interface.
*/
protected def selectUntyped(columns: TypedColumn[_, _]*): Dataset[_] = {
- val withEncoders = columns.map(withEncoder)
- val aliases = withEncoders.zipWithIndex.map { case (c, i) => Alias(c.expr, s"_${i + 1}")() }
- val unresolvedPlan = Project(aliases, logicalPlan)
- val execution = new QueryExecution(sqlContext, unresolvedPlan)
- // Rebind the encoders to the nested schema that will be produced by the select.
- val encoders = withEncoders.map(_.encoder.asInstanceOf[ExpressionEncoder[_]]).zip(aliases).map {
- case (e: ExpressionEncoder[_], a) if !e.flat =>
- e.nested(a.toAttribute).resolve(execution.analyzed.output)
- case (e, a) =>
- e.unbind(a.toAttribute :: Nil).resolve(execution.analyzed.output)
- }
- new Dataset(sqlContext, execution, ExpressionEncoder.tuple(encoders))
- }
+ val encoders = columns.map(_.encoder)
+ // We use an unbound encoder since the expression will make up its own schema.
+ // TODO: This probably doesn't work if we are relying on reordering of the input class fields.
+ val namedColumns =
+ columns.map(_.withInputType(unresolvedTEncoder, queryExecution.analyzed.output).named)
+ val execution = new QueryExecution(sqlContext, Project(namedColumns, logicalPlan))
- private def withEncoder(c: TypedColumn[_, _]): TypedColumn[_, _] = {
- val e = c.expr transform {
- case ta: TypedAggregateExpression if ta.aEncoder.isEmpty =>
- ta.copy(
- aEncoder = Some(encoder.asInstanceOf[ExpressionEncoder[Any]]),
- children = queryExecution.analyzed.output)
- }
- new TypedColumn(e, c.encoder)
+ new Dataset(sqlContext, execution, ExpressionEncoder.tuple(encoders))
}
/**
@@ -497,23 +495,18 @@ class Dataset[T] private[sql](
val left = this.logicalPlan
val right = other.logicalPlan
- val leftData = this.encoder match {
+ val leftData = this.unresolvedTEncoder match {
case e if e.flat => Alias(left.output.head, "_1")()
case _ => Alias(CreateStruct(left.output), "_1")()
}
- val rightData = other.encoder match {
+ val rightData = other.unresolvedTEncoder match {
case e if e.flat => Alias(right.output.head, "_2")()
case _ => Alias(CreateStruct(right.output), "_2")()
}
- val leftEncoder =
- if (encoder.flat) encoder else encoder.nested(leftData.toAttribute)
- val rightEncoder =
- if (other.encoder.flat) other.encoder else other.encoder.nested(rightData.toAttribute)
- implicit val tuple2Encoder: Encoder[(T, U)] =
- ExpressionEncoder.tuple(
- leftEncoder,
- rightEncoder.rebind(right.output, left.output ++ right.output))
+
+ implicit val tuple2Encoder: Encoder[(T, U)] =
+ ExpressionEncoder.tuple(this.unresolvedTEncoder, other.unresolvedTEncoder)
withPlan[(T, U)](other) { (left, right) =>
Project(
leftData :: rightData :: Nil,
@@ -580,7 +573,7 @@ class Dataset[T] private[sql](
private[sql] def logicalPlan = queryExecution.analyzed
private[sql] def withPlan(f: LogicalPlan => LogicalPlan): Dataset[T] =
- new Dataset[T](sqlContext, sqlContext.executePlan(f(logicalPlan)), encoder)
+ new Dataset[T](sqlContext, sqlContext.executePlan(f(logicalPlan)), tEncoder)
private[sql] def withPlan[R : Encoder](
other: Dataset[_])(
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
index 61e2a95450..ae1272ae53 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
@@ -17,20 +17,16 @@
package org.apache.spark.sql
-import java.util.{Iterator => JIterator}
import scala.collection.JavaConverters._
import org.apache.spark.annotation.Experimental
-import org.apache.spark.api.java.function.{Function2 => JFunction2, Function3 => JFunction3, _}
-import org.apache.spark.sql.catalyst.analysis.{UnresolvedAlias, UnresolvedAttribute}
+import org.apache.spark.api.java.function._
import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, encoderFor}
-import org.apache.spark.sql.catalyst.expressions.{Expression, NamedExpression, Alias, Attribute}
+import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical._
-import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression
import org.apache.spark.sql.execution.QueryExecution
-
/**
* :: Experimental ::
* A [[Dataset]] has been logically grouped by a user specified grouping key. Users should not
@@ -44,23 +40,21 @@ import org.apache.spark.sql.execution.QueryExecution
*/
@Experimental
class GroupedDataset[K, T] private[sql](
- private val kEncoder: Encoder[K],
- private val tEncoder: Encoder[T],
- queryExecution: QueryExecution,
+ kEncoder: Encoder[K],
+ tEncoder: Encoder[T],
+ val queryExecution: QueryExecution,
private val dataAttributes: Seq[Attribute],
private val groupingAttributes: Seq[Attribute]) extends Serializable {
- private implicit val kEnc = kEncoder match {
- case e: ExpressionEncoder[K] => e.unbind(groupingAttributes).resolve(groupingAttributes)
- case other =>
- throw new UnsupportedOperationException("Only expression encoders are currently supported")
- }
+ // Similar to [[Dataset]], we use unresolved encoders for later composition and resolved encoders
+ // when constructing new logical plans that will operate on the output of the current
+ // queryexecution.
- private implicit val tEnc = tEncoder match {
- case e: ExpressionEncoder[T] => e.resolve(dataAttributes)
- case other =>
- throw new UnsupportedOperationException("Only expression encoders are currently supported")
- }
+ private implicit val unresolvedKEncoder = encoderFor(kEncoder)
+ private implicit val unresolvedTEncoder = encoderFor(tEncoder)
+
+ private val resolvedKEncoder = unresolvedKEncoder.resolve(groupingAttributes)
+ private val resolvedTEncoder = unresolvedTEncoder.resolve(dataAttributes)
/** Encoders for built in aggregations. */
private implicit def newLongEncoder: Encoder[Long] = ExpressionEncoder[Long](flat = true)
@@ -79,7 +73,7 @@ class GroupedDataset[K, T] private[sql](
def asKey[L : Encoder]: GroupedDataset[L, T] =
new GroupedDataset(
encoderFor[L],
- tEncoder,
+ unresolvedTEncoder,
queryExecution,
dataAttributes,
groupingAttributes)
@@ -95,7 +89,7 @@ class GroupedDataset[K, T] private[sql](
}
/**
- * Applies the given function to each group of data. For each unique group, the function will
+ * Applies the given function to each group of data. For each unique group, the function will
* be passed the group key and an iterator that contains all of the elements in the group. The
* function can return an iterator containing elements of an arbitrary type which will be returned
* as a new [[Dataset]].
@@ -108,7 +102,12 @@ class GroupedDataset[K, T] private[sql](
def flatMap[U : Encoder](f: (K, Iterator[T]) => TraversableOnce[U]): Dataset[U] = {
new Dataset[U](
sqlContext,
- MapGroups(f, groupingAttributes, logicalPlan))
+ MapGroups(
+ f,
+ resolvedKEncoder,
+ resolvedTEncoder,
+ groupingAttributes,
+ logicalPlan))
}
def flatMap[U](f: FlatMapGroupFunction[K, T, U], encoder: Encoder[U]): Dataset[U] = {
@@ -127,15 +126,28 @@ class GroupedDataset[K, T] private[sql](
*/
def map[U : Encoder](f: (K, Iterator[T]) => U): Dataset[U] = {
val func = (key: K, it: Iterator[T]) => Iterator(f(key, it))
- new Dataset[U](
- sqlContext,
- MapGroups(func, groupingAttributes, logicalPlan))
+ flatMap(func)
}
def map[U](f: MapGroupFunction[K, T, U], encoder: Encoder[U]): Dataset[U] = {
map((key, data) => f.call(key, data.asJava))(encoder)
}
+ /**
+ * Reduces the elements of each group of data using the specified binary function.
+ * The given function must be commutative and associative or the result may be non-deterministic.
+ */
+ def reduce(f: (T, T) => T): Dataset[(K, T)] = {
+ val func = (key: K, it: Iterator[T]) => Iterator(key -> it.reduce(f))
+
+ implicit val resultEncoder = ExpressionEncoder.tuple(unresolvedKEncoder, unresolvedTEncoder)
+ flatMap(func)
+ }
+
+ def reduce(f: ReduceFunction[T]): Dataset[(K, T)] = {
+ reduce(f.call _)
+ }
+
// To ensure valid overloading.
protected def agg(expr: Column, exprs: Column*): DataFrame =
groupedData.agg(expr, exprs: _*)
@@ -147,37 +159,17 @@ class GroupedDataset[K, T] private[sql](
* TODO: does not handle aggrecations that return nonflat results,
*/
protected def aggUntyped(columns: TypedColumn[_, _]*): Dataset[_] = {
- val aliases = (groupingAttributes ++ columns.map(_.expr)).map {
- case u: UnresolvedAttribute => UnresolvedAlias(u)
- case expr: NamedExpression => expr
- case expr: Expression => Alias(expr, expr.prettyString)()
- }
-
- val unresolvedPlan = Aggregate(groupingAttributes, aliases, logicalPlan)
-
- // Fill in the input encoders for any aggregators in the plan.
- val withEncoders = unresolvedPlan transformAllExpressions {
- case ta: TypedAggregateExpression if ta.aEncoder.isEmpty =>
- ta.copy(
- aEncoder = Some(tEnc.asInstanceOf[ExpressionEncoder[Any]]),
- children = dataAttributes)
- }
- val execution = new QueryExecution(sqlContext, withEncoders)
-
- val columnEncoders = columns.map(_.encoder.asInstanceOf[ExpressionEncoder[_]])
-
- // Rebind the encoders to the nested schema that will be produced by the aggregation.
- val encoders = (kEnc +: columnEncoders).zip(execution.analyzed.output).map {
- case (e: ExpressionEncoder[_], a) if !e.flat =>
- e.nested(a).resolve(execution.analyzed.output)
- case (e, a) =>
- e.unbind(a :: Nil).resolve(execution.analyzed.output)
- }
+ val encoders = columns.map(_.encoder)
+ val namedColumns =
+ columns.map(
+ _.withInputType(resolvedTEncoder.bind(dataAttributes), dataAttributes).named)
+ val aggregate = Aggregate(groupingAttributes, groupingAttributes ++ namedColumns, logicalPlan)
+ val execution = new QueryExecution(sqlContext, aggregate)
new Dataset(
sqlContext,
execution,
- ExpressionEncoder.tuple(encoders))
+ ExpressionEncoder.tuple(unresolvedKEncoder +: encoders))
}
/**
@@ -230,7 +222,7 @@ class GroupedDataset[K, T] private[sql](
def cogroup[U, R : Encoder](
other: GroupedDataset[K, U])(
f: (K, Iterator[T], Iterator[U]) => TraversableOnce[R]): Dataset[R] = {
- implicit def uEnc: Encoder[U] = other.tEncoder
+ implicit def uEnc: Encoder[U] = other.unresolvedTEncoder
new Dataset[R](
sqlContext,
CoGroup(
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala
index dfcbac8687..3f2775896b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala
@@ -55,7 +55,7 @@ case class TypedAggregateExpression(
aEncoder: Option[ExpressionEncoder[Any]],
bEncoder: ExpressionEncoder[Any],
cEncoder: ExpressionEncoder[Any],
- children: Seq[Expression],
+ children: Seq[Attribute],
mutableAggBufferOffset: Int,
inputAggBufferOffset: Int)
extends ImperativeAggregate with Logging {
@@ -78,8 +78,7 @@ case class TypedAggregateExpression(
override lazy val resolved: Boolean = aEncoder.isDefined
- override lazy val inputTypes: Seq[DataType] =
- aEncoder.map(_.schema.map(_.dataType)).getOrElse(Nil)
+ override lazy val inputTypes: Seq[DataType] = Nil
override val aggBufferSchema: StructType = bEncoder.schema
@@ -90,12 +89,8 @@ case class TypedAggregateExpression(
override val inputAggBufferAttributes: Seq[AttributeReference] =
aggBufferAttributes.map(_.newInstance())
- lazy val inputAttributes = aEncoder.get.schema.toAttributes
- lazy val inputMapping = AttributeMap(inputAttributes.zip(children))
- lazy val boundA =
- aEncoder.get.copy(constructExpression = aEncoder.get.constructExpression transform {
- case a: AttributeReference => inputMapping(a)
- })
+ // We let the dataset do the binding for us.
+ lazy val boundA = aEncoder.get
val bAttributes = bEncoder.schema.toAttributes
lazy val boundB = bEncoder.resolve(bAttributes).bind(bAttributes)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
index ae08fb71bf..ed82c9a6a3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
@@ -311,6 +311,10 @@ case class AppendColumns[T, U](
newColumns: Seq[Attribute],
child: SparkPlan) extends UnaryNode {
+ // We are using an unsafe combiner.
+ override def canProcessSafeRows: Boolean = false
+ override def canProcessUnsafeRows: Boolean = true
+
override def output: Seq[Attribute] = child.output ++ newColumns
override protected def doExecute(): RDD[InternalRow] = {
@@ -349,11 +353,12 @@ case class MapGroups[K, T, U](
child.execute().mapPartitions { iter =>
val grouped = GroupedIterator(iter, groupingAttributes, child.output)
val groupKeyEncoder = kEncoder.bind(groupingAttributes)
+ val groupDataEncoder = tEncoder.bind(child.output)
grouped.flatMap { case (key, rowIter) =>
val result = func(
groupKeyEncoder.fromRow(key),
- rowIter.map(tEncoder.fromRow))
+ rowIter.map(groupDataEncoder.fromRow))
result.map(uEncoder.toRow)
}
}
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
index 33d8388f61..46169ca07d 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
@@ -157,7 +157,6 @@ public class JavaDatasetSuite implements Serializable {
Assert.assertEquals(6, reduced);
}
- @Test
public void testGroupBy() {
List<String> data = Arrays.asList("a", "foo", "bar");
Dataset<String> ds = context.createDataset(data, Encoders.STRING());
@@ -196,6 +195,17 @@ public class JavaDatasetSuite implements Serializable {
Assert.assertEquals(Arrays.asList("1a", "3foobar"), flatMapped.collectAsList());
+ Dataset<Tuple2<Integer, String>> reduced = grouped.reduce(new ReduceFunction<String>() {
+ @Override
+ public String call(String v1, String v2) throws Exception {
+ return v1 + v2;
+ }
+ });
+
+ Assert.assertEquals(
+ Arrays.asList(tuple2(1, "a"), tuple2(3, "foobar")),
+ reduced.collectAsList());
+
List<Integer> data2 = Arrays.asList(2, 6, 10);
Dataset<Integer> ds2 = context.createDataset(data2, Encoders.INT());
GroupedDataset<Integer, Integer> grouped2 = ds2.groupBy(new MapFunction<Integer, Integer>() {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala
index 378cd36527..20896efdfe 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala
@@ -67,6 +67,28 @@ object ComplexResultAgg extends Aggregator[(String, Int), (Long, Long), (Long, L
override def finish(reduction: (Long, Long)): (Long, Long) = reduction
}
+case class AggData(a: Int, b: String)
+object ClassInputAgg extends Aggregator[AggData, Int, Int] with Serializable {
+ /** A zero value for this aggregation. Should satisfy the property that any b + zero = b */
+ override def zero: Int = 0
+
+ /**
+ * Combine two values to produce a new value. For performance, the function may modify `b` and
+ * return it instead of constructing new object for b.
+ */
+ override def reduce(b: Int, a: AggData): Int = b + a.a
+
+ /**
+ * Transform the output of the reduction.
+ */
+ override def finish(reduction: Int): Int = reduction
+
+ /**
+ * Merge two intermediate values
+ */
+ override def merge(b1: Int, b2: Int): Int = b1 + b2
+}
+
class DatasetAggregatorSuite extends QueryTest with SharedSQLContext {
import testImplicits._
@@ -123,4 +145,24 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext {
ds.select(sum((i: Int) => i), sum((i: Int) => i * 2)),
11 -> 22)
}
+
+ test("typed aggregation: class input") {
+ val ds = Seq(AggData(1, "one"), AggData(2, "two")).toDS()
+
+ checkAnswer(
+ ds.select(ClassInputAgg.toColumn),
+ 3)
+ }
+
+ test("typed aggregation: class input with reordering") {
+ val ds = sql("SELECT 'one' AS b, 1 as a").as[AggData]
+
+ checkAnswer(
+ ds.select(ClassInputAgg.toColumn),
+ 1)
+
+ checkAnswer(
+ ds.groupBy(_.b).agg(ClassInputAgg.toColumn),
+ ("one", 1))
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index 6211485287..c23dd46d37 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -218,6 +218,15 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
"a", "30", "b", "3", "c", "1")
}
+ test("groupBy function, reduce") {
+ val ds = Seq("abc", "xyz", "hello").toDS()
+ val agged = ds.groupBy(_.length).reduce(_ + _)
+
+ checkAnswer(
+ agged,
+ 3 -> "abcxyz", 5 -> "hello")
+ }
+
test("groupBy columns, map") {
val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
val grouped = ds.groupBy($"_1")
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
index 7a8b7ae5bf..b5417b195f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
@@ -82,18 +82,21 @@ abstract class QueryTest extends PlanTest {
fail(
s"""
|Exception collecting dataset as objects
- |${ds.encoder}
- |${ds.encoder.constructExpression.treeString}
+ |${ds.resolvedTEncoder}
+ |${ds.resolvedTEncoder.fromRowExpression.treeString}
|${ds.queryExecution}
""".stripMargin, e)
}
if (decoded != expectedAnswer.toSet) {
+ val expected = expectedAnswer.toSet.toSeq.map((a: Any) => a.toString).sorted
+ val actual = decoded.toSet.toSeq.map((a: Any) => a.toString).sorted
+
+ val comparision = sideBySide("expected" +: expected, "spark" +: actual).mkString("\n")
fail(
s"""Decoded objects do not match expected objects:
- |Expected: ${expectedAnswer.toSet.toSeq.map((a: Any) => a.toString).sorted}
- |Actual ${decoded.toSet.toSeq.map((a: Any) => a.toString).sorted}
- |${ds.encoder.constructExpression.treeString}
+ |$comparision
+ |${ds.resolvedTEncoder.fromRowExpression.treeString}
""".stripMargin)
}
}