aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorDaoyuan Wang <daoyuan.wang@intel.com>2016-02-03 21:05:53 -0800
committerReynold Xin <rxin@databricks.com>2016-02-03 21:05:53 -0800
commit0f81318ae217346c20894572795e1a9cee2ebc8f (patch)
treef364e5944f879c58fa67760239b4e927e85ae733 /sql
parenta64831124c215f56f124747fa241560c70cf0a36 (diff)
downloadspark-0f81318ae217346c20894572795e1a9cee2ebc8f.tar.gz
spark-0f81318ae217346c20894572795e1a9cee2ebc8f.tar.bz2
spark-0f81318ae217346c20894572795e1a9cee2ebc8f.zip
[SPARK-12828][SQL] add natural join support
Jira: https://issues.apache.org/jira/browse/SPARK-12828 Author: Daoyuan Wang <daoyuan.wang@intel.com> Closes #10762 from adrian-wang/naturaljoin.
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/FromClauseParser.g23
-rw-r--r--sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlLexer.g2
-rw-r--r--sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g4
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala4
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala43
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala4
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala4
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala10
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveNaturalJoinSuite.scala90
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala1
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala24
11 files changed, 198 insertions, 11 deletions
diff --git a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/FromClauseParser.g b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/FromClauseParser.g
index 6d76afcd4a..e83f8a7cd1 100644
--- a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/FromClauseParser.g
+++ b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/FromClauseParser.g
@@ -117,15 +117,20 @@ joinToken
@init { gParent.pushMsg("join type specifier", state); }
@after { gParent.popMsg(state); }
:
- KW_JOIN -> TOK_JOIN
- | KW_INNER KW_JOIN -> TOK_JOIN
- | COMMA -> TOK_JOIN
- | KW_CROSS KW_JOIN -> TOK_CROSSJOIN
- | KW_LEFT (KW_OUTER)? KW_JOIN -> TOK_LEFTOUTERJOIN
- | KW_RIGHT (KW_OUTER)? KW_JOIN -> TOK_RIGHTOUTERJOIN
- | KW_FULL (KW_OUTER)? KW_JOIN -> TOK_FULLOUTERJOIN
- | KW_LEFT KW_SEMI KW_JOIN -> TOK_LEFTSEMIJOIN
- | KW_ANTI KW_JOIN -> TOK_ANTIJOIN
+ KW_JOIN -> TOK_JOIN
+ | KW_INNER KW_JOIN -> TOK_JOIN
+ | KW_NATURAL KW_JOIN -> TOK_NATURALJOIN
+ | KW_NATURAL KW_INNER KW_JOIN -> TOK_NATURALJOIN
+ | COMMA -> TOK_JOIN
+ | KW_CROSS KW_JOIN -> TOK_CROSSJOIN
+ | KW_LEFT (KW_OUTER)? KW_JOIN -> TOK_LEFTOUTERJOIN
+ | KW_RIGHT (KW_OUTER)? KW_JOIN -> TOK_RIGHTOUTERJOIN
+ | KW_FULL (KW_OUTER)? KW_JOIN -> TOK_FULLOUTERJOIN
+ | KW_NATURAL KW_LEFT (KW_OUTER)? KW_JOIN -> TOK_NATURALLEFTOUTERJOIN
+ | KW_NATURAL KW_RIGHT (KW_OUTER)? KW_JOIN -> TOK_NATURALRIGHTOUTERJOIN
+ | KW_NATURAL KW_FULL (KW_OUTER)? KW_JOIN -> TOK_NATURALFULLOUTERJOIN
+ | KW_LEFT KW_SEMI KW_JOIN -> TOK_LEFTSEMIJOIN
+ | KW_ANTI KW_JOIN -> TOK_ANTIJOIN
;
lateralView
diff --git a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlLexer.g b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlLexer.g
index 1d07a27353..fd1ad59207 100644
--- a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlLexer.g
+++ b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlLexer.g
@@ -335,6 +335,8 @@ KW_CACHE: 'CACHE';
KW_UNCACHE: 'UNCACHE';
KW_DFS: 'DFS';
+KW_NATURAL: 'NATURAL';
+
// Operators
// NOTE: if you add a new function/operator, add it to sysFuncNames so that describe function _FUNC_ will work.
diff --git a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g
index 6591f6b0f5..9935678ca2 100644
--- a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g
+++ b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g
@@ -96,6 +96,10 @@ TOK_RIGHTOUTERJOIN;
TOK_FULLOUTERJOIN;
TOK_UNIQUEJOIN;
TOK_CROSSJOIN;
+TOK_NATURALJOIN;
+TOK_NATURALLEFTOUTERJOIN;
+TOK_NATURALRIGHTOUTERJOIN;
+TOK_NATURALFULLOUTERJOIN;
TOK_LOAD;
TOK_EXPORT;
TOK_IMPORT;
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala
index 7ce2407913..a42360d562 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala
@@ -520,6 +520,10 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
case "TOK_LEFTSEMIJOIN" => LeftSemi
case "TOK_UNIQUEJOIN" => noParseRule("Unique Join", node)
case "TOK_ANTIJOIN" => noParseRule("Anti Join", node)
+ case "TOK_NATURALJOIN" => NaturalJoin(Inner)
+ case "TOK_NATURALRIGHTOUTERJOIN" => NaturalJoin(RightOuter)
+ case "TOK_NATURALLEFTOUTERJOIN" => NaturalJoin(LeftOuter)
+ case "TOK_NATURALFULLOUTERJOIN" => NaturalJoin(FullOuter)
}
Join(nodeToRelation(relation1),
nodeToRelation(relation2),
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index a983dc1cdf..b30ed5928f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -24,6 +24,7 @@ import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.{CatalystConf, ScalaReflection, SimpleCatalystConf}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
+import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.catalyst.trees.TreeNodeRef
@@ -81,6 +82,7 @@ class Analyzer(
ResolveAliases ::
ResolveWindowOrder ::
ResolveWindowFrame ::
+ ResolveNaturalJoin ::
ExtractWindowExpressions ::
GlobalAggregates ::
ResolveAggregateFunctions ::
@@ -1230,6 +1232,47 @@ class Analyzer(
}
}
}
+
+ /**
+ * Removes natural joins by calculating output columns based on output from two sides,
+ * Then apply a Project on a normal Join to eliminate natural join.
+ */
+ object ResolveNaturalJoin extends Rule[LogicalPlan] {
+ override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
+ // Should not skip unresolved nodes because natural join is always unresolved.
+ case j @ Join(left, right, NaturalJoin(joinType), condition) if j.resolvedExceptNatural =>
+ // find common column names from both sides, should be treated like usingColumns
+ val joinNames = left.output.map(_.name).intersect(right.output.map(_.name))
+ val leftKeys = joinNames.map(keyName => left.output.find(_.name == keyName).get)
+ val rightKeys = joinNames.map(keyName => right.output.find(_.name == keyName).get)
+ val joinPairs = leftKeys.zip(rightKeys)
+ // Add joinPairs to joinConditions
+ val newCondition = (condition ++ joinPairs.map {
+ case (l, r) => EqualTo(l, r)
+ }).reduceLeftOption(And)
+ // columns not in joinPairs
+ val lUniqueOutput = left.output.filterNot(att => leftKeys.contains(att))
+ val rUniqueOutput = right.output.filterNot(att => rightKeys.contains(att))
+ // we should only keep unique columns(depends on joinType) for joinCols
+ val projectList = joinType match {
+ case LeftOuter =>
+ leftKeys ++ lUniqueOutput ++ rUniqueOutput.map(_.withNullability(true))
+ case RightOuter =>
+ rightKeys ++ lUniqueOutput.map(_.withNullability(true)) ++ rUniqueOutput
+ case FullOuter =>
+ // in full outer join, joinCols should be non-null if there is.
+ val joinedCols = joinPairs.map {
+ case (l, r) => Alias(Coalesce(Seq(l, r)), l.name)()
+ }
+ joinedCols ++ lUniqueOutput.map(_.withNullability(true)) ++
+ rUniqueOutput.map(_.withNullability(true))
+ case _ =>
+ rightKeys ++ lUniqueOutput ++ rUniqueOutput
+ }
+ // use Project to trim unnecessary fields
+ Project(projectList, Join(left, right, joinType, newCondition))
+ }
+ }
}
/**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index f156b5d10a..4ecee75048 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral}
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.planning.{ExtractFiltersAndInnerJoins, Unions}
-import org.apache.spark.sql.catalyst.plans.{FullOuter, Inner, LeftOuter, LeftSemi, RightOuter}
+import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.types._
@@ -905,6 +905,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper {
(rightFilterConditions ++ commonFilterCondition).
reduceLeftOption(And).map(Filter(_, newJoin)).getOrElse(newJoin)
case FullOuter => f // DO Nothing for Full Outer Join
+ case NaturalJoin(_) => sys.error("Untransformed NaturalJoin node")
}
// push down the join filter into sub query scanning if applicable
@@ -939,6 +940,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper {
Join(newLeft, newRight, LeftOuter, newJoinCond)
case FullOuter => f
+ case NaturalJoin(_) => sys.error("Untransformed NaturalJoin node")
}
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala
index a5f6764aef..b10f1e63a7 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala
@@ -60,3 +60,7 @@ case object FullOuter extends JoinType {
case object LeftSemi extends JoinType {
override def sql: String = "LEFT SEMI"
}
+
+case class NaturalJoin(tpe: JoinType) extends JoinType {
+ override def sql: String = "NATURAL " + tpe.sql
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
index 8150ff8434..03a79520cb 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
@@ -250,12 +250,20 @@ case class Join(
def duplicateResolved: Boolean = left.outputSet.intersect(right.outputSet).isEmpty
// Joins are only resolved if they don't introduce ambiguous expression ids.
- override lazy val resolved: Boolean = {
+ // NaturalJoin should be ready for resolution only if everything else is resolved here
+ lazy val resolvedExceptNatural: Boolean = {
childrenResolved &&
expressions.forall(_.resolved) &&
duplicateResolved &&
condition.forall(_.dataType == BooleanType)
}
+
+ // if not a natural join, use `resolvedExceptNatural`. if it is a natural join, we still need
+ // to eliminate natural before we mark it resolved.
+ override lazy val resolved: Boolean = joinType match {
+ case NaturalJoin(_) => false
+ case _ => resolvedExceptNatural
+ }
}
/**
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveNaturalJoinSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveNaturalJoinSuite.scala
new file mode 100644
index 0000000000..a6554fbc41
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveNaturalJoinSuite.scala
@@ -0,0 +1,90 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.analysis
+
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.dsl.plans._
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans._
+import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
+
+class ResolveNaturalJoinSuite extends AnalysisTest {
+ lazy val a = 'a.string
+ lazy val b = 'b.string
+ lazy val c = 'c.string
+ lazy val aNotNull = a.notNull
+ lazy val bNotNull = b.notNull
+ lazy val cNotNull = c.notNull
+ lazy val r1 = LocalRelation(a, b)
+ lazy val r2 = LocalRelation(a, c)
+ lazy val r3 = LocalRelation(aNotNull, bNotNull)
+ lazy val r4 = LocalRelation(bNotNull, cNotNull)
+
+ test("natural inner join") {
+ val plan = r1.join(r2, NaturalJoin(Inner), None)
+ val expected = r1.join(r2, Inner, Some(EqualTo(a, a))).select(a, b, c)
+ checkAnalysis(plan, expected)
+ }
+
+ test("natural left join") {
+ val plan = r1.join(r2, NaturalJoin(LeftOuter), None)
+ val expected = r1.join(r2, LeftOuter, Some(EqualTo(a, a))).select(a, b, c)
+ checkAnalysis(plan, expected)
+ }
+
+ test("natural right join") {
+ val plan = r1.join(r2, NaturalJoin(RightOuter), None)
+ val expected = r1.join(r2, RightOuter, Some(EqualTo(a, a))).select(a, b, c)
+ checkAnalysis(plan, expected)
+ }
+
+ test("natural full outer join") {
+ val plan = r1.join(r2, NaturalJoin(FullOuter), None)
+ val expected = r1.join(r2, FullOuter, Some(EqualTo(a, a))).select(
+ Alias(Coalesce(Seq(a, a)), "a")(), b, c)
+ checkAnalysis(plan, expected)
+ }
+
+ test("natural inner join with no nullability") {
+ val plan = r3.join(r4, NaturalJoin(Inner), None)
+ val expected = r3.join(r4, Inner, Some(EqualTo(bNotNull, bNotNull))).select(
+ bNotNull, aNotNull, cNotNull)
+ checkAnalysis(plan, expected)
+ }
+
+ test("natural left join with no nullability") {
+ val plan = r3.join(r4, NaturalJoin(LeftOuter), None)
+ val expected = r3.join(r4, LeftOuter, Some(EqualTo(bNotNull, bNotNull))).select(
+ bNotNull, aNotNull, c)
+ checkAnalysis(plan, expected)
+ }
+
+ test("natural right join with no nullability") {
+ val plan = r3.join(r4, NaturalJoin(RightOuter), None)
+ val expected = r3.join(r4, RightOuter, Some(EqualTo(bNotNull, bNotNull))).select(
+ bNotNull, a, cNotNull)
+ checkAnalysis(plan, expected)
+ }
+
+ test("natural full outer join with no nullability") {
+ val plan = r3.join(r4, NaturalJoin(FullOuter), None)
+ val expected = r3.join(r4, FullOuter, Some(EqualTo(bNotNull, bNotNull))).select(
+ Alias(Coalesce(Seq(bNotNull, bNotNull)), "b")(), a, c)
+ checkAnalysis(plan, expected)
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
index 84203bbfef..f15b926bd2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
@@ -474,6 +474,7 @@ class DataFrame private[sql](
val rightCol = withPlan(joined.right).resolve(col).toAttribute.withNullability(true)
Alias(Coalesce(Seq(leftCol, rightCol)), col)()
}
+ case NaturalJoin(_) => sys.error("NaturalJoin with using clause is not supported.")
}
// The nullability of output of joined could be different than original column,
// so we can only compare them by exprId
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index 79bfd4b44b..8ef7b61314 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -2075,4 +2075,28 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
)
}
}
+
+ test("natural join") {
+ val df1 = Seq(("one", 1), ("two", 2), ("three", 3)).toDF("k", "v1")
+ val df2 = Seq(("one", 1), ("two", 22), ("one", 5)).toDF("k", "v2")
+ withTempTable("nt1", "nt2") {
+ df1.registerTempTable("nt1")
+ df2.registerTempTable("nt2")
+ checkAnswer(
+ sql("SELECT * FROM nt1 natural join nt2 where k = \"one\""),
+ Row("one", 1, 1) :: Row("one", 1, 5) :: Nil)
+
+ checkAnswer(
+ sql("SELECT * FROM nt1 natural left join nt2 order by v1, v2"),
+ Row("one", 1, 1) :: Row("one", 1, 5) :: Row("two", 2, 22) :: Row("three", 3, null) :: Nil)
+
+ checkAnswer(
+ sql("SELECT * FROM nt1 natural right join nt2 order by v1, v2"),
+ Row("one", 1, 1) :: Row("one", 1, 5) :: Row("two", 2, 22) :: Nil)
+
+ checkAnswer(
+ sql("SELECT count(*) FROM nt1 natural full outer join nt2"),
+ Row(4) :: Nil)
+ }
+ }
}