aboutsummaryrefslogtreecommitdiff
path: root/sql/core/src
diff options
context:
space:
mode:
authorAndrew Ray <ray.andrew@gmail.com>2015-11-11 16:23:24 -0800
committerYin Huai <yhuai@databricks.com>2015-11-11 16:23:24 -0800
commitb8ff6888e76b437287d7d6bf2d4b9c759710a195 (patch)
tree3f6a821b341a99dd3534603c88dfa6845d8982a8 /sql/core/src
parent1a21be15f655b9696ddac80aac629445a465f621 (diff)
downloadspark-b8ff6888e76b437287d7d6bf2d4b9c759710a195.tar.gz
spark-b8ff6888e76b437287d7d6bf2d4b9c759710a195.tar.bz2
spark-b8ff6888e76b437287d7d6bf2d4b9c759710a195.zip
[SPARK-8992][SQL] Add pivot to dataframe api
This adds a pivot method to the dataframe api. Following the lead of cube and rollup this adds a Pivot operator that is translated into an Aggregate by the analyzer. Currently the syntax is like: ~~courseSales.pivot(Seq($"year"), $"course", Seq("dotNET", "Java"), sum($"earnings"))~~ ~~Would we be interested in the following syntax also/alternatively? and~~ courseSales.groupBy($"year").pivot($"course", "dotNET", "Java").agg(sum($"earnings")) //or courseSales.groupBy($"year").pivot($"course").agg(sum($"earnings")) Later we can add it to `SQLParser`, but as Hive doesn't support it we cant add it there, right? ~~Also what would be the suggested Java friendly method signature for this?~~ Author: Andrew Ray <ray.andrew@gmail.com> Closes #7841 from aray/sql-pivot.
Diffstat (limited to 'sql/core/src')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala103
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala7
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala87
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala12
4 files changed, 199 insertions, 10 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala
index 5babf2cc0c..63dd7fbcbe 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala
@@ -24,8 +24,8 @@ import org.apache.spark.annotation.Experimental
import org.apache.spark.sql.catalyst.analysis.{UnresolvedFunction, UnresolvedAlias, UnresolvedAttribute, Star}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
-import org.apache.spark.sql.catalyst.plans.logical.{Rollup, Cube, Aggregate}
-import org.apache.spark.sql.types.NumericType
+import org.apache.spark.sql.catalyst.plans.logical.{Pivot, Rollup, Cube, Aggregate}
+import org.apache.spark.sql.types.{StringType, NumericType}
/**
@@ -50,14 +50,8 @@ class GroupedData protected[sql](
aggExprs
}
- val aliasedAgg = aggregates.map {
- // Wrap UnresolvedAttribute with UnresolvedAlias, as when we resolve UnresolvedAttribute, we
- // will remove intermediate Alias for ExtractValue chain, and we need to alias it again to
- // make it a NamedExpression.
- case u: UnresolvedAttribute => UnresolvedAlias(u)
- case expr: NamedExpression => expr
- case expr: Expression => Alias(expr, expr.prettyString)()
- }
+ val aliasedAgg = aggregates.map(alias)
+
groupType match {
case GroupedData.GroupByType =>
DataFrame(
@@ -68,9 +62,22 @@ class GroupedData protected[sql](
case GroupedData.CubeType =>
DataFrame(
df.sqlContext, Cube(groupingExprs, df.logicalPlan, aliasedAgg))
+ case GroupedData.PivotType(pivotCol, values) =>
+ val aliasedGrps = groupingExprs.map(alias)
+ DataFrame(
+ df.sqlContext, Pivot(aliasedGrps, pivotCol, values, aggExprs, df.logicalPlan))
}
}
+ // Wrap UnresolvedAttribute with UnresolvedAlias, as when we resolve UnresolvedAttribute, we
+ // will remove intermediate Alias for ExtractValue chain, and we need to alias it again to
+ // make it a NamedExpression.
+ private[this] def alias(expr: Expression): NamedExpression = expr match {
+ case u: UnresolvedAttribute => UnresolvedAlias(u)
+ case expr: NamedExpression => expr
+ case expr: Expression => Alias(expr, expr.prettyString)()
+ }
+
private[this] def aggregateNumericColumns(colNames: String*)(f: Expression => AggregateFunction)
: DataFrame = {
@@ -273,6 +280,77 @@ class GroupedData protected[sql](
def sum(colNames: String*): DataFrame = {
aggregateNumericColumns(colNames : _*)(Sum)
}
+
+ /**
+ * (Scala-specific) Pivots a column of the current [[DataFrame]] and preform the specified
+ * aggregation.
+ * {{{
+ * // Compute the sum of earnings for each year by course with each course as a separate column
+ * df.groupBy($"year").pivot($"course", "dotNET", "Java").agg(sum($"earnings"))
+ * // Or without specifying column values
+ * df.groupBy($"year").pivot($"course").agg(sum($"earnings"))
+ * }}}
+ * @param pivotColumn Column to pivot
+ * @param values Optional list of values of pivotColumn that will be translated to columns in the
+ * output data frame. If values are not provided the method with do an immediate
+ * call to .distinct() on the pivot column.
+ * @since 1.6.0
+ */
+ @scala.annotation.varargs
+ def pivot(pivotColumn: Column, values: Column*): GroupedData = groupType match {
+ case _: GroupedData.PivotType =>
+ throw new UnsupportedOperationException("repeated pivots are not supported")
+ case GroupedData.GroupByType =>
+ val pivotValues = if (values.nonEmpty) {
+ values.map {
+ case Column(literal: Literal) => literal
+ case other =>
+ throw new UnsupportedOperationException(
+ s"The values of a pivot must be literals, found $other")
+ }
+ } else {
+ // This is to prevent unintended OOM errors when the number of distinct values is large
+ val maxValues = df.sqlContext.conf.getConf(SQLConf.DATAFRAME_PIVOT_MAX_VALUES)
+ // Get the distinct values of the column and sort them so its consistent
+ val values = df.select(pivotColumn)
+ .distinct()
+ .sort(pivotColumn)
+ .map(_.get(0))
+ .take(maxValues + 1)
+ .map(Literal(_)).toSeq
+ if (values.length > maxValues) {
+ throw new RuntimeException(
+ s"The pivot column $pivotColumn has more than $maxValues distinct values, " +
+ "this could indicate an error. " +
+ "If this was intended, set \"" + SQLConf.DATAFRAME_PIVOT_MAX_VALUES.key + "\" " +
+ s"to at least the number of distinct values of the pivot column.")
+ }
+ values
+ }
+ new GroupedData(df, groupingExprs, GroupedData.PivotType(pivotColumn.expr, pivotValues))
+ case _ =>
+ throw new UnsupportedOperationException("pivot is only supported after a groupBy")
+ }
+
+ /**
+ * Pivots a column of the current [[DataFrame]] and preform the specified aggregation.
+ * {{{
+ * // Compute the sum of earnings for each year by course with each course as a separate column
+ * df.groupBy("year").pivot("course", "dotNET", "Java").sum("earnings")
+ * // Or without specifying column values
+ * df.groupBy("year").pivot("course").sum("earnings")
+ * }}}
+ * @param pivotColumn Column to pivot
+ * @param values Optional list of values of pivotColumn that will be translated to columns in the
+ * output data frame. If values are not provided the method with do an immediate
+ * call to .distinct() on the pivot column.
+ * @since 1.6.0
+ */
+ @scala.annotation.varargs
+ def pivot(pivotColumn: String, values: Any*): GroupedData = {
+ val resolvedPivotColumn = Column(df.resolve(pivotColumn))
+ pivot(resolvedPivotColumn, values.map(functions.lit): _*)
+ }
}
@@ -307,4 +385,9 @@ private[sql] object GroupedData {
* To indicate it's the ROLLUP
*/
private[sql] object RollupType extends GroupType
+
+ /**
+ * To indicate it's the PIVOT
+ */
+ private[sql] case class PivotType(pivotCol: Expression, values: Seq[Literal]) extends GroupType
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
index e02b502b7b..41d28d448c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
@@ -437,6 +437,13 @@ private[spark] object SQLConf {
defaultValue = Some(true),
isPublic = false)
+ val DATAFRAME_PIVOT_MAX_VALUES = intConf(
+ "spark.sql.pivotMaxValues",
+ defaultValue = Some(10000),
+ doc = "When doing a pivot without specifying values for the pivot column this is the maximum " +
+ "number of (distinct) values that will be collected without error."
+ )
+
val RUN_SQL_ON_FILES = booleanConf("spark.sql.runSQLOnFiles",
defaultValue = Some(true),
isPublic = false,
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala
new file mode 100644
index 0000000000..0c23d14267
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala
@@ -0,0 +1,87 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.test.SharedSQLContext
+
+class DataFramePivotSuite extends QueryTest with SharedSQLContext{
+ import testImplicits._
+
+ test("pivot courses with literals") {
+ checkAnswer(
+ courseSales.groupBy($"year").pivot($"course", lit("dotNET"), lit("Java"))
+ .agg(sum($"earnings")),
+ Row(2012, 15000.0, 20000.0) :: Row(2013, 48000.0, 30000.0) :: Nil
+ )
+ }
+
+ test("pivot year with literals") {
+ checkAnswer(
+ courseSales.groupBy($"course").pivot($"year", lit(2012), lit(2013)).agg(sum($"earnings")),
+ Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil
+ )
+ }
+
+ test("pivot courses with literals and multiple aggregations") {
+ checkAnswer(
+ courseSales.groupBy($"year").pivot($"course", lit("dotNET"), lit("Java"))
+ .agg(sum($"earnings"), avg($"earnings")),
+ Row(2012, 15000.0, 7500.0, 20000.0, 20000.0) ::
+ Row(2013, 48000.0, 48000.0, 30000.0, 30000.0) :: Nil
+ )
+ }
+
+ test("pivot year with string values (cast)") {
+ checkAnswer(
+ courseSales.groupBy("course").pivot("year", "2012", "2013").sum("earnings"),
+ Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil
+ )
+ }
+
+ test("pivot year with int values") {
+ checkAnswer(
+ courseSales.groupBy("course").pivot("year", 2012, 2013).sum("earnings"),
+ Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil
+ )
+ }
+
+ test("pivot courses with no values") {
+ // Note Java comes before dotNet in sorted order
+ checkAnswer(
+ courseSales.groupBy($"year").pivot($"course").agg(sum($"earnings")),
+ Row(2012, 20000.0, 15000.0) :: Row(2013, 30000.0, 48000.0) :: Nil
+ )
+ }
+
+ test("pivot year with no values") {
+ checkAnswer(
+ courseSales.groupBy($"course").pivot($"year").agg(sum($"earnings")),
+ Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil
+ )
+ }
+
+ test("pivot max values inforced") {
+ sqlContext.conf.setConf(SQLConf.DATAFRAME_PIVOT_MAX_VALUES, 1)
+ intercept[RuntimeException](
+ courseSales.groupBy($"year").pivot($"course")
+ )
+ sqlContext.conf.setConf(SQLConf.DATAFRAME_PIVOT_MAX_VALUES,
+ SQLConf.DATAFRAME_PIVOT_MAX_VALUES.defaultValue.get)
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala
index 520dea7f7d..abad0d7eaa 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala
@@ -242,6 +242,17 @@ private[sql] trait SQLTestData { self =>
df
}
+ protected lazy val courseSales: DataFrame = {
+ val df = sqlContext.sparkContext.parallelize(
+ CourseSales("dotNET", 2012, 10000) ::
+ CourseSales("Java", 2012, 20000) ::
+ CourseSales("dotNET", 2012, 5000) ::
+ CourseSales("dotNET", 2013, 48000) ::
+ CourseSales("Java", 2013, 30000) :: Nil).toDF()
+ df.registerTempTable("courseSales")
+ df
+ }
+
/**
* Initialize all test data such that all temp tables are properly registered.
*/
@@ -295,4 +306,5 @@ private[sql] object SQLTestData {
case class Person(id: Int, name: String, age: Int)
case class Salary(personId: Int, salary: Double)
case class ComplexData(m: Map[String, Int], s: TestData, a: Seq[Int], b: Boolean)
+ case class CourseSales(course: String, year: Int, earnings: Double)
}