aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorCheng Hao <hao.cheng@intel.com>2015-05-20 19:58:22 -0700
committerYin Huai <yhuai@databricks.com>2015-05-20 19:58:22 -0700
commit42c592adb381ff20832cce55e0849ed68dd7eee4 (patch)
treed574a7389ec67b5837ef0ec7eb26c9d456fe810c /sql
parent895baf8f77e630ce32b0e25b00bf5ee45d17398f (diff)
downloadspark-42c592adb381ff20832cce55e0849ed68dd7eee4.tar.gz
spark-42c592adb381ff20832cce55e0849ed68dd7eee4.tar.bz2
spark-42c592adb381ff20832cce55e0849ed68dd7eee4.zip
[SPARK-7320] [SQL] Add Cube / Rollup for dataframe
This is a follow up for #6257, which broke the maven test. Add cube & rollup for DataFrame For example: ```scala testData.rollup($"a" + $"b", $"b").agg(sum($"a" - $"b")) testData.cube($"a" + $"b", $"b").agg(sum($"a" - $"b")) ``` Author: Cheng Hao <hao.cheng@intel.com> Closes #6304 from chenghao-intel/rollup and squashes the following commits: 04bb1de [Cheng Hao] move the table register/unregister into beforeAll/afterAll a6069f1 [Cheng Hao] cancel the implicit keyword ced4b8f [Cheng Hao] remove the unnecessary code changes 9959dfa [Cheng Hao] update the code as comments e1d88aa [Cheng Hao] update the code as suggested 03bc3d9 [Cheng Hao] Remove the CubedData & RollupedData 5fd62d0 [Cheng Hao] hiden the CubedData & RollupedData 5ffb196 [Cheng Hao] Add Cube / Rollup for dataframe
Diffstat (limited to 'sql')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala104
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala92
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala69
3 files changed, 237 insertions, 28 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
index adad85806d..d78b4c2f89 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
@@ -685,7 +685,53 @@ class DataFrame private[sql](
* @since 1.3.0
*/
@scala.annotation.varargs
- def groupBy(cols: Column*): GroupedData = new GroupedData(this, cols.map(_.expr))
+ def groupBy(cols: Column*): GroupedData = {
+ GroupedData(this, cols.map(_.expr), GroupedData.GroupByType)
+ }
+
+ /**
+ * Create a multi-dimensional rollup for the current [[DataFrame]] using the specified columns,
+ * so we can run aggregation on them.
+ * See [[GroupedData]] for all the available aggregate functions.
+ *
+ * {{{
+ * // Compute the average for all numeric columns rolluped by department and group.
+ * df.rollup($"department", $"group").avg()
+ *
+ * // Compute the max age and average salary, rolluped by department and gender.
+ * df.rollup($"department", $"gender").agg(Map(
+ * "salary" -> "avg",
+ * "age" -> "max"
+ * ))
+ * }}}
+ * @group dfops
+ * @since 1.4.0
+ */
+ @scala.annotation.varargs
+ def rollup(cols: Column*): GroupedData = {
+ GroupedData(this, cols.map(_.expr), GroupedData.RollupType)
+ }
+
+ /**
+ * Create a multi-dimensional cube for the current [[DataFrame]] using the specified columns,
+ * so we can run aggregation on them.
+ * See [[GroupedData]] for all the available aggregate functions.
+ *
+ * {{{
+ * // Compute the average for all numeric columns cubed by department and group.
+ * df.cube($"department", $"group").avg()
+ *
+ * // Compute the max age and average salary, cubed by department and gender.
+ * df.cube($"department", $"gender").agg(Map(
+ * "salary" -> "avg",
+ * "age" -> "max"
+ * ))
+ * }}}
+ * @group dfops
+ * @since 1.4.0
+ */
+ @scala.annotation.varargs
+ def cube(cols: Column*): GroupedData = GroupedData(this, cols.map(_.expr), GroupedData.CubeType)
/**
* Groups the [[DataFrame]] using the specified columns, so we can run aggregation on them.
@@ -710,7 +756,61 @@ class DataFrame private[sql](
@scala.annotation.varargs
def groupBy(col1: String, cols: String*): GroupedData = {
val colNames: Seq[String] = col1 +: cols
- new GroupedData(this, colNames.map(colName => resolve(colName)))
+ GroupedData(this, colNames.map(colName => resolve(colName)), GroupedData.GroupByType)
+ }
+
+ /**
+ * Create a multi-dimensional rollup for the current [[DataFrame]] using the specified columns,
+ * so we can run aggregation on them.
+ * See [[GroupedData]] for all the available aggregate functions.
+ *
+ * This is a variant of rollup that can only group by existing columns using column names
+ * (i.e. cannot construct expressions).
+ *
+ * {{{
+ * // Compute the average for all numeric columns rolluped by department and group.
+ * df.rollup("department", "group").avg()
+ *
+ * // Compute the max age and average salary, rolluped by department and gender.
+ * df.rollup($"department", $"gender").agg(Map(
+ * "salary" -> "avg",
+ * "age" -> "max"
+ * ))
+ * }}}
+ * @group dfops
+ * @since 1.4.0
+ */
+ @scala.annotation.varargs
+ def rollup(col1: String, cols: String*): GroupedData = {
+ val colNames: Seq[String] = col1 +: cols
+ GroupedData(this, colNames.map(colName => resolve(colName)), GroupedData.RollupType)
+ }
+
+ /**
+ * Create a multi-dimensional cube for the current [[DataFrame]] using the specified columns,
+ * so we can run aggregation on them.
+ * See [[GroupedData]] for all the available aggregate functions.
+ *
+ * This is a variant of cube that can only group by existing columns using column names
+ * (i.e. cannot construct expressions).
+ *
+ * {{{
+ * // Compute the average for all numeric columns cubed by department and group.
+ * df.cube("department", "group").avg()
+ *
+ * // Compute the max age and average salary, cubed by department and gender.
+ * df.cube($"department", $"gender").agg(Map(
+ * "salary" -> "avg",
+ * "age" -> "max"
+ * ))
+ * }}}
+ * @group dfops
+ * @since 1.4.0
+ */
+ @scala.annotation.varargs
+ def cube(col1: String, cols: String*): GroupedData = {
+ val colNames: Seq[String] = col1 +: cols
+ GroupedData(this, colNames.map(colName => resolve(colName)), GroupedData.CubeType)
}
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala
index 1381b9f1a6..f730e4ae00 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala
@@ -23,9 +23,40 @@ import scala.language.implicitConversions
import org.apache.spark.annotation.Experimental
import org.apache.spark.sql.catalyst.analysis.Star
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.plans.logical.Aggregate
+import org.apache.spark.sql.catalyst.plans.logical.{Rollup, Cube, Aggregate}
import org.apache.spark.sql.types.NumericType
+/**
+ * Companion object for GroupedData
+ */
+private[sql] object GroupedData {
+ def apply(
+ df: DataFrame,
+ groupingExprs: Seq[Expression],
+ groupType: GroupType): GroupedData = {
+ new GroupedData(df, groupingExprs, groupType: GroupType)
+ }
+
+ /**
+ * The Grouping Type
+ */
+ trait GroupType
+
+ /**
+ * To indicate it's the GroupBy
+ */
+ object GroupByType extends GroupType
+
+ /**
+ * To indicate it's the CUBE
+ */
+ object CubeType extends GroupType
+
+ /**
+ * To indicate it's the ROLLUP
+ */
+ object RollupType extends GroupType
+}
/**
* :: Experimental ::
@@ -34,19 +65,37 @@ import org.apache.spark.sql.types.NumericType
* @since 1.3.0
*/
@Experimental
-class GroupedData protected[sql](df: DataFrame, groupingExprs: Seq[Expression]) {
+class GroupedData protected[sql](
+ df: DataFrame,
+ groupingExprs: Seq[Expression],
+ private val groupType: GroupedData.GroupType) {
- private[sql] implicit def toDF(aggExprs: Seq[NamedExpression]): DataFrame = {
- val namedGroupingExprs = groupingExprs.map {
- case expr: NamedExpression => expr
- case expr: Expression => Alias(expr, expr.prettyString)()
+ private[this] def toDF(aggExprs: Seq[NamedExpression]): DataFrame = {
+ val aggregates = if (df.sqlContext.conf.dataFrameRetainGroupColumns) {
+ val retainedExprs = groupingExprs.map {
+ case expr: NamedExpression => expr
+ case expr: Expression => Alias(expr, expr.prettyString)()
+ }
+ retainedExprs ++ aggExprs
+ } else {
+ aggExprs
+ }
+
+ groupType match {
+ case GroupedData.GroupByType =>
+ DataFrame(
+ df.sqlContext, Aggregate(groupingExprs, aggregates, df.logicalPlan))
+ case GroupedData.RollupType =>
+ DataFrame(
+ df.sqlContext, Rollup(groupingExprs, df.logicalPlan, aggregates))
+ case GroupedData.CubeType =>
+ DataFrame(
+ df.sqlContext, Cube(groupingExprs, df.logicalPlan, aggregates))
}
- DataFrame(
- df.sqlContext, Aggregate(groupingExprs, namedGroupingExprs ++ aggExprs, df.logicalPlan))
}
private[this] def aggregateNumericColumns(colNames: String*)(f: Expression => Expression)
- : Seq[NamedExpression] = {
+ : DataFrame = {
val columnExprs = if (colNames.isEmpty) {
// No columns specified. Use all numeric columns.
@@ -63,10 +112,10 @@ class GroupedData protected[sql](df: DataFrame, groupingExprs: Seq[Expression])
namedExpr
}
}
- columnExprs.map { c =>
+ toDF(columnExprs.map { c =>
val a = f(c)
Alias(a, a.prettyString)()
- }
+ })
}
private[this] def strToExpr(expr: String): (Expression => Expression) = {
@@ -119,10 +168,10 @@ class GroupedData protected[sql](df: DataFrame, groupingExprs: Seq[Expression])
* @since 1.3.0
*/
def agg(exprs: Map[String, String]): DataFrame = {
- exprs.map { case (colName, expr) =>
+ toDF(exprs.map { case (colName, expr) =>
val a = strToExpr(expr)(df(colName).expr)
Alias(a, a.prettyString)()
- }.toSeq
+ }.toSeq)
}
/**
@@ -175,19 +224,10 @@ class GroupedData protected[sql](df: DataFrame, groupingExprs: Seq[Expression])
*/
@scala.annotation.varargs
def agg(expr: Column, exprs: Column*): DataFrame = {
- val aggExprs = (expr +: exprs).map(_.expr).map {
+ toDF((expr +: exprs).map(_.expr).map {
case expr: NamedExpression => expr
case expr: Expression => Alias(expr, expr.prettyString)()
- }
- if (df.sqlContext.conf.dataFrameRetainGroupColumns) {
- val retainedExprs = groupingExprs.map {
- case expr: NamedExpression => expr
- case expr: Expression => Alias(expr, expr.prettyString)()
- }
- DataFrame(df.sqlContext, Aggregate(groupingExprs, retainedExprs ++ aggExprs, df.logicalPlan))
- } else {
- DataFrame(df.sqlContext, Aggregate(groupingExprs, aggExprs, df.logicalPlan))
- }
+ })
}
/**
@@ -196,7 +236,7 @@ class GroupedData protected[sql](df: DataFrame, groupingExprs: Seq[Expression])
*
* @since 1.3.0
*/
- def count(): DataFrame = Seq(Alias(Count(Literal(1)), "count")())
+ def count(): DataFrame = toDF(Seq(Alias(Count(Literal(1)), "count")()))
/**
* Compute the average value for each numeric columns for each group. This is an alias for `avg`.
@@ -256,5 +296,5 @@ class GroupedData protected[sql](df: DataFrame, groupingExprs: Seq[Expression])
@scala.annotation.varargs
def sum(colNames: String*): DataFrame = {
aggregateNumericColumns(colNames:_*)(Sum)
- }
+ }
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala
new file mode 100644
index 0000000000..99de14660f
--- /dev/null
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala
@@ -0,0 +1,69 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hive
+
+import org.apache.spark.sql.QueryTest
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.hive.test.TestHive
+import org.apache.spark.sql.hive.test.TestHive._
+import org.apache.spark.sql.hive.test.TestHive.implicits._
+import org.scalatest.BeforeAndAfterAll
+
+case class TestData2Int(a: Int, b: Int)
+
+// TODO ideally we should put the test suite into the package `sql`, as
+// `hive` package is optional in compiling, however, `SQLContext.sql` doesn't
+// support the `cube` or `rollup` yet.
+class HiveDataFrameAnalyticsSuite extends QueryTest with BeforeAndAfterAll {
+ val testData =
+ TestHive.sparkContext.parallelize(
+ TestData2Int(1, 2) ::
+ TestData2Int(2, 4) :: Nil).toDF()
+
+ override def beforeAll() {
+ TestHive.registerDataFrameAsTable(testData, "mytable")
+ }
+
+ override def afterAll(): Unit = {
+ TestHive.dropTempTable("mytable")
+ }
+
+ test("rollup") {
+ checkAnswer(
+ testData.rollup($"a" + $"b", $"b").agg(sum($"a" - $"b")),
+ sql("select a + b, b, sum(a - b) from mytable group by a + b, b with rollup").collect()
+ )
+
+ checkAnswer(
+ testData.rollup("a", "b").agg(sum("b")),
+ sql("select a, b, sum(b) from mytable group by a, b with rollup").collect()
+ )
+ }
+
+ test("cube") {
+ checkAnswer(
+ testData.cube($"a" + $"b", $"b").agg(sum($"a" - $"b")),
+ sql("select a + b, b, sum(a - b) from mytable group by a + b, b with cube").collect()
+ )
+
+ checkAnswer(
+ testData.cube("a", "b").agg(sum("b")),
+ sql("select a, b, sum(b) from mytable group by a, b with cube").collect()
+ )
+ }
+}