aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorReynold Xin <rxin@databricks.com>2015-02-04 18:35:51 -0800
committerReynold Xin <rxin@databricks.com>2015-02-04 18:35:51 -0800
commit1fbd124b1bd6159086d8e88b139ce0817af02322 (patch)
tree1124f3c60011e2ac5b0e3d867f32bb9c0277d41c /sql
parent9a7ce70eabc0ccaa036e142fc97bf0d37faa0b63 (diff)
downloadspark-1fbd124b1bd6159086d8e88b139ce0817af02322.tar.gz
spark-1fbd124b1bd6159086d8e88b139ce0817af02322.tar.bz2
spark-1fbd124b1bd6159086d8e88b139ce0817af02322.zip
[SPARK-5605][SQL][DF] Allow using String to specify colum name in DSL aggregate functions
Author: Reynold Xin <rxin@databricks.com> Closes #4376 from rxin/SPARK-5605 and squashes the following commits: c55f5fa [Reynold Xin] Added a Python test. f4b8dbb [Reynold Xin] [SPARK-5605][SQL][DF] Allow using String to specify colum name in DSL aggregate functions.
Diffstat (limited to 'sql')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala8
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala8
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/Dsl.scala37
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala (renamed from sql/core/src/main/scala/org/apache/spark/sql/GroupedDataFrame.scala)2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/IncomputableColumn.scala4
5 files changed, 48 insertions, 11 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
index a4997fb293..92e04ce17c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
@@ -290,7 +290,7 @@ trait DataFrame extends RDDApi[Row] {
/**
* Groups the [[DataFrame]] using the specified columns, so we can run aggregation on them.
- * See [[GroupedDataFrame]] for all the available aggregate functions.
+ * See [[GroupedData]] for all the available aggregate functions.
*
* {{{
* // Compute the average for all numeric columns grouped by department.
@@ -304,11 +304,11 @@ trait DataFrame extends RDDApi[Row] {
* }}}
*/
@scala.annotation.varargs
- def groupBy(cols: Column*): GroupedDataFrame
+ def groupBy(cols: Column*): GroupedData
/**
* Groups the [[DataFrame]] using the specified columns, so we can run aggregation on them.
- * See [[GroupedDataFrame]] for all the available aggregate functions.
+ * See [[GroupedData]] for all the available aggregate functions.
*
* This is a variant of groupBy that can only group by existing columns using column names
* (i.e. cannot construct expressions).
@@ -325,7 +325,7 @@ trait DataFrame extends RDDApi[Row] {
* }}}
*/
@scala.annotation.varargs
- def groupBy(col1: String, cols: String*): GroupedDataFrame
+ def groupBy(col1: String, cols: String*): GroupedData
/**
* (Scala-specific) Compute aggregates by specifying a map from column name to
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala
index c702adcb65..d6df927f9d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala
@@ -201,13 +201,13 @@ private[sql] class DataFrameImpl protected[sql](
filter(condition)
}
- override def groupBy(cols: Column*): GroupedDataFrame = {
- new GroupedDataFrame(this, cols.map(_.expr))
+ override def groupBy(cols: Column*): GroupedData = {
+ new GroupedData(this, cols.map(_.expr))
}
- override def groupBy(col1: String, cols: String*): GroupedDataFrame = {
+ override def groupBy(col1: String, cols: String*): GroupedData = {
val colNames: Seq[String] = col1 +: cols
- new GroupedDataFrame(this, colNames.map(colName => resolve(colName)))
+ new GroupedData(this, colNames.map(colName => resolve(colName)))
}
override def limit(n: Int): DataFrame = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dsl.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dsl.scala
index 50f442dd87..9afe496edc 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dsl.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dsl.scala
@@ -94,38 +94,75 @@ object Dsl {
/** Aggregate function: returns the sum of all values in the expression. */
def sum(e: Column): Column = Sum(e.expr)
+ /** Aggregate function: returns the sum of all values in the given column. */
+ def sum(columnName: String): Column = sum(Column(columnName))
+
/** Aggregate function: returns the sum of distinct values in the expression. */
def sumDistinct(e: Column): Column = SumDistinct(e.expr)
+ /** Aggregate function: returns the sum of distinct values in the expression. */
+ def sumDistinct(columnName: String): Column = sumDistinct(Column(columnName))
+
/** Aggregate function: returns the number of items in a group. */
def count(e: Column): Column = Count(e.expr)
+ /** Aggregate function: returns the number of items in a group. */
+ def count(columnName: String): Column = count(Column(columnName))
+
/** Aggregate function: returns the number of distinct items in a group. */
@scala.annotation.varargs
def countDistinct(expr: Column, exprs: Column*): Column =
CountDistinct((expr +: exprs).map(_.expr))
+ /** Aggregate function: returns the number of distinct items in a group. */
+ @scala.annotation.varargs
+ def countDistinct(columnName: String, columnNames: String*): Column =
+ countDistinct(Column(columnName), columnNames.map(Column.apply) :_*)
+
/** Aggregate function: returns the approximate number of distinct items in a group. */
def approxCountDistinct(e: Column): Column = ApproxCountDistinct(e.expr)
/** Aggregate function: returns the approximate number of distinct items in a group. */
+ def approxCountDistinct(columnName: String): Column = approxCountDistinct(column(columnName))
+
+ /** Aggregate function: returns the approximate number of distinct items in a group. */
def approxCountDistinct(e: Column, rsd: Double): Column = ApproxCountDistinct(e.expr, rsd)
+ /** Aggregate function: returns the approximate number of distinct items in a group. */
+ def approxCountDistinct(columnName: String, rsd: Double): Column = {
+ approxCountDistinct(Column(columnName), rsd)
+ }
+
/** Aggregate function: returns the average of the values in a group. */
def avg(e: Column): Column = Average(e.expr)
+ /** Aggregate function: returns the average of the values in a group. */
+ def avg(columnName: String): Column = avg(Column(columnName))
+
/** Aggregate function: returns the first value in a group. */
def first(e: Column): Column = First(e.expr)
+ /** Aggregate function: returns the first value of a column in a group. */
+ def first(columnName: String): Column = first(Column(columnName))
+
/** Aggregate function: returns the last value in a group. */
def last(e: Column): Column = Last(e.expr)
+ /** Aggregate function: returns the last value of the column in a group. */
+ def last(columnName: String): Column = last(Column(columnName))
+
/** Aggregate function: returns the minimum value of the expression in a group. */
def min(e: Column): Column = Min(e.expr)
+ /** Aggregate function: returns the minimum value of the column in a group. */
+ def min(columnName: String): Column = min(Column(columnName))
+
/** Aggregate function: returns the maximum value of the expression in a group. */
def max(e: Column): Column = Max(e.expr)
+ /** Aggregate function: returns the maximum value of the column in a group. */
+ def max(columnName: String): Column = max(Column(columnName))
+
//////////////////////////////////////////////////////////////////////////////////////////////
//////////////////////////////////////////////////////////////////////////////////////////////
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala
index 7963cb0312..3c20676355 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataFrame.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala
@@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.plans.logical.Aggregate
/**
* A set of methods for aggregations on a [[DataFrame]], created by [[DataFrame.groupBy]].
*/
-class GroupedDataFrame protected[sql](df: DataFrameImpl, groupingExprs: Seq[Expression]) {
+class GroupedData protected[sql](df: DataFrameImpl, groupingExprs: Seq[Expression]) {
private[this] implicit def toDataFrame(aggExprs: Seq[NamedExpression]): DataFrame = {
val namedGroupingExprs = groupingExprs.map {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/IncomputableColumn.scala b/sql/core/src/main/scala/org/apache/spark/sql/IncomputableColumn.scala
index 6b032d3d69..fedd7f06ef 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/IncomputableColumn.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/IncomputableColumn.scala
@@ -90,9 +90,9 @@ private[sql] class IncomputableColumn(protected[sql] val expr: Expression) exten
override def apply(condition: Column): DataFrame = err()
- override def groupBy(cols: Column*): GroupedDataFrame = err()
+ override def groupBy(cols: Column*): GroupedData = err()
- override def groupBy(col1: String, cols: String*): GroupedDataFrame = err()
+ override def groupBy(col1: String, cols: String*): GroupedData = err()
override def limit(n: Int): DataFrame = err()