diff options
Diffstat (limited to 'sql/catalyst/src/main/scala')
-rw-r--r-- | sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala | 49 |
1 files changed, 31 insertions, 18 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 08fb0199fc..c8e9d8e2f9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -165,36 +165,49 @@ object PushProjectThroughSample extends Rule[LogicalPlan] { * but can also benefit other operators. */ object RemoveAliasOnlyProject extends Rule[LogicalPlan] { - // Check if projectList in the Project node has the same attribute names and ordering - // as its child node. + /** + * Returns true if the project list is semantically same as child output, after strip alias on + * attribute. + */ private def isAliasOnly( projectList: Seq[NamedExpression], childOutput: Seq[Attribute]): Boolean = { - if (!projectList.forall(_.isInstanceOf[Alias]) || projectList.length != childOutput.length) { + if (projectList.length != childOutput.length) { false } else { - projectList.map(_.asInstanceOf[Alias]).zip(childOutput).forall { case (a, o) => - a.child match { - case attr: Attribute if a.name == attr.name && attr.semanticEquals(o) => true - case _ => false - } + stripAliasOnAttribute(projectList).zip(childOutput).forall { + case (a: Attribute, o) if a semanticEquals o => true + case _ => false } } } + private def stripAliasOnAttribute(projectList: Seq[NamedExpression]) = { + projectList.map { + // Alias with metadata can not be stripped, or the metadata will be lost. + // If the alias name is different from attribute name, we can't strip it either, or we may + // accidentally change the output schema name of the root plan. + case a @ Alias(attr: Attribute, name) if a.metadata == Metadata.empty && name == attr.name => + attr + case other => other + } + } + def apply(plan: LogicalPlan): LogicalPlan = { - val aliasOnlyProject = plan.find { - case Project(pList, child) if isAliasOnly(pList, child.output) => true - case _ => false + val aliasOnlyProject = plan.collectFirst { + case p @ Project(pList, child) if isAliasOnly(pList, child.output) => p } - aliasOnlyProject.map { case p: Project => - val aliases = p.projectList.map(_.asInstanceOf[Alias]) - val attrMap = AttributeMap(aliases.map(a => (a.toAttribute, a.child))) - plan.transformAllExpressions { - case a: Attribute if attrMap.contains(a) => attrMap(a) - }.transform { - case op: Project if op.eq(p) => op.child + aliasOnlyProject.map { case proj => + val attributesToReplace = proj.output.zip(proj.child.output).filterNot { + case (a1, a2) => a1 semanticEquals a2 + } + val attrMap = AttributeMap(attributesToReplace) + plan transform { + case plan: Project if plan eq proj => plan.child + case plan => plan transformExpressions { + case a: Attribute if attrMap.contains(a) => attrMap(a) + } } }.getOrElse(plan) } |