diff options
author | Wenchen Fan <wenchen@databricks.com> | 2015-12-01 10:22:55 -0800 |
---|---|---|
committer | Michael Armbrust <michael@databricks.com> | 2015-12-01 10:22:55 -0800 |
commit | 8ddc55f1d582cccc3ca135510b2ea776e889e481 (patch) | |
tree | 610f9cb1c37effa35102f356f01cd5a64ef63fd4 /sql/core | |
parent | 69dbe6b40df35d488d4ee343098ac70d00bbdafb (diff) | |
download | spark-8ddc55f1d582cccc3ca135510b2ea776e889e481.tar.gz spark-8ddc55f1d582cccc3ca135510b2ea776e889e481.tar.bz2 spark-8ddc55f1d582cccc3ca135510b2ea776e889e481.zip |
[SPARK-12068][SQL] use a single column in Dataset.groupBy and count will fail
The reason is that, for a single culumn `RowEncoder`(or a single field product encoder), when we use it as the encoder for grouping key, we should also combine the grouping attributes, although there is only one grouping attribute.
Author: Wenchen Fan <wenchen@databricks.com>
Closes #10059 from cloud-fan/bug.
Diffstat (limited to 'sql/core')
4 files changed, 27 insertions, 7 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index da46001332..c357f88a94 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -70,7 +70,7 @@ class Dataset[T] private[sql]( * implicit so that we can use it when constructing new [[Dataset]] objects that have the same * object type (that will be possibly resolved to a different schema). */ - private implicit val unresolvedTEncoder: ExpressionEncoder[T] = encoderFor(tEncoder) + private[sql] implicit val unresolvedTEncoder: ExpressionEncoder[T] = encoderFor(tEncoder) /** The encoder for this [[Dataset]] that has been resolved to its output schema. */ private[sql] val resolvedTEncoder: ExpressionEncoder[T] = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala index a10a89342f..4bf0b256fc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala @@ -228,10 +228,11 @@ class GroupedDataset[K, V] private[sql]( val namedColumns = columns.map( _.withInputType(resolvedVEncoder, dataAttributes).named) - val keyColumn = if (groupingAttributes.length > 1) { - Alias(CreateStruct(groupingAttributes), "key")() - } else { + val keyColumn = if (resolvedKEncoder.flat) { + assert(groupingAttributes.length == 1) groupingAttributes.head + } else { + Alias(CreateStruct(groupingAttributes), "key")() } val aggregate = Aggregate(groupingAttributes, keyColumn +: namedColumns, logicalPlan) val execution = new QueryExecution(sqlContext, aggregate) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 7d539180de..a2c8d20156 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -272,6 +272,16 @@ class DatasetSuite extends QueryTest with SharedSQLContext { 3 -> "abcxyz", 5 -> "hello") } + test("groupBy single field class, count") { + val ds = Seq("abc", "xyz", "hello").toDS() + val count = ds.groupBy(s => Tuple1(s.length)).count() + + checkAnswer( + count, + (Tuple1(3), 2L), (Tuple1(5), 1L) + ) + } + test("groupBy columns, map") { val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() val grouped = ds.groupBy($"_1") @@ -282,6 +292,15 @@ class DatasetSuite extends QueryTest with SharedSQLContext { ("a", 30), ("b", 3), ("c", 1)) } + test("groupBy columns, count") { + val ds = Seq("a" -> 1, "b" -> 1, "a" -> 2).toDS() + val count = ds.groupBy($"_1").count() + + checkAnswer( + count, + (Row("a"), 2L), (Row("b"), 1L)) + } + test("groupBy columns asKey, map") { val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() val grouped = ds.groupBy($"_1").keyAs[String] diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index 6ea1fe4ccf..8f476dd0f9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -64,12 +64,12 @@ abstract class QueryTest extends PlanTest { * for cases where reordering is done on fields. For such tests, user `checkDecoding` instead * which performs a subset of the checks done by this function. */ - protected def checkAnswer[T : Encoder]( - ds: => Dataset[T], + protected def checkAnswer[T]( + ds: Dataset[T], expectedAnswer: T*): Unit = { checkAnswer( ds.toDF(), - sqlContext.createDataset(expectedAnswer).toDF().collect().toSeq) + sqlContext.createDataset(expectedAnswer)(ds.unresolvedTEncoder).toDF().collect().toSeq) checkDecoding(ds, expectedAnswer: _*) } |