aboutsummaryrefslogtreecommitdiff
path: root/sql/catalyst
diff options
context:
space:
mode:
authorHerman van Hovell <hvanhovell@databricks.com>2016-09-07 00:44:07 +0200
committerHerman van Hovell <hvanhovell@databricks.com>2016-09-07 00:44:07 +0200
commit4f769b903bc9822c262f0a15f5933cc05c67923f (patch)
tree89e4e98fc53f256e1f8064e05041e4e7e7c402ec /sql/catalyst
parent29cfab3f1524c5690be675d24dda0a9a1806d6ff (diff)
downloadspark-4f769b903bc9822c262f0a15f5933cc05c67923f.tar.gz
spark-4f769b903bc9822c262f0a15f5933cc05c67923f.tar.bz2
spark-4f769b903bc9822c262f0a15f5933cc05c67923f.zip
[SPARK-17296][SQL] Simplify parser join processing.
## What changes were proposed in this pull request? Join processing in the parser relies on the fact that the grammar produces a right nested trees, for instance the parse tree for `select * from a join b join c` is expected to produce a tree similar to `JOIN(a, JOIN(b, c))`. However there are cases in which this (invariant) is violated, like: ```sql SELECT COUNT(1) FROM test T1 CROSS JOIN test T2 JOIN test T3 ON T3.col = T1.col JOIN test T4 ON T4.col = T1.col ``` In this case the parser returns a tree in which Joins are located on both the left and the right sides of the parent join node. This PR introduces a different grammar rule which does not make this assumption. The new rule takes a relation and searches for zero or more joined relations. As a bonus processing is much easier. ## How was this patch tested? Existing tests and I have added a regression test to the plan parser suite. Author: Herman van Hovell <hvanhovell@databricks.com> Closes #14867 from hvanhovell/SPARK-17296.
Diffstat (limited to 'sql/catalyst')
-rw-r--r--sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g411
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala99
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala6
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala44
4 files changed, 102 insertions, 58 deletions
diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4
index 0447436ea7..9a643465a9 100644
--- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4
+++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4
@@ -374,11 +374,12 @@ setQuantifier
;
relation
- : left=relation
- (joinType JOIN right=relation joinCriteria?
- | NATURAL joinType JOIN right=relation
- ) #joinRelation
- | relationPrimary #relationDefault
+ : relationPrimary joinRelation*
+ ;
+
+joinRelation
+ : (joinType) JOIN right=relationPrimary joinCriteria?
+ | NATURAL joinType JOIN right=relationPrimary
;
joinType
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
index e4cb9f0161..bbbb14df88 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
@@ -92,10 +92,9 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging {
// Apply CTEs
query.optional(ctx.ctes) {
- val ctes = ctx.ctes.namedQuery.asScala.map {
- case nCtx =>
- val namedQuery = visitNamedQuery(nCtx)
- (namedQuery.alias, namedQuery)
+ val ctes = ctx.ctes.namedQuery.asScala.map { nCtx =>
+ val namedQuery = visitNamedQuery(nCtx)
+ (namedQuery.alias, namedQuery)
}
// Check for duplicate names.
checkDuplicateKeys(ctes, ctx)
@@ -401,7 +400,11 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging {
* separated) relations here, these get converted into a single plan by condition-less inner join.
*/
override def visitFromClause(ctx: FromClauseContext): LogicalPlan = withOrigin(ctx) {
- val from = ctx.relation.asScala.map(plan).reduceLeft(Join(_, _, Inner, None))
+ val from = ctx.relation.asScala.foldLeft(null: LogicalPlan) { (left, relation) =>
+ val right = plan(relation.relationPrimary)
+ val join = right.optionalMap(left)(Join(_, _, Inner, None))
+ withJoinRelations(join, relation)
+ }
ctx.lateralView.asScala.foldLeft(from)(withGenerate)
}
@@ -532,55 +535,53 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging {
}
/**
- * Create a joins between two or more logical plans.
+ * Create a single relation referenced in a FROM claused. This method is used when a part of the
+ * join condition is nested, for example:
+ * {{{
+ * select * from t1 join (t2 cross join t3) on col1 = col2
+ * }}}
*/
- override def visitJoinRelation(ctx: JoinRelationContext): LogicalPlan = withOrigin(ctx) {
- /** Build a join between two plans. */
- def join(ctx: JoinRelationContext, left: LogicalPlan, right: LogicalPlan): Join = {
- val baseJoinType = ctx.joinType match {
- case null => Inner
- case jt if jt.CROSS != null => Cross
- case jt if jt.FULL != null => FullOuter
- case jt if jt.SEMI != null => LeftSemi
- case jt if jt.ANTI != null => LeftAnti
- case jt if jt.LEFT != null => LeftOuter
- case jt if jt.RIGHT != null => RightOuter
- case _ => Inner
- }
+ override def visitRelation(ctx: RelationContext): LogicalPlan = withOrigin(ctx) {
+ withJoinRelations(plan(ctx.relationPrimary), ctx)
+ }
- // Resolve the join type and join condition
- val (joinType, condition) = Option(ctx.joinCriteria) match {
- case Some(c) if c.USING != null =>
- val columns = c.identifier.asScala.map { column =>
- UnresolvedAttribute.quoted(column.getText)
- }
- (UsingJoin(baseJoinType, columns), None)
- case Some(c) if c.booleanExpression != null =>
- (baseJoinType, Option(expression(c.booleanExpression)))
- case None if ctx.NATURAL != null =>
- (NaturalJoin(baseJoinType), None)
- case None =>
- (baseJoinType, None)
- }
- Join(left, right, joinType, condition)
- }
+ /**
+ * Join one more [[LogicalPlan]]s to the current logical plan.
+ */
+ private def withJoinRelations(base: LogicalPlan, ctx: RelationContext): LogicalPlan = {
+ ctx.joinRelation.asScala.foldLeft(base) { (left, join) =>
+ withOrigin(join) {
+ val baseJoinType = join.joinType match {
+ case null => Inner
+ case jt if jt.CROSS != null => Cross
+ case jt if jt.FULL != null => FullOuter
+ case jt if jt.SEMI != null => LeftSemi
+ case jt if jt.ANTI != null => LeftAnti
+ case jt if jt.LEFT != null => LeftOuter
+ case jt if jt.RIGHT != null => RightOuter
+ case _ => Inner
+ }
- // Handle all consecutive join clauses. ANTLR produces a right nested tree in which the the
- // first join clause is at the top. However fields of previously referenced tables can be used
- // in following join clauses. The tree needs to be reversed in order to make this work.
- var result = plan(ctx.left)
- var current = ctx
- while (current != null) {
- current.right match {
- case right: JoinRelationContext =>
- result = join(current, result, plan(right.left))
- current = right
- case right =>
- result = join(current, result, plan(right))
- current = null
+ // Resolve the join type and join condition
+ val (joinType, condition) = Option(join.joinCriteria) match {
+ case Some(c) if c.USING != null =>
+ val columns = c.identifier.asScala.map { column =>
+ UnresolvedAttribute.quoted(column.getText)
+ }
+ (UsingJoin(baseJoinType, columns), None)
+ case Some(c) if c.booleanExpression != null =>
+ (baseJoinType, Option(expression(c.booleanExpression)))
+ case None if join.NATURAL != null =>
+ if (baseJoinType == Cross) {
+ throw new ParseException("NATURAL CROSS JOIN is not supported", ctx)
+ }
+ (NaturalJoin(baseJoinType), None)
+ case None =>
+ (baseJoinType, None)
+ }
+ Join(left, plan(join.right), joinType, condition)
}
}
- result
}
/**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala
index cb89a9679a..6fbc33fad7 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala
@@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.parser
import scala.collection.mutable.StringBuilder
-import org.antlr.v4.runtime.{CharStream, ParserRuleContext, Token}
+import org.antlr.v4.runtime.{ParserRuleContext, Token}
import org.antlr.v4.runtime.misc.Interval
import org.antlr.v4.runtime.tree.TerminalNode
@@ -189,9 +189,7 @@ object ParserUtils {
* Map a [[LogicalPlan]] to another [[LogicalPlan]] if the passed context exists using the
* passed function. The original plan is returned when the context does not exist.
*/
- def optionalMap[C <: ParserRuleContext](
- ctx: C)(
- f: (C, LogicalPlan) => LogicalPlan): LogicalPlan = {
+ def optionalMap[C](ctx: C)(f: (C, LogicalPlan) => LogicalPlan): LogicalPlan = {
if (ctx != null) {
f(ctx, plan)
} else {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala
index faaea17b64..ca86304d4d 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala
@@ -360,10 +360,54 @@ class PlanParserSuite extends PlanTest {
test("left anti join", LeftAnti, testExistence)
test("anti join", LeftAnti, testExistence)
+ // Test natural cross join
+ intercept("select * from a natural cross join b")
+
+ // Test natural join with a condition
+ intercept("select * from a natural join b on a.id = b.id")
+
// Test multiple consecutive joins
assertEqual(
"select * from a join b join c right join d",
table("a").join(table("b")).join(table("c")).join(table("d"), RightOuter).select(star()))
+
+ // SPARK-17296
+ assertEqual(
+ "select * from t1 cross join t2 join t3 on t3.id = t1.id join t4 on t4.id = t1.id",
+ table("t1")
+ .join(table("t2"), Cross)
+ .join(table("t3"), Inner, Option(Symbol("t3.id") === Symbol("t1.id")))
+ .join(table("t4"), Inner, Option(Symbol("t4.id") === Symbol("t1.id")))
+ .select(star()))
+
+ // Test multiple on clauses.
+ intercept("select * from t1 inner join t2 inner join t3 on col3 = col2 on col3 = col1")
+
+ // Parenthesis
+ assertEqual(
+ "select * from t1 inner join (t2 inner join t3 on col3 = col2) on col3 = col1",
+ table("t1")
+ .join(table("t2")
+ .join(table("t3"), Inner, Option('col3 === 'col2)), Inner, Option('col3 === 'col1))
+ .select(star()))
+ assertEqual(
+ "select * from t1 inner join (t2 inner join t3) on col3 = col2",
+ table("t1")
+ .join(table("t2").join(table("t3"), Inner, None), Inner, Option('col3 === 'col2))
+ .select(star()))
+ assertEqual(
+ "select * from t1 inner join (t2 inner join t3 on col3 = col2)",
+ table("t1")
+ .join(table("t2").join(table("t3"), Inner, Option('col3 === 'col2)), Inner, None)
+ .select(star()))
+
+ // Implicit joins.
+ assertEqual(
+ "select * from t1, t3 join t2 on t1.col1 = t2.col2",
+ table("t1")
+ .join(table("t3"))
+ .join(table("t2"), Inner, Option(Symbol("t1.col1") === Symbol("t2.col2")))
+ .select(star()))
}
test("sampled relations") {