aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorWenchen Fan <cloud0fan@outlook.com>2015-06-22 12:13:00 -0700
committerMichael Armbrust <michael@databricks.com>2015-06-22 12:13:00 -0700
commitda7bbb9435dae9a3bedad578599d96ea858f349e (patch)
tree6f16ac50aefadb46ef8e2b2f8acb890984cba0fc
parent5d89d9f00ba4d6d0767a4c4964d3af324bf6f14b (diff)
downloadspark-da7bbb9435dae9a3bedad578599d96ea858f349e.tar.gz
spark-da7bbb9435dae9a3bedad578599d96ea858f349e.tar.bz2
spark-da7bbb9435dae9a3bedad578599d96ea858f349e.zip
[SPARK-8104] [SQL] auto alias expressions in analyzer
Currently we auto alias expression in parser. However, during parser phase we don't have enough information to do the right alias. For example, Generator that has more than 1 kind of element need MultiAlias, ExtractValue don't need Alias if it's in middle of a ExtractValue chain. Author: Wenchen Fan <cloud0fan@outlook.com> Closes #6647 from cloud-fan/alias and squashes the following commits: 552eba4 [Wenchen Fan] fix python 5b5786d [Wenchen Fan] fix agg 73a90cb [Wenchen Fan] fix case-preserve of ExtractValue 4cfd23c [Wenchen Fan] fix order by d18f401 [Wenchen Fan] refine 9f07359 [Wenchen Fan] address comments 39c1aef [Wenchen Fan] small fix 33640ec [Wenchen Fan] auto alias expressions in analyzer
-rw-r--r--python/pyspark/sql/context.py9
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala11
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala77
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala9
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala20
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala36
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala6
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala11
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala20
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/Column.scala1
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala6
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala43
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala6
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/TestData.scala1
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala9
16 files changed, 150 insertions, 117 deletions
diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py
index 599c9ac579..dc239226e6 100644
--- a/python/pyspark/sql/context.py
+++ b/python/pyspark/sql/context.py
@@ -86,7 +86,8 @@ class SQLContext(object):
>>> df.registerTempTable("allTypes")
>>> sqlContext.sql('select i+1, d+1, not b, list[1], dict["s"], time, row.a '
... 'from allTypes where b and i > 0').collect()
- [Row(c0=2, c1=2.0, c2=False, c3=2, c4=0, time=datetime.datetime(2014, 8, 1, 14, 1, 5), a=1)]
+ [Row(_c0=2, _c1=2.0, _c2=False, _c3=2, _c4=0, \
+ time=datetime.datetime(2014, 8, 1, 14, 1, 5), a=1)]
>>> df.map(lambda x: (x.i, x.s, x.d, x.l, x.b, x.time, x.row.a, x.list)).collect()
[(1, u'string', 1.0, 1, True, datetime.datetime(2014, 8, 1, 14, 1, 5), 1, [1, 2, 3])]
"""
@@ -176,17 +177,17 @@ class SQLContext(object):
>>> sqlContext.registerFunction("stringLengthString", lambda x: len(x))
>>> sqlContext.sql("SELECT stringLengthString('test')").collect()
- [Row(c0=u'4')]
+ [Row(_c0=u'4')]
>>> from pyspark.sql.types import IntegerType
>>> sqlContext.registerFunction("stringLengthInt", lambda x: len(x), IntegerType())
>>> sqlContext.sql("SELECT stringLengthInt('test')").collect()
- [Row(c0=4)]
+ [Row(_c0=4)]
>>> from pyspark.sql.types import IntegerType
>>> sqlContext.udf.register("stringLengthInt", lambda x: len(x), IntegerType())
>>> sqlContext.sql("SELECT stringLengthInt('test')").collect()
- [Row(c0=4)]
+ [Row(_c0=4)]
"""
func = lambda _, it: map(lambda x: f(*x), it)
ser = AutoBatchedSerializer(PickleSerializer())
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
index da3a717f90..79f526e823 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
@@ -99,13 +99,6 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser {
protected val WHERE = Keyword("WHERE")
protected val WITH = Keyword("WITH")
- protected def assignAliases(exprs: Seq[Expression]): Seq[NamedExpression] = {
- exprs.zipWithIndex.map {
- case (ne: NamedExpression, _) => ne
- case (e, i) => Alias(e, s"c$i")()
- }
- }
-
protected lazy val start: Parser[LogicalPlan] =
start1 | insert | cte
@@ -130,8 +123,8 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser {
val base = r.getOrElse(OneRowRelation)
val withFilter = f.map(Filter(_, base)).getOrElse(base)
val withProjection = g
- .map(Aggregate(_, assignAliases(p), withFilter))
- .getOrElse(Project(assignAliases(p), withFilter))
+ .map(Aggregate(_, p.map(UnresolvedAlias(_)), withFilter))
+ .getOrElse(Project(p.map(UnresolvedAlias(_)), withFilter))
val withDistinct = d.map(_ => Distinct(withProjection)).getOrElse(withProjection)
val withHaving = h.map(Filter(_, withDistinct)).getOrElse(withDistinct)
val withOrder = o.map(_(withHaving)).getOrElse(withHaving)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 21b0576025..6311784422 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -17,8 +17,6 @@
package org.apache.spark.sql.catalyst.analysis
-import scala.collection.mutable.ArrayBuffer
-
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.{SimpleCatalystConf, CatalystConf}
import org.apache.spark.sql.catalyst.expressions._
@@ -74,10 +72,10 @@ class Analyzer(
ResolveSortReferences ::
ResolveGenerate ::
ResolveFunctions ::
+ ResolveAliases ::
ExtractWindowExpressions ::
GlobalAggregates ::
UnresolvedHavingClauseAttributes ::
- TrimGroupingAliases ::
typeCoercionRules ++
extendedResolutionRules : _*)
)
@@ -132,12 +130,38 @@ class Analyzer(
}
/**
- * Removes no-op Alias expressions from the plan.
+ * Replaces [[UnresolvedAlias]]s with concrete aliases.
*/
- object TrimGroupingAliases extends Rule[LogicalPlan] {
- def apply(plan: LogicalPlan): LogicalPlan = plan transform {
- case Aggregate(groups, aggs, child) =>
- Aggregate(groups.map(_.transform { case Alias(c, _) => c }), aggs, child)
+ object ResolveAliases extends Rule[LogicalPlan] {
+ private def assignAliases(exprs: Seq[NamedExpression]) = {
+ // The `UnresolvedAlias`s will appear only at root of a expression tree, we don't need
+ // to transform down the whole tree.
+ exprs.zipWithIndex.map {
+ case (u @ UnresolvedAlias(child), i) =>
+ child match {
+ case _: UnresolvedAttribute => u
+ case ne: NamedExpression => ne
+ case ev: ExtractValueWithStruct => Alias(ev, ev.field.name)()
+ case g: Generator if g.resolved && g.elementTypes.size > 1 => MultiAlias(g, Nil)
+ case e if !e.resolved => u
+ case other => Alias(other, s"_c$i")()
+ }
+ case (other, _) => other
+ }
+ }
+
+ def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
+ case Aggregate(groups, aggs, child)
+ if child.resolved && aggs.exists(_.isInstanceOf[UnresolvedAlias]) =>
+ Aggregate(groups, assignAliases(aggs), child)
+
+ case g: GroupingAnalytics
+ if g.child.resolved && g.aggregations.exists(_.isInstanceOf[UnresolvedAlias]) =>
+ g.withNewAggs(assignAliases(g.aggregations))
+
+ case Project(projectList, child)
+ if child.resolved && projectList.exists(_.isInstanceOf[UnresolvedAlias]) =>
+ Project(assignAliases(projectList), child)
}
}
@@ -228,7 +252,7 @@ class Analyzer(
}
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
- case i@InsertIntoTable(u: UnresolvedRelation, _, _, _, _) =>
+ case i @ InsertIntoTable(u: UnresolvedRelation, _, _, _, _) =>
i.copy(table = EliminateSubQueries(getTable(u)))
case u: UnresolvedRelation =>
getTable(u)
@@ -248,24 +272,24 @@ class Analyzer(
Project(
projectList.flatMap {
case s: Star => s.expand(child.output, resolver)
- case Alias(f @ UnresolvedFunction(_, args), name) if containsStar(args) =>
+ case UnresolvedAlias(f @ UnresolvedFunction(_, args)) if containsStar(args) =>
val expandedArgs = args.flatMap {
case s: Star => s.expand(child.output, resolver)
case o => o :: Nil
}
- Alias(child = f.copy(children = expandedArgs), name)() :: Nil
- case Alias(c @ CreateArray(args), name) if containsStar(args) =>
+ UnresolvedAlias(child = f.copy(children = expandedArgs)) :: Nil
+ case UnresolvedAlias(c @ CreateArray(args)) if containsStar(args) =>
val expandedArgs = args.flatMap {
case s: Star => s.expand(child.output, resolver)
case o => o :: Nil
}
- Alias(c.copy(children = expandedArgs), name)() :: Nil
- case Alias(c @ CreateStruct(args), name) if containsStar(args) =>
+ UnresolvedAlias(c.copy(children = expandedArgs)) :: Nil
+ case UnresolvedAlias(c @ CreateStruct(args)) if containsStar(args) =>
val expandedArgs = args.flatMap {
case s: Star => s.expand(child.output, resolver)
case o => o :: Nil
}
- Alias(c.copy(children = expandedArgs), name)() :: Nil
+ UnresolvedAlias(c.copy(children = expandedArgs)) :: Nil
case o => o :: Nil
},
child)
@@ -353,7 +377,9 @@ class Analyzer(
case u @ UnresolvedAttribute(nameParts) =>
// Leave unchanged if resolution fails. Hopefully will be resolved next round.
val result =
- withPosition(u) { q.resolveChildren(nameParts, resolver).getOrElse(u) }
+ withPosition(u) {
+ q.resolveChildren(nameParts, resolver).map(trimUnresolvedAlias).getOrElse(u)
+ }
logDebug(s"Resolving $u to $result")
result
case UnresolvedExtractValue(child, fieldExpr) if child.resolved =>
@@ -379,6 +405,11 @@ class Analyzer(
exprs.exists(_.collect { case _: Star => true }.nonEmpty)
}
+ private def trimUnresolvedAlias(ne: NamedExpression) = ne match {
+ case UnresolvedAlias(child) => child
+ case other => other
+ }
+
private def resolveSortOrders(ordering: Seq[SortOrder], plan: LogicalPlan, throws: Boolean) = {
ordering.map { order =>
// Resolve SortOrder in one round.
@@ -388,7 +419,7 @@ class Analyzer(
try {
val newOrder = order transformUp {
case u @ UnresolvedAttribute(nameParts) =>
- plan.resolve(nameParts, resolver).getOrElse(u)
+ plan.resolve(nameParts, resolver).map(trimUnresolvedAlias).getOrElse(u)
case UnresolvedExtractValue(child, fieldName) if child.resolved =>
ExtractValue(child, fieldName, resolver)
}
@@ -586,18 +617,6 @@ class Analyzer(
/** Extracts a [[Generator]] expression and any names assigned by aliases to their output. */
private object AliasedGenerator {
def unapply(e: Expression): Option[(Generator, Seq[String])] = e match {
- case Alias(g: Generator, name)
- if g.resolved &&
- g.elementTypes.size > 1 &&
- java.util.regex.Pattern.matches("_c[0-9]+", name) => {
- // Assume the default name given by parser is "_c[0-9]+",
- // TODO in long term, move the naming logic from Parser to Analyzer.
- // In projection, Parser gave default name for TGF as does for normal UDF,
- // but the TGF probably have multiple output columns/names.
- // e.g. SELECT explode(map(key, value)) FROM src;
- // Let's simply ignore the default given name for this case.
- Some((g, Nil))
- }
case Alias(g: Generator, name) if g.resolved && g.elementTypes.size > 1 =>
// If not given the default names, and the TGF with multiple output columns
failAnalysis(
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
index 7fabd2bfc8..c5a1437be6 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
@@ -95,14 +95,7 @@ trait CheckAnalysis {
case e => e.children.foreach(checkValidAggregateExpression)
}
- val cleaned = aggregateExprs.map(_.transform {
- // Should trim aliases around `GetField`s. These aliases are introduced while
- // resolving struct field accesses, because `GetField` is not a `NamedExpression`.
- // (Should we just turn `GetField` into a `NamedExpression`?)
- case Alias(g, _) => g
- })
-
- cleaned.foreach(checkValidAggregateExpression)
+ aggregateExprs.foreach(checkValidAggregateExpression)
case _ => // Fallbacks to the following checks
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
index c9d9142578..ae3adbab05 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
@@ -17,7 +17,6 @@
package org.apache.spark.sql.catalyst.analysis
-import org.apache.spark.sql.catalyst
import org.apache.spark.sql.catalyst.{errors, trees}
import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.catalyst.expressions._
@@ -206,3 +205,22 @@ case class UnresolvedExtractValue(child: Expression, extraction: Expression)
override def toString: String = s"$child[$extraction]"
}
+
+/**
+ * Holds the expression that has yet to be aliased.
+ */
+case class UnresolvedAlias(child: Expression) extends NamedExpression
+ with trees.UnaryNode[Expression] {
+
+ override def toAttribute: Attribute = throw new UnresolvedException(this, "toAttribute")
+ override def qualifiers: Seq[String] = throw new UnresolvedException(this, "qualifiers")
+ override def exprId: ExprId = throw new UnresolvedException(this, "exprId")
+ override def nullable: Boolean = throw new UnresolvedException(this, "nullable")
+ override def dataType: DataType = throw new UnresolvedException(this, "dataType")
+ override def name: String = throw new UnresolvedException(this, "name")
+
+ override lazy val resolved = false
+
+ override def eval(input: InternalRow = null): Any =
+ throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}")
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala
index 4aaabff15b..013027b199 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions
import scala.collection.Map
-import org.apache.spark.sql.{catalyst, AnalysisException}
+import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.types._
@@ -41,16 +41,22 @@ object ExtractValue {
resolver: Resolver): ExtractValue = {
(child.dataType, extraction) match {
- case (StructType(fields), Literal(fieldName, StringType)) =>
- val ordinal = findField(fields, fieldName.toString, resolver)
- GetStructField(child, fields(ordinal), ordinal)
- case (ArrayType(StructType(fields), containsNull), Literal(fieldName, StringType)) =>
- val ordinal = findField(fields, fieldName.toString, resolver)
- GetArrayStructFields(child, fields(ordinal), ordinal, containsNull)
+ case (StructType(fields), NonNullLiteral(v, StringType)) =>
+ val fieldName = v.toString
+ val ordinal = findField(fields, fieldName, resolver)
+ GetStructField(child, fields(ordinal).copy(name = fieldName), ordinal)
+
+ case (ArrayType(StructType(fields), containsNull), NonNullLiteral(v, StringType)) =>
+ val fieldName = v.toString
+ val ordinal = findField(fields, fieldName, resolver)
+ GetArrayStructFields(child, fields(ordinal).copy(name = fieldName), ordinal, containsNull)
+
case (_: ArrayType, _) if extraction.dataType.isInstanceOf[IntegralType] =>
GetArrayItem(child, extraction)
+
case (_: MapType, _) =>
GetMapValue(child, extraction)
+
case (otherType, _) =>
val errorMsg = otherType match {
case StructType(_) | ArrayType(StructType(_), _) =>
@@ -94,16 +100,21 @@ trait ExtractValue extends UnaryExpression {
self: Product =>
}
+abstract class ExtractValueWithStruct extends ExtractValue {
+ self: Product =>
+
+ def field: StructField
+ override def toString: String = s"$child.${field.name}"
+}
+
/**
* Returns the value of fields in the Struct `child`.
*/
case class GetStructField(child: Expression, field: StructField, ordinal: Int)
- extends ExtractValue {
+ extends ExtractValueWithStruct {
override def dataType: DataType = field.dataType
override def nullable: Boolean = child.nullable || field.nullable
- override def foldable: Boolean = child.foldable
- override def toString: String = s"$child.${field.name}"
override def eval(input: InternalRow): Any = {
val baseValue = child.eval(input).asInstanceOf[InternalRow]
@@ -118,12 +129,9 @@ case class GetArrayStructFields(
child: Expression,
field: StructField,
ordinal: Int,
- containsNull: Boolean) extends ExtractValue {
+ containsNull: Boolean) extends ExtractValueWithStruct {
override def dataType: DataType = ArrayType(field.dataType, containsNull)
- override def nullable: Boolean = child.nullable
- override def foldable: Boolean = child.foldable
- override def toString: String = s"$child.${field.name}"
override def eval(input: InternalRow): Any = {
val baseValue = child.eval(input).asInstanceOf[Seq[InternalRow]]
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
index 3b6f8bfd9f..179a348d5b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
@@ -156,12 +156,8 @@ object PartialAggregation {
partialEvaluations(new TreeNodeRef(e)).finalEvaluation
case e: Expression =>
- // Should trim aliases around `GetField`s. These aliases are introduced while
- // resolving struct field accesses, because `GetField` is not a `NamedExpression`.
- // (Should we just turn `GetField` into a `NamedExpression`?)
- val trimmed = e.transform { case Alias(g: ExtractValue, _) => g }
namedGroupingExpressions.collectFirst {
- case (expr, ne) if expr semanticEquals trimmed => ne.toAttribute
+ case (expr, ne) if expr semanticEquals e => ne.toAttribute
}.getOrElse(e)
}).asInstanceOf[Seq[NamedExpression]]
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
index a853e27c12..b009a200b9 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.Logging
import org.apache.spark.sql.AnalysisException
-import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, EliminateSubQueries, Resolver}
+import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.trees.TreeNode
@@ -252,14 +252,13 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging {
// One match, but we also need to extract the requested nested field.
case Seq((a, nestedFields)) =>
// The foldLeft adds ExtractValues for every remaining parts of the identifier,
- // and aliases it with the last part of the identifier.
+ // and wrap it with UnresolvedAlias which will be removed later.
// For example, consider "a.b.c", where "a" is resolved to an existing attribute.
- // Then this will add ExtractValue("c", ExtractValue("b", a)), and alias
- // the final expression as "c".
+ // Then this will add ExtractValue("c", ExtractValue("b", a)), and wrap it as
+ // UnresolvedAlias(ExtractValue("c", ExtractValue("b", a))).
val fieldExprs = nestedFields.foldLeft(a: Expression)((expr, fieldName) =>
ExtractValue(expr, Literal(fieldName), resolver))
- val aliasName = nestedFields.last
- Some(Alias(fieldExprs, aliasName)())
+ Some(UnresolvedAlias(fieldExprs))
// No matches.
case Seq() =>
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
index 963c782091..f8e5916d69 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
@@ -242,6 +242,8 @@ trait GroupingAnalytics extends UnaryNode {
def aggregations: Seq[NamedExpression]
override def output: Seq[Attribute] = aggregations.map(_.toAttribute)
+
+ def withNewAggs(aggs: Seq[NamedExpression]): GroupingAnalytics
}
/**
@@ -266,7 +268,11 @@ case class GroupingSets(
groupByExprs: Seq[Expression],
child: LogicalPlan,
aggregations: Seq[NamedExpression],
- gid: AttributeReference = VirtualColumn.newGroupingId) extends GroupingAnalytics
+ gid: AttributeReference = VirtualColumn.newGroupingId) extends GroupingAnalytics {
+
+ def withNewAggs(aggs: Seq[NamedExpression]): GroupingAnalytics =
+ this.copy(aggregations = aggs)
+}
/**
* Cube is a syntactic sugar for GROUPING SETS, and will be transformed to GroupingSets,
@@ -284,7 +290,11 @@ case class Cube(
groupByExprs: Seq[Expression],
child: LogicalPlan,
aggregations: Seq[NamedExpression],
- gid: AttributeReference = VirtualColumn.newGroupingId) extends GroupingAnalytics
+ gid: AttributeReference = VirtualColumn.newGroupingId) extends GroupingAnalytics {
+
+ def withNewAggs(aggs: Seq[NamedExpression]): GroupingAnalytics =
+ this.copy(aggregations = aggs)
+}
/**
* Rollup is a syntactic sugar for GROUPING SETS, and will be transformed to GroupingSets,
@@ -303,7 +313,11 @@ case class Rollup(
groupByExprs: Seq[Expression],
child: LogicalPlan,
aggregations: Seq[NamedExpression],
- gid: AttributeReference = VirtualColumn.newGroupingId) extends GroupingAnalytics
+ gid: AttributeReference = VirtualColumn.newGroupingId) extends GroupingAnalytics {
+
+ def withNewAggs(aggs: Seq[NamedExpression]): GroupingAnalytics =
+ this.copy(aggregations = aggs)
+}
case class Limit(limitExpr: Expression, child: LogicalPlan) extends UnaryNode {
override def output: Seq[Attribute] = child.output
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
index b4e008a6e8..f201c8ea8a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
@@ -21,7 +21,6 @@ import scala.language.implicitConversions
import org.apache.spark.annotation.Experimental
import org.apache.spark.Logging
-import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions.lit
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.analysis._
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
index 466258e76f..492a3321bc 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
@@ -32,7 +32,7 @@ import org.apache.spark.annotation.{DeveloperApi, Experimental}
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.api.python.SerDeUtil
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.catalyst.analysis.{MultiAlias, ResolvedStar, UnresolvedAttribute}
+import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.{Filter, _}
import org.apache.spark.sql.catalyst.plans.{Inner, JoinType}
@@ -629,6 +629,10 @@ class DataFrame private[sql](
@scala.annotation.varargs
def select(cols: Column*): DataFrame = {
val namedExpressions = cols.map {
+ // Wrap UnresolvedAttribute with UnresolvedAlias, as when we resolve UnresolvedAttribute, we
+ // will remove intermediate Alias for ExtractValue chain, and we need to alias it again to
+ // make it a NamedExpression.
+ case Column(u: UnresolvedAttribute) => UnresolvedAlias(u)
case Column(expr: NamedExpression) => expr
// Leave an unaliased explode with an empty list of names since the analzyer will generate the
// correct defaults after the nested expression's type has been resolved.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala
index 45b3e1bc62..99d557b03a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala
@@ -21,7 +21,7 @@ import scala.collection.JavaConversions._
import scala.language.implicitConversions
import org.apache.spark.annotation.Experimental
-import org.apache.spark.sql.catalyst.analysis.Star
+import org.apache.spark.sql.catalyst.analysis.{UnresolvedAlias, UnresolvedAttribute, Star}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.{Rollup, Cube, Aggregate}
import org.apache.spark.sql.types.NumericType
@@ -70,27 +70,31 @@ class GroupedData protected[sql](
groupingExprs: Seq[Expression],
private val groupType: GroupedData.GroupType) {
- private[this] def toDF(aggExprs: Seq[NamedExpression]): DataFrame = {
+ private[this] def toDF(aggExprs: Seq[Expression]): DataFrame = {
val aggregates = if (df.sqlContext.conf.dataFrameRetainGroupColumns) {
- val retainedExprs = groupingExprs.map {
- case expr: NamedExpression => expr
- case expr: Expression => Alias(expr, expr.prettyString)()
- }
- retainedExprs ++ aggExprs
- } else {
- aggExprs
- }
+ groupingExprs ++ aggExprs
+ } else {
+ aggExprs
+ }
+ val aliasedAgg = aggregates.map {
+ // Wrap UnresolvedAttribute with UnresolvedAlias, as when we resolve UnresolvedAttribute, we
+ // will remove intermediate Alias for ExtractValue chain, and we need to alias it again to
+ // make it a NamedExpression.
+ case u: UnresolvedAttribute => UnresolvedAlias(u)
+ case expr: NamedExpression => expr
+ case expr: Expression => Alias(expr, expr.prettyString)()
+ }
groupType match {
case GroupedData.GroupByType =>
DataFrame(
- df.sqlContext, Aggregate(groupingExprs, aggregates, df.logicalPlan))
+ df.sqlContext, Aggregate(groupingExprs, aliasedAgg, df.logicalPlan))
case GroupedData.RollupType =>
DataFrame(
- df.sqlContext, Rollup(groupingExprs, df.logicalPlan, aggregates))
+ df.sqlContext, Rollup(groupingExprs, df.logicalPlan, aliasedAgg))
case GroupedData.CubeType =>
DataFrame(
- df.sqlContext, Cube(groupingExprs, df.logicalPlan, aggregates))
+ df.sqlContext, Cube(groupingExprs, df.logicalPlan, aliasedAgg))
}
}
@@ -112,10 +116,7 @@ class GroupedData protected[sql](
namedExpr
}
}
- toDF(columnExprs.map { c =>
- val a = f(c)
- Alias(a, a.prettyString)()
- })
+ toDF(columnExprs.map(f))
}
private[this] def strToExpr(expr: String): (Expression => Expression) = {
@@ -169,8 +170,7 @@ class GroupedData protected[sql](
*/
def agg(exprs: Map[String, String]): DataFrame = {
toDF(exprs.map { case (colName, expr) =>
- val a = strToExpr(expr)(df(colName).expr)
- Alias(a, a.prettyString)()
+ strToExpr(expr)(df(colName).expr)
}.toSeq)
}
@@ -224,10 +224,7 @@ class GroupedData protected[sql](
*/
@scala.annotation.varargs
def agg(expr: Column, exprs: Column*): DataFrame = {
- toDF((expr +: exprs).map(_.expr).map {
- case expr: NamedExpression => expr
- case expr: Expression => Alias(expr, expr.prettyString)()
- })
+ toDF((expr +: exprs).map(_.expr))
}
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala
index 1ce150ceaf..c8c67ce334 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala
@@ -74,7 +74,7 @@ private[spark] object ExtractPythonUdfs extends Rule[LogicalPlan] {
// Skip EvaluatePython nodes.
case plan: EvaluatePython => plan
- case plan: LogicalPlan =>
+ case plan: LogicalPlan if plan.resolved =>
// Extract any PythonUDFs from the current operator.
val udfs = plan.expressions.flatMap(_.collect { case udf: PythonUDF => udf })
if (udfs.isEmpty) {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index 4441afd6bd..73bc6c9991 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -1367,9 +1367,9 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils {
test("SPARK-6145: special cases") {
sqlContext.read.json(sqlContext.sparkContext.makeRDD(
- """{"a": {"b": [1]}, "b": [{"a": 1}], "c0": {"a": 1}}""" :: Nil)).registerTempTable("t")
- checkAnswer(sql("SELECT a.b[0] FROM t ORDER BY c0.a"), Row(1))
- checkAnswer(sql("SELECT b[0].a FROM t ORDER BY c0.a"), Row(1))
+ """{"a": {"b": [1]}, "b": [{"a": 1}], "_c0": {"a": 1}}""" :: Nil)).registerTempTable("t")
+ checkAnswer(sql("SELECT a.b[0] FROM t ORDER BY _c0.a"), Row(1))
+ checkAnswer(sql("SELECT b[0].a FROM t ORDER BY _c0.a"), Row(1))
}
test("SPARK-6898: complete support for special chars in column names") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala
index 520a862ea0..207d7a352c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala
@@ -19,7 +19,6 @@ package org.apache.spark.sql
import java.sql.Timestamp
-import org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.test.TestSQLContext.implicits._
import org.apache.spark.sql.test._
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
index ca4b80b51b..7c4620952b 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
@@ -415,13 +415,6 @@ private[hive] object HiveQl {
throw new NotImplementedError(s"No parse rules for StructField:\n ${dumpTree(a).toString} ")
}
- protected def nameExpressions(exprs: Seq[Expression]): Seq[NamedExpression] = {
- exprs.zipWithIndex.map {
- case (ne: NamedExpression, _) => ne
- case (e, i) => Alias(e, s"_c$i")()
- }
- }
-
protected def extractDbNameTableName(tableNameParts: Node): (Option[String], String) = {
val (db, tableName) =
tableNameParts.getChildren.map { case Token(part, Nil) => cleanIdentifier(part) } match {
@@ -942,7 +935,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
// (if there is a group by) or a script transformation.
val withProject: LogicalPlan = transformation.getOrElse {
val selectExpressions =
- nameExpressions(select.getChildren.flatMap(selExprNodeToExpr).toSeq)
+ select.getChildren.flatMap(selExprNodeToExpr).map(UnresolvedAlias(_)).toSeq
Seq(
groupByClause.map(e => e match {
case Token("TOK_GROUPBY", children) =>