aboutsummaryrefslogtreecommitdiff
path: root/sql/catalyst
diff options
context:
space:
mode:
Diffstat (limited to 'sql/catalyst')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala11
1 files changed, 6 insertions, 5 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
index 0e3a8a6bd3..4544b32958 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
@@ -129,11 +129,12 @@ object HashFilteredJoin extends Logging with PredicateHelper {
// as join keys.
def splitPredicates(allPredicates: Seq[Expression], join: Join): Option[ReturnType] = {
val Join(left, right, joinType, _) = join
- val (joinPredicates, otherPredicates) = allPredicates.partition {
- case Equals(l, r) if (canEvaluate(l, left) && canEvaluate(r, right)) ||
- (canEvaluate(l, right) && canEvaluate(r, left)) => true
- case _ => false
- }
+ val (joinPredicates, otherPredicates) =
+ allPredicates.flatMap(splitConjunctivePredicates).partition {
+ case Equals(l, r) if (canEvaluate(l, left) && canEvaluate(r, right)) ||
+ (canEvaluate(l, right) && canEvaluate(r, left)) => true
+ case _ => false
+ }
val joinKeys = joinPredicates.map {
case Equals(l, r) if canEvaluate(l, left) && canEvaluate(r, right) => (l, r)