aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorWenchen Fan <wenchen@databricks.com>2015-11-18 10:33:17 -0800
committerMichael Armbrust <michael@databricks.com>2015-11-18 10:33:17 -0800
commitdbf428c87ab34b6f76c75946043bdf5f60c9b1b3 (patch)
treeac7f578f3d695509a53e5bb8d46ed4492985b502
parent33b837333435ceb0c04d1f361a5383c4fe6a5a75 (diff)
downloadspark-dbf428c87ab34b6f76c75946043bdf5f60c9b1b3.tar.gz
spark-dbf428c87ab34b6f76c75946043bdf5f60c9b1b3.tar.bz2
spark-dbf428c87ab34b6f76c75946043bdf5f60c9b1b3.zip
[SPARK-11795][SQL] combine grouping attributes into a single NamedExpression
we use `ExpressionEncoder.tuple` to build the result encoder, which assumes the input encoder should point to a struct type field if it’s non-flat. However, our keyEncoder always point to a flat field/fields: `groupingAttributes`, we should combine them into a single `NamedExpression`. Author: Wenchen Fan <wenchen@databricks.com> Closes #9792 from cloud-fan/agg.
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala9
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala5
2 files changed, 9 insertions, 5 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
index c66162ee21..3f84e22a10 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
@@ -22,7 +22,7 @@ import scala.collection.JavaConverters._
import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.function._
import org.apache.spark.sql.catalyst.encoders.{FlatEncoder, ExpressionEncoder, encoderFor}
-import org.apache.spark.sql.catalyst.expressions.Attribute
+import org.apache.spark.sql.catalyst.expressions.{Alias, CreateStruct, Attribute}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.execution.QueryExecution
@@ -187,7 +187,12 @@ class GroupedDataset[K, T] private[sql](
val namedColumns =
columns.map(
_.withInputType(resolvedTEncoder, dataAttributes).named)
- val aggregate = Aggregate(groupingAttributes, groupingAttributes ++ namedColumns, logicalPlan)
+ val keyColumn = if (groupingAttributes.length > 1) {
+ Alias(CreateStruct(groupingAttributes), "key")()
+ } else {
+ groupingAttributes.head
+ }
+ val aggregate = Aggregate(groupingAttributes, keyColumn +: namedColumns, logicalPlan)
val execution = new QueryExecution(sqlContext, aggregate)
new Dataset(
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index 198962b8fb..b6db583dfe 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -84,8 +84,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
("a", 2), ("b", 3), ("c", 4))
}
- ignore("Dataset should set the resolved encoders internally for maps") {
- // TODO: Enable this once we fix SPARK-11793.
+ test("map and group by with class data") {
// We inject a group by here to make sure this test case is future proof
// when we implement better pipelining and local execution mode.
val ds: Dataset[(ClassData, Long)] = Seq(ClassData("one", 1), ClassData("two", 2)).toDS()
@@ -94,7 +93,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
checkAnswer(
ds,
- (ClassData("one", 1), 1L), (ClassData("two", 2), 1L))
+ (ClassData("one", 2), 1L), (ClassData("two", 3), 1L))
}
test("select") {