aboutsummaryrefslogtreecommitdiff
path: root/sql/core/src/main
diff options
context:
space:
mode:
authorAndrew Ray <ray.andrew@gmail.com>2015-11-11 16:23:24 -0800
committerYin Huai <yhuai@databricks.com>2015-11-11 16:23:24 -0800
commitb8ff6888e76b437287d7d6bf2d4b9c759710a195 (patch)
tree3f6a821b341a99dd3534603c88dfa6845d8982a8 /sql/core/src/main
parent1a21be15f655b9696ddac80aac629445a465f621 (diff)
downloadspark-b8ff6888e76b437287d7d6bf2d4b9c759710a195.tar.gz
spark-b8ff6888e76b437287d7d6bf2d4b9c759710a195.tar.bz2
spark-b8ff6888e76b437287d7d6bf2d4b9c759710a195.zip
[SPARK-8992][SQL] Add pivot to dataframe api
This adds a pivot method to the dataframe api. Following the lead of cube and rollup this adds a Pivot operator that is translated into an Aggregate by the analyzer. Currently the syntax is like: ~~courseSales.pivot(Seq($"year"), $"course", Seq("dotNET", "Java"), sum($"earnings"))~~ ~~Would we be interested in the following syntax also/alternatively? and~~ courseSales.groupBy($"year").pivot($"course", "dotNET", "Java").agg(sum($"earnings")) //or courseSales.groupBy($"year").pivot($"course").agg(sum($"earnings")) Later we can add it to `SQLParser`, but as Hive doesn't support it we cant add it there, right? ~~Also what would be the suggested Java friendly method signature for this?~~ Author: Andrew Ray <ray.andrew@gmail.com> Closes #7841 from aray/sql-pivot.
Diffstat (limited to 'sql/core/src/main')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala103
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala7
2 files changed, 100 insertions, 10 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala
index 5babf2cc0c..63dd7fbcbe 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala
@@ -24,8 +24,8 @@ import org.apache.spark.annotation.Experimental
import org.apache.spark.sql.catalyst.analysis.{UnresolvedFunction, UnresolvedAlias, UnresolvedAttribute, Star}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
-import org.apache.spark.sql.catalyst.plans.logical.{Rollup, Cube, Aggregate}
-import org.apache.spark.sql.types.NumericType
+import org.apache.spark.sql.catalyst.plans.logical.{Pivot, Rollup, Cube, Aggregate}
+import org.apache.spark.sql.types.{StringType, NumericType}
/**
@@ -50,14 +50,8 @@ class GroupedData protected[sql](
aggExprs
}
- val aliasedAgg = aggregates.map {
- // Wrap UnresolvedAttribute with UnresolvedAlias, as when we resolve UnresolvedAttribute, we
- // will remove intermediate Alias for ExtractValue chain, and we need to alias it again to
- // make it a NamedExpression.
- case u: UnresolvedAttribute => UnresolvedAlias(u)
- case expr: NamedExpression => expr
- case expr: Expression => Alias(expr, expr.prettyString)()
- }
+ val aliasedAgg = aggregates.map(alias)
+
groupType match {
case GroupedData.GroupByType =>
DataFrame(
@@ -68,9 +62,22 @@ class GroupedData protected[sql](
case GroupedData.CubeType =>
DataFrame(
df.sqlContext, Cube(groupingExprs, df.logicalPlan, aliasedAgg))
+ case GroupedData.PivotType(pivotCol, values) =>
+ val aliasedGrps = groupingExprs.map(alias)
+ DataFrame(
+ df.sqlContext, Pivot(aliasedGrps, pivotCol, values, aggExprs, df.logicalPlan))
}
}
+ // Wrap UnresolvedAttribute with UnresolvedAlias, as when we resolve UnresolvedAttribute, we
+ // will remove intermediate Alias for ExtractValue chain, and we need to alias it again to
+ // make it a NamedExpression.
+ private[this] def alias(expr: Expression): NamedExpression = expr match {
+ case u: UnresolvedAttribute => UnresolvedAlias(u)
+ case expr: NamedExpression => expr
+ case expr: Expression => Alias(expr, expr.prettyString)()
+ }
+
private[this] def aggregateNumericColumns(colNames: String*)(f: Expression => AggregateFunction)
: DataFrame = {
@@ -273,6 +280,77 @@ class GroupedData protected[sql](
def sum(colNames: String*): DataFrame = {
aggregateNumericColumns(colNames : _*)(Sum)
}
+
+ /**
+ * (Scala-specific) Pivots a column of the current [[DataFrame]] and preform the specified
+ * aggregation.
+ * {{{
+ * // Compute the sum of earnings for each year by course with each course as a separate column
+ * df.groupBy($"year").pivot($"course", "dotNET", "Java").agg(sum($"earnings"))
+ * // Or without specifying column values
+ * df.groupBy($"year").pivot($"course").agg(sum($"earnings"))
+ * }}}
+ * @param pivotColumn Column to pivot
+ * @param values Optional list of values of pivotColumn that will be translated to columns in the
+ * output data frame. If values are not provided the method with do an immediate
+ * call to .distinct() on the pivot column.
+ * @since 1.6.0
+ */
+ @scala.annotation.varargs
+ def pivot(pivotColumn: Column, values: Column*): GroupedData = groupType match {
+ case _: GroupedData.PivotType =>
+ throw new UnsupportedOperationException("repeated pivots are not supported")
+ case GroupedData.GroupByType =>
+ val pivotValues = if (values.nonEmpty) {
+ values.map {
+ case Column(literal: Literal) => literal
+ case other =>
+ throw new UnsupportedOperationException(
+ s"The values of a pivot must be literals, found $other")
+ }
+ } else {
+ // This is to prevent unintended OOM errors when the number of distinct values is large
+ val maxValues = df.sqlContext.conf.getConf(SQLConf.DATAFRAME_PIVOT_MAX_VALUES)
+ // Get the distinct values of the column and sort them so its consistent
+ val values = df.select(pivotColumn)
+ .distinct()
+ .sort(pivotColumn)
+ .map(_.get(0))
+ .take(maxValues + 1)
+ .map(Literal(_)).toSeq
+ if (values.length > maxValues) {
+ throw new RuntimeException(
+ s"The pivot column $pivotColumn has more than $maxValues distinct values, " +
+ "this could indicate an error. " +
+ "If this was intended, set \"" + SQLConf.DATAFRAME_PIVOT_MAX_VALUES.key + "\" " +
+ s"to at least the number of distinct values of the pivot column.")
+ }
+ values
+ }
+ new GroupedData(df, groupingExprs, GroupedData.PivotType(pivotColumn.expr, pivotValues))
+ case _ =>
+ throw new UnsupportedOperationException("pivot is only supported after a groupBy")
+ }
+
+ /**
+ * Pivots a column of the current [[DataFrame]] and preform the specified aggregation.
+ * {{{
+ * // Compute the sum of earnings for each year by course with each course as a separate column
+ * df.groupBy("year").pivot("course", "dotNET", "Java").sum("earnings")
+ * // Or without specifying column values
+ * df.groupBy("year").pivot("course").sum("earnings")
+ * }}}
+ * @param pivotColumn Column to pivot
+ * @param values Optional list of values of pivotColumn that will be translated to columns in the
+ * output data frame. If values are not provided the method with do an immediate
+ * call to .distinct() on the pivot column.
+ * @since 1.6.0
+ */
+ @scala.annotation.varargs
+ def pivot(pivotColumn: String, values: Any*): GroupedData = {
+ val resolvedPivotColumn = Column(df.resolve(pivotColumn))
+ pivot(resolvedPivotColumn, values.map(functions.lit): _*)
+ }
}
@@ -307,4 +385,9 @@ private[sql] object GroupedData {
* To indicate it's the ROLLUP
*/
private[sql] object RollupType extends GroupType
+
+ /**
+ * To indicate it's the PIVOT
+ */
+ private[sql] case class PivotType(pivotCol: Expression, values: Seq[Literal]) extends GroupType
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
index e02b502b7b..41d28d448c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
@@ -437,6 +437,13 @@ private[spark] object SQLConf {
defaultValue = Some(true),
isPublic = false)
+ val DATAFRAME_PIVOT_MAX_VALUES = intConf(
+ "spark.sql.pivotMaxValues",
+ defaultValue = Some(10000),
+ doc = "When doing a pivot without specifying values for the pivot column this is the maximum " +
+ "number of (distinct) values that will be collected without error."
+ )
+
val RUN_SQL_ON_FILES = booleanConf("spark.sql.runSQLOnFiles",
defaultValue = Some(true),
isPublic = false,