diff options
author | Reynold Xin <rxin@databricks.com> | 2016-02-03 23:43:48 -0800 |
---|---|---|
committer | Reynold Xin <rxin@databricks.com> | 2016-02-03 23:43:48 -0800 |
commit | dee801adb78d6abd0abbf76b4dfa71aa296b4f0b (patch) | |
tree | 705efd2edac21ab651e7badcaa452735a2c31552 /sql | |
parent | d39087147ff1052b623cdba69ffbde28b266745f (diff) | |
download | spark-dee801adb78d6abd0abbf76b4dfa71aa296b4f0b.tar.gz spark-dee801adb78d6abd0abbf76b4dfa71aa296b4f0b.tar.bz2 spark-dee801adb78d6abd0abbf76b4dfa71aa296b4f0b.zip |
[SPARK-12828][SQL] Natural join follow-up
This is a small addendum to #10762 to make the code more robust again future changes.
Author: Reynold Xin <rxin@databricks.com>
Closes #11070 from rxin/SPARK-12828-natural-join.
Diffstat (limited to 'sql')
3 files changed, 17 insertions, 12 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index b30ed5928f..b59eb12419 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1239,21 +1239,23 @@ class Analyzer( */ object ResolveNaturalJoin extends Rule[LogicalPlan] { override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { - // Should not skip unresolved nodes because natural join is always unresolved. case j @ Join(left, right, NaturalJoin(joinType), condition) if j.resolvedExceptNatural => - // find common column names from both sides, should be treated like usingColumns + // find common column names from both sides val joinNames = left.output.map(_.name).intersect(right.output.map(_.name)) val leftKeys = joinNames.map(keyName => left.output.find(_.name == keyName).get) val rightKeys = joinNames.map(keyName => right.output.find(_.name == keyName).get) val joinPairs = leftKeys.zip(rightKeys) + // Add joinPairs to joinConditions val newCondition = (condition ++ joinPairs.map { case (l, r) => EqualTo(l, r) - }).reduceLeftOption(And) + }).reduceOption(And) + // columns not in joinPairs val lUniqueOutput = left.output.filterNot(att => leftKeys.contains(att)) val rUniqueOutput = right.output.filterNot(att => rightKeys.contains(att)) - // we should only keep unique columns(depends on joinType) for joinCols + + // the output list looks like: join keys, columns from left, columns from right val projectList = joinType match { case LeftOuter => leftKeys ++ lUniqueOutput ++ rUniqueOutput.map(_.withNullability(true)) @@ -1261,13 +1263,14 @@ class Analyzer( rightKeys ++ lUniqueOutput.map(_.withNullability(true)) ++ rUniqueOutput case FullOuter => // in full outer join, joinCols should be non-null if there is. - val joinedCols = joinPairs.map { - case (l, r) => Alias(Coalesce(Seq(l, r)), l.name)() - } - joinedCols ++ lUniqueOutput.map(_.withNullability(true)) ++ + val joinedCols = joinPairs.map { case (l, r) => Alias(Coalesce(Seq(l, r)), l.name)() } + joinedCols ++ + lUniqueOutput.map(_.withNullability(true)) ++ rUniqueOutput.map(_.withNullability(true)) - case _ => + case Inner => rightKeys ++ lUniqueOutput ++ rUniqueOutput + case _ => + sys.error("Unsupported natural join type " + joinType) } // use Project to trim unnecessary fields Project(projectList, Join(left, right, joinType, newCondition)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala index b10f1e63a7..27a75326eb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala @@ -62,5 +62,7 @@ case object LeftSemi extends JoinType { } case class NaturalJoin(tpe: JoinType) extends JoinType { + require(Seq(Inner, LeftOuter, RightOuter, FullOuter).contains(tpe), + "Unsupported natural join type " + tpe) override def sql: String = "NATURAL " + tpe.sql } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveNaturalJoinSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveNaturalJoinSuite.scala index a6554fbc41..fcf4ac1967 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveNaturalJoinSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveNaturalJoinSuite.scala @@ -30,10 +30,10 @@ class ResolveNaturalJoinSuite extends AnalysisTest { lazy val aNotNull = a.notNull lazy val bNotNull = b.notNull lazy val cNotNull = c.notNull - lazy val r1 = LocalRelation(a, b) - lazy val r2 = LocalRelation(a, c) + lazy val r1 = LocalRelation(b, a) + lazy val r2 = LocalRelation(c, a) lazy val r3 = LocalRelation(aNotNull, bNotNull) - lazy val r4 = LocalRelation(bNotNull, cNotNull) + lazy val r4 = LocalRelation(cNotNull, bNotNull) test("natural inner join") { val plan = r1.join(r2, NaturalJoin(Inner), None) |