diff options
3 files changed, 120 insertions, 70 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 9c45b19624..e42f0b9a24 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -40,7 +40,7 @@ object DefaultOptimizer extends Optimizer { ReplaceDistinctWithAggregate) :: Batch("Operator Optimizations", FixedPoint(100), // Operator push down - UnionPushDown, + SetOperationPushDown, SamplePushDown, PushPredicateThroughJoin, PushPredicateThroughProject, @@ -84,23 +84,24 @@ object SamplePushDown extends Rule[LogicalPlan] { } /** - * Pushes operations to either side of a Union. + * Pushes operations to either side of a Union, Intersect or Except. */ -object UnionPushDown extends Rule[LogicalPlan] { +object SetOperationPushDown extends Rule[LogicalPlan] { /** * Maps Attributes from the left side to the corresponding Attribute on the right side. */ - private def buildRewrites(union: Union): AttributeMap[Attribute] = { - assert(union.left.output.size == union.right.output.size) + private def buildRewrites(bn: BinaryNode): AttributeMap[Attribute] = { + assert(bn.isInstanceOf[Union] || bn.isInstanceOf[Intersect] || bn.isInstanceOf[Except]) + assert(bn.left.output.size == bn.right.output.size) - AttributeMap(union.left.output.zip(union.right.output)) + AttributeMap(bn.left.output.zip(bn.right.output)) } /** - * Rewrites an expression so that it can be pushed to the right side of a Union operator. - * This method relies on the fact that the output attributes of a union are always equal - * to the left child's output. + * Rewrites an expression so that it can be pushed to the right side of a + * Union, Intersect or Except operator. This method relies on the fact that the output attributes + * of a union/intersect/except are always equal to the left child's output. */ private def pushToRight[A <: Expression](e: A, rewrites: AttributeMap[Attribute]) = { val result = e transform { @@ -126,6 +127,34 @@ object UnionPushDown extends Rule[LogicalPlan] { Union( Project(projectList, left), Project(projectList.map(pushToRight(_, rewrites)), right)) + + // Push down filter into intersect + case Filter(condition, i @ Intersect(left, right)) => + val rewrites = buildRewrites(i) + Intersect( + Filter(condition, left), + Filter(pushToRight(condition, rewrites), right)) + + // Push down projection into intersect + case Project(projectList, i @ Intersect(left, right)) => + val rewrites = buildRewrites(i) + Intersect( + Project(projectList, left), + Project(projectList.map(pushToRight(_, rewrites)), right)) + + // Push down filter into except + case Filter(condition, e @ Except(left, right)) => + val rewrites = buildRewrites(e) + Except( + Filter(condition, left), + Filter(pushToRight(condition, rewrites), right)) + + // Push down projection into except + case Project(projectList, e @ Except(left, right)) => + val rewrites = buildRewrites(e) + Except( + Project(projectList, left), + Project(projectList.map(pushToRight(_, rewrites)), right)) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationPushDownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationPushDownSuite.scala new file mode 100644 index 0000000000..49c979bc7d --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationPushDownSuite.scala @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.dsl.expressions._ + +class SetOperationPushDownSuite extends PlanTest { + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = + Batch("Subqueries", Once, + EliminateSubQueries) :: + Batch("Union Pushdown", Once, + SetOperationPushDown) :: Nil + } + + val testRelation = LocalRelation('a.int, 'b.int, 'c.int) + val testRelation2 = LocalRelation('d.int, 'e.int, 'f.int) + val testUnion = Union(testRelation, testRelation2) + val testIntersect = Intersect(testRelation, testRelation2) + val testExcept = Except(testRelation, testRelation2) + + test("union/intersect/except: filter to each side") { + val unionQuery = testUnion.where('a === 1) + val intersectQuery = testIntersect.where('b < 10) + val exceptQuery = testExcept.where('c >= 5) + + val unionOptimized = Optimize.execute(unionQuery.analyze) + val intersectOptimized = Optimize.execute(intersectQuery.analyze) + val exceptOptimized = Optimize.execute(exceptQuery.analyze) + + val unionCorrectAnswer = + Union(testRelation.where('a === 1), testRelation2.where('d === 1)).analyze + val intersectCorrectAnswer = + Intersect(testRelation.where('b < 10), testRelation2.where('e < 10)).analyze + val exceptCorrectAnswer = + Except(testRelation.where('c >= 5), testRelation2.where('f >= 5)).analyze + + comparePlans(unionOptimized, unionCorrectAnswer) + comparePlans(intersectOptimized, intersectCorrectAnswer) + comparePlans(exceptOptimized, exceptCorrectAnswer) + } + + test("union/intersect/except: project to each side") { + val unionQuery = testUnion.select('a) + val intersectQuery = testIntersect.select('b, 'c) + val exceptQuery = testExcept.select('a, 'b, 'c) + + val unionOptimized = Optimize.execute(unionQuery.analyze) + val intersectOptimized = Optimize.execute(intersectQuery.analyze) + val exceptOptimized = Optimize.execute(exceptQuery.analyze) + + val unionCorrectAnswer = + Union(testRelation.select('a), testRelation2.select('d)).analyze + val intersectCorrectAnswer = + Intersect(testRelation.select('b, 'c), testRelation2.select('e, 'f)).analyze + val exceptCorrectAnswer = + Except(testRelation.select('a, 'b, 'c), testRelation2.select('d, 'e, 'f)).analyze + + comparePlans(unionOptimized, unionCorrectAnswer) + comparePlans(intersectOptimized, intersectCorrectAnswer) + comparePlans(exceptOptimized, exceptCorrectAnswer) } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnionPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnionPushdownSuite.scala deleted file mode 100644 index ec379489a6..0000000000 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnionPushdownSuite.scala +++ /dev/null @@ -1,61 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.optimizer - -import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries -import org.apache.spark.sql.catalyst.plans.PlanTest -import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.rules._ -import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.dsl.expressions._ - -class UnionPushDownSuite extends PlanTest { - object Optimize extends RuleExecutor[LogicalPlan] { - val batches = - Batch("Subqueries", Once, - EliminateSubQueries) :: - Batch("Union Pushdown", Once, - UnionPushDown) :: Nil - } - - val testRelation = LocalRelation('a.int, 'b.int, 'c.int) - val testRelation2 = LocalRelation('d.int, 'e.int, 'f.int) - val testUnion = Union(testRelation, testRelation2) - - test("union: filter to each side") { - val query = testUnion.where('a === 1) - - val optimized = Optimize.execute(query.analyze) - - val correctAnswer = - Union(testRelation.where('a === 1), testRelation2.where('d === 1)).analyze - - comparePlans(optimized, correctAnswer) - } - - test("union: project to each side") { - val query = testUnion.select('b) - - val optimized = Optimize.execute(query.analyze) - - val correctAnswer = - Union(testRelation.select('b), testRelation2.select('e)).analyze - - comparePlans(optimized, correctAnswer) - } -} |