diff options
author | Liang-Chi Hsieh <viirya@gmail.com> | 2015-02-16 10:06:11 -0800 |
---|---|---|
committer | Reynold Xin <rxin@databricks.com> | 2015-02-16 10:06:22 -0800 |
commit | 0165e9d1324e24571c702b32d8d76edca8808887 (patch) | |
tree | 5ba3ab31177f72ddcdc9d0a41887e2d208ff9d67 /sql | |
parent | 78f7edb85be5a397c0d1a2f3fd26aa83675cc0b1 (diff) | |
download | spark-0165e9d1324e24571c702b32d8d76edca8808887.tar.gz spark-0165e9d1324e24571c702b32d8d76edca8808887.tar.bz2 spark-0165e9d1324e24571c702b32d8d76edca8808887.zip |
[SPARK-5799][SQL] Compute aggregation function on specified numeric columns
Compute aggregation function on specified numeric columns. For example:
val df = Seq(("a", 1, 0, "b"), ("b", 2, 4, "c"), ("a", 2, 3, "d")).toDataFrame("key", "value1", "value2", "rest")
df.groupBy("key").min("value2")
Author: Liang-Chi Hsieh <viirya@gmail.com>
Closes #4592 from viirya/specific_cols_agg and squashes the following commits:
9446896 [Liang-Chi Hsieh] For comments.
314c4cd [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into specific_cols_agg
353fad7 [Liang-Chi Hsieh] For python unit tests.
54ed0c4 [Liang-Chi Hsieh] Address comments.
b079e6b [Liang-Chi Hsieh] Remove duplicate codes.
55100fb [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into specific_cols_agg
880c2ac [Liang-Chi Hsieh] Fix Python style checks.
4c63a01 [Liang-Chi Hsieh] Fix pyspark.
b1a24fc [Liang-Chi Hsieh] Address comments.
2592f29 [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into specific_cols_agg
27069c3 [Liang-Chi Hsieh] Combine functions and add varargs annotation.
371a3f7 [Liang-Chi Hsieh] Compute aggregation function on specified numeric columns.
(cherry picked from commit 5c78be7a515fc2fc92cda0517318e7b5d85762f4)
Signed-off-by: Reynold Xin <rxin@databricks.com>
Diffstat (limited to 'sql')
3 files changed, 62 insertions, 11 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala index 7b7efbe347..9eb0c13140 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala @@ -88,12 +88,12 @@ private[sql] class DataFrameImpl protected[sql]( } } - protected[sql] def numericColumns: Seq[Expression] = { + protected[sql] def numericColumns(): Seq[Expression] = { schema.fields.filter(_.dataType.isInstanceOf[NumericType]).map { n => queryExecution.analyzed.resolve(n.name, sqlContext.analyzer.resolver).get } } - + override def toDF(colNames: String*): DataFrame = { require(schema.size == colNames.size, "The number of columns doesn't match.\n" + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala index 0868013fe7..a5a677b688 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala @@ -23,6 +23,8 @@ import scala.collection.JavaConversions._ import org.apache.spark.sql.catalyst.analysis.Star import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.Aggregate +import org.apache.spark.sql.types.NumericType + /** @@ -39,13 +41,30 @@ class GroupedData protected[sql](df: DataFrameImpl, groupingExprs: Seq[Expressio df.sqlContext, Aggregate(groupingExprs, namedGroupingExprs ++ aggExprs, df.logicalPlan)) } - private[this] def aggregateNumericColumns(f: Expression => Expression): Seq[NamedExpression] = { - df.numericColumns.map { c => + private[this] def aggregateNumericColumns(colNames: String*)(f: Expression => Expression) + : Seq[NamedExpression] = { + + val columnExprs = if (colNames.isEmpty) { + // No columns specified. Use all numeric columns. + df.numericColumns + } else { + // Make sure all specified columns are numeric + colNames.map { colName => + val namedExpr = df.resolve(colName) + if (!namedExpr.dataType.isInstanceOf[NumericType]) { + throw new AnalysisException( + s""""$colName" is not a numeric column. """ + + "Aggregation function can only be performed on a numeric column.") + } + namedExpr + } + } + columnExprs.map { c => val a = f(c) Alias(a, a.toString)() } } - + private[this] def strToExpr(expr: String): (Expression => Expression) = { expr.toLowerCase match { case "avg" | "average" | "mean" => Average @@ -152,30 +171,50 @@ class GroupedData protected[sql](df: DataFrameImpl, groupingExprs: Seq[Expressio /** * Compute the average value for each numeric columns for each group. This is an alias for `avg`. * The resulting [[DataFrame]] will also contain the grouping columns. + * When specified columns are given, only compute the average values for them. */ - def mean(): DataFrame = aggregateNumericColumns(Average) - + @scala.annotation.varargs + def mean(colNames: String*): DataFrame = { + aggregateNumericColumns(colNames:_*)(Average) + } + /** * Compute the max value for each numeric columns for each group. * The resulting [[DataFrame]] will also contain the grouping columns. + * When specified columns are given, only compute the max values for them. */ - def max(): DataFrame = aggregateNumericColumns(Max) + @scala.annotation.varargs + def max(colNames: String*): DataFrame = { + aggregateNumericColumns(colNames:_*)(Max) + } /** * Compute the mean value for each numeric columns for each group. * The resulting [[DataFrame]] will also contain the grouping columns. + * When specified columns are given, only compute the mean values for them. */ - def avg(): DataFrame = aggregateNumericColumns(Average) + @scala.annotation.varargs + def avg(colNames: String*): DataFrame = { + aggregateNumericColumns(colNames:_*)(Average) + } /** * Compute the min value for each numeric column for each group. * The resulting [[DataFrame]] will also contain the grouping columns. + * When specified columns are given, only compute the min values for them. */ - def min(): DataFrame = aggregateNumericColumns(Min) + @scala.annotation.varargs + def min(colNames: String*): DataFrame = { + aggregateNumericColumns(colNames:_*)(Min) + } /** * Compute the sum for each numeric columns for each group. * The resulting [[DataFrame]] will also contain the grouping columns. + * When specified columns are given, only compute the sum for them. */ - def sum(): DataFrame = aggregateNumericColumns(Sum) + @scala.annotation.varargs + def sum(colNames: String*): DataFrame = { + aggregateNumericColumns(colNames:_*)(Sum) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index f0cd43632e..524571d9cc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -162,6 +162,18 @@ class DataFrameSuite extends QueryTest { testData2.groupBy("a").agg(Map("b" -> "sum")), Row(1, 3) :: Row(2, 3) :: Row(3, 3) :: Nil ) + + val df1 = Seq(("a", 1, 0, "b"), ("b", 2, 4, "c"), ("a", 2, 3, "d")) + .toDF("key", "value1", "value2", "rest") + + checkAnswer( + df1.groupBy("key").min(), + df1.groupBy("key").min("value1", "value2").collect() + ) + checkAnswer( + df1.groupBy("key").min("value2"), + Seq(Row("a", 0), Row("b", 4)) + ) } test("agg without groups") { |