aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala10
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala11
2 files changed, 20 insertions, 1 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 92bf8e0536..5210f42c55 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -373,7 +373,15 @@ class Analyzer(
case Pivot(groupByExprs, pivotColumn, pivotValues, aggregates, child) =>
val singleAgg = aggregates.size == 1
def outputName(value: Literal, aggregate: Expression): String = {
- if (singleAgg) value.toString else value + "_" + aggregate.sql
+ if (singleAgg) {
+ value.toString
+ } else {
+ val suffix = aggregate match {
+ case n: NamedExpression => n.name
+ case _ => aggregate.sql
+ }
+ value + "_" + suffix
+ }
}
if (aggregates.forall(a => PivotFirst.supportsDataType(a.dataType))) {
// Since evaluating |pivotValues| if statements for each input row can get slow this is an
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala
index d5cb5e1568..1bbe1354d5 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala
@@ -197,4 +197,15 @@ class DataFramePivotSuite extends QueryTest with SharedSQLContext{
Row(2013, Seq(48000.0, 7.0), Seq(30000.0, 7.0)) :: Nil
)
}
+
+ test("pivot preserves aliases if given") {
+ assertResult(
+ Array("year", "dotNET_foo", "dotNET_avg(`earnings`)", "Java_foo", "Java_avg(`earnings`)")
+ )(
+ courseSales.groupBy($"year")
+ .pivot("course", Seq("dotNET", "Java"))
+ .agg(sum($"earnings").as("foo"), avg($"earnings")).columns
+ )
+ }
+
}