aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorReynold Xin <rxin@databricks.com>2016-02-03 23:43:48 -0800
committerReynold Xin <rxin@databricks.com>2016-02-03 23:43:48 -0800
commitdee801adb78d6abd0abbf76b4dfa71aa296b4f0b (patch)
tree705efd2edac21ab651e7badcaa452735a2c31552
parentd39087147ff1052b623cdba69ffbde28b266745f (diff)
downloadspark-dee801adb78d6abd0abbf76b4dfa71aa296b4f0b.tar.gz
spark-dee801adb78d6abd0abbf76b4dfa71aa296b4f0b.tar.bz2
spark-dee801adb78d6abd0abbf76b4dfa71aa296b4f0b.zip
[SPARK-12828][SQL] Natural join follow-up
This is a small addendum to #10762 to make the code more robust again future changes. Author: Reynold Xin <rxin@databricks.com> Closes #11070 from rxin/SPARK-12828-natural-join.
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala21
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala2
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveNaturalJoinSuite.scala6
3 files changed, 17 insertions, 12 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index b30ed5928f..b59eb12419 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -1239,21 +1239,23 @@ class Analyzer(
*/
object ResolveNaturalJoin extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
- // Should not skip unresolved nodes because natural join is always unresolved.
case j @ Join(left, right, NaturalJoin(joinType), condition) if j.resolvedExceptNatural =>
- // find common column names from both sides, should be treated like usingColumns
+ // find common column names from both sides
val joinNames = left.output.map(_.name).intersect(right.output.map(_.name))
val leftKeys = joinNames.map(keyName => left.output.find(_.name == keyName).get)
val rightKeys = joinNames.map(keyName => right.output.find(_.name == keyName).get)
val joinPairs = leftKeys.zip(rightKeys)
+
// Add joinPairs to joinConditions
val newCondition = (condition ++ joinPairs.map {
case (l, r) => EqualTo(l, r)
- }).reduceLeftOption(And)
+ }).reduceOption(And)
+
// columns not in joinPairs
val lUniqueOutput = left.output.filterNot(att => leftKeys.contains(att))
val rUniqueOutput = right.output.filterNot(att => rightKeys.contains(att))
- // we should only keep unique columns(depends on joinType) for joinCols
+
+ // the output list looks like: join keys, columns from left, columns from right
val projectList = joinType match {
case LeftOuter =>
leftKeys ++ lUniqueOutput ++ rUniqueOutput.map(_.withNullability(true))
@@ -1261,13 +1263,14 @@ class Analyzer(
rightKeys ++ lUniqueOutput.map(_.withNullability(true)) ++ rUniqueOutput
case FullOuter =>
// in full outer join, joinCols should be non-null if there is.
- val joinedCols = joinPairs.map {
- case (l, r) => Alias(Coalesce(Seq(l, r)), l.name)()
- }
- joinedCols ++ lUniqueOutput.map(_.withNullability(true)) ++
+ val joinedCols = joinPairs.map { case (l, r) => Alias(Coalesce(Seq(l, r)), l.name)() }
+ joinedCols ++
+ lUniqueOutput.map(_.withNullability(true)) ++
rUniqueOutput.map(_.withNullability(true))
- case _ =>
+ case Inner =>
rightKeys ++ lUniqueOutput ++ rUniqueOutput
+ case _ =>
+ sys.error("Unsupported natural join type " + joinType)
}
// use Project to trim unnecessary fields
Project(projectList, Join(left, right, joinType, newCondition))
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala
index b10f1e63a7..27a75326eb 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala
@@ -62,5 +62,7 @@ case object LeftSemi extends JoinType {
}
case class NaturalJoin(tpe: JoinType) extends JoinType {
+ require(Seq(Inner, LeftOuter, RightOuter, FullOuter).contains(tpe),
+ "Unsupported natural join type " + tpe)
override def sql: String = "NATURAL " + tpe.sql
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveNaturalJoinSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveNaturalJoinSuite.scala
index a6554fbc41..fcf4ac1967 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveNaturalJoinSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveNaturalJoinSuite.scala
@@ -30,10 +30,10 @@ class ResolveNaturalJoinSuite extends AnalysisTest {
lazy val aNotNull = a.notNull
lazy val bNotNull = b.notNull
lazy val cNotNull = c.notNull
- lazy val r1 = LocalRelation(a, b)
- lazy val r2 = LocalRelation(a, c)
+ lazy val r1 = LocalRelation(b, a)
+ lazy val r2 = LocalRelation(c, a)
lazy val r3 = LocalRelation(aNotNull, bNotNull)
- lazy val r4 = LocalRelation(bNotNull, cNotNull)
+ lazy val r4 = LocalRelation(cNotNull, bNotNull)
test("natural inner join") {
val plan = r1.join(r2, NaturalJoin(Inner), None)