path: root/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala
diff options
Diffstat (limited to 'sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala')
1 files changed, 153 insertions, 92 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala
index 51cfc50130..d0132529f1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala
@@ -16,91 +16,106 @@
package org.apache.spark.sql.catalyst.parser
-import scala.annotation.tailrec
-import org.antlr.runtime._
-import org.antlr.runtime.tree.CommonTree
+import org.antlr.v4.runtime._
+import org.antlr.v4.runtime.atn.PredictionMode
+import org.antlr.v4.runtime.misc.ParseCancellationException
import org.apache.spark.internal.Logging
import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.catalyst.TableIdentifier
+import org.apache.spark.sql.catalyst.expressions.Expression
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.trees.Origin
+import org.apache.spark.sql.types.DataType
- * The ParseDriver takes a SQL command and turns this into an AST.
- *
- * This is based on Hive's org.apache.hadoop.hive.ql.parse.ParseDriver
+ * Base SQL parsing infrastructure.
-object ParseDriver extends Logging {
- /** Create an LogicalPlan ASTNode from a SQL command. */
- def parsePlan(command: String, conf: ParserConf): ASTNode = parse(command, conf) { parser =>
- parser.statement().getTree
- }
+abstract class AbstractSqlParser extends ParserInterface with Logging {
- /** Create an Expression ASTNode from a SQL command. */
- def parseExpression(command: String, conf: ParserConf): ASTNode = parse(command, conf) { parser =>
- parser.singleNamedExpression().getTree
+ /** Creates/Resolves DataType for a given SQL string. */
+ def parseDataType(sqlText: String): DataType = parse(sqlText) { parser =>
+ // TODO add this to the parser interface.
+ astBuilder.visitSingleDataType(parser.singleDataType())
- /** Create an TableIdentifier ASTNode from a SQL command. */
- def parseTableName(command: String, conf: ParserConf): ASTNode = parse(command, conf) { parser =>
- parser.singleTableName().getTree
+ /** Creates Expression for a given SQL string. */
+ override def parseExpression(sqlText: String): Expression = parse(sqlText) { parser =>
+ astBuilder.visitSingleExpression(parser.singleExpression())
- private def parse(
- command: String,
- conf: ParserConf)(
- toTree: SparkSqlParser => CommonTree): ASTNode = {
- logInfo(s"Parsing command: $command")
+ /** Creates TableIdentifier for a given SQL string. */
+ override def parseTableIdentifier(sqlText: String): TableIdentifier = parse(sqlText) { parser =>
+ astBuilder.visitSingleTableIdentifier(parser.singleTableIdentifier())
+ }
- // Setup error collection.
- val reporter = new ParseErrorReporter()
+ /** Creates LogicalPlan for a given SQL string. */
+ override def parsePlan(sqlText: String): LogicalPlan = parse(sqlText) { parser =>
+ astBuilder.visitSingleStatement(parser.singleStatement()) match {
+ case plan: LogicalPlan => plan
+ case _ => nativeCommand(sqlText)
+ }
+ }
- // Create lexer.
- val lexer = new SparkSqlLexer(new ANTLRNoCaseStringStream(command))
- val tokens = new TokenRewriteStream(lexer)
- lexer.configure(conf, reporter)
+ /** Get the builder (visitor) which converts a ParseTree into a AST. */
+ protected def astBuilder: AstBuilder
- // Create the parser.
- val parser = new SparkSqlParser(tokens)
- parser.configure(conf, reporter)
+ /** Create a native command, or fail when this is not supported. */
+ protected def nativeCommand(sqlText: String): LogicalPlan = {
+ val position = Origin(None, None)
+ throw new ParseException(Option(sqlText), "Unsupported SQL statement", position, position)
+ }
- try {
- val result = toTree(parser)
+ protected def parse[T](command: String)(toResult: SqlBaseParser => T): T = {
+ logInfo(s"Parsing command: $command")
- // Check errors.
- reporter.checkForErrors()
+ val lexer = new SqlBaseLexer(new ANTLRNoCaseStringStream(command))
+ lexer.removeErrorListeners()
+ lexer.addErrorListener(ParseErrorListener)
- // Return the AST node from the result.
- logInfo(s"Parse completed.")
+ val tokenStream = new CommonTokenStream(lexer)
+ val parser = new SqlBaseParser(tokenStream)
+ parser.addParseListener(PostProcessor)
+ parser.removeErrorListeners()
+ parser.addErrorListener(ParseErrorListener)
- // Find the non null token tree in the result.
- @tailrec
- def nonNullToken(tree: CommonTree): CommonTree = {
- if (tree.token != null || tree.getChildCount == 0) tree
- else nonNullToken(tree.getChild(0).asInstanceOf[CommonTree])
+ try {
+ try {
+ // first, try parsing with potentially faster SLL mode
+ parser.getInterpreter.setPredictionMode(PredictionMode.SLL)
+ toResult(parser)
- val tree = nonNullToken(result)
- // Make sure all boundaries are set.
- tree.setUnknownTokenBoundaries()
- // Construct the immutable AST.
- def createASTNode(tree: CommonTree): ASTNode = {
- val children = (0 until tree.getChildCount).map { i =>
- createASTNode(tree.getChild(i).asInstanceOf[CommonTree])
- }.toList
- ASTNode(tree.token, tree.getTokenStartIndex, tree.getTokenStopIndex, children, tokens)
+ catch {
+ case e: ParseCancellationException =>
+ // if we fail, parse with LL mode
+ tokenStream.reset() // rewind input stream
+ parser.reset()
+ // Try Again.
+ parser.getInterpreter.setPredictionMode(PredictionMode.LL)
+ toResult(parser)
- createASTNode(tree)
catch {
- case e: RecognitionException =>
- logInfo(s"Parse failed.")
- reporter.throwError(e)
+ case e: ParseException if e.command.isDefined =>
+ throw e
+ case e: ParseException =>
+ throw e.withCommand(command)
+ case e: AnalysisException =>
+ val position = Origin(e.line, e.startPosition)
+ throw new ParseException(Option(command), e.message, position, position)
+ * Concrete SQL parser for Catalyst-only SQL statements.
+ */
+object CatalystSqlParser extends AbstractSqlParser {
+ val astBuilder = new AstBuilder
* This string stream provides the lexer with upper case characters only. This greatly simplifies
* lexing the stream, while we can maintain the original command.
@@ -120,58 +135,104 @@ object ParseDriver extends Logging {
* have the ANTLRNoCaseStringStream implementation.
-private[parser] class ANTLRNoCaseStringStream(input: String) extends ANTLRStringStream(input) {
+private[parser] class ANTLRNoCaseStringStream(input: String) extends ANTLRInputStream(input) {
override def LA(i: Int): Int = {
val la = super.LA(i)
- if (la == 0 || la == CharStream.EOF) la
+ if (la == 0 || la == IntStream.EOF) la
else Character.toUpperCase(la)
- * Utility used by the Parser and the Lexer for error collection and reporting.
+ * The ParseErrorListener converts parse errors into AnalysisExceptions.
-private[parser] class ParseErrorReporter {
- val errors = scala.collection.mutable.Buffer.empty[ParseError]
- def report(br: BaseRecognizer, re: RecognitionException, tokenNames: Array[String]): Unit = {
- errors += ParseError(br, re, tokenNames)
+case object ParseErrorListener extends BaseErrorListener {
+ override def syntaxError(
+ recognizer: Recognizer[_, _],
+ offendingSymbol: scala.Any,
+ line: Int,
+ charPositionInLine: Int,
+ msg: String,
+ e: RecognitionException): Unit = {
+ val position = Origin(Some(line), Some(charPositionInLine))
+ throw new ParseException(None, msg, position, position)
- def checkForErrors(): Unit = {
- if (errors.nonEmpty) {
- val first = errors.head
- val e = first.re
- throwError(e.line, e.charPositionInLine, first.buildMessage().toString, errors.tail)
- }
+ * A [[ParseException]] is an [[AnalysisException]] that is thrown during the parse process. It
+ * contains fields and an extended error message that make reporting and diagnosing errors easier.
+ */
+class ParseException(
+ val command: Option[String],
+ message: String,
+ val start: Origin,
+ val stop: Origin) extends AnalysisException(message, start.line, start.startPosition) {
+ def this(message: String, ctx: ParserRuleContext) = {
+ this(Option(ParserUtils.command(ctx)),
+ message,
+ ParserUtils.position(ctx.getStart),
+ ParserUtils.position(ctx.getStop))
- def throwError(e: RecognitionException): Nothing = {
- throwError(e.line, e.charPositionInLine, e.toString, errors)
+ override def getMessage: String = {
+ val builder = new StringBuilder
+ builder ++= "\n" ++= message
+ start match {
+ case Origin(Some(l), Some(p)) =>
+ builder ++= s"(line $l, pos $p)\n"
+ command.foreach { cmd =>
+ val (above, below) = cmd.split("\n").splitAt(l)
+ builder ++= "\n== SQL ==\n"
+ above.foreach(builder ++= _ += '\n')
+ builder ++= (0 until p).map(_ => "-").mkString("") ++= "^^^\n"
+ below.foreach(builder ++= _ += '\n')
+ }
+ case _ =>
+ command.foreach { cmd =>
+ builder ++= "\n== SQL ==\n" ++= cmd
+ }
+ }
+ builder.toString
- private def throwError(
- line: Int,
- startPosition: Int,
- msg: String,
- errors: Seq[ParseError]): Nothing = {
- val b = new StringBuilder
- b.append(msg).append("\n")
- errors.foreach(error => error.buildMessage(b).append("\n"))
- throw new AnalysisException(b.toString, Option(line), Option(startPosition))
+ def withCommand(cmd: String): ParseException = {
+ new ParseException(Option(cmd), message, start, stop)
- * Error collected during the parsing process.
- *
- * This is based on Hive's org.apache.hadoop.hive.ql.parse.ParseError
+ * The post-processor validates & cleans-up the parse tree during the parse process.
-private[parser] case class ParseError(
- br: BaseRecognizer,
- re: RecognitionException,
- tokenNames: Array[String]) {
- def buildMessage(s: StringBuilder = new StringBuilder): StringBuilder = {
- s.append(br.getErrorHeader(re)).append(" ").append(br.getErrorMessage(re, tokenNames))
+case object PostProcessor extends SqlBaseBaseListener {
+ /** Remove the back ticks from an Identifier. */
+ override def exitQuotedIdentifier(ctx: SqlBaseParser.QuotedIdentifierContext): Unit = {
+ replaceTokenByIdentifier(ctx, 1) { token =>
+ // Remove the double back ticks in the string.
+ token.setText(token.getText.replace("``", "`"))
+ token
+ }
+ }
+ /** Treat non-reserved keywords as Identifiers. */
+ override def exitNonReserved(ctx: SqlBaseParser.NonReservedContext): Unit = {
+ replaceTokenByIdentifier(ctx, 0)(identity)
+ }
+ private def replaceTokenByIdentifier(
+ ctx: ParserRuleContext,
+ stripMargins: Int)(
+ f: CommonToken => CommonToken = identity): Unit = {
+ val parent = ctx.getParent
+ parent.removeLastChild()
+ val token = ctx.getChild(0).getPayload.asInstanceOf[Token]
+ parent.addChild(f(new CommonToken(
+ new org.antlr.v4.runtime.misc.Pair(token.getTokenSource, token.getInputStream),
+ SqlBaseParser.IDENTIFIER,
+ token.getChannel,
+ token.getStartIndex + stripMargins,
+ token.getStopIndex - stripMargins)))