aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorHerman van Hovell <hvanhovell@questtec.nl>2016-03-28 12:31:12 -0700
committerReynold Xin <rxin@databricks.com>2016-03-28 12:31:12 -0700
commit600c0b69cab4767e8e5a6f4284777d8b9d4bd40e (patch)
treebae635ab17a8b58400127f20bbbe5acaecc92f98
parent1528ff4c9affe1df103c4b3abd56a86c71d8b753 (diff)
downloadspark-600c0b69cab4767e8e5a6f4284777d8b9d4bd40e.tar.gz
spark-600c0b69cab4767e8e5a6f4284777d8b9d4bd40e.tar.bz2
spark-600c0b69cab4767e8e5a6f4284777d8b9d4bd40e.zip
[SPARK-13713][SQL] Migrate parser from ANTLR3 to ANTLR4
### What changes were proposed in this pull request? The current ANTLR3 parser is quite complex to maintain and suffers from code blow-ups. This PR introduces a new parser that is based on ANTLR4. This parser is based on the [Presto's SQL parser](https://github.com/facebook/presto/blob/master/presto-parser/src/main/antlr4/com/facebook/presto/sql/parser/SqlBase.g4). The current implementation can parse and create Catalyst and SQL plans. Large parts of the HiveQl DDL and some of the DML functionality is currently missing, the plan is to add this in follow-up PRs. This PR is a work in progress, and work needs to be done in the following area's: - [x] Error handling should be improved. - [x] Documentation should be improved. - [x] Multi-Insert needs to be tested. - [ ] Naming and package locations. ### How was this patch tested? Catalyst and SQL unit tests. Author: Herman van Hovell <hvanhovell@questtec.nl> Closes #11557 from hvanhovell/ngParser.
-rw-r--r--LICENSE1
-rw-r--r--dev/deps/spark-deps-hadoop-2.21
-rw-r--r--dev/deps/spark-deps-hadoop-2.31
-rw-r--r--dev/deps/spark-deps-hadoop-2.41
-rw-r--r--dev/deps/spark-deps-hadoop-2.61
-rw-r--r--dev/deps/spark-deps-hadoop-2.71
-rw-r--r--pom.xml6
-rw-r--r--project/SparkBuild.scala8
-rw-r--r--project/plugins.sbt6
-rw-r--r--python/pyspark/sql/tests.py6
-rw-r--r--python/pyspark/sql/utils.py8
-rw-r--r--sql/catalyst/pom.xml4
-rw-r--r--sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/ng/SqlBase.g4911
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala34
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ng/AstBuilder.scala1452
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ng/ParseDriver.scala240
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ng/ParserUtils.scala118
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/CatalystQlSuite.scala52
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DataTypeParserSuite.scala55
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ng/ErrorParserSuite.scala67
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ng/ExpressionParserSuite.scala497
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ng/PlanParserSuite.scala429
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ng/TableIdentifierParserSuite.scala42
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala20
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala219
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/functions.scala5
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala4
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala2
29 files changed, 4127 insertions, 66 deletions
diff --git a/LICENSE b/LICENSE
index d7a790a628..5a8c78b98b 100644
--- a/LICENSE
+++ b/LICENSE
@@ -238,6 +238,7 @@ The text of each license is also included at licenses/LICENSE-[project].txt.
(BSD 3 Clause) netlib core (com.github.fommil.netlib:core:1.1.2 - https://github.com/fommil/netlib-java/core)
(BSD 3 Clause) JPMML-Model (org.jpmml:pmml-model:1.2.7 - https://github.com/jpmml/jpmml-model)
(BSD License) AntLR Parser Generator (antlr:antlr:2.7.7 - http://www.antlr.org/)
+ (BSD License) ANTLR 4.5.2-1 (org.antlr:antlr4:4.5.2-1 - http://wwww.antlr.org/)
(BSD licence) ANTLR ST4 4.0.4 (org.antlr:ST4:4.0.4 - http://www.stringtemplate.org)
(BSD licence) ANTLR StringTemplate (org.antlr:stringtemplate:3.2.1 - http://www.stringtemplate.org)
(BSD License) Javolution (javolution:javolution:5.5.1 - http://javolution.org)
diff --git a/dev/deps/spark-deps-hadoop-2.2 b/dev/deps/spark-deps-hadoop-2.2
index 512675a599..7c2f88bdb1 100644
--- a/dev/deps/spark-deps-hadoop-2.2
+++ b/dev/deps/spark-deps-hadoop-2.2
@@ -3,6 +3,7 @@ RoaringBitmap-0.5.11.jar
ST4-4.0.4.jar
activation-1.1.jar
antlr-runtime-3.5.2.jar
+antlr4-runtime-4.5.2-1.jar
aopalliance-1.0.jar
apache-log4j-extras-1.2.17.jar
arpack_combined_all-0.1.jar
diff --git a/dev/deps/spark-deps-hadoop-2.3 b/dev/deps/spark-deps-hadoop-2.3
index 31f8694fed..f4d600038d 100644
--- a/dev/deps/spark-deps-hadoop-2.3
+++ b/dev/deps/spark-deps-hadoop-2.3
@@ -3,6 +3,7 @@ RoaringBitmap-0.5.11.jar
ST4-4.0.4.jar
activation-1.1.1.jar
antlr-runtime-3.5.2.jar
+antlr4-runtime-4.5.2-1.jar
aopalliance-1.0.jar
apache-log4j-extras-1.2.17.jar
arpack_combined_all-0.1.jar
diff --git a/dev/deps/spark-deps-hadoop-2.4 b/dev/deps/spark-deps-hadoop-2.4
index 0fa8bccab0..7c5e2c35bd 100644
--- a/dev/deps/spark-deps-hadoop-2.4
+++ b/dev/deps/spark-deps-hadoop-2.4
@@ -3,6 +3,7 @@ RoaringBitmap-0.5.11.jar
ST4-4.0.4.jar
activation-1.1.1.jar
antlr-runtime-3.5.2.jar
+antlr4-runtime-4.5.2-1.jar
aopalliance-1.0.jar
apache-log4j-extras-1.2.17.jar
arpack_combined_all-0.1.jar
diff --git a/dev/deps/spark-deps-hadoop-2.6 b/dev/deps/spark-deps-hadoop-2.6
index 8d2f6e6e32..03d9a51057 100644
--- a/dev/deps/spark-deps-hadoop-2.6
+++ b/dev/deps/spark-deps-hadoop-2.6
@@ -3,6 +3,7 @@ RoaringBitmap-0.5.11.jar
ST4-4.0.4.jar
activation-1.1.1.jar
antlr-runtime-3.5.2.jar
+antlr4-runtime-4.5.2-1.jar
aopalliance-1.0.jar
apache-log4j-extras-1.2.17.jar
apacheds-i18n-2.0.0-M15.jar
diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7
index a114c4ae8d..5765071a1c 100644
--- a/dev/deps/spark-deps-hadoop-2.7
+++ b/dev/deps/spark-deps-hadoop-2.7
@@ -3,6 +3,7 @@ RoaringBitmap-0.5.11.jar
ST4-4.0.4.jar
activation-1.1.1.jar
antlr-runtime-3.5.2.jar
+antlr4-runtime-4.5.2-1.jar
aopalliance-1.0.jar
apache-log4j-extras-1.2.17.jar
apacheds-i18n-2.0.0-M15.jar
diff --git a/pom.xml b/pom.xml
index b4cfa3a598..475f0544bd 100644
--- a/pom.xml
+++ b/pom.xml
@@ -178,6 +178,7 @@
<jsr305.version>1.3.9</jsr305.version>
<libthrift.version>0.9.2</libthrift.version>
<antlr.version>3.5.2</antlr.version>
+ <antlr4.version>4.5.2-1</antlr4.version>
<test.java.home>${java.home}</test.java.home>
<test.exclude.tags></test.exclude.tags>
@@ -1759,6 +1760,11 @@
<artifactId>antlr-runtime</artifactId>
<version>${antlr.version}</version>
</dependency>
+ <dependency>
+ <groupId>org.antlr</groupId>
+ <artifactId>antlr4-runtime</artifactId>
+ <version>${antlr4.version}</version>
+ </dependency>
</dependencies>
</dependencyManagement>
diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala
index fb229b979d..39a9e16f7e 100644
--- a/project/SparkBuild.scala
+++ b/project/SparkBuild.scala
@@ -25,6 +25,7 @@ import sbt._
import sbt.Classpaths.publishTask
import sbt.Keys._
import sbtunidoc.Plugin.UnidocKeys.unidocGenjavadocVersion
+import com.simplytyped.Antlr4Plugin._
import com.typesafe.sbt.pom.{PomBuild, SbtPomKeys}
import com.typesafe.tools.mima.plugin.MimaKeys
@@ -401,7 +402,10 @@ object OldDeps {
}
object Catalyst {
- lazy val settings = Seq(
+ lazy val settings = antlr4Settings ++ Seq(
+ antlr4PackageName in Antlr4 := Some("org.apache.spark.sql.catalyst.parser.ng"),
+ antlr4GenListener in Antlr4 := true,
+ antlr4GenVisitor in Antlr4 := true,
// ANTLR code-generation step.
//
// This has been heavily inspired by com.github.stefri.sbt-antlr (0.5.3). It fixes a number of
@@ -414,7 +418,7 @@ object Catalyst {
"SparkSqlLexer.g",
"SparkSqlParser.g")
val sourceDir = (sourceDirectory in Compile).value / "antlr3"
- val targetDir = (sourceManaged in Compile).value
+ val targetDir = (sourceManaged in Compile).value / "antlr3"
// Create default ANTLR Tool.
val antlr = new org.antlr.Tool
diff --git a/project/plugins.sbt b/project/plugins.sbt
index eeca94a47c..d9ed7962bf 100644
--- a/project/plugins.sbt
+++ b/project/plugins.sbt
@@ -23,3 +23,9 @@ libraryDependencies += "org.ow2.asm" % "asm" % "5.0.3"
libraryDependencies += "org.ow2.asm" % "asm-commons" % "5.0.3"
libraryDependencies += "org.antlr" % "antlr" % "3.5.2"
+
+
+// TODO I am not sure we want such a dep.
+resolvers += "simplytyped" at "http://simplytyped.github.io/repo/releases"
+
+addSbtPlugin("com.simplytyped" % "sbt-antlr4" % "0.7.10")
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 83ef76c13c..1a5d422af9 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -51,7 +51,7 @@ from pyspark.sql.types import UserDefinedType, _infer_type
from pyspark.tests import ReusedPySparkTestCase
from pyspark.sql.functions import UserDefinedFunction, sha2
from pyspark.sql.window import Window
-from pyspark.sql.utils import AnalysisException, IllegalArgumentException
+from pyspark.sql.utils import AnalysisException, ParseException, IllegalArgumentException
class UTCOffsetTimezone(datetime.tzinfo):
@@ -1130,7 +1130,9 @@ class SQLTests(ReusedPySparkTestCase):
def test_capture_analysis_exception(self):
self.assertRaises(AnalysisException, lambda: self.sqlCtx.sql("select abc"))
self.assertRaises(AnalysisException, lambda: self.df.selectExpr("a + b"))
- self.assertRaises(AnalysisException, lambda: self.sqlCtx.sql("abc"))
+
+ def test_capture_parse_exception(self):
+ self.assertRaises(ParseException, lambda: self.sqlCtx.sql("abc"))
def test_capture_illegalargument_exception(self):
self.assertRaisesRegexp(IllegalArgumentException, "Setting negative mapred.reduce.tasks",
diff --git a/python/pyspark/sql/utils.py b/python/pyspark/sql/utils.py
index b0a0373372..b89ea8c6e0 100644
--- a/python/pyspark/sql/utils.py
+++ b/python/pyspark/sql/utils.py
@@ -33,6 +33,12 @@ class AnalysisException(CapturedException):
"""
+class ParseException(CapturedException):
+ """
+ Failed to parse a SQL command.
+ """
+
+
class IllegalArgumentException(CapturedException):
"""
Passed an illegal or inappropriate argument.
@@ -49,6 +55,8 @@ def capture_sql_exception(f):
e.java_exception.getStackTrace()))
if s.startswith('org.apache.spark.sql.AnalysisException: '):
raise AnalysisException(s.split(': ', 1)[1], stackTrace)
+ if s.startswith('org.apache.spark.sql.catalyst.parser.ng.ParseException: '):
+ raise ParseException(s.split(': ', 1)[1], stackTrace)
if s.startswith('java.lang.IllegalArgumentException: '):
raise IllegalArgumentException(s.split(': ', 1)[1], stackTrace)
raise
diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml
index 5d1d9edd25..c834a011f1 100644
--- a/sql/catalyst/pom.xml
+++ b/sql/catalyst/pom.xml
@@ -76,6 +76,10 @@
<artifactId>antlr-runtime</artifactId>
</dependency>
<dependency>
+ <groupId>org.antlr</groupId>
+ <artifactId>antlr4-runtime</artifactId>
+ </dependency>
+ <dependency>
<groupId>commons-codec</groupId>
<artifactId>commons-codec</artifactId>
</dependency>
diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/ng/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/ng/SqlBase.g4
new file mode 100644
index 0000000000..e46fd9bed5
--- /dev/null
+++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/ng/SqlBase.g4
@@ -0,0 +1,911 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ * This file is an adaptation of Presto's presto-parser/src/main/antlr4/com/facebook/presto/sql/parser/SqlBase.g4 grammar.
+ */
+
+grammar SqlBase;
+
+tokens {
+ DELIMITER
+}
+
+singleStatement
+ : statement EOF
+ ;
+
+singleExpression
+ : namedExpression EOF
+ ;
+
+singleTableIdentifier
+ : tableIdentifier EOF
+ ;
+
+singleDataType
+ : dataType EOF
+ ;
+
+statement
+ : query #statementDefault
+ | USE db=identifier #use
+ | CREATE DATABASE (IF NOT EXISTS)? identifier
+ (COMMENT comment=STRING)? locationSpec?
+ (WITH DBPROPERTIES tablePropertyList)? #createDatabase
+ | ALTER DATABASE identifier SET DBPROPERTIES tablePropertyList #setDatabaseProperties
+ | DROP DATABASE (IF EXISTS)? identifier (RESTRICT | CASCADE)? #dropDatabase
+ | createTableHeader ('(' colTypeList ')')? tableProvider
+ (OPTIONS tablePropertyList)? #createTableUsing
+ | createTableHeader tableProvider
+ (OPTIONS tablePropertyList)? AS? query #createTableUsing
+ | createTableHeader ('(' colTypeList ')')? (COMMENT STRING)?
+ (PARTITIONED BY identifierList)? bucketSpec? skewSpec?
+ rowFormat? createFileFormat? locationSpec?
+ (TBLPROPERTIES tablePropertyList)?
+ (AS? query)? #createTable
+ | ANALYZE TABLE tableIdentifier partitionSpec? COMPUTE STATISTICS
+ (identifier | FOR COLUMNS identifierSeq?) #analyze
+ | ALTER TABLE from=tableIdentifier RENAME TO to=tableIdentifier #renameTable
+ | ALTER TABLE tableIdentifier
+ SET TBLPROPERTIES tablePropertyList #setTableProperties
+ | ALTER TABLE tableIdentifier
+ UNSET TBLPROPERTIES (IF EXISTS)? tablePropertyList #unsetTableProperties
+ | ALTER TABLE tableIdentifier (partitionSpec)?
+ SET SERDE STRING (WITH SERDEPROPERTIES tablePropertyList)? #setTableSerDe
+ | ALTER TABLE tableIdentifier (partitionSpec)?
+ SET SERDEPROPERTIES tablePropertyList #setTableSerDe
+ | ALTER TABLE tableIdentifier bucketSpec #bucketTable
+ | ALTER TABLE tableIdentifier NOT CLUSTERED #unclusterTable
+ | ALTER TABLE tableIdentifier NOT SORTED #unsortTable
+ | ALTER TABLE tableIdentifier skewSpec #skewTable
+ | ALTER TABLE tableIdentifier NOT SKEWED #unskewTable
+ | ALTER TABLE tableIdentifier NOT STORED AS DIRECTORIES #unstoreTable
+ | ALTER TABLE tableIdentifier
+ SET SKEWED LOCATION skewedLocationList #setTableSkewLocations
+ | ALTER TABLE tableIdentifier ADD (IF NOT EXISTS)?
+ partitionSpecLocation+ #addTablePartition
+ | ALTER TABLE tableIdentifier
+ from=partitionSpec RENAME TO to=partitionSpec #renameTablePartition
+ | ALTER TABLE from=tableIdentifier
+ EXCHANGE partitionSpec WITH TABLE to=tableIdentifier #exchangeTablePartition
+ | ALTER TABLE tableIdentifier
+ DROP (IF EXISTS)? partitionSpec (',' partitionSpec)* PURGE? #dropTablePartitions
+ | ALTER TABLE tableIdentifier ARCHIVE partitionSpec #archiveTablePartition
+ | ALTER TABLE tableIdentifier UNARCHIVE partitionSpec #unarchiveTablePartition
+ | ALTER TABLE tableIdentifier partitionSpec?
+ SET FILEFORMAT fileFormat #setTableFileFormat
+ | ALTER TABLE tableIdentifier partitionSpec? SET locationSpec #setTableLocation
+ | ALTER TABLE tableIdentifier TOUCH partitionSpec? #touchTable
+ | ALTER TABLE tableIdentifier partitionSpec? COMPACT STRING #compactTable
+ | ALTER TABLE tableIdentifier partitionSpec? CONCATENATE #concatenateTable
+ | ALTER TABLE tableIdentifier partitionSpec?
+ CHANGE COLUMN? oldName=identifier colType
+ (FIRST | AFTER after=identifier)? (CASCADE | RESTRICT)? #changeColumn
+ | ALTER TABLE tableIdentifier partitionSpec?
+ ADD COLUMNS '(' colTypeList ')' (CASCADE | RESTRICT)? #addColumns
+ | ALTER TABLE tableIdentifier partitionSpec?
+ REPLACE COLUMNS '(' colTypeList ')' (CASCADE | RESTRICT)? #replaceColumns
+ | DROP TABLE (IF EXISTS)? tableIdentifier PURGE?
+ (FOR METADATA? REPLICATION '(' STRING ')')? #dropTable
+ | CREATE (OR REPLACE)? VIEW (IF NOT EXISTS)? tableIdentifier
+ identifierCommentList? (COMMENT STRING)?
+ (PARTITIONED ON identifierList)?
+ (TBLPROPERTIES tablePropertyList)? AS query #createView
+ | ALTER VIEW tableIdentifier AS? query #alterViewQuery
+ | CREATE TEMPORARY? FUNCTION qualifiedName AS className=STRING
+ (USING resource (',' resource)*)? #createFunction
+ | DROP TEMPORARY? FUNCTION (IF EXISTS)? qualifiedName #dropFunction
+ | EXPLAIN explainOption* statement #explain
+ | SHOW TABLES ((FROM | IN) db=identifier)?
+ (LIKE (qualifiedName | pattern=STRING))? #showTables
+ | SHOW FUNCTIONS (LIKE? (qualifiedName | pattern=STRING))? #showFunctions
+ | (DESC | DESCRIBE) FUNCTION EXTENDED? qualifiedName #describeFunction
+ | (DESC | DESCRIBE) option=(EXTENDED | FORMATTED)?
+ tableIdentifier partitionSpec? describeColName? #describeTable
+ | (DESC | DESCRIBE) DATABASE EXTENDED? identifier #describeDatabase
+ | REFRESH TABLE tableIdentifier #refreshTable
+ | CACHE LAZY? TABLE identifier (AS? query)? #cacheTable
+ | UNCACHE TABLE identifier #uncacheTable
+ | CLEAR CACHE #clearCache
+ | ADD identifier .*? #addResource
+ | SET .*? #setConfiguration
+ | hiveNativeCommands #executeNativeCommand
+ ;
+
+hiveNativeCommands
+ : createTableHeader LIKE tableIdentifier
+ rowFormat? createFileFormat? locationSpec?
+ (TBLPROPERTIES tablePropertyList)?
+ | DELETE FROM tableIdentifier (WHERE booleanExpression)?
+ | TRUNCATE TABLE tableIdentifier partitionSpec?
+ (COLUMNS identifierList)?
+ | ALTER VIEW from=tableIdentifier AS? RENAME TO to=tableIdentifier
+ | ALTER VIEW from=tableIdentifier AS?
+ SET TBLPROPERTIES tablePropertyList
+ | ALTER VIEW from=tableIdentifier AS?
+ UNSET TBLPROPERTIES (IF EXISTS)? tablePropertyList
+ | ALTER VIEW from=tableIdentifier AS?
+ ADD (IF NOT EXISTS)? partitionSpecLocation+
+ | ALTER VIEW from=tableIdentifier AS?
+ DROP (IF EXISTS)? partitionSpec (',' partitionSpec)* PURGE?
+ | DROP VIEW (IF EXISTS)? qualifiedName
+ | SHOW COLUMNS (FROM | IN) tableIdentifier ((FROM|IN) identifier)?
+ | START TRANSACTION (transactionMode (',' transactionMode)*)?
+ | COMMIT WORK?
+ | ROLLBACK WORK?
+ | SHOW PARTITIONS tableIdentifier partitionSpec?
+ | DFS .*?
+ | (CREATE | ALTER | DROP | SHOW | DESC | DESCRIBE | REVOKE | GRANT | LOCK | UNLOCK | MSCK | EXPORT | IMPORT | LOAD) .*?
+ ;
+
+createTableHeader
+ : CREATE TEMPORARY? EXTERNAL? TABLE (IF NOT EXISTS)? tableIdentifier
+ ;
+
+bucketSpec
+ : CLUSTERED BY identifierList
+ (SORTED BY orderedIdentifierList)?
+ INTO INTEGER_VALUE BUCKETS
+ ;
+
+skewSpec
+ : SKEWED BY identifierList
+ ON (constantList | nestedConstantList)
+ (STORED AS DIRECTORIES)?
+ ;
+
+locationSpec
+ : LOCATION STRING
+ ;
+
+query
+ : ctes? queryNoWith
+ ;
+
+insertInto
+ : INSERT OVERWRITE TABLE tableIdentifier partitionSpec? (IF NOT EXISTS)?
+ | INSERT INTO TABLE? tableIdentifier partitionSpec?
+ ;
+
+partitionSpecLocation
+ : partitionSpec locationSpec?
+ ;
+
+partitionSpec
+ : PARTITION '(' partitionVal (',' partitionVal)* ')'
+ ;
+
+partitionVal
+ : identifier (EQ constant)?
+ ;
+
+describeColName
+ : identifier ('.' (identifier | STRING))*
+ ;
+
+ctes
+ : WITH namedQuery (',' namedQuery)*
+ ;
+
+namedQuery
+ : name=identifier AS? '(' queryNoWith ')'
+ ;
+
+tableProvider
+ : USING qualifiedName
+ ;
+
+tablePropertyList
+ : '(' tableProperty (',' tableProperty)* ')'
+ ;
+
+tableProperty
+ : key=tablePropertyKey (EQ? value=STRING)?
+ ;
+
+tablePropertyKey
+ : looseIdentifier ('.' looseIdentifier)*
+ | STRING
+ ;
+
+constantList
+ : '(' constant (',' constant)* ')'
+ ;
+
+nestedConstantList
+ : '(' constantList (',' constantList)* ')'
+ ;
+
+skewedLocation
+ : (constant | constantList) EQ STRING
+ ;
+
+skewedLocationList
+ : '(' skewedLocation (',' skewedLocation)* ')'
+ ;
+
+createFileFormat
+ : STORED AS fileFormat
+ | STORED BY storageHandler
+ ;
+
+fileFormat
+ : INPUTFORMAT inFmt=STRING OUTPUTFORMAT outFmt=STRING (SERDE serdeCls=STRING)?
+ (INPUTDRIVER inDriver=STRING OUTPUTDRIVER outDriver=STRING)? #tableFileFormat
+ | identifier #genericFileFormat
+ ;
+
+storageHandler
+ : STRING (WITH SERDEPROPERTIES tablePropertyList)?
+ ;
+
+resource
+ : identifier STRING
+ ;
+
+queryNoWith
+ : insertInto? queryTerm queryOrganization #singleInsertQuery
+ | fromClause multiInsertQueryBody+ #multiInsertQuery
+ ;
+
+queryOrganization
+ : (ORDER BY order+=sortItem (',' order+=sortItem)*)?
+ (CLUSTER BY clusterBy+=expression (',' clusterBy+=expression)*)?
+ (DISTRIBUTE BY distributeBy+=expression (',' distributeBy+=expression)*)?
+ (SORT BY sort+=sortItem (',' sort+=sortItem)*)?
+ windows?
+ (LIMIT limit=expression)?
+ ;
+
+multiInsertQueryBody
+ : insertInto?
+ querySpecification
+ queryOrganization
+ ;
+
+queryTerm
+ : queryPrimary #queryTermDefault
+ | left=queryTerm operator=(INTERSECT | UNION | EXCEPT) setQuantifier? right=queryTerm #setOperation
+ ;
+
+queryPrimary
+ : querySpecification #queryPrimaryDefault
+ | TABLE tableIdentifier #table
+ | inlineTable #inlineTableDefault1
+ | '(' queryNoWith ')' #subquery
+ ;
+
+sortItem
+ : expression ordering=(ASC | DESC)?
+ ;
+
+querySpecification
+ : (((SELECT kind=TRANSFORM '(' namedExpressionSeq ')'
+ | kind=MAP namedExpressionSeq
+ | kind=REDUCE namedExpressionSeq))
+ inRowFormat=rowFormat?
+ (RECORDWRITER recordWriter=STRING)?
+ USING script=STRING
+ (AS (identifierSeq | colTypeList | ('(' (identifierSeq | colTypeList) ')')))?
+ outRowFormat=rowFormat?
+ (RECORDREADER recordReader=STRING)?
+ fromClause?
+ (WHERE where=booleanExpression)?)
+ | ((kind=SELECT setQuantifier? namedExpressionSeq fromClause?
+ | fromClause (kind=SELECT setQuantifier? namedExpressionSeq)?)
+ lateralView*
+ (WHERE where=booleanExpression)?
+ aggregation?
+ (HAVING having=booleanExpression)?
+ windows?)
+ ;
+
+fromClause
+ : FROM relation (',' relation)* lateralView*
+ ;
+
+aggregation
+ : GROUP BY groupingExpressions+=expression (',' groupingExpressions+=expression)* (
+ WITH kind=ROLLUP
+ | WITH kind=CUBE
+ | kind=GROUPING SETS '(' groupingSet (',' groupingSet)* ')')?
+ ;
+
+groupingSet
+ : '(' (expression (',' expression)*)? ')'
+ | expression
+ ;
+
+lateralView
+ : LATERAL VIEW (OUTER)? qualifiedName '(' (expression (',' expression)*)? ')' tblName=identifier (AS? colName+=identifier (',' colName+=identifier)*)?
+ ;
+
+setQuantifier
+ : DISTINCT
+ | ALL
+ ;
+
+relation
+ : left=relation
+ ((CROSS | joinType) JOIN right=relation joinCriteria?
+ | NATURAL joinType JOIN right=relation
+ ) #joinRelation
+ | relationPrimary #relationDefault
+ ;
+
+joinType
+ : INNER?
+ | LEFT OUTER?
+ | LEFT SEMI
+ | RIGHT OUTER?
+ | FULL OUTER?
+ ;
+
+joinCriteria
+ : ON booleanExpression
+ | USING '(' identifier (',' identifier)* ')'
+ ;
+
+sample
+ : TABLESAMPLE '('
+ ( (percentage=(INTEGER_VALUE | DECIMAL_VALUE) sampleType=PERCENTLIT)
+ | (expression sampleType=ROWS)
+ | (sampleType=BUCKET numerator=INTEGER_VALUE OUT OF denominator=INTEGER_VALUE (ON identifier)?))
+ ')'
+ ;
+
+identifierList
+ : '(' identifierSeq ')'
+ ;
+
+identifierSeq
+ : identifier (',' identifier)*
+ ;
+
+orderedIdentifierList
+ : '(' orderedIdentifier (',' orderedIdentifier)* ')'
+ ;
+
+orderedIdentifier
+ : identifier ordering=(ASC | DESC)?
+ ;
+
+identifierCommentList
+ : '(' identifierComment (',' identifierComment)* ')'
+ ;
+
+identifierComment
+ : identifier (COMMENT STRING)?
+ ;
+
+relationPrimary
+ : tableIdentifier sample? (AS? identifier)? #tableName
+ | '(' queryNoWith ')' sample? (AS? identifier)? #aliasedQuery
+ | '(' relation ')' sample? (AS? identifier)? #aliasedRelation
+ | inlineTable #inlineTableDefault2
+ ;
+
+inlineTable
+ : VALUES expression (',' expression)* (AS? identifier identifierList?)?
+ ;
+
+rowFormat
+ : ROW FORMAT SERDE name=STRING (WITH SERDEPROPERTIES props=tablePropertyList)? #rowFormatSerde
+ | ROW FORMAT DELIMITED
+ (FIELDS TERMINATED BY fieldsTerminatedBy=STRING (ESCAPED BY escapedBy=STRING)?)?
+ (COLLECTION ITEMS TERMINATED BY collectionItemsTerminatedBy=STRING)?
+ (MAP KEYS TERMINATED BY keysTerminatedBy=STRING)?
+ (LINES TERMINATED BY linesSeparatedBy=STRING)?
+ (NULL DEFINED AS nullDefinedAs=STRING)? #rowFormatDelimited
+ ;
+
+tableIdentifier
+ : (db=identifier '.')? table=identifier
+ ;
+
+namedExpression
+ : expression (AS? (identifier | identifierList))?
+ ;
+
+namedExpressionSeq
+ : namedExpression (',' namedExpression)*
+ ;
+
+expression
+ : booleanExpression
+ ;
+
+booleanExpression
+ : predicated #booleanDefault
+ | NOT booleanExpression #logicalNot
+ | left=booleanExpression operator=AND right=booleanExpression #logicalBinary
+ | left=booleanExpression operator=OR right=booleanExpression #logicalBinary
+ | EXISTS '(' query ')' #exists
+ ;
+
+// workaround for:
+// https://github.com/antlr/antlr4/issues/780
+// https://github.com/antlr/antlr4/issues/781
+predicated
+ : valueExpression predicate[$valueExpression.ctx]?
+ ;
+
+predicate[ParserRuleContext value]
+ : NOT? BETWEEN lower=valueExpression AND upper=valueExpression #between
+ | NOT? IN '(' expression (',' expression)* ')' #inList
+ | NOT? IN '(' query ')' #inSubquery
+ | NOT? like=(RLIKE | LIKE) pattern=valueExpression #like
+ | IS NOT? NULL #nullPredicate
+ ;
+
+valueExpression
+ : primaryExpression #valueExpressionDefault
+ | operator=(MINUS | PLUS | TILDE) valueExpression #arithmeticUnary
+ | left=valueExpression operator=(ASTERISK | SLASH | PERCENT | DIV) right=valueExpression #arithmeticBinary
+ | left=valueExpression operator=(PLUS | MINUS) right=valueExpression #arithmeticBinary
+ | left=valueExpression operator=AMPERSAND right=valueExpression #arithmeticBinary
+ | left=valueExpression operator=HAT right=valueExpression #arithmeticBinary
+ | left=valueExpression operator=PIPE right=valueExpression #arithmeticBinary
+ | left=valueExpression comparisonOperator right=valueExpression #comparison
+ ;
+
+primaryExpression
+ : constant #constantDefault
+ | ASTERISK #star
+ | qualifiedName '.' ASTERISK #star
+ | '(' expression (',' expression)+ ')' #rowConstructor
+ | qualifiedName '(' (setQuantifier? expression (',' expression)*)? ')' (OVER windowSpec)? #functionCall
+ | '(' query ')' #subqueryExpression
+ | CASE valueExpression whenClause+ (ELSE elseExpression=expression)? END #simpleCase
+ | CASE whenClause+ (ELSE elseExpression=expression)? END #searchedCase
+ | CAST '(' expression AS dataType ')' #cast
+ | value=primaryExpression '[' index=valueExpression ']' #subscript
+ | identifier #columnReference
+ | base=primaryExpression '.' fieldName=identifier #dereference
+ | '(' expression ')' #parenthesizedExpression
+ ;
+
+constant
+ : NULL #nullLiteral
+ | interval #intervalLiteral
+ | identifier STRING #typeConstructor
+ | number #numericLiteral
+ | booleanValue #booleanLiteral
+ | STRING+ #stringLiteral
+ ;
+
+comparisonOperator
+ : EQ | NEQ | NEQJ | LT | LTE | GT | GTE | NSEQ
+ ;
+
+booleanValue
+ : TRUE | FALSE
+ ;
+
+interval
+ : INTERVAL intervalField*
+ ;
+
+intervalField
+ : value=intervalValue unit=identifier (TO to=identifier)?
+ ;
+
+intervalValue
+ : (PLUS | MINUS)? (INTEGER_VALUE | DECIMAL_VALUE)
+ | STRING
+ ;
+
+dataType
+ : complex=ARRAY '<' dataType '>' #complexDataType
+ | complex=MAP '<' dataType ',' dataType '>' #complexDataType
+ | complex=STRUCT ('<' colTypeList? '>' | NEQ) #complexDataType
+ | identifier ('(' INTEGER_VALUE (',' INTEGER_VALUE)* ')')? #primitiveDataType
+ ;
+
+colTypeList
+ : colType (',' colType)*
+ ;
+
+colType
+ : identifier ':'? dataType (COMMENT STRING)?
+ ;
+
+whenClause
+ : WHEN condition=expression THEN result=expression
+ ;
+
+windows
+ : WINDOW namedWindow (',' namedWindow)*
+ ;
+
+namedWindow
+ : identifier AS windowSpec
+ ;
+
+windowSpec
+ : name=identifier #windowRef
+ | '('
+ ( CLUSTER BY partition+=expression (',' partition+=expression)*
+ | ((PARTITION | DISTRIBUTE) BY partition+=expression (',' partition+=expression)*)?
+ ((ORDER | SORT) BY sortItem (',' sortItem)*)?)
+ windowFrame?
+ ')' #windowDef
+ ;
+
+windowFrame
+ : frameType=RANGE start=frameBound
+ | frameType=ROWS start=frameBound
+ | frameType=RANGE BETWEEN start=frameBound AND end=frameBound
+ | frameType=ROWS BETWEEN start=frameBound AND end=frameBound
+ ;
+
+frameBound
+ : UNBOUNDED boundType=(PRECEDING | FOLLOWING)
+ | boundType=CURRENT ROW
+ | expression boundType=(PRECEDING | FOLLOWING)
+ ;
+
+
+explainOption
+ : LOGICAL | FORMATTED | EXTENDED
+ ;
+
+transactionMode
+ : ISOLATION LEVEL SNAPSHOT #isolationLevel
+ | READ accessMode=(ONLY | WRITE) #transactionAccessMode
+ ;
+
+qualifiedName
+ : identifier ('.' identifier)*
+ ;
+
+// Identifier that also allows the use of a number of SQL keywords (mainly for backwards compatibility).
+looseIdentifier
+ : identifier
+ | FROM
+ | TO
+ | TABLE
+ | WITH
+ ;
+
+identifier
+ : IDENTIFIER #unquotedIdentifier
+ | quotedIdentifier #quotedIdentifierAlternative
+ | nonReserved #unquotedIdentifier
+ ;
+
+quotedIdentifier
+ : BACKQUOTED_IDENTIFIER
+ ;
+
+number
+ : DECIMAL_VALUE #decimalLiteral
+ | SCIENTIFIC_DECIMAL_VALUE #scientificDecimalLiteral
+ | INTEGER_VALUE #integerLiteral
+ | BIGINT_LITERAL #bigIntLiteral
+ | SMALLINT_LITERAL #smallIntLiteral
+ | TINYINT_LITERAL #tinyIntLiteral
+ | DOUBLE_LITERAL #doubleLiteral
+ ;
+
+nonReserved
+ : SHOW | TABLES | COLUMNS | COLUMN | PARTITIONS | FUNCTIONS
+ | ADD
+ | OVER | PARTITION | RANGE | ROWS | PRECEDING | FOLLOWING | CURRENT | ROW | MAP | ARRAY | STRUCT
+ | LATERAL | WINDOW | REDUCE | TRANSFORM | USING | SERDE | SERDEPROPERTIES | RECORDREADER
+ | DELIMITED | FIELDS | TERMINATED | COLLECTION | ITEMS | KEYS | ESCAPED | LINES | SEPARATED
+ | EXTENDED | REFRESH | CLEAR | CACHE | UNCACHE | LAZY | TEMPORARY | OPTIONS
+ | GROUPING | CUBE | ROLLUP
+ | EXPLAIN | FORMAT | LOGICAL | FORMATTED
+ | TABLESAMPLE | USE | TO | BUCKET | PERCENTLIT | OUT | OF
+ | SET
+ | VIEW | REPLACE
+ | IF
+ | NO | DATA
+ | START | TRANSACTION | COMMIT | ROLLBACK | WORK | ISOLATION | LEVEL
+ | SNAPSHOT | READ | WRITE | ONLY
+ | SORT | CLUSTER | DISTRIBUTE UNSET | TBLPROPERTIES | SKEWED | STORED | DIRECTORIES | LOCATION
+ | EXCHANGE | ARCHIVE | UNARCHIVE | FILEFORMAT | TOUCH | COMPACT | CONCATENATE | CHANGE | FIRST
+ | AFTER | CASCADE | RESTRICT | BUCKETS | CLUSTERED | SORTED | PURGE | INPUTFORMAT | OUTPUTFORMAT
+ | INPUTDRIVER | OUTPUTDRIVER | DBPROPERTIES | DFS | TRUNCATE | METADATA | REPLICATION | COMPUTE
+ | STATISTICS | ANALYZE | PARTITIONED | EXTERNAL | DEFINED | RECORDWRITER
+ | REVOKE | GRANT | LOCK | UNLOCK | MSCK | EXPORT | IMPORT | LOAD | VALUES | COMMENT
+ ;
+
+SELECT: 'SELECT';
+FROM: 'FROM';
+ADD: 'ADD';
+AS: 'AS';
+ALL: 'ALL';
+DISTINCT: 'DISTINCT';
+WHERE: 'WHERE';
+GROUP: 'GROUP';
+BY: 'BY';
+GROUPING: 'GROUPING';
+SETS: 'SETS';
+CUBE: 'CUBE';
+ROLLUP: 'ROLLUP';
+ORDER: 'ORDER';
+HAVING: 'HAVING';
+LIMIT: 'LIMIT';
+AT: 'AT';
+OR: 'OR';
+AND: 'AND';
+IN: 'IN';
+NOT: 'NOT' | '!';
+NO: 'NO';
+EXISTS: 'EXISTS';
+BETWEEN: 'BETWEEN';
+LIKE: 'LIKE';
+RLIKE: 'RLIKE' | 'REGEXP';
+IS: 'IS';
+NULL: 'NULL';
+TRUE: 'TRUE';
+FALSE: 'FALSE';
+NULLS: 'NULLS';
+ASC: 'ASC';
+DESC: 'DESC';
+FOR: 'FOR';
+INTERVAL: 'INTERVAL';
+CASE: 'CASE';
+WHEN: 'WHEN';
+THEN: 'THEN';
+ELSE: 'ELSE';
+END: 'END';
+JOIN: 'JOIN';
+CROSS: 'CROSS';
+OUTER: 'OUTER';
+INNER: 'INNER';
+LEFT: 'LEFT';
+SEMI: 'SEMI';
+RIGHT: 'RIGHT';
+FULL: 'FULL';
+NATURAL: 'NATURAL';
+ON: 'ON';
+LATERAL: 'LATERAL';
+WINDOW: 'WINDOW';
+OVER: 'OVER';
+PARTITION: 'PARTITION';
+RANGE: 'RANGE';
+ROWS: 'ROWS';
+UNBOUNDED: 'UNBOUNDED';
+PRECEDING: 'PRECEDING';
+FOLLOWING: 'FOLLOWING';
+CURRENT: 'CURRENT';
+ROW: 'ROW';
+WITH: 'WITH';
+VALUES: 'VALUES';
+CREATE: 'CREATE';
+TABLE: 'TABLE';
+VIEW: 'VIEW';
+REPLACE: 'REPLACE';
+INSERT: 'INSERT';
+DELETE: 'DELETE';
+INTO: 'INTO';
+DESCRIBE: 'DESCRIBE';
+EXPLAIN: 'EXPLAIN';
+FORMAT: 'FORMAT';
+LOGICAL: 'LOGICAL';
+CAST: 'CAST';
+SHOW: 'SHOW';
+TABLES: 'TABLES';
+COLUMNS: 'COLUMNS';
+COLUMN: 'COLUMN';
+USE: 'USE';
+PARTITIONS: 'PARTITIONS';
+FUNCTIONS: 'FUNCTIONS';
+DROP: 'DROP';
+UNION: 'UNION';
+EXCEPT: 'EXCEPT';
+INTERSECT: 'INTERSECT';
+TO: 'TO';
+TABLESAMPLE: 'TABLESAMPLE';
+STRATIFY: 'STRATIFY';
+ALTER: 'ALTER';
+RENAME: 'RENAME';
+ARRAY: 'ARRAY';
+MAP: 'MAP';
+STRUCT: 'STRUCT';
+COMMENT: 'COMMENT';
+SET: 'SET';
+DATA: 'DATA';
+START: 'START';
+TRANSACTION: 'TRANSACTION';
+COMMIT: 'COMMIT';
+ROLLBACK: 'ROLLBACK';
+WORK: 'WORK';
+ISOLATION: 'ISOLATION';
+LEVEL: 'LEVEL';
+SNAPSHOT: 'SNAPSHOT';
+READ: 'READ';
+WRITE: 'WRITE';
+ONLY: 'ONLY';
+
+IF: 'IF';
+
+EQ : '=' | '==';
+NSEQ: '<=>';
+NEQ : '<>';
+NEQJ: '!=';
+LT : '<';
+LTE : '<=';
+GT : '>';
+GTE : '>=';
+
+PLUS: '+';
+MINUS: '-';
+ASTERISK: '*';
+SLASH: '/';
+PERCENT: '%';
+DIV: 'DIV';
+TILDE: '~';
+AMPERSAND: '&';
+PIPE: '|';
+HAT: '^';
+
+PERCENTLIT: 'PERCENT';
+BUCKET: 'BUCKET';
+OUT: 'OUT';
+OF: 'OF';
+
+SORT: 'SORT';
+CLUSTER: 'CLUSTER';
+DISTRIBUTE: 'DISTRIBUTE';
+OVERWRITE: 'OVERWRITE';
+TRANSFORM: 'TRANSFORM';
+REDUCE: 'REDUCE';
+USING: 'USING';
+SERDE: 'SERDE';
+SERDEPROPERTIES: 'SERDEPROPERTIES';
+RECORDREADER: 'RECORDREADER';
+RECORDWRITER: 'RECORDWRITER';
+DELIMITED: 'DELIMITED';
+FIELDS: 'FIELDS';
+TERMINATED: 'TERMINATED';
+COLLECTION: 'COLLECTION';
+ITEMS: 'ITEMS';
+KEYS: 'KEYS';
+ESCAPED: 'ESCAPED';
+LINES: 'LINES';
+SEPARATED: 'SEPARATED';
+FUNCTION: 'FUNCTION';
+EXTENDED: 'EXTENDED';
+REFRESH: 'REFRESH';
+CLEAR: 'CLEAR';
+CACHE: 'CACHE';
+UNCACHE: 'UNCACHE';
+LAZY: 'LAZY';
+FORMATTED: 'FORMATTED';
+TEMPORARY: 'TEMPORARY' | 'TEMP';
+OPTIONS: 'OPTIONS';
+UNSET: 'UNSET';
+TBLPROPERTIES: 'TBLPROPERTIES';
+DBPROPERTIES: 'DBPROPERTIES';
+BUCKETS: 'BUCKETS';
+SKEWED: 'SKEWED';
+STORED: 'STORED';
+DIRECTORIES: 'DIRECTORIES';
+LOCATION: 'LOCATION';
+EXCHANGE: 'EXCHANGE';
+ARCHIVE: 'ARCHIVE';
+UNARCHIVE: 'UNARCHIVE';
+FILEFORMAT: 'FILEFORMAT';
+TOUCH: 'TOUCH';
+COMPACT: 'COMPACT';
+CONCATENATE: 'CONCATENATE';
+CHANGE: 'CHANGE';
+FIRST: 'FIRST';
+AFTER: 'AFTER';
+CASCADE: 'CASCADE';
+RESTRICT: 'RESTRICT';
+CLUSTERED: 'CLUSTERED';
+SORTED: 'SORTED';
+PURGE: 'PURGE';
+INPUTFORMAT: 'INPUTFORMAT';
+OUTPUTFORMAT: 'OUTPUTFORMAT';
+INPUTDRIVER: 'INPUTDRIVER';
+OUTPUTDRIVER: 'OUTPUTDRIVER';
+DATABASE: 'DATABASE' | 'SCHEMA';
+DFS: 'DFS';
+TRUNCATE: 'TRUNCATE';
+METADATA: 'METADATA';
+REPLICATION: 'REPLICATION';
+ANALYZE: 'ANALYZE';
+COMPUTE: 'COMPUTE';
+STATISTICS: 'STATISTICS';
+PARTITIONED: 'PARTITIONED';
+EXTERNAL: 'EXTERNAL';
+DEFINED: 'DEFINED';
+REVOKE: 'REVOKE';
+GRANT: 'GRANT';
+LOCK: 'LOCK';
+UNLOCK: 'UNLOCK';
+MSCK: 'MSCK';
+EXPORT: 'EXPORT';
+IMPORT: 'IMPORT';
+LOAD: 'LOAD';
+
+STRING
+ : '\'' ( ~('\''|'\\') | ('\\' .) )* '\''
+ | '\"' ( ~('\"'|'\\') | ('\\' .) )* '\"'
+ ;
+
+BIGINT_LITERAL
+ : DIGIT+ 'L'
+ ;
+
+SMALLINT_LITERAL
+ : DIGIT+ 'S'
+ ;
+
+TINYINT_LITERAL
+ : DIGIT+ 'Y'
+ ;
+
+INTEGER_VALUE
+ : DIGIT+
+ ;
+
+DECIMAL_VALUE
+ : DIGIT+ '.' DIGIT*
+ | '.' DIGIT+
+ ;
+
+SCIENTIFIC_DECIMAL_VALUE
+ : DIGIT+ ('.' DIGIT*)? EXPONENT
+ | '.' DIGIT+ EXPONENT
+ ;
+
+DOUBLE_LITERAL
+ :
+ (INTEGER_VALUE | DECIMAL_VALUE | SCIENTIFIC_DECIMAL_VALUE) 'D'
+ ;
+
+IDENTIFIER
+ : (LETTER | DIGIT | '_')+
+ ;
+
+BACKQUOTED_IDENTIFIER
+ : '`' ( ~'`' | '``' )* '`'
+ ;
+
+fragment EXPONENT
+ : 'E' [+-]? DIGIT+
+ ;
+
+fragment DIGIT
+ : [0-9]
+ ;
+
+fragment LETTER
+ : [A-Z]
+ ;
+
+SIMPLE_COMMENT
+ : '--' ~[\r\n]* '\r'? '\n'? -> channel(HIDDEN)
+ ;
+
+BRACKETED_COMMENT
+ : '/*' .*? '*/' -> channel(HIDDEN)
+ ;
+
+WS
+ : [ \r\n\t]+ -> channel(HIDDEN)
+ ;
+
+// Catch-all for anything we can't recognize.
+// We use this to be able to ignore and recover all the text
+// when splitting statements with DelimiterLexer
+UNRECOGNIZED
+ : .
+ ;
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
index 3540014c3e..105947028d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
@@ -21,7 +21,7 @@ import java.sql.{Date, Timestamp}
import scala.language.implicitConversions
-import org.apache.spark.sql.catalyst.analysis.{EliminateSubqueryAliases, UnresolvedAttribute, UnresolvedExtractValue}
+import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.plans.{Inner, JoinType}
@@ -161,6 +161,10 @@ package object dsl {
def lower(e: Expression): Expression = Lower(e)
def sqrt(e: Expression): Expression = Sqrt(e)
def abs(e: Expression): Expression = Abs(e)
+ def star(names: String*): Expression = names match {
+ case Seq() => UnresolvedStar(None)
+ case target => UnresolvedStar(Option(target))
+ }
implicit class DslSymbol(sym: Symbol) extends ImplicitAttribute { def s: String = sym.name }
// TODO more implicit class for literal?
@@ -231,6 +235,12 @@ package object dsl {
AttributeReference(s, structType, nullable = true)()
def struct(attrs: AttributeReference*): AttributeReference =
struct(StructType.fromAttributes(attrs))
+
+ /** Create a function. */
+ def function(exprs: Expression*): UnresolvedFunction =
+ UnresolvedFunction(s, exprs, isDistinct = false)
+ def distinctFunction(exprs: Expression*): UnresolvedFunction =
+ UnresolvedFunction(s, exprs, isDistinct = true)
}
implicit class DslAttribute(a: AttributeReference) {
@@ -243,8 +253,20 @@ package object dsl {
object expressions extends ExpressionConversions // scalastyle:ignore
object plans { // scalastyle:ignore
+ def table(ref: String): LogicalPlan =
+ UnresolvedRelation(TableIdentifier(ref), None)
+
+ def table(db: String, ref: String): LogicalPlan =
+ UnresolvedRelation(TableIdentifier(ref, Option(db)), None)
+
implicit class DslLogicalPlan(val logicalPlan: LogicalPlan) {
- def select(exprs: NamedExpression*): LogicalPlan = Project(exprs, logicalPlan)
+ def select(exprs: Expression*): LogicalPlan = {
+ val namedExpressions = exprs.map {
+ case e: NamedExpression => e
+ case e => UnresolvedAlias(e)
+ }
+ Project(namedExpressions, logicalPlan)
+ }
def where(condition: Expression): LogicalPlan = Filter(condition, logicalPlan)
@@ -296,6 +318,14 @@ package object dsl {
analysis.UnresolvedRelation(TableIdentifier(tableName)),
Map.empty, logicalPlan, overwrite, false)
+ def as(alias: String): LogicalPlan = logicalPlan match {
+ case UnresolvedRelation(tbl, _) => UnresolvedRelation(tbl, Option(alias))
+ case plan => SubqueryAlias(alias, plan)
+ }
+
+ def distribute(exprs: Expression*): LogicalPlan =
+ RepartitionByExpression(exprs, logicalPlan)
+
def analyze: LogicalPlan =
EliminateSubqueryAliases(analysis.SimpleAnalyzer.execute(logicalPlan))
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ng/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ng/AstBuilder.scala
new file mode 100644
index 0000000000..5a64c414fb
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ng/AstBuilder.scala
@@ -0,0 +1,1452 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.catalyst.parser.ng
+
+import java.sql.{Date, Timestamp}
+
+import scala.collection.JavaConverters._
+import scala.collection.mutable.ArrayBuffer
+
+import org.antlr.v4.runtime.{ParserRuleContext, Token}
+import org.antlr.v4.runtime.tree.{ParseTree, TerminalNode}
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier}
+import org.apache.spark.sql.catalyst.analysis._
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.parser.ng.SqlBaseParser._
+import org.apache.spark.sql.catalyst.plans._
+import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.CalendarInterval
+import org.apache.spark.util.random.RandomSampler
+
+/**
+ * The AstBuilder converts an ANTLR4 ParseTree into a catalyst Expression, LogicalPlan or
+ * TableIdentifier.
+ */
+class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging {
+ import ParserUtils._
+
+ protected def typedVisit[T](ctx: ParseTree): T = {
+ ctx.accept(this).asInstanceOf[T]
+ }
+
+ override def visitSingleStatement(ctx: SingleStatementContext): LogicalPlan = withOrigin(ctx) {
+ visit(ctx.statement).asInstanceOf[LogicalPlan]
+ }
+
+ override def visitSingleExpression(ctx: SingleExpressionContext): Expression = withOrigin(ctx) {
+ visitNamedExpression(ctx.namedExpression)
+ }
+
+ override def visitSingleTableIdentifier(
+ ctx: SingleTableIdentifierContext): TableIdentifier = withOrigin(ctx) {
+ visitTableIdentifier(ctx.tableIdentifier)
+ }
+
+ override def visitSingleDataType(ctx: SingleDataTypeContext): DataType = withOrigin(ctx) {
+ visit(ctx.dataType).asInstanceOf[DataType]
+ }
+
+ /* ********************************************************************************************
+ * Plan parsing
+ * ******************************************************************************************** */
+ protected def plan(tree: ParserRuleContext): LogicalPlan = typedVisit(tree)
+
+ /**
+ * Make sure we do not try to create a plan for a native command.
+ */
+ override def visitExecuteNativeCommand(ctx: ExecuteNativeCommandContext): LogicalPlan = null
+
+ /**
+ * Create a plan for a SHOW FUNCTIONS command.
+ */
+ override def visitShowFunctions(ctx: ShowFunctionsContext): LogicalPlan = withOrigin(ctx) {
+ import ctx._
+ if (qualifiedName != null) {
+ val names = qualifiedName().identifier().asScala.map(_.getText).toList
+ names match {
+ case db :: name :: Nil =>
+ ShowFunctions(Some(db), Some(name))
+ case name :: Nil =>
+ ShowFunctions(None, Some(name))
+ case _ =>
+ throw new ParseException("SHOW FUNCTIONS unsupported name", ctx)
+ }
+ } else if (pattern != null) {
+ ShowFunctions(None, Some(string(pattern)))
+ } else {
+ ShowFunctions(None, None)
+ }
+ }
+
+ /**
+ * Create a plan for a DESCRIBE FUNCTION command.
+ */
+ override def visitDescribeFunction(ctx: DescribeFunctionContext): LogicalPlan = withOrigin(ctx) {
+ val functionName = ctx.qualifiedName().identifier().asScala.map(_.getText).mkString(".")
+ DescribeFunction(functionName, ctx.EXTENDED != null)
+ }
+
+ /**
+ * Create a top-level plan with Common Table Expressions.
+ */
+ override def visitQuery(ctx: QueryContext): LogicalPlan = withOrigin(ctx) {
+ val query = plan(ctx.queryNoWith)
+
+ // Apply CTEs
+ query.optional(ctx.ctes) {
+ val ctes = ctx.ctes.namedQuery.asScala.map {
+ case nCtx =>
+ val namedQuery = visitNamedQuery(nCtx)
+ (namedQuery.alias, namedQuery)
+ }
+
+ // Check for duplicate names.
+ ctes.groupBy(_._1).filter(_._2.size > 1).foreach {
+ case (name, _) =>
+ throw new ParseException(
+ s"Name '$name' is used for multiple common table expressions", ctx)
+ }
+
+ With(query, ctes.toMap)
+ }
+ }
+
+ /**
+ * Create a named logical plan.
+ *
+ * This is only used for Common Table Expressions.
+ */
+ override def visitNamedQuery(ctx: NamedQueryContext): SubqueryAlias = withOrigin(ctx) {
+ SubqueryAlias(ctx.name.getText, plan(ctx.queryNoWith))
+ }
+
+ /**
+ * Create a logical plan which allows for multiple inserts using one 'from' statement. These
+ * queries have the following SQL form:
+ * {{{
+ * [WITH cte...]?
+ * FROM src
+ * [INSERT INTO tbl1 SELECT *]+
+ * }}}
+ * For example:
+ * {{{
+ * FROM db.tbl1 A
+ * INSERT INTO dbo.tbl1 SELECT * WHERE A.value = 10 LIMIT 5
+ * INSERT INTO dbo.tbl2 SELECT * WHERE A.value = 12
+ * }}}
+ * This (Hive) feature cannot be combined with set-operators.
+ */
+ override def visitMultiInsertQuery(ctx: MultiInsertQueryContext): LogicalPlan = withOrigin(ctx) {
+ val from = visitFromClause(ctx.fromClause)
+
+ // Build the insert clauses.
+ val inserts = ctx.multiInsertQueryBody.asScala.map {
+ body =>
+ assert(body.querySpecification.fromClause == null,
+ "Multi-Insert queries cannot have a FROM clause in their individual SELECT statements",
+ body)
+
+ withQuerySpecification(body.querySpecification, from).
+ // Add organization statements.
+ optionalMap(body.queryOrganization)(withQueryResultClauses).
+ // Add insert.
+ optionalMap(body.insertInto())(withInsertInto)
+ }
+
+ // If there are multiple INSERTS just UNION them together into one query.
+ inserts match {
+ case Seq(query) => query
+ case queries => Union(queries)
+ }
+ }
+
+ /**
+ * Create a logical plan for a regular (single-insert) query.
+ */
+ override def visitSingleInsertQuery(
+ ctx: SingleInsertQueryContext): LogicalPlan = withOrigin(ctx) {
+ plan(ctx.queryTerm).
+ // Add organization statements.
+ optionalMap(ctx.queryOrganization)(withQueryResultClauses).
+ // Add insert.
+ optionalMap(ctx.insertInto())(withInsertInto)
+ }
+
+ /**
+ * Add an INSERT INTO [TABLE]/INSERT OVERWRITE TABLE operation to the logical plan.
+ */
+ private def withInsertInto(
+ ctx: InsertIntoContext,
+ query: LogicalPlan): LogicalPlan = withOrigin(ctx) {
+ val tableIdent = visitTableIdentifier(ctx.tableIdentifier)
+ val partitionKeys = Option(ctx.partitionSpec).map(visitPartitionSpec).getOrElse(Map.empty)
+
+ InsertIntoTable(
+ UnresolvedRelation(tableIdent, None),
+ partitionKeys,
+ query,
+ ctx.OVERWRITE != null,
+ ctx.EXISTS != null)
+ }
+
+ /**
+ * Create a partition specification map.
+ */
+ override def visitPartitionSpec(
+ ctx: PartitionSpecContext): Map[String, Option[String]] = withOrigin(ctx) {
+ ctx.partitionVal.asScala.map { pVal =>
+ val name = pVal.identifier.getText.toLowerCase
+ val value = Option(pVal.constant).map(visitStringConstant)
+ name -> value
+ }.toMap
+ }
+
+ /**
+ * Create a partition specification map without optional values.
+ */
+ protected def visitNonOptionalPartitionSpec(
+ ctx: PartitionSpecContext): Map[String, String] = withOrigin(ctx) {
+ visitPartitionSpec(ctx).mapValues(_.orNull).map(identity)
+ }
+
+ /**
+ * Convert a constant of any type into a string. This is typically used in DDL commands, and its
+ * main purpose is to prevent slight differences due to back to back conversions i.e.:
+ * String -> Literal -> String.
+ */
+ protected def visitStringConstant(ctx: ConstantContext): String = withOrigin(ctx) {
+ ctx match {
+ case s: StringLiteralContext => createString(s)
+ case o => o.getText
+ }
+ }
+
+ /**
+ * Add ORDER BY/SORT BY/CLUSTER BY/DISTRIBUTE BY/LIMIT/WINDOWS clauses to the logical plan. These
+ * clauses determine the shape (ordering/partitioning/rows) of the query result.
+ */
+ private def withQueryResultClauses(
+ ctx: QueryOrganizationContext,
+ query: LogicalPlan): LogicalPlan = withOrigin(ctx) {
+ import ctx._
+
+ // Handle ORDER BY, SORT BY, DISTRIBUTE BY, and CLUSTER BY clause.
+ val withOrder = if (
+ !order.isEmpty && sort.isEmpty && distributeBy.isEmpty && clusterBy.isEmpty) {
+ // ORDER BY ...
+ Sort(order.asScala.map(visitSortItem), global = true, query)
+ } else if (order.isEmpty && !sort.isEmpty && distributeBy.isEmpty && clusterBy.isEmpty) {
+ // SORT BY ...
+ Sort(sort.asScala.map(visitSortItem), global = false, query)
+ } else if (order.isEmpty && sort.isEmpty && !distributeBy.isEmpty && clusterBy.isEmpty) {
+ // DISTRIBUTE BY ...
+ RepartitionByExpression(expressionList(distributeBy), query)
+ } else if (order.isEmpty && !sort.isEmpty && !distributeBy.isEmpty && clusterBy.isEmpty) {
+ // SORT BY ... DISTRIBUTE BY ...
+ Sort(
+ sort.asScala.map(visitSortItem),
+ global = false,
+ RepartitionByExpression(expressionList(distributeBy), query))
+ } else if (order.isEmpty && sort.isEmpty && distributeBy.isEmpty && !clusterBy.isEmpty) {
+ // CLUSTER BY ...
+ val expressions = expressionList(clusterBy)
+ Sort(
+ expressions.map(SortOrder(_, Ascending)),
+ global = false,
+ RepartitionByExpression(expressions, query))
+ } else if (order.isEmpty && sort.isEmpty && distributeBy.isEmpty && clusterBy.isEmpty) {
+ // [EMPTY]
+ query
+ } else {
+ throw new ParseException(
+ "Combination of ORDER BY/SORT BY/DISTRIBUTE BY/CLUSTER BY is not supported", ctx)
+ }
+
+ // WINDOWS
+ val withWindow = withOrder.optionalMap(windows)(withWindows)
+
+ // LIMIT
+ withWindow.optional(limit) {
+ Limit(typedVisit(limit), withWindow)
+ }
+ }
+
+ /**
+ * Create a logical plan using a query specification.
+ */
+ override def visitQuerySpecification(
+ ctx: QuerySpecificationContext): LogicalPlan = withOrigin(ctx) {
+ val from = OneRowRelation.optional(ctx.fromClause) {
+ visitFromClause(ctx.fromClause)
+ }
+ withQuerySpecification(ctx, from)
+ }
+
+ /**
+ * Add a query specification to a logical plan. The query specification is the core of the logical
+ * plan, this is where sourcing (FROM clause), transforming (SELECT TRANSFORM/MAP/REDUCE),
+ * projection (SELECT), aggregation (GROUP BY ... HAVING ...) and filtering (WHERE) takes place.
+ *
+ * Note that query hints are ignored (both by the parser and the builder).
+ */
+ private def withQuerySpecification(
+ ctx: QuerySpecificationContext,
+ relation: LogicalPlan): LogicalPlan = withOrigin(ctx) {
+ import ctx._
+
+ // WHERE
+ def filter(ctx: BooleanExpressionContext, plan: LogicalPlan): LogicalPlan = {
+ Filter(expression(ctx), plan)
+ }
+
+ // Expressions.
+ val expressions = Option(namedExpressionSeq).toSeq
+ .flatMap(_.namedExpression.asScala)
+ .map(typedVisit[Expression])
+
+ // Create either a transform or a regular query.
+ val specType = Option(kind).map(_.getType).getOrElse(SqlBaseParser.SELECT)
+ specType match {
+ case SqlBaseParser.MAP | SqlBaseParser.REDUCE | SqlBaseParser.TRANSFORM =>
+ // Transform
+
+ // Add where.
+ val withFilter = relation.optionalMap(where)(filter)
+
+ // Create the attributes.
+ val (attributes, schemaLess) = if (colTypeList != null) {
+ // Typed return columns.
+ (createStructType(colTypeList).toAttributes, false)
+ } else if (identifierSeq != null) {
+ // Untyped return columns.
+ val attrs = visitIdentifierSeq(identifierSeq).map { name =>
+ AttributeReference(name, StringType, nullable = true)()
+ }
+ (attrs, false)
+ } else {
+ (Seq(AttributeReference("key", StringType)(),
+ AttributeReference("value", StringType)()), true)
+ }
+
+ // Create the transform.
+ ScriptTransformation(
+ expressions,
+ string(script),
+ attributes,
+ withFilter,
+ withScriptIOSchema(inRowFormat, recordWriter, outRowFormat, recordReader, schemaLess))
+
+ case SqlBaseParser.SELECT =>
+ // Regular select
+
+ // Add lateral views.
+ val withLateralView = ctx.lateralView.asScala.foldLeft(relation)(withGenerate)
+
+ // Add where.
+ val withFilter = withLateralView.optionalMap(where)(filter)
+
+ // Add aggregation or a project.
+ val namedExpressions = expressions.map {
+ case e: NamedExpression => e
+ case e: Expression => UnresolvedAlias(e)
+ }
+ val withProject = if (aggregation != null) {
+ withAggregation(aggregation, namedExpressions, withFilter)
+ } else if (namedExpressions.nonEmpty) {
+ Project(namedExpressions, withFilter)
+ } else {
+ withFilter
+ }
+
+ // Having
+ val withHaving = withProject.optional(having) {
+ // Note that we added a cast to boolean. If the expression itself is already boolean,
+ // the optimizer will get rid of the unnecessary cast.
+ Filter(Cast(expression(having), BooleanType), withProject)
+ }
+
+ // Distinct
+ val withDistinct = if (setQuantifier() != null && setQuantifier().DISTINCT() != null) {
+ Distinct(withHaving)
+ } else {
+ withHaving
+ }
+
+ // Window
+ withDistinct.optionalMap(windows)(withWindows)
+ }
+ }
+
+ /**
+ * Create a (Hive based) [[ScriptInputOutputSchema]].
+ */
+ protected def withScriptIOSchema(
+ inRowFormat: RowFormatContext,
+ recordWriter: Token,
+ outRowFormat: RowFormatContext,
+ recordReader: Token,
+ schemaLess: Boolean): ScriptInputOutputSchema = null
+
+ /**
+ * Create a logical plan for a given 'FROM' clause. Note that we support multiple (comma
+ * separated) relations here, these get converted into a single plan by condition-less inner join.
+ */
+ override def visitFromClause(ctx: FromClauseContext): LogicalPlan = withOrigin(ctx) {
+ val from = ctx.relation.asScala.map(plan).reduceLeft(Join(_, _, Inner, None))
+ ctx.lateralView.asScala.foldLeft(from)(withGenerate)
+ }
+
+ /**
+ * Connect two queries by a Set operator.
+ *
+ * Supported Set operators are:
+ * - UNION [DISTINCT]
+ * - UNION ALL
+ * - EXCEPT [DISTINCT]
+ * - INTERSECT [DISTINCT]
+ */
+ override def visitSetOperation(ctx: SetOperationContext): LogicalPlan = withOrigin(ctx) {
+ val left = plan(ctx.left)
+ val right = plan(ctx.right)
+ val all = Option(ctx.setQuantifier()).exists(_.ALL != null)
+ ctx.operator.getType match {
+ case SqlBaseParser.UNION if all =>
+ Union(left, right)
+ case SqlBaseParser.UNION =>
+ Distinct(Union(left, right))
+ case SqlBaseParser.INTERSECT if all =>
+ throw new ParseException("INTERSECT ALL is not supported.", ctx)
+ case SqlBaseParser.INTERSECT =>
+ Intersect(left, right)
+ case SqlBaseParser.EXCEPT if all =>
+ throw new ParseException("EXCEPT ALL is not supported.", ctx)
+ case SqlBaseParser.EXCEPT =>
+ Except(left, right)
+ }
+ }
+
+ /**
+ * Add a [[WithWindowDefinition]] operator to a logical plan.
+ */
+ private def withWindows(
+ ctx: WindowsContext,
+ query: LogicalPlan): LogicalPlan = withOrigin(ctx) {
+ // Collect all window specifications defined in the WINDOW clause.
+ val baseWindowMap = ctx.namedWindow.asScala.map {
+ wCtx =>
+ (wCtx.identifier.getText, typedVisit[WindowSpec](wCtx.windowSpec))
+ }.toMap
+
+ // Handle cases like
+ // window w1 as (partition by p_mfgr order by p_name
+ // range between 2 preceding and 2 following),
+ // w2 as w1
+ val windowMapView = baseWindowMap.mapValues {
+ case WindowSpecReference(name) =>
+ baseWindowMap.get(name) match {
+ case Some(spec: WindowSpecDefinition) =>
+ spec
+ case Some(ref) =>
+ throw new ParseException(s"Window reference '$name' is not a window specification", ctx)
+ case None =>
+ throw new ParseException(s"Cannot resolve window reference '$name'", ctx)
+ }
+ case spec: WindowSpecDefinition => spec
+ }
+
+ // Note that mapValues creates a view instead of materialized map. We force materialization by
+ // mapping over identity.
+ WithWindowDefinition(windowMapView.map(identity), query)
+ }
+
+ /**
+ * Add an [[Aggregate]] to a logical plan.
+ */
+ private def withAggregation(
+ ctx: AggregationContext,
+ selectExpressions: Seq[NamedExpression],
+ query: LogicalPlan): LogicalPlan = withOrigin(ctx) {
+ import ctx._
+ val groupByExpressions = expressionList(groupingExpressions)
+
+ if (GROUPING != null) {
+ // GROUP BY .... GROUPING SETS (...)
+ val expressionMap = groupByExpressions.zipWithIndex.toMap
+ val numExpressions = expressionMap.size
+ val mask = (1 << numExpressions) - 1
+ val masks = ctx.groupingSet.asScala.map {
+ _.expression.asScala.foldLeft(mask) {
+ case (bitmap, eCtx) =>
+ // Find the index of the expression.
+ val e = typedVisit[Expression](eCtx)
+ val index = expressionMap.find(_._1.semanticEquals(e)).map(_._2).getOrElse(
+ throw new ParseException(
+ s"$e doesn't show up in the GROUP BY list", ctx))
+ // 0 means that the column at the given index is a grouping column, 1 means it is not,
+ // so we unset the bit in bitmap.
+ bitmap & ~(1 << (numExpressions - 1 - index))
+ }
+ }
+ GroupingSets(masks, groupByExpressions, query, selectExpressions)
+ } else {
+ // GROUP BY .... (WITH CUBE | WITH ROLLUP)?
+ val mappedGroupByExpressions = if (CUBE != null) {
+ Seq(Cube(groupByExpressions))
+ } else if (ROLLUP != null) {
+ Seq(Rollup(groupByExpressions))
+ } else {
+ groupByExpressions
+ }
+ Aggregate(mappedGroupByExpressions, selectExpressions, query)
+ }
+ }
+
+ /**
+ * Add a [[Generate]] (Lateral View) to a logical plan.
+ */
+ private def withGenerate(
+ query: LogicalPlan,
+ ctx: LateralViewContext): LogicalPlan = withOrigin(ctx) {
+ val expressions = expressionList(ctx.expression)
+
+ // Create the generator.
+ val generator = ctx.qualifiedName.getText.toLowerCase match {
+ case "explode" if expressions.size == 1 =>
+ Explode(expressions.head)
+ case "json_tuple" =>
+ JsonTuple(expressions)
+ case other =>
+ withGenerator(other, expressions, ctx)
+ }
+
+ Generate(
+ generator,
+ join = true,
+ outer = ctx.OUTER != null,
+ Some(ctx.tblName.getText.toLowerCase),
+ ctx.colName.asScala.map(_.getText).map(UnresolvedAttribute.apply),
+ query)
+ }
+
+ /**
+ * Create a [[Generator]]. Override this method in order to support custom Generators.
+ */
+ protected def withGenerator(
+ name: String,
+ expressions: Seq[Expression],
+ ctx: LateralViewContext): Generator = {
+ throw new ParseException(s"Generator function '$name' is not supported", ctx)
+ }
+
+ /**
+ * Create a joins between two or more logical plans.
+ */
+ override def visitJoinRelation(ctx: JoinRelationContext): LogicalPlan = withOrigin(ctx) {
+ /** Build a join between two plans. */
+ def join(ctx: JoinRelationContext, left: LogicalPlan, right: LogicalPlan): Join = {
+ val baseJoinType = ctx.joinType match {
+ case null => Inner
+ case jt if jt.FULL != null => FullOuter
+ case jt if jt.SEMI != null => LeftSemi
+ case jt if jt.LEFT != null => LeftOuter
+ case jt if jt.RIGHT != null => RightOuter
+ case _ => Inner
+ }
+
+ // Resolve the join type and join condition
+ val (joinType, condition) = Option(ctx.joinCriteria) match {
+ case Some(c) if c.USING != null =>
+ val columns = c.identifier.asScala.map { column =>
+ UnresolvedAttribute.quoted(column.getText)
+ }
+ (UsingJoin(baseJoinType, columns), None)
+ case Some(c) if c.booleanExpression != null =>
+ (baseJoinType, Option(expression(c.booleanExpression)))
+ case None if ctx.NATURAL != null =>
+ (NaturalJoin(baseJoinType), None)
+ case None =>
+ (baseJoinType, None)
+ }
+ Join(left, right, joinType, condition)
+ }
+
+ // Handle all consecutive join clauses. ANTLR produces a right nested tree in which the the
+ // first join clause is at the top. However fields of previously referenced tables can be used
+ // in following join clauses. The tree needs to be reversed in order to make this work.
+ var result = plan(ctx.left)
+ var current = ctx
+ while (current != null) {
+ current.right match {
+ case right: JoinRelationContext =>
+ result = join(current, result, plan(right.left))
+ current = right
+ case right =>
+ result = join(current, result, plan(right))
+ current = null
+ }
+ }
+ result
+ }
+
+ /**
+ * Add a [[Sample]] to a logical plan.
+ *
+ * This currently supports the following sampling methods:
+ * - TABLESAMPLE(x ROWS): Sample the table down to the given number of rows.
+ * - TABLESAMPLE(x PERCENT): Sample the table down to the given percentage. Note that percentages
+ * are defined as a number between 0 and 100.
+ * - TABLESAMPLE(BUCKET x OUT OF y): Sample the table down to a 'x' divided by 'y' fraction.
+ */
+ private def withSample(ctx: SampleContext, query: LogicalPlan): LogicalPlan = withOrigin(ctx) {
+ // Create a sampled plan if we need one.
+ def sample(fraction: Double): Sample = {
+ // The range of fraction accepted by Sample is [0, 1]. Because Hive's block sampling
+ // function takes X PERCENT as the input and the range of X is [0, 100], we need to
+ // adjust the fraction.
+ val eps = RandomSampler.roundingEpsilon
+ assert(fraction >= 0.0 - eps && fraction <= 1.0 + eps,
+ s"Sampling fraction ($fraction) must be on interval [0, 1]",
+ ctx)
+ Sample(0.0, fraction, withReplacement = false, (math.random * 1000).toInt, query)(true)
+ }
+
+ ctx.sampleType.getType match {
+ case SqlBaseParser.ROWS =>
+ Limit(expression(ctx.expression), query)
+
+ case SqlBaseParser.PERCENTLIT =>
+ val fraction = ctx.percentage.getText.toDouble
+ sample(fraction / 100.0d)
+
+ case SqlBaseParser.BUCKET if ctx.ON != null =>
+ throw new ParseException("TABLESAMPLE(BUCKET x OUT OF y ON id) is not supported", ctx)
+
+ case SqlBaseParser.BUCKET =>
+ sample(ctx.numerator.getText.toDouble / ctx.denominator.getText.toDouble)
+ }
+ }
+
+ /**
+ * Create a logical plan for a sub-query.
+ */
+ override def visitSubquery(ctx: SubqueryContext): LogicalPlan = withOrigin(ctx) {
+ plan(ctx.queryNoWith)
+ }
+
+ /**
+ * Create an un-aliased table reference. This is typically used for top-level table references,
+ * for example:
+ * {{{
+ * INSERT INTO db.tbl2
+ * TABLE db.tbl1
+ * }}}
+ */
+ override def visitTable(ctx: TableContext): LogicalPlan = withOrigin(ctx) {
+ UnresolvedRelation(visitTableIdentifier(ctx.tableIdentifier), None)
+ }
+
+ /**
+ * Create an aliased table reference. This is typically used in FROM clauses.
+ */
+ override def visitTableName(ctx: TableNameContext): LogicalPlan = withOrigin(ctx) {
+ val table = UnresolvedRelation(
+ visitTableIdentifier(ctx.tableIdentifier),
+ Option(ctx.identifier).map(_.getText))
+ table.optionalMap(ctx.sample)(withSample)
+ }
+
+ /**
+ * Create an inline table (a virtual table in Hive parlance).
+ */
+ override def visitInlineTable(ctx: InlineTableContext): LogicalPlan = withOrigin(ctx) {
+ // Get the backing expressions.
+ val expressions = ctx.expression.asScala.map { eCtx =>
+ val e = expression(eCtx)
+ assert(e.foldable, "All expressions in an inline table must be constants.", eCtx)
+ e
+ }
+
+ // Validate and evaluate the rows.
+ val (structType, structConstructor) = expressions.head.dataType match {
+ case st: StructType =>
+ (st, (e: Expression) => e)
+ case dt =>
+ val st = CreateStruct(Seq(expressions.head)).dataType
+ (st, (e: Expression) => CreateStruct(Seq(e)))
+ }
+ val rows = expressions.map {
+ case expression =>
+ val safe = Cast(structConstructor(expression), structType)
+ safe.eval().asInstanceOf[InternalRow]
+ }
+
+ // Construct attributes.
+ val baseAttributes = structType.toAttributes.map(_.withNullability(true))
+ val attributes = if (ctx.identifierList != null) {
+ val aliases = visitIdentifierList(ctx.identifierList)
+ assert(aliases.size == baseAttributes.size,
+ "Number of aliases must match the number of fields in an inline table.", ctx)
+ baseAttributes.zip(aliases).map(p => p._1.withName(p._2))
+ } else {
+ baseAttributes
+ }
+
+ // Create plan and add an alias if a name has been defined.
+ LocalRelation(attributes, rows).optionalMap(ctx.identifier)(aliasPlan)
+ }
+
+ /**
+ * Create an alias (SubqueryAlias) for a join relation. This is practically the same as
+ * visitAliasedQuery and visitNamedExpression, ANTLR4 however requires us to use 3 different
+ * hooks.
+ */
+ override def visitAliasedRelation(ctx: AliasedRelationContext): LogicalPlan = withOrigin(ctx) {
+ plan(ctx.relation).optionalMap(ctx.sample)(withSample).optionalMap(ctx.identifier)(aliasPlan)
+ }
+
+ /**
+ * Create an alias (SubqueryAlias) for a sub-query. This is practically the same as
+ * visitAliasedRelation and visitNamedExpression, ANTLR4 however requires us to use 3 different
+ * hooks.
+ */
+ override def visitAliasedQuery(ctx: AliasedQueryContext): LogicalPlan = withOrigin(ctx) {
+ plan(ctx.queryNoWith).optionalMap(ctx.sample)(withSample).optionalMap(ctx.identifier)(aliasPlan)
+ }
+
+ /**
+ * Create an alias (SubqueryAlias) for a LogicalPlan.
+ */
+ private def aliasPlan(alias: IdentifierContext, plan: LogicalPlan): LogicalPlan = {
+ SubqueryAlias(alias.getText, plan)
+ }
+
+ /**
+ * Create a Sequence of Strings for a parenthesis enclosed alias list.
+ */
+ override def visitIdentifierList(ctx: IdentifierListContext): Seq[String] = withOrigin(ctx) {
+ visitIdentifierSeq(ctx.identifierSeq)
+ }
+
+ /**
+ * Create a Sequence of Strings for an identifier list.
+ */
+ override def visitIdentifierSeq(ctx: IdentifierSeqContext): Seq[String] = withOrigin(ctx) {
+ ctx.identifier.asScala.map(_.getText)
+ }
+
+ /* ********************************************************************************************
+ * Table Identifier parsing
+ * ******************************************************************************************** */
+ /**
+ * Create a [[TableIdentifier]] from a 'tableName' or 'databaseName'.'tableName' pattern.
+ */
+ override def visitTableIdentifier(
+ ctx: TableIdentifierContext): TableIdentifier = withOrigin(ctx) {
+ TableIdentifier(ctx.table.getText, Option(ctx.db).map(_.getText))
+ }
+
+ /* ********************************************************************************************
+ * Expression parsing
+ * ******************************************************************************************** */
+ /**
+ * Create an expression from the given context. This method just passes the context on to the
+ * vistor and only takes care of typing (We assume that the visitor returns an Expression here).
+ */
+ protected def expression(ctx: ParserRuleContext): Expression = typedVisit(ctx)
+
+ /**
+ * Create sequence of expressions from the given sequence of contexts.
+ */
+ private def expressionList(trees: java.util.List[ExpressionContext]): Seq[Expression] = {
+ trees.asScala.map(expression)
+ }
+
+ /**
+ * Invert a boolean expression if it has a valid NOT clause.
+ */
+ private def invertIfNotDefined(expression: Expression, not: TerminalNode): Expression = {
+ if (not != null) {
+ Not(expression)
+ } else {
+ expression
+ }
+ }
+
+ /**
+ * Create a star (i.e. all) expression; this selects all elements (in the specified object).
+ * Both un-targeted (global) and targeted aliases are supported.
+ */
+ override def visitStar(ctx: StarContext): Expression = withOrigin(ctx) {
+ UnresolvedStar(Option(ctx.qualifiedName()).map(_.identifier.asScala.map(_.getText)))
+ }
+
+ /**
+ * Create an aliased expression if an alias is specified. Both single and multi-aliases are
+ * supported.
+ */
+ override def visitNamedExpression(ctx: NamedExpressionContext): Expression = withOrigin(ctx) {
+ val e = expression(ctx.expression)
+ if (ctx.identifier != null) {
+ Alias(e, ctx.identifier.getText)()
+ } else if (ctx.identifierList != null) {
+ MultiAlias(e, visitIdentifierList(ctx.identifierList))
+ } else {
+ e
+ }
+ }
+
+ /**
+ * Combine a number of boolean expressions into a balanced expression tree. These expressions are
+ * either combined by a logical [[And]] or a logical [[Or]].
+ *
+ * A balanced binary tree is created because regular left recursive trees cause considerable
+ * performance degradations and can cause stack overflows.
+ */
+ override def visitLogicalBinary(ctx: LogicalBinaryContext): Expression = withOrigin(ctx) {
+ val expressionType = ctx.operator.getType
+ val expressionCombiner = expressionType match {
+ case SqlBaseParser.AND => And.apply _
+ case SqlBaseParser.OR => Or.apply _
+ }
+
+ // Collect all similar left hand contexts.
+ val contexts = ArrayBuffer(ctx.right)
+ var current = ctx.left
+ def collectContexts: Boolean = current match {
+ case lbc: LogicalBinaryContext if lbc.operator.getType == expressionType =>
+ contexts += lbc.right
+ current = lbc.left
+ true
+ case _ =>
+ contexts += current
+ false
+ }
+ while (collectContexts) {
+ // No body - all updates take place in the collectContexts.
+ }
+
+ // Reverse the contexts to have them in the same sequence as in the SQL statement & turn them
+ // into expressions.
+ val expressions = contexts.reverse.map(expression)
+
+ // Create a balanced tree.
+ def reduceToExpressionTree(low: Int, high: Int): Expression = high - low match {
+ case 0 =>
+ expressions(low)
+ case 1 =>
+ expressionCombiner(expressions(low), expressions(high))
+ case x =>
+ val mid = low + x / 2
+ expressionCombiner(
+ reduceToExpressionTree(low, mid),
+ reduceToExpressionTree(mid + 1, high))
+ }
+ reduceToExpressionTree(0, expressions.size - 1)
+ }
+
+ /**
+ * Invert a boolean expression.
+ */
+ override def visitLogicalNot(ctx: LogicalNotContext): Expression = withOrigin(ctx) {
+ Not(expression(ctx.booleanExpression()))
+ }
+
+ /**
+ * Create a filtering correlated sub-query. This is not supported yet.
+ */
+ override def visitExists(ctx: ExistsContext): Expression = {
+ throw new ParseException("EXISTS clauses are not supported.", ctx)
+ }
+
+ /**
+ * Create a comparison expression. This compares two expressions. The following comparison
+ * operators are supported:
+ * - Equal: '=' or '=='
+ * - Null-safe Equal: '<=>'
+ * - Not Equal: '<>' or '!='
+ * - Less than: '<'
+ * - Less then or Equal: '<='
+ * - Greater than: '>'
+ * - Greater then or Equal: '>='
+ */
+ override def visitComparison(ctx: ComparisonContext): Expression = withOrigin(ctx) {
+ val left = expression(ctx.left)
+ val right = expression(ctx.right)
+ val operator = ctx.comparisonOperator().getChild(0).asInstanceOf[TerminalNode]
+ operator.getSymbol.getType match {
+ case SqlBaseParser.EQ =>
+ EqualTo(left, right)
+ case SqlBaseParser.NSEQ =>
+ EqualNullSafe(left, right)
+ case SqlBaseParser.NEQ | SqlBaseParser.NEQJ =>
+ Not(EqualTo(left, right))
+ case SqlBaseParser.LT =>
+ LessThan(left, right)
+ case SqlBaseParser.LTE =>
+ LessThanOrEqual(left, right)
+ case SqlBaseParser.GT =>
+ GreaterThan(left, right)
+ case SqlBaseParser.GTE =>
+ GreaterThanOrEqual(left, right)
+ }
+ }
+
+ /**
+ * Create a BETWEEN expression. This tests if an expression lies with in the bounds set by two
+ * other expressions. The inverse can also be created.
+ */
+ override def visitBetween(ctx: BetweenContext): Expression = withOrigin(ctx) {
+ val value = expression(ctx.value)
+ val between = And(
+ GreaterThanOrEqual(value, expression(ctx.lower)),
+ LessThanOrEqual(value, expression(ctx.upper)))
+ invertIfNotDefined(between, ctx.NOT)
+ }
+
+ /**
+ * Create an IN expression. This tests if the value of the left hand side expression is
+ * contained by the sequence of expressions on the right hand side.
+ */
+ override def visitInList(ctx: InListContext): Expression = withOrigin(ctx) {
+ val in = In(expression(ctx.value), ctx.expression().asScala.map(expression))
+ invertIfNotDefined(in, ctx.NOT)
+ }
+
+ /**
+ * Create an IN expression, where the the right hand side is a query. This is unsupported.
+ */
+ override def visitInSubquery(ctx: InSubqueryContext): Expression = {
+ throw new ParseException("IN with a Sub-query is currently not supported.", ctx)
+ }
+
+ /**
+ * Create a (R)LIKE/REGEXP expression.
+ */
+ override def visitLike(ctx: LikeContext): Expression = {
+ val left = expression(ctx.value)
+ val right = expression(ctx.pattern)
+ val like = ctx.like.getType match {
+ case SqlBaseParser.LIKE =>
+ Like(left, right)
+ case SqlBaseParser.RLIKE =>
+ RLike(left, right)
+ }
+ invertIfNotDefined(like, ctx.NOT)
+ }
+
+ /**
+ * Create an IS (NOT) NULL expression.
+ */
+ override def visitNullPredicate(ctx: NullPredicateContext): Expression = withOrigin(ctx) {
+ val value = expression(ctx.value)
+ if (ctx.NOT != null) {
+ IsNotNull(value)
+ } else {
+ IsNull(value)
+ }
+ }
+
+ /**
+ * Create a binary arithmetic expression. The following arithmetic operators are supported:
+ * - Mulitplication: '*'
+ * - Division: '/'
+ * - Hive Long Division: 'DIV'
+ * - Modulo: '%'
+ * - Addition: '+'
+ * - Subtraction: '-'
+ * - Binary AND: '&'
+ * - Binary XOR
+ * - Binary OR: '|'
+ */
+ override def visitArithmeticBinary(ctx: ArithmeticBinaryContext): Expression = withOrigin(ctx) {
+ val left = expression(ctx.left)
+ val right = expression(ctx.right)
+ ctx.operator.getType match {
+ case SqlBaseParser.ASTERISK =>
+ Multiply(left, right)
+ case SqlBaseParser.SLASH =>
+ Divide(left, right)
+ case SqlBaseParser.PERCENT =>
+ Remainder(left, right)
+ case SqlBaseParser.DIV =>
+ Cast(Divide(left, right), LongType)
+ case SqlBaseParser.PLUS =>
+ Add(left, right)
+ case SqlBaseParser.MINUS =>
+ Subtract(left, right)
+ case SqlBaseParser.AMPERSAND =>
+ BitwiseAnd(left, right)
+ case SqlBaseParser.HAT =>
+ BitwiseXor(left, right)
+ case SqlBaseParser.PIPE =>
+ BitwiseOr(left, right)
+ }
+ }
+
+ /**
+ * Create a unary arithmetic expression. The following arithmetic operators are supported:
+ * - Plus: '+'
+ * - Minus: '-'
+ * - Bitwise Not: '~'
+ */
+ override def visitArithmeticUnary(ctx: ArithmeticUnaryContext): Expression = withOrigin(ctx) {
+ val value = expression(ctx.valueExpression)
+ ctx.operator.getType match {
+ case SqlBaseParser.PLUS =>
+ value
+ case SqlBaseParser.MINUS =>
+ UnaryMinus(value)
+ case SqlBaseParser.TILDE =>
+ BitwiseNot(value)
+ }
+ }
+
+ /**
+ * Create a [[Cast]] expression.
+ */
+ override def visitCast(ctx: CastContext): Expression = withOrigin(ctx) {
+ Cast(expression(ctx.expression), typedVisit(ctx.dataType))
+ }
+
+ /**
+ * Create a (windowed) Function expression.
+ */
+ override def visitFunctionCall(ctx: FunctionCallContext): Expression = withOrigin(ctx) {
+ // Create the function call.
+ val name = ctx.qualifiedName.getText
+ val isDistinct = Option(ctx.setQuantifier()).exists(_.DISTINCT != null)
+ val arguments = ctx.expression().asScala.map(expression) match {
+ case Seq(UnresolvedStar(None)) if name.toLowerCase == "count" && !isDistinct =>
+ // Transform COUNT(*) into COUNT(1). Move this to analysis?
+ Seq(Literal(1))
+ case expressions =>
+ expressions
+ }
+ val function = UnresolvedFunction(name, arguments, isDistinct)
+
+ // Check if the function is evaluated in a windowed context.
+ ctx.windowSpec match {
+ case spec: WindowRefContext =>
+ UnresolvedWindowExpression(function, visitWindowRef(spec))
+ case spec: WindowDefContext =>
+ WindowExpression(function, visitWindowDef(spec))
+ case _ => function
+ }
+ }
+
+ /**
+ * Create a reference to a window frame, i.e. [[WindowSpecReference]].
+ */
+ override def visitWindowRef(ctx: WindowRefContext): WindowSpecReference = withOrigin(ctx) {
+ WindowSpecReference(ctx.identifier.getText)
+ }
+
+ /**
+ * Create a window definition, i.e. [[WindowSpecDefinition]].
+ */
+ override def visitWindowDef(ctx: WindowDefContext): WindowSpecDefinition = withOrigin(ctx) {
+ // CLUSTER BY ... | PARTITION BY ... ORDER BY ...
+ val partition = ctx.partition.asScala.map(expression)
+ val order = ctx.sortItem.asScala.map(visitSortItem)
+
+ // RANGE/ROWS BETWEEN ...
+ val frameSpecOption = Option(ctx.windowFrame).map { frame =>
+ val frameType = frame.frameType.getType match {
+ case SqlBaseParser.RANGE => RangeFrame
+ case SqlBaseParser.ROWS => RowFrame
+ }
+
+ SpecifiedWindowFrame(
+ frameType,
+ visitFrameBound(frame.start),
+ Option(frame.end).map(visitFrameBound).getOrElse(CurrentRow))
+ }
+
+ WindowSpecDefinition(
+ partition,
+ order,
+ frameSpecOption.getOrElse(UnspecifiedFrame))
+ }
+
+ /**
+ * Create or resolve a [[FrameBoundary]]. Simple math expressions are allowed for Value
+ * Preceding/Following boundaries. These expressions must be constant (foldable) and return an
+ * integer value.
+ */
+ override def visitFrameBound(ctx: FrameBoundContext): FrameBoundary = withOrigin(ctx) {
+ // We currently only allow foldable integers.
+ def value: Int = {
+ val e = expression(ctx.expression)
+ assert(e.resolved && e.foldable && e.dataType == IntegerType,
+ "Frame bound value must be a constant integer.",
+ ctx)
+ e.eval().asInstanceOf[Int]
+ }
+
+ // Create the FrameBoundary
+ ctx.boundType.getType match {
+ case SqlBaseParser.PRECEDING if ctx.UNBOUNDED != null =>
+ UnboundedPreceding
+ case SqlBaseParser.PRECEDING =>
+ ValuePreceding(value)
+ case SqlBaseParser.CURRENT =>
+ CurrentRow
+ case SqlBaseParser.FOLLOWING if ctx.UNBOUNDED != null =>
+ UnboundedFollowing
+ case SqlBaseParser.FOLLOWING =>
+ ValueFollowing(value)
+ }
+ }
+
+ /**
+ * Create a [[CreateStruct]] expression.
+ */
+ override def visitRowConstructor(ctx: RowConstructorContext): Expression = withOrigin(ctx) {
+ CreateStruct(ctx.expression.asScala.map(expression))
+ }
+
+ /**
+ * Create a [[ScalarSubquery]] expression.
+ */
+ override def visitSubqueryExpression(
+ ctx: SubqueryExpressionContext): Expression = withOrigin(ctx) {
+ ScalarSubquery(plan(ctx.query))
+ }
+
+ /**
+ * Create a value based [[CaseWhen]] expression. This has the following SQL form:
+ * {{{
+ * CASE [expression]
+ * WHEN [value] THEN [expression]
+ * ...
+ * ELSE [expression]
+ * END
+ * }}}
+ */
+ override def visitSimpleCase(ctx: SimpleCaseContext): Expression = withOrigin(ctx) {
+ val e = expression(ctx.valueExpression)
+ val branches = ctx.whenClause.asScala.map { wCtx =>
+ (EqualTo(e, expression(wCtx.condition)), expression(wCtx.result))
+ }
+ CaseWhen(branches, Option(ctx.elseExpression).map(expression))
+ }
+
+ /**
+ * Create a condition based [[CaseWhen]] expression. This has the following SQL syntax:
+ * {{{
+ * CASE
+ * WHEN [predicate] THEN [expression]
+ * ...
+ * ELSE [expression]
+ * END
+ * }}}
+ *
+ * @param ctx the parse tree
+ * */
+ override def visitSearchedCase(ctx: SearchedCaseContext): Expression = withOrigin(ctx) {
+ val branches = ctx.whenClause.asScala.map { wCtx =>
+ (expression(wCtx.condition), expression(wCtx.result))
+ }
+ CaseWhen(branches, Option(ctx.elseExpression).map(expression))
+ }
+
+ /**
+ * Create a dereference expression. The return type depends on the type of the parent, this can
+ * either be a [[UnresolvedAttribute]] (if the parent is an [[UnresolvedAttribute]]), or an
+ * [[UnresolvedExtractValue]] if the parent is some expression.
+ */
+ override def visitDereference(ctx: DereferenceContext): Expression = withOrigin(ctx) {
+ val attr = ctx.fieldName.getText
+ expression(ctx.base) match {
+ case UnresolvedAttribute(nameParts) =>
+ UnresolvedAttribute(nameParts :+ attr)
+ case e =>
+ UnresolvedExtractValue(e, Literal(attr))
+ }
+ }
+
+ /**
+ * Create an [[UnresolvedAttribute]] expression.
+ */
+ override def visitColumnReference(ctx: ColumnReferenceContext): Expression = withOrigin(ctx) {
+ UnresolvedAttribute.quoted(ctx.getText)
+ }
+
+ /**
+ * Create an [[UnresolvedExtractValue]] expression, this is used for subscript access to an array.
+ */
+ override def visitSubscript(ctx: SubscriptContext): Expression = withOrigin(ctx) {
+ UnresolvedExtractValue(expression(ctx.value), expression(ctx.index))
+ }
+
+ /**
+ * Create an expression for an expression between parentheses. This is need because the ANTLR
+ * visitor cannot automatically convert the nested context into an expression.
+ */
+ override def visitParenthesizedExpression(
+ ctx: ParenthesizedExpressionContext): Expression = withOrigin(ctx) {
+ expression(ctx.expression)
+ }
+
+ /**
+ * Create a [[SortOrder]] expression.
+ */
+ override def visitSortItem(ctx: SortItemContext): SortOrder = withOrigin(ctx) {
+ if (ctx.DESC != null) {
+ SortOrder(expression(ctx.expression), Descending)
+ } else {
+ SortOrder(expression(ctx.expression), Ascending)
+ }
+ }
+
+ /**
+ * Create a typed Literal expression. A typed literal has the following SQL syntax:
+ * {{{
+ * [TYPE] '[VALUE]'
+ * }}}
+ * Currently Date and Timestamp typed literals are supported.
+ *
+ * TODO what the added value of this over casting?
+ */
+ override def visitTypeConstructor(ctx: TypeConstructorContext): Literal = withOrigin(ctx) {
+ val value = string(ctx.STRING)
+ ctx.identifier.getText.toUpperCase match {
+ case "DATE" =>
+ Literal(Date.valueOf(value))
+ case "TIMESTAMP" =>
+ Literal(Timestamp.valueOf(value))
+ case other =>
+ throw new ParseException(s"Literals of type '$other' are currently not supported.", ctx)
+ }
+ }
+
+ /**
+ * Create a NULL literal expression.
+ */
+ override def visitNullLiteral(ctx: NullLiteralContext): Literal = withOrigin(ctx) {
+ Literal(null)
+ }
+
+ /**
+ * Create a Boolean literal expression.
+ */
+ override def visitBooleanLiteral(ctx: BooleanLiteralContext): Literal = withOrigin(ctx) {
+ if (ctx.getText.toBoolean) {
+ Literal.TrueLiteral
+ } else {
+ Literal.FalseLiteral
+ }
+ }
+
+ /**
+ * Create an integral literal expression. The code selects the most narrow integral type
+ * possible, either a BigDecimal, a Long or an Integer is returned.
+ */
+ override def visitIntegerLiteral(ctx: IntegerLiteralContext): Literal = withOrigin(ctx) {
+ BigDecimal(ctx.getText) match {
+ case v if v.isValidInt =>
+ Literal(v.intValue())
+ case v if v.isValidLong =>
+ Literal(v.longValue())
+ case v => Literal(v.underlying())
+ }
+ }
+
+ /**
+ * Create a double literal for a number denoted in scientifc notation.
+ */
+ override def visitScientificDecimalLiteral(
+ ctx: ScientificDecimalLiteralContext): Literal = withOrigin(ctx) {
+ Literal(ctx.getText.toDouble)
+ }
+
+ /**
+ * Create a decimal literal for a regular decimal number.
+ */
+ override def visitDecimalLiteral(ctx: DecimalLiteralContext): Literal = withOrigin(ctx) {
+ Literal(BigDecimal(ctx.getText).underlying())
+ }
+
+ /** Create a numeric literal expression. */
+ private def numericLiteral(ctx: NumberContext)(f: String => Any): Literal = withOrigin(ctx) {
+ val raw = ctx.getText
+ try {
+ Literal(f(raw.substring(0, raw.length - 1)))
+ } catch {
+ case e: NumberFormatException =>
+ throw new ParseException(e.getMessage, ctx)
+ }
+ }
+
+ /**
+ * Create a Byte Literal expression.
+ */
+ override def visitTinyIntLiteral(ctx: TinyIntLiteralContext): Literal = numericLiteral(ctx) {
+ _.toByte
+ }
+
+ /**
+ * Create a Short Literal expression.
+ */
+ override def visitSmallIntLiteral(ctx: SmallIntLiteralContext): Literal = numericLiteral(ctx) {
+ _.toShort
+ }
+
+ /**
+ * Create a Long Literal expression.
+ */
+ override def visitBigIntLiteral(ctx: BigIntLiteralContext): Literal = numericLiteral(ctx) {
+ _.toLong
+ }
+
+ /**
+ * Create a Double Literal expression.
+ */
+ override def visitDoubleLiteral(ctx: DoubleLiteralContext): Literal = numericLiteral(ctx) {
+ _.toDouble
+ }
+
+ /**
+ * Create a String literal expression.
+ */
+ override def visitStringLiteral(ctx: StringLiteralContext): Literal = withOrigin(ctx) {
+ Literal(createString(ctx))
+ }
+
+ /**
+ * Create a String from a string literal context. This supports multiple consecutive string
+ * literals, these are concatenated, for example this expression "'hello' 'world'" will be
+ * converted into "helloworld".
+ *
+ * Special characters can be escaped by using Hive/C-style escaping.
+ */
+ private def createString(ctx: StringLiteralContext): String = {
+ ctx.STRING().asScala.map(string).mkString
+ }
+
+ /**
+ * Create a [[CalendarInterval]] literal expression. An interval expression can contain multiple
+ * unit value pairs, for instance: interval 2 months 2 days.
+ */
+ override def visitInterval(ctx: IntervalContext): Literal = withOrigin(ctx) {
+ val intervals = ctx.intervalField.asScala.map(visitIntervalField)
+ assert(intervals.nonEmpty, "at least one time unit should be given for interval literal", ctx)
+ Literal(intervals.reduce(_.add(_)))
+ }
+
+ /**
+ * Create a [[CalendarInterval]] for a unit value pair. Two unit configuration types are
+ * supported:
+ * - Single unit.
+ * - From-To unit (only 'YEAR TO MONTH' and 'DAY TO SECOND' are supported).
+ */
+ override def visitIntervalField(ctx: IntervalFieldContext): CalendarInterval = withOrigin(ctx) {
+ import ctx._
+ val s = value.getText
+ val interval = (unit.getText.toLowerCase, Option(to).map(_.getText.toLowerCase)) match {
+ case (u, None) if u.endsWith("s") =>
+ // Handle plural forms, e.g: yearS/monthS/weekS/dayS/hourS/minuteS/hourS/...
+ CalendarInterval.fromSingleUnitString(u.substring(0, u.length - 1), s)
+ case (u, None) =>
+ CalendarInterval.fromSingleUnitString(u, s)
+ case ("year", Some("month")) =>
+ CalendarInterval.fromYearMonthString(s)
+ case ("day", Some("second")) =>
+ CalendarInterval.fromDayTimeString(s)
+ case (from, Some(t)) =>
+ throw new ParseException(s"Intervals FROM $from TO $t are not supported.", ctx)
+ }
+ assert(interval != null, "No interval can be constructed", ctx)
+ interval
+ }
+
+ /* ********************************************************************************************
+ * DataType parsing
+ * ******************************************************************************************** */
+ /**
+ * Resolve/create a primitive type.
+ */
+ override def visitPrimitiveDataType(ctx: PrimitiveDataTypeContext): DataType = withOrigin(ctx) {
+ (ctx.identifier.getText.toLowerCase, ctx.INTEGER_VALUE().asScala.toList) match {
+ case ("boolean", Nil) => BooleanType
+ case ("tinyint" | "byte", Nil) => ByteType
+ case ("smallint" | "short", Nil) => ShortType
+ case ("int" | "integer", Nil) => IntegerType
+ case ("bigint" | "long", Nil) => LongType
+ case ("float", Nil) => FloatType
+ case ("double", Nil) => DoubleType
+ case ("date", Nil) => DateType
+ case ("timestamp", Nil) => TimestampType
+ case ("char" | "varchar" | "string", Nil) => StringType
+ case ("char" | "varchar", _ :: Nil) => StringType
+ case ("binary", Nil) => BinaryType
+ case ("decimal", Nil) => DecimalType.USER_DEFAULT
+ case ("decimal", precision :: Nil) => DecimalType(precision.getText.toInt, 0)
+ case ("decimal", precision :: scale :: Nil) =>
+ DecimalType(precision.getText.toInt, scale.getText.toInt)
+ case (dt, params) =>
+ throw new ParseException(
+ s"DataType $dt${params.mkString("(", ",", ")")} is not supported.", ctx)
+ }
+ }
+
+ /**
+ * Create a complex DataType. Arrays, Maps and Structures are supported.
+ */
+ override def visitComplexDataType(ctx: ComplexDataTypeContext): DataType = withOrigin(ctx) {
+ ctx.complex.getType match {
+ case SqlBaseParser.ARRAY =>
+ ArrayType(typedVisit(ctx.dataType(0)))
+ case SqlBaseParser.MAP =>
+ MapType(typedVisit(ctx.dataType(0)), typedVisit(ctx.dataType(1)))
+ case SqlBaseParser.STRUCT =>
+ createStructType(ctx.colTypeList())
+ }
+ }
+
+ /**
+ * Create a [[StructType]] from a sequence of [[StructField]]s.
+ */
+ protected def createStructType(ctx: ColTypeListContext): StructType = {
+ StructType(Option(ctx).toSeq.flatMap(visitColTypeList))
+ }
+
+ /**
+ * Create a [[StructType]] from a number of column definitions.
+ */
+ override def visitColTypeList(ctx: ColTypeListContext): Seq[StructField] = withOrigin(ctx) {
+ ctx.colType().asScala.map(visitColType)
+ }
+
+ /**
+ * Create a [[StructField]] from a column definition.
+ */
+ override def visitColType(ctx: ColTypeContext): StructField = withOrigin(ctx) {
+ import ctx._
+
+ // Add the comment to the metadata.
+ val builder = new MetadataBuilder
+ if (STRING != null) {
+ builder.putString("comment", string(STRING))
+ }
+
+ StructField(identifier.getText, typedVisit(dataType), nullable = true, builder.build())
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ng/ParseDriver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ng/ParseDriver.scala
new file mode 100644
index 0000000000..c9a286374c
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ng/ParseDriver.scala
@@ -0,0 +1,240 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.catalyst.parser.ng
+
+import org.antlr.v4.runtime._
+import org.antlr.v4.runtime.atn.PredictionMode
+import org.antlr.v4.runtime.misc.ParseCancellationException
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.catalyst.TableIdentifier
+import org.apache.spark.sql.catalyst.expressions.Expression
+import org.apache.spark.sql.catalyst.parser.ParserInterface
+import org.apache.spark.sql.catalyst.parser.ng.SqlBaseParser._
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.trees.Origin
+import org.apache.spark.sql.types.DataType
+
+/**
+ * Base SQL parsing infrastructure.
+ */
+abstract class AbstractSqlParser extends ParserInterface with Logging {
+
+ /** Creates/Resolves DataType for a given SQL string. */
+ def parseDataType(sqlText: String): DataType = parse(sqlText) { parser =>
+ // TODO add this to the parser interface.
+ astBuilder.visitSingleDataType(parser.singleDataType())
+ }
+
+ /** Creates Expression for a given SQL string. */
+ override def parseExpression(sqlText: String): Expression = parse(sqlText) { parser =>
+ astBuilder.visitSingleExpression(parser.singleExpression())
+ }
+
+ /** Creates TableIdentifier for a given SQL string. */
+ override def parseTableIdentifier(sqlText: String): TableIdentifier = parse(sqlText) { parser =>
+ astBuilder.visitSingleTableIdentifier(parser.singleTableIdentifier())
+ }
+
+ /** Creates LogicalPlan for a given SQL string. */
+ override def parsePlan(sqlText: String): LogicalPlan = parse(sqlText) { parser =>
+ astBuilder.visitSingleStatement(parser.singleStatement()) match {
+ case plan: LogicalPlan => plan
+ case _ => nativeCommand(sqlText)
+ }
+ }
+
+ /** Get the builder (visitor) which converts a ParseTree into a AST. */
+ protected def astBuilder: AstBuilder
+
+ /** Create a native command, or fail when this is not supported. */
+ protected def nativeCommand(sqlText: String): LogicalPlan = {
+ val position = Origin(None, None)
+ throw new ParseException(Option(sqlText), "Unsupported SQL statement", position, position)
+ }
+
+ protected def parse[T](command: String)(toResult: SqlBaseParser => T): T = {
+ logInfo(s"Parsing command: $command")
+
+ val lexer = new SqlBaseLexer(new ANTLRNoCaseStringStream(command))
+ lexer.removeErrorListeners()
+ lexer.addErrorListener(ParseErrorListener)
+
+ val tokenStream = new CommonTokenStream(lexer)
+ val parser = new SqlBaseParser(tokenStream)
+ parser.addParseListener(PostProcessor)
+ parser.removeErrorListeners()
+ parser.addErrorListener(ParseErrorListener)
+
+ try {
+ try {
+ // first, try parsing with potentially faster SLL mode
+ parser.getInterpreter.setPredictionMode(PredictionMode.SLL)
+ toResult(parser)
+ }
+ catch {
+ case e: ParseCancellationException =>
+ // if we fail, parse with LL mode
+ tokenStream.reset() // rewind input stream
+ parser.reset()
+
+ // Try Again.
+ parser.getInterpreter.setPredictionMode(PredictionMode.LL)
+ toResult(parser)
+ }
+ }
+ catch {
+ case e: ParseException if e.command.isDefined =>
+ throw e
+ case e: ParseException =>
+ throw e.withCommand(command)
+ case e: AnalysisException =>
+ val position = Origin(e.line, e.startPosition)
+ throw new ParseException(Option(command), e.message, position, position)
+ }
+ }
+}
+
+/**
+ * Concrete SQL parser for Catalyst-only SQL statements.
+ */
+object CatalystSqlParser extends AbstractSqlParser {
+ val astBuilder = new AstBuilder
+}
+
+/**
+ * This string stream provides the lexer with upper case characters only. This greatly simplifies
+ * lexing the stream, while we can maintain the original command.
+ *
+ * This is based on Hive's org.apache.hadoop.hive.ql.parse.ParseDriver.ANTLRNoCaseStringStream
+ *
+ * The comment below (taken from the original class) describes the rationale for doing this:
+ *
+ * This class provides and implementation for a case insensitive token checker for the lexical
+ * analysis part of antlr. By converting the token stream into upper case at the time when lexical
+ * rules are checked, this class ensures that the lexical rules need to just match the token with
+ * upper case letters as opposed to combination of upper case and lower case characters. This is
+ * purely used for matching lexical rules. The actual token text is stored in the same way as the
+ * user input without actually converting it into an upper case. The token values are generated by
+ * the consume() function of the super class ANTLRStringStream. The LA() function is the lookahead
+ * function and is purely used for matching lexical rules. This also means that the grammar will
+ * only accept capitalized tokens in case it is run from other tools like antlrworks which do not
+ * have the ANTLRNoCaseStringStream implementation.
+ */
+
+private[parser] class ANTLRNoCaseStringStream(input: String) extends ANTLRInputStream(input) {
+ override def LA(i: Int): Int = {
+ val la = super.LA(i)
+ if (la == 0 || la == IntStream.EOF) la
+ else Character.toUpperCase(la)
+ }
+}
+
+/**
+ * The ParseErrorListener converts parse errors into AnalysisExceptions.
+ */
+case object ParseErrorListener extends BaseErrorListener {
+ override def syntaxError(
+ recognizer: Recognizer[_, _],
+ offendingSymbol: scala.Any,
+ line: Int,
+ charPositionInLine: Int,
+ msg: String,
+ e: RecognitionException): Unit = {
+ val position = Origin(Some(line), Some(charPositionInLine))
+ throw new ParseException(None, msg, position, position)
+ }
+}
+
+/**
+ * A [[ParseException]] is an [[AnalysisException]] that is thrown during the parse process. It
+ * contains fields and an extended error message that make reporting and diagnosing errors easier.
+ */
+class ParseException(
+ val command: Option[String],
+ message: String,
+ val start: Origin,
+ val stop: Origin) extends AnalysisException(message, start.line, start.startPosition) {
+
+ def this(message: String, ctx: ParserRuleContext) = {
+ this(Option(ParserUtils.command(ctx)),
+ message,
+ ParserUtils.position(ctx.getStart),
+ ParserUtils.position(ctx.getStop))
+ }
+
+ override def getMessage: String = {
+ val builder = new StringBuilder
+ builder ++= "\n" ++= message
+ start match {
+ case Origin(Some(l), Some(p)) =>
+ builder ++= s"(line $l, pos $p)\n"
+ command.foreach { cmd =>
+ val (above, below) = cmd.split("\n").splitAt(l)
+ builder ++= "\n== SQL ==\n"
+ above.foreach(builder ++= _ += '\n')
+ builder ++= (0 until p).map(_ => "-").mkString("") ++= "^^^\n"
+ below.foreach(builder ++= _ += '\n')
+ }
+ case _ =>
+ command.foreach { cmd =>
+ builder ++= "\n== SQL ==\n" ++= cmd
+ }
+ }
+ builder.toString
+ }
+
+ def withCommand(cmd: String): ParseException = {
+ new ParseException(Option(cmd), message, start, stop)
+ }
+}
+
+/**
+ * The post-processor validates & cleans-up the parse tree during the parse process.
+ */
+case object PostProcessor extends SqlBaseBaseListener {
+
+ /** Remove the back ticks from an Identifier. */
+ override def exitQuotedIdentifier(ctx: QuotedIdentifierContext): Unit = {
+ replaceTokenByIdentifier(ctx, 1) { token =>
+ // Remove the double back ticks in the string.
+ token.setText(token.getText.replace("``", "`"))
+ token
+ }
+ }
+
+ /** Treat non-reserved keywords as Identifiers. */
+ override def exitNonReserved(ctx: NonReservedContext): Unit = {
+ replaceTokenByIdentifier(ctx, 0)(identity)
+ }
+
+ private def replaceTokenByIdentifier(
+ ctx: ParserRuleContext,
+ stripMargins: Int)(
+ f: CommonToken => CommonToken = identity): Unit = {
+ val parent = ctx.getParent
+ parent.removeLastChild()
+ val token = ctx.getChild(0).getPayload.asInstanceOf[Token]
+ parent.addChild(f(new CommonToken(
+ new org.antlr.v4.runtime.misc.Pair(token.getTokenSource, token.getInputStream),
+ SqlBaseParser.IDENTIFIER,
+ token.getChannel,
+ token.getStartIndex + stripMargins,
+ token.getStopIndex - stripMargins)))
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ng/ParserUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ng/ParserUtils.scala
new file mode 100644
index 0000000000..1fbfa763b4
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ng/ParserUtils.scala
@@ -0,0 +1,118 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.catalyst.parser.ng
+
+import org.antlr.v4.runtime.{CharStream, ParserRuleContext, Token}
+import org.antlr.v4.runtime.misc.Interval
+import org.antlr.v4.runtime.tree.TerminalNode
+
+import org.apache.spark.sql.catalyst.parser.ParseUtils.unescapeSQLString
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin}
+
+/**
+ * A collection of utility methods for use during the parsing process.
+ */
+object ParserUtils {
+ /** Get the command which created the token. */
+ def command(ctx: ParserRuleContext): String = {
+ command(ctx.getStart.getInputStream)
+ }
+
+ /** Get the command which created the token. */
+ def command(stream: CharStream): String = {
+ stream.getText(Interval.of(0, stream.size()))
+ }
+
+ /** Get the code that creates the given node. */
+ def source(ctx: ParserRuleContext): String = {
+ val stream = ctx.getStart.getInputStream
+ stream.getText(Interval.of(ctx.getStart.getStartIndex, ctx.getStop.getStopIndex))
+ }
+
+ /** Get all the text which comes after the given rule. */
+ def remainder(ctx: ParserRuleContext): String = remainder(ctx.getStop)
+
+ /** Get all the text which comes after the given token. */
+ def remainder(token: Token): String = {
+ val stream = token.getInputStream
+ val interval = Interval.of(token.getStopIndex + 1, stream.size())
+ stream.getText(interval)
+ }
+
+ /** Convert a string token into a string. */
+ def string(token: Token): String = unescapeSQLString(token.getText)
+
+ /** Convert a string node into a string. */
+ def string(node: TerminalNode): String = unescapeSQLString(node.getText)
+
+ /** Get the origin (line and position) of the token. */
+ def position(token: Token): Origin = {
+ Origin(Option(token.getLine), Option(token.getCharPositionInLine))
+ }
+
+ /** Assert if a condition holds. If it doesn't throw a parse exception. */
+ def assert(f: => Boolean, message: String, ctx: ParserRuleContext): Unit = {
+ if (!f) {
+ throw new ParseException(message, ctx)
+ }
+ }
+
+ /**
+ * Register the origin of the context. Any TreeNode created in the closure will be assigned the
+ * registered origin. This method restores the previously set origin after completion of the
+ * closure.
+ */
+ def withOrigin[T](ctx: ParserRuleContext)(f: => T): T = {
+ val current = CurrentOrigin.get
+ CurrentOrigin.set(position(ctx.getStart))
+ try {
+ f
+ } finally {
+ CurrentOrigin.set(current)
+ }
+ }
+
+ /** Some syntactic sugar which makes it easier to work with optional clauses for LogicalPlans. */
+ implicit class EnhancedLogicalPlan(val plan: LogicalPlan) extends AnyVal {
+ /**
+ * Create a plan using the block of code when the given context exists. Otherwise return the
+ * original plan.
+ */
+ def optional(ctx: AnyRef)(f: => LogicalPlan): LogicalPlan = {
+ if (ctx != null) {
+ f
+ } else {
+ plan
+ }
+ }
+
+ /**
+ * Map a [[LogicalPlan]] to another [[LogicalPlan]] if the passed context exists using the
+ * passed function. The original plan is returned when the context does not exist.
+ */
+ def optionalMap[C <: ParserRuleContext](
+ ctx: C)(
+ f: (C, LogicalPlan) => LogicalPlan): LogicalPlan = {
+ if (ctx != null) {
+ f(ctx, plan)
+ } else {
+ plan
+ }
+ }
+ }
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/CatalystQlSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/CatalystQlSuite.scala
index c068e895b6..223485e292 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/CatalystQlSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/CatalystQlSuite.scala
@@ -21,15 +21,20 @@ import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.parser.ng.CatalystSqlParser
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.unsafe.types.CalendarInterval
class CatalystQlSuite extends PlanTest {
val parser = new CatalystQl()
+ import org.apache.spark.sql.catalyst.dsl.expressions._
+ import org.apache.spark.sql.catalyst.dsl.plans._
+
+ val star = UnresolvedAlias(UnresolvedStar(None))
test("test case insensitive") {
- val result = Project(UnresolvedAlias(Literal(1)):: Nil, OneRowRelation)
+ val result = OneRowRelation.select(1)
assert(result === parser.parsePlan("seLect 1"))
assert(result === parser.parsePlan("select 1"))
assert(result === parser.parsePlan("SELECT 1"))
@@ -37,52 +42,31 @@ class CatalystQlSuite extends PlanTest {
test("test NOT operator with comparison operations") {
val parsed = parser.parsePlan("SELECT NOT TRUE > TRUE")
- val expected = Project(
- UnresolvedAlias(
- Not(
- GreaterThan(Literal(true), Literal(true)))
- ) :: Nil,
- OneRowRelation)
+ val expected = OneRowRelation.select(Not(GreaterThan(true, true)))
comparePlans(parsed, expected)
}
test("test Union Distinct operator") {
- val parsed1 = parser.parsePlan("SELECT * FROM t0 UNION SELECT * FROM t1")
- val parsed2 = parser.parsePlan("SELECT * FROM t0 UNION DISTINCT SELECT * FROM t1")
- val expected =
- Project(UnresolvedAlias(UnresolvedStar(None)) :: Nil,
- SubqueryAlias("u_1",
- Distinct(
- Union(
- Project(UnresolvedAlias(UnresolvedStar(None)) :: Nil,
- UnresolvedRelation(TableIdentifier("t0"), None)),
- Project(UnresolvedAlias(UnresolvedStar(None)) :: Nil,
- UnresolvedRelation(TableIdentifier("t1"), None))))))
+ val parsed1 = parser.parsePlan(
+ "SELECT * FROM t0 UNION SELECT * FROM t1")
+ val parsed2 = parser.parsePlan(
+ "SELECT * FROM t0 UNION DISTINCT SELECT * FROM t1")
+ val expected = Distinct(Union(table("t0").select(star), table("t1").select(star)))
+ .as("u_1").select(star)
comparePlans(parsed1, expected)
comparePlans(parsed2, expected)
}
test("test Union All operator") {
val parsed = parser.parsePlan("SELECT * FROM t0 UNION ALL SELECT * FROM t1")
- val expected =
- Project(UnresolvedAlias(UnresolvedStar(None)) :: Nil,
- SubqueryAlias("u_1",
- Union(
- Project(UnresolvedAlias(UnresolvedStar(None)) :: Nil,
- UnresolvedRelation(TableIdentifier("t0"), None)),
- Project(UnresolvedAlias(UnresolvedStar(None)) :: Nil,
- UnresolvedRelation(TableIdentifier("t1"), None)))))
+ val expected = Union(table("t0").select(star), table("t1").select(star)).as("u_1").select(star)
comparePlans(parsed, expected)
}
test("support hive interval literal") {
def checkInterval(sql: String, result: CalendarInterval): Unit = {
val parsed = parser.parsePlan(sql)
- val expected = Project(
- UnresolvedAlias(
- Literal(result)
- ) :: Nil,
- OneRowRelation)
+ val expected = OneRowRelation.select(Literal(result))
comparePlans(parsed, expected)
}
@@ -129,11 +113,7 @@ class CatalystQlSuite extends PlanTest {
test("support scientific notation") {
def assertRight(input: String, output: Double): Unit = {
val parsed = parser.parsePlan("SELECT " + input)
- val expected = Project(
- UnresolvedAlias(
- Literal(output)
- ) :: Nil,
- OneRowRelation)
+ val expected = OneRowRelation.select(Literal(output))
comparePlans(parsed, expected)
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DataTypeParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DataTypeParserSuite.scala
index 7d3608033b..d9bd33c50a 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DataTypeParserSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DataTypeParserSuite.scala
@@ -18,19 +18,24 @@
package org.apache.spark.sql.catalyst.parser
import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.catalyst.parser.ng.{CatalystSqlParser, ParseException}
import org.apache.spark.sql.types._
-class DataTypeParserSuite extends SparkFunSuite {
+abstract class AbstractDataTypeParserSuite extends SparkFunSuite {
+
+ def parse(sql: String): DataType
def checkDataType(dataTypeString: String, expectedDataType: DataType): Unit = {
test(s"parse ${dataTypeString.replace("\n", "")}") {
- assert(DataTypeParser.parse(dataTypeString) === expectedDataType)
+ assert(parse(dataTypeString) === expectedDataType)
}
}
+ def intercept(sql: String)
+
def unsupported(dataTypeString: String): Unit = {
test(s"$dataTypeString is not supported") {
- intercept[DataTypeException](DataTypeParser.parse(dataTypeString))
+ intercept(dataTypeString)
}
}
@@ -97,13 +102,6 @@ class DataTypeParserSuite extends SparkFunSuite {
StructField("arrAy", ArrayType(DoubleType, true), true) ::
StructField("anotherArray", ArrayType(StringType, true), true) :: Nil)
)
- // A column name can be a reserved word in our DDL parser and SqlParser.
- checkDataType(
- "Struct<TABLE: string, CASE:boolean>",
- StructType(
- StructField("TABLE", StringType, true) ::
- StructField("CASE", BooleanType, true) :: Nil)
- )
// Use backticks to quote column names having special characters.
checkDataType(
"struct<`x+y`:int, `!@#$%^&*()`:string, `1_2.345<>:\"`:varchar(20)>",
@@ -118,6 +116,43 @@ class DataTypeParserSuite extends SparkFunSuite {
unsupported("it is not a data type")
unsupported("struct<x+y: int, 1.1:timestamp>")
unsupported("struct<x: int")
+}
+
+class DataTypeParserSuite extends AbstractDataTypeParserSuite {
+ override def intercept(sql: String): Unit =
+ intercept[DataTypeException](DataTypeParser.parse(sql))
+
+ override def parse(sql: String): DataType =
+ DataTypeParser.parse(sql)
+
+ // A column name can be a reserved word in our DDL parser and SqlParser.
+ checkDataType(
+ "Struct<TABLE: string, CASE:boolean>",
+ StructType(
+ StructField("TABLE", StringType, true) ::
+ StructField("CASE", BooleanType, true) :: Nil)
+ )
+
unsupported("struct<x int, y string>")
+
unsupported("struct<`x``y` int>")
}
+
+class CatalystQlDataTypeParserSuite extends AbstractDataTypeParserSuite {
+ override def intercept(sql: String): Unit =
+ intercept[ParseException](CatalystSqlParser.parseDataType(sql))
+
+ override def parse(sql: String): DataType =
+ CatalystSqlParser.parseDataType(sql)
+
+ // A column name can be a reserved word in our DDL parser and SqlParser.
+ unsupported("Struct<TABLE: string, CASE:boolean>")
+
+ checkDataType(
+ "struct<x int, y string>",
+ (new StructType).add("x", IntegerType).add("y", StringType))
+
+ checkDataType(
+ "struct<`x``y` int>",
+ (new StructType).add("x`y", IntegerType))
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ng/ErrorParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ng/ErrorParserSuite.scala
new file mode 100644
index 0000000000..1963fc368f
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ng/ErrorParserSuite.scala
@@ -0,0 +1,67 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.catalyst.parser.ng
+
+import org.apache.spark.SparkFunSuite
+
+/**
+ * Test various parser errors.
+ */
+class ErrorParserSuite extends SparkFunSuite {
+ def intercept(sql: String, line: Int, startPosition: Int, messages: String*): Unit = {
+ val e = intercept[ParseException](CatalystSqlParser.parsePlan(sql))
+
+ // Check position.
+ assert(e.line.isDefined)
+ assert(e.line.get === line)
+ assert(e.startPosition.isDefined)
+ assert(e.startPosition.get === startPosition)
+
+ // Check messages.
+ val error = e.getMessage
+ messages.foreach { message =>
+ assert(error.contains(message))
+ }
+ }
+
+ test("no viable input") {
+ intercept("select from tbl", 1, 7, "no viable alternative at input", "-------^^^")
+ intercept("select\nfrom tbl", 2, 0, "no viable alternative at input", "^^^")
+ intercept("select ((r + 1) ", 1, 16, "no viable alternative at input", "----------------^^^")
+ }
+
+ test("extraneous input") {
+ intercept("select 1 1", 1, 9, "extraneous input '1' expecting", "---------^^^")
+ intercept("select *\nfrom r as q t", 2, 12, "extraneous input", "------------^^^")
+ }
+
+ test("mismatched input") {
+ intercept("select * from r order by q from t", 1, 27,
+ "mismatched input",
+ "---------------------------^^^")
+ intercept("select *\nfrom r\norder by q\nfrom t", 4, 0, "mismatched input", "^^^")
+ }
+
+ test("semantic errors") {
+ intercept("select *\nfrom r\norder by q\ncluster by q", 3, 0,
+ "Combination of ORDER BY/SORT BY/DISTRIBUTE BY/CLUSTER BY is not supported",
+ "^^^")
+ intercept("select * from r where a in (select * from t)", 1, 24,
+ "IN with a Sub-query is currently not supported",
+ "------------------------^^^")
+ }
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ng/ExpressionParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ng/ExpressionParserSuite.scala
new file mode 100644
index 0000000000..32311a5a66
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ng/ExpressionParserSuite.scala
@@ -0,0 +1,497 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.catalyst.parser.ng
+
+import java.sql.{Date, Timestamp}
+
+import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, _}
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.PlanTest
+import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.CalendarInterval
+
+/**
+ * Test basic expression parsing. If a type of expression is supported it should be tested here.
+ *
+ * Please note that some of the expressions test don't have to be sound expressions, only their
+ * structure needs to be valid. Unsound expressions should be caught by the Analyzer or
+ * CheckAnalysis classes.
+ */
+class ExpressionParserSuite extends PlanTest {
+ import CatalystSqlParser._
+ import org.apache.spark.sql.catalyst.dsl.expressions._
+ import org.apache.spark.sql.catalyst.dsl.plans._
+
+ def assertEqual(sqlCommand: String, e: Expression): Unit = {
+ compareExpressions(parseExpression(sqlCommand), e)
+ }
+
+ def intercept(sqlCommand: String, messages: String*): Unit = {
+ val e = intercept[ParseException](parseExpression(sqlCommand))
+ messages.foreach { message =>
+ assert(e.message.contains(message))
+ }
+ }
+
+ test("star expressions") {
+ // Global Star
+ assertEqual("*", UnresolvedStar(None))
+
+ // Targeted Star
+ assertEqual("a.b.*", UnresolvedStar(Option(Seq("a", "b"))))
+ }
+
+ // NamedExpression (Alias/Multialias)
+ test("named expressions") {
+ // No Alias
+ val r0 = 'a
+ assertEqual("a", r0)
+
+ // Single Alias.
+ val r1 = 'a as "b"
+ assertEqual("a as b", r1)
+ assertEqual("a b", r1)
+
+ // Multi-Alias
+ assertEqual("a as (b, c)", MultiAlias('a, Seq("b", "c")))
+ assertEqual("a() (b, c)", MultiAlias('a.function(), Seq("b", "c")))
+
+ // Numeric literals without a space between the literal qualifier and the alias, should not be
+ // interpreted as such. An unresolved reference should be returned instead.
+ // TODO add the JIRA-ticket number.
+ assertEqual("1SL", Symbol("1SL"))
+
+ // Aliased star is allowed.
+ assertEqual("a.* b", UnresolvedStar(Option(Seq("a"))) as 'b)
+ }
+
+ test("binary logical expressions") {
+ // And
+ assertEqual("a and b", 'a && 'b)
+
+ // Or
+ assertEqual("a or b", 'a || 'b)
+
+ // Combination And/Or check precedence
+ assertEqual("a and b or c and d", ('a && 'b) || ('c && 'd))
+ assertEqual("a or b or c and d", 'a || 'b || ('c && 'd))
+
+ // Multiple AND/OR get converted into a balanced tree
+ assertEqual("a or b or c or d or e or f", (('a || 'b) || 'c) || (('d || 'e) || 'f))
+ assertEqual("a and b and c and d and e and f", (('a && 'b) && 'c) && (('d && 'e) && 'f))
+ }
+
+ test("long binary logical expressions") {
+ def testVeryBinaryExpression(op: String, clazz: Class[_]): Unit = {
+ val sql = (1 to 1000).map(x => s"$x == $x").mkString(op)
+ val e = parseExpression(sql)
+ assert(e.collect { case _: EqualTo => true }.size === 1000)
+ assert(e.collect { case x if clazz.isInstance(x) => true }.size === 999)
+ }
+ testVeryBinaryExpression(" AND ", classOf[And])
+ testVeryBinaryExpression(" OR ", classOf[Or])
+ }
+
+ test("not expressions") {
+ assertEqual("not a", !'a)
+ assertEqual("!a", !'a)
+ assertEqual("not true > true", Not(GreaterThan(true, true)))
+ }
+
+ test("exists expression") {
+ intercept("exists (select 1 from b where b.x = a.x)", "EXISTS clauses are not supported")
+ }
+
+ test("comparison expressions") {
+ assertEqual("a = b", 'a === 'b)
+ assertEqual("a == b", 'a === 'b)
+ assertEqual("a <=> b", 'a <=> 'b)
+ assertEqual("a <> b", 'a =!= 'b)
+ assertEqual("a != b", 'a =!= 'b)
+ assertEqual("a < b", 'a < 'b)
+ assertEqual("a <= b", 'a <= 'b)
+ assertEqual("a > b", 'a > 'b)
+ assertEqual("a >= b", 'a >= 'b)
+ }
+
+ test("between expressions") {
+ assertEqual("a between b and c", 'a >= 'b && 'a <= 'c)
+ assertEqual("a not between b and c", !('a >= 'b && 'a <= 'c))
+ }
+
+ test("in expressions") {
+ assertEqual("a in (b, c, d)", 'a in ('b, 'c, 'd))
+ assertEqual("a not in (b, c, d)", !('a in ('b, 'c, 'd)))
+ }
+
+ test("in sub-query") {
+ intercept("a in (select b from c)", "IN with a Sub-query is currently not supported")
+ }
+
+ test("like expressions") {
+ assertEqual("a like 'pattern%'", 'a like "pattern%")
+ assertEqual("a not like 'pattern%'", !('a like "pattern%"))
+ assertEqual("a rlike 'pattern%'", 'a rlike "pattern%")
+ assertEqual("a not rlike 'pattern%'", !('a rlike "pattern%"))
+ assertEqual("a regexp 'pattern%'", 'a rlike "pattern%")
+ assertEqual("a not regexp 'pattern%'", !('a rlike "pattern%"))
+ }
+
+ test("is null expressions") {
+ assertEqual("a is null", 'a.isNull)
+ assertEqual("a is not null", 'a.isNotNull)
+ assertEqual("a = b is null", ('a === 'b).isNull)
+ assertEqual("a = b is not null", ('a === 'b).isNotNull)
+ }
+
+ test("binary arithmetic expressions") {
+ // Simple operations
+ assertEqual("a * b", 'a * 'b)
+ assertEqual("a / b", 'a / 'b)
+ assertEqual("a DIV b", ('a / 'b).cast(LongType))
+ assertEqual("a % b", 'a % 'b)
+ assertEqual("a + b", 'a + 'b)
+ assertEqual("a - b", 'a - 'b)
+ assertEqual("a & b", 'a & 'b)
+ assertEqual("a ^ b", 'a ^ 'b)
+ assertEqual("a | b", 'a | 'b)
+
+ // Check precedences
+ assertEqual(
+ "a * t | b ^ c & d - e + f % g DIV h / i * k",
+ 'a * 't | ('b ^ ('c & ('d - 'e + (('f % 'g / 'h).cast(LongType) / 'i * 'k)))))
+ }
+
+ test("unary arithmetic expressions") {
+ assertEqual("+a", 'a)
+ assertEqual("-a", -'a)
+ assertEqual("~a", ~'a)
+ assertEqual("-+~~a", -(~(~'a)))
+ }
+
+ test("cast expressions") {
+ // Note that DataType parsing is tested elsewhere.
+ assertEqual("cast(a as int)", 'a.cast(IntegerType))
+ assertEqual("cast(a as timestamp)", 'a.cast(TimestampType))
+ assertEqual("cast(a as array<int>)", 'a.cast(ArrayType(IntegerType)))
+ assertEqual("cast(cast(a as int) as long)", 'a.cast(IntegerType).cast(LongType))
+ }
+
+ test("function expressions") {
+ assertEqual("foo()", 'foo.function())
+ assertEqual("foo.bar()", Symbol("foo.bar").function())
+ assertEqual("foo(*)", 'foo.function(star()))
+ assertEqual("count(*)", 'count.function(1))
+ assertEqual("foo(a, b)", 'foo.function('a, 'b))
+ assertEqual("foo(all a, b)", 'foo.function('a, 'b))
+ assertEqual("foo(distinct a, b)", 'foo.distinctFunction('a, 'b))
+ assertEqual("grouping(distinct a, b)", 'grouping.distinctFunction('a, 'b))
+ assertEqual("`select`(all a, b)", 'select.function('a, 'b))
+ }
+
+ test("window function expressions") {
+ val func = 'foo.function(star())
+ def windowed(
+ partitioning: Seq[Expression] = Seq.empty,
+ ordering: Seq[SortOrder] = Seq.empty,
+ frame: WindowFrame = UnspecifiedFrame): Expression = {
+ WindowExpression(func, WindowSpecDefinition(partitioning, ordering, frame))
+ }
+
+ // Basic window testing.
+ assertEqual("foo(*) over w1", UnresolvedWindowExpression(func, WindowSpecReference("w1")))
+ assertEqual("foo(*) over ()", windowed())
+ assertEqual("foo(*) over (partition by a, b)", windowed(Seq('a, 'b)))
+ assertEqual("foo(*) over (distribute by a, b)", windowed(Seq('a, 'b)))
+ assertEqual("foo(*) over (cluster by a, b)", windowed(Seq('a, 'b)))
+ assertEqual("foo(*) over (order by a desc, b asc)", windowed(Seq.empty, Seq('a.desc, 'b.asc )))
+ assertEqual("foo(*) over (sort by a desc, b asc)", windowed(Seq.empty, Seq('a.desc, 'b.asc )))
+ assertEqual("foo(*) over (partition by a, b order by c)", windowed(Seq('a, 'b), Seq('c.asc)))
+ assertEqual("foo(*) over (distribute by a, b sort by c)", windowed(Seq('a, 'b), Seq('c.asc)))
+
+ // Test use of expressions in window functions.
+ assertEqual(
+ "sum(product + 1) over (partition by ((product) + (1)) order by 2)",
+ WindowExpression('sum.function('product + 1),
+ WindowSpecDefinition(Seq('product + 1), Seq(Literal(2).asc), UnspecifiedFrame)))
+ assertEqual(
+ "sum(product + 1) over (partition by ((product / 2) + 1) order by 2)",
+ WindowExpression('sum.function('product + 1),
+ WindowSpecDefinition(Seq('product / 2 + 1), Seq(Literal(2).asc), UnspecifiedFrame)))
+
+ // Range/Row
+ val frameTypes = Seq(("rows", RowFrame), ("range", RangeFrame))
+ val boundaries = Seq(
+ ("10 preceding", ValuePreceding(10), CurrentRow),
+ ("3 + 1 following", ValueFollowing(4), CurrentRow), // Will fail during analysis
+ ("unbounded preceding", UnboundedPreceding, CurrentRow),
+ ("unbounded following", UnboundedFollowing, CurrentRow), // Will fail during analysis
+ ("between unbounded preceding and current row", UnboundedPreceding, CurrentRow),
+ ("between unbounded preceding and unbounded following",
+ UnboundedPreceding, UnboundedFollowing),
+ ("between 10 preceding and current row", ValuePreceding(10), CurrentRow),
+ ("between current row and 5 following", CurrentRow, ValueFollowing(5)),
+ ("between 10 preceding and 5 following", ValuePreceding(10), ValueFollowing(5))
+ )
+ frameTypes.foreach {
+ case (frameTypeSql, frameType) =>
+ boundaries.foreach {
+ case (boundarySql, begin, end) =>
+ val query = s"foo(*) over (partition by a order by b $frameTypeSql $boundarySql)"
+ val expr = windowed(Seq('a), Seq('b.asc), SpecifiedWindowFrame(frameType, begin, end))
+ assertEqual(query, expr)
+ }
+ }
+
+ // We cannot use non integer constants.
+ intercept("foo(*) over (partition by a order by b rows 10.0 preceding)",
+ "Frame bound value must be a constant integer.")
+
+ // We cannot use an arbitrary expression.
+ intercept("foo(*) over (partition by a order by b rows exp(b) preceding)",
+ "Frame bound value must be a constant integer.")
+ }
+
+ test("row constructor") {
+ // Note that '(a)' will be interpreted as a nested expression.
+ assertEqual("(a, b)", CreateStruct(Seq('a, 'b)))
+ assertEqual("(a, b, c)", CreateStruct(Seq('a, 'b, 'c)))
+ }
+
+ test("scalar sub-query") {
+ assertEqual(
+ "(select max(val) from tbl) > current",
+ ScalarSubquery(table("tbl").select('max.function('val))) > 'current)
+ assertEqual(
+ "a = (select b from s)",
+ 'a === ScalarSubquery(table("s").select('b)))
+ }
+
+ test("case when") {
+ assertEqual("case a when 1 then b when 2 then c else d end",
+ CaseKeyWhen('a, Seq(1, 'b, 2, 'c, 'd)))
+ assertEqual("case when a = 1 then b when a = 2 then c else d end",
+ CaseWhen(Seq(('a === 1, 'b.expr), ('a === 2, 'c.expr)), 'd))
+ }
+
+ test("dereference") {
+ assertEqual("a.b", UnresolvedAttribute("a.b"))
+ assertEqual("`select`.b", UnresolvedAttribute("select.b"))
+ assertEqual("(a + b).b", ('a + 'b).getField("b")) // This will fail analysis.
+ assertEqual("struct(a, b).b", 'struct.function('a, 'b).getField("b"))
+ }
+
+ test("reference") {
+ // Regular
+ assertEqual("a", 'a)
+
+ // Starting with a digit.
+ assertEqual("1a", Symbol("1a"))
+
+ // Quoted using a keyword.
+ assertEqual("`select`", 'select)
+
+ // Unquoted using an unreserved keyword.
+ assertEqual("columns", 'columns)
+ }
+
+ test("subscript") {
+ assertEqual("a[b]", 'a.getItem('b))
+ assertEqual("a[1 + 1]", 'a.getItem(Literal(1) + 1))
+ assertEqual("`c`.a[b]", UnresolvedAttribute("c.a").getItem('b))
+ }
+
+ test("parenthesis") {
+ assertEqual("(a)", 'a)
+ assertEqual("r * (a + b)", 'r * ('a + 'b))
+ }
+
+ test("type constructors") {
+ // Dates.
+ assertEqual("dAte '2016-03-11'", Literal(Date.valueOf("2016-03-11")))
+ intercept[IllegalArgumentException] {
+ parseExpression("DAtE 'mar 11 2016'")
+ }
+
+ // Timestamps.
+ assertEqual("tImEstAmp '2016-03-11 20:54:00.000'",
+ Literal(Timestamp.valueOf("2016-03-11 20:54:00.000")))
+ intercept[IllegalArgumentException] {
+ parseExpression("timestamP '2016-33-11 20:54:00.000'")
+ }
+
+ // Unsupported datatype.
+ intercept("GEO '(10,-6)'", "Literals of type 'GEO' are currently not supported.")
+ }
+
+ test("literals") {
+ // NULL
+ assertEqual("null", Literal(null))
+
+ // Boolean
+ assertEqual("trUe", Literal(true))
+ assertEqual("False", Literal(false))
+
+ // Integral should have the narrowest possible type
+ assertEqual("787324", Literal(787324))
+ assertEqual("7873247234798249234", Literal(7873247234798249234L))
+ assertEqual("78732472347982492793712334",
+ Literal(BigDecimal("78732472347982492793712334").underlying()))
+
+ // Decimal
+ assertEqual("7873247234798249279371.2334",
+ Literal(BigDecimal("7873247234798249279371.2334").underlying()))
+
+ // Scientific Decimal
+ assertEqual("9.0e1", 90d)
+ assertEqual(".9e+2", 90d)
+ assertEqual("0.9e+2", 90d)
+ assertEqual("900e-1", 90d)
+ assertEqual("900.0E-1", 90d)
+ assertEqual("9.e+1", 90d)
+ intercept(".e3")
+
+ // Tiny Int Literal
+ assertEqual("10Y", Literal(10.toByte))
+ intercept("-1000Y")
+
+ // Small Int Literal
+ assertEqual("10S", Literal(10.toShort))
+ intercept("40000S")
+
+ // Long Int Literal
+ assertEqual("10L", Literal(10L))
+ intercept("78732472347982492793712334L")
+
+ // Double Literal
+ assertEqual("10.0D", Literal(10.0D))
+ // TODO we need to figure out if we should throw an exception here!
+ assertEqual("1E309", Literal(Double.PositiveInfinity))
+ }
+
+ test("strings") {
+ // Single Strings.
+ assertEqual("\"hello\"", "hello")
+ assertEqual("'hello'", "hello")
+
+ // Multi-Strings.
+ assertEqual("\"hello\" 'world'", "helloworld")
+ assertEqual("'hello' \" \" 'world'", "hello world")
+
+ // 'LIKE' string literals. Notice that an escaped '%' is the same as an escaped '\' and a
+ // regular '%'; to get the correct result you need to add another escaped '\'.
+ // TODO figure out if we shouldn't change the ParseUtils.unescapeSQLString method?
+ assertEqual("'pattern%'", "pattern%")
+ assertEqual("'no-pattern\\%'", "no-pattern\\%")
+ assertEqual("'pattern\\\\%'", "pattern\\%")
+ assertEqual("'pattern\\\\\\%'", "pattern\\\\%")
+
+ // Escaped characters.
+ // See: http://dev.mysql.com/doc/refman/5.7/en/string-literals.html
+ assertEqual("'\\0'", "\u0000") // ASCII NUL (X'00')
+ assertEqual("'\\''", "\'") // Single quote
+ assertEqual("'\\\"'", "\"") // Double quote
+ assertEqual("'\\b'", "\b") // Backspace
+ assertEqual("'\\n'", "\n") // Newline
+ assertEqual("'\\r'", "\r") // Carriage return
+ assertEqual("'\\t'", "\t") // Tab character
+ assertEqual("'\\Z'", "\u001A") // ASCII 26 - CTRL + Z (EOF on windows)
+
+ // Octals
+ assertEqual("'\\110\\145\\154\\154\\157\\041'", "Hello!")
+
+ // Unicode
+ assertEqual("'\\u0087\\u0111\\u0114\\u0108\\u0100\\u0032\\u0058\\u0041'", "World :)")
+ }
+
+ test("intervals") {
+ def intervalLiteral(u: String, s: String): Literal = {
+ Literal(CalendarInterval.fromSingleUnitString(u, s))
+ }
+
+ // Empty interval statement
+ intercept("interval", "at least one time unit should be given for interval literal")
+
+ // Single Intervals.
+ val units = Seq(
+ "year",
+ "month",
+ "week",
+ "day",
+ "hour",
+ "minute",
+ "second",
+ "millisecond",
+ "microsecond")
+ val forms = Seq("", "s")
+ val values = Seq("0", "10", "-7", "21")
+ units.foreach { unit =>
+ forms.foreach { form =>
+ values.foreach { value =>
+ val expected = intervalLiteral(unit, value)
+ assertEqual(s"interval $value $unit$form", expected)
+ assertEqual(s"interval '$value' $unit$form", expected)
+ }
+ }
+ }
+
+ // Hive nanosecond notation.
+ assertEqual("interval 13.123456789 seconds", intervalLiteral("second", "13.123456789"))
+ assertEqual("interval -13.123456789 second", intervalLiteral("second", "-13.123456789"))
+
+ // Non Existing unit
+ intercept("interval 10 nanoseconds", "No interval can be constructed")
+
+ // Year-Month intervals.
+ val yearMonthValues = Seq("123-10", "496-0", "-2-3", "-123-0")
+ yearMonthValues.foreach { value =>
+ val result = Literal(CalendarInterval.fromYearMonthString(value))
+ assertEqual(s"interval '$value' year to month", result)
+ }
+
+ // Day-Time intervals.
+ val datTimeValues = Seq(
+ "99 11:22:33.123456789",
+ "-99 11:22:33.123456789",
+ "10 9:8:7.123456789",
+ "1 0:0:0",
+ "-1 0:0:0",
+ "1 0:0:1")
+ datTimeValues.foreach { value =>
+ val result = Literal(CalendarInterval.fromDayTimeString(value))
+ assertEqual(s"interval '$value' day to second", result)
+ }
+
+ // Unknown FROM TO intervals
+ intercept("interval 10 month to second", "Intervals FROM month TO second are not supported.")
+
+ // Composed intervals.
+ assertEqual(
+ "interval 3 months 22 seconds 1 millisecond",
+ Literal(new CalendarInterval(3, 22001000L)))
+ assertEqual(
+ "interval 3 years '-1-10' year to month 3 weeks '1 0:0:2' day to second",
+ Literal(new CalendarInterval(14,
+ 22 * CalendarInterval.MICROS_PER_DAY + 2 * CalendarInterval.MICROS_PER_SECOND)))
+ }
+
+ test("composed expressions") {
+ assertEqual("1 + r.r As q", (Literal(1) + UnresolvedAttribute("r.r")).as("q"))
+ assertEqual("1 - f('o', o(bar))", Literal(1) - 'f.function("o", 'o.function('bar)))
+ intercept("1 - f('o', o(bar)) hello * world", "mismatched input '*'")
+ }
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ng/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ng/PlanParserSuite.scala
new file mode 100644
index 0000000000..4206d22ca7
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ng/PlanParserSuite.scala
@@ -0,0 +1,429 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.catalyst.parser.ng
+
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans._
+import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.types.{BooleanType, IntegerType}
+
+class PlanParserSuite extends PlanTest {
+ import CatalystSqlParser._
+ import org.apache.spark.sql.catalyst.dsl.expressions._
+ import org.apache.spark.sql.catalyst.dsl.plans._
+
+ def assertEqual(sqlCommand: String, plan: LogicalPlan): Unit = {
+ comparePlans(parsePlan(sqlCommand), plan)
+ }
+
+ def intercept(sqlCommand: String, messages: String*): Unit = {
+ val e = intercept[ParseException](parsePlan(sqlCommand))
+ messages.foreach { message =>
+ assert(e.message.contains(message))
+ }
+ }
+
+ test("case insensitive") {
+ val plan = table("a").select(star())
+ assertEqual("sELEct * FroM a", plan)
+ assertEqual("select * fRoM a", plan)
+ assertEqual("SELECT * FROM a", plan)
+ }
+
+ test("show functions") {
+ assertEqual("show functions", ShowFunctions(None, None))
+ assertEqual("show functions foo", ShowFunctions(None, Some("foo")))
+ assertEqual("show functions foo.bar", ShowFunctions(Some("foo"), Some("bar")))
+ assertEqual("show functions 'foo\\\\.*'", ShowFunctions(None, Some("foo\\.*")))
+ intercept("show functions foo.bar.baz", "SHOW FUNCTIONS unsupported name")
+ }
+
+ test("describe function") {
+ assertEqual("describe function bar", DescribeFunction("bar", isExtended = false))
+ assertEqual("describe function extended bar", DescribeFunction("bar", isExtended = true))
+ assertEqual("describe function foo.bar", DescribeFunction("foo.bar", isExtended = false))
+ assertEqual("describe function extended f.bar", DescribeFunction("f.bar", isExtended = true))
+ }
+
+ test("set operations") {
+ val a = table("a").select(star())
+ val b = table("b").select(star())
+
+ assertEqual("select * from a union select * from b", Distinct(a.union(b)))
+ assertEqual("select * from a union distinct select * from b", Distinct(a.union(b)))
+ assertEqual("select * from a union all select * from b", a.union(b))
+ assertEqual("select * from a except select * from b", a.except(b))
+ intercept("select * from a except all select * from b", "EXCEPT ALL is not supported.")
+ assertEqual("select * from a except distinct select * from b", a.except(b))
+ assertEqual("select * from a intersect select * from b", a.intersect(b))
+ intercept("select * from a intersect all select * from b", "INTERSECT ALL is not supported.")
+ assertEqual("select * from a intersect distinct select * from b", a.intersect(b))
+ }
+
+ test("common table expressions") {
+ def cte(plan: LogicalPlan, namedPlans: (String, LogicalPlan)*): With = {
+ val ctes = namedPlans.map {
+ case (name, cte) =>
+ name -> SubqueryAlias(name, cte)
+ }.toMap
+ With(plan, ctes)
+ }
+ assertEqual(
+ "with cte1 as (select * from a) select * from cte1",
+ cte(table("cte1").select(star()), "cte1" -> table("a").select(star())))
+ assertEqual(
+ "with cte1 (select 1) select * from cte1",
+ cte(table("cte1").select(star()), "cte1" -> OneRowRelation.select(1)))
+ assertEqual(
+ "with cte1 (select 1), cte2 as (select * from cte1) select * from cte2",
+ cte(table("cte2").select(star()),
+ "cte1" -> OneRowRelation.select(1),
+ "cte2" -> table("cte1").select(star())))
+ intercept(
+ "with cte1 (select 1), cte1 as (select 1 from cte1) select * from cte1",
+ "Name 'cte1' is used for multiple common table expressions")
+ }
+
+ test("simple select query") {
+ assertEqual("select 1", OneRowRelation.select(1))
+ assertEqual("select a, b", OneRowRelation.select('a, 'b))
+ assertEqual("select a, b from db.c", table("db", "c").select('a, 'b))
+ assertEqual("select a, b from db.c where x < 1", table("db", "c").where('x < 1).select('a, 'b))
+ assertEqual(
+ "select a, b from db.c having x < 1",
+ table("db", "c").select('a, 'b).where(('x < 1).cast(BooleanType)))
+ assertEqual("select distinct a, b from db.c", Distinct(table("db", "c").select('a, 'b)))
+ assertEqual("select all a, b from db.c", table("db", "c").select('a, 'b))
+ }
+
+ test("reverse select query") {
+ assertEqual("from a", table("a"))
+ assertEqual("from a select b, c", table("a").select('b, 'c))
+ assertEqual(
+ "from db.a select b, c where d < 1", table("db", "a").where('d < 1).select('b, 'c))
+ assertEqual("from a select distinct b, c", Distinct(table("a").select('b, 'c)))
+ assertEqual(
+ "from (from a union all from b) c select *",
+ table("a").union(table("b")).as("c").select(star()))
+ }
+
+ test("transform query spec") {
+ val p = ScriptTransformation(Seq('a, 'b), "func", Seq.empty, table("e"), null)
+ assertEqual("select transform(a, b) using 'func' from e where f < 10",
+ p.copy(child = p.child.where('f < 10), output = Seq('key.string, 'value.string)))
+ assertEqual("map a, b using 'func' as c, d from e",
+ p.copy(output = Seq('c.string, 'd.string)))
+ assertEqual("reduce a, b using 'func' as (c: int, d decimal(10, 0)) from e",
+ p.copy(output = Seq('c.int, 'd.decimal(10, 0))))
+ }
+
+ test("multi select query") {
+ assertEqual(
+ "from a select * select * where s < 10",
+ table("a").select(star()).union(table("a").where('s < 10).select(star())))
+ intercept(
+ "from a select * select * from x where a.s < 10",
+ "Multi-Insert queries cannot have a FROM clause in their individual SELECT statements")
+ assertEqual(
+ "from a insert into tbl1 select * insert into tbl2 select * where s < 10",
+ table("a").select(star()).insertInto("tbl1").union(
+ table("a").where('s < 10).select(star()).insertInto("tbl2")))
+ }
+
+ test("query organization") {
+ // Test all valid combinations of order by/sort by/distribute by/cluster by/limit/windows
+ val baseSql = "select * from t"
+ val basePlan = table("t").select(star())
+
+ val ws = Map("w1" -> WindowSpecDefinition(Seq.empty, Seq.empty, UnspecifiedFrame))
+ val limitWindowClauses = Seq(
+ ("", (p: LogicalPlan) => p),
+ (" limit 10", (p: LogicalPlan) => p.limit(10)),
+ (" window w1 as ()", (p: LogicalPlan) => WithWindowDefinition(ws, p)),
+ (" window w1 as () limit 10", (p: LogicalPlan) => WithWindowDefinition(ws, p).limit(10))
+ )
+
+ val orderSortDistrClusterClauses = Seq(
+ ("", basePlan),
+ (" order by a, b desc", basePlan.orderBy('a.asc, 'b.desc)),
+ (" sort by a, b desc", basePlan.sortBy('a.asc, 'b.desc)),
+ (" distribute by a, b", basePlan.distribute('a, 'b)),
+ (" distribute by a sort by b", basePlan.distribute('a).sortBy('b.asc)),
+ (" cluster by a, b", basePlan.distribute('a, 'b).sortBy('a.asc, 'b.asc))
+ )
+
+ orderSortDistrClusterClauses.foreach {
+ case (s1, p1) =>
+ limitWindowClauses.foreach {
+ case (s2, pf2) =>
+ assertEqual(baseSql + s1 + s2, pf2(p1))
+ }
+ }
+
+ val msg = "Combination of ORDER BY/SORT BY/DISTRIBUTE BY/CLUSTER BY is not supported"
+ intercept(s"$baseSql order by a sort by a", msg)
+ intercept(s"$baseSql cluster by a distribute by a", msg)
+ intercept(s"$baseSql order by a cluster by a", msg)
+ intercept(s"$baseSql order by a distribute by a", msg)
+ }
+
+ test("insert into") {
+ val sql = "select * from t"
+ val plan = table("t").select(star())
+ def insert(
+ partition: Map[String, Option[String]],
+ overwrite: Boolean = false,
+ ifNotExists: Boolean = false): LogicalPlan =
+ InsertIntoTable(table("s"), partition, plan, overwrite, ifNotExists)
+
+ // Single inserts
+ assertEqual(s"insert overwrite table s $sql",
+ insert(Map.empty, overwrite = true))
+ assertEqual(s"insert overwrite table s if not exists $sql",
+ insert(Map.empty, overwrite = true, ifNotExists = true))
+ assertEqual(s"insert into s $sql",
+ insert(Map.empty))
+ assertEqual(s"insert into table s partition (c = 'd', e = 1) $sql",
+ insert(Map("c" -> Option("d"), "e" -> Option("1"))))
+ assertEqual(s"insert overwrite table s partition (c = 'd', x) if not exists $sql",
+ insert(Map("c" -> Option("d"), "x" -> None), overwrite = true, ifNotExists = true))
+
+ // Multi insert
+ val plan2 = table("t").where('x > 5).select(star())
+ assertEqual("from t insert into s select * limit 1 insert into u select * where x > 5",
+ InsertIntoTable(
+ table("s"), Map.empty, plan.limit(1), overwrite = false, ifNotExists = false).union(
+ InsertIntoTable(
+ table("u"), Map.empty, plan2, overwrite = false, ifNotExists = false)))
+ }
+
+ test("aggregation") {
+ val sql = "select a, b, sum(c) as c from d group by a, b"
+
+ // Normal
+ assertEqual(sql, table("d").groupBy('a, 'b)('a, 'b, 'sum.function('c).as("c")))
+
+ // Cube
+ assertEqual(s"$sql with cube",
+ table("d").groupBy(Cube(Seq('a, 'b)))('a, 'b, 'sum.function('c).as("c")))
+
+ // Rollup
+ assertEqual(s"$sql with rollup",
+ table("d").groupBy(Rollup(Seq('a, 'b)))('a, 'b, 'sum.function('c).as("c")))
+
+ // Grouping Sets
+ assertEqual(s"$sql grouping sets((a, b), (a), ())",
+ GroupingSets(Seq(0, 1, 3), Seq('a, 'b), table("d"), Seq('a, 'b, 'sum.function('c).as("c"))))
+ intercept(s"$sql grouping sets((a, b), (c), ())",
+ "c doesn't show up in the GROUP BY list")
+ }
+
+ test("limit") {
+ val sql = "select * from t"
+ val plan = table("t").select(star())
+ assertEqual(s"$sql limit 10", plan.limit(10))
+ assertEqual(s"$sql limit cast(9 / 4 as int)", plan.limit(Cast(Literal(9) / 4, IntegerType)))
+ }
+
+ test("window spec") {
+ // Note that WindowSpecs are testing in the ExpressionParserSuite
+ val sql = "select * from t"
+ val plan = table("t").select(star())
+ val spec = WindowSpecDefinition(Seq('a, 'b), Seq('c.asc),
+ SpecifiedWindowFrame(RowFrame, ValuePreceding(1), ValueFollowing(1)))
+
+ // Test window resolution.
+ val ws1 = Map("w1" -> spec, "w2" -> spec, "w3" -> spec)
+ assertEqual(
+ s"""$sql
+ |window w1 as (partition by a, b order by c rows between 1 preceding and 1 following),
+ | w2 as w1,
+ | w3 as w1""".stripMargin,
+ WithWindowDefinition(ws1, plan))
+
+ // Fail with no reference.
+ intercept(s"$sql window w2 as w1", "Cannot resolve window reference 'w1'")
+
+ // Fail when resolved reference is not a window spec.
+ intercept(
+ s"""$sql
+ |window w1 as (partition by a, b order by c rows between 1 preceding and 1 following),
+ | w2 as w1,
+ | w3 as w2""".stripMargin,
+ "Window reference 'w2' is not a window specification"
+ )
+ }
+
+ test("lateral view") {
+ // Single lateral view
+ assertEqual(
+ "select * from t lateral view explode(x) expl as x",
+ table("t")
+ .generate(Explode('x), join = true, outer = false, Some("expl"), Seq("x"))
+ .select(star()))
+
+ // Multiple lateral views
+ assertEqual(
+ """select *
+ |from t
+ |lateral view explode(x) expl
+ |lateral view outer json_tuple(x, y) jtup q, z""".stripMargin,
+ table("t")
+ .generate(Explode('x), join = true, outer = false, Some("expl"), Seq.empty)
+ .generate(JsonTuple(Seq('x, 'y)), join = true, outer = true, Some("jtup"), Seq("q", "z"))
+ .select(star()))
+
+ // Multi-Insert lateral views.
+ val from = table("t1").generate(Explode('x), join = true, outer = false, Some("expl"), Seq("x"))
+ assertEqual(
+ """from t1
+ |lateral view explode(x) expl as x
+ |insert into t2
+ |select *
+ |lateral view json_tuple(x, y) jtup q, z
+ |insert into t3
+ |select *
+ |where s < 10
+ """.stripMargin,
+ Union(from
+ .generate(JsonTuple(Seq('x, 'y)), join = true, outer = false, Some("jtup"), Seq("q", "z"))
+ .select(star())
+ .insertInto("t2"),
+ from.where('s < 10).select(star()).insertInto("t3")))
+
+ // Unsupported generator.
+ intercept(
+ "select * from t lateral view posexplode(x) posexpl as x, y",
+ "Generator function 'posexplode' is not supported")
+ }
+
+ test("joins") {
+ // Test single joins.
+ val testUnconditionalJoin = (sql: String, jt: JoinType) => {
+ assertEqual(
+ s"select * from t as tt $sql u",
+ table("t").as("tt").join(table("u"), jt, None).select(star()))
+ }
+ val testConditionalJoin = (sql: String, jt: JoinType) => {
+ assertEqual(
+ s"select * from t $sql u as uu on a = b",
+ table("t").join(table("u").as("uu"), jt, Option('a === 'b)).select(star()))
+ }
+ val testNaturalJoin = (sql: String, jt: JoinType) => {
+ assertEqual(
+ s"select * from t tt natural $sql u as uu",
+ table("t").as("tt").join(table("u").as("uu"), NaturalJoin(jt), None).select(star()))
+ }
+ val testUsingJoin = (sql: String, jt: JoinType) => {
+ assertEqual(
+ s"select * from t $sql u using(a, b)",
+ table("t").join(table("u"), UsingJoin(jt, Seq('a.attr, 'b.attr)), None).select(star()))
+ }
+ val testAll = Seq(testUnconditionalJoin, testConditionalJoin, testNaturalJoin, testUsingJoin)
+
+ def test(sql: String, jt: JoinType, tests: Seq[(String, JoinType) => Unit]): Unit = {
+ tests.foreach(_(sql, jt))
+ }
+ test("cross join", Inner, Seq(testUnconditionalJoin))
+ test(",", Inner, Seq(testUnconditionalJoin))
+ test("join", Inner, testAll)
+ test("inner join", Inner, testAll)
+ test("left join", LeftOuter, testAll)
+ test("left outer join", LeftOuter, testAll)
+ test("right join", RightOuter, testAll)
+ test("right outer join", RightOuter, testAll)
+ test("full join", FullOuter, testAll)
+ test("full outer join", FullOuter, testAll)
+
+ // Test multiple consecutive joins
+ assertEqual(
+ "select * from a join b join c right join d",
+ table("a").join(table("b")).join(table("c")).join(table("d"), RightOuter).select(star()))
+ }
+
+ test("sampled relations") {
+ val sql = "select * from t"
+ assertEqual(s"$sql tablesample(100 rows)",
+ table("t").limit(100).select(star()))
+ assertEqual(s"$sql tablesample(43 percent) as x",
+ Sample(0, .43d, withReplacement = false, 10L, table("t").as("x"))(true).select(star()))
+ assertEqual(s"$sql tablesample(bucket 4 out of 10) as x",
+ Sample(0, .4d, withReplacement = false, 10L, table("t").as("x"))(true).select(star()))
+ intercept(s"$sql tablesample(bucket 4 out of 10 on x) as x",
+ "TABLESAMPLE(BUCKET x OUT OF y ON id) is not supported")
+ intercept(s"$sql tablesample(bucket 11 out of 10) as x",
+ s"Sampling fraction (${11.0/10.0}) must be on interval [0, 1]")
+ }
+
+ test("sub-query") {
+ val plan = table("t0").select('id)
+ assertEqual("select id from (t0)", plan)
+ assertEqual("select id from ((((((t0))))))", plan)
+ assertEqual(
+ "(select * from t1) union distinct (select * from t2)",
+ Distinct(table("t1").select(star()).union(table("t2").select(star()))))
+ assertEqual(
+ "select * from ((select * from t1) union (select * from t2)) t",
+ Distinct(
+ table("t1").select(star()).union(table("t2").select(star()))).as("t").select(star()))
+ assertEqual(
+ """select id
+ |from (((select id from t0)
+ | union all
+ | (select id from t0))
+ | union all
+ | (select id from t0)) as u_1
+ """.stripMargin,
+ plan.union(plan).union(plan).as("u_1").select('id))
+ }
+
+ test("scalar sub-query") {
+ assertEqual(
+ "select (select max(b) from s) ss from t",
+ table("t").select(ScalarSubquery(table("s").select('max.function('b))).as("ss")))
+ assertEqual(
+ "select * from t where a = (select b from s)",
+ table("t").where('a === ScalarSubquery(table("s").select('b))).select(star()))
+ assertEqual(
+ "select g from t group by g having a > (select b from s)",
+ table("t")
+ .groupBy('g)('g)
+ .where(('a > ScalarSubquery(table("s").select('b))).cast(BooleanType)))
+ }
+
+ test("table reference") {
+ assertEqual("table t", table("t"))
+ assertEqual("table d.t", table("d", "t"))
+ }
+
+ test("inline table") {
+ assertEqual("values 1, 2, 3, 4", LocalRelation.fromExternalRows(
+ Seq('col1.int),
+ Seq(1, 2, 3, 4).map(x => Row(x))))
+ assertEqual(
+ "values (1, 'a'), (2, 'b'), (3, 'c') as tbl(a, b)",
+ LocalRelation.fromExternalRows(
+ Seq('a.int, 'b.string),
+ Seq((1, "a"), (2, "b"), (3, "c")).map(x => Row(x._1, x._2))).as("tbl"))
+ intercept("values (a, 'a'), (b, 'b')",
+ "All expressions in an inline table must be constants.")
+ intercept("values (1, 'a'), (2, 'b') as tbl(a, b, c)",
+ "Number of aliases must match the number of fields in an inline table.")
+ intercept[ArrayIndexOutOfBoundsException](parsePlan("values (1, 'a'), (2, 'b', 5Y)"))
+ }
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ng/TableIdentifierParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ng/TableIdentifierParserSuite.scala
new file mode 100644
index 0000000000..0874322187
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ng/TableIdentifierParserSuite.scala
@@ -0,0 +1,42 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.catalyst.parser.ng
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.catalyst.TableIdentifier
+
+class TableIdentifierParserSuite extends SparkFunSuite {
+ import CatalystSqlParser._
+
+ test("table identifier") {
+ // Regular names.
+ assert(TableIdentifier("q") === parseTableIdentifier("q"))
+ assert(TableIdentifier("q", Option("d")) === parseTableIdentifier("d.q"))
+
+ // Illegal names.
+ intercept[ParseException](parseTableIdentifier(""))
+ intercept[ParseException](parseTableIdentifier("d.q.g"))
+
+ // SQL Keywords.
+ val keywords = Seq("select", "from", "where", "left", "right")
+ keywords.foreach { keyword =>
+ intercept[ParseException](parseTableIdentifier(keyword))
+ assert(TableIdentifier(keyword) === parseTableIdentifier(s"`$keyword`"))
+ assert(TableIdentifier(keyword, Option("db")) === parseTableIdentifier(s"db.`$keyword`"))
+ }
+ }
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala
index 0541844e0b..aa5d4330d3 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.plans
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, OneRowRelation}
+import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, OneRowRelation, Sample}
import org.apache.spark.sql.catalyst.util._
/**
@@ -32,6 +32,8 @@ abstract class PlanTest extends SparkFunSuite with PredicateHelper {
*/
protected def normalizeExprIds(plan: LogicalPlan) = {
plan transformAllExpressions {
+ case s: ScalarSubquery =>
+ ScalarSubquery(s.query, ExprId(0))
case a: AttributeReference =>
AttributeReference(a.name, a.dataType, a.nullable)(exprId = ExprId(0))
case a: Alias =>
@@ -40,21 +42,25 @@ abstract class PlanTest extends SparkFunSuite with PredicateHelper {
}
/**
- * Normalizes the filter conditions that appear in the plan. For instance,
- * ((expr 1 && expr 2) && expr 3), (expr 1 && expr 2 && expr 3), (expr 3 && (expr 1 && expr 2)
- * etc., will all now be equivalent.
+ * Normalizes plans:
+ * - Filter the filter conditions that appear in a plan. For instance,
+ * ((expr 1 && expr 2) && expr 3), (expr 1 && expr 2 && expr 3), (expr 3 && (expr 1 && expr 2)
+ * etc., will all now be equivalent.
+ * - Sample the seed will replaced by 0L.
*/
- private def normalizeFilters(plan: LogicalPlan) = {
+ private def normalizePlan(plan: LogicalPlan): LogicalPlan = {
plan transform {
case filter @ Filter(condition: Expression, child: LogicalPlan) =>
Filter(splitConjunctivePredicates(condition).sortBy(_.hashCode()).reduce(And), child)
+ case sample: Sample =>
+ sample.copy(seed = 0L)(true)
}
}
/** Fails the test if the two plans do not match */
protected def comparePlans(plan1: LogicalPlan, plan2: LogicalPlan) {
- val normalized1 = normalizeFilters(normalizeExprIds(plan1))
- val normalized2 = normalizeFilters(normalizeExprIds(plan2))
+ val normalized1 = normalizePlan(normalizeExprIds(plan1))
+ val normalized2 = normalizePlan(normalizeExprIds(plan2))
if (normalized1 != normalized2) {
fail(
s"""
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala
new file mode 100644
index 0000000000..c098fa99c2
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala
@@ -0,0 +1,219 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution
+
+import scala.collection.JavaConverters._
+
+import org.apache.spark.sql.SaveMode
+import org.apache.spark.sql.catalyst.TableIdentifier
+import org.apache.spark.sql.catalyst.parser.ng.{AbstractSqlParser, AstBuilder}
+import org.apache.spark.sql.catalyst.parser.ng.SqlBaseParser._
+import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, OneRowRelation}
+import org.apache.spark.sql.execution.command.{DescribeCommand => _, _}
+import org.apache.spark.sql.execution.datasources._
+
+/**
+ * Concrete parser for Spark SQL statements.
+ */
+object SparkSqlParser extends AbstractSqlParser{
+ val astBuilder = new SparkSqlAstBuilder
+}
+
+/**
+ * Builder that converts an ANTLR ParseTree into a LogicalPlan/Expression/TableIdentifier.
+ */
+class SparkSqlAstBuilder extends AstBuilder {
+ import org.apache.spark.sql.catalyst.parser.ng.ParserUtils._
+
+ /**
+ * Create a [[SetCommand]] logical plan.
+ *
+ * Note that we assume that everything after the SET keyword is assumed to be a part of the
+ * key-value pair. The split between key and value is made by searching for the first `=`
+ * character in the raw string.
+ */
+ override def visitSetConfiguration(ctx: SetConfigurationContext): LogicalPlan = withOrigin(ctx) {
+ // Construct the command.
+ val raw = remainder(ctx.SET.getSymbol)
+ val keyValueSeparatorIndex = raw.indexOf('=')
+ if (keyValueSeparatorIndex >= 0) {
+ val key = raw.substring(0, keyValueSeparatorIndex).trim
+ val value = raw.substring(keyValueSeparatorIndex + 1).trim
+ SetCommand(Some(key -> Option(value)))
+ } else if (raw.nonEmpty) {
+ SetCommand(Some(raw.trim -> None))
+ } else {
+ SetCommand(None)
+ }
+ }
+
+ /**
+ * Create a [[SetDatabaseCommand]] logical plan.
+ */
+ override def visitUse(ctx: UseContext): LogicalPlan = withOrigin(ctx) {
+ SetDatabaseCommand(ctx.db.getText)
+ }
+
+ /**
+ * Create a [[ShowTablesCommand]] logical plan.
+ */
+ override def visitShowTables(ctx: ShowTablesContext): LogicalPlan = withOrigin(ctx) {
+ if (ctx.LIKE != null) {
+ logWarning("SHOW TABLES LIKE option is ignored.")
+ }
+ ShowTablesCommand(Option(ctx.db).map(_.getText))
+ }
+
+ /**
+ * Create a [[RefreshTable]] logical plan.
+ */
+ override def visitRefreshTable(ctx: RefreshTableContext): LogicalPlan = withOrigin(ctx) {
+ RefreshTable(visitTableIdentifier(ctx.tableIdentifier))
+ }
+
+ /**
+ * Create a [[CacheTableCommand]] logical plan.
+ */
+ override def visitCacheTable(ctx: CacheTableContext): LogicalPlan = withOrigin(ctx) {
+ val query = Option(ctx.query).map(plan)
+ CacheTableCommand(ctx.identifier.getText, query, ctx.LAZY != null)
+ }
+
+ /**
+ * Create an [[UncacheTableCommand]] logical plan.
+ */
+ override def visitUncacheTable(ctx: UncacheTableContext): LogicalPlan = withOrigin(ctx) {
+ UncacheTableCommand(ctx.identifier.getText)
+ }
+
+ /**
+ * Create a [[ClearCacheCommand]] logical plan.
+ */
+ override def visitClearCache(ctx: ClearCacheContext): LogicalPlan = withOrigin(ctx) {
+ ClearCacheCommand
+ }
+
+ /**
+ * Create an [[ExplainCommand]] logical plan.
+ */
+ override def visitExplain(ctx: ExplainContext): LogicalPlan = withOrigin(ctx) {
+ val options = ctx.explainOption.asScala
+ if (options.exists(_.FORMATTED != null)) {
+ logWarning("EXPLAIN FORMATTED option is ignored.")
+ }
+ if (options.exists(_.LOGICAL != null)) {
+ logWarning("EXPLAIN LOGICAL option is ignored.")
+ }
+
+ // Create the explain comment.
+ val statement = plan(ctx.statement)
+ if (isExplainableStatement(statement)) {
+ ExplainCommand(statement, extended = options.exists(_.EXTENDED != null))
+ } else {
+ ExplainCommand(OneRowRelation)
+ }
+ }
+
+ /**
+ * Determine if a plan should be explained at all.
+ */
+ protected def isExplainableStatement(plan: LogicalPlan): Boolean = plan match {
+ case _: datasources.DescribeCommand => false
+ case _ => true
+ }
+
+ /**
+ * Create a [[DescribeCommand]] logical plan.
+ */
+ override def visitDescribeTable(ctx: DescribeTableContext): LogicalPlan = withOrigin(ctx) {
+ // FORMATTED and columns are not supported. Return null and let the parser decide what to do
+ // with this (create an exception or pass it on to a different system).
+ if (ctx.describeColName != null || ctx.FORMATTED != null || ctx.partitionSpec != null) {
+ null
+ } else {
+ datasources.DescribeCommand(
+ visitTableIdentifier(ctx.tableIdentifier),
+ ctx.EXTENDED != null)
+ }
+ }
+
+ /** Type to keep track of a table header. */
+ type TableHeader = (TableIdentifier, Boolean, Boolean, Boolean)
+
+ /**
+ * Validate a create table statement and return the [[TableIdentifier]].
+ */
+ override def visitCreateTableHeader(
+ ctx: CreateTableHeaderContext): TableHeader = withOrigin(ctx) {
+ val temporary = ctx.TEMPORARY != null
+ val ifNotExists = ctx.EXISTS != null
+ assert(!temporary || !ifNotExists,
+ "a CREATE TEMPORARY TABLE statement does not allow IF NOT EXISTS clause.",
+ ctx)
+ (visitTableIdentifier(ctx.tableIdentifier), temporary, ifNotExists, ctx.EXTERNAL != null)
+ }
+
+ /**
+ * Create a [[CreateTableUsing]] or a [[CreateTableUsingAsSelect]] logical plan.
+ *
+ * TODO add bucketing and partitioning.
+ */
+ override def visitCreateTableUsing(ctx: CreateTableUsingContext): LogicalPlan = withOrigin(ctx) {
+ val (table, temp, ifNotExists, external) = visitCreateTableHeader(ctx.createTableHeader)
+ if (external) {
+ logWarning("EXTERNAL option is not supported.")
+ }
+ val options = Option(ctx.tablePropertyList).map(visitTablePropertyList).getOrElse(Map.empty)
+ val provider = ctx.tableProvider.qualifiedName.getText
+
+ if (ctx.query != null) {
+ // Get the backing query.
+ val query = plan(ctx.query)
+
+ // Determine the storage mode.
+ val mode = if (ifNotExists) {
+ SaveMode.Ignore
+ } else if (temp) {
+ SaveMode.Overwrite
+ } else {
+ SaveMode.ErrorIfExists
+ }
+ CreateTableUsingAsSelect(table, provider, temp, Array.empty, None, mode, options, query)
+ } else {
+ val struct = Option(ctx.colTypeList).map(createStructType)
+ CreateTableUsing(table, struct, provider, temp, options, ifNotExists, managedIfNoPath = false)
+ }
+ }
+
+ /**
+ * Convert a table property list into a key-value map.
+ */
+ override def visitTablePropertyList(
+ ctx: TablePropertyListContext): Map[String, String] = withOrigin(ctx) {
+ ctx.tableProperty.asScala.map { property =>
+ // A key can either be a String or a collection of dot separated elements. We need to treat
+ // these differently.
+ val key = if (property.key.STRING != null) {
+ string(property.key.STRING)
+ } else {
+ property.key.getText
+ }
+ val value = Option(property.value).map(string).orNull
+ key -> value
+ }.toMap
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index 8abb9d7e4a..7ce15e3f35 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -27,8 +27,8 @@ import org.apache.spark.sql.catalyst.analysis.{Star, UnresolvedFunction}
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
-import org.apache.spark.sql.catalyst.parser.CatalystQl
import org.apache.spark.sql.catalyst.plans.logical.BroadcastHint
+import org.apache.spark.sql.execution.SparkSqlParser
import org.apache.spark.sql.expressions.UserDefinedFunction
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils
@@ -1172,8 +1172,7 @@ object functions {
* @group normal_funcs
*/
def expr(expr: String): Column = {
- val parser = SQLContext.getActive().map(_.sessionState.sqlParser).getOrElse(new CatalystQl())
- Column(parser.parseExpression(expr))
+ Column(SparkSqlParser.parseExpression(expr))
}
//////////////////////////////////////////////////////////////////////////////////////////////
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala
index e5f02caabc..9bc640763f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala
@@ -81,7 +81,7 @@ private[sql] class SessionState(ctx: SQLContext) {
/**
* Parser that extracts expressions, plans, table identifiers etc. from SQL texts.
*/
- lazy val sqlParser: ParserInterface = new SparkQl(conf)
+ lazy val sqlParser: ParserInterface = SparkSqlParser
/**
* Planner that converts optimized logical plans to physical plans.
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
index 5af1a4fcd7..a5a4ff13de 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
@@ -329,8 +329,8 @@ class JoinSuite extends QueryTest with SharedSQLContext {
}
test("full outer join") {
- upperCaseData.where('N <= 4).registerTempTable("left")
- upperCaseData.where('N >= 3).registerTempTable("right")
+ upperCaseData.where('N <= 4).registerTempTable("`left`")
+ upperCaseData.where('N >= 3).registerTempTable("`right`")
val left = UnresolvedRelation(TableIdentifier("left"), None)
val right = UnresolvedRelation(TableIdentifier("right"), None)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index c958eac266..b727e88668 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -1656,7 +1656,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
val e2 = intercept[AnalysisException] {
sql("select interval 23 nanosecond")
}
- assert(e2.message.contains("cannot recognize input near"))
+ assert(e2.message.contains("No interval can be constructed"))
}
test("SPARK-8945: add and subtract expressions for interval type") {