diff options
author | Wenchen Fan <wenchen@databricks.com> | 2015-11-18 10:33:17 -0800 |
---|---|---|
committer | Michael Armbrust <michael@databricks.com> | 2015-11-18 10:33:17 -0800 |
commit | dbf428c87ab34b6f76c75946043bdf5f60c9b1b3 (patch) | |
tree | ac7f578f3d695509a53e5bb8d46ed4492985b502 | |
parent | 33b837333435ceb0c04d1f361a5383c4fe6a5a75 (diff) | |
download | spark-dbf428c87ab34b6f76c75946043bdf5f60c9b1b3.tar.gz spark-dbf428c87ab34b6f76c75946043bdf5f60c9b1b3.tar.bz2 spark-dbf428c87ab34b6f76c75946043bdf5f60c9b1b3.zip |
[SPARK-11795][SQL] combine grouping attributes into a single NamedExpression
we use `ExpressionEncoder.tuple` to build the result encoder, which assumes the input encoder should point to a struct type field if it’s non-flat.
However, our keyEncoder always point to a flat field/fields: `groupingAttributes`, we should combine them into a single `NamedExpression`.
Author: Wenchen Fan <wenchen@databricks.com>
Closes #9792 from cloud-fan/agg.
-rw-r--r-- | sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala | 9 | ||||
-rw-r--r-- | sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala | 5 |
2 files changed, 9 insertions, 5 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala index c66162ee21..3f84e22a10 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala @@ -22,7 +22,7 @@ import scala.collection.JavaConverters._ import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.function._ import org.apache.spark.sql.catalyst.encoders.{FlatEncoder, ExpressionEncoder, encoderFor} -import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.expressions.{Alias, CreateStruct, Attribute} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.QueryExecution @@ -187,7 +187,12 @@ class GroupedDataset[K, T] private[sql]( val namedColumns = columns.map( _.withInputType(resolvedTEncoder, dataAttributes).named) - val aggregate = Aggregate(groupingAttributes, groupingAttributes ++ namedColumns, logicalPlan) + val keyColumn = if (groupingAttributes.length > 1) { + Alias(CreateStruct(groupingAttributes), "key")() + } else { + groupingAttributes.head + } + val aggregate = Aggregate(groupingAttributes, keyColumn +: namedColumns, logicalPlan) val execution = new QueryExecution(sqlContext, aggregate) new Dataset( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 198962b8fb..b6db583dfe 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -84,8 +84,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { ("a", 2), ("b", 3), ("c", 4)) } - ignore("Dataset should set the resolved encoders internally for maps") { - // TODO: Enable this once we fix SPARK-11793. + test("map and group by with class data") { // We inject a group by here to make sure this test case is future proof // when we implement better pipelining and local execution mode. val ds: Dataset[(ClassData, Long)] = Seq(ClassData("one", 1), ClassData("two", 2)).toDS() @@ -94,7 +93,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { checkAnswer( ds, - (ClassData("one", 1), 1L), (ClassData("two", 2), 1L)) + (ClassData("one", 2), 1L), (ClassData("two", 3), 1L)) } test("select") { |