diff options
-rw-r--r-- | sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala | 9 | ||||
-rw-r--r-- | sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ParserUtilsSuite.scala | 126 |
2 files changed, 128 insertions, 7 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala index bc35ae2f55..cb89a9679a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala @@ -31,11 +31,7 @@ import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin} object ParserUtils { /** Get the command which created the token. */ def command(ctx: ParserRuleContext): String = { - command(ctx.getStart.getInputStream) - } - - /** Get the command which created the token. */ - def command(stream: CharStream): String = { + val stream = ctx.getStart.getInputStream stream.getText(Interval.of(0, stream.size())) } @@ -74,7 +70,8 @@ object ParserUtils { /** Get the origin (line and position) of the token. */ def position(token: Token): Origin = { - Origin(Option(token.getLine), Option(token.getCharPositionInLine)) + val opt = Option(token) + Origin(opt.map(_.getLine), opt.map(_.getCharPositionInLine)) } /** Validate the condition. If it doesn't throw a parse exception. */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ParserUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ParserUtilsSuite.scala index d090daf7b4..d5748a4ff1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ParserUtilsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ParserUtilsSuite.scala @@ -16,12 +16,53 @@ */ package org.apache.spark.sql.catalyst.parser +import org.antlr.v4.runtime.{CommonTokenStream, ParserRuleContext} + import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.parser.SqlBaseParser._ +import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin} class ParserUtilsSuite extends SparkFunSuite { import ParserUtils._ + val setConfContext = buildContext("set example.setting.name=setting.value") { parser => + parser.statement().asInstanceOf[SetConfigurationContext] + } + + val showFuncContext = buildContext("show functions foo.bar") { parser => + parser.statement().asInstanceOf[ShowFunctionsContext] + } + + val descFuncContext = buildContext("describe function extended bar") { parser => + parser.statement().asInstanceOf[DescribeFunctionContext] + } + + val showDbsContext = buildContext("show databases like 'identifier_with_wildcards'") { parser => + parser.statement().asInstanceOf[ShowDatabasesContext] + } + + val createDbContext = buildContext( + """ + |CREATE DATABASE IF NOT EXISTS database_name + |COMMENT 'database_comment' LOCATION '/home/user/db' + |WITH DBPROPERTIES ('a'='a', 'b'='b', 'c'='c') + """.stripMargin + ) { parser => + parser.statement().asInstanceOf[CreateDatabaseContext] + } + + val emptyContext = buildContext("") { parser => + parser.statement + } + + private def buildContext[T](command: String)(toResult: SqlBaseParser => T): T = { + val lexer = new SqlBaseLexer(new ANTLRNoCaseStringStream(command)) + val tokenStream = new CommonTokenStream(lexer) + val parser = new SqlBaseParser(tokenStream) + toResult(parser) + } + test("unescapeSQLString") { // scalastyle:off nonascii @@ -61,5 +102,88 @@ class ParserUtilsSuite extends SparkFunSuite { // scalastyle:on nonascii } - // TODO: Add test cases for other methods in ParserUtils + test("command") { + assert(command(setConfContext) == "set example.setting.name=setting.value") + assert(command(showFuncContext) == "show functions foo.bar") + assert(command(descFuncContext) == "describe function extended bar") + assert(command(showDbsContext) == "show databases like 'identifier_with_wildcards'") + } + + test("operationNotAllowed") { + val errorMessage = "parse.fail.operation.not.allowed.error.message" + val e = intercept[ParseException] { + operationNotAllowed(errorMessage, showFuncContext) + }.getMessage + assert(e.contains("Operation not allowed")) + assert(e.contains(errorMessage)) + } + + test("checkDuplicateKeys") { + val properties = Seq(("a", "a"), ("b", "b"), ("c", "c")) + checkDuplicateKeys[String](properties, createDbContext) + + val properties2 = Seq(("a", "a"), ("b", "b"), ("a", "c")) + val e = intercept[ParseException] { + checkDuplicateKeys(properties2, createDbContext) + }.getMessage + assert(e.contains("Found duplicate keys")) + } + + test("source") { + assert(source(setConfContext) == "set example.setting.name=setting.value") + assert(source(showFuncContext) == "show functions foo.bar") + assert(source(descFuncContext) == "describe function extended bar") + assert(source(showDbsContext) == "show databases like 'identifier_with_wildcards'") + } + + test("remainder") { + assert(remainder(setConfContext) == "") + assert(remainder(showFuncContext) == "") + assert(remainder(descFuncContext) == "") + assert(remainder(showDbsContext) == "") + + assert(remainder(setConfContext.SET.getSymbol) == " example.setting.name=setting.value") + assert(remainder(showFuncContext.FUNCTIONS.getSymbol) == " foo.bar") + assert(remainder(descFuncContext.EXTENDED.getSymbol) == " bar") + assert(remainder(showDbsContext.LIKE.getSymbol) == " 'identifier_with_wildcards'") + } + + test("string") { + assert(string(showDbsContext.pattern) == "identifier_with_wildcards") + assert(string(createDbContext.comment) == "database_comment") + + assert(string(createDbContext.locationSpec.STRING) == "/home/user/db") + } + + test("position") { + assert(position(setConfContext.start) == Origin(Some(1), Some(0))) + assert(position(showFuncContext.stop) == Origin(Some(1), Some(19))) + assert(position(descFuncContext.describeFuncName.start) == Origin(Some(1), Some(27))) + assert(position(createDbContext.locationSpec.start) == Origin(Some(3), Some(27))) + assert(position(emptyContext.stop) == Origin(None, None)) + } + + test("validate") { + val f1 = { ctx: ParserRuleContext => + ctx.children != null && !ctx.children.isEmpty + } + val message = "ParserRuleContext should not be empty." + validate(f1(showFuncContext), message, showFuncContext) + + val e = intercept[ParseException] { + validate(f1(emptyContext), message, emptyContext) + }.getMessage + assert(e.contains(message)) + } + + test("withOrigin") { + val ctx = createDbContext.locationSpec + val current = CurrentOrigin.get + val (location, origin) = withOrigin(ctx) { + (string(ctx.STRING), CurrentOrigin.get) + } + assert(location == "/home/user/db") + assert(origin == Origin(Some(3), Some(27))) + assert(CurrentOrigin.get == current) + } } |