aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorWenchen Fan <cloud0fan@outlook.com>2015-08-14 20:59:54 -0700
committerReynold Xin <rxin@databricks.com>2015-08-14 20:59:54 -0700
commitec29f2034a3306cc0afdc4c160b42c2eefa0897c (patch)
treeab85cd0a650c7f804fd2af7768156ce4f318aab9
parent37586e5449ff8f892d41f0b6b8fa1de83dd3849e (diff)
downloadspark-ec29f2034a3306cc0afdc4c160b42c2eefa0897c.tar.gz
spark-ec29f2034a3306cc0afdc4c160b42c2eefa0897c.tar.bz2
spark-ec29f2034a3306cc0afdc4c160b42c2eefa0897c.zip
[SPARK-9634] [SPARK-9323] [SQL] cleanup unnecessary Aliases in LogicalPlan at the end of analysis
Also alias the ExtractValue instead of wrapping it with UnresolvedAlias when resolve attribute in LogicalPlan, as this alias will be trimmed if it's unnecessary. Based on #7957 without the changes to mllib, but instead maintaining earlier behavior when using `withColumn` on expressions that already have metadata. Author: Wenchen Fan <cloud0fan@outlook.com> Author: Michael Armbrust <michael@databricks.com> Closes #8215 from marmbrus/pr/7957.
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala75
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala9
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala8
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala2
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala17
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/Column.scala16
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala9
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala6
9 files changed, 120 insertions, 24 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index a684dbc3af..4bc1c1af40 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -82,7 +82,9 @@ class Analyzer(
HiveTypeCoercion.typeCoercionRules ++
extendedResolutionRules : _*),
Batch("Nondeterministic", Once,
- PullOutNondeterministic)
+ PullOutNondeterministic),
+ Batch("Cleanup", fixedPoint,
+ CleanupAliases)
)
/**
@@ -146,8 +148,6 @@ class Analyzer(
child match {
case _: UnresolvedAttribute => u
case ne: NamedExpression => ne
- case g: GetStructField => Alias(g, g.field.name)()
- case g: GetArrayStructFields => Alias(g, g.field.name)()
case g: Generator if g.resolved && g.elementTypes.size > 1 => MultiAlias(g, Nil)
case e if !e.resolved => u
case other => Alias(other, s"_c$i")()
@@ -384,9 +384,7 @@ class Analyzer(
case u @ UnresolvedAttribute(nameParts) =>
// Leave unchanged if resolution fails. Hopefully will be resolved next round.
val result =
- withPosition(u) {
- q.resolveChildren(nameParts, resolver).map(trimUnresolvedAlias).getOrElse(u)
- }
+ withPosition(u) { q.resolveChildren(nameParts, resolver).getOrElse(u) }
logDebug(s"Resolving $u to $result")
result
case UnresolvedExtractValue(child, fieldExpr) if child.resolved =>
@@ -412,11 +410,6 @@ class Analyzer(
exprs.exists(_.collect { case _: Star => true }.nonEmpty)
}
- private def trimUnresolvedAlias(ne: NamedExpression) = ne match {
- case UnresolvedAlias(child) => child
- case other => other
- }
-
private def resolveSortOrders(ordering: Seq[SortOrder], plan: LogicalPlan, throws: Boolean) = {
ordering.map { order =>
// Resolve SortOrder in one round.
@@ -426,7 +419,7 @@ class Analyzer(
try {
val newOrder = order transformUp {
case u @ UnresolvedAttribute(nameParts) =>
- plan.resolve(nameParts, resolver).map(trimUnresolvedAlias).getOrElse(u)
+ plan.resolve(nameParts, resolver).getOrElse(u)
case UnresolvedExtractValue(child, fieldName) if child.resolved =>
ExtractValue(child, fieldName, resolver)
}
@@ -968,3 +961,61 @@ object EliminateSubQueries extends Rule[LogicalPlan] {
case Subquery(_, child) => child
}
}
+
+/**
+ * Cleans up unnecessary Aliases inside the plan. Basically we only need Alias as a top level
+ * expression in Project(project list) or Aggregate(aggregate expressions) or
+ * Window(window expressions).
+ */
+object CleanupAliases extends Rule[LogicalPlan] {
+ private def trimAliases(e: Expression): Expression = {
+ var stop = false
+ e.transformDown {
+ // CreateStruct is a special case, we need to retain its top level Aliases as they decide the
+ // name of StructField. We also need to stop transform down this expression, or the Aliases
+ // under CreateStruct will be mistakenly trimmed.
+ case c: CreateStruct if !stop =>
+ stop = true
+ c.copy(children = c.children.map(trimNonTopLevelAliases))
+ case c: CreateStructUnsafe if !stop =>
+ stop = true
+ c.copy(children = c.children.map(trimNonTopLevelAliases))
+ case Alias(child, _) if !stop => child
+ }
+ }
+
+ def trimNonTopLevelAliases(e: Expression): Expression = e match {
+ case a: Alias =>
+ Alias(trimAliases(a.child), a.name)(a.exprId, a.qualifiers, a.explicitMetadata)
+ case other => trimAliases(other)
+ }
+
+ override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
+ case Project(projectList, child) =>
+ val cleanedProjectList =
+ projectList.map(trimNonTopLevelAliases(_).asInstanceOf[NamedExpression])
+ Project(cleanedProjectList, child)
+
+ case Aggregate(grouping, aggs, child) =>
+ val cleanedAggs = aggs.map(trimNonTopLevelAliases(_).asInstanceOf[NamedExpression])
+ Aggregate(grouping.map(trimAliases), cleanedAggs, child)
+
+ case w @ Window(projectList, windowExprs, partitionSpec, orderSpec, child) =>
+ val cleanedWindowExprs =
+ windowExprs.map(e => trimNonTopLevelAliases(e).asInstanceOf[NamedExpression])
+ Window(projectList, cleanedWindowExprs, partitionSpec.map(trimAliases),
+ orderSpec.map(trimAliases(_).asInstanceOf[SortOrder]), child)
+
+ case other =>
+ var stop = false
+ other transformExpressionsDown {
+ case c: CreateStruct if !stop =>
+ stop = true
+ c.copy(children = c.children.map(trimNonTopLevelAliases))
+ case c: CreateStructUnsafe if !stop =>
+ stop = true
+ c.copy(children = c.children.map(trimNonTopLevelAliases))
+ case Alias(child, _) if !stop => child
+ }
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
index 4a071e663e..298aee3499 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
@@ -75,8 +75,6 @@ case class CreateStruct(children: Seq[Expression]) extends Expression {
override def foldable: Boolean = children.forall(_.foldable)
- override lazy val resolved: Boolean = childrenResolved
-
override lazy val dataType: StructType = {
val fields = children.zipWithIndex.map { case (child, idx) =>
child match {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 4ab5ac2c61..47b06cae15 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -18,7 +18,7 @@
package org.apache.spark.sql.catalyst.optimizer
import scala.collection.immutable.HashSet
-import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries
+import org.apache.spark.sql.catalyst.analysis.{CleanupAliases, EliminateSubQueries}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.Inner
import org.apache.spark.sql.catalyst.plans.FullOuter
@@ -260,8 +260,11 @@ object ProjectCollapsing extends Rule[LogicalPlan] {
val substitutedProjection = projectList1.map(_.transform {
case a: Attribute => aliasMap.getOrElse(a, a)
}).asInstanceOf[Seq[NamedExpression]]
-
- Project(substitutedProjection, child)
+ // collapse 2 projects may introduce unnecessary Aliases, trim them here.
+ val cleanedProjection = substitutedProjection.map(p =>
+ CleanupAliases.trimNonTopLevelAliases(p).asInstanceOf[NamedExpression]
+ )
+ Project(cleanedProjection, child)
}
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
index c290e6acb3..9bb466ac2d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
@@ -259,13 +259,13 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging {
// One match, but we also need to extract the requested nested field.
case Seq((a, nestedFields)) =>
// The foldLeft adds ExtractValues for every remaining parts of the identifier,
- // and wrap it with UnresolvedAlias which will be removed later.
+ // and aliased it with the last part of the name.
// For example, consider "a.b.c", where "a" is resolved to an existing attribute.
- // Then this will add ExtractValue("c", ExtractValue("b", a)), and wrap it as
- // UnresolvedAlias(ExtractValue("c", ExtractValue("b", a))).
+ // Then this will add ExtractValue("c", ExtractValue("b", a)), and alias the final
+ // expression as "c".
val fieldExprs = nestedFields.foldLeft(a: Expression)((expr, fieldName) =>
ExtractValue(expr, Literal(fieldName), resolver))
- Some(UnresolvedAlias(fieldExprs))
+ Some(Alias(fieldExprs, nestedFields.last)())
// No matches.
case Seq() =>
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
index 7c404722d8..73b8261260 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
@@ -228,7 +228,7 @@ case class Window(
child: LogicalPlan) extends UnaryNode {
override def output: Seq[Attribute] =
- (projectList ++ windowExpressions).map(_.toAttribute)
+ projectList ++ windowExpressions.map(_.toAttribute)
}
/**
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
index c944bc69e2..1e0cc81dae 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
@@ -119,4 +119,21 @@ class AnalysisSuite extends AnalysisTest {
Project(testRelation.output :+ projected, testRelation)))
checkAnalysis(plan, expected)
}
+
+ test("SPARK-9634: cleanup unnecessary Aliases in LogicalPlan") {
+ val a = testRelation.output.head
+ var plan = testRelation.select(((a + 1).as("a+1") + 2).as("col"))
+ var expected = testRelation.select((a + 1 + 2).as("col"))
+ checkAnalysis(plan, expected)
+
+ plan = testRelation.groupBy(a.as("a1").as("a2"))((min(a).as("min_a") + 1).as("col"))
+ expected = testRelation.groupBy(a)((min(a) + 1).as("col"))
+ checkAnalysis(plan, expected)
+
+ // CreateStruct is a special case that we should not trim Alias for it.
+ plan = testRelation.select(CreateStruct(Seq(a, (a + 1).as("a+1"))).as("col"))
+ checkAnalysis(plan, plan)
+ plan = testRelation.select(CreateStructUnsafe(Seq(a, (a + 1).as("a+1"))).as("col"))
+ checkAnalysis(plan, plan)
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
index 27bd084847..807bc8c30c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
@@ -753,10 +753,16 @@ class Column(protected[sql] val expr: Expression) extends Logging {
* df.select($"colA".as("colB"))
* }}}
*
+ * If the current column has metadata associated with it, this metadata will be propagated
+ * to the new column. If this not desired, use `as` with explicitly empty metadata.
+ *
* @group expr_ops
* @since 1.3.0
*/
- def as(alias: String): Column = Alias(expr, alias)()
+ def as(alias: String): Column = expr match {
+ case ne: NamedExpression => Alias(expr, alias)(explicitMetadata = Some(ne.metadata))
+ case other => Alias(other, alias)()
+ }
/**
* (Scala-specific) Assigns the given aliases to the results of a table generating function.
@@ -789,10 +795,16 @@ class Column(protected[sql] val expr: Expression) extends Logging {
* df.select($"colA".as('colB))
* }}}
*
+ * If the current column has metadata associated with it, this metadata will be propagated
+ * to the new column. If this not desired, use `as` with explicitly empty metadata.
+ *
* @group expr_ops
* @since 1.3.0
*/
- def as(alias: Symbol): Column = Alias(expr, alias.name)()
+ def as(alias: Symbol): Column = expr match {
+ case ne: NamedExpression => Alias(expr, alias.name)(explicitMetadata = Some(ne.metadata))
+ case other => Alias(other, alias.name)()
+ }
/**
* Gives the column an alias with metadata.
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
index ee74e3e83d..37738ec5b3 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
@@ -17,6 +17,7 @@
package org.apache.spark.sql
+import org.apache.spark.sql.catalyst.expressions.NamedExpression
import org.scalatest.Matchers._
import org.apache.spark.sql.execution.{Project, TungstenProject}
@@ -110,6 +111,14 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext {
assert(df.select(df("a").alias("b")).columns.head === "b")
}
+ test("as propagates metadata") {
+ val metadata = new MetadataBuilder
+ metadata.putString("key", "value")
+ val origCol = $"a".as("b", metadata.build())
+ val newCol = origCol.as("c")
+ assert(newCol.expr.asInstanceOf[NamedExpression].metadata.getString("key") === "value")
+ }
+
test("single explode") {
val df = Seq((1, Seq(1, 2, 3))).toDF("a", "intList")
checkAnswer(
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index 10bfa9b64f..cf22797752 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -867,4 +867,10 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
val actual = df.sort(rand(seed)).collect().map(_.getInt(0))
assert(expected === actual)
}
+
+ test("SPARK-9323: DataFrame.orderBy should support nested column name") {
+ val df = sqlContext.read.json(sqlContext.sparkContext.makeRDD(
+ """{"a": {"b": 1}}""" :: Nil))
+ checkAnswer(df.orderBy("a.b"), Row(Row(1)))
+ }
}