aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--python/pyspark/sql/dataframe.py74
-rw-r--r--python/pyspark/sql/functions.py2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala57
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala12
5 files changed, 123 insertions, 26 deletions
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index 1438fe5285..28a59e73a3 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -664,6 +664,18 @@ def dfapi(f):
return _api
+def df_varargs_api(f):
+ def _api(self, *args):
+ jargs = ListConverter().convert(args,
+ self.sql_ctx._sc._gateway._gateway_client)
+ name = f.__name__
+ jdf = getattr(self._jdf, name)(self.sql_ctx._sc._jvm.PythonUtils.toSeq(jargs))
+ return DataFrame(jdf, self.sql_ctx)
+ _api.__name__ = f.__name__
+ _api.__doc__ = f.__doc__
+ return _api
+
+
class GroupedData(object):
"""
@@ -714,30 +726,60 @@ class GroupedData(object):
[Row(age=2, count=1), Row(age=5, count=1)]
"""
- @dfapi
- def mean(self):
+ @df_varargs_api
+ def mean(self, *cols):
"""Compute the average value for each numeric columns
- for each group. This is an alias for `avg`."""
+ for each group. This is an alias for `avg`.
- @dfapi
- def avg(self):
+ >>> df.groupBy().mean('age').collect()
+ [Row(AVG(age#0)=3.5)]
+ >>> df3.groupBy().mean('age', 'height').collect()
+ [Row(AVG(age#4)=3.5, AVG(height#5)=82.5)]
+ """
+
+ @df_varargs_api
+ def avg(self, *cols):
"""Compute the average value for each numeric columns
- for each group."""
+ for each group.
- @dfapi
- def max(self):
+ >>> df.groupBy().avg('age').collect()
+ [Row(AVG(age#0)=3.5)]
+ >>> df3.groupBy().avg('age', 'height').collect()
+ [Row(AVG(age#4)=3.5, AVG(height#5)=82.5)]
+ """
+
+ @df_varargs_api
+ def max(self, *cols):
"""Compute the max value for each numeric columns for
- each group. """
+ each group.
- @dfapi
- def min(self):
+ >>> df.groupBy().max('age').collect()
+ [Row(MAX(age#0)=5)]
+ >>> df3.groupBy().max('age', 'height').collect()
+ [Row(MAX(age#4)=5, MAX(height#5)=85)]
+ """
+
+ @df_varargs_api
+ def min(self, *cols):
"""Compute the min value for each numeric column for
- each group."""
+ each group.
- @dfapi
- def sum(self):
+ >>> df.groupBy().min('age').collect()
+ [Row(MIN(age#0)=2)]
+ >>> df3.groupBy().min('age', 'height').collect()
+ [Row(MIN(age#4)=2, MIN(height#5)=80)]
+ """
+
+ @df_varargs_api
+ def sum(self, *cols):
"""Compute the sum for each numeric columns for each
- group."""
+ group.
+
+ >>> df.groupBy().sum('age').collect()
+ [Row(SUM(age#0)=7)]
+ >>> df3.groupBy().sum('age', 'height').collect()
+ [Row(SUM(age#4)=7, SUM(height#5)=165)]
+ """
def _create_column_from_literal(literal):
@@ -945,6 +987,8 @@ def _test():
globs['sqlCtx'] = SQLContext(sc)
globs['df'] = sc.parallelize([Row(name='Alice', age=2), Row(name='Bob', age=5)]).toDF()
globs['df2'] = sc.parallelize([Row(name='Tom', height=80), Row(name='Bob', height=85)]).toDF()
+ globs['df3'] = sc.parallelize([Row(name='Alice', age=2, height=80),
+ Row(name='Bob', age=5, height=85)]).toDF()
(failure_count, test_count) = doctest.testmod(
pyspark.sql.dataframe, globs=globs,
optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE)
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index 39aa550eeb..d0e090607f 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -158,6 +158,8 @@ def _test():
globs['sqlCtx'] = SQLContext(sc)
globs['df'] = sc.parallelize([Row(name='Alice', age=2), Row(name='Bob', age=5)]).toDF()
globs['df2'] = sc.parallelize([Row(name='Tom', height=80), Row(name='Bob', height=85)]).toDF()
+ globs['df3'] = sc.parallelize([Row(name='Alice', age=2, height=80),
+ Row(name='Bob', age=5, height=85)]).toDF()
(failure_count, test_count) = doctest.testmod(
pyspark.sql.dataframe, globs=globs,
optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala
index 7b7efbe347..9eb0c13140 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala
@@ -88,12 +88,12 @@ private[sql] class DataFrameImpl protected[sql](
}
}
- protected[sql] def numericColumns: Seq[Expression] = {
+ protected[sql] def numericColumns(): Seq[Expression] = {
schema.fields.filter(_.dataType.isInstanceOf[NumericType]).map { n =>
queryExecution.analyzed.resolve(n.name, sqlContext.analyzer.resolver).get
}
}
-
+
override def toDF(colNames: String*): DataFrame = {
require(schema.size == colNames.size,
"The number of columns doesn't match.\n" +
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala
index 0868013fe7..a5a677b688 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala
@@ -23,6 +23,8 @@ import scala.collection.JavaConversions._
import org.apache.spark.sql.catalyst.analysis.Star
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.Aggregate
+import org.apache.spark.sql.types.NumericType
+
/**
@@ -39,13 +41,30 @@ class GroupedData protected[sql](df: DataFrameImpl, groupingExprs: Seq[Expressio
df.sqlContext, Aggregate(groupingExprs, namedGroupingExprs ++ aggExprs, df.logicalPlan))
}
- private[this] def aggregateNumericColumns(f: Expression => Expression): Seq[NamedExpression] = {
- df.numericColumns.map { c =>
+ private[this] def aggregateNumericColumns(colNames: String*)(f: Expression => Expression)
+ : Seq[NamedExpression] = {
+
+ val columnExprs = if (colNames.isEmpty) {
+ // No columns specified. Use all numeric columns.
+ df.numericColumns
+ } else {
+ // Make sure all specified columns are numeric
+ colNames.map { colName =>
+ val namedExpr = df.resolve(colName)
+ if (!namedExpr.dataType.isInstanceOf[NumericType]) {
+ throw new AnalysisException(
+ s""""$colName" is not a numeric column. """ +
+ "Aggregation function can only be performed on a numeric column.")
+ }
+ namedExpr
+ }
+ }
+ columnExprs.map { c =>
val a = f(c)
Alias(a, a.toString)()
}
}
-
+
private[this] def strToExpr(expr: String): (Expression => Expression) = {
expr.toLowerCase match {
case "avg" | "average" | "mean" => Average
@@ -152,30 +171,50 @@ class GroupedData protected[sql](df: DataFrameImpl, groupingExprs: Seq[Expressio
/**
* Compute the average value for each numeric columns for each group. This is an alias for `avg`.
* The resulting [[DataFrame]] will also contain the grouping columns.
+ * When specified columns are given, only compute the average values for them.
*/
- def mean(): DataFrame = aggregateNumericColumns(Average)
-
+ @scala.annotation.varargs
+ def mean(colNames: String*): DataFrame = {
+ aggregateNumericColumns(colNames:_*)(Average)
+ }
+
/**
* Compute the max value for each numeric columns for each group.
* The resulting [[DataFrame]] will also contain the grouping columns.
+ * When specified columns are given, only compute the max values for them.
*/
- def max(): DataFrame = aggregateNumericColumns(Max)
+ @scala.annotation.varargs
+ def max(colNames: String*): DataFrame = {
+ aggregateNumericColumns(colNames:_*)(Max)
+ }
/**
* Compute the mean value for each numeric columns for each group.
* The resulting [[DataFrame]] will also contain the grouping columns.
+ * When specified columns are given, only compute the mean values for them.
*/
- def avg(): DataFrame = aggregateNumericColumns(Average)
+ @scala.annotation.varargs
+ def avg(colNames: String*): DataFrame = {
+ aggregateNumericColumns(colNames:_*)(Average)
+ }
/**
* Compute the min value for each numeric column for each group.
* The resulting [[DataFrame]] will also contain the grouping columns.
+ * When specified columns are given, only compute the min values for them.
*/
- def min(): DataFrame = aggregateNumericColumns(Min)
+ @scala.annotation.varargs
+ def min(colNames: String*): DataFrame = {
+ aggregateNumericColumns(colNames:_*)(Min)
+ }
/**
* Compute the sum for each numeric columns for each group.
* The resulting [[DataFrame]] will also contain the grouping columns.
+ * When specified columns are given, only compute the sum for them.
*/
- def sum(): DataFrame = aggregateNumericColumns(Sum)
+ @scala.annotation.varargs
+ def sum(colNames: String*): DataFrame = {
+ aggregateNumericColumns(colNames:_*)(Sum)
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index f0cd43632e..524571d9cc 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -162,6 +162,18 @@ class DataFrameSuite extends QueryTest {
testData2.groupBy("a").agg(Map("b" -> "sum")),
Row(1, 3) :: Row(2, 3) :: Row(3, 3) :: Nil
)
+
+ val df1 = Seq(("a", 1, 0, "b"), ("b", 2, 4, "c"), ("a", 2, 3, "d"))
+ .toDF("key", "value1", "value2", "rest")
+
+ checkAnswer(
+ df1.groupBy("key").min(),
+ df1.groupBy("key").min("value1", "value2").collect()
+ )
+ checkAnswer(
+ df1.groupBy("key").min("value2"),
+ Seq(Row("a", 0), Row("b", 4))
+ )
}
test("agg without groups") {