aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAndrew Ray <ray.andrew@gmail.com>2015-11-11 16:23:24 -0800
committerYin Huai <yhuai@databricks.com>2015-11-11 16:23:24 -0800
commitb8ff6888e76b437287d7d6bf2d4b9c759710a195 (patch)
tree3f6a821b341a99dd3534603c88dfa6845d8982a8
parent1a21be15f655b9696ddac80aac629445a465f621 (diff)
downloadspark-b8ff6888e76b437287d7d6bf2d4b9c759710a195.tar.gz
spark-b8ff6888e76b437287d7d6bf2d4b9c759710a195.tar.bz2
spark-b8ff6888e76b437287d7d6bf2d4b9c759710a195.zip
[SPARK-8992][SQL] Add pivot to dataframe api
This adds a pivot method to the dataframe api. Following the lead of cube and rollup this adds a Pivot operator that is translated into an Aggregate by the analyzer. Currently the syntax is like: ~~courseSales.pivot(Seq($"year"), $"course", Seq("dotNET", "Java"), sum($"earnings"))~~ ~~Would we be interested in the following syntax also/alternatively? and~~ courseSales.groupBy($"year").pivot($"course", "dotNET", "Java").agg(sum($"earnings")) //or courseSales.groupBy($"year").pivot($"course").agg(sum($"earnings")) Later we can add it to `SQLParser`, but as Hive doesn't support it we cant add it there, right? ~~Also what would be the suggested Java friendly method signature for this?~~ Author: Andrew Ray <ray.andrew@gmail.com> Closes #7841 from aray/sql-pivot.
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala42
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala14
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala103
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala7
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala87
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala12
6 files changed, 255 insertions, 10 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index a9cd9a7703..2f4670b55b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -72,6 +72,7 @@ class Analyzer(
ResolveRelations ::
ResolveReferences ::
ResolveGroupingAnalytics ::
+ ResolvePivot ::
ResolveSortReferences ::
ResolveGenerate ::
ResolveFunctions ::
@@ -166,6 +167,10 @@ class Analyzer(
case g: GroupingAnalytics if g.child.resolved && hasUnresolvedAlias(g.aggregations) =>
g.withNewAggs(assignAliases(g.aggregations))
+ case Pivot(groupByExprs, pivotColumn, pivotValues, aggregates, child)
+ if child.resolved && hasUnresolvedAlias(groupByExprs) =>
+ Pivot(assignAliases(groupByExprs), pivotColumn, pivotValues, aggregates, child)
+
case Project(projectList, child) if child.resolved && hasUnresolvedAlias(projectList) =>
Project(assignAliases(projectList), child)
}
@@ -248,6 +253,43 @@ class Analyzer(
}
}
+ object ResolvePivot extends Rule[LogicalPlan] {
+ def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+ case p: Pivot if !p.childrenResolved => p
+ case Pivot(groupByExprs, pivotColumn, pivotValues, aggregates, child) =>
+ val singleAgg = aggregates.size == 1
+ val pivotAggregates: Seq[NamedExpression] = pivotValues.flatMap { value =>
+ def ifExpr(expr: Expression) = {
+ If(EqualTo(pivotColumn, value), expr, Literal(null))
+ }
+ aggregates.map { aggregate =>
+ val filteredAggregate = aggregate.transformDown {
+ // Assumption is the aggregate function ignores nulls. This is true for all current
+ // AggregateFunction's with the exception of First and Last in their default mode
+ // (which we handle) and possibly some Hive UDAF's.
+ case First(expr, _) =>
+ First(ifExpr(expr), Literal(true))
+ case Last(expr, _) =>
+ Last(ifExpr(expr), Literal(true))
+ case a: AggregateFunction =>
+ a.withNewChildren(a.children.map(ifExpr))
+ }
+ if (filteredAggregate.fastEquals(aggregate)) {
+ throw new AnalysisException(
+ s"Aggregate expression required for pivot, found '$aggregate'")
+ }
+ val name = if (singleAgg) value.toString else value + "_" + aggregate.prettyString
+ Alias(filteredAggregate, name)()
+ }
+ }
+ val newGroupByExprs = groupByExprs.map {
+ case UnresolvedAlias(e) => e
+ case e => e
+ }
+ Aggregate(newGroupByExprs, groupByExprs ++ pivotAggregates, child)
+ }
+ }
+
/**
* Replaces [[UnresolvedRelation]]s with concrete relations from the catalog.
*/
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
index 597f03e752..32b09b59af 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
@@ -386,6 +386,20 @@ case class Rollup(
this.copy(aggregations = aggs)
}
+case class Pivot(
+ groupByExprs: Seq[NamedExpression],
+ pivotColumn: Expression,
+ pivotValues: Seq[Literal],
+ aggregates: Seq[Expression],
+ child: LogicalPlan) extends UnaryNode {
+ override def output: Seq[Attribute] = groupByExprs.map(_.toAttribute) ++ aggregates match {
+ case agg :: Nil => pivotValues.map(value => AttributeReference(value.toString, agg.dataType)())
+ case _ => pivotValues.flatMap{ value =>
+ aggregates.map(agg => AttributeReference(value + "_" + agg.prettyString, agg.dataType)())
+ }
+ }
+}
+
case class Limit(limitExpr: Expression, child: LogicalPlan) extends UnaryNode {
override def output: Seq[Attribute] = child.output
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala
index 5babf2cc0c..63dd7fbcbe 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala
@@ -24,8 +24,8 @@ import org.apache.spark.annotation.Experimental
import org.apache.spark.sql.catalyst.analysis.{UnresolvedFunction, UnresolvedAlias, UnresolvedAttribute, Star}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
-import org.apache.spark.sql.catalyst.plans.logical.{Rollup, Cube, Aggregate}
-import org.apache.spark.sql.types.NumericType
+import org.apache.spark.sql.catalyst.plans.logical.{Pivot, Rollup, Cube, Aggregate}
+import org.apache.spark.sql.types.{StringType, NumericType}
/**
@@ -50,14 +50,8 @@ class GroupedData protected[sql](
aggExprs
}
- val aliasedAgg = aggregates.map {
- // Wrap UnresolvedAttribute with UnresolvedAlias, as when we resolve UnresolvedAttribute, we
- // will remove intermediate Alias for ExtractValue chain, and we need to alias it again to
- // make it a NamedExpression.
- case u: UnresolvedAttribute => UnresolvedAlias(u)
- case expr: NamedExpression => expr
- case expr: Expression => Alias(expr, expr.prettyString)()
- }
+ val aliasedAgg = aggregates.map(alias)
+
groupType match {
case GroupedData.GroupByType =>
DataFrame(
@@ -68,9 +62,22 @@ class GroupedData protected[sql](
case GroupedData.CubeType =>
DataFrame(
df.sqlContext, Cube(groupingExprs, df.logicalPlan, aliasedAgg))
+ case GroupedData.PivotType(pivotCol, values) =>
+ val aliasedGrps = groupingExprs.map(alias)
+ DataFrame(
+ df.sqlContext, Pivot(aliasedGrps, pivotCol, values, aggExprs, df.logicalPlan))
}
}
+ // Wrap UnresolvedAttribute with UnresolvedAlias, as when we resolve UnresolvedAttribute, we
+ // will remove intermediate Alias for ExtractValue chain, and we need to alias it again to
+ // make it a NamedExpression.
+ private[this] def alias(expr: Expression): NamedExpression = expr match {
+ case u: UnresolvedAttribute => UnresolvedAlias(u)
+ case expr: NamedExpression => expr
+ case expr: Expression => Alias(expr, expr.prettyString)()
+ }
+
private[this] def aggregateNumericColumns(colNames: String*)(f: Expression => AggregateFunction)
: DataFrame = {
@@ -273,6 +280,77 @@ class GroupedData protected[sql](
def sum(colNames: String*): DataFrame = {
aggregateNumericColumns(colNames : _*)(Sum)
}
+
+ /**
+ * (Scala-specific) Pivots a column of the current [[DataFrame]] and preform the specified
+ * aggregation.
+ * {{{
+ * // Compute the sum of earnings for each year by course with each course as a separate column
+ * df.groupBy($"year").pivot($"course", "dotNET", "Java").agg(sum($"earnings"))
+ * // Or without specifying column values
+ * df.groupBy($"year").pivot($"course").agg(sum($"earnings"))
+ * }}}
+ * @param pivotColumn Column to pivot
+ * @param values Optional list of values of pivotColumn that will be translated to columns in the
+ * output data frame. If values are not provided the method with do an immediate
+ * call to .distinct() on the pivot column.
+ * @since 1.6.0
+ */
+ @scala.annotation.varargs
+ def pivot(pivotColumn: Column, values: Column*): GroupedData = groupType match {
+ case _: GroupedData.PivotType =>
+ throw new UnsupportedOperationException("repeated pivots are not supported")
+ case GroupedData.GroupByType =>
+ val pivotValues = if (values.nonEmpty) {
+ values.map {
+ case Column(literal: Literal) => literal
+ case other =>
+ throw new UnsupportedOperationException(
+ s"The values of a pivot must be literals, found $other")
+ }
+ } else {
+ // This is to prevent unintended OOM errors when the number of distinct values is large
+ val maxValues = df.sqlContext.conf.getConf(SQLConf.DATAFRAME_PIVOT_MAX_VALUES)
+ // Get the distinct values of the column and sort them so its consistent
+ val values = df.select(pivotColumn)
+ .distinct()
+ .sort(pivotColumn)
+ .map(_.get(0))
+ .take(maxValues + 1)
+ .map(Literal(_)).toSeq
+ if (values.length > maxValues) {
+ throw new RuntimeException(
+ s"The pivot column $pivotColumn has more than $maxValues distinct values, " +
+ "this could indicate an error. " +
+ "If this was intended, set \"" + SQLConf.DATAFRAME_PIVOT_MAX_VALUES.key + "\" " +
+ s"to at least the number of distinct values of the pivot column.")
+ }
+ values
+ }
+ new GroupedData(df, groupingExprs, GroupedData.PivotType(pivotColumn.expr, pivotValues))
+ case _ =>
+ throw new UnsupportedOperationException("pivot is only supported after a groupBy")
+ }
+
+ /**
+ * Pivots a column of the current [[DataFrame]] and preform the specified aggregation.
+ * {{{
+ * // Compute the sum of earnings for each year by course with each course as a separate column
+ * df.groupBy("year").pivot("course", "dotNET", "Java").sum("earnings")
+ * // Or without specifying column values
+ * df.groupBy("year").pivot("course").sum("earnings")
+ * }}}
+ * @param pivotColumn Column to pivot
+ * @param values Optional list of values of pivotColumn that will be translated to columns in the
+ * output data frame. If values are not provided the method with do an immediate
+ * call to .distinct() on the pivot column.
+ * @since 1.6.0
+ */
+ @scala.annotation.varargs
+ def pivot(pivotColumn: String, values: Any*): GroupedData = {
+ val resolvedPivotColumn = Column(df.resolve(pivotColumn))
+ pivot(resolvedPivotColumn, values.map(functions.lit): _*)
+ }
}
@@ -307,4 +385,9 @@ private[sql] object GroupedData {
* To indicate it's the ROLLUP
*/
private[sql] object RollupType extends GroupType
+
+ /**
+ * To indicate it's the PIVOT
+ */
+ private[sql] case class PivotType(pivotCol: Expression, values: Seq[Literal]) extends GroupType
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
index e02b502b7b..41d28d448c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
@@ -437,6 +437,13 @@ private[spark] object SQLConf {
defaultValue = Some(true),
isPublic = false)
+ val DATAFRAME_PIVOT_MAX_VALUES = intConf(
+ "spark.sql.pivotMaxValues",
+ defaultValue = Some(10000),
+ doc = "When doing a pivot without specifying values for the pivot column this is the maximum " +
+ "number of (distinct) values that will be collected without error."
+ )
+
val RUN_SQL_ON_FILES = booleanConf("spark.sql.runSQLOnFiles",
defaultValue = Some(true),
isPublic = false,
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala
new file mode 100644
index 0000000000..0c23d14267
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala
@@ -0,0 +1,87 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.test.SharedSQLContext
+
+class DataFramePivotSuite extends QueryTest with SharedSQLContext{
+ import testImplicits._
+
+ test("pivot courses with literals") {
+ checkAnswer(
+ courseSales.groupBy($"year").pivot($"course", lit("dotNET"), lit("Java"))
+ .agg(sum($"earnings")),
+ Row(2012, 15000.0, 20000.0) :: Row(2013, 48000.0, 30000.0) :: Nil
+ )
+ }
+
+ test("pivot year with literals") {
+ checkAnswer(
+ courseSales.groupBy($"course").pivot($"year", lit(2012), lit(2013)).agg(sum($"earnings")),
+ Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil
+ )
+ }
+
+ test("pivot courses with literals and multiple aggregations") {
+ checkAnswer(
+ courseSales.groupBy($"year").pivot($"course", lit("dotNET"), lit("Java"))
+ .agg(sum($"earnings"), avg($"earnings")),
+ Row(2012, 15000.0, 7500.0, 20000.0, 20000.0) ::
+ Row(2013, 48000.0, 48000.0, 30000.0, 30000.0) :: Nil
+ )
+ }
+
+ test("pivot year with string values (cast)") {
+ checkAnswer(
+ courseSales.groupBy("course").pivot("year", "2012", "2013").sum("earnings"),
+ Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil
+ )
+ }
+
+ test("pivot year with int values") {
+ checkAnswer(
+ courseSales.groupBy("course").pivot("year", 2012, 2013).sum("earnings"),
+ Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil
+ )
+ }
+
+ test("pivot courses with no values") {
+ // Note Java comes before dotNet in sorted order
+ checkAnswer(
+ courseSales.groupBy($"year").pivot($"course").agg(sum($"earnings")),
+ Row(2012, 20000.0, 15000.0) :: Row(2013, 30000.0, 48000.0) :: Nil
+ )
+ }
+
+ test("pivot year with no values") {
+ checkAnswer(
+ courseSales.groupBy($"course").pivot($"year").agg(sum($"earnings")),
+ Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil
+ )
+ }
+
+ test("pivot max values inforced") {
+ sqlContext.conf.setConf(SQLConf.DATAFRAME_PIVOT_MAX_VALUES, 1)
+ intercept[RuntimeException](
+ courseSales.groupBy($"year").pivot($"course")
+ )
+ sqlContext.conf.setConf(SQLConf.DATAFRAME_PIVOT_MAX_VALUES,
+ SQLConf.DATAFRAME_PIVOT_MAX_VALUES.defaultValue.get)
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala
index 520dea7f7d..abad0d7eaa 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala
@@ -242,6 +242,17 @@ private[sql] trait SQLTestData { self =>
df
}
+ protected lazy val courseSales: DataFrame = {
+ val df = sqlContext.sparkContext.parallelize(
+ CourseSales("dotNET", 2012, 10000) ::
+ CourseSales("Java", 2012, 20000) ::
+ CourseSales("dotNET", 2012, 5000) ::
+ CourseSales("dotNET", 2013, 48000) ::
+ CourseSales("Java", 2013, 30000) :: Nil).toDF()
+ df.registerTempTable("courseSales")
+ df
+ }
+
/**
* Initialize all test data such that all temp tables are properly registered.
*/
@@ -295,4 +306,5 @@ private[sql] object SQLTestData {
case class Person(id: Int, name: String, age: Int)
case class Salary(personId: Int, salary: Double)
case class ComplexData(m: Map[String, Int], s: TestData, a: Seq[Int], b: Boolean)
+ case class CourseSales(course: String, year: Int, earnings: Double)
}