aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SelectClauseParser.g7
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala59
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala24
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystQlSuite.scala151
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/ExtendedHiveQlParser.scala2
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala4
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala19
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala5
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQlSuite.scala2
9 files changed, 217 insertions, 56 deletions
diff --git a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SelectClauseParser.g b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SelectClauseParser.g
index 2d2bafb1ee..f18b6ec496 100644
--- a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SelectClauseParser.g
+++ b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SelectClauseParser.g
@@ -131,6 +131,13 @@ selectItem
:
(tableAllColumns) => tableAllColumns -> ^(TOK_SELEXPR tableAllColumns)
|
+ namedExpression
+ ;
+
+namedExpression
+@init { gParent.pushMsg("select named expression", state); }
+@after { gParent.popMsg(state); }
+ :
( expression
((KW_AS? identifier) | (KW_AS LPAREN identifier (COMMA identifier)* RPAREN))?
) -> ^(TOK_SELEXPR expression identifier*)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala
index 2e3cc0bfde..c87b6c8e95 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala
@@ -30,6 +30,12 @@ import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.CalendarInterval
import org.apache.spark.util.random.RandomSampler
+private[sql] object CatalystQl {
+ val parser = new CatalystQl
+ def parseExpression(sql: String): Expression = parser.parseExpression(sql)
+ def parseTableIdentifier(sql: String): TableIdentifier = parser.parseTableIdentifier(sql)
+}
+
/**
* This class translates a HQL String to a Catalyst [[LogicalPlan]] or [[Expression]].
*/
@@ -41,16 +47,13 @@ private[sql] class CatalystQl(val conf: ParserConf = SimpleParserConf()) {
}
}
-
/**
- * Returns the AST for the given SQL string.
+ * The safeParse method allows a user to focus on the parsing/AST transformation logic. This
+ * method will take care of possible errors during the parsing process.
*/
- protected def getAst(sql: String): ASTNode = ParseDriver.parse(sql, conf)
-
- /** Creates LogicalPlan for a given HiveQL string. */
- def createPlan(sql: String): LogicalPlan = {
+ protected def safeParse[T](sql: String, ast: ASTNode)(toResult: ASTNode => T): T = {
try {
- createPlan(sql, ParseDriver.parse(sql, conf))
+ toResult(ast)
} catch {
case e: MatchError => throw e
case e: AnalysisException => throw e
@@ -58,26 +61,39 @@ private[sql] class CatalystQl(val conf: ParserConf = SimpleParserConf()) {
throw new AnalysisException(e.getMessage)
case e: NotImplementedError =>
throw new AnalysisException(
- s"""
- |Unsupported language features in query: $sql
- |${getAst(sql).treeString}
+ s"""Unsupported language features in query
+ |== SQL ==
+ |$sql
+ |== AST ==
+ |${ast.treeString}
+ |== Error ==
|$e
+ |== Stacktrace ==
|${e.getStackTrace.head}
""".stripMargin)
}
}
- protected def createPlan(sql: String, tree: ASTNode): LogicalPlan = nodeToPlan(tree)
-
- def parseDdl(ddl: String): Seq[Attribute] = {
- val tree = getAst(ddl)
- assert(tree.text == "TOK_CREATETABLE", "Only CREATE TABLE supported.")
- val tableOps = tree.children
- val colList = tableOps
- .find(_.text == "TOK_TABCOLLIST")
- .getOrElse(sys.error("No columnList!"))
-
- colList.children.map(nodeToAttribute)
+ /** Creates LogicalPlan for a given SQL string. */
+ def parsePlan(sql: String): LogicalPlan =
+ safeParse(sql, ParseDriver.parsePlan(sql, conf))(nodeToPlan)
+
+ /** Creates Expression for a given SQL string. */
+ def parseExpression(sql: String): Expression =
+ safeParse(sql, ParseDriver.parseExpression(sql, conf))(selExprNodeToExpr(_).get)
+
+ /** Creates TableIdentifier for a given SQL string. */
+ def parseTableIdentifier(sql: String): TableIdentifier =
+ safeParse(sql, ParseDriver.parseTableName(sql, conf))(extractTableIdent)
+
+ def parseDdl(sql: String): Seq[Attribute] = {
+ safeParse(sql, ParseDriver.parseExpression(sql, conf)) { ast =>
+ val Token("TOK_CREATETABLE", children) = ast
+ children
+ .find(_.text == "TOK_TABCOLLIST")
+ .getOrElse(sys.error("No columnList!"))
+ .flatMap(_.children.map(nodeToAttribute))
+ }
}
protected def getClauses(
@@ -187,7 +203,6 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
val keyMap = keyASTs.zipWithIndex.toMap
val bitmasks: Seq[Int] = setASTs.map {
- case Token("TOK_GROUPING_SETS_EXPRESSION", null) => 0
case Token("TOK_GROUPING_SETS_EXPRESSION", columns) =>
columns.foldLeft(0)((bitmap, col) => {
val keyIndex = keyMap.find(_._1.treeEquals(col)).map(_._2)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala
index 0e93af8b92..f8e4f21451 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala
@@ -28,7 +28,25 @@ import org.apache.spark.sql.AnalysisException
* This is based on Hive's org.apache.hadoop.hive.ql.parse.ParseDriver
*/
object ParseDriver extends Logging {
- def parse(command: String, conf: ParserConf): ASTNode = {
+ /** Create an LogicalPlan ASTNode from a SQL command. */
+ def parsePlan(command: String, conf: ParserConf): ASTNode = parse(command, conf) { parser =>
+ parser.statement().getTree
+ }
+
+ /** Create an Expression ASTNode from a SQL command. */
+ def parseExpression(command: String, conf: ParserConf): ASTNode = parse(command, conf) { parser =>
+ parser.namedExpression().getTree
+ }
+
+ /** Create an TableIdentifier ASTNode from a SQL command. */
+ def parseTableName(command: String, conf: ParserConf): ASTNode = parse(command, conf) { parser =>
+ parser.tableName().getTree
+ }
+
+ private def parse(
+ command: String,
+ conf: ParserConf)(
+ toTree: SparkSqlParser => CommonTree): ASTNode = {
logInfo(s"Parsing command: $command")
// Setup error collection.
@@ -44,7 +62,7 @@ object ParseDriver extends Logging {
parser.configure(conf, reporter)
try {
- val result = parser.statement()
+ val result = toTree(parser)
// Check errors.
reporter.checkForErrors()
@@ -57,7 +75,7 @@ object ParseDriver extends Logging {
if (tree.token != null || tree.getChildCount == 0) tree
else nonNullToken(tree.getChild(0).asInstanceOf[CommonTree])
}
- val tree = nonNullToken(result.getTree)
+ val tree = nonNullToken(result)
// Make sure all boundaries are set.
tree.setUnknownTokenBoundaries()
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystQlSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystQlSuite.scala
index d7204c3488..ba9d2524a9 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystQlSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystQlSuite.scala
@@ -17,36 +17,157 @@
package org.apache.spark.sql.catalyst
+import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.catalyst.analysis.{UnresolvedAlias, UnresolvedAttribute, UnresolvedFunction}
+import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.PlanTest
+import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Project}
+import org.apache.spark.unsafe.types.CalendarInterval
class CatalystQlSuite extends PlanTest {
val parser = new CatalystQl()
+ test("test case insensitive") {
+ val result = Project(UnresolvedAlias(Literal(1)):: Nil, OneRowRelation)
+ assert(result === parser.parsePlan("seLect 1"))
+ assert(result === parser.parsePlan("select 1"))
+ assert(result === parser.parsePlan("SELECT 1"))
+ }
+
+ test("test NOT operator with comparison operations") {
+ val parsed = parser.parsePlan("SELECT NOT TRUE > TRUE")
+ val expected = Project(
+ UnresolvedAlias(
+ Not(
+ GreaterThan(Literal(true), Literal(true)))
+ ) :: Nil,
+ OneRowRelation)
+ comparePlans(parsed, expected)
+ }
+
+ test("support hive interval literal") {
+ def checkInterval(sql: String, result: CalendarInterval): Unit = {
+ val parsed = parser.parsePlan(sql)
+ val expected = Project(
+ UnresolvedAlias(
+ Literal(result)
+ ) :: Nil,
+ OneRowRelation)
+ comparePlans(parsed, expected)
+ }
+
+ def checkYearMonth(lit: String): Unit = {
+ checkInterval(
+ s"SELECT INTERVAL '$lit' YEAR TO MONTH",
+ CalendarInterval.fromYearMonthString(lit))
+ }
+
+ def checkDayTime(lit: String): Unit = {
+ checkInterval(
+ s"SELECT INTERVAL '$lit' DAY TO SECOND",
+ CalendarInterval.fromDayTimeString(lit))
+ }
+
+ def checkSingleUnit(lit: String, unit: String): Unit = {
+ checkInterval(
+ s"SELECT INTERVAL '$lit' $unit",
+ CalendarInterval.fromSingleUnitString(unit, lit))
+ }
+
+ checkYearMonth("123-10")
+ checkYearMonth("496-0")
+ checkYearMonth("-2-3")
+ checkYearMonth("-123-0")
+
+ checkDayTime("99 11:22:33.123456789")
+ checkDayTime("-99 11:22:33.123456789")
+ checkDayTime("10 9:8:7.123456789")
+ checkDayTime("1 0:0:0")
+ checkDayTime("-1 0:0:0")
+ checkDayTime("1 0:0:1")
+
+ for (unit <- Seq("year", "month", "day", "hour", "minute", "second")) {
+ checkSingleUnit("7", unit)
+ checkSingleUnit("-7", unit)
+ checkSingleUnit("0", unit)
+ }
+
+ checkSingleUnit("13.123456789", "second")
+ checkSingleUnit("-13.123456789", "second")
+ }
+
+ test("support scientific notation") {
+ def assertRight(input: String, output: Double): Unit = {
+ val parsed = parser.parsePlan("SELECT " + input)
+ val expected = Project(
+ UnresolvedAlias(
+ Literal(output)
+ ) :: Nil,
+ OneRowRelation)
+ comparePlans(parsed, expected)
+ }
+
+ assertRight("9.0e1", 90)
+ assertRight("0.9e+2", 90)
+ assertRight("900e-1", 90)
+ assertRight("900.0E-1", 90)
+ assertRight("9.e+1", 90)
+
+ intercept[AnalysisException](parser.parsePlan("SELECT .e3"))
+ }
+
+ test("parse expressions") {
+ compareExpressions(
+ parser.parseExpression("prinln('hello', 'world')"),
+ UnresolvedFunction(
+ "prinln", Literal("hello") :: Literal("world") :: Nil, false))
+
+ compareExpressions(
+ parser.parseExpression("1 + r.r As q"),
+ Alias(Add(Literal(1), UnresolvedAttribute("r.r")), "q")())
+
+ compareExpressions(
+ parser.parseExpression("1 - f('o', o(bar))"),
+ Subtract(Literal(1),
+ UnresolvedFunction("f",
+ Literal("o") ::
+ UnresolvedFunction("o", UnresolvedAttribute("bar") :: Nil, false) ::
+ Nil, false)))
+ }
+
+ test("table identifier") {
+ assert(TableIdentifier("q") === parser.parseTableIdentifier("q"))
+ assert(TableIdentifier("q", Some("d")) === parser.parseTableIdentifier("d.q"))
+ intercept[AnalysisException](parser.parseTableIdentifier(""))
+ // TODO parser swallows third identifier.
+ // intercept[AnalysisException](parser.parseTableIdentifier("d.q.g"))
+ }
+
test("parse union/except/intersect") {
- parser.createPlan("select * from t1 union all select * from t2")
- parser.createPlan("select * from t1 union distinct select * from t2")
- parser.createPlan("select * from t1 union select * from t2")
- parser.createPlan("select * from t1 except select * from t2")
- parser.createPlan("select * from t1 intersect select * from t2")
- parser.createPlan("(select * from t1) union all (select * from t2)")
- parser.createPlan("(select * from t1) union distinct (select * from t2)")
- parser.createPlan("(select * from t1) union (select * from t2)")
- parser.createPlan("select * from ((select * from t1) union (select * from t2)) t")
+ parser.parsePlan("select * from t1 union all select * from t2")
+ parser.parsePlan("select * from t1 union distinct select * from t2")
+ parser.parsePlan("select * from t1 union select * from t2")
+ parser.parsePlan("select * from t1 except select * from t2")
+ parser.parsePlan("select * from t1 intersect select * from t2")
+ parser.parsePlan("(select * from t1) union all (select * from t2)")
+ parser.parsePlan("(select * from t1) union distinct (select * from t2)")
+ parser.parsePlan("(select * from t1) union (select * from t2)")
+ parser.parsePlan("select * from ((select * from t1) union (select * from t2)) t")
}
test("window function: better support of parentheses") {
- parser.createPlan("select sum(product + 1) over (partition by ((1) + (product / 2)) " +
+ parser.parsePlan("select sum(product + 1) over (partition by ((1) + (product / 2)) " +
"order by 2) from windowData")
- parser.createPlan("select sum(product + 1) over (partition by (1 + (product / 2)) " +
+ parser.parsePlan("select sum(product + 1) over (partition by (1 + (product / 2)) " +
"order by 2) from windowData")
- parser.createPlan("select sum(product + 1) over (partition by ((product / 2) + 1) " +
+ parser.parsePlan("select sum(product + 1) over (partition by ((product / 2) + 1) " +
"order by 2) from windowData")
- parser.createPlan("select sum(product + 1) over (partition by ((product) + (1)) order by 2) " +
+ parser.parsePlan("select sum(product + 1) over (partition by ((product) + (1)) order by 2) " +
"from windowData")
- parser.createPlan("select sum(product + 1) over (partition by ((product) + 1) order by 2) " +
+ parser.parsePlan("select sum(product + 1) over (partition by ((product) + 1) order by 2) " +
"from windowData")
- parser.createPlan("select sum(product + 1) over (partition by (product + (1)) order by 2) " +
+ parser.parsePlan("select sum(product + 1) over (partition by (product + (1)) order by 2) " +
"from windowData")
}
}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/ExtendedHiveQlParser.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/ExtendedHiveQlParser.scala
index 395c8bff53..b22f424981 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/ExtendedHiveQlParser.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/ExtendedHiveQlParser.scala
@@ -38,7 +38,7 @@ private[hive] class ExtendedHiveQlParser extends AbstractSparkSQLParser {
protected lazy val hiveQl: Parser[LogicalPlan] =
restInput ^^ {
- case statement => HiveQl.createPlan(statement.trim)
+ case statement => HiveQl.parsePlan(statement.trim)
}
protected lazy val dfs: Parser[LogicalPlan] =
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
index 43d84d507b..67228f3f3c 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
@@ -414,8 +414,8 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive
alias match {
// because hive use things like `_c0` to build the expanded text
// currently we cannot support view from "create view v1(c1) as ..."
- case None => Subquery(table.name, HiveQl.createPlan(viewText))
- case Some(aliasText) => Subquery(aliasText, HiveQl.createPlan(viewText))
+ case None => Subquery(table.name, HiveQl.parsePlan(viewText))
+ case Some(aliasText) => Subquery(aliasText, HiveQl.parsePlan(viewText))
}
} else {
MetastoreRelation(qualifiedTableName.database, qualifiedTableName.name, alias)(table)(hive)
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
index d1b1c0d8d8..ca9ddf94c1 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
@@ -230,15 +230,16 @@ private[hive] object HiveQl extends SparkQl with Logging {
CreateViewAsSelect(tableDesc, nodeToPlan(query), allowExist, replace, sql)
}
- protected override def createPlan(
- sql: String,
- node: ASTNode): LogicalPlan = {
- if (nativeCommands.contains(node.text)) {
- HiveNativeCommand(sql)
- } else {
- nodeToPlan(node) match {
- case NativePlaceholder => HiveNativeCommand(sql)
- case plan => plan
+ /** Creates LogicalPlan for a given SQL string. */
+ override def parsePlan(sql: String): LogicalPlan = {
+ safeParse(sql, ParseDriver.parsePlan(sql, conf)) { ast =>
+ if (nativeCommands.contains(ast.text)) {
+ HiveNativeCommand(sql)
+ } else {
+ nodeToPlan(ast) match {
+ case NativePlaceholder => HiveNativeCommand(sql)
+ case plan => plan
+ }
}
}
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala
index e72a18a716..14a466cfe9 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala
@@ -117,9 +117,8 @@ class ErrorPositionSuite extends QueryTest with TestHiveSingleton with BeforeAnd
* @param token a unique token in the string that should be indicated by the exception
*/
def positionTest(name: String, query: String, token: String): Unit = {
- def ast = ParseDriver.parse(query, hiveContext.conf)
- def parseTree =
- Try(quietly(ast.treeString)).getOrElse("<failed to parse>")
+ def ast = ParseDriver.parsePlan(query, hiveContext.conf)
+ def parseTree = Try(quietly(ast.treeString)).getOrElse("<failed to parse>")
test(name) {
val error = intercept[AnalysisException] {
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQlSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQlSuite.scala
index f4a1a17422..53d15c14cb 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQlSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQlSuite.scala
@@ -28,7 +28,7 @@ import org.apache.spark.sql.hive.client.{ExternalTable, HiveColumn, HiveTable, M
class HiveQlSuite extends SparkFunSuite with BeforeAndAfterAll {
private def extractTableDesc(sql: String): (HiveTable, Boolean) = {
- HiveQl.createPlan(sql).collect {
+ HiveQl.parsePlan(sql).collect {
case CreateTableAsSelect(desc, child, allowExisting) => (desc, allowExisting)
}.head
}