diff options
10 files changed, 340 insertions, 4 deletions
diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index 312399861e..0ac986d137 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -365,7 +365,7 @@ querySpecification (RECORDREADER recordReader=STRING)? fromClause? (WHERE where=booleanExpression)?) - | ((kind=SELECT setQuantifier? namedExpressionSeq fromClause? + | ((kind=SELECT hint? setQuantifier? namedExpressionSeq fromClause? | fromClause (kind=SELECT setQuantifier? namedExpressionSeq)?) lateralView* (WHERE where=booleanExpression)? @@ -374,6 +374,16 @@ querySpecification windows?) ; +hint + : '/*+' hintStatement '*/' + ; + +hintStatement + : hintName=identifier + | hintName=identifier '(' parameters+=identifier parameters+=identifier ')' + | hintName=identifier '(' parameters+=identifier (',' parameters+=identifier)* ')' + ; + fromClause : FROM relation (',' relation)* lateralView* ; @@ -1002,8 +1012,12 @@ SIMPLE_COMMENT : '--' ~[\r\n]* '\r'? '\n'? -> channel(HIDDEN) ; +BRACKETED_EMPTY_COMMENT + : '/**/' -> channel(HIDDEN) + ; + BRACKETED_COMMENT - : '/*' .*? '*/' -> channel(HIDDEN) + : '/*' ~[+] .*? '*/' -> channel(HIDDEN) ; WS diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 4913dccf4b..8348cb5012 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -114,6 +114,9 @@ class Analyzer( val postHocResolutionRules: Seq[Rule[LogicalPlan]] = Nil lazy val batches: Seq[Batch] = Seq( + Batch("Hints", fixedPoint, + new SubstituteHints.SubstituteBroadcastHints(conf), + SubstituteHints.RemoveAllHints), Batch("Substitution", fixedPoint, CTESubstitution, WindowsSubstitution, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 532ecb8757..36ab8b8527 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -387,6 +387,10 @@ trait CheckAnalysis extends PredicateHelper { |in operator ${operator.simpleString} """.stripMargin) + case _: Hint => + throw new IllegalStateException( + "Internal error: logical hint operator should have been removed during analysis") + case _ => // Analysis successful! } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/SubstituteHints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/SubstituteHints.scala new file mode 100644 index 0000000000..fda4d1b612 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/SubstituteHints.scala @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.catalyst.CatalystConf +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.trees.CurrentOrigin + + +/** + * Collection of rules related to hints. The only hint currently available is broadcast join hint. + * + * Note that this is separatedly into two rules because in the future we might introduce new hint + * rules that have different ordering requirements from broadcast. + */ +object SubstituteHints { + + /** + * Substitute Hints. + * + * The only hint currently available is broadcast join hint. + * + * For broadcast hint, we accept "BROADCAST", "BROADCASTJOIN", and "MAPJOIN", and a sequence of + * relation aliases can be specified in the hint. A broadcast hint plan node will be inserted + * on top of any relation (that is not aliased differently), subquery, or common table expression + * that match the specified name. + * + * The hint resolution works by recursively traversing down the query plan to find a relation or + * subquery that matches one of the specified broadcast aliases. The traversal does not go past + * beyond any existing broadcast hints, subquery aliases. + * + * This rule must happen before common table expressions. + */ + class SubstituteBroadcastHints(conf: CatalystConf) extends Rule[LogicalPlan] { + private val BROADCAST_HINT_NAMES = Set("BROADCAST", "BROADCASTJOIN", "MAPJOIN") + + def resolver: Resolver = conf.resolver + + private def applyBroadcastHint(plan: LogicalPlan, toBroadcast: Set[String]): LogicalPlan = { + // Whether to continue recursing down the tree + var recurse = true + + val newNode = CurrentOrigin.withOrigin(plan.origin) { + plan match { + case r: UnresolvedRelation => + val alias = r.alias.getOrElse(r.tableIdentifier.table) + if (toBroadcast.exists(resolver(_, alias))) BroadcastHint(plan) else plan + case r: SubqueryAlias => + if (toBroadcast.exists(resolver(_, r.alias))) { + BroadcastHint(plan) + } else { + // Don't recurse down subquery aliases if there are no match. + recurse = false + plan + } + case _: BroadcastHint => + // Found a broadcast hint; don't change the plan but also don't recurse down. + recurse = false + plan + case _ => + plan + } + } + + if ((plan fastEquals newNode) && recurse) { + newNode.mapChildren(child => applyBroadcastHint(child, toBroadcast)) + } else { + newNode + } + } + + def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + case h: Hint if BROADCAST_HINT_NAMES.contains(h.name.toUpperCase) => + applyBroadcastHint(h.child, h.parameters.toSet) + } + } + + /** + * Removes all the hints, used to remove invalid hints provided by the user. + * This must be executed after all the other hint rules are executed. + */ + object RemoveAllHints extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + case h: Hint => h.child + } + } + +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index bb07558c81..bbb9922c18 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -380,7 +380,10 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { } // Window - withDistinct.optionalMap(windows)(withWindows) + val withWindow = withDistinct.optionalMap(windows)(withWindows) + + // Hint + withWindow.optionalMap(hint)(withHints) } } @@ -506,6 +509,16 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { } /** + * Add a Hint to a logical plan. + */ + private def withHints( + ctx: HintContext, + query: LogicalPlan): LogicalPlan = withOrigin(ctx) { + val stmt = ctx.hintStatement + Hint(stmt.hintName.getText, stmt.parameters.asScala.map(_.getText), query) + } + + /** * Add a [[Generate]] (Lateral View) to a logical plan. */ private def withGenerate( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 8d7a6bc4b5..4d696c0a3f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -363,6 +363,15 @@ case class BroadcastHint(child: LogicalPlan) extends UnaryNode { } /** + * A general hint for the child. This node will be eliminated post analysis. + * A pair of (name, parameters). + */ +case class Hint(name: String, parameters: Seq[String], child: LogicalPlan) extends UnaryNode { + override lazy val resolved: Boolean = false + override def output: Seq[Attribute] = child.output +} + +/** * Insert some data into a table. Note that this plan is unresolved and has to be replaced by the * concrete implementations during analysis. * diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala index 3acb261800..0f059b9591 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala @@ -32,6 +32,7 @@ trait AnalysisTest extends PlanTest { val conf = new SimpleCatalystConf(caseSensitive) val catalog = new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry, conf) catalog.createTempView("TaBlE", TestRelations.testRelation, overrideIfExists = true) + catalog.createTempView("TaBlE2", TestRelations.testRelation2, overrideIfExists = true) new Analyzer(catalog, conf) { override val extendedResolutionRules = EliminateSubqueryAliases :: Nil } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/SubstituteHintsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/SubstituteHintsSuite.scala new file mode 100644 index 0000000000..9d671f3121 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/SubstituteHintsSuite.scala @@ -0,0 +1,121 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.parser.CatalystSqlParser +import org.apache.spark.sql.catalyst.plans.Inner +import org.apache.spark.sql.catalyst.plans.logical._ + +class SubstituteHintsSuite extends AnalysisTest { + import org.apache.spark.sql.catalyst.analysis.TestRelations._ + + test("invalid hints should be ignored") { + checkAnalysis( + Hint("some_random_hint_that_does_not_exist", Seq("TaBlE"), table("TaBlE")), + testRelation, + caseSensitive = false) + } + + test("case-sensitive or insensitive parameters") { + checkAnalysis( + Hint("MAPJOIN", Seq("TaBlE"), table("TaBlE")), + BroadcastHint(testRelation), + caseSensitive = false) + + checkAnalysis( + Hint("MAPJOIN", Seq("table"), table("TaBlE")), + BroadcastHint(testRelation), + caseSensitive = false) + + checkAnalysis( + Hint("MAPJOIN", Seq("TaBlE"), table("TaBlE")), + BroadcastHint(testRelation), + caseSensitive = true) + + checkAnalysis( + Hint("MAPJOIN", Seq("table"), table("TaBlE")), + testRelation, + caseSensitive = true) + } + + test("multiple broadcast hint aliases") { + checkAnalysis( + Hint("MAPJOIN", Seq("table", "table2"), table("table").join(table("table2"))), + Join(BroadcastHint(testRelation), BroadcastHint(testRelation2), Inner, None), + caseSensitive = false) + } + + test("do not traverse past existing broadcast hints") { + checkAnalysis( + Hint("MAPJOIN", Seq("table"), BroadcastHint(table("table").where('a > 1))), + BroadcastHint(testRelation.where('a > 1)).analyze, + caseSensitive = false) + } + + test("should work for subqueries") { + checkAnalysis( + Hint("MAPJOIN", Seq("tableAlias"), table("table").as("tableAlias")), + BroadcastHint(testRelation), + caseSensitive = false) + + checkAnalysis( + Hint("MAPJOIN", Seq("tableAlias"), table("table").subquery('tableAlias)), + BroadcastHint(testRelation), + caseSensitive = false) + + // Negative case: if the alias doesn't match, don't match the original table name. + checkAnalysis( + Hint("MAPJOIN", Seq("table"), table("table").as("tableAlias")), + testRelation, + caseSensitive = false) + } + + test("do not traverse past subquery alias") { + checkAnalysis( + Hint("MAPJOIN", Seq("table"), table("table").where('a > 1).subquery('tableAlias)), + testRelation.where('a > 1).analyze, + caseSensitive = false) + } + + test("should work for CTE") { + checkAnalysis( + CatalystSqlParser.parsePlan( + """ + |WITH ctetable AS (SELECT * FROM table WHERE a > 1) + |SELECT /*+ BROADCAST(ctetable) */ * FROM ctetable + """.stripMargin + ), + BroadcastHint(testRelation.where('a > 1).select('a)).select('a).analyze, + caseSensitive = false) + } + + test("should not traverse down CTE") { + checkAnalysis( + CatalystSqlParser.parsePlan( + """ + |WITH ctetable AS (SELECT * FROM table WHERE a > 1) + |SELECT /*+ BROADCAST(table) */ * FROM ctetable + """.stripMargin + ), + testRelation.where('a > 1).select('a).select('a).analyze, + caseSensitive = false) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala index f408ba99d1..13a84b465b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala @@ -493,4 +493,46 @@ class PlanParserSuite extends PlanTest { assertEqual("select a, b from db.c where x !> 1", table("db", "c").where('x <= 1).select('a, 'b)) } + + test("select hint syntax") { + // Hive compatibility: Missing parameter raises ParseException. + val m = intercept[ParseException] { + parsePlan("SELECT /*+ HINT() */ * FROM t") + }.getMessage + assert(m.contains("no viable alternative at input")) + + // Hive compatibility: No database. + val m2 = intercept[ParseException] { + parsePlan("SELECT /*+ MAPJOIN(default.t) */ * from default.t") + }.getMessage + assert(m2.contains("no viable alternative at input")) + + comparePlans( + parsePlan("SELECT /*+ HINT */ * FROM t"), + Hint("HINT", Seq.empty, table("t").select(star()))) + + comparePlans( + parsePlan("SELECT /*+ BROADCASTJOIN(u) */ * FROM t"), + Hint("BROADCASTJOIN", Seq("u"), table("t").select(star()))) + + comparePlans( + parsePlan("SELECT /*+ MAPJOIN(u) */ * FROM t"), + Hint("MAPJOIN", Seq("u"), table("t").select(star()))) + + comparePlans( + parsePlan("SELECT /*+ STREAMTABLE(a,b,c) */ * FROM t"), + Hint("STREAMTABLE", Seq("a", "b", "c"), table("t").select(star()))) + + comparePlans( + parsePlan("SELECT /*+ INDEX(t emp_job_ix) */ * FROM t"), + Hint("INDEX", Seq("t", "emp_job_ix"), table("t").select(star()))) + + comparePlans( + parsePlan("SELECT /*+ MAPJOIN(`default.t`) */ * from `default.t`"), + Hint("MAPJOIN", Seq("default.t"), table("default.t").select(star()))) + + comparePlans( + parsePlan("SELECT /*+ MAPJOIN(t) */ a from t where true group by a order by a"), + Hint("MAPJOIN", Seq("t"), table("t").where(Literal(true)).groupBy('a)('a)).orderBy('a.asc)) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala index 119d6e25df..9c55357ab9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala @@ -22,6 +22,8 @@ import scala.reflect.ClassTag import org.apache.spark.AccumulatorSuite import org.apache.spark.sql.{Dataset, QueryTest, Row, SparkSession} import org.apache.spark.sql.catalyst.expressions.{BitwiseAnd, BitwiseOr, Cast, Literal, ShiftLeft} +import org.apache.spark.sql.catalyst.plans.logical.SubqueryAlias +import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.execution.exchange.EnsureRequirements import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.functions._ @@ -137,7 +139,7 @@ class BroadcastJoinSuite extends QueryTest with SQLTestUtils { assert(plan.collect { case p: BroadcastHashJoinExec => p }.size === 1) } - test("broadcast hint is propagated correctly") { + test("broadcast hint programming API") { withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { val df2 = spark.createDataFrame(Seq((1, "1"), (2, "2"), (3, "2"))).toDF("key", "value") val broadcasted = broadcast(df2) @@ -157,6 +159,29 @@ class BroadcastJoinSuite extends QueryTest with SQLTestUtils { } } + test("broadcast hint in SQL") { + import org.apache.spark.sql.catalyst.plans.logical.{BroadcastHint, Join} + + spark.range(10).createOrReplaceTempView("t") + spark.range(10).createOrReplaceTempView("u") + + for (name <- Seq("BROADCAST", "BROADCASTJOIN", "MAPJOIN")) { + val plan1 = sql(s"SELECT /*+ $name(t) */ * FROM t JOIN u ON t.id = u.id").queryExecution + .optimizedPlan + val plan2 = sql(s"SELECT /*+ $name(u) */ * FROM t JOIN u ON t.id = u.id").queryExecution + .optimizedPlan + val plan3 = sql(s"SELECT /*+ $name(v) */ * FROM t JOIN u ON t.id = u.id").queryExecution + .optimizedPlan + + assert(plan1.asInstanceOf[Join].left.isInstanceOf[BroadcastHint]) + assert(!plan1.asInstanceOf[Join].right.isInstanceOf[BroadcastHint]) + assert(!plan2.asInstanceOf[Join].left.isInstanceOf[BroadcastHint]) + assert(plan2.asInstanceOf[Join].right.isInstanceOf[BroadcastHint]) + assert(!plan3.asInstanceOf[Join].left.isInstanceOf[BroadcastHint]) + assert(!plan3.asInstanceOf[Join].right.isInstanceOf[BroadcastHint]) + } + } + test("join key rewritten") { val l = Literal(1L) val i = Literal(2) |