aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorWenchen Fan <wenchen@databricks.com>2015-12-01 10:22:55 -0800
committerMichael Armbrust <michael@databricks.com>2015-12-01 10:22:55 -0800
commit8ddc55f1d582cccc3ca135510b2ea776e889e481 (patch)
tree610f9cb1c37effa35102f356f01cd5a64ef63fd4
parent69dbe6b40df35d488d4ee343098ac70d00bbdafb (diff)
downloadspark-8ddc55f1d582cccc3ca135510b2ea776e889e481.tar.gz
spark-8ddc55f1d582cccc3ca135510b2ea776e889e481.tar.bz2
spark-8ddc55f1d582cccc3ca135510b2ea776e889e481.zip
[SPARK-12068][SQL] use a single column in Dataset.groupBy and count will fail
The reason is that, for a single culumn `RowEncoder`(or a single field product encoder), when we use it as the encoder for grouping key, we should also combine the grouping attributes, although there is only one grouping attribute. Author: Wenchen Fan <wenchen@databricks.com> Closes #10059 from cloud-fan/bug.
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala7
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala19
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala6
4 files changed, 27 insertions, 7 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index da46001332..c357f88a94 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -70,7 +70,7 @@ class Dataset[T] private[sql](
* implicit so that we can use it when constructing new [[Dataset]] objects that have the same
* object type (that will be possibly resolved to a different schema).
*/
- private implicit val unresolvedTEncoder: ExpressionEncoder[T] = encoderFor(tEncoder)
+ private[sql] implicit val unresolvedTEncoder: ExpressionEncoder[T] = encoderFor(tEncoder)
/** The encoder for this [[Dataset]] that has been resolved to its output schema. */
private[sql] val resolvedTEncoder: ExpressionEncoder[T] =
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
index a10a89342f..4bf0b256fc 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
@@ -228,10 +228,11 @@ class GroupedDataset[K, V] private[sql](
val namedColumns =
columns.map(
_.withInputType(resolvedVEncoder, dataAttributes).named)
- val keyColumn = if (groupingAttributes.length > 1) {
- Alias(CreateStruct(groupingAttributes), "key")()
- } else {
+ val keyColumn = if (resolvedKEncoder.flat) {
+ assert(groupingAttributes.length == 1)
groupingAttributes.head
+ } else {
+ Alias(CreateStruct(groupingAttributes), "key")()
}
val aggregate = Aggregate(groupingAttributes, keyColumn +: namedColumns, logicalPlan)
val execution = new QueryExecution(sqlContext, aggregate)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index 7d539180de..a2c8d20156 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -272,6 +272,16 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
3 -> "abcxyz", 5 -> "hello")
}
+ test("groupBy single field class, count") {
+ val ds = Seq("abc", "xyz", "hello").toDS()
+ val count = ds.groupBy(s => Tuple1(s.length)).count()
+
+ checkAnswer(
+ count,
+ (Tuple1(3), 2L), (Tuple1(5), 1L)
+ )
+ }
+
test("groupBy columns, map") {
val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
val grouped = ds.groupBy($"_1")
@@ -282,6 +292,15 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
("a", 30), ("b", 3), ("c", 1))
}
+ test("groupBy columns, count") {
+ val ds = Seq("a" -> 1, "b" -> 1, "a" -> 2).toDS()
+ val count = ds.groupBy($"_1").count()
+
+ checkAnswer(
+ count,
+ (Row("a"), 2L), (Row("b"), 1L))
+ }
+
test("groupBy columns asKey, map") {
val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
val grouped = ds.groupBy($"_1").keyAs[String]
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
index 6ea1fe4ccf..8f476dd0f9 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
@@ -64,12 +64,12 @@ abstract class QueryTest extends PlanTest {
* for cases where reordering is done on fields. For such tests, user `checkDecoding` instead
* which performs a subset of the checks done by this function.
*/
- protected def checkAnswer[T : Encoder](
- ds: => Dataset[T],
+ protected def checkAnswer[T](
+ ds: Dataset[T],
expectedAnswer: T*): Unit = {
checkAnswer(
ds.toDF(),
- sqlContext.createDataset(expectedAnswer).toDF().collect().toSeq)
+ sqlContext.createDataset(expectedAnswer)(ds.unresolvedTEncoder).toDF().collect().toSeq)
checkDecoding(ds, expectedAnswer: _*)
}