aboutsummaryrefslogtreecommitdiff
path: root/sql/catalyst
diff options
context:
space:
mode:
authorAndrew Ray <ray.andrew@gmail.com>2015-11-11 16:23:24 -0800
committerYin Huai <yhuai@databricks.com>2015-11-11 16:23:24 -0800
commitb8ff6888e76b437287d7d6bf2d4b9c759710a195 (patch)
tree3f6a821b341a99dd3534603c88dfa6845d8982a8 /sql/catalyst
parent1a21be15f655b9696ddac80aac629445a465f621 (diff)
downloadspark-b8ff6888e76b437287d7d6bf2d4b9c759710a195.tar.gz
spark-b8ff6888e76b437287d7d6bf2d4b9c759710a195.tar.bz2
spark-b8ff6888e76b437287d7d6bf2d4b9c759710a195.zip
[SPARK-8992][SQL] Add pivot to dataframe api
This adds a pivot method to the dataframe api. Following the lead of cube and rollup this adds a Pivot operator that is translated into an Aggregate by the analyzer. Currently the syntax is like: ~~courseSales.pivot(Seq($"year"), $"course", Seq("dotNET", "Java"), sum($"earnings"))~~ ~~Would we be interested in the following syntax also/alternatively? and~~ courseSales.groupBy($"year").pivot($"course", "dotNET", "Java").agg(sum($"earnings")) //or courseSales.groupBy($"year").pivot($"course").agg(sum($"earnings")) Later we can add it to `SQLParser`, but as Hive doesn't support it we cant add it there, right? ~~Also what would be the suggested Java friendly method signature for this?~~ Author: Andrew Ray <ray.andrew@gmail.com> Closes #7841 from aray/sql-pivot.
Diffstat (limited to 'sql/catalyst')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala42
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala14
2 files changed, 56 insertions, 0 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index a9cd9a7703..2f4670b55b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -72,6 +72,7 @@ class Analyzer(
ResolveRelations ::
ResolveReferences ::
ResolveGroupingAnalytics ::
+ ResolvePivot ::
ResolveSortReferences ::
ResolveGenerate ::
ResolveFunctions ::
@@ -166,6 +167,10 @@ class Analyzer(
case g: GroupingAnalytics if g.child.resolved && hasUnresolvedAlias(g.aggregations) =>
g.withNewAggs(assignAliases(g.aggregations))
+ case Pivot(groupByExprs, pivotColumn, pivotValues, aggregates, child)
+ if child.resolved && hasUnresolvedAlias(groupByExprs) =>
+ Pivot(assignAliases(groupByExprs), pivotColumn, pivotValues, aggregates, child)
+
case Project(projectList, child) if child.resolved && hasUnresolvedAlias(projectList) =>
Project(assignAliases(projectList), child)
}
@@ -248,6 +253,43 @@ class Analyzer(
}
}
+ object ResolvePivot extends Rule[LogicalPlan] {
+ def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+ case p: Pivot if !p.childrenResolved => p
+ case Pivot(groupByExprs, pivotColumn, pivotValues, aggregates, child) =>
+ val singleAgg = aggregates.size == 1
+ val pivotAggregates: Seq[NamedExpression] = pivotValues.flatMap { value =>
+ def ifExpr(expr: Expression) = {
+ If(EqualTo(pivotColumn, value), expr, Literal(null))
+ }
+ aggregates.map { aggregate =>
+ val filteredAggregate = aggregate.transformDown {
+ // Assumption is the aggregate function ignores nulls. This is true for all current
+ // AggregateFunction's with the exception of First and Last in their default mode
+ // (which we handle) and possibly some Hive UDAF's.
+ case First(expr, _) =>
+ First(ifExpr(expr), Literal(true))
+ case Last(expr, _) =>
+ Last(ifExpr(expr), Literal(true))
+ case a: AggregateFunction =>
+ a.withNewChildren(a.children.map(ifExpr))
+ }
+ if (filteredAggregate.fastEquals(aggregate)) {
+ throw new AnalysisException(
+ s"Aggregate expression required for pivot, found '$aggregate'")
+ }
+ val name = if (singleAgg) value.toString else value + "_" + aggregate.prettyString
+ Alias(filteredAggregate, name)()
+ }
+ }
+ val newGroupByExprs = groupByExprs.map {
+ case UnresolvedAlias(e) => e
+ case e => e
+ }
+ Aggregate(newGroupByExprs, groupByExprs ++ pivotAggregates, child)
+ }
+ }
+
/**
* Replaces [[UnresolvedRelation]]s with concrete relations from the catalog.
*/
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
index 597f03e752..32b09b59af 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
@@ -386,6 +386,20 @@ case class Rollup(
this.copy(aggregations = aggs)
}
+case class Pivot(
+ groupByExprs: Seq[NamedExpression],
+ pivotColumn: Expression,
+ pivotValues: Seq[Literal],
+ aggregates: Seq[Expression],
+ child: LogicalPlan) extends UnaryNode {
+ override def output: Seq[Attribute] = groupByExprs.map(_.toAttribute) ++ aggregates match {
+ case agg :: Nil => pivotValues.map(value => AttributeReference(value.toString, agg.dataType)())
+ case _ => pivotValues.flatMap{ value =>
+ aggregates.map(agg => AttributeReference(value + "_" + agg.prettyString, agg.dataType)())
+ }
+ }
+}
+
case class Limit(limitExpr: Expression, child: LogicalPlan) extends UnaryNode {
override def output: Seq[Attribute] = child.output