From fb0894b3a87331a731129ad3fc7ebe598d90a6ee Mon Sep 17 00:00:00 2001 From: Tejas Patil Date: Thu, 20 Oct 2016 09:50:55 -0700 Subject: [SPARK-17698][SQL] Join predicates should not contain filter clauses ## What changes were proposed in this pull request? Jira : https://issues.apache.org/jira/browse/SPARK-17698 `ExtractEquiJoinKeys` is incorrectly using filter predicates as the join condition for joins. `canEvaluate` [0] tries to see if the an `Expression` can be evaluated using output of a given `Plan`. In case of filter predicates (eg. `a.id='1'`), the `Expression` passed for the right hand side (ie. '1' ) is a `Literal` which does not have any attribute references. Thus `expr.references` is an empty set which theoretically is a subset of any set. This leads to `canEvaluate` returning `true` and `a.id='1'` is treated as a join predicate. While this does not lead to incorrect results but in case of bucketed + sorted tables, we might miss out on avoiding un-necessary shuffle + sort. See example below: [0] : https://github.com/apache/spark/blob/master/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala#L91 eg. ``` val df = (1 until 10).toDF("id").coalesce(1) hc.sql("DROP TABLE IF EXISTS table1").collect df.write.bucketBy(8, "id").sortBy("id").saveAsTable("table1") hc.sql("DROP TABLE IF EXISTS table2").collect df.write.bucketBy(8, "id").sortBy("id").saveAsTable("table2") sqlContext.sql(""" SELECT a.id, b.id FROM table1 a FULL OUTER JOIN table2 b ON a.id = b.id AND a.id='1' AND b.id='1' """).explain(true) ``` BEFORE: This is doing shuffle + sort over table scan outputs which is not needed as both tables are bucketed and sorted on the same columns and have same number of buckets. This should be a single stage job. ``` SortMergeJoin [id#38, cast(id#38 as double), 1.0], [id#39, 1.0, cast(id#39 as double)], FullOuter :- *Sort [id#38 ASC NULLS FIRST, cast(id#38 as double) ASC NULLS FIRST, 1.0 ASC NULLS FIRST], false, 0 : +- Exchange hashpartitioning(id#38, cast(id#38 as double), 1.0, 200) : +- *FileScan parquet default.table1[id#38] Batched: true, Format: ParquetFormat, InputPaths: file:spark-warehouse/table1, PartitionFilters: [], PushedFilters: [], ReadSchema: struct +- *Sort [id#39 ASC NULLS FIRST, 1.0 ASC NULLS FIRST, cast(id#39 as double) ASC NULLS FIRST], false, 0 +- Exchange hashpartitioning(id#39, 1.0, cast(id#39 as double), 200) +- *FileScan parquet default.table2[id#39] Batched: true, Format: ParquetFormat, InputPaths: file:spark-warehouse/table2, PartitionFilters: [], PushedFilters: [], ReadSchema: struct ``` AFTER : ``` SortMergeJoin [id#32], [id#33], FullOuter, ((cast(id#32 as double) = 1.0) && (cast(id#33 as double) = 1.0)) :- *FileScan parquet default.table1[id#32] Batched: true, Format: ParquetFormat, InputPaths: file:spark-warehouse/table1, PartitionFilters: [], PushedFilters: [], ReadSchema: struct +- *FileScan parquet default.table2[id#33] Batched: true, Format: ParquetFormat, InputPaths: file:spark-warehouse/table2, PartitionFilters: [], PushedFilters: [], ReadSchema: struct ``` ## How was this patch tested? - Added a new test case for this scenario : `SPARK-17698 Join predicates should not contain filter clauses` - Ran all the tests in `BucketedReadSuite` Author: Tejas Patil Closes #15272 from tejasapatil/SPARK-17698_join_predicate_filter_clause. --- .../sql/catalyst/expressions/predicates.scala | 5 +- .../spark/sql/catalyst/optimizer/joins.scala | 4 +- .../spark/sql/catalyst/planning/patterns.scala | 2 + .../spark/sql/sources/BucketedReadSuite.scala | 124 +++++++++++++++++---- 4 files changed, 109 insertions(+), 26 deletions(-) (limited to 'sql') diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 799858a686..9394e39aad 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -84,8 +84,9 @@ trait PredicateHelper { * * For example consider a join between two relations R(a, b) and S(c, d). * - * `canEvaluate(EqualTo(a,b), R)` returns `true` where as `canEvaluate(EqualTo(a,c), R)` returns - * `false`. + * - `canEvaluate(EqualTo(a,b), R)` returns `true` + * - `canEvaluate(EqualTo(a,c), R)` returns `false` + * - `canEvaluate(Literal(1), R)` returns `true` as literals CAN be evaluated on any plan */ protected def canEvaluate(expr: Expression, plan: LogicalPlan): Boolean = expr.references.subsetOf(plan.outputSet) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala index 2626057e49..180ad2e0ad 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala @@ -65,7 +65,9 @@ object ReorderJoin extends Rule[LogicalPlan] with PredicateHelper { val conditionalJoin = rest.find { planJoinPair => val plan = planJoinPair._1 val refs = left.outputSet ++ plan.outputSet - conditions.filterNot(canEvaluate(_, left)).filterNot(canEvaluate(_, plan)) + conditions + .filterNot(l => l.references.nonEmpty && canEvaluate(l, left)) + .filterNot(r => r.references.nonEmpty && canEvaluate(r, plan)) .exists(_.references.subsetOf(refs)) } // pick the next one if no condition left diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index bdae56881b..c5f92c59c8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -112,6 +112,7 @@ object ExtractEquiJoinKeys extends Logging with PredicateHelper { // as join keys. val predicates = condition.map(splitConjunctivePredicates).getOrElse(Nil) val joinKeys = predicates.flatMap { + case EqualTo(l, r) if l.references.isEmpty || r.references.isEmpty => None case EqualTo(l, r) if canEvaluate(l, left) && canEvaluate(r, right) => Some((l, r)) case EqualTo(l, r) if canEvaluate(l, right) && canEvaluate(r, left) => Some((r, l)) // Replace null with default value for joining key, then those rows with null in it could @@ -125,6 +126,7 @@ object ExtractEquiJoinKeys extends Logging with PredicateHelper { case other => None } val otherPredicates = predicates.filterNot { + case EqualTo(l, r) if l.references.isEmpty || r.references.isEmpty => false case EqualTo(l, r) => canEvaluate(l, left) && canEvaluate(r, right) || canEvaluate(l, right) && canEvaluate(r, left) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala index 3ff85176de..9ed454e578 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala @@ -235,7 +235,8 @@ class BucketedReadSuite extends QueryTest with SQLTestUtils with TestHiveSinglet private def testBucketing( bucketSpecLeft: Option[BucketSpec], bucketSpecRight: Option[BucketSpec], - joinColumns: Seq[String], + joinType: String = "inner", + joinCondition: (DataFrame, DataFrame) => Column, shuffleLeft: Boolean, shuffleRight: Boolean, sortLeft: Boolean = true, @@ -268,12 +269,12 @@ class BucketedReadSuite extends QueryTest with SQLTestUtils with TestHiveSinglet SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false") { val t1 = spark.table("bucketed_table1") val t2 = spark.table("bucketed_table2") - val joined = t1.join(t2, joinCondition(t1, t2, joinColumns)) + val joined = t1.join(t2, joinCondition(t1, t2), joinType) // First check the result is corrected. checkAnswer( joined.sort("bucketed_table1.k", "bucketed_table2.k"), - df1.join(df2, joinCondition(df1, df2, joinColumns)).sort("df1.k", "df2.k")) + df1.join(df2, joinCondition(df1, df2), joinType).sort("df1.k", "df2.k")) assert(joined.queryExecution.executedPlan.isInstanceOf[SortMergeJoinExec]) val joinOperator = joined.queryExecution.executedPlan.asInstanceOf[SortMergeJoinExec] @@ -297,56 +298,102 @@ class BucketedReadSuite extends QueryTest with SQLTestUtils with TestHiveSinglet } } - private def joinCondition(left: DataFrame, right: DataFrame, joinCols: Seq[String]): Column = { + private def joinCondition(joinCols: Seq[String]) (left: DataFrame, right: DataFrame): Column = { joinCols.map(col => left(col) === right(col)).reduce(_ && _) } test("avoid shuffle when join 2 bucketed tables") { val bucketSpec = Some(BucketSpec(8, Seq("i", "j"), Nil)) - testBucketing(bucketSpec, bucketSpec, Seq("i", "j"), shuffleLeft = false, shuffleRight = false) + testBucketing( + bucketSpecLeft = bucketSpec, + bucketSpecRight = bucketSpec, + joinCondition = joinCondition(Seq("i", "j")), + shuffleLeft = false, + shuffleRight = false + ) } // Enable it after fix https://issues.apache.org/jira/browse/SPARK-12704 ignore("avoid shuffle when join keys are a super-set of bucket keys") { val bucketSpec = Some(BucketSpec(8, Seq("i"), Nil)) - testBucketing(bucketSpec, bucketSpec, Seq("i", "j"), shuffleLeft = false, shuffleRight = false) + testBucketing( + bucketSpecLeft = bucketSpec, + bucketSpecRight = bucketSpec, + joinCondition = joinCondition(Seq("i", "j")), + shuffleLeft = false, + shuffleRight = false + ) } test("only shuffle one side when join bucketed table and non-bucketed table") { val bucketSpec = Some(BucketSpec(8, Seq("i", "j"), Nil)) - testBucketing(bucketSpec, None, Seq("i", "j"), shuffleLeft = false, shuffleRight = true) + testBucketing( + bucketSpecLeft = bucketSpec, + bucketSpecRight = None, + joinCondition = joinCondition(Seq("i", "j")), + shuffleLeft = false, + shuffleRight = true + ) } test("only shuffle one side when 2 bucketed tables have different bucket number") { val bucketSpec1 = Some(BucketSpec(8, Seq("i", "j"), Nil)) val bucketSpec2 = Some(BucketSpec(5, Seq("i", "j"), Nil)) - testBucketing(bucketSpec1, bucketSpec2, Seq("i", "j"), shuffleLeft = false, shuffleRight = true) + testBucketing( + bucketSpecLeft = bucketSpec1, + bucketSpecRight = bucketSpec2, + joinCondition = joinCondition(Seq("i", "j")), + shuffleLeft = false, + shuffleRight = true + ) } test("only shuffle one side when 2 bucketed tables have different bucket keys") { val bucketSpec1 = Some(BucketSpec(8, Seq("i"), Nil)) val bucketSpec2 = Some(BucketSpec(8, Seq("j"), Nil)) - testBucketing(bucketSpec1, bucketSpec2, Seq("i"), shuffleLeft = false, shuffleRight = true) + testBucketing( + bucketSpecLeft = bucketSpec1, + bucketSpecRight = bucketSpec2, + joinCondition = joinCondition(Seq("i")), + shuffleLeft = false, + shuffleRight = true + ) } test("shuffle when join keys are not equal to bucket keys") { val bucketSpec = Some(BucketSpec(8, Seq("i"), Nil)) - testBucketing(bucketSpec, bucketSpec, Seq("j"), shuffleLeft = true, shuffleRight = true) + testBucketing( + bucketSpecLeft = bucketSpec, + bucketSpecRight = bucketSpec, + joinCondition = joinCondition(Seq("j")), + shuffleLeft = true, + shuffleRight = true + ) } test("shuffle when join 2 bucketed tables with bucketing disabled") { val bucketSpec = Some(BucketSpec(8, Seq("i", "j"), Nil)) withSQLConf(SQLConf.BUCKETING_ENABLED.key -> "false") { - testBucketing(bucketSpec, bucketSpec, Seq("i", "j"), shuffleLeft = true, shuffleRight = true) + testBucketing( + bucketSpecLeft = bucketSpec, + bucketSpecRight = bucketSpec, + joinCondition = joinCondition(Seq("i", "j")), + shuffleLeft = true, + shuffleRight = true + ) } } test("avoid shuffle and sort when bucket and sort columns are join keys") { val bucketSpec = Some(BucketSpec(8, Seq("i", "j"), Seq("i", "j"))) testBucketing( - bucketSpec, bucketSpec, Seq("i", "j"), - shuffleLeft = false, shuffleRight = false, - sortLeft = false, sortRight = false + bucketSpecLeft = bucketSpec, + bucketSpecRight = bucketSpec, + joinCondition = joinCondition(Seq("i", "j")), + shuffleLeft = false, + shuffleRight = false, + sortLeft = false, + sortRight = false ) } @@ -354,9 +401,13 @@ class BucketedReadSuite extends QueryTest with SQLTestUtils with TestHiveSinglet val bucketSpec1 = Some(BucketSpec(8, Seq("i"), Seq("i", "j"))) val bucketSpec2 = Some(BucketSpec(8, Seq("i"), Seq("i", "k"))) testBucketing( - bucketSpec1, bucketSpec2, Seq("i"), - shuffleLeft = false, shuffleRight = false, - sortLeft = false, sortRight = false + bucketSpecLeft = bucketSpec1, + bucketSpecRight = bucketSpec2, + joinCondition = joinCondition(Seq("i")), + shuffleLeft = false, + shuffleRight = false, + sortLeft = false, + sortRight = false ) } @@ -364,9 +415,13 @@ class BucketedReadSuite extends QueryTest with SQLTestUtils with TestHiveSinglet val bucketSpec1 = Some(BucketSpec(8, Seq("i", "j"), Seq("i", "j"))) val bucketSpec2 = Some(BucketSpec(8, Seq("i", "j"), Seq("k"))) testBucketing( - bucketSpec1, bucketSpec2, Seq("i", "j"), - shuffleLeft = false, shuffleRight = false, - sortLeft = false, sortRight = true + bucketSpecLeft = bucketSpec1, + bucketSpecRight = bucketSpec2, + joinCondition = joinCondition(Seq("i", "j")), + shuffleLeft = false, + shuffleRight = false, + sortLeft = false, + sortRight = true ) } @@ -374,9 +429,13 @@ class BucketedReadSuite extends QueryTest with SQLTestUtils with TestHiveSinglet val bucketSpec1 = Some(BucketSpec(8, Seq("i", "j"), Seq("i", "j"))) val bucketSpec2 = Some(BucketSpec(8, Seq("i", "j"), Seq("j", "i"))) testBucketing( - bucketSpec1, bucketSpec2, Seq("i", "j"), - shuffleLeft = false, shuffleRight = false, - sortLeft = false, sortRight = true + bucketSpecLeft = bucketSpec1, + bucketSpecRight = bucketSpec2, + joinCondition = joinCondition(Seq("i", "j")), + shuffleLeft = false, + shuffleRight = false, + sortLeft = false, + sortRight = true ) } @@ -408,6 +467,25 @@ class BucketedReadSuite extends QueryTest with SQLTestUtils with TestHiveSinglet } } + test("SPARK-17698 Join predicates should not contain filter clauses") { + val bucketSpec = Some(BucketSpec(8, Seq("i"), Seq("i"))) + testBucketing( + bucketSpecLeft = bucketSpec, + bucketSpecRight = bucketSpec, + joinType = "fullouter", + joinCondition = (left: DataFrame, right: DataFrame) => { + val joinPredicates = Seq("i").map(col => left(col) === right(col)).reduce(_ && _) + val filterLeft = left("i") === Literal("1") + val filterRight = right("i") === Literal("1") + joinPredicates && filterLeft && filterRight + }, + shuffleLeft = false, + shuffleRight = false, + sortLeft = false, + sortRight = false + ) + } + test("error if there exists any malformed bucket files") { withTable("bucketed_table") { df1.write.format("parquet").bucketBy(8, "i").saveAsTable("bucketed_table") -- cgit v1.2.3