aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala75
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala9
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala8
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala2
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala17
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/Column.scala16
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala9
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala6
9 files changed, 120 insertions, 24 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index a684dbc3af..4bc1c1af40 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -82,7 +82,9 @@ class Analyzer(
HiveTypeCoercion.typeCoercionRules ++
extendedResolutionRules : _*),
Batch("Nondeterministic", Once,
- PullOutNondeterministic)
+ PullOutNondeterministic),
+ Batch("Cleanup", fixedPoint,
+ CleanupAliases)
)
/**
@@ -146,8 +148,6 @@ class Analyzer(
child match {
case _: UnresolvedAttribute => u
case ne: NamedExpression => ne
- case g: GetStructField => Alias(g, g.field.name)()
- case g: GetArrayStructFields => Alias(g, g.field.name)()
case g: Generator if g.resolved && g.elementTypes.size > 1 => MultiAlias(g, Nil)
case e if !e.resolved => u
case other => Alias(other, s"_c$i")()
@@ -384,9 +384,7 @@ class Analyzer(
case u @ UnresolvedAttribute(nameParts) =>
// Leave unchanged if resolution fails. Hopefully will be resolved next round.
val result =
- withPosition(u) {
- q.resolveChildren(nameParts, resolver).map(trimUnresolvedAlias).getOrElse(u)
- }
+ withPosition(u) { q.resolveChildren(nameParts, resolver).getOrElse(u) }
logDebug(s"Resolving $u to $result")
result
case UnresolvedExtractValue(child, fieldExpr) if child.resolved =>
@@ -412,11 +410,6 @@ class Analyzer(
exprs.exists(_.collect { case _: Star => true }.nonEmpty)
}
- private def trimUnresolvedAlias(ne: NamedExpression) = ne match {
- case UnresolvedAlias(child) => child
- case other => other
- }
-
private def resolveSortOrders(ordering: Seq[SortOrder], plan: LogicalPlan, throws: Boolean) = {
ordering.map { order =>
// Resolve SortOrder in one round.
@@ -426,7 +419,7 @@ class Analyzer(
try {
val newOrder = order transformUp {
case u @ UnresolvedAttribute(nameParts) =>
- plan.resolve(nameParts, resolver).map(trimUnresolvedAlias).getOrElse(u)
+ plan.resolve(nameParts, resolver).getOrElse(u)
case UnresolvedExtractValue(child, fieldName) if child.resolved =>
ExtractValue(child, fieldName, resolver)
}
@@ -968,3 +961,61 @@ object EliminateSubQueries extends Rule[LogicalPlan] {
case Subquery(_, child) => child
}
}
+
+/**
+ * Cleans up unnecessary Aliases inside the plan. Basically we only need Alias as a top level
+ * expression in Project(project list) or Aggregate(aggregate expressions) or
+ * Window(window expressions).
+ */
+object CleanupAliases extends Rule[LogicalPlan] {
+ private def trimAliases(e: Expression): Expression = {
+ var stop = false
+ e.transformDown {
+ // CreateStruct is a special case, we need to retain its top level Aliases as they decide the
+ // name of StructField. We also need to stop transform down this expression, or the Aliases
+ // under CreateStruct will be mistakenly trimmed.
+ case c: CreateStruct if !stop =>
+ stop = true
+ c.copy(children = c.children.map(trimNonTopLevelAliases))
+ case c: CreateStructUnsafe if !stop =>
+ stop = true
+ c.copy(children = c.children.map(trimNonTopLevelAliases))
+ case Alias(child, _) if !stop => child
+ }
+ }
+
+ def trimNonTopLevelAliases(e: Expression): Expression = e match {
+ case a: Alias =>
+ Alias(trimAliases(a.child), a.name)(a.exprId, a.qualifiers, a.explicitMetadata)
+ case other => trimAliases(other)
+ }
+
+ override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
+ case Project(projectList, child) =>
+ val cleanedProjectList =
+ projectList.map(trimNonTopLevelAliases(_).asInstanceOf[NamedExpression])
+ Project(cleanedProjectList, child)
+
+ case Aggregate(grouping, aggs, child) =>
+ val cleanedAggs = aggs.map(trimNonTopLevelAliases(_).asInstanceOf[NamedExpression])
+ Aggregate(grouping.map(trimAliases), cleanedAggs, child)
+
+ case w @ Window(projectList, windowExprs, partitionSpec, orderSpec, child) =>
+ val cleanedWindowExprs =
+ windowExprs.map(e => trimNonTopLevelAliases(e).asInstanceOf[NamedExpression])
+ Window(projectList, cleanedWindowExprs, partitionSpec.map(trimAliases),
+ orderSpec.map(trimAliases(_).asInstanceOf[SortOrder]), child)
+
+ case other =>
+ var stop = false
+ other transformExpressionsDown {
+ case c: CreateStruct if !stop =>
+ stop = true
+ c.copy(children = c.children.map(trimNonTopLevelAliases))
+ case c: CreateStructUnsafe if !stop =>
+ stop = true
+ c.copy(children = c.children.map(trimNonTopLevelAliases))
+ case Alias(child, _) if !stop => child
+ }
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
index 4a071e663e..298aee3499 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
@@ -75,8 +75,6 @@ case class CreateStruct(children: Seq[Expression]) extends Expression {
override def foldable: Boolean = children.forall(_.foldable)
- override lazy val resolved: Boolean = childrenResolved
-
override lazy val dataType: StructType = {
val fields = children.zipWithIndex.map { case (child, idx) =>
child match {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 4ab5ac2c61..47b06cae15 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -18,7 +18,7 @@
package org.apache.spark.sql.catalyst.optimizer
import scala.collection.immutable.HashSet
-import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries
+import org.apache.spark.sql.catalyst.analysis.{CleanupAliases, EliminateSubQueries}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.Inner
import org.apache.spark.sql.catalyst.plans.FullOuter
@@ -260,8 +260,11 @@ object ProjectCollapsing extends Rule[LogicalPlan] {
val substitutedProjection = projectList1.map(_.transform {
case a: Attribute => aliasMap.getOrElse(a, a)
}).asInstanceOf[Seq[NamedExpression]]
-
- Project(substitutedProjection, child)
+ // collapse 2 projects may introduce unnecessary Aliases, trim them here.
+ val cleanedProjection = substitutedProjection.map(p =>
+ CleanupAliases.trimNonTopLevelAliases(p).asInstanceOf[NamedExpression]
+ )
+ Project(cleanedProjection, child)
}
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
index c290e6acb3..9bb466ac2d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
@@ -259,13 +259,13 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging {
// One match, but we also need to extract the requested nested field.
case Seq((a, nestedFields)) =>
// The foldLeft adds ExtractValues for every remaining parts of the identifier,
- // and wrap it with UnresolvedAlias which will be removed later.
+ // and aliased it with the last part of the name.
// For example, consider "a.b.c", where "a" is resolved to an existing attribute.
- // Then this will add ExtractValue("c", ExtractValue("b", a)), and wrap it as
- // UnresolvedAlias(ExtractValue("c", ExtractValue("b", a))).
+ // Then this will add ExtractValue("c", ExtractValue("b", a)), and alias the final
+ // expression as "c".
val fieldExprs = nestedFields.foldLeft(a: Expression)((expr, fieldName) =>
ExtractValue(expr, Literal(fieldName), resolver))
- Some(UnresolvedAlias(fieldExprs))
+ Some(Alias(fieldExprs, nestedFields.last)())
// No matches.
case Seq() =>
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
index 7c404722d8..73b8261260 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
@@ -228,7 +228,7 @@ case class Window(
child: LogicalPlan) extends UnaryNode {
override def output: Seq[Attribute] =
- (projectList ++ windowExpressions).map(_.toAttribute)
+ projectList ++ windowExpressions.map(_.toAttribute)
}
/**
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
index c944bc69e2..1e0cc81dae 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
@@ -119,4 +119,21 @@ class AnalysisSuite extends AnalysisTest {
Project(testRelation.output :+ projected, testRelation)))
checkAnalysis(plan, expected)
}
+
+ test("SPARK-9634: cleanup unnecessary Aliases in LogicalPlan") {
+ val a = testRelation.output.head
+ var plan = testRelation.select(((a + 1).as("a+1") + 2).as("col"))
+ var expected = testRelation.select((a + 1 + 2).as("col"))
+ checkAnalysis(plan, expected)
+
+ plan = testRelation.groupBy(a.as("a1").as("a2"))((min(a).as("min_a") + 1).as("col"))
+ expected = testRelation.groupBy(a)((min(a) + 1).as("col"))
+ checkAnalysis(plan, expected)
+
+ // CreateStruct is a special case that we should not trim Alias for it.
+ plan = testRelation.select(CreateStruct(Seq(a, (a + 1).as("a+1"))).as("col"))
+ checkAnalysis(plan, plan)
+ plan = testRelation.select(CreateStructUnsafe(Seq(a, (a + 1).as("a+1"))).as("col"))
+ checkAnalysis(plan, plan)
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
index 27bd084847..807bc8c30c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
@@ -753,10 +753,16 @@ class Column(protected[sql] val expr: Expression) extends Logging {
* df.select($"colA".as("colB"))
* }}}
*
+ * If the current column has metadata associated with it, this metadata will be propagated
+ * to the new column. If this not desired, use `as` with explicitly empty metadata.
+ *
* @group expr_ops
* @since 1.3.0
*/
- def as(alias: String): Column = Alias(expr, alias)()
+ def as(alias: String): Column = expr match {
+ case ne: NamedExpression => Alias(expr, alias)(explicitMetadata = Some(ne.metadata))
+ case other => Alias(other, alias)()
+ }
/**
* (Scala-specific) Assigns the given aliases to the results of a table generating function.
@@ -789,10 +795,16 @@ class Column(protected[sql] val expr: Expression) extends Logging {
* df.select($"colA".as('colB))
* }}}
*
+ * If the current column has metadata associated with it, this metadata will be propagated
+ * to the new column. If this not desired, use `as` with explicitly empty metadata.
+ *
* @group expr_ops
* @since 1.3.0
*/
- def as(alias: Symbol): Column = Alias(expr, alias.name)()
+ def as(alias: Symbol): Column = expr match {
+ case ne: NamedExpression => Alias(expr, alias.name)(explicitMetadata = Some(ne.metadata))
+ case other => Alias(other, alias.name)()
+ }
/**
* Gives the column an alias with metadata.
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
index ee74e3e83d..37738ec5b3 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
@@ -17,6 +17,7 @@
package org.apache.spark.sql
+import org.apache.spark.sql.catalyst.expressions.NamedExpression
import org.scalatest.Matchers._
import org.apache.spark.sql.execution.{Project, TungstenProject}
@@ -110,6 +111,14 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext {
assert(df.select(df("a").alias("b")).columns.head === "b")
}
+ test("as propagates metadata") {
+ val metadata = new MetadataBuilder
+ metadata.putString("key", "value")
+ val origCol = $"a".as("b", metadata.build())
+ val newCol = origCol.as("c")
+ assert(newCol.expr.asInstanceOf[NamedExpression].metadata.getString("key") === "value")
+ }
+
test("single explode") {
val df = Seq((1, Seq(1, 2, 3))).toDF("a", "intList")
checkAnswer(
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index 10bfa9b64f..cf22797752 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -867,4 +867,10 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
val actual = df.sort(rand(seed)).collect().map(_.getInt(0))
assert(expected === actual)
}
+
+ test("SPARK-9323: DataFrame.orderBy should support nested column name") {
+ val df = sqlContext.read.json(sqlContext.sparkContext.makeRDD(
+ """{"a": {"b": 1}}""" :: Nil))
+ checkAnswer(df.orderBy("a.b"), Row(Row(1)))
+ }
}