aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala5
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala4
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala2
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala124
4 files changed, 109 insertions, 26 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
index 799858a686..9394e39aad 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
@@ -84,8 +84,9 @@ trait PredicateHelper {
*
* For example consider a join between two relations R(a, b) and S(c, d).
*
- * `canEvaluate(EqualTo(a,b), R)` returns `true` where as `canEvaluate(EqualTo(a,c), R)` returns
- * `false`.
+ * - `canEvaluate(EqualTo(a,b), R)` returns `true`
+ * - `canEvaluate(EqualTo(a,c), R)` returns `false`
+ * - `canEvaluate(Literal(1), R)` returns `true` as literals CAN be evaluated on any plan
*/
protected def canEvaluate(expr: Expression, plan: LogicalPlan): Boolean =
expr.references.subsetOf(plan.outputSet)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala
index 2626057e49..180ad2e0ad 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala
@@ -65,7 +65,9 @@ object ReorderJoin extends Rule[LogicalPlan] with PredicateHelper {
val conditionalJoin = rest.find { planJoinPair =>
val plan = planJoinPair._1
val refs = left.outputSet ++ plan.outputSet
- conditions.filterNot(canEvaluate(_, left)).filterNot(canEvaluate(_, plan))
+ conditions
+ .filterNot(l => l.references.nonEmpty && canEvaluate(l, left))
+ .filterNot(r => r.references.nonEmpty && canEvaluate(r, plan))
.exists(_.references.subsetOf(refs))
}
// pick the next one if no condition left
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
index bdae56881b..c5f92c59c8 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
@@ -112,6 +112,7 @@ object ExtractEquiJoinKeys extends Logging with PredicateHelper {
// as join keys.
val predicates = condition.map(splitConjunctivePredicates).getOrElse(Nil)
val joinKeys = predicates.flatMap {
+ case EqualTo(l, r) if l.references.isEmpty || r.references.isEmpty => None
case EqualTo(l, r) if canEvaluate(l, left) && canEvaluate(r, right) => Some((l, r))
case EqualTo(l, r) if canEvaluate(l, right) && canEvaluate(r, left) => Some((r, l))
// Replace null with default value for joining key, then those rows with null in it could
@@ -125,6 +126,7 @@ object ExtractEquiJoinKeys extends Logging with PredicateHelper {
case other => None
}
val otherPredicates = predicates.filterNot {
+ case EqualTo(l, r) if l.references.isEmpty || r.references.isEmpty => false
case EqualTo(l, r) =>
canEvaluate(l, left) && canEvaluate(r, right) ||
canEvaluate(l, right) && canEvaluate(r, left)
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala
index 3ff85176de..9ed454e578 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala
@@ -235,7 +235,8 @@ class BucketedReadSuite extends QueryTest with SQLTestUtils with TestHiveSinglet
private def testBucketing(
bucketSpecLeft: Option[BucketSpec],
bucketSpecRight: Option[BucketSpec],
- joinColumns: Seq[String],
+ joinType: String = "inner",
+ joinCondition: (DataFrame, DataFrame) => Column,
shuffleLeft: Boolean,
shuffleRight: Boolean,
sortLeft: Boolean = true,
@@ -268,12 +269,12 @@ class BucketedReadSuite extends QueryTest with SQLTestUtils with TestHiveSinglet
SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false") {
val t1 = spark.table("bucketed_table1")
val t2 = spark.table("bucketed_table2")
- val joined = t1.join(t2, joinCondition(t1, t2, joinColumns))
+ val joined = t1.join(t2, joinCondition(t1, t2), joinType)
// First check the result is corrected.
checkAnswer(
joined.sort("bucketed_table1.k", "bucketed_table2.k"),
- df1.join(df2, joinCondition(df1, df2, joinColumns)).sort("df1.k", "df2.k"))
+ df1.join(df2, joinCondition(df1, df2), joinType).sort("df1.k", "df2.k"))
assert(joined.queryExecution.executedPlan.isInstanceOf[SortMergeJoinExec])
val joinOperator = joined.queryExecution.executedPlan.asInstanceOf[SortMergeJoinExec]
@@ -297,56 +298,102 @@ class BucketedReadSuite extends QueryTest with SQLTestUtils with TestHiveSinglet
}
}
- private def joinCondition(left: DataFrame, right: DataFrame, joinCols: Seq[String]): Column = {
+ private def joinCondition(joinCols: Seq[String]) (left: DataFrame, right: DataFrame): Column = {
joinCols.map(col => left(col) === right(col)).reduce(_ && _)
}
test("avoid shuffle when join 2 bucketed tables") {
val bucketSpec = Some(BucketSpec(8, Seq("i", "j"), Nil))
- testBucketing(bucketSpec, bucketSpec, Seq("i", "j"), shuffleLeft = false, shuffleRight = false)
+ testBucketing(
+ bucketSpecLeft = bucketSpec,
+ bucketSpecRight = bucketSpec,
+ joinCondition = joinCondition(Seq("i", "j")),
+ shuffleLeft = false,
+ shuffleRight = false
+ )
}
// Enable it after fix https://issues.apache.org/jira/browse/SPARK-12704
ignore("avoid shuffle when join keys are a super-set of bucket keys") {
val bucketSpec = Some(BucketSpec(8, Seq("i"), Nil))
- testBucketing(bucketSpec, bucketSpec, Seq("i", "j"), shuffleLeft = false, shuffleRight = false)
+ testBucketing(
+ bucketSpecLeft = bucketSpec,
+ bucketSpecRight = bucketSpec,
+ joinCondition = joinCondition(Seq("i", "j")),
+ shuffleLeft = false,
+ shuffleRight = false
+ )
}
test("only shuffle one side when join bucketed table and non-bucketed table") {
val bucketSpec = Some(BucketSpec(8, Seq("i", "j"), Nil))
- testBucketing(bucketSpec, None, Seq("i", "j"), shuffleLeft = false, shuffleRight = true)
+ testBucketing(
+ bucketSpecLeft = bucketSpec,
+ bucketSpecRight = None,
+ joinCondition = joinCondition(Seq("i", "j")),
+ shuffleLeft = false,
+ shuffleRight = true
+ )
}
test("only shuffle one side when 2 bucketed tables have different bucket number") {
val bucketSpec1 = Some(BucketSpec(8, Seq("i", "j"), Nil))
val bucketSpec2 = Some(BucketSpec(5, Seq("i", "j"), Nil))
- testBucketing(bucketSpec1, bucketSpec2, Seq("i", "j"), shuffleLeft = false, shuffleRight = true)
+ testBucketing(
+ bucketSpecLeft = bucketSpec1,
+ bucketSpecRight = bucketSpec2,
+ joinCondition = joinCondition(Seq("i", "j")),
+ shuffleLeft = false,
+ shuffleRight = true
+ )
}
test("only shuffle one side when 2 bucketed tables have different bucket keys") {
val bucketSpec1 = Some(BucketSpec(8, Seq("i"), Nil))
val bucketSpec2 = Some(BucketSpec(8, Seq("j"), Nil))
- testBucketing(bucketSpec1, bucketSpec2, Seq("i"), shuffleLeft = false, shuffleRight = true)
+ testBucketing(
+ bucketSpecLeft = bucketSpec1,
+ bucketSpecRight = bucketSpec2,
+ joinCondition = joinCondition(Seq("i")),
+ shuffleLeft = false,
+ shuffleRight = true
+ )
}
test("shuffle when join keys are not equal to bucket keys") {
val bucketSpec = Some(BucketSpec(8, Seq("i"), Nil))
- testBucketing(bucketSpec, bucketSpec, Seq("j"), shuffleLeft = true, shuffleRight = true)
+ testBucketing(
+ bucketSpecLeft = bucketSpec,
+ bucketSpecRight = bucketSpec,
+ joinCondition = joinCondition(Seq("j")),
+ shuffleLeft = true,
+ shuffleRight = true
+ )
}
test("shuffle when join 2 bucketed tables with bucketing disabled") {
val bucketSpec = Some(BucketSpec(8, Seq("i", "j"), Nil))
withSQLConf(SQLConf.BUCKETING_ENABLED.key -> "false") {
- testBucketing(bucketSpec, bucketSpec, Seq("i", "j"), shuffleLeft = true, shuffleRight = true)
+ testBucketing(
+ bucketSpecLeft = bucketSpec,
+ bucketSpecRight = bucketSpec,
+ joinCondition = joinCondition(Seq("i", "j")),
+ shuffleLeft = true,
+ shuffleRight = true
+ )
}
}
test("avoid shuffle and sort when bucket and sort columns are join keys") {
val bucketSpec = Some(BucketSpec(8, Seq("i", "j"), Seq("i", "j")))
testBucketing(
- bucketSpec, bucketSpec, Seq("i", "j"),
- shuffleLeft = false, shuffleRight = false,
- sortLeft = false, sortRight = false
+ bucketSpecLeft = bucketSpec,
+ bucketSpecRight = bucketSpec,
+ joinCondition = joinCondition(Seq("i", "j")),
+ shuffleLeft = false,
+ shuffleRight = false,
+ sortLeft = false,
+ sortRight = false
)
}
@@ -354,9 +401,13 @@ class BucketedReadSuite extends QueryTest with SQLTestUtils with TestHiveSinglet
val bucketSpec1 = Some(BucketSpec(8, Seq("i"), Seq("i", "j")))
val bucketSpec2 = Some(BucketSpec(8, Seq("i"), Seq("i", "k")))
testBucketing(
- bucketSpec1, bucketSpec2, Seq("i"),
- shuffleLeft = false, shuffleRight = false,
- sortLeft = false, sortRight = false
+ bucketSpecLeft = bucketSpec1,
+ bucketSpecRight = bucketSpec2,
+ joinCondition = joinCondition(Seq("i")),
+ shuffleLeft = false,
+ shuffleRight = false,
+ sortLeft = false,
+ sortRight = false
)
}
@@ -364,9 +415,13 @@ class BucketedReadSuite extends QueryTest with SQLTestUtils with TestHiveSinglet
val bucketSpec1 = Some(BucketSpec(8, Seq("i", "j"), Seq("i", "j")))
val bucketSpec2 = Some(BucketSpec(8, Seq("i", "j"), Seq("k")))
testBucketing(
- bucketSpec1, bucketSpec2, Seq("i", "j"),
- shuffleLeft = false, shuffleRight = false,
- sortLeft = false, sortRight = true
+ bucketSpecLeft = bucketSpec1,
+ bucketSpecRight = bucketSpec2,
+ joinCondition = joinCondition(Seq("i", "j")),
+ shuffleLeft = false,
+ shuffleRight = false,
+ sortLeft = false,
+ sortRight = true
)
}
@@ -374,9 +429,13 @@ class BucketedReadSuite extends QueryTest with SQLTestUtils with TestHiveSinglet
val bucketSpec1 = Some(BucketSpec(8, Seq("i", "j"), Seq("i", "j")))
val bucketSpec2 = Some(BucketSpec(8, Seq("i", "j"), Seq("j", "i")))
testBucketing(
- bucketSpec1, bucketSpec2, Seq("i", "j"),
- shuffleLeft = false, shuffleRight = false,
- sortLeft = false, sortRight = true
+ bucketSpecLeft = bucketSpec1,
+ bucketSpecRight = bucketSpec2,
+ joinCondition = joinCondition(Seq("i", "j")),
+ shuffleLeft = false,
+ shuffleRight = false,
+ sortLeft = false,
+ sortRight = true
)
}
@@ -408,6 +467,25 @@ class BucketedReadSuite extends QueryTest with SQLTestUtils with TestHiveSinglet
}
}
+ test("SPARK-17698 Join predicates should not contain filter clauses") {
+ val bucketSpec = Some(BucketSpec(8, Seq("i"), Seq("i")))
+ testBucketing(
+ bucketSpecLeft = bucketSpec,
+ bucketSpecRight = bucketSpec,
+ joinType = "fullouter",
+ joinCondition = (left: DataFrame, right: DataFrame) => {
+ val joinPredicates = Seq("i").map(col => left(col) === right(col)).reduce(_ && _)
+ val filterLeft = left("i") === Literal("1")
+ val filterRight = right("i") === Literal("1")
+ joinPredicates && filterLeft && filterRight
+ },
+ shuffleLeft = false,
+ shuffleRight = false,
+ sortLeft = false,
+ sortRight = false
+ )
+ }
+
test("error if there exists any malformed bucket files") {
withTable("bucketed_table") {
df1.write.format("parquet").bucketBy(8, "i").saveAsTable("bucketed_table")