diff options
Diffstat (limited to 'sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala')
-rw-r--r-- | sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala | 245 |
1 files changed, 153 insertions, 92 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala index 51cfc50130..d0132529f1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala @@ -16,91 +16,106 @@ */ package org.apache.spark.sql.catalyst.parser -import scala.annotation.tailrec - -import org.antlr.runtime._ -import org.antlr.runtime.tree.CommonTree +import org.antlr.v4.runtime._ +import org.antlr.v4.runtime.atn.PredictionMode +import org.antlr.v4.runtime.misc.ParseCancellationException import org.apache.spark.internal.Logging import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.trees.Origin +import org.apache.spark.sql.types.DataType /** - * The ParseDriver takes a SQL command and turns this into an AST. - * - * This is based on Hive's org.apache.hadoop.hive.ql.parse.ParseDriver + * Base SQL parsing infrastructure. */ -object ParseDriver extends Logging { - /** Create an LogicalPlan ASTNode from a SQL command. */ - def parsePlan(command: String, conf: ParserConf): ASTNode = parse(command, conf) { parser => - parser.statement().getTree - } +abstract class AbstractSqlParser extends ParserInterface with Logging { - /** Create an Expression ASTNode from a SQL command. */ - def parseExpression(command: String, conf: ParserConf): ASTNode = parse(command, conf) { parser => - parser.singleNamedExpression().getTree + /** Creates/Resolves DataType for a given SQL string. */ + def parseDataType(sqlText: String): DataType = parse(sqlText) { parser => + // TODO add this to the parser interface. + astBuilder.visitSingleDataType(parser.singleDataType()) } - /** Create an TableIdentifier ASTNode from a SQL command. */ - def parseTableName(command: String, conf: ParserConf): ASTNode = parse(command, conf) { parser => - parser.singleTableName().getTree + /** Creates Expression for a given SQL string. */ + override def parseExpression(sqlText: String): Expression = parse(sqlText) { parser => + astBuilder.visitSingleExpression(parser.singleExpression()) } - private def parse( - command: String, - conf: ParserConf)( - toTree: SparkSqlParser => CommonTree): ASTNode = { - logInfo(s"Parsing command: $command") + /** Creates TableIdentifier for a given SQL string. */ + override def parseTableIdentifier(sqlText: String): TableIdentifier = parse(sqlText) { parser => + astBuilder.visitSingleTableIdentifier(parser.singleTableIdentifier()) + } - // Setup error collection. - val reporter = new ParseErrorReporter() + /** Creates LogicalPlan for a given SQL string. */ + override def parsePlan(sqlText: String): LogicalPlan = parse(sqlText) { parser => + astBuilder.visitSingleStatement(parser.singleStatement()) match { + case plan: LogicalPlan => plan + case _ => nativeCommand(sqlText) + } + } - // Create lexer. - val lexer = new SparkSqlLexer(new ANTLRNoCaseStringStream(command)) - val tokens = new TokenRewriteStream(lexer) - lexer.configure(conf, reporter) + /** Get the builder (visitor) which converts a ParseTree into a AST. */ + protected def astBuilder: AstBuilder - // Create the parser. - val parser = new SparkSqlParser(tokens) - parser.configure(conf, reporter) + /** Create a native command, or fail when this is not supported. */ + protected def nativeCommand(sqlText: String): LogicalPlan = { + val position = Origin(None, None) + throw new ParseException(Option(sqlText), "Unsupported SQL statement", position, position) + } - try { - val result = toTree(parser) + protected def parse[T](command: String)(toResult: SqlBaseParser => T): T = { + logInfo(s"Parsing command: $command") - // Check errors. - reporter.checkForErrors() + val lexer = new SqlBaseLexer(new ANTLRNoCaseStringStream(command)) + lexer.removeErrorListeners() + lexer.addErrorListener(ParseErrorListener) - // Return the AST node from the result. - logInfo(s"Parse completed.") + val tokenStream = new CommonTokenStream(lexer) + val parser = new SqlBaseParser(tokenStream) + parser.addParseListener(PostProcessor) + parser.removeErrorListeners() + parser.addErrorListener(ParseErrorListener) - // Find the non null token tree in the result. - @tailrec - def nonNullToken(tree: CommonTree): CommonTree = { - if (tree.token != null || tree.getChildCount == 0) tree - else nonNullToken(tree.getChild(0).asInstanceOf[CommonTree]) + try { + try { + // first, try parsing with potentially faster SLL mode + parser.getInterpreter.setPredictionMode(PredictionMode.SLL) + toResult(parser) } - val tree = nonNullToken(result) - - // Make sure all boundaries are set. - tree.setUnknownTokenBoundaries() - - // Construct the immutable AST. - def createASTNode(tree: CommonTree): ASTNode = { - val children = (0 until tree.getChildCount).map { i => - createASTNode(tree.getChild(i).asInstanceOf[CommonTree]) - }.toList - ASTNode(tree.token, tree.getTokenStartIndex, tree.getTokenStopIndex, children, tokens) + catch { + case e: ParseCancellationException => + // if we fail, parse with LL mode + tokenStream.reset() // rewind input stream + parser.reset() + + // Try Again. + parser.getInterpreter.setPredictionMode(PredictionMode.LL) + toResult(parser) } - createASTNode(tree) } catch { - case e: RecognitionException => - logInfo(s"Parse failed.") - reporter.throwError(e) + case e: ParseException if e.command.isDefined => + throw e + case e: ParseException => + throw e.withCommand(command) + case e: AnalysisException => + val position = Origin(e.line, e.startPosition) + throw new ParseException(Option(command), e.message, position, position) } } } /** + * Concrete SQL parser for Catalyst-only SQL statements. + */ +object CatalystSqlParser extends AbstractSqlParser { + val astBuilder = new AstBuilder +} + +/** * This string stream provides the lexer with upper case characters only. This greatly simplifies * lexing the stream, while we can maintain the original command. * @@ -120,58 +135,104 @@ object ParseDriver extends Logging { * have the ANTLRNoCaseStringStream implementation. */ -private[parser] class ANTLRNoCaseStringStream(input: String) extends ANTLRStringStream(input) { +private[parser] class ANTLRNoCaseStringStream(input: String) extends ANTLRInputStream(input) { override def LA(i: Int): Int = { val la = super.LA(i) - if (la == 0 || la == CharStream.EOF) la + if (la == 0 || la == IntStream.EOF) la else Character.toUpperCase(la) } } /** - * Utility used by the Parser and the Lexer for error collection and reporting. + * The ParseErrorListener converts parse errors into AnalysisExceptions. */ -private[parser] class ParseErrorReporter { - val errors = scala.collection.mutable.Buffer.empty[ParseError] - - def report(br: BaseRecognizer, re: RecognitionException, tokenNames: Array[String]): Unit = { - errors += ParseError(br, re, tokenNames) +case object ParseErrorListener extends BaseErrorListener { + override def syntaxError( + recognizer: Recognizer[_, _], + offendingSymbol: scala.Any, + line: Int, + charPositionInLine: Int, + msg: String, + e: RecognitionException): Unit = { + val position = Origin(Some(line), Some(charPositionInLine)) + throw new ParseException(None, msg, position, position) } +} - def checkForErrors(): Unit = { - if (errors.nonEmpty) { - val first = errors.head - val e = first.re - throwError(e.line, e.charPositionInLine, first.buildMessage().toString, errors.tail) - } +/** + * A [[ParseException]] is an [[AnalysisException]] that is thrown during the parse process. It + * contains fields and an extended error message that make reporting and diagnosing errors easier. + */ +class ParseException( + val command: Option[String], + message: String, + val start: Origin, + val stop: Origin) extends AnalysisException(message, start.line, start.startPosition) { + + def this(message: String, ctx: ParserRuleContext) = { + this(Option(ParserUtils.command(ctx)), + message, + ParserUtils.position(ctx.getStart), + ParserUtils.position(ctx.getStop)) } - def throwError(e: RecognitionException): Nothing = { - throwError(e.line, e.charPositionInLine, e.toString, errors) + override def getMessage: String = { + val builder = new StringBuilder + builder ++= "\n" ++= message + start match { + case Origin(Some(l), Some(p)) => + builder ++= s"(line $l, pos $p)\n" + command.foreach { cmd => + val (above, below) = cmd.split("\n").splitAt(l) + builder ++= "\n== SQL ==\n" + above.foreach(builder ++= _ += '\n') + builder ++= (0 until p).map(_ => "-").mkString("") ++= "^^^\n" + below.foreach(builder ++= _ += '\n') + } + case _ => + command.foreach { cmd => + builder ++= "\n== SQL ==\n" ++= cmd + } + } + builder.toString } - private def throwError( - line: Int, - startPosition: Int, - msg: String, - errors: Seq[ParseError]): Nothing = { - val b = new StringBuilder - b.append(msg).append("\n") - errors.foreach(error => error.buildMessage(b).append("\n")) - throw new AnalysisException(b.toString, Option(line), Option(startPosition)) + def withCommand(cmd: String): ParseException = { + new ParseException(Option(cmd), message, start, stop) } } /** - * Error collected during the parsing process. - * - * This is based on Hive's org.apache.hadoop.hive.ql.parse.ParseError + * The post-processor validates & cleans-up the parse tree during the parse process. */ -private[parser] case class ParseError( - br: BaseRecognizer, - re: RecognitionException, - tokenNames: Array[String]) { - def buildMessage(s: StringBuilder = new StringBuilder): StringBuilder = { - s.append(br.getErrorHeader(re)).append(" ").append(br.getErrorMessage(re, tokenNames)) +case object PostProcessor extends SqlBaseBaseListener { + + /** Remove the back ticks from an Identifier. */ + override def exitQuotedIdentifier(ctx: SqlBaseParser.QuotedIdentifierContext): Unit = { + replaceTokenByIdentifier(ctx, 1) { token => + // Remove the double back ticks in the string. + token.setText(token.getText.replace("``", "`")) + token + } + } + + /** Treat non-reserved keywords as Identifiers. */ + override def exitNonReserved(ctx: SqlBaseParser.NonReservedContext): Unit = { + replaceTokenByIdentifier(ctx, 0)(identity) + } + + private def replaceTokenByIdentifier( + ctx: ParserRuleContext, + stripMargins: Int)( + f: CommonToken => CommonToken = identity): Unit = { + val parent = ctx.getParent + parent.removeLastChild() + val token = ctx.getChild(0).getPayload.asInstanceOf[Token] + parent.addChild(f(new CommonToken( + new org.antlr.v4.runtime.misc.Pair(token.getTokenSource, token.getInputStream), + SqlBaseParser.IDENTIFIER, + token.getChannel, + token.getStartIndex + stripMargins, + token.getStopIndex - stripMargins))) } } |