diff options
author | Wenchen Fan <cloud0fan@outlook.com> | 2015-06-22 12:13:00 -0700 |
---|---|---|
committer | Michael Armbrust <michael@databricks.com> | 2015-06-22 12:13:00 -0700 |
commit | da7bbb9435dae9a3bedad578599d96ea858f349e (patch) | |
tree | 6f16ac50aefadb46ef8e2b2f8acb890984cba0fc /sql/core | |
parent | 5d89d9f00ba4d6d0767a4c4964d3af324bf6f14b (diff) | |
download | spark-da7bbb9435dae9a3bedad578599d96ea858f349e.tar.gz spark-da7bbb9435dae9a3bedad578599d96ea858f349e.tar.bz2 spark-da7bbb9435dae9a3bedad578599d96ea858f349e.zip |
[SPARK-8104] [SQL] auto alias expressions in analyzer
Currently we auto alias expression in parser. However, during parser phase we don't have enough information to do the right alias. For example, Generator that has more than 1 kind of element need MultiAlias, ExtractValue don't need Alias if it's in middle of a ExtractValue chain.
Author: Wenchen Fan <cloud0fan@outlook.com>
Closes #6647 from cloud-fan/alias and squashes the following commits:
552eba4 [Wenchen Fan] fix python
5b5786d [Wenchen Fan] fix agg
73a90cb [Wenchen Fan] fix case-preserve of ExtractValue
4cfd23c [Wenchen Fan] fix order by
d18f401 [Wenchen Fan] refine
9f07359 [Wenchen Fan] address comments
39c1aef [Wenchen Fan] small fix
33640ec [Wenchen Fan] auto alias expressions in analyzer
Diffstat (limited to 'sql/core')
6 files changed, 29 insertions, 30 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index b4e008a6e8..f201c8ea8a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -21,7 +21,6 @@ import scala.language.implicitConversions import org.apache.spark.annotation.Experimental import org.apache.spark.Logging -import org.apache.spark.sql.expressions.Window import org.apache.spark.sql.functions.lit import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.analysis._ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 466258e76f..492a3321bc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -32,7 +32,7 @@ import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.api.java.JavaRDD import org.apache.spark.api.python.SerDeUtil import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.analysis.{MultiAlias, ResolvedStar, UnresolvedAttribute} +import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.{Filter, _} import org.apache.spark.sql.catalyst.plans.{Inner, JoinType} @@ -629,6 +629,10 @@ class DataFrame private[sql]( @scala.annotation.varargs def select(cols: Column*): DataFrame = { val namedExpressions = cols.map { + // Wrap UnresolvedAttribute with UnresolvedAlias, as when we resolve UnresolvedAttribute, we + // will remove intermediate Alias for ExtractValue chain, and we need to alias it again to + // make it a NamedExpression. + case Column(u: UnresolvedAttribute) => UnresolvedAlias(u) case Column(expr: NamedExpression) => expr // Leave an unaliased explode with an empty list of names since the analzyer will generate the // correct defaults after the nested expression's type has been resolved. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala index 45b3e1bc62..99d557b03a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala @@ -21,7 +21,7 @@ import scala.collection.JavaConversions._ import scala.language.implicitConversions import org.apache.spark.annotation.Experimental -import org.apache.spark.sql.catalyst.analysis.Star +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAlias, UnresolvedAttribute, Star} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.{Rollup, Cube, Aggregate} import org.apache.spark.sql.types.NumericType @@ -70,27 +70,31 @@ class GroupedData protected[sql]( groupingExprs: Seq[Expression], private val groupType: GroupedData.GroupType) { - private[this] def toDF(aggExprs: Seq[NamedExpression]): DataFrame = { + private[this] def toDF(aggExprs: Seq[Expression]): DataFrame = { val aggregates = if (df.sqlContext.conf.dataFrameRetainGroupColumns) { - val retainedExprs = groupingExprs.map { - case expr: NamedExpression => expr - case expr: Expression => Alias(expr, expr.prettyString)() - } - retainedExprs ++ aggExprs - } else { - aggExprs - } + groupingExprs ++ aggExprs + } else { + aggExprs + } + val aliasedAgg = aggregates.map { + // Wrap UnresolvedAttribute with UnresolvedAlias, as when we resolve UnresolvedAttribute, we + // will remove intermediate Alias for ExtractValue chain, and we need to alias it again to + // make it a NamedExpression. + case u: UnresolvedAttribute => UnresolvedAlias(u) + case expr: NamedExpression => expr + case expr: Expression => Alias(expr, expr.prettyString)() + } groupType match { case GroupedData.GroupByType => DataFrame( - df.sqlContext, Aggregate(groupingExprs, aggregates, df.logicalPlan)) + df.sqlContext, Aggregate(groupingExprs, aliasedAgg, df.logicalPlan)) case GroupedData.RollupType => DataFrame( - df.sqlContext, Rollup(groupingExprs, df.logicalPlan, aggregates)) + df.sqlContext, Rollup(groupingExprs, df.logicalPlan, aliasedAgg)) case GroupedData.CubeType => DataFrame( - df.sqlContext, Cube(groupingExprs, df.logicalPlan, aggregates)) + df.sqlContext, Cube(groupingExprs, df.logicalPlan, aliasedAgg)) } } @@ -112,10 +116,7 @@ class GroupedData protected[sql]( namedExpr } } - toDF(columnExprs.map { c => - val a = f(c) - Alias(a, a.prettyString)() - }) + toDF(columnExprs.map(f)) } private[this] def strToExpr(expr: String): (Expression => Expression) = { @@ -169,8 +170,7 @@ class GroupedData protected[sql]( */ def agg(exprs: Map[String, String]): DataFrame = { toDF(exprs.map { case (colName, expr) => - val a = strToExpr(expr)(df(colName).expr) - Alias(a, a.prettyString)() + strToExpr(expr)(df(colName).expr) }.toSeq) } @@ -224,10 +224,7 @@ class GroupedData protected[sql]( */ @scala.annotation.varargs def agg(expr: Column, exprs: Column*): DataFrame = { - toDF((expr +: exprs).map(_.expr).map { - case expr: NamedExpression => expr - case expr: Expression => Alias(expr, expr.prettyString)() - }) + toDF((expr +: exprs).map(_.expr)) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala index 1ce150ceaf..c8c67ce334 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala @@ -74,7 +74,7 @@ private[spark] object ExtractPythonUdfs extends Rule[LogicalPlan] { // Skip EvaluatePython nodes. case plan: EvaluatePython => plan - case plan: LogicalPlan => + case plan: LogicalPlan if plan.resolved => // Extract any PythonUDFs from the current operator. val udfs = plan.expressions.flatMap(_.collect { case udf: PythonUDF => udf }) if (udfs.isEmpty) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 4441afd6bd..73bc6c9991 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -1367,9 +1367,9 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { test("SPARK-6145: special cases") { sqlContext.read.json(sqlContext.sparkContext.makeRDD( - """{"a": {"b": [1]}, "b": [{"a": 1}], "c0": {"a": 1}}""" :: Nil)).registerTempTable("t") - checkAnswer(sql("SELECT a.b[0] FROM t ORDER BY c0.a"), Row(1)) - checkAnswer(sql("SELECT b[0].a FROM t ORDER BY c0.a"), Row(1)) + """{"a": {"b": [1]}, "b": [{"a": 1}], "_c0": {"a": 1}}""" :: Nil)).registerTempTable("t") + checkAnswer(sql("SELECT a.b[0] FROM t ORDER BY _c0.a"), Row(1)) + checkAnswer(sql("SELECT b[0].a FROM t ORDER BY _c0.a"), Row(1)) } test("SPARK-6898: complete support for special chars in column names") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala index 520a862ea0..207d7a352c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql import java.sql.Timestamp -import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.test.TestSQLContext.implicits._ import org.apache.spark.sql.test._ |