From 0a4844f90a712e796c9404b422cea76d21a5d2e3 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Mon, 11 May 2015 11:35:16 -0700 Subject: [SPARK-7462] By default retain group by columns in aggregate Updated Java, Scala, Python, and R. Author: Reynold Xin Author: Shivaram Venkataraman Closes #5996 from rxin/groupby-retain and squashes the following commits: aac7119 [Reynold Xin] Merge branch 'groupby-retain' of github.com:rxin/spark into groupby-retain f6858f6 [Reynold Xin] Merge branch 'master' into groupby-retain 5f923c0 [Reynold Xin] Merge pull request #15 from shivaram/sparkr-groupby-retrain c1de670 [Shivaram Venkataraman] Revert workaround in SparkR to retain grouped cols Based on reverting code added in commit https://github.com/amplab-extras/spark/commit/9a6be746efc9fafad88122fa2267862ef87aa0e1 b8b87e1 [Reynold Xin] Fixed DataFrameJoinSuite. d910141 [Reynold Xin] Updated rest of the files 1e6e666 [Reynold Xin] [SPARK-7462] By default retain group by columns in aggregate --- R/pkg/R/group.R | 4 +- python/pyspark/sql/dataframe.py | 2 +- .../scala/org/apache/spark/sql/GroupedData.scala | 15 +- .../main/scala/org/apache/spark/sql/SQLConf.scala | 6 + .../org/apache/spark/sql/api/r/SQLUtils.scala | 11 -- .../spark/sql/execution/stat/StatFunctions.scala | 2 +- .../apache/spark/sql/DataFrameAggregateSuite.scala | 193 +++++++++++++++++++++ .../org/apache/spark/sql/DataFrameJoinSuite.scala | 4 +- .../org/apache/spark/sql/DataFrameSuite.scala | 151 +--------------- .../test/scala/org/apache/spark/sql/TestData.scala | 2 - 10 files changed, 218 insertions(+), 172 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala diff --git a/R/pkg/R/group.R b/R/pkg/R/group.R index 5a7a8a2cab..b758481997 100644 --- a/R/pkg/R/group.R +++ b/R/pkg/R/group.R @@ -102,9 +102,7 @@ setMethod("agg", } } jcols <- lapply(cols, function(c) { c@jc }) - # the GroupedData.agg(col, cols*) API does not contain grouping Column - sdf <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "aggWithGrouping", - x@sgd, listToSeq(jcols)) + sdf <- callJMethod(x@sgd, "agg", jcols[[1]], listToSeq(jcols[-1])) } else { stop("agg can only support Column or character") } diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index a9697999e8..c2fa6c8738 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1069,7 +1069,7 @@ class GroupedData(object): >>> from pyspark.sql import functions as F >>> gdf.agg(F.min(df.age)).collect() - [Row(MIN(age)=2), Row(MIN(age)=5)] + [Row(name=u'Alice', MIN(age)=2), Row(name=u'Bob', MIN(age)=5)] """ assert exprs, "exprs should not be empty" if len(exprs) == 1 and isinstance(exprs[0], dict): diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala index 53ad67372e..003a620dcc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala @@ -135,8 +135,9 @@ class GroupedData protected[sql](df: DataFrame, groupingExprs: Seq[Expression]) } /** - * Compute aggregates by specifying a series of aggregate columns. Unlike other methods in this - * class, the resulting [[DataFrame]] won't automatically include the grouping columns. + * Compute aggregates by specifying a series of aggregate columns. Note that this function by + * default retains the grouping columns in its output. To not retain grouping columns, set + * `spark.sql.retainGroupColumns` to false. * * The available aggregate methods are defined in [[org.apache.spark.sql.functions]]. * @@ -158,7 +159,15 @@ class GroupedData protected[sql](df: DataFrame, groupingExprs: Seq[Expression]) case expr: NamedExpression => expr case expr: Expression => Alias(expr, expr.prettyString)() } - DataFrame(df.sqlContext, Aggregate(groupingExprs, aggExprs, df.logicalPlan)) + if (df.sqlContext.conf.dataFrameRetainGroupColumns) { + val retainedExprs = groupingExprs.map { + case expr: NamedExpression => expr + case expr: Expression => Alias(expr, expr.prettyString)() + } + DataFrame(df.sqlContext, Aggregate(groupingExprs, retainedExprs ++ aggExprs, df.logicalPlan)) + } else { + DataFrame(df.sqlContext, Aggregate(groupingExprs, aggExprs, df.logicalPlan)) + } } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index 98a75bb4ed..dcac97beaf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -74,6 +74,9 @@ private[spark] object SQLConf { // See SPARK-6231. val DATAFRAME_SELF_JOIN_AUTO_RESOLVE_AMBIGUITY = "spark.sql.selfJoinAutoResolveAmbiguity" + // Whether to retain group by columns or not in GroupedData.agg. + val DATAFRAME_RETAIN_GROUP_COLUMNS = "spark.sql.retainGroupColumns" + val USE_SQL_SERIALIZER2 = "spark.sql.useSerializer2" val USE_JACKSON_STREAMING_API = "spark.sql.json.useJacksonStreamingAPI" @@ -242,6 +245,9 @@ private[sql] class SQLConf extends Serializable with CatalystConf { private[spark] def dataFrameSelfJoinAutoResolveAmbiguity: Boolean = getConf(DATAFRAME_SELF_JOIN_AUTO_RESOLVE_AMBIGUITY, "true").toBoolean + + private[spark] def dataFrameRetainGroupColumns: Boolean = + getConf(DATAFRAME_RETAIN_GROUP_COLUMNS, "true").toBoolean /** ********************** SQLConf functionality methods ************ */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala index ae77f72998..423ecdff58 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala @@ -72,17 +72,6 @@ private[r] object SQLUtils { sqlContext.createDataFrame(rowRDD, schema) } - // A helper to include grouping columns in Agg() - def aggWithGrouping(gd: GroupedData, exprs: Column*): DataFrame = { - val aggExprs = exprs.map { col => - col.expr match { - case expr: NamedExpression => expr - case expr: Expression => Alias(expr, expr.simpleString)() - } - } - gd.toDF(aggExprs) - } - def dfToRowRDD(df: DataFrame): JavaRDD[Array[Byte]] = { df.map(r => rowToRBytes(r)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala index 71b7f6c2a6..d22f5fd2d4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala @@ -104,7 +104,7 @@ private[sql] object StatFunctions extends Logging { /** Generate a table of frequencies for the elements of two columns. */ private[sql] def crossTabulate(df: DataFrame, col1: String, col2: String): DataFrame = { val tableName = s"${col1}_$col2" - val counts = df.groupBy(col1, col2).agg(col(col1), col(col2), count("*")).take(1e6.toInt) + val counts = df.groupBy(col1, col2).agg(count("*")).take(1e6.toInt) if (counts.length == 1e6.toInt) { logWarning("The maximum limit of 1e6 pairs have been collected, which may not be all of " + "the pairs. Please try reducing the amount of distinct items in your columns.") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala new file mode 100644 index 0000000000..35a574f354 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -0,0 +1,193 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.TestData._ +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.TestSQLContext +import org.apache.spark.sql.test.TestSQLContext.implicits._ +import org.apache.spark.sql.types.DecimalType + + +class DataFrameAggregateSuite extends QueryTest { + + test("groupBy") { + checkAnswer( + testData2.groupBy("a").agg(sum($"b")), + Seq(Row(1, 3), Row(2, 3), Row(3, 3)) + ) + checkAnswer( + testData2.groupBy("a").agg(sum($"b").as("totB")).agg(sum('totB)), + Row(9) + ) + checkAnswer( + testData2.groupBy("a").agg(count("*")), + Row(1, 2) :: Row(2, 2) :: Row(3, 2) :: Nil + ) + checkAnswer( + testData2.groupBy("a").agg(Map("*" -> "count")), + Row(1, 2) :: Row(2, 2) :: Row(3, 2) :: Nil + ) + checkAnswer( + testData2.groupBy("a").agg(Map("b" -> "sum")), + Row(1, 3) :: Row(2, 3) :: Row(3, 3) :: Nil + ) + + val df1 = Seq(("a", 1, 0, "b"), ("b", 2, 4, "c"), ("a", 2, 3, "d")) + .toDF("key", "value1", "value2", "rest") + + checkAnswer( + df1.groupBy("key").min(), + df1.groupBy("key").min("value1", "value2").collect() + ) + checkAnswer( + df1.groupBy("key").min("value2"), + Seq(Row("a", 0), Row("b", 4)) + ) + } + + test("spark.sql.retainGroupColumns config") { + checkAnswer( + testData2.groupBy("a").agg(sum($"b")), + Seq(Row(1, 3), Row(2, 3), Row(3, 3)) + ) + + TestSQLContext.conf.setConf("spark.sql.retainGroupColumns", "false") + checkAnswer( + testData2.groupBy("a").agg(sum($"b")), + Seq(Row(3), Row(3), Row(3)) + ) + TestSQLContext.conf.setConf("spark.sql.retainGroupColumns", "true") + } + + test("agg without groups") { + checkAnswer( + testData2.agg(sum('b)), + Row(9) + ) + } + + test("average") { + checkAnswer( + testData2.agg(avg('a)), + Row(2.0)) + + // Also check mean + checkAnswer( + testData2.agg(mean('a)), + Row(2.0)) + + checkAnswer( + testData2.agg(avg('a), sumDistinct('a)), // non-partial + Row(2.0, 6.0) :: Nil) + + checkAnswer( + decimalData.agg(avg('a)), + Row(new java.math.BigDecimal(2.0))) + checkAnswer( + decimalData.agg(avg('a), sumDistinct('a)), // non-partial + Row(new java.math.BigDecimal(2.0), new java.math.BigDecimal(6)) :: Nil) + + checkAnswer( + decimalData.agg(avg('a cast DecimalType(10, 2))), + Row(new java.math.BigDecimal(2.0))) + // non-partial + checkAnswer( + decimalData.agg(avg('a cast DecimalType(10, 2)), sumDistinct('a cast DecimalType(10, 2))), + Row(new java.math.BigDecimal(2.0), new java.math.BigDecimal(6)) :: Nil) + } + + test("null average") { + checkAnswer( + testData3.agg(avg('b)), + Row(2.0)) + + checkAnswer( + testData3.agg(avg('b), countDistinct('b)), + Row(2.0, 1)) + + checkAnswer( + testData3.agg(avg('b), sumDistinct('b)), // non-partial + Row(2.0, 2.0)) + } + + test("zero average") { + val emptyTableData = Seq.empty[(Int, Int)].toDF("a", "b") + checkAnswer( + emptyTableData.agg(avg('a)), + Row(null)) + + checkAnswer( + emptyTableData.agg(avg('a), sumDistinct('b)), // non-partial + Row(null, null)) + } + + test("count") { + assert(testData2.count() === testData2.map(_ => 1).count()) + + checkAnswer( + testData2.agg(count('a), sumDistinct('a)), // non-partial + Row(6, 6.0)) + } + + test("null count") { + checkAnswer( + testData3.groupBy('a).agg(count('b)), + Seq(Row(1,0), Row(2, 1)) + ) + + checkAnswer( + testData3.groupBy('a).agg(count('a + 'b)), + Seq(Row(1,0), Row(2, 1)) + ) + + checkAnswer( + testData3.agg(count('a), count('b), count(lit(1)), countDistinct('a), countDistinct('b)), + Row(2, 1, 2, 2, 1) + ) + + checkAnswer( + testData3.agg(count('b), countDistinct('b), sumDistinct('b)), // non-partial + Row(1, 1, 2) + ) + } + + test("zero count") { + val emptyTableData = Seq.empty[(Int, Int)].toDF("a", "b") + assert(emptyTableData.count() === 0) + + checkAnswer( + emptyTableData.agg(count('a), sumDistinct('a)), // non-partial + Row(0, null)) + } + + test("zero sum") { + val emptyTableData = Seq.empty[(Int, Int)].toDF("a", "b") + checkAnswer( + emptyTableData.agg(sum('a)), + Row(null)) + } + + test("zero sum distinct") { + val emptyTableData = Seq.empty[(Int, Int)].toDF("a", "b") + checkAnswer( + emptyTableData.agg(sumDistinct('a)), + Row(null)) + } + +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala index f005f55b64..787f3f175f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala @@ -77,8 +77,8 @@ class DataFrameJoinSuite extends QueryTest { df.join(df, df("key") === df("key") && df("value") === 1), Row(1, "1", 1, "1") :: Nil) - val left = df.groupBy("key").agg($"key", count("*")) - val right = df.groupBy("key").agg($"key", sum("key")) + val left = df.groupBy("key").agg(count("*")) + val right = df.groupBy("key").agg(sum("key")) checkAnswer( left.join(right, left("key") === right("key")), Row(1, 1, 1, 1) :: Row(2, 1, 2, 2) :: Nil) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index cf590cbd52..7552c12881 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -22,7 +22,6 @@ import scala.language.postfixOps import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ import org.apache.spark.sql.test.{ExamplePointUDT, ExamplePoint, TestSQLContext} -import org.apache.spark.sql.test.TestSQLContext.logicalPlanToSparkQuery import org.apache.spark.sql.test.TestSQLContext.implicits._ @@ -63,7 +62,7 @@ class DataFrameSuite extends QueryTest { val df = Seq((1,(1,1))).toDF() checkAnswer( - df.groupBy("_1").agg(col("_1"), sum("_2._1")).toDF("key", "total"), + df.groupBy("_1").agg(sum("_2._1")).toDF("key", "total"), Row(1, 1) :: Nil) } @@ -128,7 +127,7 @@ class DataFrameSuite extends QueryTest { df2 .select('_1 as 'letter, 'number) .groupBy('letter) - .agg('letter, countDistinct('number)), + .agg(countDistinct('number)), Row("a", 3) :: Row("b", 2) :: Row("c", 1) :: Nil ) } @@ -165,48 +164,6 @@ class DataFrameSuite extends QueryTest { testData.select('key).collect().toSeq) } - test("groupBy") { - checkAnswer( - testData2.groupBy("a").agg($"a", sum($"b")), - Seq(Row(1, 3), Row(2, 3), Row(3, 3)) - ) - checkAnswer( - testData2.groupBy("a").agg($"a", sum($"b").as("totB")).agg(sum('totB)), - Row(9) - ) - checkAnswer( - testData2.groupBy("a").agg(col("a"), count("*")), - Row(1, 2) :: Row(2, 2) :: Row(3, 2) :: Nil - ) - checkAnswer( - testData2.groupBy("a").agg(Map("*" -> "count")), - Row(1, 2) :: Row(2, 2) :: Row(3, 2) :: Nil - ) - checkAnswer( - testData2.groupBy("a").agg(Map("b" -> "sum")), - Row(1, 3) :: Row(2, 3) :: Row(3, 3) :: Nil - ) - - val df1 = Seq(("a", 1, 0, "b"), ("b", 2, 4, "c"), ("a", 2, 3, "d")) - .toDF("key", "value1", "value2", "rest") - - checkAnswer( - df1.groupBy("key").min(), - df1.groupBy("key").min("value1", "value2").collect() - ) - checkAnswer( - df1.groupBy("key").min("value2"), - Seq(Row("a", 0), Row("b", 4)) - ) - } - - test("agg without groups") { - checkAnswer( - testData2.agg(sum('b)), - Row(9) - ) - } - test("convert $\"attribute name\" into unresolved attribute") { checkAnswer( testData.where($"key" === lit(1)).select($"value"), @@ -303,110 +260,6 @@ class DataFrameSuite extends QueryTest { mapData.take(1).map(r => Row.fromSeq(r.productIterator.toSeq))) } - test("average") { - checkAnswer( - testData2.agg(avg('a)), - Row(2.0)) - - // Also check mean - checkAnswer( - testData2.agg(mean('a)), - Row(2.0)) - - checkAnswer( - testData2.agg(avg('a), sumDistinct('a)), // non-partial - Row(2.0, 6.0) :: Nil) - - checkAnswer( - decimalData.agg(avg('a)), - Row(new java.math.BigDecimal(2.0))) - checkAnswer( - decimalData.agg(avg('a), sumDistinct('a)), // non-partial - Row(new java.math.BigDecimal(2.0), new java.math.BigDecimal(6)) :: Nil) - - checkAnswer( - decimalData.agg(avg('a cast DecimalType(10, 2))), - Row(new java.math.BigDecimal(2.0))) - // non-partial - checkAnswer( - decimalData.agg(avg('a cast DecimalType(10, 2)), sumDistinct('a cast DecimalType(10, 2))), - Row(new java.math.BigDecimal(2.0), new java.math.BigDecimal(6)) :: Nil) - } - - test("null average") { - checkAnswer( - testData3.agg(avg('b)), - Row(2.0)) - - checkAnswer( - testData3.agg(avg('b), countDistinct('b)), - Row(2.0, 1)) - - checkAnswer( - testData3.agg(avg('b), sumDistinct('b)), // non-partial - Row(2.0, 2.0)) - } - - test("zero average") { - checkAnswer( - emptyTableData.agg(avg('a)), - Row(null)) - - checkAnswer( - emptyTableData.agg(avg('a), sumDistinct('b)), // non-partial - Row(null, null)) - } - - test("count") { - assert(testData2.count() === testData2.map(_ => 1).count()) - - checkAnswer( - testData2.agg(count('a), sumDistinct('a)), // non-partial - Row(6, 6.0)) - } - - test("null count") { - checkAnswer( - testData3.groupBy('a).agg('a, count('b)), - Seq(Row(1,0), Row(2, 1)) - ) - - checkAnswer( - testData3.groupBy('a).agg('a, count('a + 'b)), - Seq(Row(1,0), Row(2, 1)) - ) - - checkAnswer( - testData3.agg(count('a), count('b), count(lit(1)), countDistinct('a), countDistinct('b)), - Row(2, 1, 2, 2, 1) - ) - - checkAnswer( - testData3.agg(count('b), countDistinct('b), sumDistinct('b)), // non-partial - Row(1, 1, 2) - ) - } - - test("zero count") { - assert(emptyTableData.count() === 0) - - checkAnswer( - emptyTableData.agg(count('a), sumDistinct('a)), // non-partial - Row(0, null)) - } - - test("zero sum") { - checkAnswer( - emptyTableData.agg(sum('a)), - Row(null)) - } - - test("zero sum distinct") { - checkAnswer( - emptyTableData.agg(sumDistinct('a)), - Row(null)) - } - test("except") { checkAnswer( lowerCaseData.except(upperCaseData), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala index 225b51bd73..446771ab2a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala @@ -86,8 +86,6 @@ object TestData { TestData3(2, Some(2)) :: Nil).toDF() testData3.registerTempTable("testData3") - val emptyTableData = logical.LocalRelation($"a".int, $"b".int) - case class UpperCaseData(N: Int, L: String) val upperCaseData = TestSQLContext.sparkContext.parallelize( -- cgit v1.2.3