aboutsummaryrefslogtreecommitdiff
path: root/sql/core
diff options
context:
space:
mode:
authorLiang-Chi Hsieh <viirya@gmail.com>2015-02-16 10:06:11 -0800
committerReynold Xin <rxin@databricks.com>2015-02-16 10:06:11 -0800
commit5c78be7a515fc2fc92cda0517318e7b5d85762f4 (patch)
treeb3685d0c4946bf4a005944a465753d8e308ca75c /sql/core
parenta3afa4a1bff88c4d8a5228fcf1e0cfc132541a22 (diff)
downloadspark-5c78be7a515fc2fc92cda0517318e7b5d85762f4.tar.gz
spark-5c78be7a515fc2fc92cda0517318e7b5d85762f4.tar.bz2
spark-5c78be7a515fc2fc92cda0517318e7b5d85762f4.zip
[SPARK-5799][SQL] Compute aggregation function on specified numeric columns
Compute aggregation function on specified numeric columns. For example: val df = Seq(("a", 1, 0, "b"), ("b", 2, 4, "c"), ("a", 2, 3, "d")).toDataFrame("key", "value1", "value2", "rest") df.groupBy("key").min("value2") Author: Liang-Chi Hsieh <viirya@gmail.com> Closes #4592 from viirya/specific_cols_agg and squashes the following commits: 9446896 [Liang-Chi Hsieh] For comments. 314c4cd [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into specific_cols_agg 353fad7 [Liang-Chi Hsieh] For python unit tests. 54ed0c4 [Liang-Chi Hsieh] Address comments. b079e6b [Liang-Chi Hsieh] Remove duplicate codes. 55100fb [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into specific_cols_agg 880c2ac [Liang-Chi Hsieh] Fix Python style checks. 4c63a01 [Liang-Chi Hsieh] Fix pyspark. b1a24fc [Liang-Chi Hsieh] Address comments. 2592f29 [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into specific_cols_agg 27069c3 [Liang-Chi Hsieh] Combine functions and add varargs annotation. 371a3f7 [Liang-Chi Hsieh] Compute aggregation function on specified numeric columns.
Diffstat (limited to 'sql/core')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala57
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala12
3 files changed, 62 insertions, 11 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala
index 7b7efbe347..9eb0c13140 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala
@@ -88,12 +88,12 @@ private[sql] class DataFrameImpl protected[sql](
}
}
- protected[sql] def numericColumns: Seq[Expression] = {
+ protected[sql] def numericColumns(): Seq[Expression] = {
schema.fields.filter(_.dataType.isInstanceOf[NumericType]).map { n =>
queryExecution.analyzed.resolve(n.name, sqlContext.analyzer.resolver).get
}
}
-
+
override def toDF(colNames: String*): DataFrame = {
require(schema.size == colNames.size,
"The number of columns doesn't match.\n" +
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala
index 0868013fe7..a5a677b688 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala
@@ -23,6 +23,8 @@ import scala.collection.JavaConversions._
import org.apache.spark.sql.catalyst.analysis.Star
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.Aggregate
+import org.apache.spark.sql.types.NumericType
+
/**
@@ -39,13 +41,30 @@ class GroupedData protected[sql](df: DataFrameImpl, groupingExprs: Seq[Expressio
df.sqlContext, Aggregate(groupingExprs, namedGroupingExprs ++ aggExprs, df.logicalPlan))
}
- private[this] def aggregateNumericColumns(f: Expression => Expression): Seq[NamedExpression] = {
- df.numericColumns.map { c =>
+ private[this] def aggregateNumericColumns(colNames: String*)(f: Expression => Expression)
+ : Seq[NamedExpression] = {
+
+ val columnExprs = if (colNames.isEmpty) {
+ // No columns specified. Use all numeric columns.
+ df.numericColumns
+ } else {
+ // Make sure all specified columns are numeric
+ colNames.map { colName =>
+ val namedExpr = df.resolve(colName)
+ if (!namedExpr.dataType.isInstanceOf[NumericType]) {
+ throw new AnalysisException(
+ s""""$colName" is not a numeric column. """ +
+ "Aggregation function can only be performed on a numeric column.")
+ }
+ namedExpr
+ }
+ }
+ columnExprs.map { c =>
val a = f(c)
Alias(a, a.toString)()
}
}
-
+
private[this] def strToExpr(expr: String): (Expression => Expression) = {
expr.toLowerCase match {
case "avg" | "average" | "mean" => Average
@@ -152,30 +171,50 @@ class GroupedData protected[sql](df: DataFrameImpl, groupingExprs: Seq[Expressio
/**
* Compute the average value for each numeric columns for each group. This is an alias for `avg`.
* The resulting [[DataFrame]] will also contain the grouping columns.
+ * When specified columns are given, only compute the average values for them.
*/
- def mean(): DataFrame = aggregateNumericColumns(Average)
-
+ @scala.annotation.varargs
+ def mean(colNames: String*): DataFrame = {
+ aggregateNumericColumns(colNames:_*)(Average)
+ }
+
/**
* Compute the max value for each numeric columns for each group.
* The resulting [[DataFrame]] will also contain the grouping columns.
+ * When specified columns are given, only compute the max values for them.
*/
- def max(): DataFrame = aggregateNumericColumns(Max)
+ @scala.annotation.varargs
+ def max(colNames: String*): DataFrame = {
+ aggregateNumericColumns(colNames:_*)(Max)
+ }
/**
* Compute the mean value for each numeric columns for each group.
* The resulting [[DataFrame]] will also contain the grouping columns.
+ * When specified columns are given, only compute the mean values for them.
*/
- def avg(): DataFrame = aggregateNumericColumns(Average)
+ @scala.annotation.varargs
+ def avg(colNames: String*): DataFrame = {
+ aggregateNumericColumns(colNames:_*)(Average)
+ }
/**
* Compute the min value for each numeric column for each group.
* The resulting [[DataFrame]] will also contain the grouping columns.
+ * When specified columns are given, only compute the min values for them.
*/
- def min(): DataFrame = aggregateNumericColumns(Min)
+ @scala.annotation.varargs
+ def min(colNames: String*): DataFrame = {
+ aggregateNumericColumns(colNames:_*)(Min)
+ }
/**
* Compute the sum for each numeric columns for each group.
* The resulting [[DataFrame]] will also contain the grouping columns.
+ * When specified columns are given, only compute the sum for them.
*/
- def sum(): DataFrame = aggregateNumericColumns(Sum)
+ @scala.annotation.varargs
+ def sum(colNames: String*): DataFrame = {
+ aggregateNumericColumns(colNames:_*)(Sum)
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index f0cd43632e..524571d9cc 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -162,6 +162,18 @@ class DataFrameSuite extends QueryTest {
testData2.groupBy("a").agg(Map("b" -> "sum")),
Row(1, 3) :: Row(2, 3) :: Row(3, 3) :: Nil
)
+
+ val df1 = Seq(("a", 1, 0, "b"), ("b", 2, 4, "c"), ("a", 2, 3, "d"))
+ .toDF("key", "value1", "value2", "rest")
+
+ checkAnswer(
+ df1.groupBy("key").min(),
+ df1.groupBy("key").min("value1", "value2").collect()
+ )
+ checkAnswer(
+ df1.groupBy("key").min("value2"),
+ Seq(Row("a", 0), Row("b", 4))
+ )
}
test("agg without groups") {