aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorAndrew Ray <ray.andrew@gmail.com>2016-12-07 04:44:14 -0800
committerHerman van Hovell <hvanhovell@databricks.com>2016-12-07 04:44:14 -0800
commitf1fca81b165c5a673f7d86b268e04ea42a6c267e (patch)
treee688ad2a71b46b09fafe04318160513899f424f1 /sql
parentc496d03b5289f7c604661a12af86f6accddcf125 (diff)
downloadspark-f1fca81b165c5a673f7d86b268e04ea42a6c267e.tar.gz
spark-f1fca81b165c5a673f7d86b268e04ea42a6c267e.tar.bz2
spark-f1fca81b165c5a673f7d86b268e04ea42a6c267e.zip
[SPARK-17760][SQL] AnalysisException with dataframe pivot when groupBy column is not attribute
## What changes were proposed in this pull request? Fixes AnalysisException for pivot queries that have group by columns that are expressions and not attributes by substituting the expressions output attribute in the second aggregation and final projection. ## How was this patch tested? existing and additional unit tests Author: Andrew Ray <ray.andrew@gmail.com> Closes #16177 from aray/SPARK-17760.
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala5
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala8
2 files changed, 11 insertions, 2 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index ed6e17a8eb..58f98d529a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -463,14 +463,15 @@ class Analyzer(
.toAggregateExpression()
, "__pivot_" + a.sql)()
}
- val secondAgg = Aggregate(groupByExprs, groupByExprs ++ pivotAggs, firstAgg)
+ val groupByExprsAttr = groupByExprs.map(_.toAttribute)
+ val secondAgg = Aggregate(groupByExprsAttr, groupByExprsAttr ++ pivotAggs, firstAgg)
val pivotAggAttribute = pivotAggs.map(_.toAttribute)
val pivotOutputs = pivotValues.zipWithIndex.flatMap { case (value, i) =>
aggregates.zip(pivotAggAttribute).map { case (aggregate, pivotAtt) =>
Alias(ExtractValue(pivotAtt, Literal(i), resolver), outputName(value, aggregate))()
}
}
- Project(groupByExprs ++ pivotOutputs, secondAgg)
+ Project(groupByExprsAttr ++ pivotOutputs, secondAgg)
} else {
val pivotAggregates: Seq[NamedExpression] = pivotValues.flatMap { value =>
def ifExpr(expr: Expression) = {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala
index 1bbe1354d5..a8d854ccbc 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala
@@ -208,4 +208,12 @@ class DataFramePivotSuite extends QueryTest with SharedSQLContext{
)
}
+ test("pivot with column definition in groupby") {
+ checkAnswer(
+ courseSales.groupBy(substring(col("course"), 0, 1).as("foo"))
+ .pivot("year", Seq(2012, 2013))
+ .sum("earnings"),
+ Row("d", 15000.0, 48000.0) :: Row("J", 20000.0, 30000.0) :: Nil
+ )
+ }
}