aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--python/pyspark/sql/context.py9
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala11
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala77
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala9
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala20
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala36
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala6
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala11
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala20
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/Column.scala1
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala6
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala43
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala6
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/TestData.scala1
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala9
16 files changed, 150 insertions, 117 deletions
diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py
index 599c9ac579..dc239226e6 100644
--- a/python/pyspark/sql/context.py
+++ b/python/pyspark/sql/context.py
@@ -86,7 +86,8 @@ class SQLContext(object):
>>> df.registerTempTable("allTypes")
>>> sqlContext.sql('select i+1, d+1, not b, list[1], dict["s"], time, row.a '
... 'from allTypes where b and i > 0').collect()
- [Row(c0=2, c1=2.0, c2=False, c3=2, c4=0, time=datetime.datetime(2014, 8, 1, 14, 1, 5), a=1)]
+ [Row(_c0=2, _c1=2.0, _c2=False, _c3=2, _c4=0, \
+ time=datetime.datetime(2014, 8, 1, 14, 1, 5), a=1)]
>>> df.map(lambda x: (x.i, x.s, x.d, x.l, x.b, x.time, x.row.a, x.list)).collect()
[(1, u'string', 1.0, 1, True, datetime.datetime(2014, 8, 1, 14, 1, 5), 1, [1, 2, 3])]
"""
@@ -176,17 +177,17 @@ class SQLContext(object):
>>> sqlContext.registerFunction("stringLengthString", lambda x: len(x))
>>> sqlContext.sql("SELECT stringLengthString('test')").collect()
- [Row(c0=u'4')]
+ [Row(_c0=u'4')]
>>> from pyspark.sql.types import IntegerType
>>> sqlContext.registerFunction("stringLengthInt", lambda x: len(x), IntegerType())
>>> sqlContext.sql("SELECT stringLengthInt('test')").collect()
- [Row(c0=4)]
+ [Row(_c0=4)]
>>> from pyspark.sql.types import IntegerType
>>> sqlContext.udf.register("stringLengthInt", lambda x: len(x), IntegerType())
>>> sqlContext.sql("SELECT stringLengthInt('test')").collect()
- [Row(c0=4)]
+ [Row(_c0=4)]
"""
func = lambda _, it: map(lambda x: f(*x), it)
ser = AutoBatchedSerializer(PickleSerializer())
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
index da3a717f90..79f526e823 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
@@ -99,13 +99,6 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser {
protected val WHERE = Keyword("WHERE")
protected val WITH = Keyword("WITH")
- protected def assignAliases(exprs: Seq[Expression]): Seq[NamedExpression] = {
- exprs.zipWithIndex.map {
- case (ne: NamedExpression, _) => ne
- case (e, i) => Alias(e, s"c$i")()
- }
- }
-
protected lazy val start: Parser[LogicalPlan] =
start1 | insert | cte
@@ -130,8 +123,8 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser {
val base = r.getOrElse(OneRowRelation)
val withFilter = f.map(Filter(_, base)).getOrElse(base)
val withProjection = g
- .map(Aggregate(_, assignAliases(p), withFilter))
- .getOrElse(Project(assignAliases(p), withFilter))
+ .map(Aggregate(_, p.map(UnresolvedAlias(_)), withFilter))
+ .getOrElse(Project(p.map(UnresolvedAlias(_)), withFilter))
val withDistinct = d.map(_ => Distinct(withProjection)).getOrElse(withProjection)
val withHaving = h.map(Filter(_, withDistinct)).getOrElse(withDistinct)
val withOrder = o.map(_(withHaving)).getOrElse(withHaving)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 21b0576025..6311784422 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -17,8 +17,6 @@
package org.apache.spark.sql.catalyst.analysis
-import scala.collection.mutable.ArrayBuffer
-
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.{SimpleCatalystConf, CatalystConf}
import org.apache.spark.sql.catalyst.expressions._
@@ -74,10 +72,10 @@ class Analyzer(
ResolveSortReferences ::
ResolveGenerate ::
ResolveFunctions ::
+ ResolveAliases ::
ExtractWindowExpressions ::
GlobalAggregates ::
UnresolvedHavingClauseAttributes ::
- TrimGroupingAliases ::
typeCoercionRules ++
extendedResolutionRules : _*)
)
@@ -132,12 +130,38 @@ class Analyzer(
}
/**
- * Removes no-op Alias expressions from the plan.
+ * Replaces [[UnresolvedAlias]]s with concrete aliases.
*/
- object TrimGroupingAliases extends Rule[LogicalPlan] {
- def apply(plan: LogicalPlan): LogicalPlan = plan transform {
- case Aggregate(groups, aggs, child) =>
- Aggregate(groups.map(_.transform { case Alias(c, _) => c }), aggs, child)
+ object ResolveAliases extends Rule[LogicalPlan] {
+ private def assignAliases(exprs: Seq[NamedExpression]) = {
+ // The `UnresolvedAlias`s will appear only at root of a expression tree, we don't need
+ // to transform down the whole tree.
+ exprs.zipWithIndex.map {
+ case (u @ UnresolvedAlias(child), i) =>
+ child match {
+ case _: UnresolvedAttribute => u
+ case ne: NamedExpression => ne
+ case ev: ExtractValueWithStruct => Alias(ev, ev.field.name)()
+ case g: Generator if g.resolved && g.elementTypes.size > 1 => MultiAlias(g, Nil)
+ case e if !e.resolved => u
+ case other => Alias(other, s"_c$i")()
+ }
+ case (other, _) => other
+ }
+ }
+
+ def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
+ case Aggregate(groups, aggs, child)
+ if child.resolved && aggs.exists(_.isInstanceOf[UnresolvedAlias]) =>
+ Aggregate(groups, assignAliases(aggs), child)
+
+ case g: GroupingAnalytics
+ if g.child.resolved && g.aggregations.exists(_.isInstanceOf[UnresolvedAlias]) =>
+ g.withNewAggs(assignAliases(g.aggregations))
+
+ case Project(projectList, child)
+ if child.resolved && projectList.exists(_.isInstanceOf[UnresolvedAlias]) =>
+ Project(assignAliases(projectList), child)
}
}
@@ -228,7 +252,7 @@ class Analyzer(
}
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
- case i@InsertIntoTable(u: UnresolvedRelation, _, _, _, _) =>
+ case i @ InsertIntoTable(u: UnresolvedRelation, _, _, _, _) =>
i.copy(table = EliminateSubQueries(getTable(u)))
case u: UnresolvedRelation =>
getTable(u)
@@ -248,24 +272,24 @@ class Analyzer(
Project(
projectList.flatMap {
case s: Star => s.expand(child.output, resolver)
- case Alias(f @ UnresolvedFunction(_, args), name) if containsStar(args) =>
+ case UnresolvedAlias(f @ UnresolvedFunction(_, args)) if containsStar(args) =>
val expandedArgs = args.flatMap {
case s: Star => s.expand(child.output, resolver)
case o => o :: Nil
}
- Alias(child = f.copy(children = expandedArgs), name)() :: Nil
- case Alias(c @ CreateArray(args), name) if containsStar(args) =>
+ UnresolvedAlias(child = f.copy(children = expandedArgs)) :: Nil
+ case UnresolvedAlias(c @ CreateArray(args)) if containsStar(args) =>
val expandedArgs = args.flatMap {
case s: Star => s.expand(child.output, resolver)
case o => o :: Nil
}
- Alias(c.copy(children = expandedArgs), name)() :: Nil
- case Alias(c @ CreateStruct(args), name) if containsStar(args) =>
+ UnresolvedAlias(c.copy(children = expandedArgs)) :: Nil
+ case UnresolvedAlias(c @ CreateStruct(args)) if containsStar(args) =>
val expandedArgs = args.flatMap {
case s: Star => s.expand(child.output, resolver)
case o => o :: Nil
}
- Alias(c.copy(children = expandedArgs), name)() :: Nil
+ UnresolvedAlias(c.copy(children = expandedArgs)) :: Nil
case o => o :: Nil
},
child)
@@ -353,7 +377,9 @@ class Analyzer(
case u @ UnresolvedAttribute(nameParts) =>
// Leave unchanged if resolution fails. Hopefully will be resolved next round.
val result =
- withPosition(u) { q.resolveChildren(nameParts, resolver).getOrElse(u) }
+ withPosition(u) {
+ q.resolveChildren(nameParts, resolver).map(trimUnresolvedAlias).getOrElse(u)
+ }
logDebug(s"Resolving $u to $result")
result
case UnresolvedExtractValue(child, fieldExpr) if child.resolved =>
@@ -379,6 +405,11 @@ class Analyzer(
exprs.exists(_.collect { case _: Star => true }.nonEmpty)
}
+ private def trimUnresolvedAlias(ne: NamedExpression) = ne match {
+ case UnresolvedAlias(child) => child
+ case other => other
+ }
+
private def resolveSortOrders(ordering: Seq[SortOrder], plan: LogicalPlan, throws: Boolean) = {
ordering.map { order =>
// Resolve SortOrder in one round.
@@ -388,7 +419,7 @@ class Analyzer(
try {
val newOrder = order transformUp {
case u @ UnresolvedAttribute(nameParts) =>
- plan.resolve(nameParts, resolver).getOrElse(u)
+ plan.resolve(nameParts, resolver).map(trimUnresolvedAlias).getOrElse(u)
case UnresolvedExtractValue(child, fieldName) if child.resolved =>
ExtractValue(child, fieldName, resolver)
}
@@ -586,18 +617,6 @@ class Analyzer(
/** Extracts a [[Generator]] expression and any names assigned by aliases to their output. */
private object AliasedGenerator {
def unapply(e: Expression): Option[(Generator, Seq[String])] = e match {
- case Alias(g: Generator, name)
- if g.resolved &&
- g.elementTypes.size > 1 &&
- java.util.regex.Pattern.matches("_c[0-9]+", name) => {
- // Assume the default name given by parser is "_c[0-9]+",
- // TODO in long term, move the naming logic from Parser to Analyzer.
- // In projection, Parser gave default name for TGF as does for normal UDF,
- // but the TGF probably have multiple output columns/names.
- // e.g. SELECT explode(map(key, value)) FROM src;
- // Let's simply ignore the default given name for this case.
- Some((g, Nil))
- }
case Alias(g: Generator, name) if g.resolved && g.elementTypes.size > 1 =>
// If not given the default names, and the TGF with multiple output columns
failAnalysis(
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
index 7fabd2bfc8..c5a1437be6 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
@@ -95,14 +95,7 @@ trait CheckAnalysis {
case e => e.children.foreach(checkValidAggregateExpression)
}
- val cleaned = aggregateExprs.map(_.transform {
- // Should trim aliases around `GetField`s. These aliases are introduced while
- // resolving struct field accesses, because `GetField` is not a `NamedExpression`.
- // (Should we just turn `GetField` into a `NamedExpression`?)
- case Alias(g, _) => g
- })
-
- cleaned.foreach(checkValidAggregateExpression)
+ aggregateExprs.foreach(checkValidAggregateExpression)
case _ => // Fallbacks to the following checks
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
index c9d9142578..ae3adbab05 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
@@ -17,7 +17,6 @@
package org.apache.spark.sql.catalyst.analysis
-import org.apache.spark.sql.catalyst
import org.apache.spark.sql.catalyst.{errors, trees}
import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.catalyst.expressions._
@@ -206,3 +205,22 @@ case class UnresolvedExtractValue(child: Expression, extraction: Expression)
override def toString: String = s"$child[$extraction]"
}
+
+/**
+ * Holds the expression that has yet to be aliased.
+ */
+case class UnresolvedAlias(child: Expression) extends NamedExpression
+ with trees.UnaryNode[Expression] {
+
+ override def toAttribute: Attribute = throw new UnresolvedException(this, "toAttribute")
+ override def qualifiers: Seq[String] = throw new UnresolvedException(this, "qualifiers")
+ override def exprId: ExprId = throw new UnresolvedException(this, "exprId")
+ override def nullable: Boolean = throw new UnresolvedException(this, "nullable")
+ override def dataType: DataType = throw new UnresolvedException(this, "dataType")
+ override def name: String = throw new UnresolvedException(this, "name")
+
+ override lazy val resolved = false
+
+ override def eval(input: InternalRow = null): Any =
+ throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}")
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala
index 4aaabff15b..013027b199 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions
import scala.collection.Map
-import org.apache.spark.sql.{catalyst, AnalysisException}
+import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.types._
@@ -41,16 +41,22 @@ object ExtractValue {
resolver: Resolver): ExtractValue = {
(child.dataType, extraction) match {
- case (StructType(fields), Literal(fieldName, StringType)) =>
- val ordinal = findField(fields, fieldName.toString, resolver)
- GetStructField(child, fields(ordinal), ordinal)
- case (ArrayType(StructType(fields), containsNull), Literal(fieldName, StringType)) =>
- val ordinal = findField(fields, fieldName.toString, resolver)
- GetArrayStructFields(child, fields(ordinal), ordinal, containsNull)
+ case (StructType(fields), NonNullLiteral(v, StringType)) =>
+ val fieldName = v.toString
+ val ordinal = findField(fields, fieldName, resolver)
+ GetStructField(child, fields(ordinal).copy(name = fieldName), ordinal)
+
+ case (ArrayType(StructType(fields), containsNull), NonNullLiteral(v, StringType)) =>
+ val fieldName = v.toString
+ val ordinal = findField(fields, fieldName, resolver)
+ GetArrayStructFields(child, fields(ordinal).copy(name = fieldName), ordinal, containsNull)
+
case (_: ArrayType, _) if extraction.dataType.isInstanceOf[IntegralType] =>
GetArrayItem(child, extraction)
+
case (_: MapType, _) =>
GetMapValue(child, extraction)
+
case (otherType, _) =>
val errorMsg = otherType match {
case StructType(_) | ArrayType(StructType(_), _) =>
@@ -94,16 +100,21 @@ trait ExtractValue extends UnaryExpression {
self: Product =>
}
+abstract class ExtractValueWithStruct extends ExtractValue {
+ self: Product =>
+
+ def field: StructField
+ override def toString: String = s"$child.${field.name}"
+}
+
/**
* Returns the value of fields in the Struct `child`.
*/
case class GetStructField(child: Expression, field: StructField, ordinal: Int)
- extends ExtractValue {
+ extends ExtractValueWithStruct {
override def dataType: DataType = field.dataType
override def nullable: Boolean = child.nullable || field.nullable
- override def foldable: Boolean = child.foldable
- override def toString: String = s"$child.${field.name}"
override def eval(input: InternalRow): Any = {
val baseValue = child.eval(input).asInstanceOf[InternalRow]
@@ -118,12 +129,9 @@ case class GetArrayStructFields(
child: Expression,
field: StructField,
ordinal: Int,
- containsNull: Boolean) extends ExtractValue {
+ containsNull: Boolean) extends ExtractValueWithStruct {
override def dataType: DataType = ArrayType(field.dataType, containsNull)
- override def nullable: Boolean = child.nullable
- override def foldable: Boolean = child.foldable
- override def toString: String = s"$child.${field.name}"
override def eval(input: InternalRow): Any = {
val baseValue = child.eval(input).asInstanceOf[Seq[InternalRow]]
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
index 3b6f8bfd9f..179a348d5b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
@@ -156,12 +156,8 @@ object PartialAggregation {
partialEvaluations(new TreeNodeRef(e)).finalEvaluation
case e: Expression =>
- // Should trim aliases around `GetField`s. These aliases are introduced while
- // resolving struct field accesses, because `GetField` is not a `NamedExpression`.
- // (Should we just turn `GetField` into a `NamedExpression`?)
- val trimmed = e.transform { case Alias(g: ExtractValue, _) => g }
namedGroupingExpressions.collectFirst {
- case (expr, ne) if expr semanticEquals trimmed => ne.toAttribute
+ case (expr, ne) if expr semanticEquals e => ne.toAttribute
}.getOrElse(e)
}).asInstanceOf[Seq[NamedExpression]]
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
index a853e27c12..b009a200b9 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.Logging
import org.apache.spark.sql.AnalysisException
-import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, EliminateSubQueries, Resolver}
+import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.trees.TreeNode
@@ -252,14 +252,13 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging {
// One match, but we also need to extract the requested nested field.
case Seq((a, nestedFields)) =>
// The foldLeft adds ExtractValues for every remaining parts of the identifier,
- // and aliases it with the last part of the identifier.
+ // and wrap it with UnresolvedAlias which will be removed later.
// For example, consider "a.b.c", where "a" is resolved to an existing attribute.
- // Then this will add ExtractValue("c", ExtractValue("b", a)), and alias
- // the final expression as "c".
+ // Then this will add ExtractValue("c", ExtractValue("b", a)), and wrap it as
+ // UnresolvedAlias(ExtractValue("c", ExtractValue("b", a))).
val fieldExprs = nestedFields.foldLeft(a: Expression)((expr, fieldName) =>
ExtractValue(expr, Literal(fieldName), resolver))
- val aliasName = nestedFields.last
- Some(Alias(fieldExprs, aliasName)())
+ Some(UnresolvedAlias(fieldExprs))
// No matches.
case Seq() =>
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
index 963c782091..f8e5916d69 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
@@ -242,6 +242,8 @@ trait GroupingAnalytics extends UnaryNode {
def aggregations: Seq[NamedExpression]
override def output: Seq[Attribute] = aggregations.map(_.toAttribute)
+
+ def withNewAggs(aggs: Seq[NamedExpression]): GroupingAnalytics
}
/**
@@ -266,7 +268,11 @@ case class GroupingSets(
groupByExprs: Seq[Expression],
child: LogicalPlan,
aggregations: Seq[NamedExpression],
- gid: AttributeReference = VirtualColumn.newGroupingId) extends GroupingAnalytics
+ gid: AttributeReference = VirtualColumn.newGroupingId) extends GroupingAnalytics {
+
+ def withNewAggs(aggs: Seq[NamedExpression]): GroupingAnalytics =
+ this.copy(aggregations = aggs)
+}
/**
* Cube is a syntactic sugar for GROUPING SETS, and will be transformed to GroupingSets,
@@ -284,7 +290,11 @@ case class Cube(
groupByExprs: Seq[Expression],
child: LogicalPlan,
aggregations: Seq[NamedExpression],
- gid: AttributeReference = VirtualColumn.newGroupingId) extends GroupingAnalytics
+ gid: AttributeReference = VirtualColumn.newGroupingId) extends GroupingAnalytics {
+
+ def withNewAggs(aggs: Seq[NamedExpression]): GroupingAnalytics =
+ this.copy(aggregations = aggs)
+}
/**
* Rollup is a syntactic sugar for GROUPING SETS, and will be transformed to GroupingSets,
@@ -303,7 +313,11 @@ case class Rollup(
groupByExprs: Seq[Expression],
child: LogicalPlan,
aggregations: Seq[NamedExpression],
- gid: AttributeReference = VirtualColumn.newGroupingId) extends GroupingAnalytics
+ gid: AttributeReference = VirtualColumn.newGroupingId) extends GroupingAnalytics {
+
+ def withNewAggs(aggs: Seq[NamedExpression]): GroupingAnalytics =
+ this.copy(aggregations = aggs)
+}
case class Limit(limitExpr: Expression, child: LogicalPlan) extends UnaryNode {
override def output: Seq[Attribute] = child.output
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
index b4e008a6e8..f201c8ea8a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
@@ -21,7 +21,6 @@ import scala.language.implicitConversions
import org.apache.spark.annotation.Experimental
import org.apache.spark.Logging
-import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions.lit
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.analysis._
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
index 466258e76f..492a3321bc 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
@@ -32,7 +32,7 @@ import org.apache.spark.annotation.{DeveloperApi, Experimental}
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.api.python.SerDeUtil
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.catalyst.analysis.{MultiAlias, ResolvedStar, UnresolvedAttribute}
+import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.{Filter, _}
import org.apache.spark.sql.catalyst.plans.{Inner, JoinType}
@@ -629,6 +629,10 @@ class DataFrame private[sql](
@scala.annotation.varargs
def select(cols: Column*): DataFrame = {
val namedExpressions = cols.map {
+ // Wrap UnresolvedAttribute with UnresolvedAlias, as when we resolve UnresolvedAttribute, we
+ // will remove intermediate Alias for ExtractValue chain, and we need to alias it again to
+ // make it a NamedExpression.
+ case Column(u: UnresolvedAttribute) => UnresolvedAlias(u)
case Column(expr: NamedExpression) => expr
// Leave an unaliased explode with an empty list of names since the analzyer will generate the
// correct defaults after the nested expression's type has been resolved.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala
index 45b3e1bc62..99d557b03a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala
@@ -21,7 +21,7 @@ import scala.collection.JavaConversions._
import scala.language.implicitConversions
import org.apache.spark.annotation.Experimental
-import org.apache.spark.sql.catalyst.analysis.Star
+import org.apache.spark.sql.catalyst.analysis.{UnresolvedAlias, UnresolvedAttribute, Star}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.{Rollup, Cube, Aggregate}
import org.apache.spark.sql.types.NumericType
@@ -70,27 +70,31 @@ class GroupedData protected[sql](
groupingExprs: Seq[Expression],
private val groupType: GroupedData.GroupType) {
- private[this] def toDF(aggExprs: Seq[NamedExpression]): DataFrame = {
+ private[this] def toDF(aggExprs: Seq[Expression]): DataFrame = {
val aggregates = if (df.sqlContext.conf.dataFrameRetainGroupColumns) {
- val retainedExprs = groupingExprs.map {
- case expr: NamedExpression => expr
- case expr: Expression => Alias(expr, expr.prettyString)()
- }
- retainedExprs ++ aggExprs
- } else {
- aggExprs
- }
+ groupingExprs ++ aggExprs
+ } else {
+ aggExprs
+ }
+ val aliasedAgg = aggregates.map {
+ // Wrap UnresolvedAttribute with UnresolvedAlias, as when we resolve UnresolvedAttribute, we
+ // will remove intermediate Alias for ExtractValue chain, and we need to alias it again to
+ // make it a NamedExpression.
+ case u: UnresolvedAttribute => UnresolvedAlias(u)
+ case expr: NamedExpression => expr
+ case expr: Expression => Alias(expr, expr.prettyString)()
+ }
groupType match {
case GroupedData.GroupByType =>
DataFrame(
- df.sqlContext, Aggregate(groupingExprs, aggregates, df.logicalPlan))
+ df.sqlContext, Aggregate(groupingExprs, aliasedAgg, df.logicalPlan))
case GroupedData.RollupType =>
DataFrame(
- df.sqlContext, Rollup(groupingExprs, df.logicalPlan, aggregates))
+ df.sqlContext, Rollup(groupingExprs, df.logicalPlan, aliasedAgg))
case GroupedData.CubeType =>
DataFrame(
- df.sqlContext, Cube(groupingExprs, df.logicalPlan, aggregates))
+ df.sqlContext, Cube(groupingExprs, df.logicalPlan, aliasedAgg))
}
}
@@ -112,10 +116,7 @@ class GroupedData protected[sql](
namedExpr
}
}
- toDF(columnExprs.map { c =>
- val a = f(c)
- Alias(a, a.prettyString)()
- })
+ toDF(columnExprs.map(f))
}
private[this] def strToExpr(expr: String): (Expression => Expression) = {
@@ -169,8 +170,7 @@ class GroupedData protected[sql](
*/
def agg(exprs: Map[String, String]): DataFrame = {
toDF(exprs.map { case (colName, expr) =>
- val a = strToExpr(expr)(df(colName).expr)
- Alias(a, a.prettyString)()
+ strToExpr(expr)(df(colName).expr)
}.toSeq)
}
@@ -224,10 +224,7 @@ class GroupedData protected[sql](
*/
@scala.annotation.varargs
def agg(expr: Column, exprs: Column*): DataFrame = {
- toDF((expr +: exprs).map(_.expr).map {
- case expr: NamedExpression => expr
- case expr: Expression => Alias(expr, expr.prettyString)()
- })
+ toDF((expr +: exprs).map(_.expr))
}
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala
index 1ce150ceaf..c8c67ce334 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala
@@ -74,7 +74,7 @@ private[spark] object ExtractPythonUdfs extends Rule[LogicalPlan] {
// Skip EvaluatePython nodes.
case plan: EvaluatePython => plan
- case plan: LogicalPlan =>
+ case plan: LogicalPlan if plan.resolved =>
// Extract any PythonUDFs from the current operator.
val udfs = plan.expressions.flatMap(_.collect { case udf: PythonUDF => udf })
if (udfs.isEmpty) {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index 4441afd6bd..73bc6c9991 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -1367,9 +1367,9 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils {
test("SPARK-6145: special cases") {
sqlContext.read.json(sqlContext.sparkContext.makeRDD(
- """{"a": {"b": [1]}, "b": [{"a": 1}], "c0": {"a": 1}}""" :: Nil)).registerTempTable("t")
- checkAnswer(sql("SELECT a.b[0] FROM t ORDER BY c0.a"), Row(1))
- checkAnswer(sql("SELECT b[0].a FROM t ORDER BY c0.a"), Row(1))
+ """{"a": {"b": [1]}, "b": [{"a": 1}], "_c0": {"a": 1}}""" :: Nil)).registerTempTable("t")
+ checkAnswer(sql("SELECT a.b[0] FROM t ORDER BY _c0.a"), Row(1))
+ checkAnswer(sql("SELECT b[0].a FROM t ORDER BY _c0.a"), Row(1))
}
test("SPARK-6898: complete support for special chars in column names") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala
index 520a862ea0..207d7a352c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala
@@ -19,7 +19,6 @@ package org.apache.spark.sql
import java.sql.Timestamp
-import org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.test.TestSQLContext.implicits._
import org.apache.spark.sql.test._
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
index ca4b80b51b..7c4620952b 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
@@ -415,13 +415,6 @@ private[hive] object HiveQl {
throw new NotImplementedError(s"No parse rules for StructField:\n ${dumpTree(a).toString} ")
}
- protected def nameExpressions(exprs: Seq[Expression]): Seq[NamedExpression] = {
- exprs.zipWithIndex.map {
- case (ne: NamedExpression, _) => ne
- case (e, i) => Alias(e, s"_c$i")()
- }
- }
-
protected def extractDbNameTableName(tableNameParts: Node): (Option[String], String) = {
val (db, tableName) =
tableNameParts.getChildren.map { case Token(part, Nil) => cleanIdentifier(part) } match {
@@ -942,7 +935,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
// (if there is a group by) or a script transformation.
val withProject: LogicalPlan = transformation.getOrElse {
val selectExpressions =
- nameExpressions(select.getChildren.flatMap(selExprNodeToExpr).toSeq)
+ select.getChildren.flatMap(selExprNodeToExpr).map(UnresolvedAlias(_)).toSeq
Seq(
groupByClause.map(e => e match {
case Token("TOK_GROUPBY", children) =>