diff options
Diffstat (limited to 'sql/catalyst')
2 files changed, 55 insertions, 30 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index b7458910da..3a7004ef29 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -428,43 +428,49 @@ object FoldablePropagation extends Rule[LogicalPlan] { } case _ => Nil }) + val replaceFoldable: PartialFunction[Expression, Expression] = { + case a: AttributeReference if foldableMap.contains(a) => foldableMap(a) + } if (foldableMap.isEmpty) { plan } else { var stop = false CleanupAliases(plan.transformUp { - case u: Union => - stop = true - u - case c: Command => - stop = true - c - // For outer join, although its output attributes are derived from its children, they are - // actually different attributes: the output of outer join is not always picked from its - // children, but can also be null. + // A leaf node should not stop the folding process (note that we are traversing up the + // tree, starting at the leaf nodes); so we are allowing it. + case l: LeafNode => + l + + // Whitelist of all nodes we are allowed to apply this rule to. + case p @ (_: Project | _: Filter | _: SubqueryAlias | _: Aggregate | _: Window | + _: Sample | _: GlobalLimit | _: LocalLimit | _: Generate | _: Distinct | + _: AppendColumns | _: AppendColumnsWithObject | _: BroadcastHint | + _: RedistributeData | _: Repartition | _: Sort | _: TypedFilter) if !stop => + p.transformExpressions(replaceFoldable) + + // Allow inner joins. We do not allow outer join, although its output attributes are + // derived from its children, they are actually different attributes: the output of outer + // join is not always picked from its children, but can also be null. // TODO(cloud-fan): It seems more reasonable to use new attributes as the output attributes // of outer join. - case j @ Join(_, _, LeftOuter | RightOuter | FullOuter, _) => + case j @ Join(_, _, Inner, _) => + j.transformExpressions(replaceFoldable) + + // We can fold the projections an expand holds. However expand changes the output columns + // and often reuses the underlying attributes; so we cannot assume that a column is still + // foldable after the expand has been applied. + // TODO(hvanhovell): Expand should use new attributes as the output attributes. + case expand: Expand if !stop => + val newExpand = expand.copy(projections = expand.projections.map { projection => + projection.map(_.transform(replaceFoldable)) + }) stop = true - j + newExpand - // These 3 operators take attributes as constructor parameters, and these attributes - // can't be replaced by alias. - case m: MapGroups => - stop = true - m - case f: FlatMapGroupsInR => - stop = true - f - case c: CoGroup => + case other => stop = true - c - - case p: LogicalPlan if !stop => p.transformExpressions { - case a: AttributeReference if foldableMap.contains(a) => - foldableMap(a) - } + other }) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FoldablePropagationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FoldablePropagationSuite.scala index 355b3fc4aa..82756f545a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FoldablePropagationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FoldablePropagationSuite.scala @@ -116,16 +116,35 @@ class FoldablePropagationSuite extends PlanTest { test("Propagate in subqueries of Union queries") { val query = Union( Seq( - testRelation.select(Literal(1).as('x), 'a).select('x + 'a), - testRelation.select(Literal(2).as('x), 'a).select('x + 'a))) + testRelation.select(Literal(1).as('x), 'a).select('x, 'x + 'a), + testRelation.select(Literal(2).as('x), 'a).select('x, 'x + 'a))) .select('x) val optimized = Optimize.execute(query.analyze) val correctAnswer = Union( Seq( - testRelation.select(Literal(1).as('x), 'a).select((Literal(1).as('x) + 'a).as("(x + a)")), - testRelation.select(Literal(2).as('x), 'a).select((Literal(2).as('x) + 'a).as("(x + a)")))) + testRelation.select(Literal(1).as('x), 'a) + .select(Literal(1).as('x), (Literal(1).as('x) + 'a).as("(x + a)")), + testRelation.select(Literal(2).as('x), 'a) + .select(Literal(2).as('x), (Literal(2).as('x) + 'a).as("(x + a)")))) .select('x).analyze + comparePlans(optimized, correctAnswer) + } + test("Propagate in expand") { + val c1 = Literal(1).as('a) + val c2 = Literal(2).as('b) + val a1 = c1.toAttribute.withNullability(true) + val a2 = c2.toAttribute.withNullability(true) + val expand = Expand( + Seq(Seq(Literal(null), 'b), Seq('a, Literal(null))), + Seq(a1, a2), + OneRowRelation.select(c1, c2)) + val query = expand.where(a1.isNotNull).select(a1, a2).analyze + val optimized = Optimize.execute(query) + val correctExpand = expand.copy(projections = Seq( + Seq(Literal(null), c2), + Seq(c1, Literal(null)))) + val correctAnswer = correctExpand.where(a1.isNotNull).select(a1, a2).analyze comparePlans(optimized, correctAnswer) } } |