aboutsummaryrefslogtreecommitdiff
path: root/sql/catalyst/src/main/scala
diff options
context:
space:
mode:
Diffstat (limited to 'sql/catalyst/src/main/scala')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala49
1 files changed, 31 insertions, 18 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 08fb0199fc..c8e9d8e2f9 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -165,36 +165,49 @@ object PushProjectThroughSample extends Rule[LogicalPlan] {
* but can also benefit other operators.
*/
object RemoveAliasOnlyProject extends Rule[LogicalPlan] {
- // Check if projectList in the Project node has the same attribute names and ordering
- // as its child node.
+ /**
+ * Returns true if the project list is semantically same as child output, after strip alias on
+ * attribute.
+ */
private def isAliasOnly(
projectList: Seq[NamedExpression],
childOutput: Seq[Attribute]): Boolean = {
- if (!projectList.forall(_.isInstanceOf[Alias]) || projectList.length != childOutput.length) {
+ if (projectList.length != childOutput.length) {
false
} else {
- projectList.map(_.asInstanceOf[Alias]).zip(childOutput).forall { case (a, o) =>
- a.child match {
- case attr: Attribute if a.name == attr.name && attr.semanticEquals(o) => true
- case _ => false
- }
+ stripAliasOnAttribute(projectList).zip(childOutput).forall {
+ case (a: Attribute, o) if a semanticEquals o => true
+ case _ => false
}
}
}
+ private def stripAliasOnAttribute(projectList: Seq[NamedExpression]) = {
+ projectList.map {
+ // Alias with metadata can not be stripped, or the metadata will be lost.
+ // If the alias name is different from attribute name, we can't strip it either, or we may
+ // accidentally change the output schema name of the root plan.
+ case a @ Alias(attr: Attribute, name) if a.metadata == Metadata.empty && name == attr.name =>
+ attr
+ case other => other
+ }
+ }
+
def apply(plan: LogicalPlan): LogicalPlan = {
- val aliasOnlyProject = plan.find {
- case Project(pList, child) if isAliasOnly(pList, child.output) => true
- case _ => false
+ val aliasOnlyProject = plan.collectFirst {
+ case p @ Project(pList, child) if isAliasOnly(pList, child.output) => p
}
- aliasOnlyProject.map { case p: Project =>
- val aliases = p.projectList.map(_.asInstanceOf[Alias])
- val attrMap = AttributeMap(aliases.map(a => (a.toAttribute, a.child)))
- plan.transformAllExpressions {
- case a: Attribute if attrMap.contains(a) => attrMap(a)
- }.transform {
- case op: Project if op.eq(p) => op.child
+ aliasOnlyProject.map { case proj =>
+ val attributesToReplace = proj.output.zip(proj.child.output).filterNot {
+ case (a1, a2) => a1 semanticEquals a2
+ }
+ val attrMap = AttributeMap(attributesToReplace)
+ plan transform {
+ case plan: Project if plan eq proj => plan.child
+ case plan => plan transformExpressions {
+ case a: Attribute if attrMap.contains(a) => attrMap(a)
+ }
}
}.getOrElse(plan)
}