diff options
author | Andrew Ray <ray.andrew@gmail.com> | 2015-11-11 16:23:24 -0800 |
---|---|---|
committer | Yin Huai <yhuai@databricks.com> | 2015-11-11 16:23:24 -0800 |
commit | b8ff6888e76b437287d7d6bf2d4b9c759710a195 (patch) | |
tree | 3f6a821b341a99dd3534603c88dfa6845d8982a8 /sql/catalyst | |
parent | 1a21be15f655b9696ddac80aac629445a465f621 (diff) | |
download | spark-b8ff6888e76b437287d7d6bf2d4b9c759710a195.tar.gz spark-b8ff6888e76b437287d7d6bf2d4b9c759710a195.tar.bz2 spark-b8ff6888e76b437287d7d6bf2d4b9c759710a195.zip |
[SPARK-8992][SQL] Add pivot to dataframe api
This adds a pivot method to the dataframe api.
Following the lead of cube and rollup this adds a Pivot operator that is translated into an Aggregate by the analyzer.
Currently the syntax is like:
~~courseSales.pivot(Seq($"year"), $"course", Seq("dotNET", "Java"), sum($"earnings"))~~
~~Would we be interested in the following syntax also/alternatively? and~~
courseSales.groupBy($"year").pivot($"course", "dotNET", "Java").agg(sum($"earnings"))
//or
courseSales.groupBy($"year").pivot($"course").agg(sum($"earnings"))
Later we can add it to `SQLParser`, but as Hive doesn't support it we cant add it there, right?
~~Also what would be the suggested Java friendly method signature for this?~~
Author: Andrew Ray <ray.andrew@gmail.com>
Closes #7841 from aray/sql-pivot.
Diffstat (limited to 'sql/catalyst')
-rw-r--r-- | sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala | 42 | ||||
-rw-r--r-- | sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala | 14 |
2 files changed, 56 insertions, 0 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index a9cd9a7703..2f4670b55b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -72,6 +72,7 @@ class Analyzer( ResolveRelations :: ResolveReferences :: ResolveGroupingAnalytics :: + ResolvePivot :: ResolveSortReferences :: ResolveGenerate :: ResolveFunctions :: @@ -166,6 +167,10 @@ class Analyzer( case g: GroupingAnalytics if g.child.resolved && hasUnresolvedAlias(g.aggregations) => g.withNewAggs(assignAliases(g.aggregations)) + case Pivot(groupByExprs, pivotColumn, pivotValues, aggregates, child) + if child.resolved && hasUnresolvedAlias(groupByExprs) => + Pivot(assignAliases(groupByExprs), pivotColumn, pivotValues, aggregates, child) + case Project(projectList, child) if child.resolved && hasUnresolvedAlias(projectList) => Project(assignAliases(projectList), child) } @@ -248,6 +253,43 @@ class Analyzer( } } + object ResolvePivot extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case p: Pivot if !p.childrenResolved => p + case Pivot(groupByExprs, pivotColumn, pivotValues, aggregates, child) => + val singleAgg = aggregates.size == 1 + val pivotAggregates: Seq[NamedExpression] = pivotValues.flatMap { value => + def ifExpr(expr: Expression) = { + If(EqualTo(pivotColumn, value), expr, Literal(null)) + } + aggregates.map { aggregate => + val filteredAggregate = aggregate.transformDown { + // Assumption is the aggregate function ignores nulls. This is true for all current + // AggregateFunction's with the exception of First and Last in their default mode + // (which we handle) and possibly some Hive UDAF's. + case First(expr, _) => + First(ifExpr(expr), Literal(true)) + case Last(expr, _) => + Last(ifExpr(expr), Literal(true)) + case a: AggregateFunction => + a.withNewChildren(a.children.map(ifExpr)) + } + if (filteredAggregate.fastEquals(aggregate)) { + throw new AnalysisException( + s"Aggregate expression required for pivot, found '$aggregate'") + } + val name = if (singleAgg) value.toString else value + "_" + aggregate.prettyString + Alias(filteredAggregate, name)() + } + } + val newGroupByExprs = groupByExprs.map { + case UnresolvedAlias(e) => e + case e => e + } + Aggregate(newGroupByExprs, groupByExprs ++ pivotAggregates, child) + } + } + /** * Replaces [[UnresolvedRelation]]s with concrete relations from the catalog. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 597f03e752..32b09b59af 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -386,6 +386,20 @@ case class Rollup( this.copy(aggregations = aggs) } +case class Pivot( + groupByExprs: Seq[NamedExpression], + pivotColumn: Expression, + pivotValues: Seq[Literal], + aggregates: Seq[Expression], + child: LogicalPlan) extends UnaryNode { + override def output: Seq[Attribute] = groupByExprs.map(_.toAttribute) ++ aggregates match { + case agg :: Nil => pivotValues.map(value => AttributeReference(value.toString, agg.dataType)()) + case _ => pivotValues.flatMap{ value => + aggregates.map(agg => AttributeReference(value + "_" + agg.prettyString, agg.dataType)()) + } + } +} + case class Limit(limitExpr: Expression, child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output |