aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--LICENSE1
-rw-r--r--dev/deps/spark-deps-hadoop-2.21
-rw-r--r--dev/deps/spark-deps-hadoop-2.31
-rw-r--r--dev/deps/spark-deps-hadoop-2.41
-rw-r--r--dev/deps/spark-deps-hadoop-2.61
-rw-r--r--dev/deps/spark-deps-hadoop-2.71
-rw-r--r--pom.xml6
-rw-r--r--project/SparkBuild.scala8
-rw-r--r--project/plugins.sbt6
-rw-r--r--python/pyspark/sql/tests.py6
-rw-r--r--python/pyspark/sql/utils.py8
-rw-r--r--sql/catalyst/pom.xml4
-rw-r--r--sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/ng/SqlBase.g4911
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala34
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ng/AstBuilder.scala1452
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ng/ParseDriver.scala240
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ng/ParserUtils.scala118
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/CatalystQlSuite.scala52
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DataTypeParserSuite.scala55
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ng/ErrorParserSuite.scala67
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ng/ExpressionParserSuite.scala497
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ng/PlanParserSuite.scala429
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ng/TableIdentifierParserSuite.scala42
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala20
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala219
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/functions.scala5
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala4
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala2
29 files changed, 4127 insertions, 66 deletions
diff --git a/LICENSE b/LICENSE
index d7a790a628..5a8c78b98b 100644
--- a/LICENSE
+++ b/LICENSE
@@ -238,6 +238,7 @@ The text of each license is also included at licenses/LICENSE-[project].txt.
(BSD 3 Clause) netlib core (com.github.fommil.netlib:core:1.1.2 - https://github.com/fommil/netlib-java/core)
(BSD 3 Clause) JPMML-Model (org.jpmml:pmml-model:1.2.7 - https://github.com/jpmml/jpmml-model)
(BSD License) AntLR Parser Generator (antlr:antlr:2.7.7 - http://www.antlr.org/)
+ (BSD License) ANTLR 4.5.2-1 (org.antlr:antlr4:4.5.2-1 - http://wwww.antlr.org/)
(BSD licence) ANTLR ST4 4.0.4 (org.antlr:ST4:4.0.4 - http://www.stringtemplate.org)
(BSD licence) ANTLR StringTemplate (org.antlr:stringtemplate:3.2.1 - http://www.stringtemplate.org)
(BSD License) Javolution (javolution:javolution:5.5.1 - http://javolution.org)
diff --git a/dev/deps/spark-deps-hadoop-2.2 b/dev/deps/spark-deps-hadoop-2.2
index 512675a599..7c2f88bdb1 100644
--- a/dev/deps/spark-deps-hadoop-2.2
+++ b/dev/deps/spark-deps-hadoop-2.2
@@ -3,6 +3,7 @@ RoaringBitmap-0.5.11.jar
ST4-4.0.4.jar
activation-1.1.jar
antlr-runtime-3.5.2.jar
+antlr4-runtime-4.5.2-1.jar
aopalliance-1.0.jar
apache-log4j-extras-1.2.17.jar
arpack_combined_all-0.1.jar
diff --git a/dev/deps/spark-deps-hadoop-2.3 b/dev/deps/spark-deps-hadoop-2.3
index 31f8694fed..f4d600038d 100644
--- a/dev/deps/spark-deps-hadoop-2.3
+++ b/dev/deps/spark-deps-hadoop-2.3
@@ -3,6 +3,7 @@ RoaringBitmap-0.5.11.jar
ST4-4.0.4.jar
activation-1.1.1.jar
antlr-runtime-3.5.2.jar
+antlr4-runtime-4.5.2-1.jar
aopalliance-1.0.jar
apache-log4j-extras-1.2.17.jar
arpack_combined_all-0.1.jar
diff --git a/dev/deps/spark-deps-hadoop-2.4 b/dev/deps/spark-deps-hadoop-2.4
index 0fa8bccab0..7c5e2c35bd 100644
--- a/dev/deps/spark-deps-hadoop-2.4
+++ b/dev/deps/spark-deps-hadoop-2.4
@@ -3,6 +3,7 @@ RoaringBitmap-0.5.11.jar
ST4-4.0.4.jar
activation-1.1.1.jar
antlr-runtime-3.5.2.jar
+antlr4-runtime-4.5.2-1.jar
aopalliance-1.0.jar
apache-log4j-extras-1.2.17.jar
arpack_combined_all-0.1.jar
diff --git a/dev/deps/spark-deps-hadoop-2.6 b/dev/deps/spark-deps-hadoop-2.6
index 8d2f6e6e32..03d9a51057 100644
--- a/dev/deps/spark-deps-hadoop-2.6
+++ b/dev/deps/spark-deps-hadoop-2.6
@@ -3,6 +3,7 @@ RoaringBitmap-0.5.11.jar
ST4-4.0.4.jar
activation-1.1.1.jar
antlr-runtime-3.5.2.jar
+antlr4-runtime-4.5.2-1.jar
aopalliance-1.0.jar
apache-log4j-extras-1.2.17.jar
apacheds-i18n-2.0.0-M15.jar
diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7
index a114c4ae8d..5765071a1c 100644
--- a/dev/deps/spark-deps-hadoop-2.7
+++ b/dev/deps/spark-deps-hadoop-2.7
@@ -3,6 +3,7 @@ RoaringBitmap-0.5.11.jar
ST4-4.0.4.jar
activation-1.1.1.jar
antlr-runtime-3.5.2.jar
+antlr4-runtime-4.5.2-1.jar
aopalliance-1.0.jar
apache-log4j-extras-1.2.17.jar
apacheds-i18n-2.0.0-M15.jar
diff --git a/pom.xml b/pom.xml
index b4cfa3a598..475f0544bd 100644
--- a/pom.xml
+++ b/pom.xml
@@ -178,6 +178,7 @@
<jsr305.version>1.3.9</jsr305.version>
<libthrift.version>0.9.2</libthrift.version>
<antlr.version>3.5.2</antlr.version>
+ <antlr4.version>4.5.2-1</antlr4.version>
<test.java.home>${java.home}</test.java.home>
<test.exclude.tags></test.exclude.tags>
@@ -1759,6 +1760,11 @@
<artifactId>antlr-runtime</artifactId>
<version>${antlr.version}</version>
</dependency>
+ <dependency>
+ <groupId>org.antlr</groupId>
+ <artifactId>antlr4-runtime</artifactId>
+ <version>${antlr4.version}</version>
+ </dependency>
</dependencies>
</dependencyManagement>
diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala
index fb229b979d..39a9e16f7e 100644
--- a/project/SparkBuild.scala
+++ b/project/SparkBuild.scala
@@ -25,6 +25,7 @@ import sbt._
import sbt.Classpaths.publishTask
import sbt.Keys._
import sbtunidoc.Plugin.UnidocKeys.unidocGenjavadocVersion
+import com.simplytyped.Antlr4Plugin._
import com.typesafe.sbt.pom.{PomBuild, SbtPomKeys}
import com.typesafe.tools.mima.plugin.MimaKeys
@@ -401,7 +402,10 @@ object OldDeps {
}
object Catalyst {
- lazy val settings = Seq(
+ lazy val settings = antlr4Settings ++ Seq(
+ antlr4PackageName in Antlr4 := Some("org.apache.spark.sql.catalyst.parser.ng"),
+ antlr4GenListener in Antlr4 := true,
+ antlr4GenVisitor in Antlr4 := true,
// ANTLR code-generation step.
//
// This has been heavily inspired by com.github.stefri.sbt-antlr (0.5.3). It fixes a number of
@@ -414,7 +418,7 @@ object Catalyst {
"SparkSqlLexer.g",
"SparkSqlParser.g")
val sourceDir = (sourceDirectory in Compile).value / "antlr3"
- val targetDir = (sourceManaged in Compile).value
+ val targetDir = (sourceManaged in Compile).value / "antlr3"
// Create default ANTLR Tool.
val antlr = new org.antlr.Tool
diff --git a/project/plugins.sbt b/project/plugins.sbt
index eeca94a47c..d9ed7962bf 100644
--- a/project/plugins.sbt
+++ b/project/plugins.sbt
@@ -23,3 +23,9 @@ libraryDependencies += "org.ow2.asm" % "asm" % "5.0.3"
libraryDependencies += "org.ow2.asm" % "asm-commons" % "5.0.3"
libraryDependencies += "org.antlr" % "antlr" % "3.5.2"
+
+
+// TODO I am not sure we want such a dep.
+resolvers += "simplytyped" at "http://simplytyped.github.io/repo/releases"
+
+addSbtPlugin("com.simplytyped" % "sbt-antlr4" % "0.7.10")
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 83ef76c13c..1a5d422af9 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -51,7 +51,7 @@ from pyspark.sql.types import UserDefinedType, _infer_type
from pyspark.tests import ReusedPySparkTestCase
from pyspark.sql.functions import UserDefinedFunction, sha2
from pyspark.sql.window import Window
-from pyspark.sql.utils import AnalysisException, IllegalArgumentException
+from pyspark.sql.utils import AnalysisException, ParseException, IllegalArgumentException
class UTCOffsetTimezone(datetime.tzinfo):
@@ -1130,7 +1130,9 @@ class SQLTests(ReusedPySparkTestCase):
def test_capture_analysis_exception(self):
self.assertRaises(AnalysisException, lambda: self.sqlCtx.sql("select abc"))
self.assertRaises(AnalysisException, lambda: self.df.selectExpr("a + b"))
- self.assertRaises(AnalysisException, lambda: self.sqlCtx.sql("abc"))
+
+ def test_capture_parse_exception(self):
+ self.assertRaises(ParseException, lambda: self.sqlCtx.sql("abc"))
def test_capture_illegalargument_exception(self):
self.assertRaisesRegexp(IllegalArgumentException, "Setting negative mapred.reduce.tasks",
diff --git a/python/pyspark/sql/utils.py b/python/pyspark/sql/utils.py
index b0a0373372..b89ea8c6e0 100644
--- a/python/pyspark/sql/utils.py
+++ b/python/pyspark/sql/utils.py
@@ -33,6 +33,12 @@ class AnalysisException(CapturedException):
"""
+class ParseException(CapturedException):
+ """
+ Failed to parse a SQL command.
+ """
+
+
class IllegalArgumentException(CapturedException):
"""
Passed an illegal or inappropriate argument.
@@ -49,6 +55,8 @@ def capture_sql_exception(f):
e.java_exception.getStackTrace()))
if s.startswith('org.apache.spark.sql.AnalysisException: '):
raise AnalysisException(s.split(': ', 1)[1], stackTrace)
+ if s.startswith('org.apache.spark.sql.catalyst.parser.ng.ParseException: '):
+ raise ParseException(s.split(': ', 1)[1], stackTrace)
if s.startswith('java.lang.IllegalArgumentException: '):
raise IllegalArgumentException(s.split(': ', 1)[1], stackTrace)
raise
diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml
index 5d1d9edd25..c834a011f1 100644
--- a/sql/catalyst/pom.xml
+++ b/sql/catalyst/pom.xml
@@ -76,6 +76,10 @@
<artifactId>antlr-runtime</artifactId>
</dependency>
<dependency>
+ <groupId>org.antlr</groupId>
+ <artifactId>antlr4-runtime</artifactId>
+ </dependency>
+ <dependency>
<groupId>commons-codec</groupId>
<artifactId>commons-codec</artifactId>
</dependency>
diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/ng/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/ng/SqlBase.g4
new file mode 100644
index 0000000000..e46fd9bed5
--- /dev/null
+++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/ng/SqlBase.g4
@@ -0,0 +1,911 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ * This file is an adaptation of Presto's presto-parser/src/main/antlr4/com/facebook/presto/sql/parser/SqlBase.g4 grammar.
+ */
+
+grammar SqlBase;
+
+tokens {
+ DELIMITER
+}
+
+singleStatement
+ : statement EOF
+ ;
+
+singleExpression
+ : namedExpression EOF
+ ;
+
+singleTableIdentifier
+ : tableIdentifier EOF
+ ;
+
+singleDataType
+ : dataType EOF
+ ;
+
+statement
+ : query #statementDefault
+ | USE db=identifier #use
+ | CREATE DATABASE (IF NOT EXISTS)? identifier
+ (COMMENT comment=STRING)? locationSpec?
+ (WITH DBPROPERTIES tablePropertyList)? #createDatabase
+ | ALTER DATABASE identifier SET DBPROPERTIES tablePropertyList #setDatabaseProperties
+ | DROP DATABASE (IF EXISTS)? identifier (RESTRICT | CASCADE)? #dropDatabase
+ | createTableHeader ('(' colTypeList ')')? tableProvider
+ (OPTIONS tablePropertyList)? #createTableUsing
+ | createTableHeader tableProvider
+ (OPTIONS tablePropertyList)? AS? query #createTableUsing
+ | createTableHeader ('(' colTypeList ')')? (COMMENT STRING)?
+ (PARTITIONED BY identifierList)? bucketSpec? skewSpec?
+ rowFormat? createFileFormat? locationSpec?
+ (TBLPROPERTIES tablePropertyList)?
+ (AS? query)? #createTable
+ | ANALYZE TABLE tableIdentifier partitionSpec? COMPUTE STATISTICS
+ (identifier | FOR COLUMNS identifierSeq?) #analyze
+ | ALTER TABLE from=tableIdentifier RENAME TO to=tableIdentifier #renameTable
+ | ALTER TABLE tableIdentifier
+ SET TBLPROPERTIES tablePropertyList #setTableProperties
+ | ALTER TABLE tableIdentifier
+ UNSET TBLPROPERTIES (IF EXISTS)? tablePropertyList #unsetTableProperties
+ | ALTER TABLE tableIdentifier (partitionSpec)?
+ SET SERDE STRING (WITH SERDEPROPERTIES tablePropertyList)? #setTableSerDe
+ | ALTER TABLE tableIdentifier (partitionSpec)?
+ SET SERDEPROPERTIES tablePropertyList #setTableSerDe
+ | ALTER TABLE tableIdentifier bucketSpec #bucketTable
+ | ALTER TABLE tableIdentifier NOT CLUSTERED #unclusterTable
+ | ALTER TABLE tableIdentifier NOT SORTED #unsortTable
+ | ALTER TABLE tableIdentifier skewSpec #skewTable
+ | ALTER TABLE tableIdentifier NOT SKEWED #unskewTable
+ | ALTER TABLE tableIdentifier NOT STORED AS DIRECTORIES #unstoreTable
+ | ALTER TABLE tableIdentifier
+ SET SKEWED LOCATION skewedLocationList #setTableSkewLocations
+ | ALTER TABLE tableIdentifier ADD (IF NOT EXISTS)?
+ partitionSpecLocation+ #addTablePartition
+ | ALTER TABLE tableIdentifier
+ from=partitionSpec RENAME TO to=partitionSpec #renameTablePartition
+ | ALTER TABLE from=tableIdentifier
+ EXCHANGE partitionSpec WITH TABLE to=tableIdentifier #exchangeTablePartition
+ | ALTER TABLE tableIdentifier
+ DROP (IF EXISTS)? partitionSpec (',' partitionSpec)* PURGE? #dropTablePartitions
+ | ALTER TABLE tableIdentifier ARCHIVE partitionSpec #archiveTablePartition
+ | ALTER TABLE tableIdentifier UNARCHIVE partitionSpec #unarchiveTablePartition
+ | ALTER TABLE tableIdentifier partitionSpec?
+ SET FILEFORMAT fileFormat #setTableFileFormat
+ | ALTER TABLE tableIdentifier partitionSpec? SET locationSpec #setTableLocation
+ | ALTER TABLE tableIdentifier TOUCH partitionSpec? #touchTable
+ | ALTER TABLE tableIdentifier partitionSpec? COMPACT STRING #compactTable
+ | ALTER TABLE tableIdentifier partitionSpec? CONCATENATE #concatenateTable
+ | ALTER TABLE tableIdentifier partitionSpec?
+ CHANGE COLUMN? oldName=identifier colType
+ (FIRST | AFTER after=identifier)? (CASCADE | RESTRICT)? #changeColumn
+ | ALTER TABLE tableIdentifier partitionSpec?
+ ADD COLUMNS '(' colTypeList ')' (CASCADE | RESTRICT)? #addColumns
+ | ALTER TABLE tableIdentifier partitionSpec?
+ REPLACE COLUMNS '(' colTypeList ')' (CASCADE | RESTRICT)? #replaceColumns
+ | DROP TABLE (IF EXISTS)? tableIdentifier PURGE?
+ (FOR METADATA? REPLICATION '(' STRING ')')? #dropTable
+ | CREATE (OR REPLACE)? VIEW (IF NOT EXISTS)? tableIdentifier
+ identifierCommentList? (COMMENT STRING)?
+ (PARTITIONED ON identifierList)?
+ (TBLPROPERTIES tablePropertyList)? AS query #createView
+ | ALTER VIEW tableIdentifier AS? query #alterViewQuery
+ | CREATE TEMPORARY? FUNCTION qualifiedName AS className=STRING
+ (USING resource (',' resource)*)? #createFunction
+ | DROP TEMPORARY? FUNCTION (IF EXISTS)? qualifiedName #dropFunction
+ | EXPLAIN explainOption* statement #explain
+ | SHOW TABLES ((FROM | IN) db=identifier)?
+ (LIKE (qualifiedName | pattern=STRING))? #showTables
+ | SHOW FUNCTIONS (LIKE? (qualifiedName | pattern=STRING))? #showFunctions
+ | (DESC | DESCRIBE) FUNCTION EXTENDED? qualifiedName #describeFunction
+ | (DESC | DESCRIBE) option=(EXTENDED | FORMATTED)?
+ tableIdentifier partitionSpec? describeColName? #describeTable
+ | (DESC | DESCRIBE) DATABASE EXTENDED? identifier #describeDatabase
+ | REFRESH TABLE tableIdentifier #refreshTable
+ | CACHE LAZY? TABLE identifier (AS? query)? #cacheTable
+ | UNCACHE TABLE identifier #uncacheTable
+ | CLEAR CACHE #clearCache
+ | ADD identifier .*? #addResource
+ | SET .*? #setConfiguration
+ | hiveNativeCommands #executeNativeCommand
+ ;
+
+hiveNativeCommands
+ : createTableHeader LIKE tableIdentifier
+ rowFormat? createFileFormat? locationSpec?
+ (TBLPROPERTIES tablePropertyList)?
+ | DELETE FROM tableIdentifier (WHERE booleanExpression)?
+ | TRUNCATE TABLE tableIdentifier partitionSpec?
+ (COLUMNS identifierList)?
+ | ALTER VIEW from=tableIdentifier AS? RENAME TO to=tableIdentifier
+ | ALTER VIEW from=tableIdentifier AS?
+ SET TBLPROPERTIES tablePropertyList
+ | ALTER VIEW from=tableIdentifier AS?
+ UNSET TBLPROPERTIES (IF EXISTS)? tablePropertyList
+ | ALTER VIEW from=tableIdentifier AS?
+ ADD (IF NOT EXISTS)? partitionSpecLocation+
+ | ALTER VIEW from=tableIdentifier AS?
+ DROP (IF EXISTS)? partitionSpec (',' partitionSpec)* PURGE?
+ | DROP VIEW (IF EXISTS)? qualifiedName
+ | SHOW COLUMNS (FROM | IN) tableIdentifier ((FROM|IN) identifier)?
+ | START TRANSACTION (transactionMode (',' transactionMode)*)?
+ | COMMIT WORK?
+ | ROLLBACK WORK?
+ | SHOW PARTITIONS tableIdentifier partitionSpec?
+ | DFS .*?
+ | (CREATE | ALTER | DROP | SHOW | DESC | DESCRIBE | REVOKE | GRANT | LOCK | UNLOCK | MSCK | EXPORT | IMPORT | LOAD) .*?
+ ;
+
+createTableHeader
+ : CREATE TEMPORARY? EXTERNAL? TABLE (IF NOT EXISTS)? tableIdentifier
+ ;
+
+bucketSpec
+ : CLUSTERED BY identifierList
+ (SORTED BY orderedIdentifierList)?
+ INTO INTEGER_VALUE BUCKETS
+ ;
+
+skewSpec
+ : SKEWED BY identifierList
+ ON (constantList | nestedConstantList)
+ (STORED AS DIRECTORIES)?
+ ;
+
+locationSpec
+ : LOCATION STRING
+ ;
+
+query
+ : ctes? queryNoWith
+ ;
+
+insertInto
+ : INSERT OVERWRITE TABLE tableIdentifier partitionSpec? (IF NOT EXISTS)?
+ | INSERT INTO TABLE? tableIdentifier partitionSpec?
+ ;
+
+partitionSpecLocation
+ : partitionSpec locationSpec?
+ ;
+
+partitionSpec
+ : PARTITION '(' partitionVal (',' partitionVal)* ')'
+ ;
+
+partitionVal
+ : identifier (EQ constant)?
+ ;
+
+describeColName
+ : identifier ('.' (identifier | STRING))*
+ ;
+
+ctes
+ : WITH namedQuery (',' namedQuery)*
+ ;
+
+namedQuery
+ : name=identifier AS? '(' queryNoWith ')'
+ ;
+
+tableProvider
+ : USING qualifiedName
+ ;
+
+tablePropertyList
+ : '(' tableProperty (',' tableProperty)* ')'
+ ;
+
+tableProperty
+ : key=tablePropertyKey (EQ? value=STRING)?
+ ;
+
+tablePropertyKey
+ : looseIdentifier ('.' looseIdentifier)*
+ | STRING
+ ;
+
+constantList
+ : '(' constant (',' constant)* ')'
+ ;
+
+nestedConstantList
+ : '(' constantList (',' constantList)* ')'
+ ;
+
+skewedLocation
+ : (constant | constantList) EQ STRING
+ ;
+
+skewedLocationList
+ : '(' skewedLocation (',' skewedLocation)* ')'
+ ;
+
+createFileFormat
+ : STORED AS fileFormat
+ | STORED BY storageHandler
+ ;
+
+fileFormat
+ : INPUTFORMAT inFmt=STRING OUTPUTFORMAT outFmt=STRING (SERDE serdeCls=STRING)?
+ (INPUTDRIVER inDriver=STRING OUTPUTDRIVER outDriver=STRING)? #tableFileFormat
+ | identifier #genericFileFormat
+ ;
+
+storageHandler
+ : STRING (WITH SERDEPROPERTIES tablePropertyList)?
+ ;
+
+resource
+ : identifier STRING
+ ;
+
+queryNoWith
+ : insertInto? queryTerm queryOrganization #singleInsertQuery
+ | fromClause multiInsertQueryBody+ #multiInsertQuery
+ ;
+
+queryOrganization
+ : (ORDER BY order+=sortItem (',' order+=sortItem)*)?
+ (CLUSTER BY clusterBy+=expression (',' clusterBy+=expression)*)?
+ (DISTRIBUTE BY distributeBy+=expression (',' distributeBy+=expression)*)?
+ (SORT BY sort+=sortItem (',' sort+=sortItem)*)?
+ windows?
+ (LIMIT limit=expression)?
+ ;
+
+multiInsertQueryBody
+ : insertInto?
+ querySpecification
+ queryOrganization
+ ;
+
+queryTerm
+ : queryPrimary #queryTermDefault
+ | left=queryTerm operator=(INTERSECT | UNION | EXCEPT) setQuantifier? right=queryTerm #setOperation
+ ;
+
+queryPrimary
+ : querySpecification #queryPrimaryDefault
+ | TABLE tableIdentifier #table
+ | inlineTable #inlineTableDefault1
+ | '(' queryNoWith ')' #subquery
+ ;
+
+sortItem
+ : expression ordering=(ASC | DESC)?
+ ;
+
+querySpecification
+ : (((SELECT kind=TRANSFORM '(' namedExpressionSeq ')'
+ | kind=MAP namedExpressionSeq
+ | kind=REDUCE namedExpressionSeq))
+ inRowFormat=rowFormat?
+ (RECORDWRITER recordWriter=STRING)?
+ USING script=STRING
+ (AS (identifierSeq | colTypeList | ('(' (identifierSeq | colTypeList) ')')))?
+ outRowFormat=rowFormat?
+ (RECORDREADER recordReader=STRING)?
+ fromClause?
+ (WHERE where=booleanExpression)?)
+ | ((kind=SELECT setQuantifier? namedExpressionSeq fromClause?
+ | fromClause (kind=SELECT setQuantifier? namedExpressionSeq)?)
+ lateralView*
+ (WHERE where=booleanExpression)?
+ aggregation?
+ (HAVING having=booleanExpression)?
+ windows?)
+ ;
+
+fromClause
+ : FROM relation (',' relation)* lateralView*
+ ;
+
+aggregation
+ : GROUP BY groupingExpressions+=expression (',' groupingExpressions+=expression)* (
+ WITH kind=ROLLUP
+ | WITH kind=CUBE
+ | kind=GROUPING SETS '(' groupingSet (',' groupingSet)* ')')?
+ ;
+
+groupingSet
+ : '(' (expression (',' expression)*)? ')'
+ | expression
+ ;
+
+lateralView
+ : LATERAL VIEW (OUTER)? qualifiedName '(' (expression (',' expression)*)? ')' tblName=identifier (AS? colName+=identifier (',' colName+=identifier)*)?
+ ;
+
+setQuantifier
+ : DISTINCT
+ | ALL
+ ;
+
+relation
+ : left=relation
+ ((CROSS | joinType) JOIN right=relation joinCriteria?
+ | NATURAL joinType JOIN right=relation
+ ) #joinRelation
+ | relationPrimary #relationDefault
+ ;
+
+joinType
+ : INNER?
+ | LEFT OUTER?
+ | LEFT SEMI
+ | RIGHT OUTER?
+ | FULL OUTER?
+ ;
+
+joinCriteria
+ : ON booleanExpression
+ | USING '(' identifier (',' identifier)* ')'
+ ;
+
+sample
+ : TABLESAMPLE '('
+ ( (percentage=(INTEGER_VALUE | DECIMAL_VALUE) sampleType=PERCENTLIT)
+ | (expression sampleType=ROWS)
+ | (sampleType=BUCKET numerator=INTEGER_VALUE OUT OF denominator=INTEGER_VALUE (ON identifier)?))
+ ')'
+ ;
+
+identifierList
+ : '(' identifierSeq ')'
+ ;
+
+identifierSeq
+ : identifier (',' identifier)*
+ ;
+
+orderedIdentifierList
+ : '(' orderedIdentifier (',' orderedIdentifier)* ')'
+ ;
+
+orderedIdentifier
+ : identifier ordering=(ASC | DESC)?
+ ;
+
+identifierCommentList
+ : '(' identifierComment (',' identifierComment)* ')'
+ ;
+
+identifierComment
+ : identifier (COMMENT STRING)?
+ ;
+
+relationPrimary
+ : tableIdentifier sample? (AS? identifier)? #tableName
+ | '(' queryNoWith ')' sample? (AS? identifier)? #aliasedQuery
+ | '(' relation ')' sample? (AS? identifier)? #aliasedRelation
+ | inlineTable #inlineTableDefault2
+ ;
+
+inlineTable
+ : VALUES expression (',' expression)* (AS? identifier identifierList?)?
+ ;
+
+rowFormat
+ : ROW FORMAT SERDE name=STRING (WITH SERDEPROPERTIES props=tablePropertyList)? #rowFormatSerde
+ | ROW FORMAT DELIMITED
+ (FIELDS TERMINATED BY fieldsTerminatedBy=STRING (ESCAPED BY escapedBy=STRING)?)?
+ (COLLECTION ITEMS TERMINATED BY collectionItemsTerminatedBy=STRING)?
+ (MAP KEYS TERMINATED BY keysTerminatedBy=STRING)?
+ (LINES TERMINATED BY linesSeparatedBy=STRING)?
+ (NULL DEFINED AS nullDefinedAs=STRING)? #rowFormatDelimited
+ ;
+
+tableIdentifier
+ : (db=identifier '.')? table=identifier
+ ;
+
+namedExpression
+ : expression (AS? (identifier | identifierList))?
+ ;
+
+namedExpressionSeq
+ : namedExpression (',' namedExpression)*
+ ;
+
+expression
+ : booleanExpression
+ ;
+
+booleanExpression
+ : predicated #booleanDefault
+ | NOT booleanExpression #logicalNot
+ | left=booleanExpression operator=AND right=booleanExpression #logicalBinary
+ | left=booleanExpression operator=OR right=booleanExpression #logicalBinary
+ | EXISTS '(' query ')' #exists
+ ;
+
+// workaround for:
+// https://github.com/antlr/antlr4/issues/780
+// https://github.com/antlr/antlr4/issues/781
+predicated
+ : valueExpression predicate[$valueExpression.ctx]?
+ ;
+
+predicate[ParserRuleContext value]
+ : NOT? BETWEEN lower=valueExpression AND upper=valueExpression #between
+ | NOT? IN '(' expression (',' expression)* ')' #inList
+ | NOT? IN '(' query ')' #inSubquery
+ | NOT? like=(RLIKE | LIKE) pattern=valueExpression #like
+ | IS NOT? NULL #nullPredicate
+ ;
+
+valueExpression
+ : primaryExpression #valueExpressionDefault
+ | operator=(MINUS | PLUS | TILDE) valueExpression #arithmeticUnary
+ | left=valueExpression operator=(ASTERISK | SLASH | PERCENT | DIV) right=valueExpression #arithmeticBinary
+ | left=valueExpression operator=(PLUS | MINUS) right=valueExpression #arithmeticBinary
+ | left=valueExpression operator=AMPERSAND right=valueExpression #arithmeticBinary
+ | left=valueExpression operator=HAT right=valueExpression #arithmeticBinary
+ | left=valueExpression operator=PIPE right=valueExpression #arithmeticBinary
+ | left=valueExpression comparisonOperator right=valueExpression #comparison
+ ;
+
+primaryExpression
+ : constant #constantDefault
+ | ASTERISK #star
+ | qualifiedName '.' ASTERISK #star
+ | '(' expression (',' expression)+ ')' #rowConstructor
+ | qualifiedName '(' (setQuantifier? expression (',' expression)*)? ')' (OVER windowSpec)? #functionCall
+ | '(' query ')' #subqueryExpression
+ | CASE valueExpression whenClause+ (ELSE elseExpression=expression)? END #simpleCase
+ | CASE whenClause+ (ELSE elseExpression=expression)? END #searchedCase
+ | CAST '(' expression AS dataType ')' #cast
+ | value=primaryExpression '[' index=valueExpression ']' #subscript
+ | identifier #columnReference
+ | base=primaryExpression '.' fieldName=identifier #dereference
+ | '(' expression ')' #parenthesizedExpression
+ ;
+
+constant
+ : NULL #nullLiteral
+ | interval #intervalLiteral
+ | identifier STRING #typeConstructor
+ | number #numericLiteral
+ | booleanValue #booleanLiteral
+ | STRING+ #stringLiteral
+ ;
+
+comparisonOperator
+ : EQ | NEQ | NEQJ | LT | LTE | GT | GTE | NSEQ
+ ;
+
+booleanValue
+ : TRUE | FALSE
+ ;
+
+interval
+ : INTERVAL intervalField*
+ ;
+
+intervalField
+ : value=intervalValue unit=identifier (TO to=identifier)?
+ ;
+
+intervalValue
+ : (PLUS | MINUS)? (INTEGER_VALUE | DECIMAL_VALUE)
+ | STRING
+ ;
+
+dataType
+ : complex=ARRAY '<' dataType '>' #complexDataType
+ | complex=MAP '<' dataType ',' dataType '>' #complexDataType
+ | complex=STRUCT ('<' colTypeList? '>' | NEQ) #complexDataType
+ | identifier ('(' INTEGER_VALUE (',' INTEGER_VALUE)* ')')? #primitiveDataType
+ ;
+
+colTypeList
+ : colType (',' colType)*
+ ;
+
+colType
+ : identifier ':'? dataType (COMMENT STRING)?
+ ;
+
+whenClause
+ : WHEN condition=expression THEN result=expression
+ ;
+
+windows
+ : WINDOW namedWindow (',' namedWindow)*
+ ;
+
+namedWindow
+ : identifier AS windowSpec
+ ;
+
+windowSpec
+ : name=identifier #windowRef
+ | '('
+ ( CLUSTER BY partition+=expression (',' partition+=expression)*
+ | ((PARTITION | DISTRIBUTE) BY partition+=expression (',' partition+=expression)*)?
+ ((ORDER | SORT) BY sortItem (',' sortItem)*)?)
+ windowFrame?
+ ')' #windowDef
+ ;
+
+windowFrame
+ : frameType=RANGE start=frameBound
+ | frameType=ROWS start=frameBound
+ | frameType=RANGE BETWEEN start=frameBound AND end=frameBound
+ | frameType=ROWS BETWEEN start=frameBound AND end=frameBound
+ ;
+
+frameBound
+ : UNBOUNDED boundType=(PRECEDING | FOLLOWING)
+ | boundType=CURRENT ROW
+ | expression boundType=(PRECEDING | FOLLOWING)
+ ;
+
+
+explainOption
+ : LOGICAL | FORMATTED | EXTENDED
+ ;
+
+transactionMode
+ : ISOLATION LEVEL SNAPSHOT #isolationLevel
+ | READ accessMode=(ONLY | WRITE) #transactionAccessMode
+ ;
+
+qualifiedName
+ : identifier ('.' identifier)*
+ ;
+
+// Identifier that also allows the use of a number of SQL keywords (mainly for backwards compatibility).
+looseIdentifier
+ : identifier
+ | FROM
+ | TO
+ | TABLE
+ | WITH
+ ;
+
+identifier
+ : IDENTIFIER #unquotedIdentifier
+ | quotedIdentifier #quotedIdentifierAlternative
+ | nonReserved #unquotedIdentifier
+ ;
+
+quotedIdentifier
+ : BACKQUOTED_IDENTIFIER
+ ;
+
+number
+ : DECIMAL_VALUE #decimalLiteral
+ | SCIENTIFIC_DECIMAL_VALUE #scientificDecimalLiteral
+ | INTEGER_VALUE #integerLiteral
+ | BIGINT_LITERAL #bigIntLiteral
+ | SMALLINT_LITERAL #smallIntLiteral
+ | TINYINT_LITERAL #tinyIntLiteral
+ | DOUBLE_LITERAL #doubleLiteral
+ ;
+
+nonReserved
+ : SHOW | TABLES | COLUMNS | COLUMN | PARTITIONS | FUNCTIONS
+ | ADD
+ | OVER | PARTITION | RANGE | ROWS | PRECEDING | FOLLOWING | CURRENT | ROW | MAP | ARRAY | STRUCT
+ | LATERAL | WINDOW | REDUCE | TRANSFORM | USING | SERDE | SERDEPROPERTIES | RECORDREADER
+ | DELIMITED | FIELDS | TERMINATED | COLLECTION | ITEMS | KEYS | ESCAPED | LINES | SEPARATED
+ | EXTENDED | REFRESH | CLEAR | CACHE | UNCACHE | LAZY | TEMPORARY | OPTIONS
+ | GROUPING | CUBE | ROLLUP
+ | EXPLAIN | FORMAT | LOGICAL | FORMATTED
+ | TABLESAMPLE | USE | TO | BUCKET | PERCENTLIT | OUT | OF
+ | SET
+ | VIEW | REPLACE
+ | IF
+ | NO | DATA
+ | START | TRANSACTION | COMMIT | ROLLBACK | WORK | ISOLATION | LEVEL
+ | SNAPSHOT | READ | WRITE | ONLY
+ | SORT | CLUSTER | DISTRIBUTE UNSET | TBLPROPERTIES | SKEWED | STORED | DIRECTORIES | LOCATION
+ | EXCHANGE | ARCHIVE | UNARCHIVE | FILEFORMAT | TOUCH | COMPACT | CONCATENATE | CHANGE | FIRST
+ | AFTER | CASCADE | RESTRICT | BUCKETS | CLUSTERED | SORTED | PURGE | INPUTFORMAT | OUTPUTFORMAT
+ | INPUTDRIVER | OUTPUTDRIVER | DBPROPERTIES | DFS | TRUNCATE | METADATA | REPLICATION | COMPUTE
+ | STATISTICS | ANALYZE | PARTITIONED | EXTERNAL | DEFINED | RECORDWRITER
+ | REVOKE | GRANT | LOCK | UNLOCK | MSCK | EXPORT | IMPORT | LOAD | VALUES | COMMENT
+ ;
+
+SELECT: 'SELECT';
+FROM: 'FROM';
+ADD: 'ADD';
+AS: 'AS';
+ALL: 'ALL';
+DISTINCT: 'DISTINCT';
+WHERE: 'WHERE';
+GROUP: 'GROUP';
+BY: 'BY';
+GROUPING: 'GROUPING';
+SETS: 'SETS';
+CUBE: 'CUBE';
+ROLLUP: 'ROLLUP';
+ORDER: 'ORDER';
+HAVING: 'HAVING';
+LIMIT: 'LIMIT';
+AT: 'AT';
+OR: 'OR';
+AND: 'AND';
+IN: 'IN';
+NOT: 'NOT' | '!';
+NO: 'NO';
+EXISTS: 'EXISTS';
+BETWEEN: 'BETWEEN';
+LIKE: 'LIKE';
+RLIKE: 'RLIKE' | 'REGEXP';
+IS: 'IS';
+NULL: 'NULL';
+TRUE: 'TRUE';
+FALSE: 'FALSE';
+NULLS: 'NULLS';
+ASC: 'ASC';
+DESC: 'DESC';
+FOR: 'FOR';
+INTERVAL: 'INTERVAL';
+CASE: 'CASE';
+WHEN: 'WHEN';
+THEN: 'THEN';
+ELSE: 'ELSE';
+END: 'END';
+JOIN: 'JOIN';
+CROSS: 'CROSS';
+OUTER: 'OUTER';
+INNER: 'INNER';
+LEFT: 'LEFT';
+SEMI: 'SEMI';
+RIGHT: 'RIGHT';
+FULL: 'FULL';
+NATURAL: 'NATURAL';
+ON: 'ON';
+LATERAL: 'LATERAL';
+WINDOW: 'WINDOW';
+OVER: 'OVER';
+PARTITION: 'PARTITION';
+RANGE: 'RANGE';
+ROWS: 'ROWS';
+UNBOUNDED: 'UNBOUNDED';
+PRECEDING: 'PRECEDING';
+FOLLOWING: 'FOLLOWING';
+CURRENT: 'CURRENT';
+ROW: 'ROW';
+WITH: 'WITH';
+VALUES: 'VALUES';
+CREATE: 'CREATE';
+TABLE: 'TABLE';
+VIEW: 'VIEW';
+REPLACE: 'REPLACE';
+INSERT: 'INSERT';
+DELETE: 'DELETE';
+INTO: 'INTO';
+DESCRIBE: 'DESCRIBE';
+EXPLAIN: 'EXPLAIN';
+FORMAT: 'FORMAT';
+LOGICAL: 'LOGICAL';
+CAST: 'CAST';
+SHOW: 'SHOW';
+TABLES: 'TABLES';
+COLUMNS: 'COLUMNS';
+COLUMN: 'COLUMN';
+USE: 'USE';
+PARTITIONS: 'PARTITIONS';
+FUNCTIONS: 'FUNCTIONS';
+DROP: 'DROP';
+UNION: 'UNION';
+EXCEPT: 'EXCEPT';
+INTERSECT: 'INTERSECT';
+TO: 'TO';
+TABLESAMPLE: 'TABLESAMPLE';
+STRATIFY: 'STRATIFY';
+ALTER: 'ALTER';
+RENAME: 'RENAME';
+ARRAY: 'ARRAY';
+MAP: 'MAP';
+STRUCT: 'STRUCT';
+COMMENT: 'COMMENT';
+SET: 'SET';
+DATA: 'DATA';
+START: 'START';
+TRANSACTION: 'TRANSACTION';
+COMMIT: 'COMMIT';
+ROLLBACK: 'ROLLBACK';
+WORK: 'WORK';
+ISOLATION: 'ISOLATION';
+LEVEL: 'LEVEL';
+SNAPSHOT: 'SNAPSHOT';
+READ: 'READ';
+WRITE: 'WRITE';
+ONLY: 'ONLY';
+
+IF: 'IF';
+
+EQ : '=' | '==';
+NSEQ: '<=>';
+NEQ : '<>';
+NEQJ: '!=';
+LT : '<';
+LTE : '<=';
+GT : '>';
+GTE : '>=';
+
+PLUS: '+';
+MINUS: '-';
+ASTERISK: '*';
+SLASH: '/';
+PERCENT: '%';
+DIV: 'DIV';
+TILDE: '~';
+AMPERSAND: '&';
+PIPE: '|';
+HAT: '^';
+
+PERCENTLIT: 'PERCENT';
+BUCKET: 'BUCKET';
+OUT: 'OUT';
+OF: 'OF';
+
+SORT: 'SORT';
+CLUSTER: 'CLUSTER';
+DISTRIBUTE: 'DISTRIBUTE';
+OVERWRITE: 'OVERWRITE';
+TRANSFORM: 'TRANSFORM';
+REDUCE: 'REDUCE';
+USING: 'USING';
+SERDE: 'SERDE';
+SERDEPROPERTIES: 'SERDEPROPERTIES';
+RECORDREADER: 'RECORDREADER';
+RECORDWRITER: 'RECORDWRITER';
+DELIMITED: 'DELIMITED';
+FIELDS: 'FIELDS';
+TERMINATED: 'TERMINATED';
+COLLECTION: 'COLLECTION';
+ITEMS: 'ITEMS';
+KEYS: 'KEYS';
+ESCAPED: 'ESCAPED';
+LINES: 'LINES';
+SEPARATED: 'SEPARATED';
+FUNCTION: 'FUNCTION';
+EXTENDED: 'EXTENDED';
+REFRESH: 'REFRESH';
+CLEAR: 'CLEAR';
+CACHE: 'CACHE';
+UNCACHE: 'UNCACHE';
+LAZY: 'LAZY';
+FORMATTED: 'FORMATTED';
+TEMPORARY: 'TEMPORARY' | 'TEMP';
+OPTIONS: 'OPTIONS';
+UNSET: 'UNSET';
+TBLPROPERTIES: 'TBLPROPERTIES';
+DBPROPERTIES: 'DBPROPERTIES';
+BUCKETS: 'BUCKETS';
+SKEWED: 'SKEWED';
+STORED: 'STORED';
+DIRECTORIES: 'DIRECTORIES';
+LOCATION: 'LOCATION';
+EXCHANGE: 'EXCHANGE';
+ARCHIVE: 'ARCHIVE';
+UNARCHIVE: 'UNARCHIVE';
+FILEFORMAT: 'FILEFORMAT';
+TOUCH: 'TOUCH';
+COMPACT: 'COMPACT';
+CONCATENATE: 'CONCATENATE';
+CHANGE: 'CHANGE';
+FIRST: 'FIRST';
+AFTER: 'AFTER';
+CASCADE: 'CASCADE';
+RESTRICT: 'RESTRICT';
+CLUSTERED: 'CLUSTERED';
+SORTED: 'SORTED';
+PURGE: 'PURGE';
+INPUTFORMAT: 'INPUTFORMAT';
+OUTPUTFORMAT: 'OUTPUTFORMAT';
+INPUTDRIVER: 'INPUTDRIVER';
+OUTPUTDRIVER: 'OUTPUTDRIVER';
+DATABASE: 'DATABASE' | 'SCHEMA';
+DFS: 'DFS';
+TRUNCATE: 'TRUNCATE';
+METADATA: 'METADATA';
+REPLICATION: 'REPLICATION';
+ANALYZE: 'ANALYZE';
+COMPUTE: 'COMPUTE';
+STATISTICS: 'STATISTICS';
+PARTITIONED: 'PARTITIONED';
+EXTERNAL: 'EXTERNAL';
+DEFINED: 'DEFINED';
+REVOKE: 'REVOKE';
+GRANT: 'GRANT';
+LOCK: 'LOCK';
+UNLOCK: 'UNLOCK';
+MSCK: 'MSCK';
+EXPORT: 'EXPORT';
+IMPORT: 'IMPORT';
+LOAD: 'LOAD';
+
+STRING
+ : '\'' ( ~('\''|'\\') | ('\\' .) )* '\''
+ | '\"' ( ~('\"'|'\\') | ('\\' .) )* '\"'
+ ;
+
+BIGINT_LITERAL
+ : DIGIT+ 'L'
+ ;
+
+SMALLINT_LITERAL
+ : DIGIT+ 'S'
+ ;
+
+TINYINT_LITERAL
+ : DIGIT+ 'Y'
+ ;
+
+INTEGER_VALUE
+ : DIGIT+
+ ;
+
+DECIMAL_VALUE
+ : DIGIT+ '.' DIGIT*
+ | '.' DIGIT+
+ ;
+
+SCIENTIFIC_DECIMAL_VALUE
+ : DIGIT+ ('.' DIGIT*)? EXPONENT
+ | '.' DIGIT+ EXPONENT
+ ;
+
+DOUBLE_LITERAL
+ :
+ (INTEGER_VALUE | DECIMAL_VALUE | SCIENTIFIC_DECIMAL_VALUE) 'D'
+ ;
+
+IDENTIFIER
+ : (LETTER | DIGIT | '_')+
+ ;
+
+BACKQUOTED_IDENTIFIER
+ : '`' ( ~'`' | '``' )* '`'
+ ;
+
+fragment EXPONENT
+ : 'E' [+-]? DIGIT+
+ ;
+
+fragment DIGIT
+ : [0-9]
+ ;
+
+fragment LETTER
+ : [A-Z]
+ ;
+
+SIMPLE_COMMENT
+ : '--' ~[\r\n]* '\r'? '\n'? -> channel(HIDDEN)
+ ;
+
+BRACKETED_COMMENT
+ : '/*' .*? '*/' -> channel(HIDDEN)
+ ;
+
+WS
+ : [ \r\n\t]+ -> channel(HIDDEN)
+ ;
+
+// Catch-all for anything we can't recognize.
+// We use this to be able to ignore and recover all the text
+// when splitting statements with DelimiterLexer
+UNRECOGNIZED
+ : .
+ ;
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
index 3540014c3e..105947028d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
@@ -21,7 +21,7 @@ import java.sql.{Date, Timestamp}
import scala.language.implicitConversions
-import org.apache.spark.sql.catalyst.analysis.{EliminateSubqueryAliases, UnresolvedAttribute, UnresolvedExtractValue}
+import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.plans.{Inner, JoinType}
@@ -161,6 +161,10 @@ package object dsl {
def lower(e: Expression): Expression = Lower(e)
def sqrt(e: Expression): Expression = Sqrt(e)
def abs(e: Expression): Expression = Abs(e)
+ def star(names: String*): Expression = names match {
+ case Seq() => UnresolvedStar(None)
+ case target => UnresolvedStar(Option(target))
+ }
implicit class DslSymbol(sym: Symbol) extends ImplicitAttribute { def s: String = sym.name }
// TODO more implicit class for literal?
@@ -231,6 +235,12 @@ package object dsl {
AttributeReference(s, structType, nullable = true)()
def struct(attrs: AttributeReference*): AttributeReference =
struct(StructType.fromAttributes(attrs))
+
+ /** Create a function. */
+ def function(exprs: Expression*): UnresolvedFunction =
+ UnresolvedFunction(s, exprs, isDistinct = false)
+ def distinctFunction(exprs: Expression*): UnresolvedFunction =
+ UnresolvedFunction(s, exprs, isDistinct = true)
}
implicit class DslAttribute(a: AttributeReference) {
@@ -243,8 +253,20 @@ package object dsl {
object expressions extends ExpressionConversions // scalastyle:ignore
object plans { // scalastyle:ignore
+ def table(ref: String): LogicalPlan =
+ UnresolvedRelation(TableIdentifier(ref), None)
+
+ def table(db: String, ref: String): LogicalPlan =
+ UnresolvedRelation(TableIdentifier(ref, Option(db)), None)
+
implicit class DslLogicalPlan(val logicalPlan: LogicalPlan) {
- def select(exprs: NamedExpression*): LogicalPlan = Project(exprs, logicalPlan)
+ def select(exprs: Expression*): LogicalPlan = {
+ val namedExpressions = exprs.map {
+ case e: NamedExpression => e
+ case e => UnresolvedAlias(e)
+ }
+ Project(namedExpressions, logicalPlan)
+ }
def where(condition: Expression): LogicalPlan = Filter(condition, logicalPlan)
@@ -296,6 +318,14 @@ package object dsl {
analysis.UnresolvedRelation(TableIdentifier(tableName)),
Map.empty, logicalPlan, overwrite, false)
+ def as(alias: String): LogicalPlan = logicalPlan match {
+ case UnresolvedRelation(tbl, _) => UnresolvedRelation(tbl, Option(alias))
+ case plan => SubqueryAlias(alias, plan)
+ }
+
+ def distribute(exprs: Expression*): LogicalPlan =
+ RepartitionByExpression(exprs, logicalPlan)
+
def analyze: LogicalPlan =
EliminateSubqueryAliases(analysis.SimpleAnalyzer.execute(logicalPlan))
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ng/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ng/AstBuilder.scala
new file mode 100644
index 0000000000..5a64c414fb
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ng/AstBuilder.scala
@@ -0,0 +1,1452 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.catalyst.parser.ng
+
+import java.sql.{Date, Timestamp}
+
+import scala.collection.JavaConverters._
+import scala.collection.mutable.ArrayBuffer
+
+import org.antlr.v4.runtime.{ParserRuleContext, Token}
+import org.antlr.v4.runtime.tree.{ParseTree, TerminalNode}
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier}
+import org.apache.spark.sql.catalyst.analysis._
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.parser.ng.SqlBaseParser._
+import org.apache.spark.sql.catalyst.plans._
+import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.CalendarInterval
+import org.apache.spark.util.random.RandomSampler
+
+/**
+ * The AstBuilder converts an ANTLR4 ParseTree into a catalyst Expression, LogicalPlan or
+ * TableIdentifier.
+ */
+class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging {
+ import ParserUtils._
+
+ protected def typedVisit[T](ctx: ParseTree): T = {
+ ctx.accept(this).asInstanceOf[T]
+ }
+
+ override def visitSingleStatement(ctx: SingleStatementContext): LogicalPlan = withOrigin(ctx) {
+ visit(ctx.statement).asInstanceOf[LogicalPlan]
+ }
+
+ override def visitSingleExpression(ctx: SingleExpressionContext): Expression = withOrigin(ctx) {
+ visitNamedExpression(ctx.namedExpression)
+ }
+
+ override def visitSingleTableIdentifier(
+ ctx: SingleTableIdentifierContext): TableIdentifier = withOrigin(ctx) {
+ visitTableIdentifier(ctx.tableIdentifier)
+ }
+
+ override def visitSingleDataType(ctx: SingleDataTypeContext): DataType = withOrigin(ctx) {
+ visit(ctx.dataType).asInstanceOf[DataType]
+ }
+
+ /* ********************************************************************************************
+ * Plan parsing
+ * ******************************************************************************************** */
+ protected def plan(tree: ParserRuleContext): LogicalPlan = typedVisit(tree)
+
+ /**
+ * Make sure we do not try to create a plan for a native command.
+ */
+ override def visitExecuteNativeCommand(ctx: ExecuteNativeCommandContext): LogicalPlan = null
+
+ /**
+ * Create a plan for a SHOW FUNCTIONS command.
+ */
+ override def visitShowFunctions(ctx: ShowFunctionsContext): LogicalPlan = withOrigin(ctx) {
+ import ctx._
+ if (qualifiedName != null) {
+ val names = qualifiedName().identifier().asScala.map(_.getText).toList
+ names match {
+ case db :: name :: Nil =>
+ ShowFunctions(Some(db), Some(name))
+ case name :: Nil =>
+ ShowFunctions(None, Some(name))
+ case _ =>
+ throw new ParseException("SHOW FUNCTIONS unsupported name", ctx)
+ }
+ } else if (pattern != null) {
+ ShowFunctions(None, Some(string(pattern)))
+ } else {
+ ShowFunctions(None, None)
+ }
+ }
+
+ /**
+ * Create a plan for a DESCRIBE FUNCTION command.
+ */
+ override def visitDescribeFunction(ctx: DescribeFunctionContext): LogicalPlan = withOrigin(ctx) {
+ val functionName = ctx.qualifiedName().identifier().asScala.map(_.getText).mkString(".")
+ DescribeFunction(functionName, ctx.EXTENDED != null)
+ }
+
+ /**
+ * Create a top-level plan with Common Table Expressions.
+ */
+ override def visitQuery(ctx: QueryContext): LogicalPlan = withOrigin(ctx) {
+ val query = plan(ctx.queryNoWith)
+
+ // Apply CTEs
+ query.optional(ctx.ctes) {
+ val ctes = ctx.ctes.namedQuery.asScala.map {
+ case nCtx =>
+ val namedQuery = visitNamedQuery(nCtx)
+ (namedQuery.alias, namedQuery)
+ }
+
+ // Check for duplicate names.
+ ctes.groupBy(_._1).filter(_._2.size > 1).foreach {
+ case (name, _) =>
+ throw new ParseException(
+ s"Name '$name' is used for multiple common table expressions", ctx)
+ }
+
+ With(query, ctes.toMap)
+ }
+ }
+
+ /**
+ * Create a named logical plan.
+ *
+ * This is only used for Common Table Expressions.
+ */
+ override def visitNamedQuery(ctx: NamedQueryContext): SubqueryAlias = withOrigin(ctx) {
+ SubqueryAlias(ctx.name.getText, plan(ctx.queryNoWith))
+ }
+
+ /**
+ * Create a logical plan which allows for multiple inserts using one 'from' statement. These
+ * queries have the following SQL form:
+ * {{{
+ * [WITH cte...]?
+ * FROM src
+ * [INSERT INTO tbl1 SELECT *]+
+ * }}}
+ * For example:
+ * {{{
+ * FROM db.tbl1 A
+ * INSERT INTO dbo.tbl1 SELECT * WHERE A.value = 10 LIMIT 5
+ * INSERT INTO dbo.tbl2 SELECT * WHERE A.value = 12
+ * }}}
+ * This (Hive) feature cannot be combined with set-operators.
+ */
+ override def visitMultiInsertQuery(ctx: MultiInsertQueryContext): LogicalPlan = withOrigin(ctx) {
+ val from = visitFromClause(ctx.fromClause)
+
+ // Build the insert clauses.
+ val inserts = ctx.multiInsertQueryBody.asScala.map {
+ body =>
+ assert(body.querySpecification.fromClause == null,
+ "Multi-Insert queries cannot have a FROM clause in their individual SELECT statements",
+ body)
+
+ withQuerySpecification(body.querySpecification, from).
+ // Add organization statements.
+ optionalMap(body.queryOrganization)(withQueryResultClauses).
+ // Add insert.
+ optionalMap(body.insertInto())(withInsertInto)
+ }
+
+ // If there are multiple INSERTS just UNION them together into one query.
+ inserts match {
+ case Seq(query) => query
+ case queries => Union(queries)
+ }
+ }
+
+ /**
+ * Create a logical plan for a regular (single-insert) query.
+ */
+ override def visitSingleInsertQuery(
+ ctx: SingleInsertQueryContext): LogicalPlan = withOrigin(ctx) {
+ plan(ctx.queryTerm).
+ // Add organization statements.
+ optionalMap(ctx.queryOrganization)(withQueryResultClauses).
+ // Add insert.
+ optionalMap(ctx.insertInto())(withInsertInto)
+ }
+
+ /**
+ * Add an INSERT INTO [TABLE]/INSERT OVERWRITE TABLE operation to the logical plan.
+ */
+ private def withInsertInto(
+ ctx: InsertIntoContext,
+ query: LogicalPlan): LogicalPlan = withOrigin(ctx) {
+ val tableIdent = visitTableIdentifier(ctx.tableIdentifier)
+ val partitionKeys = Option(ctx.partitionSpec).map(visitPartitionSpec).getOrElse(Map.empty)
+
+ InsertIntoTable(
+ UnresolvedRelation(tableIdent, None),
+ partitionKeys,
+ query,
+ ctx.OVERWRITE != null,
+ ctx.EXISTS != null)
+ }
+
+ /**
+ * Create a partition specification map.
+ */
+ override def visitPartitionSpec(
+ ctx: PartitionSpecContext): Map[String, Option[String]] = withOrigin(ctx) {
+ ctx.partitionVal.asScala.map { pVal =>
+ val name = pVal.identifier.getText.toLowerCase
+ val value = Option(pVal.constant).map(visitStringConstant)
+ name -> value
+ }.toMap
+ }
+
+ /**
+ * Create a partition specification map without optional values.
+ */
+ protected def visitNonOptionalPartitionSpec(
+ ctx: PartitionSpecContext): Map[String, String] = withOrigin(ctx) {
+ visitPartitionSpec(ctx).mapValues(_.orNull).map(identity)
+ }
+
+ /**
+ * Convert a constant of any type into a string. This is typically used in DDL commands, and its
+ * main purpose is to prevent slight differences due to back to back conversions i.e.:
+ * String -> Literal -> String.
+ */
+ protected def visitStringConstant(ctx: ConstantContext): String = withOrigin(ctx) {
+ ctx match {
+ case s: StringLiteralContext => createString(s)
+ case o => o.getText
+ }
+ }
+
+ /**
+ * Add ORDER BY/SORT BY/CLUSTER BY/DISTRIBUTE BY/LIMIT/WINDOWS clauses to the logical plan. These
+ * clauses determine the shape (ordering/partitioning/rows) of the query result.
+ */
+ private def withQueryResultClauses(
+ ctx: QueryOrganizationContext,
+ query: LogicalPlan): LogicalPlan = withOrigin(ctx) {
+ import ctx._
+
+ // Handle ORDER BY, SORT BY, DISTRIBUTE BY, and CLUSTER BY clause.
+ val withOrder = if (
+ !order.isEmpty && sort.isEmpty && distributeBy.isEmpty && clusterBy.isEmpty) {
+ // ORDER BY ...
+ Sort(order.asScala.map(visitSortItem), global = true, query)
+ } else if (order.isEmpty && !sort.isEmpty && distributeBy.isEmpty && clusterBy.isEmpty) {
+ // SORT BY ...
+ Sort(sort.asScala.map(visitSortItem), global = false, query)
+ } else if (order.isEmpty && sort.isEmpty && !distributeBy.isEmpty && clusterBy.isEmpty) {
+ // DISTRIBUTE BY ...
+ RepartitionByExpression(expressionList(distributeBy), query)
+ } else if (order.isEmpty && !sort.isEmpty && !distributeBy.isEmpty && clusterBy.isEmpty) {
+ // SORT BY ... DISTRIBUTE BY ...
+ Sort(
+ sort.asScala.map(visitSortItem),
+ global = false,
+ RepartitionByExpression(expressionList(distributeBy), query))
+ } else if (order.isEmpty && sort.isEmpty && distributeBy.isEmpty && !clusterBy.isEmpty) {
+ // CLUSTER BY ...
+ val expressions = expressionList(clusterBy)
+ Sort(
+ expressions.map(SortOrder(_, Ascending)),
+ global = false,
+ RepartitionByExpression(expressions, query))
+ } else if (order.isEmpty && sort.isEmpty && distributeBy.isEmpty && clusterBy.isEmpty) {
+ // [EMPTY]
+ query
+ } else {
+ throw new ParseException(
+ "Combination of ORDER BY/SORT BY/DISTRIBUTE BY/CLUSTER BY is not supported", ctx)
+ }
+
+ // WINDOWS
+ val withWindow = withOrder.optionalMap(windows)(withWindows)
+
+ // LIMIT
+ withWindow.optional(limit) {
+ Limit(typedVisit(limit), withWindow)
+ }
+ }
+
+ /**
+ * Create a logical plan using a query specification.
+ */
+ override def visitQuerySpecification(
+ ctx: QuerySpecificationContext): LogicalPlan = withOrigin(ctx) {
+ val from = OneRowRelation.optional(ctx.fromClause) {
+ visitFromClause(ctx.fromClause)
+ }
+ withQuerySpecification(ctx, from)
+ }
+
+ /**
+ * Add a query specification to a logical plan. The query specification is the core of the logical
+ * plan, this is where sourcing (FROM clause), transforming (SELECT TRANSFORM/MAP/REDUCE),
+ * projection (SELECT), aggregation (GROUP BY ... HAVING ...) and filtering (WHERE) takes place.
+ *
+ * Note that query hints are ignored (both by the parser and the builder).
+ */
+ private def withQuerySpecification(
+ ctx: QuerySpecificationContext,
+ relation: LogicalPlan): LogicalPlan = withOrigin(ctx) {
+ import ctx._
+
+ // WHERE
+ def filter(ctx: BooleanExpressionContext, plan: LogicalPlan): LogicalPlan = {
+ Filter(expression(ctx), plan)
+ }
+
+ // Expressions.
+ val expressions = Option(namedExpressionSeq).toSeq
+ .flatMap(_.namedExpression.asScala)
+ .map(typedVisit[Expression])
+
+ // Create either a transform or a regular query.
+ val specType = Option(kind).map(_.getType).getOrElse(SqlBaseParser.SELECT)
+ specType match {
+ case SqlBaseParser.MAP | SqlBaseParser.REDUCE | SqlBaseParser.TRANSFORM =>
+ // Transform
+
+ // Add where.
+ val withFilter = relation.optionalMap(where)(filter)
+
+ // Create the attributes.
+ val (attributes, schemaLess) = if (colTypeList != null) {
+ // Typed return columns.
+ (createStructType(colTypeList).toAttributes, false)
+ } else if (identifierSeq != null) {
+ // Untyped return columns.
+ val attrs = visitIdentifierSeq(identifierSeq).map { name =>
+ AttributeReference(name, StringType, nullable = true)()
+ }
+ (attrs, false)
+ } else {
+ (Seq(AttributeReference("key", StringType)(),
+ AttributeReference("value", StringType)()), true)
+ }
+
+ // Create the transform.
+ ScriptTransformation(
+ expressions,
+ string(script),
+ attributes,
+ withFilter,
+ withScriptIOSchema(inRowFormat, recordWriter, outRowFormat, recordReader, schemaLess))
+
+ case SqlBaseParser.SELECT =>
+ // Regular select
+
+ // Add lateral views.
+ val withLateralView = ctx.lateralView.asScala.foldLeft(relation)(withGenerate)
+
+ // Add where.
+ val withFilter = withLateralView.optionalMap(where)(filter)
+
+ // Add aggregation or a project.
+ val namedExpressions = expressions.map {
+ case e: NamedExpression => e
+ case e: Expression => UnresolvedAlias(e)
+ }
+ val withProject = if (aggregation != null) {
+ withAggregation(aggregation, namedExpressions, withFilter)
+ } else if (namedExpressions.nonEmpty) {
+ Project(namedExpressions, withFilter)
+ } else {
+ withFilter
+ }
+
+ // Having
+ val withHaving = withProject.optional(having) {
+ // Note that we added a cast to boolean. If the expression itself is already boolean,
+ // the optimizer will get rid of the unnecessary cast.
+ Filter(Cast(expression(having), BooleanType), withProject)
+ }
+
+ // Distinct
+ val withDistinct = if (setQuantifier() != null && setQuantifier().DISTINCT() != null) {
+ Distinct(withHaving)
+ } else {
+ withHaving
+ }
+
+ // Window
+ withDistinct.optionalMap(windows)(withWindows)
+ }
+ }
+
+ /**
+ * Create a (Hive based) [[ScriptInputOutputSchema]].
+ */
+ protected def withScriptIOSchema(
+ inRowFormat: RowFormatContext,
+ recordWriter: Token,
+ outRowFormat: RowFormatContext,
+ recordReader: Token,
+ schemaLess: Boolean): ScriptInputOutputSchema = null
+
+ /**
+ * Create a logical plan for a given 'FROM' clause. Note that we support multiple (comma
+ * separated) relations here, these get converted into a single plan by condition-less inner join.
+ */
+ override def visitFromClause(ctx: FromClauseContext): LogicalPlan = withOrigin(ctx) {
+ val from = ctx.relation.asScala.map(plan).reduceLeft(Join(_, _, Inner, None))
+ ctx.lateralView.asScala.foldLeft(from)(withGenerate)
+ }
+
+ /**
+ * Connect two queries by a Set operator.
+ *
+ * Supported Set operators are:
+ * - UNION [DISTINCT]
+ * - UNION ALL
+ * - EXCEPT [DISTINCT]
+ * - INTERSECT [DISTINCT]
+ */
+ override def visitSetOperation(ctx: SetOperationContext): LogicalPlan = withOrigin(ctx) {
+ val left = plan(ctx.left)
+ val right = plan(ctx.right)
+ val all = Option(ctx.setQuantifier()).exists(_.ALL != null)
+ ctx.operator.getType match {
+ case SqlBaseParser.UNION if all =>
+ Union(left, right)
+ case SqlBaseParser.UNION =>
+ Distinct(Union(left, right))
+ case SqlBaseParser.INTERSECT if all =>
+ throw new ParseException("INTERSECT ALL is not supported.", ctx)
+ case SqlBaseParser.INTERSECT =>
+ Intersect(left, right)
+ case SqlBaseParser.EXCEPT if all =>
+ throw new ParseException("EXCEPT ALL is not supported.", ctx)
+ case SqlBaseParser.EXCEPT =>
+ Except(left, right)
+ }
+ }
+
+ /**
+ * Add a [[WithWindowDefinition]] operator to a logical plan.
+ */
+ private def withWindows(
+ ctx: WindowsContext,
+ query: LogicalPlan): LogicalPlan = withOrigin(ctx) {
+ // Collect all window specifications defined in the WINDOW clause.
+ val baseWindowMap = ctx.namedWindow.asScala.map {
+ wCtx =>
+ (wCtx.identifier.getText, typedVisit[WindowSpec](wCtx.windowSpec))
+ }.toMap
+
+ // Handle cases like
+ // window w1 as (partition by p_mfgr order by p_name
+ // range between 2 preceding and 2 following),
+ // w2 as w1
+ val windowMapView = baseWindowMap.mapValues {
+ case WindowSpecReference(name) =>
+ baseWindowMap.get(name) match {
+ case Some(spec: WindowSpecDefinition) =>
+ spec
+ case Some(ref) =>
+ throw new ParseException(s"Window reference '$name' is not a window specification", ctx)
+ case None =>
+ throw new ParseException(s"Cannot resolve window reference '$name'", ctx)
+ }
+ case spec: WindowSpecDefinition => spec
+ }
+
+ // Note that mapValues creates a view instead of materialized map. We force materialization by
+ // mapping over identity.
+ WithWindowDefinition(windowMapView.map(identity), query)
+ }
+
+ /**
+ * Add an [[Aggregate]] to a logical plan.
+ */
+ private def withAggregation(
+ ctx: AggregationContext,
+ selectExpressions: Seq[NamedExpression],
+ query: LogicalPlan): LogicalPlan = withOrigin(ctx) {
+ import ctx._
+ val groupByExpressions = expressionList(groupingExpressions)
+
+ if (GROUPING != null) {
+ // GROUP BY .... GROUPING SETS (...)
+ val expressionMap = groupByExpressions.zipWithIndex.toMap
+ val numExpressions = expressionMap.size
+ val mask = (1 << numExpressions) - 1
+ val masks = ctx.groupingSet.asScala.map {
+ _.expression.asScala.foldLeft(mask) {
+ case (bitmap, eCtx) =>
+ // Find the index of the expression.
+ val e = typedVisit[Expression](eCtx)
+ val index = expressionMap.find(_._1.semanticEquals(e)).map(_._2).getOrElse(
+ throw new ParseException(
+ s"$e doesn't show up in the GROUP BY list", ctx))
+ // 0 means that the column at the given index is a grouping column, 1 means it is not,
+ // so we unset the bit in bitmap.
+ bitmap & ~(1 << (numExpressions - 1 - index))
+ }
+ }
+ GroupingSets(masks, groupByExpressions, query, selectExpressions)
+ } else {
+ // GROUP BY .... (WITH CUBE | WITH ROLLUP)?
+ val mappedGroupByExpressions = if (CUBE != null) {
+ Seq(Cube(groupByExpressions))
+ } else if (ROLLUP != null) {
+ Seq(Rollup(groupByExpressions))
+ } else {
+ groupByExpressions
+ }
+ Aggregate(mappedGroupByExpressions, selectExpressions, query)
+ }
+ }
+
+ /**
+ * Add a [[Generate]] (Lateral View) to a logical plan.
+ */
+ private def withGenerate(
+ query: LogicalPlan,
+ ctx: LateralViewContext): LogicalPlan = withOrigin(ctx) {
+ val expressions = expressionList(ctx.expression)
+
+ // Create the generator.
+ val generator = ctx.qualifiedName.getText.toLowerCase match {
+ case "explode" if expressions.size == 1 =>
+ Explode(expressions.head)
+ case "json_tuple" =>
+ JsonTuple(expressions)
+ case other =>
+ withGenerator(other, expressions, ctx)
+ }
+
+ Generate(
+ generator,
+ join = true,
+ outer = ctx.OUTER != null,
+ Some(ctx.tblName.getText.toLowerCase),
+ ctx.colName.asScala.map(_.getText).map(UnresolvedAttribute.apply),
+ query)
+ }
+
+ /**
+ * Create a [[Generator]]. Override this method in order to support custom Generators.
+ */
+ protected def withGenerator(
+ name: String,
+ expressions: Seq[Expression],
+ ctx: LateralViewContext): Generator = {
+ throw new ParseException(s"Generator function '$name' is not supported", ctx)
+ }
+
+ /**
+ * Create a joins between two or more logical plans.
+ */
+ override def visitJoinRelation(ctx: JoinRelationContext): LogicalPlan = withOrigin(ctx) {
+ /** Build a join between two plans. */
+ def join(ctx: JoinRelationContext, left: LogicalPlan, right: LogicalPlan): Join = {
+ val baseJoinType = ctx.joinType match {
+ case null => Inner
+ case jt if jt.FULL != null => FullOuter
+ case jt if jt.SEMI != null => LeftSemi
+ case jt if jt.LEFT != null => LeftOuter
+ case jt if jt.RIGHT != null => RightOuter
+ case _ => Inner
+ }
+
+ // Resolve the join type and join condition
+ val (joinType, condition) = Option(ctx.joinCriteria) match {
+ case Some(c) if c.USING != null =>
+ val columns = c.identifier.asScala.map { column =>
+ UnresolvedAttribute.quoted(column.getText)
+ }
+ (UsingJoin(baseJoinType, columns), None)
+ case Some(c) if c.booleanExpression != null =>
+ (baseJoinType, Option(expression(c.booleanExpression)))
+ case None if ctx.NATURAL != null =>
+ (NaturalJoin(baseJoinType), None)
+ case None =>
+ (baseJoinType, None)
+ }
+ Join(left, right, joinType, condition)
+ }
+
+ // Handle all consecutive join clauses. ANTLR produces a right nested tree in which the the
+ // first join clause is at the top. However fields of previously referenced tables can be used
+ // in following join clauses. The tree needs to be reversed in order to make this work.
+ var result = plan(ctx.left)
+ var current = ctx
+ while (current != null) {
+ current.right match {
+ case right: JoinRelationContext =>
+ result = join(current, result, plan(right.left))
+ current = right
+ case right =>
+ result = join(current, result, plan(right))
+ current = null
+ }
+ }
+ result
+ }
+
+ /**
+ * Add a [[Sample]] to a logical plan.
+ *
+ * This currently supports the following sampling methods:
+ * - TABLESAMPLE(x ROWS): Sample the table down to the given number of rows.
+ * - TABLESAMPLE(x PERCENT): Sample the table down to the given percentage. Note that percentages
+ * are defined as a number between 0 and 100.
+ * - TABLESAMPLE(BUCKET x OUT OF y): Sample the table down to a 'x' divided by 'y' fraction.
+ */
+ private def withSample(ctx: SampleContext, query: LogicalPlan): LogicalPlan = withOrigin(ctx) {
+ // Create a sampled plan if we need one.
+ def sample(fraction: Double): Sample = {
+ // The range of fraction accepted by Sample is [0, 1]. Because Hive's block sampling
+ // function takes X PERCENT as the input and the range of X is [0, 100], we need to
+ // adjust the fraction.
+ val eps = RandomSampler.roundingEpsilon
+ assert(fraction >= 0.0 - eps && fraction <= 1.0 + eps,
+ s"Sampling fraction ($fraction) must be on interval [0, 1]",
+ ctx)
+ Sample(0.0, fraction, withReplacement = false, (math.random * 1000).toInt, query)(true)
+ }
+
+ ctx.sampleType.getType match {
+ case SqlBaseParser.ROWS =>
+ Limit(expression(ctx.expression), query)
+
+ case SqlBaseParser.PERCENTLIT =>
+ val fraction = ctx.percentage.getText.toDouble
+ sample(fraction / 100.0d)
+
+ case SqlBaseParser.BUCKET if ctx.ON != null =>
+ throw new ParseException("TABLESAMPLE(BUCKET x OUT OF y ON id) is not supported", ctx)
+
+ case SqlBaseParser.BUCKET =>
+ sample(ctx.numerator.getText.toDouble / ctx.denominator.getText.toDouble)
+ }
+ }
+
+ /**
+ * Create a logical plan for a sub-query.
+ */
+ override def visitSubquery(ctx: SubqueryContext): LogicalPlan = withOrigin(ctx) {
+ plan(ctx.queryNoWith)
+ }
+
+ /**
+ * Create an un-aliased table reference. This is typically used for top-level table references,
+ * for example:
+ * {{{
+ * INSERT INTO db.tbl2
+ * TABLE db.tbl1
+ * }}}
+ */
+ override def visitTable(ctx: TableContext): LogicalPlan = withOrigin(ctx) {
+ UnresolvedRelation(visitTableIdentifier(ctx.tableIdentifier), None)
+ }
+
+ /**
+ * Create an aliased table reference. This is typically used in FROM clauses.
+ */
+ override def visitTableName(ctx: TableNameContext): LogicalPlan = withOrigin(ctx) {
+ val table = UnresolvedRelation(
+ visitTableIdentifier(ctx.tableIdentifier),
+ Option(ctx.identifier).map(_.getText))
+ table.optionalMap(ctx.sample)(withSample)
+ }
+
+ /**
+ * Create an inline table (a virtual table in Hive parlance).
+ */
+ override def visitInlineTable(ctx: InlineTableContext): LogicalPlan = withOrigin(ctx) {
+ // Get the backing expressions.
+ val expressions = ctx.expression.asScala.map { eCtx =>
+ val e = expression(eCtx)
+ assert(e.foldable, "All expressions in an inline table must be constants.", eCtx)
+ e
+ }
+
+ // Validate and evaluate the rows.
+ val (structType, structConstructor) = expressions.head.dataType match {
+ case st: StructType =>
+ (st, (e: Expression) => e)
+ case dt =>
+ val st = CreateStruct(Seq(expressions.head)).dataType
+ (st, (e: Expression) => CreateStruct(Seq(e)))
+ }
+ val rows = expressions.map {
+ case expression =>
+ val safe = Cast(structConstructor(expression), structType)
+ safe.eval().asInstanceOf[InternalRow]
+ }
+
+ // Construct attributes.
+ val baseAttributes = structType.toAttributes.map(_.withNullability(true))
+ val attributes = if (ctx.identifierList != null) {
+ val aliases = visitIdentifierList(ctx.identifierList)
+ assert(aliases.size == baseAttributes.size,
+ "Number of aliases must match the number of fields in an inline table.", ctx)
+ baseAttributes.zip(aliases).map(p => p._1.withName(p._2))
+ } else {
+ baseAttributes
+ }
+
+ // Create plan and add an alias if a name has been defined.
+ LocalRelation(attributes, rows).optionalMap(ctx.identifier)(aliasPlan)
+ }
+
+ /**
+ * Create an alias (SubqueryAlias) for a join relation. This is practically the same as
+ * visitAliasedQuery and visitNamedExpression, ANTLR4 however requires us to use 3 different
+ * hooks.
+ */
+ override def visitAliasedRelation(ctx: AliasedRelationContext): LogicalPlan = withOrigin(ctx) {
+ plan(ctx.relation).optionalMap(ctx.sample)(withSample).optionalMap(ctx.identifier)(aliasPlan)
+ }
+
+ /**
+ * Create an alias (SubqueryAlias) for a sub-query. This is practically the same as
+ * visitAliasedRelation and visitNamedExpression, ANTLR4 however requires us to use 3 different
+ * hooks.
+ */
+ override def visitAliasedQuery(ctx: AliasedQueryContext): LogicalPlan = withOrigin(ctx) {
+ plan(ctx.queryNoWith).optionalMap(ctx.sample)(withSample).optionalMap(ctx.identifier)(aliasPlan)
+ }
+
+ /**
+ * Create an alias (SubqueryAlias) for a LogicalPlan.
+ */
+ private def aliasPlan(alias: IdentifierContext, plan: LogicalPlan): LogicalPlan = {
+ SubqueryAlias(alias.getText, plan)
+ }
+
+ /**
+ * Create a Sequence of Strings for a parenthesis enclosed alias list.
+ */
+ override def visitIdentifierList(ctx: IdentifierListContext): Seq[String] = withOrigin(ctx) {
+ visitIdentifierSeq(ctx.identifierSeq)
+ }
+
+ /**
+ * Create a Sequence of Strings for an identifier list.
+ */
+ override def visitIdentifierSeq(ctx: IdentifierSeqContext): Seq[String] = withOrigin(ctx) {
+ ctx.identifier.asScala.map(_.getText)
+ }
+
+ /* ********************************************************************************************
+ * Table Identifier parsing
+ * ******************************************************************************************** */
+ /**
+ * Create a [[TableIdentifier]] from a 'tableName' or 'databaseName'.'tableName' pattern.
+ */
+ override def visitTableIdentifier(
+ ctx: TableIdentifierContext): TableIdentifier = withOrigin(ctx) {
+ TableIdentifier(ctx.table.getText, Option(ctx.db).map(_.getText))
+ }
+
+ /* ********************************************************************************************
+ * Expression parsing
+ * ******************************************************************************************** */
+ /**
+ * Create an expression from the given context. This method just passes the context on to the
+ * vistor and only takes care of typing (We assume that the visitor returns an Expression here).
+ */
+ protected def expression(ctx: ParserRuleContext): Expression = typedVisit(ctx)
+
+ /**
+ * Create sequence of expressions from the given sequence of contexts.
+ */
+ private def expressionList(trees: java.util.List[ExpressionContext]): Seq[Expression] = {
+ trees.asScala.map(expression)
+ }
+
+ /**
+ * Invert a boolean expression if it has a valid NOT clause.
+ */
+ private def invertIfNotDefined(expression: Expression, not: TerminalNode): Expression = {
+ if (not != null) {
+ Not(expression)
+ } else {
+ expression
+ }
+ }
+
+ /**
+ * Create a star (i.e. all) expression; this selects all elements (in the specified object).
+ * Both un-targeted (global) and targeted aliases are supported.
+ */
+ override def visitStar(ctx: StarContext): Expression = withOrigin(ctx) {
+ UnresolvedStar(Option(ctx.qualifiedName()).map(_.identifier.asScala.map(_.getText)))
+ }
+
+ /**
+ * Create an aliased expression if an alias is specified. Both single and multi-aliases are
+ * supported.
+ */
+ override def visitNamedExpression(ctx: NamedExpressionContext): Expression = withOrigin(ctx) {
+ val e = expression(ctx.expression)
+ if (ctx.identifier != null) {
+ Alias(e, ctx.identifier.getText)()
+ } else if (ctx.identifierList != null) {
+ MultiAlias(e, visitIdentifierList(ctx.identifierList))
+ } else {
+ e
+ }
+ }
+
+ /**
+ * Combine a number of boolean expressions into a balanced expression tree. These expressions are
+ * either combined by a logical [[And]] or a logical [[Or]].
+ *
+ * A balanced binary tree is created because regular left recursive trees cause considerable
+ * performance degradations and can cause stack overflows.
+ */
+ override def visitLogicalBinary(ctx: LogicalBinaryContext): Expression = withOrigin(ctx) {
+ val expressionType = ctx.operator.getType
+ val expressionCombiner = expressionType match {
+ case SqlBaseParser.AND => And.apply _
+ case SqlBaseParser.OR => Or.apply _
+ }
+
+ // Collect all similar left hand contexts.
+ val contexts = ArrayBuffer(ctx.right)
+ var current = ctx.left
+ def collectContexts: Boolean = current match {
+ case lbc: LogicalBinaryContext if lbc.operator.getType == expressionType =>
+ contexts += lbc.right
+ current = lbc.left
+ true
+ case _ =>
+ contexts += current
+ false
+ }
+ while (collectContexts) {
+ // No body - all updates take place in the collectContexts.
+ }
+
+ // Reverse the contexts to have them in the same sequence as in the SQL statement & turn them
+ // into expressions.
+ val expressions = contexts.reverse.map(expression)
+
+ // Create a balanced tree.
+ def reduceToExpressionTree(low: Int, high: Int): Expression = high - low match {
+ case 0 =>
+ expressions(low)
+ case 1 =>
+ expressionCombiner(expressions(low), expressions(high))
+ case x =>
+ val mid = low + x / 2
+ expressionCombiner(
+ reduceToExpressionTree(low, mid),
+ reduceToExpressionTree(mid + 1, high))
+ }
+ reduceToExpressionTree(0, expressions.size - 1)
+ }
+
+ /**
+ * Invert a boolean expression.
+ */
+ override def visitLogicalNot(ctx: LogicalNotContext): Expression = withOrigin(ctx) {
+ Not(expression(ctx.booleanExpression()))
+ }
+
+ /**
+ * Create a filtering correlated sub-query. This is not supported yet.
+ */
+ override def visitExists(ctx: ExistsContext): Expression = {
+ throw new ParseException("EXISTS clauses are not supported.", ctx)
+ }
+
+ /**
+ * Create a comparison expression. This compares two expressions. The following comparison
+ * operators are supported:
+ * - Equal: '=' or '=='
+ * - Null-safe Equal: '<=>'
+ * - Not Equal: '<>' or '!='
+ * - Less than: '<'
+ * - Less then or Equal: '<='
+ * - Greater than: '>'
+ * - Greater then or Equal: '>='
+ */
+ override def visitComparison(ctx: ComparisonContext): Expression = withOrigin(ctx) {
+ val left = expression(ctx.left)
+ val right = expression(ctx.right)
+ val operator = ctx.comparisonOperator().getChild(0).asInstanceOf[TerminalNode]
+ operator.getSymbol.getType match {
+ case SqlBaseParser.EQ =>
+ EqualTo(left, right)
+ case SqlBaseParser.NSEQ =>
+ EqualNullSafe(left, right)
+ case SqlBaseParser.NEQ | SqlBaseParser.NEQJ =>
+ Not(EqualTo(left, right))
+ case SqlBaseParser.LT =>
+ LessThan(left, right)
+ case SqlBaseParser.LTE =>
+ LessThanOrEqual(left, right)
+ case SqlBaseParser.GT =>
+ GreaterThan(left, right)
+ case SqlBaseParser.GTE =>
+ GreaterThanOrEqual(left, right)
+ }
+ }
+
+ /**
+ * Create a BETWEEN expression. This tests if an expression lies with in the bounds set by two
+ * other expressions. The inverse can also be created.
+ */
+ override def visitBetween(ctx: BetweenContext): Expression = withOrigin(ctx) {
+ val value = expression(ctx.value)
+ val between = And(
+ GreaterThanOrEqual(value, expression(ctx.lower)),
+ LessThanOrEqual(value, expression(ctx.upper)))
+ invertIfNotDefined(between, ctx.NOT)
+ }
+
+ /**
+ * Create an IN expression. This tests if the value of the left hand side expression is
+ * contained by the sequence of expressions on the right hand side.
+ */
+ override def visitInList(ctx: InListContext): Expression = withOrigin(ctx) {
+ val in = In(expression(ctx.value), ctx.expression().asScala.map(expression))
+ invertIfNotDefined(in, ctx.NOT)
+ }
+
+ /**
+ * Create an IN expression, where the the right hand side is a query. This is unsupported.
+ */
+ override def visitInSubquery(ctx: InSubqueryContext): Expression = {
+ throw new ParseException("IN with a Sub-query is currently not supported.", ctx)
+ }
+
+ /**
+ * Create a (R)LIKE/REGEXP expression.
+ */
+ override def visitLike(ctx: LikeContext): Expression = {
+ val left = expression(ctx.value)
+ val right = expression(ctx.pattern)
+ val like = ctx.like.getType match {
+ case SqlBaseParser.LIKE =>
+ Like(left, right)
+ case SqlBaseParser.RLIKE =>
+ RLike(left, right)
+ }
+ invertIfNotDefined(like, ctx.NOT)
+ }
+
+ /**
+ * Create an IS (NOT) NULL expression.
+ */
+ override def visitNullPredicate(ctx: NullPredicateContext): Expression = withOrigin(ctx) {
+ val value = expression(ctx.value)
+ if (ctx.NOT != null) {
+ IsNotNull(value)
+ } else {
+ IsNull(value)
+ }
+ }
+
+ /**
+ * Create a binary arithmetic expression. The following arithmetic operators are supported:
+ * - Mulitplication: '*'
+ * - Division: '/'
+ * - Hive Long Division: 'DIV'
+ * - Modulo: '%'
+ * - Addition: '+'
+ * - Subtraction: '-'
+ * - Binary AND: '&'
+ * - Binary XOR
+ * - Binary OR: '|'
+ */
+ override def visitArithmeticBinary(ctx: ArithmeticBinaryContext): Expression = withOrigin(ctx) {
+ val left = expression(ctx.left)
+ val right = expression(ctx.right)
+ ctx.operator.getType match {
+ case SqlBaseParser.ASTERISK =>
+ Multiply(left, right)
+ case SqlBaseParser.SLASH =>
+ Divide(left, right)
+ case SqlBaseParser.PERCENT =>
+ Remainder(left, right)
+ case SqlBaseParser.DIV =>
+ Cast(Divide(left, right), LongType)
+ case SqlBaseParser.PLUS =>
+ Add(left, right)
+ case SqlBaseParser.MINUS =>
+ Subtract(left, right)
+ case SqlBaseParser.AMPERSAND =>
+ BitwiseAnd(left, right)
+ case SqlBaseParser.HAT =>
+ BitwiseXor(left, right)
+ case SqlBaseParser.PIPE =>
+ BitwiseOr(left, right)
+ }
+ }
+
+ /**
+ * Create a unary arithmetic expression. The following arithmetic operators are supported:
+ * - Plus: '+'
+ * - Minus: '-'
+ * - Bitwise Not: '~'
+ */
+ override def visitArithmeticUnary(ctx: ArithmeticUnaryContext): Expression = withOrigin(ctx) {
+ val value = expression(ctx.valueExpression)
+ ctx.operator.getType match {
+ case SqlBaseParser.PLUS =>
+ value
+ case SqlBaseParser.MINUS =>
+ UnaryMinus(value)
+ case SqlBaseParser.TILDE =>
+ BitwiseNot(value)
+ }
+ }
+
+ /**
+ * Create a [[Cast]] expression.
+ */
+ override def visitCast(ctx: CastContext): Expression = withOrigin(ctx) {
+ Cast(expression(ctx.expression), typedVisit(ctx.dataType))
+ }
+
+ /**
+ * Create a (windowed) Function expression.
+ */
+ override def visitFunctionCall(ctx: FunctionCallContext): Expression = withOrigin(ctx) {
+ // Create the function call.
+ val name = ctx.qualifiedName.getText
+ val isDistinct = Option(ctx.setQuantifier()).exists(_.DISTINCT != null)
+ val arguments = ctx.expression().asScala.map(expression) match {
+ case Seq(UnresolvedStar(None)) if name.toLowerCase == "count" && !isDistinct =>
+ // Transform COUNT(*) into COUNT(1). Move this to analysis?
+ Seq(Literal(1))
+ case expressions =>
+ expressions
+ }
+ val function = UnresolvedFunction(name, arguments, isDistinct)
+
+ // Check if the function is evaluated in a windowed context.
+ ctx.windowSpec match {
+ case spec: WindowRefContext =>
+ UnresolvedWindowExpression(function, visitWindowRef(spec))
+ case spec: WindowDefContext =>
+ WindowExpression(function, visitWindowDef(spec))
+ case _ => function
+ }
+ }
+
+ /**
+ * Create a reference to a window frame, i.e. [[WindowSpecReference]].
+ */
+ override def visitWindowRef(ctx: WindowRefContext): WindowSpecReference = withOrigin(ctx) {
+ WindowSpecReference(ctx.identifier.getText)
+ }
+
+ /**
+ * Create a window definition, i.e. [[WindowSpecDefinition]].
+ */
+ override def visitWindowDef(ctx: WindowDefContext): WindowSpecDefinition = withOrigin(ctx) {
+ // CLUSTER BY ... | PARTITION BY ... ORDER BY ...
+ val partition = ctx.partition.asScala.map(expression)
+ val order = ctx.sortItem.asScala.map(visitSortItem)
+
+ // RANGE/ROWS BETWEEN ...
+ val frameSpecOption = Option(ctx.windowFrame).map { frame =>
+ val frameType = frame.frameType.getType match {
+ case SqlBaseParser.RANGE => RangeFrame
+ case SqlBaseParser.ROWS => RowFrame
+ }
+
+ SpecifiedWindowFrame(
+ frameType,
+ visitFrameBound(frame.start),
+ Option(frame.end).map(visitFrameBound).getOrElse(CurrentRow))
+ }
+
+ WindowSpecDefinition(
+ partition,
+ order,
+ frameSpecOption.getOrElse(UnspecifiedFrame))
+ }
+
+ /**
+ * Create or resolve a [[FrameBoundary]]. Simple math expressions are allowed for Value
+ * Preceding/Following boundaries. These expressions must be constant (foldable) and return an
+ * integer value.
+ */
+ override def visitFrameBound(ctx: FrameBoundContext): FrameBoundary = withOrigin(ctx) {
+ // We currently only allow foldable integers.
+ def value: Int = {
+ val e = expression(ctx.expression)
+ assert(e.resolved && e.foldable && e.dataType == IntegerType,
+ "Frame bound value must be a constant integer.",
+ ctx)
+ e.eval().asInstanceOf[Int]
+ }
+
+ // Create the FrameBoundary
+ ctx.boundType.getType match {
+ case SqlBaseParser.PRECEDING if ctx.UNBOUNDED != null =>
+ UnboundedPreceding
+ case SqlBaseParser.PRECEDING =>
+ ValuePreceding(value)
+ case SqlBaseParser.CURRENT =>
+ CurrentRow
+ case SqlBaseParser.FOLLOWING if ctx.UNBOUNDED != null =>
+ UnboundedFollowing
+ case SqlBaseParser.FOLLOWING =>
+ ValueFollowing(value)
+ }
+ }
+
+ /**
+ * Create a [[CreateStruct]] expression.
+ */
+ override def visitRowConstructor(ctx: RowConstructorContext): Expression = withOrigin(ctx) {
+ CreateStruct(ctx.expression.asScala.map(expression))
+ }
+
+ /**
+ * Create a [[ScalarSubquery]] expression.
+ */
+ override def visitSubqueryExpression(
+ ctx: SubqueryExpressionContext): Expression = withOrigin(ctx) {
+ ScalarSubquery(plan(ctx.query))
+ }
+
+ /**
+ * Create a value based [[CaseWhen]] expression. This has the following SQL form:
+ * {{{
+ * CASE [expression]
+ * WHEN [value] THEN [expression]
+ * ...
+ * ELSE [expression]
+ * END
+ * }}}
+ */
+ override def visitSimpleCase(ctx: SimpleCaseContext): Expression = withOrigin(ctx) {
+ val e = expression(ctx.valueExpression)
+ val branches = ctx.whenClause.asScala.map { wCtx =>
+ (EqualTo(e, expression(wCtx.condition)), expression(wCtx.result))
+ }
+ CaseWhen(branches, Option(ctx.elseExpression).map(expression))
+ }
+
+ /**
+ * Create a condition based [[CaseWhen]] expression. This has the following SQL syntax:
+ * {{{
+ * CASE
+ * WHEN [predicate] THEN [expression]
+ * ...
+ * ELSE [expression]
+ * END
+ * }}}
+ *
+ * @param ctx the parse tree
+ * */
+ override def visitSearchedCase(ctx: SearchedCaseContext): Expression = withOrigin(ctx) {
+ val branches = ctx.whenClause.asScala.map { wCtx =>
+ (expression(wCtx.condition), expression(wCtx.result))
+ }
+ CaseWhen(branches, Option(ctx.elseExpression).map(expression))
+ }
+
+ /**
+ * Create a dereference expression. The return type depends on the type of the parent, this can
+ * either be a [[UnresolvedAttribute]] (if the parent is an [[UnresolvedAttribute]]), or an
+ * [[UnresolvedExtractValue]] if the parent is some expression.
+ */
+ override def visitDereference(ctx: DereferenceContext): Expression = withOrigin(ctx) {
+ val attr = ctx.fieldName.getText
+ expression(ctx.base) match {
+ case UnresolvedAttribute(nameParts) =>
+ UnresolvedAttribute(nameParts :+ attr)
+ case e =>
+ UnresolvedExtractValue(e, Literal(attr))
+ }
+ }
+
+ /**
+ * Create an [[UnresolvedAttribute]] expression.
+ */
+ override def visitColumnReference(ctx: ColumnReferenceContext): Expression = withOrigin(ctx) {
+ UnresolvedAttribute.quoted(ctx.getText)
+ }
+
+ /**
+ * Create an [[UnresolvedExtractValue]] expression, this is used for subscript access to an array.
+ */
+ override def visitSubscript(ctx: SubscriptContext): Expression = withOrigin(ctx) {
+ UnresolvedExtractValue(expression(ctx.value), expression(ctx.index))
+ }
+
+ /**
+ * Create an expression for an expression between parentheses. This is need because the ANTLR
+ * visitor cannot automatically convert the nested context into an expression.
+ */
+ override def visitParenthesizedExpression(
+ ctx: ParenthesizedExpressionContext): Expression = withOrigin(ctx) {
+ expression(ctx.expression)
+ }
+
+ /**
+ * Create a [[SortOrder]] expression.
+ */
+ override def visitSortItem(ctx: SortItemContext): SortOrder = withOrigin(ctx) {
+ if (ctx.DESC != null) {
+ SortOrder(expression(ctx.expression), Descending)
+ } else {
+ SortOrder(expression(ctx.expression), Ascending)
+ }
+ }
+
+ /**
+ * Create a typed Literal expression. A typed literal has the following SQL syntax:
+ * {{{
+ * [TYPE] '[VALUE]'
+ * }}}
+ * Currently Date and Timestamp typed literals are supported.
+ *
+ * TODO what the added value of this over casting?
+ */
+ override def visitTypeConstructor(ctx: TypeConstructorContext): Literal = withOrigin(ctx) {
+ val value = string(ctx.STRING)
+ ctx.identifier.getText.toUpperCase match {
+ case "DATE" =>
+ Literal(Date.valueOf(value))
+ case "TIMESTAMP" =>
+ Literal(Timestamp.valueOf(value))
+ case other =>
+ throw new ParseException(s"Literals of type '$other' are currently not supported.", ctx)
+ }
+ }
+
+ /**
+ * Create a NULL literal expression.
+ */
+ override def visitNullLiteral(ctx: NullLiteralContext): Literal = withOrigin(ctx) {
+ Literal(null)
+ }
+
+ /**
+ * Create a Boolean literal expression.
+ */
+ override def visitBooleanLiteral(ctx: BooleanLiteralContext): Literal = withOrigin(ctx) {
+ if (ctx.getText.toBoolean) {
+ Literal.TrueLiteral
+ } else {
+ Literal.FalseLiteral
+ }
+ }
+
+ /**
+ * Create an integral literal expression. The code selects the most narrow integral type
+ * possible, either a BigDecimal, a Long or an Integer is returned.
+ */
+ override def visitIntegerLiteral(ctx: IntegerLiteralContext): Literal = withOrigin(ctx) {
+ BigDecimal(ctx.getText) match {
+ case v if v.isValidInt =>
+ Literal(v.intValue())
+ case v if v.isValidLong =>
+ Literal(v.longValue())
+ case v => Literal(v.underlying())
+ }
+ }
+
+ /**
+ * Create a double literal for a number denoted in scientifc notation.
+ */
+ override def visitScientificDecimalLiteral(
+ ctx: ScientificDecimalLiteralContext): Literal = withOrigin(ctx) {
+ Literal(ctx.getText.toDouble)
+ }
+
+ /**
+ * Create a decimal literal for a regular decimal number.
+ */
+ override def visitDecimalLiteral(ctx: DecimalLiteralContext): Literal = withOrigin(ctx) {
+ Literal(BigDecimal(ctx.getText).underlying())
+ }
+
+ /** Create a numeric literal expression. */
+ private def numericLiteral(ctx: NumberContext)(f: String => Any): Literal = withOrigin(ctx) {
+ val raw = ctx.getText
+ try {
+ Literal(f(raw.substring(0, raw.length - 1)))
+ } catch {
+ case e: NumberFormatException =>
+ throw new ParseException(e.getMessage, ctx)
+ }
+ }
+
+ /**
+ * Create a Byte Literal expression.
+ */
+ override def visitTinyIntLiteral(ctx: TinyIntLiteralContext): Literal = numericLiteral(ctx) {
+ _.toByte
+ }
+
+ /**
+ * Create a Short Literal expression.
+ */
+ override def visitSmallIntLiteral(ctx: SmallIntLiteralContext): Literal = numericLiteral(ctx) {
+ _.toShort
+ }
+
+ /**
+ * Create a Long Literal expression.
+ */
+ override def visitBigIntLiteral(ctx: BigIntLiteralContext): Literal = numericLiteral(ctx) {
+ _.toLong
+ }
+
+ /**
+ * Create a Double Literal expression.
+ */
+ override def visitDoubleLiteral(ctx: DoubleLiteralContext): Literal = numericLiteral(ctx) {
+ _.toDouble
+ }
+
+ /**
+ * Create a String literal expression.
+ */
+ override def visitStringLiteral(ctx: StringLiteralContext): Literal = withOrigin(ctx) {
+ Literal(createString(ctx))
+ }
+
+ /**
+ * Create a String from a string literal context. This supports multiple consecutive string
+ * literals, these are concatenated, for example this expression "'hello' 'world'" will be
+ * converted into "helloworld".
+ *
+ * Special characters can be escaped by using Hive/C-style escaping.
+ */
+ private def createString(ctx: StringLiteralContext): String = {
+ ctx.STRING().asScala.map(string).mkString
+ }
+
+ /**
+ * Create a [[CalendarInterval]] literal expression. An interval expression can contain multiple
+ * unit value pairs, for instance: interval 2 months 2 days.
+ */
+ override def visitInterval(ctx: IntervalContext): Literal = withOrigin(ctx) {
+ val intervals = ctx.intervalField.asScala.map(visitIntervalField)
+ assert(intervals.nonEmpty, "at least one time unit should be given for interval literal", ctx)
+ Literal(intervals.reduce(_.add(_)))
+ }
+
+ /**
+ * Create a [[CalendarInterval]] for a unit value pair. Two unit configuration types are
+ * supported:
+ * - Single unit.
+ * - From-To unit (only 'YEAR TO MONTH' and 'DAY TO SECOND' are supported).
+ */
+ override def visitIntervalField(ctx: IntervalFieldContext): CalendarInterval = withOrigin(ctx) {
+ import ctx._
+ val s = value.getText
+ val interval = (unit.getText.toLowerCase, Option(to).map(_.getText.toLowerCase)) match {
+ case (u, None) if u.endsWith("s") =>
+ // Handle plural forms, e.g: yearS/monthS/weekS/dayS/hourS/minuteS/hourS/...
+ CalendarInterval.fromSingleUnitString(u.substring(0, u.length - 1), s)
+ case (u, None) =>
+ CalendarInterval.fromSingleUnitString(u, s)
+ case ("year", Some("month")) =>
+ CalendarInterval.fromYearMonthString(s)
+ case ("day", Some("second")) =>
+ CalendarInterval.fromDayTimeString(s)
+ case (from, Some(t)) =>
+ throw new ParseException(s"Intervals FROM $from TO $t are not supported.", ctx)
+ }
+ assert(interval != null, "No interval can be constructed", ctx)
+ interval
+ }
+
+ /* ********************************************************************************************
+ * DataType parsing
+ * ******************************************************************************************** */
+ /**
+ * Resolve/create a primitive type.
+ */
+ override def visitPrimitiveDataType(ctx: PrimitiveDataTypeContext): DataType = withOrigin(ctx) {
+ (ctx.identifier.getText.toLowerCase, ctx.INTEGER_VALUE().asScala.toList) match {
+ case ("boolean", Nil) => BooleanType
+ case ("tinyint" | "byte", Nil) => ByteType
+ case ("smallint" | "short", Nil) => ShortType
+ case ("int" | "integer", Nil) => IntegerType
+ case ("bigint" | "long", Nil) => LongType
+ case ("float", Nil) => FloatType
+ case ("double", Nil) => DoubleType
+ case ("date", Nil) => DateType
+ case ("timestamp", Nil) => TimestampType
+ case ("char" | "varchar" | "string", Nil) => StringType
+ case ("char" | "varchar", _ :: Nil) => StringType
+ case ("binary", Nil) => BinaryType
+ case ("decimal", Nil) => DecimalType.USER_DEFAULT
+ case ("decimal", precision :: Nil) => DecimalType(precision.getText.toInt, 0)
+ case ("decimal", precision :: scale :: Nil) =>
+ DecimalType(precision.getText.toInt, scale.getText.toInt)
+ case (dt, params) =>
+ throw new ParseException(
+ s"DataType $dt${params.mkString("(", ",", ")")} is not supported.", ctx)
+ }
+ }
+
+ /**
+ * Create a complex DataType. Arrays, Maps and Structures are supported.
+ */
+ override def visitComplexDataType(ctx: ComplexDataTypeContext): DataType = withOrigin(ctx) {
+ ctx.complex.getType match {
+ case SqlBaseParser.ARRAY =>
+ ArrayType(typedVisit(ctx.dataType(0)))
+ case SqlBaseParser.MAP =>
+ MapType(typedVisit(ctx.dataType(0)), typedVisit(ctx.dataType(1)))
+ case SqlBaseParser.STRUCT =>
+ createStructType(ctx.colTypeList())
+ }
+ }
+
+ /**
+ * Create a [[StructType]] from a sequence of [[StructField]]s.
+ */
+ protected def createStructType(ctx: ColTypeListContext): StructType = {
+ StructType(Option(ctx).toSeq.flatMap(visitColTypeList))
+ }
+
+ /**
+ * Create a [[StructType]] from a number of column definitions.
+ */
+ override def visitColTypeList(ctx: ColTypeListContext): Seq[StructField] = withOrigin(ctx) {
+ ctx.colType().asScala.map(visitColType)
+ }
+
+ /**
+ * Create a [[StructField]] from a column definition.
+ */
+ override def visitColType(ctx: ColTypeContext): StructField = withOrigin(ctx) {
+ import ctx._
+
+ // Add the comment to the metadata.
+ val builder = new MetadataBuilder
+ if (STRING != null) {
+ builder.putString("comment", string(STRING))
+ }
+
+ StructField(identifier.getText, typedVisit(dataType), nullable = true, builder.build())
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ng/ParseDriver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ng/ParseDriver.scala
new file mode 100644
index 0000000000..c9a286374c
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ng/ParseDriver.scala
@@ -0,0 +1,240 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.catalyst.parser.ng
+
+import org.antlr.v4.runtime._
+import org.antlr.v4.runtime.atn.PredictionMode
+import org.antlr.v4.runtime.misc.ParseCancellationException
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.catalyst.TableIdentifier
+import org.apache.spark.sql.catalyst.expressions.Expression
+import org.apache.spark.sql.catalyst.parser.ParserInterface
+import org.apache.spark.sql.catalyst.parser.ng.SqlBaseParser._
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.trees.Origin
+import org.apache.spark.sql.types.DataType
+
+/**
+ * Base SQL parsing infrastructure.
+ */
+abstract class AbstractSqlParser extends ParserInterface with Logging {
+
+ /** Creates/Resolves DataType for a given SQL string. */
+ def parseDataType(sqlText: String): DataType = parse(sqlText) { parser =>
+ // TODO add this to the parser interface.
+ astBuilder.visitSingleDataType(parser.singleDataType())
+ }
+
+ /** Creates Expression for a given SQL string. */
+ override def parseExpression(sqlText: String): Expression = parse(sqlText) { parser =>
+ astBuilder.visitSingleExpression(parser.singleExpression())
+ }
+
+ /** Creates TableIdentifier for a given SQL string. */
+ override def parseTableIdentifier(sqlText: String): TableIdentifier = parse(sqlText) { parser =>
+ astBuilder.visitSingleTableIdentifier(parser.singleTableIdentifier())
+ }
+
+ /** Creates LogicalPlan for a given SQL string. */
+ override def parsePlan(sqlText: String): LogicalPlan = parse(sqlText) { parser =>
+ astBuilder.visitSingleStatement(parser.singleStatement()) match {
+ case plan: LogicalPlan => plan
+ case _ => nativeCommand(sqlText)
+ }
+ }
+
+ /** Get the builder (visitor) which converts a ParseTree into a AST. */
+ protected def astBuilder: AstBuilder
+
+ /** Create a native command, or fail when this is not supported. */
+ protected def nativeCommand(sqlText: String): LogicalPlan = {
+ val position = Origin(None, None)
+ throw new ParseException(Option(sqlText), "Unsupported SQL statement", position, position)
+ }
+
+ protected def parse[T](command: String)(toResult: SqlBaseParser => T): T = {
+ logInfo(s"Parsing command: $command")
+
+ val lexer = new SqlBaseLexer(new ANTLRNoCaseStringStream(command))
+ lexer.removeErrorListeners()
+ lexer.addErrorListener(ParseErrorListener)
+
+ val tokenStream = new CommonTokenStream(lexer)
+ val parser = new SqlBaseParser(tokenStream)
+ parser.addParseListener(PostProcessor)
+ parser.removeErrorListeners()
+ parser.addErrorListener(ParseErrorListener)
+
+ try {
+ try {
+ // first, try parsing with potentially faster SLL mode
+ parser.getInterpreter.setPredictionMode(PredictionMode.SLL)
+ toResult(parser)
+ }
+ catch {
+ case e: ParseCancellationException =>
+ // if we fail, parse with LL mode
+ tokenStream.reset() // rewind input stream
+ parser.reset()
+
+ // Try Again.
+ parser.getInterpreter.setPredictionMode(PredictionMode.LL)
+ toResult(parser)
+ }
+ }
+ catch {
+ case e: ParseException if e.command.isDefined =>
+ throw e
+ case e: ParseException =>
+ throw e.withCommand(command)
+ case e: AnalysisException =>
+ val position = Origin(e.line, e.startPosition)
+ throw new ParseException(Option(command), e.message, position, position)
+ }
+ }
+}
+
+/**
+ * Concrete SQL parser for Catalyst-only SQL statements.
+ */
+object CatalystSqlParser extends AbstractSqlParser {
+ val astBuilder = new AstBuilder
+}
+
+/**
+ * This string stream provides the lexer with upper case characters only. This greatly simplifies
+ * lexing the stream, while we can maintain the original command.
+ *
+ * This is based on Hive's org.apache.hadoop.hive.ql.parse.ParseDriver.ANTLRNoCaseStringStream
+ *
+ * The comment below (taken from the original class) describes the rationale for doing this:
+ *
+ * This class provides and implementation for a case insensitive token checker for the lexical
+ * analysis part of antlr. By converting the token stream into upper case at the time when lexical
+ * rules are checked, this class ensures that the lexical rules need to just match the token with
+ * upper case letters as opposed to combination of upper case and lower case characters. This is
+ * purely used for matching lexical rules. The actual token text is stored in the same way as the
+ * user input without actually converting it into an upper case. The token values are generated by
+ * the consume() function of the super class ANTLRStringStream. The LA() function is the lookahead
+ * function and is purely used for matching lexical rules. This also means that the grammar will
+ * only accept capitalized tokens in case it is run from other tools like antlrworks which do not
+ * have the ANTLRNoCaseStringStream implementation.
+ */
+
+private[parser] class ANTLRNoCaseStringStream(input: String) extends ANTLRInputStream(input) {
+ override def LA(i: Int): Int = {
+ val la = super.LA(i)
+ if (la == 0 || la == IntStream.EOF) la
+ else Character.toUpperCase(la)
+ }
+}
+
+/**
+ * The ParseErrorListener converts parse errors into AnalysisExceptions.
+ */
+case object ParseErrorListener extends BaseErrorListener {
+ override def syntaxError(
+ recognizer: Recognizer[_, _],
+ offendingSymbol: scala.Any,
+ line: Int,
+ charPositionInLine: Int,
+ msg: String,
+ e: RecognitionException): Unit = {
+ val position = Origin(Some(line), Some(charPositionInLine))
+ throw new ParseException(None, msg, position, position)
+ }
+}
+
+/**
+ * A [[ParseException]] is an [[AnalysisException]] that is thrown during the parse process. It
+ * contains fields and an extended error message that make reporting and diagnosing errors easier.
+ */
+class ParseException(
+ val command: Option[String],
+ message: String,
+ val start: Origin,
+ val stop: Origin) extends AnalysisException(message, start.line, start.startPosition) {
+
+ def this(message: String, ctx: ParserRuleContext) = {
+ this(Option(ParserUtils.command(ctx)),
+ message,
+ ParserUtils.position(ctx.getStart),
+ ParserUtils.position(ctx.getStop))
+ }
+
+ override def getMessage: String = {
+ val builder = new StringBuilder
+ builder ++= "\n" ++= message
+ start match {
+ case Origin(Some(l), Some(p)) =>
+ builder ++= s"(line $l, pos $p)\n"
+ command.foreach { cmd =>
+ val (above, below) = cmd.split("\n").splitAt(l)
+ builder ++= "\n== SQL ==\n"
+ above.foreach(builder ++= _ += '\n')
+ builder ++= (0 until p).map(_ => "-").mkString("") ++= "^^^\n"
+ below.foreach(builder ++= _ += '\n')
+ }
+ case _ =>
+ command.foreach { cmd =>
+ builder ++= "\n== SQL ==\n" ++= cmd
+ }
+ }
+ builder.toString
+ }
+
+ def withCommand(cmd: String): ParseException = {
+ new ParseException(Option(cmd), message, start, stop)
+ }
+}
+
+/**
+ * The post-processor validates & cleans-up the parse tree during the parse process.
+ */
+case object PostProcessor extends SqlBaseBaseListener {
+
+ /** Remove the back ticks from an Identifier. */
+ override def exitQuotedIdentifier(ctx: QuotedIdentifierContext): Unit = {
+ replaceTokenByIdentifier(ctx, 1) { token =>
+ // Remove the double back ticks in the string.
+ token.setText(token.getText.replace("``", "`"))
+ token
+ }
+ }
+
+ /** Treat non-reserved keywords as Identifiers. */
+ override def exitNonReserved(ctx: NonReservedContext): Unit = {
+ replaceTokenByIdentifier(ctx, 0)(identity)
+ }
+
+ private def replaceTokenByIdentifier(
+ ctx: ParserRuleContext,
+ stripMargins: Int)(
+ f: CommonToken => CommonToken = identity): Unit = {
+ val parent = ctx.getParent
+ parent.removeLastChild()
+ val token = ctx.getChild(0).getPayload.asInstanceOf[Token]
+ parent.addChild(f(new CommonToken(
+ new org.antlr.v4.runtime.misc.Pair(token.getTokenSource, token.getInputStream),
+ SqlBaseParser.IDENTIFIER,
+ token.getChannel,
+ token.getStartIndex + stripMargins,
+ token.getStopIndex - stripMargins)))
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ng/ParserUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ng/ParserUtils.scala
new file mode 100644
index 0000000000..1fbfa763b4
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ng/ParserUtils.scala
@@ -0,0 +1,118 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.catalyst.parser.ng
+
+import org.antlr.v4.runtime.{CharStream, ParserRuleContext, Token}
+import org.antlr.v4.runtime.misc.Interval
+import org.antlr.v4.runtime.tree.TerminalNode
+
+import org.apache.spark.sql.catalyst.parser.ParseUtils.unescapeSQLString
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin}
+
+/**
+ * A collection of utility methods for use during the parsing process.
+ */
+object ParserUtils {
+ /** Get the command which created the token. */
+ def command(ctx: ParserRuleContext): String = {
+ command(ctx.getStart.getInputStream)
+ }
+
+ /** Get the command which created the token. */
+ def command(stream: CharStream): String = {
+ stream.getText(Interval.of(0, stream.size()))
+ }
+
+ /** Get the code that creates the given node. */
+ def source(ctx: ParserRuleContext): String = {
+ val stream = ctx.getStart.getInputStream
+ stream.getText(Interval.of(ctx.getStart.getStartIndex, ctx.getStop.getStopIndex))
+ }
+
+ /** Get all the text which comes after the given rule. */
+ def remainder(ctx: ParserRuleContext): String = remainder(ctx.getStop)
+
+ /** Get all the text which comes after the given token. */
+ def remainder(token: Token): String = {
+ val stream = token.getInputStream
+ val interval = Interval.of(token.getStopIndex + 1, stream.size())
+ stream.getText(interval)
+ }
+
+ /** Convert a string token into a string. */
+ def string(token: Token): String = unescapeSQLString(token.getText)
+
+ /** Convert a string node into a string. */
+ def string(node: TerminalNode): String = unescapeSQLString(node.getText)
+
+ /** Get the origin (line and position) of the token. */
+ def position(token: Token): Origin = {
+ Origin(Option(token.getLine), Option(token.getCharPositionInLine))
+ }
+
+ /** Assert if a condition holds. If it doesn't throw a parse exception. */
+ def assert(f: => Boolean, message: String, ctx: ParserRuleContext): Unit = {
+ if (!f) {
+ throw new ParseException(message, ctx)
+ }
+ }
+
+ /**
+ * Register the origin of the context. Any TreeNode created in the closure will be assigned the
+ * registered origin. This method restores the previously set origin after completion of the
+ * closure.
+ */
+ def withOrigin[T](ctx: ParserRuleContext)(f: => T): T = {
+ val current = CurrentOrigin.get
+ CurrentOrigin.set(position(ctx.getStart))
+ try {
+ f
+ } finally {
+ CurrentOrigin.set(current)
+ }
+ }
+
+ /** Some syntactic sugar which makes it easier to work with optional clauses for LogicalPlans. */
+ implicit class EnhancedLogicalPlan(val plan: LogicalPlan) extends AnyVal {
+ /**
+ * Create a plan using the block of code when the given context exists. Otherwise return the
+ * original plan.
+ */
+ def optional(ctx: AnyRef)(f: => LogicalPlan): LogicalPlan = {
+ if (ctx != null) {
+ f
+ } else {
+ plan
+ }
+ }
+
+ /**
+ * Map a [[LogicalPlan]] to another [[LogicalPlan]] if the passed context exists using the
+ * passed function. The original plan is returned when the context does not exist.
+ */
+ def optionalMap[C <: ParserRuleContext](
+ ctx: C)(
+ f: (C, LogicalPlan) => LogicalPlan): LogicalPlan = {
+ if (ctx != null) {
+ f(ctx, plan)
+ } else {
+ plan
+ }
+ }
+ }
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/CatalystQlSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/CatalystQlSuite.scala
index c068e895b6..223485e292 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/CatalystQlSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/CatalystQlSuite.scala
@@ -21,15 +21,20 @@ import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.parser.ng.CatalystSqlParser
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.unsafe.types.CalendarInterval
class CatalystQlSuite extends PlanTest {
val parser = new CatalystQl()
+ import org.apache.spark.sql.catalyst.dsl.expressions._
+ import org.apache.spark.sql.catalyst.dsl.plans._
+
+ val star = UnresolvedAlias(UnresolvedStar(None))
test("test case insensitive") {
- val result = Project(UnresolvedAlias(Literal(1)):: Nil, OneRowRelation)
+ val result = OneRowRelation.select(1)
assert(result === parser.parsePlan("seLect 1"))
assert(result === parser.parsePlan("select 1"))
assert(result === parser.parsePlan("SELECT 1"))
@@ -37,52 +42,31 @@ class CatalystQlSuite extends PlanTest {
test("test NOT operator with comparison operations") {
val parsed = parser.parsePlan("SELECT NOT TRUE > TRUE")
- val expected = Project(
- UnresolvedAlias(
- Not(
- GreaterThan(Literal(true), Literal(true)))
- ) :: Nil,
- OneRowRelation)
+ val expected = OneRowRelation.select(Not(GreaterThan(true, true)))
comparePlans(parsed, expected)
}
test("test Union Distinct operator") {
- val parsed1 = parser.parsePlan("SELECT * FROM t0 UNION SELECT * FROM t1")
- val parsed2 = parser.parsePlan("SELECT * FROM t0 UNION DISTINCT SELECT * FROM t1")
- val expected =
- Project(UnresolvedAlias(UnresolvedStar(None)) :: Nil,
- SubqueryAlias("u_1",
- Distinct(
- Union(
- Project(UnresolvedAlias(UnresolvedStar(None)) :: Nil,
- UnresolvedRelation(TableIdentifier("t0"), None)),
- Project(UnresolvedAlias(UnresolvedStar(None)) :: Nil,
- UnresolvedRelation(TableIdentifier("t1"), None))))))
+ val parsed1 = parser.parsePlan(
+ "SELECT * FROM t0 UNION SELECT * FROM t1")
+ val parsed2 = parser.parsePlan(
+ "SELECT * FROM t0 UNION DISTINCT SELECT * FROM t1")
+ val expected = Distinct(Union(table("t0").select(star), table("t1").select(star)))
+ .as("u_1").select(star)
comparePlans(parsed1, expected)
comparePlans(parsed2, expected)
}
test("test Union All operator") {
val parsed = parser.parsePlan("SELECT * FROM t0 UNION ALL SELECT * FROM t1")
- val expected =
- Project(UnresolvedAlias(UnresolvedStar(None)) :: Nil,
- SubqueryAlias("u_1",
- Union(
- Project(UnresolvedAlias(UnresolvedStar(None)) :: Nil,
- UnresolvedRelation(TableIdentifier("t0"), None)),
- Project(UnresolvedAlias(UnresolvedStar(None)) :: Nil,
- UnresolvedRelation(TableIdentifier("t1"), None)))))
+ val expected = Union(table("t0").select(star), table("t1").select(star)).as("u_1").select(star)
comparePlans(parsed, expected)
}
test("support hive interval literal") {
def checkInterval(sql: String, result: CalendarInterval): Unit = {
val parsed = parser.parsePlan(sql)
- val expected = Project(
- UnresolvedAlias(
- Literal(result)
- ) :: Nil,
- OneRowRelation)
+ val expected = OneRowRelation.select(Literal(result))
comparePlans(parsed, expected)
}
@@ -129,11 +113,7 @@ class CatalystQlSuite extends PlanTest {
test("support scientific notation") {
def assertRight(input: String, output: Double): Unit = {
val parsed = parser.parsePlan("SELECT " + input)
- val expected = Project(
- UnresolvedAlias(
- Literal(output)
- ) :: Nil,
- OneRowRelation)
+ val expected = OneRowRelation.select(Literal(output))
comparePlans(parsed, expected)
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DataTypeParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DataTypeParserSuite.scala
index 7d3608033b..d9bd33c50a 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DataTypeParserSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DataTypeParserSuite.scala
@@ -18,19 +18,24 @@
package org.apache.spark.sql.catalyst.parser
import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.catalyst.parser.ng.{CatalystSqlParser, ParseException}
import org.apache.spark.sql.types._
-class DataTypeParserSuite extends SparkFunSuite {
+abstract class AbstractDataTypeParserSuite extends SparkFunSuite {
+
+ def parse(sql: String): DataType
def checkDataType(dataTypeString: String, expectedDataType: DataType): Unit = {
test(s"parse ${dataTypeString.replace("\n", "")}") {
- assert(DataTypeParser.parse(dataTypeString) === expectedDataType)
+ assert(parse(dataTypeString) === expectedDataType)
}
}
+ def intercept(sql: String)
+
def unsupported(dataTypeString: String): Unit = {
test(s"$dataTypeString is not supported") {
- intercept[DataTypeException](DataTypeParser.parse(dataTypeString))
+ intercept(dataTypeString)
}
}
@@ -97,13 +102,6 @@ class DataTypeParserSuite extends SparkFunSuite {
StructField("arrAy", ArrayType(DoubleType, true), true) ::
StructField("anotherArray", ArrayType(StringType, true), true) :: Nil)
)
- // A column name can be a reserved word in our DDL parser and SqlParser.
- checkDataType(
- "Struct<TABLE: string, CASE:boolean>",
- StructType(
- StructField("TABLE", StringType, true) ::
- StructField("CASE", BooleanType, true) :: Nil)
- )
// Use backticks to quote column names having special characters.
checkDataType(
"struct<`x+y`:int, `!@#$%^&*()`:string, `1_2.345<>:\"`:varchar(20)>",
@@ -118,6 +116,43 @@ class DataTypeParserSuite extends SparkFunSuite {
unsupported("it is not a data type")
unsupported("struct<x+y: int, 1.1:timestamp>")
unsupported("struct<x: int")
+}
+
+class DataTypeParserSuite extends AbstractDataTypeParserSuite {
+ override def intercept(sql: String): Unit =
+ intercept[DataTypeException](DataTypeParser.parse(sql))
+
+ override def parse(sql: String): DataType =
+ DataTypeParser.parse(sql)
+
+ // A column name can be a reserved word in our DDL parser and SqlParser.
+ checkDataType(
+ "Struct<TABLE: string, CASE:boolean>",
+ StructType(
+ StructField("TABLE", StringType, true) ::
+ StructField("CASE", BooleanType, true) :: Nil)
+ )
+
unsupported("struct<x int, y string>")
+
unsupported("struct<`x``y` int>")
}
+
+class CatalystQlDataTypeParserSuite extends AbstractDataTypeParserSuite {
+ override def intercept(sql: String): Unit =
+ intercept[ParseException](CatalystSqlParser.parseDataType(sql))
+
+ override def parse(sql: String): DataType =
+ CatalystSqlParser.parseDataType(sql)
+
+ // A column name can be a reserved word in our DDL parser and SqlParser.
+ unsupported("Struct<TABLE: string, CASE:boolean>")
+
+ checkDataType(
+ "struct<x int, y string>",
+ (new StructType).add("x", IntegerType).add("y", StringType))
+
+ checkDataType(
+ "struct<`x``y` int>",
+ (new StructType).add("x`y", IntegerType))
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ng/ErrorParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ng/ErrorParserSuite.scala
new file mode 100644
index 0000000000..1963fc368f
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ng/ErrorParserSuite.scala
@@ -0,0 +1,67 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.catalyst.parser.ng
+
+import org.apache.spark.SparkFunSuite
+
+/**
+ * Test various parser errors.
+ */
+class ErrorParserSuite extends SparkFunSuite {
+ def intercept(sql: String, line: Int, startPosition: Int, messages: String*): Unit = {
+ val e = intercept[ParseException](CatalystSqlParser.parsePlan(sql))
+
+ // Check position.
+ assert(e.line.isDefined)
+ assert(e.line.get === line)
+ assert(e.startPosition.isDefined)
+ assert(e.startPosition.get === startPosition)
+
+ // Check messages.
+ val error = e.getMessage
+ messages.foreach { message =>
+ assert(error.contains(message))
+ }
+ }
+
+ test("no viable input") {
+ intercept("select from tbl", 1, 7, "no viable alternative at input", "-------^^^")
+ intercept("select\nfrom tbl", 2, 0, "no viable alternative at input", "^^^")
+ intercept("select ((r + 1) ", 1, 16, "no viable alternative at input", "----------------^^^")
+ }
+
+ test("extraneous input") {
+ intercept("select 1 1", 1, 9, "extraneous input '1' expecting", "---------^^^")
+ intercept("select *\nfrom r as q t", 2, 12, "extraneous input", "------------^^^")
+ }
+
+ test("mismatched input") {
+ intercept("select * from r order by q from t", 1, 27,
+ "mismatched input",
+ "---------------------------^^^")
+ intercept("select *\nfrom r\norder by q\nfrom t", 4, 0, "mismatched input", "^^^")
+ }
+
+ test("semantic errors") {
+ intercept("select *\nfrom r\norder by q\ncluster by q", 3, 0,
+ "Combination of ORDER BY/SORT BY/DISTRIBUTE BY/CLUSTER BY is not supported",
+ "^^^")
+ intercept("select * from r where a in (select * from t)", 1, 24,
+ "IN with a Sub-query is currently not supported",
+ "------------------------^^^")
+ }
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ng/ExpressionParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ng/ExpressionParserSuite.scala
new file mode 100644
index 0000000000..32311a5a66
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ng/ExpressionParserSuite.scala
@@ -0,0 +1,497 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.catalyst.parser.ng
+
+import java.sql.{Date, Timestamp}
+
+import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, _}
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.PlanTest
+import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.CalendarInterval
+
+/**
+ * Test basic expression parsing. If a type of expression is supported it should be tested here.
+ *
+ * Please note that some of the expressions test don't have to be sound expressions, only their
+ * structure needs to be valid. Unsound expressions should be caught by the Analyzer or
+ * CheckAnalysis classes.
+ */
+class ExpressionParserSuite extends PlanTest {
+ import CatalystSqlParser._
+ import org.apache.spark.sql.catalyst.dsl.expressions._
+ import org.apache.spark.sql.catalyst.dsl.plans._
+
+ def assertEqual(sqlCommand: String, e: Expression): Unit = {
+ compareExpressions(parseExpression(sqlCommand), e)
+ }
+
+ def intercept(sqlCommand: String, messages: String*): Unit = {
+ val e = intercept[ParseException](parseExpression(sqlCommand))
+ messages.foreach { message =>
+ assert(e.message.contains(message))
+ }
+ }
+
+ test("star expressions") {
+ // Global Star
+ assertEqual("*", UnresolvedStar(None))
+
+ // Targeted Star
+ assertEqual("a.b.*", UnresolvedStar(Option(Seq("a", "b"))))
+ }
+
+ // NamedExpression (Alias/Multialias)
+ test("named expressions") {
+ // No Alias
+ val r0 = 'a
+ assertEqual("a", r0)
+
+ // Single Alias.
+ val r1 = 'a as "b"
+ assertEqual("a as b", r1)
+ assertEqual("a b", r1)
+
+ // Multi-Alias
+ assertEqual("a as (b, c)", MultiAlias('a, Seq("b", "c")))
+ assertEqual("a() (b, c)", MultiAlias('a.function(), Seq("b", "c")))
+
+ // Numeric literals without a space between the literal qualifier and the alias, should not be
+ // interpreted as such. An unresolved reference should be returned instead.
+ // TODO add the JIRA-ticket number.
+ assertEqual("1SL", Symbol("1SL"))
+
+ // Aliased star is allowed.
+ assertEqual("a.* b", UnresolvedStar(Option(Seq("a"))) as 'b)
+ }
+
+ test("binary logical expressions") {
+ // And
+ assertEqual("a and b", 'a && 'b)
+
+ // Or
+ assertEqual("a or b", 'a || 'b)
+
+ // Combination And/Or check precedence
+ assertEqual("a and b or c and d", ('a && 'b) || ('c && 'd))
+ assertEqual("a or b or c and d", 'a || 'b || ('c && 'd))
+
+ // Multiple AND/OR get converted into a balanced tree
+ assertEqual("a or b or c or d or e or f", (('a || 'b) || 'c) || (('d || 'e) || 'f))
+ assertEqual("a and b and c and d and e and f", (('a && 'b) && 'c) && (('d && 'e) && 'f))
+ }
+
+ test("long binary logical expressions") {
+ def testVeryBinaryExpression(op: String, clazz: Class[_]): Unit = {
+ val sql = (1 to 1000).map(x => s"$x == $x").mkString(op)
+ val e = parseExpression(sql)
+ assert(e.collect { case _: EqualTo => true }.size === 1000)
+ assert(e.collect { case x if clazz.isInstance(x) => true }.size === 999)
+ }
+ testVeryBinaryExpression(" AND ", classOf[And])
+ testVeryBinaryExpression(" OR ", classOf[Or])
+ }
+
+ test("not expressions") {
+ assertEqual("not a", !'a)
+ assertEqual("!a", !'a)
+ assertEqual("not true > true", Not(GreaterThan(true, true)))
+ }
+
+ test("exists expression") {
+ intercept("exists (select 1 from b where b.x = a.x)", "EXISTS clauses are not supported")
+ }
+
+ test("comparison expressions") {
+ assertEqual("a = b", 'a === 'b)
+ assertEqual("a == b", 'a === 'b)
+ assertEqual("a <=> b", 'a <=> 'b)
+ assertEqual("a <> b", 'a =!= 'b)
+ assertEqual("a != b", 'a =!= 'b)
+ assertEqual("a < b", 'a < 'b)
+ assertEqual("a <= b", 'a <= 'b)
+ assertEqual("a > b", 'a > 'b)
+ assertEqual("a >= b", 'a >= 'b)
+ }
+
+ test("between expressions") {
+ assertEqual("a between b and c", 'a >= 'b && 'a <= 'c)
+ assertEqual("a not between b and c", !('a >= 'b && 'a <= 'c))
+ }
+
+ test("in expressions") {
+ assertEqual("a in (b, c, d)", 'a in ('b, 'c, 'd))
+ assertEqual("a not in (b, c, d)", !('a in ('b, 'c, 'd)))
+ }
+
+ test("in sub-query") {
+ intercept("a in (select b from c)", "IN with a Sub-query is currently not supported")
+ }
+
+ test("like expressions") {
+ assertEqual("a like 'pattern%'", 'a like "pattern%")
+ assertEqual("a not like 'pattern%'", !('a like "pattern%"))
+ assertEqual("a rlike 'pattern%'", 'a rlike "pattern%")
+ assertEqual("a not rlike 'pattern%'", !('a rlike "pattern%"))
+ assertEqual("a regexp 'pattern%'", 'a rlike "pattern%")
+ assertEqual("a not regexp 'pattern%'", !('a rlike "pattern%"))
+ }
+
+ test("is null expressions") {
+ assertEqual("a is null", 'a.isNull)
+ assertEqual("a is not null", 'a.isNotNull)
+ assertEqual("a = b is null", ('a === 'b).isNull)
+ assertEqual("a = b is not null", ('a === 'b).isNotNull)
+ }
+
+ test("binary arithmetic expressions") {
+ // Simple operations
+ assertEqual("a * b", 'a * 'b)
+ assertEqual("a / b", 'a / 'b)
+ assertEqual("a DIV b", ('a / 'b).cast(LongType))
+ assertEqual("a % b", 'a % 'b)
+ assertEqual("a + b", 'a + 'b)
+ assertEqual("a - b", 'a - 'b)
+ assertEqual("a & b", 'a & 'b)
+ assertEqual("a ^ b", 'a ^ 'b)
+ assertEqual("a | b", 'a | 'b)
+
+ // Check precedences
+ assertEqual(
+ "a * t | b ^ c & d - e + f % g DIV h / i * k",
+ 'a * 't | ('b ^ ('c & ('d - 'e + (('f % 'g / 'h).cast(LongType) / 'i * 'k)))))
+ }
+
+ test("unary arithmetic expressions") {
+ assertEqual("+a", 'a)
+ assertEqual("-a", -'a)
+ assertEqual("~a", ~'a)
+ assertEqual("-+~~a", -(~(~'a)))
+ }
+
+ test("cast expressions") {
+ // Note that DataType parsing is tested elsewhere.
+ assertEqual("cast(a as int)", 'a.cast(IntegerType))
+ assertEqual("cast(a as timestamp)", 'a.cast(TimestampType))
+ assertEqual("cast(a as array<int>)", 'a.cast(ArrayType(IntegerType)))
+ assertEqual("cast(cast(a as int) as long)", 'a.cast(IntegerType).cast(LongType))
+ }
+
+ test("function expressions") {
+ assertEqual("foo()", 'foo.function())
+ assertEqual("foo.bar()", Symbol("foo.bar").function())
+ assertEqual("foo(*)", 'foo.function(star()))
+ assertEqual("count(*)", 'count.function(1))
+ assertEqual("foo(a, b)", 'foo.function('a, 'b))
+ assertEqual("foo(all a, b)", 'foo.function('a, 'b))
+ assertEqual("foo(distinct a, b)", 'foo.distinctFunction('a, 'b))
+ assertEqual("grouping(distinct a, b)", 'grouping.distinctFunction('a, 'b))
+ assertEqual("`select`(all a, b)", 'select.function('a, 'b))
+ }
+
+ test("window function expressions") {
+ val func = 'foo.function(star())
+ def windowed(
+ partitioning: Seq[Expression] = Seq.empty,
+ ordering: Seq[SortOrder] = Seq.empty,
+ frame: WindowFrame = UnspecifiedFrame): Expression = {
+ WindowExpression(func, WindowSpecDefinition(partitioning, ordering, frame))
+ }
+
+ // Basic window testing.
+ assertEqual("foo(*) over w1", UnresolvedWindowExpression(func, WindowSpecReference("w1")))
+ assertEqual("foo(*) over ()", windowed())
+ assertEqual("foo(*) over (partition by a, b)", windowed(Seq('a, 'b)))
+ assertEqual("foo(*) over (distribute by a, b)", windowed(Seq('a, 'b)))
+ assertEqual("foo(*) over (cluster by a, b)", windowed(Seq('a, 'b)))
+ assertEqual("foo(*) over (order by a desc, b asc)", windowed(Seq.empty, Seq('a.desc, 'b.asc )))
+ assertEqual("foo(*) over (sort by a desc, b asc)", windowed(Seq.empty, Seq('a.desc, 'b.asc )))
+ assertEqual("foo(*) over (partition by a, b order by c)", windowed(Seq('a, 'b), Seq('c.asc)))
+ assertEqual("foo(*) over (distribute by a, b sort by c)", windowed(Seq('a, 'b), Seq('c.asc)))
+
+ // Test use of expressions in window functions.
+ assertEqual(
+ "sum(product + 1) over (partition by ((product) + (1)) order by 2)",
+ WindowExpression('sum.function('product + 1),
+ WindowSpecDefinition(Seq('product + 1), Seq(Literal(2).asc), UnspecifiedFrame)))
+ assertEqual(
+ "sum(product + 1) over (partition by ((product / 2) + 1) order by 2)",
+ WindowExpression('sum.function('product + 1),
+ WindowSpecDefinition(Seq('product / 2 + 1), Seq(Literal(2).asc), UnspecifiedFrame)))
+
+ // Range/Row
+ val frameTypes = Seq(("rows", RowFrame), ("range", RangeFrame))
+ val boundaries = Seq(
+ ("10 preceding", ValuePreceding(10), CurrentRow),
+ ("3 + 1 following", ValueFollowing(4), CurrentRow), // Will fail during analysis
+ ("unbounded preceding", UnboundedPreceding, CurrentRow),
+ ("unbounded following", UnboundedFollowing, CurrentRow), // Will fail during analysis
+ ("between unbounded preceding and current row", UnboundedPreceding, CurrentRow),
+ ("between unbounded preceding and unbounded following",
+ UnboundedPreceding, UnboundedFollowing),
+ ("between 10 preceding and current row", ValuePreceding(10), CurrentRow),
+ ("between current row and 5 following", CurrentRow, ValueFollowing(5)),
+ ("between 10 preceding and 5 following", ValuePreceding(10), ValueFollowing(5))
+ )
+ frameTypes.foreach {
+ case (frameTypeSql, frameType) =>
+ boundaries.foreach {
+ case (boundarySql, begin, end) =>
+ val query = s"foo(*) over (partition by a order by b $frameTypeSql $boundarySql)"
+ val expr = windowed(Seq('a), Seq('b.asc), SpecifiedWindowFrame(frameType, begin, end))
+ assertEqual(query, expr)
+ }
+ }
+
+ // We cannot use non integer constants.
+ intercept("foo(*) over (partition by a order by b rows 10.0 preceding)",
+ "Frame bound value must be a constant integer.")
+
+ // We cannot use an arbitrary expression.
+ intercept("foo(*) over (partition by a order by b rows exp(b) preceding)",
+ "Frame bound value must be a constant integer.")
+ }
+
+ test("row constructor") {
+ // Note that '(a)' will be interpreted as a nested expression.
+ assertEqual("(a, b)", CreateStruct(Seq('a, 'b)))
+ assertEqual("(a, b, c)", CreateStruct(Seq('a, 'b, 'c)))
+ }
+
+ test("scalar sub-query") {
+ assertEqual(
+ "(select max(val) from tbl) > current",
+ ScalarSubquery(table("tbl").select('max.function('val))) > 'current)
+ assertEqual(
+ "a = (select b from s)",
+ 'a === ScalarSubquery(table("s").select('b)))
+ }
+
+ test("case when") {
+ assertEqual("case a when 1 then b when 2 then c else d end",
+ CaseKeyWhen('a, Seq(1, 'b, 2, 'c, 'd)))
+ assertEqual("case when a = 1 then b when a = 2 then c else d end",
+ CaseWhen(Seq(('a === 1, 'b.expr), ('a === 2, 'c.expr)), 'd))
+ }
+
+ test("dereference") {
+ assertEqual("a.b", UnresolvedAttribute("a.b"))
+ assertEqual("`select`.b", UnresolvedAttribute("select.b"))
+ assertEqual("(a + b).b", ('a + 'b).getField("b")) // This will fail analysis.
+ assertEqual("struct(a, b).b", 'struct.function('a, 'b).getField("b"))
+ }
+
+ test("reference") {
+ // Regular
+ assertEqual("a", 'a)
+
+ // Starting with a digit.
+ assertEqual("1a", Symbol("1a"))
+
+ // Quoted using a keyword.
+ assertEqual("`select`", 'select)
+
+ // Unquoted using an unreserved keyword.
+ assertEqual("columns", 'columns)
+ }
+
+ test("subscript") {
+ assertEqual("a[b]", 'a.getItem('b))
+ assertEqual("a[1 + 1]", 'a.getItem(Literal(1) + 1))
+ assertEqual("`c`.a[b]", UnresolvedAttribute("c.a").getItem('b))
+ }
+
+ test("parenthesis") {
+ assertEqual("(a)", 'a)
+ assertEqual("r * (a + b)", 'r * ('a + 'b))
+ }
+
+ test("type constructors") {
+ // Dates.
+ assertEqual("dAte '2016-03-11'", Literal(Date.valueOf("2016-03-11")))
+ intercept[IllegalArgumentException] {
+ parseExpression("DAtE 'mar 11 2016'")
+ }
+
+ // Timestamps.
+ assertEqual("tImEstAmp '2016-03-11 20:54:00.000'",
+ Literal(Timestamp.valueOf("2016-03-11 20:54:00.000")))
+ intercept[IllegalArgumentException] {
+ parseExpression("timestamP '2016-33-11 20:54:00.000'")
+ }
+
+ // Unsupported datatype.
+ intercept("GEO '(10,-6)'", "Literals of type 'GEO' are currently not supported.")
+ }
+
+ test("literals") {
+ // NULL
+ assertEqual("null", Literal(null))
+
+ // Boolean
+ assertEqual("trUe", Literal(true))
+ assertEqual("False", Literal(false))
+
+ // Integral should have the narrowest possible type
+ assertEqual("787324", Literal(787324))
+ assertEqual("7873247234798249234", Literal(7873247234798249234L))
+ assertEqual("78732472347982492793712334",
+ Literal(BigDecimal("78732472347982492793712334").underlying()))
+
+ // Decimal
+ assertEqual("7873247234798249279371.2334",
+ Literal(BigDecimal("7873247234798249279371.2334").underlying()))
+
+ // Scientific Decimal
+ assertEqual("9.0e1", 90d)
+ assertEqual(".9e+2", 90d)
+ assertEqual("0.9e+2", 90d)
+ assertEqual("900e-1", 90d)
+ assertEqual("900.0E-1", 90d)
+ assertEqual("9.e+1", 90d)
+ intercept(".e3")
+
+ // Tiny Int Literal
+ assertEqual("10Y", Literal(10.toByte))
+ intercept("-1000Y")
+
+ // Small Int Literal
+ assertEqual("10S", Literal(10.toShort))
+ intercept("40000S")
+
+ // Long Int Literal
+ assertEqual("10L", Literal(10L))
+ intercept("78732472347982492793712334L")
+
+ // Double Literal
+ assertEqual("10.0D", Literal(10.0D))
+ // TODO we need to figure out if we should throw an exception here!
+ assertEqual("1E309", Literal(Double.PositiveInfinity))
+ }
+
+ test("strings") {
+ // Single Strings.
+ assertEqual("\"hello\"", "hello")
+ assertEqual("'hello'", "hello")
+
+ // Multi-Strings.
+ assertEqual("\"hello\" 'world'", "helloworld")
+ assertEqual("'hello' \" \" 'world'", "hello world")
+
+ // 'LIKE' string literals. Notice that an escaped '%' is the same as an escaped '\' and a
+ // regular '%'; to get the correct result you need to add another escaped '\'.
+ // TODO figure out if we shouldn't change the ParseUtils.unescapeSQLString method?
+ assertEqual("'pattern%'", "pattern%")
+ assertEqual("'no-pattern\\%'", "no-pattern\\%")
+ assertEqual("'pattern\\\\%'", "pattern\\%")
+ assertEqual("'pattern\\\\\\%'", "pattern\\\\%")
+
+ // Escaped characters.
+ // See: http://dev.mysql.com/doc/refman/5.7/en/string-literals.html
+ assertEqual("'\\0'", "\u0000") // ASCII NUL (X'00')
+ assertEqual("'\\''", "\'") // Single quote
+ assertEqual("'\\\"'", "\"") // Double quote
+ assertEqual("'\\b'", "\b") // Backspace
+ assertEqual("'\\n'", "\n") // Newline
+ assertEqual("'\\r'", "\r") // Carriage return
+ assertEqual("'\\t'", "\t") // Tab character
+ assertEqual("'\\Z'", "\u001A") // ASCII 26 - CTRL + Z (EOF on windows)
+
+ // Octals
+ assertEqual("'\\110\\145\\154\\154\\157\\041'", "Hello!")
+
+ // Unicode
+ assertEqual("'\\u0087\\u0111\\u0114\\u0108\\u0100\\u0032\\u0058\\u0041'", "World :)")
+ }
+
+ test("intervals") {
+ def intervalLiteral(u: String, s: String): Literal = {
+ Literal(CalendarInterval.fromSingleUnitString(u, s))
+ }
+
+ // Empty interval statement
+ intercept("interval", "at least one time unit should be given for interval literal")
+
+ // Single Intervals.
+ val units = Seq(
+ "year",
+ "month",
+ "week",
+ "day",
+ "hour",
+ "minute",
+ "second",
+ "millisecond",
+ "microsecond")
+ val forms = Seq("", "s")
+ val values = Seq("0", "10", "-7", "21")
+ units.foreach { unit =>
+ forms.foreach { form =>
+ values.foreach { value =>
+ val expected = intervalLiteral(unit, value)
+ assertEqual(s"interval $value $unit$form", expected)
+ assertEqual(s"interval '$value' $unit$form", expected)
+ }
+ }
+ }
+
+ // Hive nanosecond notation.
+ assertEqual("interval 13.123456789 seconds", intervalLiteral("second", "13.123456789"))
+ assertEqual("interval -13.123456789 second", intervalLiteral("second", "-13.123456789"))
+
+ // Non Existing unit
+ intercept("interval 10 nanoseconds", "No interval can be constructed")
+
+ // Year-Month intervals.
+ val yearMonthValues = Seq("123-10", "496-0", "-2-3", "-123-0")
+ yearMonthValues.foreach { value =>
+ val result = Literal(CalendarInterval.fromYearMonthString(value))
+ assertEqual(s"interval '$value' year to month", result)
+ }
+
+ // Day-Time intervals.
+ val datTimeValues = Seq(
+ "99 11:22:33.123456789",
+ "-99 11:22:33.123456789",
+ "10 9:8:7.123456789",
+ "1 0:0:0",
+ "-1 0:0:0",
+ "1 0:0:1")
+ datTimeValues.foreach { value =>
+ val result = Literal(CalendarInterval.fromDayTimeString(value))
+ assertEqual(s"interval '$value' day to second", result)
+ }
+
+ // Unknown FROM TO intervals
+ intercept("interval 10 month to second", "Intervals FROM month TO second are not supported.")
+
+ // Composed intervals.
+ assertEqual(
+ "interval 3 months 22 seconds 1 millisecond",
+ Literal(new CalendarInterval(3, 22001000L)))
+ assertEqual(
+ "interval 3 years '-1-10' year to month 3 weeks '1 0:0:2' day to second",
+ Literal(new CalendarInterval(14,
+ 22 * CalendarInterval.MICROS_PER_DAY + 2 * CalendarInterval.MICROS_PER_SECOND)))
+ }
+
+ test("composed expressions") {
+ assertEqual("1 + r.r As q", (Literal(1) + UnresolvedAttribute("r.r")).as("q"))
+ assertEqual("1 - f('o', o(bar))", Literal(1) - 'f.function("o", 'o.function('bar)))
+ intercept("1 - f('o', o(bar)) hello * world", "mismatched input '*'")
+ }
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ng/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ng/PlanParserSuite.scala
new file mode 100644
index 0000000000..4206d22ca7
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ng/PlanParserSuite.scala
@@ -0,0 +1,429 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.catalyst.parser.ng
+
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans._
+import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.types.{BooleanType, IntegerType}
+
+class PlanParserSuite extends PlanTest {
+ import CatalystSqlParser._
+ import org.apache.spark.sql.catalyst.dsl.expressions._
+ import org.apache.spark.sql.catalyst.dsl.plans._
+
+ def assertEqual(sqlCommand: String, plan: LogicalPlan): Unit = {
+ comparePlans(parsePlan(sqlCommand), plan)
+ }
+
+ def intercept(sqlCommand: String, messages: String*): Unit = {
+ val e = intercept[ParseException](parsePlan(sqlCommand))
+ messages.foreach { message =>
+ assert(e.message.contains(message))
+ }
+ }
+
+ test("case insensitive") {
+ val plan = table("a").select(star())
+ assertEqual("sELEct * FroM a", plan)
+ assertEqual("select * fRoM a", plan)
+ assertEqual("SELECT * FROM a", plan)
+ }
+
+ test("show functions") {
+ assertEqual("show functions", ShowFunctions(None, None))
+ assertEqual("show functions foo", ShowFunctions(None, Some("foo")))
+ assertEqual("show functions foo.bar", ShowFunctions(Some("foo"), Some("bar")))
+ assertEqual("show functions 'foo\\\\.*'", ShowFunctions(None, Some("foo\\.*")))
+ intercept("show functions foo.bar.baz", "SHOW FUNCTIONS unsupported name")
+ }
+
+ test("describe function") {
+ assertEqual("describe function bar", DescribeFunction("bar", isExtended = false))
+ assertEqual("describe function extended bar", DescribeFunction("bar", isExtended = true))
+ assertEqual("describe function foo.bar", DescribeFunction("foo.bar", isExtended = false))
+ assertEqual("describe function extended f.bar", DescribeFunction("f.bar", isExtended = true))
+ }
+
+ test("set operations") {
+ val a = table("a").select(star())
+ val b = table("b").select(star())
+
+ assertEqual("select * from a union select * from b", Distinct(a.union(b)))
+ assertEqual("select * from a union distinct select * from b", Distinct(a.union(b)))
+ assertEqual("select * from a union all select * from b", a.union(b))
+ assertEqual("select * from a except select * from b", a.except(b))
+ intercept("select * from a except all select * from b", "EXCEPT ALL is not supported.")
+ assertEqual("select * from a except distinct select * from b", a.except(b))
+ assertEqual("select * from a intersect select * from b", a.intersect(b))
+ intercept("select * from a intersect all select * from b", "INTERSECT ALL is not supported.")
+ assertEqual("select * from a intersect distinct select * from b", a.intersect(b))
+ }
+
+ test("common table expressions") {
+ def cte(plan: LogicalPlan, namedPlans: (String, LogicalPlan)*): With = {
+ val ctes = namedPlans.map {
+ case (name, cte) =>
+ name -> SubqueryAlias(name, cte)
+ }.toMap
+ With(plan, ctes)
+ }
+ assertEqual(
+ "with cte1 as (select * from a) select * from cte1",
+ cte(table("cte1").select(star()), "cte1" -> table("a").select(star())))
+ assertEqual(
+ "with cte1 (select 1) select * from cte1",
+ cte(table("cte1").select(star()), "cte1" -> OneRowRelation.select(1)))
+ assertEqual(
+ "with cte1 (select 1), cte2 as (select * from cte1) select * from cte2",
+ cte(table("cte2").select(star()),
+ "cte1" -> OneRowRelation.select(1),
+ "cte2" -> table("cte1").select(star())))
+ intercept(
+ "with cte1 (select 1), cte1 as (select 1 from cte1) select * from cte1",
+ "Name 'cte1' is used for multiple common table expressions")
+ }
+
+ test("simple select query") {
+ assertEqual("select 1", OneRowRelation.select(1))
+ assertEqual("select a, b", OneRowRelation.select('a, 'b))
+ assertEqual("select a, b from db.c", table("db", "c").select('a, 'b))
+ assertEqual("select a, b from db.c where x < 1", table("db", "c").where('x < 1).select('a, 'b))
+ assertEqual(
+ "select a, b from db.c having x < 1",
+ table("db", "c").select('a, 'b).where(('x < 1).cast(BooleanType)))
+ assertEqual("select distinct a, b from db.c", Distinct(table("db", "c").select('a, 'b)))
+ assertEqual("select all a, b from db.c", table("db", "c").select('a, 'b))
+ }
+
+ test("reverse select query") {
+ assertEqual("from a", table("a"))
+ assertEqual("from a select b, c", table("a").select('b, 'c))
+ assertEqual(
+ "from db.a select b, c where d < 1", table("db", "a").where('d < 1).select('b, 'c))
+ assertEqual("from a select distinct b, c", Distinct(table("a").select('b, 'c)))
+ assertEqual(
+ "from (from a union all from b) c select *",
+ table("a").union(table("b")).as("c").select(star()))
+ }
+
+ test("transform query spec") {
+ val p = ScriptTransformation(Seq('a, 'b), "func", Seq.empty, table("e"), null)
+ assertEqual("select transform(a, b) using 'func' from e where f < 10",
+ p.copy(child = p.child.where('f < 10), output = Seq('key.string, 'value.string)))
+ assertEqual("map a, b using 'func' as c, d from e",
+ p.copy(output = Seq('c.string, 'd.string)))
+ assertEqual("reduce a, b using 'func' as (c: int, d decimal(10, 0)) from e",
+ p.copy(output = Seq('c.int, 'd.decimal(10, 0))))
+ }
+
+ test("multi select query") {
+ assertEqual(
+ "from a select * select * where s < 10",
+ table("a").select(star()).union(table("a").where('s < 10).select(star())))
+ intercept(
+ "from a select * select * from x where a.s < 10",
+ "Multi-Insert queries cannot have a FROM clause in their individual SELECT statements")
+ assertEqual(
+ "from a insert into tbl1 select * insert into tbl2 select * where s < 10",
+ table("a").select(star()).insertInto("tbl1").union(
+ table("a").where('s < 10).select(star()).insertInto("tbl2")))
+ }
+
+ test("query organization") {
+ // Test all valid combinations of order by/sort by/distribute by/cluster by/limit/windows
+ val baseSql = "select * from t"
+ val basePlan = table("t").select(star())
+
+ val ws = Map("w1" -> WindowSpecDefinition(Seq.empty, Seq.empty, UnspecifiedFrame))
+ val limitWindowClauses = Seq(
+ ("", (p: LogicalPlan) => p),
+ (" limit 10", (p: LogicalPlan) => p.limit(10)),
+ (" window w1 as ()", (p: LogicalPlan) => WithWindowDefinition(ws, p)),
+ (" window w1 as () limit 10", (p: LogicalPlan) => WithWindowDefinition(ws, p).limit(10))
+ )
+
+ val orderSortDistrClusterClauses = Seq(
+ ("", basePlan),
+ (" order by a, b desc", basePlan.orderBy('a.asc, 'b.desc)),
+ (" sort by a, b desc", basePlan.sortBy('a.asc, 'b.desc)),
+ (" distribute by a, b", basePlan.distribute('a, 'b)),
+ (" distribute by a sort by b", basePlan.distribute('a).sortBy('b.asc)),
+ (" cluster by a, b", basePlan.distribute('a, 'b).sortBy('a.asc, 'b.asc))
+ )
+
+ orderSortDistrClusterClauses.foreach {
+ case (s1, p1) =>
+ limitWindowClauses.foreach {
+ case (s2, pf2) =>
+ assertEqual(baseSql + s1 + s2, pf2(p1))
+ }
+ }
+
+ val msg = "Combination of ORDER BY/SORT BY/DISTRIBUTE BY/CLUSTER BY is not supported"
+ intercept(s"$baseSql order by a sort by a", msg)
+ intercept(s"$baseSql cluster by a distribute by a", msg)
+ intercept(s"$baseSql order by a cluster by a", msg)
+ intercept(s"$baseSql order by a distribute by a", msg)
+ }
+
+ test("insert into") {
+ val sql = "select * from t"
+ val plan = table("t").select(star())
+ def insert(
+ partition: Map[String, Option[String]],
+ overwrite: Boolean = false,
+ ifNotExists: Boolean = false): LogicalPlan =
+ InsertIntoTable(table("s"), partition, plan, overwrite, ifNotExists)
+
+ // Single inserts
+ assertEqual(s"insert overwrite table s $sql",
+ insert(Map.empty, overwrite = true))
+ assertEqual(s"insert overwrite table s if not exists $sql",
+ insert(Map.empty, overwrite = true, ifNotExists = true))
+ assertEqual(s"insert into s $sql",
+ insert(Map.empty))
+ assertEqual(s"insert into table s partition (c = 'd', e = 1) $sql",
+ insert(Map("c" -> Option("d"), "e" -> Option("1"))))
+ assertEqual(s"insert overwrite table s partition (c = 'd', x) if not exists $sql",
+ insert(Map("c" -> Option("d"), "x" -> None), overwrite = true, ifNotExists = true))
+
+ // Multi insert
+ val plan2 = table("t").where('x > 5).select(star())
+ assertEqual("from t insert into s select * limit 1 insert into u select * where x > 5",
+ InsertIntoTable(
+ table("s"), Map.empty, plan.limit(1), overwrite = false, ifNotExists = false).union(
+ InsertIntoTable(
+ table("u"), Map.empty, plan2, overwrite = false, ifNotExists = false)))
+ }
+
+ test("aggregation") {
+ val sql = "select a, b, sum(c) as c from d group by a, b"
+
+ // Normal
+ assertEqual(sql, table("d").groupBy('a, 'b)('a, 'b, 'sum.function('c).as("c")))
+
+ // Cube
+ assertEqual(s"$sql with cube",
+ table("d").groupBy(Cube(Seq('a, 'b)))('a, 'b, 'sum.function('c).as("c")))
+
+ // Rollup
+ assertEqual(s"$sql with rollup",
+ table("d").groupBy(Rollup(Seq('a, 'b)))('a, 'b, 'sum.function('c).as("c")))
+
+ // Grouping Sets
+ assertEqual(s"$sql grouping sets((a, b), (a), ())",
+ GroupingSets(Seq(0, 1, 3), Seq('a, 'b), table("d"), Seq('a, 'b, 'sum.function('c).as("c"))))
+ intercept(s"$sql grouping sets((a, b), (c), ())",
+ "c doesn't show up in the GROUP BY list")
+ }
+
+ test("limit") {
+ val sql = "select * from t"
+ val plan = table("t").select(star())
+ assertEqual(s"$sql limit 10", plan.limit(10))
+ assertEqual(s"$sql limit cast(9 / 4 as int)", plan.limit(Cast(Literal(9) / 4, IntegerType)))
+ }
+
+ test("window spec") {
+ // Note that WindowSpecs are testing in the ExpressionParserSuite
+ val sql = "select * from t"
+ val plan = table("t").select(star())
+ val spec = WindowSpecDefinition(Seq('a, 'b), Seq('c.asc),
+ SpecifiedWindowFrame(RowFrame, ValuePreceding(1), ValueFollowing(1)))
+
+ // Test window resolution.
+ val ws1 = Map("w1" -> spec, "w2" -> spec, "w3" -> spec)
+ assertEqual(
+ s"""$sql
+ |window w1 as (partition by a, b order by c rows between 1 preceding and 1 following),
+ | w2 as w1,
+ | w3 as w1""".stripMargin,
+ WithWindowDefinition(ws1, plan))
+
+ // Fail with no reference.
+ intercept(s"$sql window w2 as w1", "Cannot resolve window reference 'w1'")
+
+ // Fail when resolved reference is not a window spec.
+ intercept(
+ s"""$sql
+ |window w1 as (partition by a, b order by c rows between 1 preceding and 1 following),
+ | w2 as w1,
+ | w3 as w2""".stripMargin,
+ "Window reference 'w2' is not a window specification"
+ )
+ }
+
+ test("lateral view") {
+ // Single lateral view
+ assertEqual(
+ "select * from t lateral view explode(x) expl as x",
+ table("t")
+ .generate(Explode('x), join = true, outer = false, Some("expl"), Seq("x"))
+ .select(star()))
+
+ // Multiple lateral views
+ assertEqual(
+ """select *
+ |from t
+ |lateral view explode(x) expl
+ |lateral view outer json_tuple(x, y) jtup q, z""".stripMargin,
+ table("t")
+ .generate(Explode('x), join = true, outer = false, Some("expl"), Seq.empty)
+ .generate(JsonTuple(Seq('x, 'y)), join = true, outer = true, Some("jtup"), Seq("q", "z"))
+ .select(star()))
+
+ // Multi-Insert lateral views.
+ val from = table("t1").generate(Explode('x), join = true, outer = false, Some("expl"), Seq("x"))
+ assertEqual(
+ """from t1
+ |lateral view explode(x) expl as x
+ |insert into t2
+ |select *
+ |lateral view json_tuple(x, y) jtup q, z
+ |insert into t3
+ |select *
+ |where s < 10
+ """.stripMargin,
+ Union(from
+ .generate(JsonTuple(Seq('x, 'y)), join = true, outer = false, Some("jtup"), Seq("q", "z"))
+ .select(star())
+ .insertInto("t2"),
+ from.where('s < 10).select(star()).insertInto("t3")))
+
+ // Unsupported generator.
+ intercept(
+ "select * from t lateral view posexplode(x) posexpl as x, y",
+ "Generator function 'posexplode' is not supported")
+ }
+
+ test("joins") {
+ // Test single joins.
+ val testUnconditionalJoin = (sql: String, jt: JoinType) => {
+ assertEqual(
+ s"select * from t as tt $sql u",
+ table("t").as("tt").join(table("u"), jt, None).select(star()))
+ }
+ val testConditionalJoin = (sql: String, jt: JoinType) => {
+ assertEqual(
+ s"select * from t $sql u as uu on a = b",
+ table("t").join(table("u").as("uu"), jt, Option('a === 'b)).select(star()))
+ }
+ val testNaturalJoin = (sql: String, jt: JoinType) => {
+ assertEqual(
+ s"select * from t tt natural $sql u as uu",
+ table("t").as("tt").join(table("u").as("uu"), NaturalJoin(jt), None).select(star()))
+ }
+ val testUsingJoin = (sql: String, jt: JoinType) => {
+ assertEqual(
+ s"select * from t $sql u using(a, b)",
+ table("t").join(table("u"), UsingJoin(jt, Seq('a.attr, 'b.attr)), None).select(star()))
+ }
+ val testAll = Seq(testUnconditionalJoin, testConditionalJoin, testNaturalJoin, testUsingJoin)
+
+ def test(sql: String, jt: JoinType, tests: Seq[(String, JoinType) => Unit]): Unit = {
+ tests.foreach(_(sql, jt))
+ }
+ test("cross join", Inner, Seq(testUnconditionalJoin))
+ test(",", Inner, Seq(testUnconditionalJoin))
+ test("join", Inner, testAll)
+ test("inner join", Inner, testAll)
+ test("left join", LeftOuter, testAll)
+ test("left outer join", LeftOuter, testAll)
+ test("right join", RightOuter, testAll)
+ test("right outer join", RightOuter, testAll)
+ test("full join", FullOuter, testAll)
+ test("full outer join", FullOuter, testAll)
+
+ // Test multiple consecutive joins
+ assertEqual(
+ "select * from a join b join c right join d",
+ table("a").join(table("b")).join(table("c")).join(table("d"), RightOuter).select(star()))
+ }
+
+ test("sampled relations") {
+ val sql = "select * from t"
+ assertEqual(s"$sql tablesample(100 rows)",
+ table("t").limit(100).select(star()))
+ assertEqual(s"$sql tablesample(43 percent) as x",
+ Sample(0, .43d, withReplacement = false, 10L, table("t").as("x"))(true).select(star()))
+ assertEqual(s"$sql tablesample(bucket 4 out of 10) as x",
+ Sample(0, .4d, withReplacement = false, 10L, table("t").as("x"))(true).select(star()))
+ intercept(s"$sql tablesample(bucket 4 out of 10 on x) as x",
+ "TABLESAMPLE(BUCKET x OUT OF y ON id) is not supported")
+ intercept(s"$sql tablesample(bucket 11 out of 10) as x",
+ s"Sampling fraction (${11.0/10.0}) must be on interval [0, 1]")
+ }
+
+ test("sub-query") {
+ val plan = table("t0").select('id)
+ assertEqual("select id from (t0)", plan)
+ assertEqual("select id from ((((((t0))))))", plan)
+ assertEqual(
+ "(select * from t1) union distinct (select * from t2)",
+ Distinct(table("t1").select(star()).union(table("t2").select(star()))))
+ assertEqual(
+ "select * from ((select * from t1) union (select * from t2)) t",
+ Distinct(
+ table("t1").select(star()).union(table("t2").select(star()))).as("t").select(star()))
+ assertEqual(
+ """select id
+ |from (((select id from t0)
+ | union all
+ | (select id from t0))
+ | union all
+ | (select id from t0)) as u_1
+ """.stripMargin,
+ plan.union(plan).union(plan).as("u_1").select('id))
+ }
+
+ test("scalar sub-query") {
+ assertEqual(
+ "select (select max(b) from s) ss from t",
+ table("t").select(ScalarSubquery(table("s").select('max.function('b))).as("ss")))
+ assertEqual(
+ "select * from t where a = (select b from s)",
+ table("t").where('a === ScalarSubquery(table("s").select('b))).select(star()))
+ assertEqual(
+ "select g from t group by g having a > (select b from s)",
+ table("t")
+ .groupBy('g)('g)
+ .where(('a > ScalarSubquery(table("s").select('b))).cast(BooleanType)))
+ }
+
+ test("table reference") {
+ assertEqual("table t", table("t"))
+ assertEqual("table d.t", table("d", "t"))
+ }
+
+ test("inline table") {
+ assertEqual("values 1, 2, 3, 4", LocalRelation.fromExternalRows(
+ Seq('col1.int),
+ Seq(1, 2, 3, 4).map(x => Row(x))))
+ assertEqual(
+ "values (1, 'a'), (2, 'b'), (3, 'c') as tbl(a, b)",
+ LocalRelation.fromExternalRows(
+ Seq('a.int, 'b.string),
+ Seq((1, "a"), (2, "b"), (3, "c")).map(x => Row(x._1, x._2))).as("tbl"))
+ intercept("values (a, 'a'), (b, 'b')",
+ "All expressions in an inline table must be constants.")
+ intercept("values (1, 'a'), (2, 'b') as tbl(a, b, c)",
+ "Number of aliases must match the number of fields in an inline table.")
+ intercept[ArrayIndexOutOfBoundsException](parsePlan("values (1, 'a'), (2, 'b', 5Y)"))
+ }
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ng/TableIdentifierParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ng/TableIdentifierParserSuite.scala
new file mode 100644
index 0000000000..0874322187
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ng/TableIdentifierParserSuite.scala
@@ -0,0 +1,42 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.catalyst.parser.ng
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.catalyst.TableIdentifier
+
+class TableIdentifierParserSuite extends SparkFunSuite {
+ import CatalystSqlParser._
+
+ test("table identifier") {
+ // Regular names.
+ assert(TableIdentifier("q") === parseTableIdentifier("q"))
+ assert(TableIdentifier("q", Option("d")) === parseTableIdentifier("d.q"))
+
+ // Illegal names.
+ intercept[ParseException](parseTableIdentifier(""))
+ intercept[ParseException](parseTableIdentifier("d.q.g"))
+
+ // SQL Keywords.
+ val keywords = Seq("select", "from", "where", "left", "right")
+ keywords.foreach { keyword =>
+ intercept[ParseException](parseTableIdentifier(keyword))
+ assert(TableIdentifier(keyword) === parseTableIdentifier(s"`$keyword`"))
+ assert(TableIdentifier(keyword, Option("db")) === parseTableIdentifier(s"db.`$keyword`"))
+ }
+ }
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala
index 0541844e0b..aa5d4330d3 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.plans
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, OneRowRelation}
+import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, OneRowRelation, Sample}
import org.apache.spark.sql.catalyst.util._
/**
@@ -32,6 +32,8 @@ abstract class PlanTest extends SparkFunSuite with PredicateHelper {
*/
protected def normalizeExprIds(plan: LogicalPlan) = {
plan transformAllExpressions {
+ case s: ScalarSubquery =>
+ ScalarSubquery(s.query, ExprId(0))
case a: AttributeReference =>
AttributeReference(a.name, a.dataType, a.nullable)(exprId = ExprId(0))
case a: Alias =>
@@ -40,21 +42,25 @@ abstract class PlanTest extends SparkFunSuite with PredicateHelper {
}
/**
- * Normalizes the filter conditions that appear in the plan. For instance,
- * ((expr 1 && expr 2) && expr 3), (expr 1 && expr 2 && expr 3), (expr 3 && (expr 1 && expr 2)
- * etc., will all now be equivalent.
+ * Normalizes plans:
+ * - Filter the filter conditions that appear in a plan. For instance,
+ * ((expr 1 && expr 2) && expr 3), (expr 1 && expr 2 && expr 3), (expr 3 && (expr 1 && expr 2)
+ * etc., will all now be equivalent.
+ * - Sample the seed will replaced by 0L.
*/
- private def normalizeFilters(plan: LogicalPlan) = {
+ private def normalizePlan(plan: LogicalPlan): LogicalPlan = {
plan transform {
case filter @ Filter(condition: Expression, child: LogicalPlan) =>
Filter(splitConjunctivePredicates(condition).sortBy(_.hashCode()).reduce(And), child)
+ case sample: Sample =>
+ sample.copy(seed = 0L)(true)
}
}
/** Fails the test if the two plans do not match */
protected def comparePlans(plan1: LogicalPlan, plan2: LogicalPlan) {
- val normalized1 = normalizeFilters(normalizeExprIds(plan1))
- val normalized2 = normalizeFilters(normalizeExprIds(plan2))
+ val normalized1 = normalizePlan(normalizeExprIds(plan1))
+ val normalized2 = normalizePlan(normalizeExprIds(plan2))
if (normalized1 != normalized2) {
fail(
s"""
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala
new file mode 100644
index 0000000000..c098fa99c2
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala
@@ -0,0 +1,219 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution
+
+import scala.collection.JavaConverters._
+
+import org.apache.spark.sql.SaveMode
+import org.apache.spark.sql.catalyst.TableIdentifier
+import org.apache.spark.sql.catalyst.parser.ng.{AbstractSqlParser, AstBuilder}
+import org.apache.spark.sql.catalyst.parser.ng.SqlBaseParser._
+import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, OneRowRelation}
+import org.apache.spark.sql.execution.command.{DescribeCommand => _, _}
+import org.apache.spark.sql.execution.datasources._
+
+/**
+ * Concrete parser for Spark SQL statements.
+ */
+object SparkSqlParser extends AbstractSqlParser{
+ val astBuilder = new SparkSqlAstBuilder
+}
+
+/**
+ * Builder that converts an ANTLR ParseTree into a LogicalPlan/Expression/TableIdentifier.
+ */
+class SparkSqlAstBuilder extends AstBuilder {
+ import org.apache.spark.sql.catalyst.parser.ng.ParserUtils._
+
+ /**
+ * Create a [[SetCommand]] logical plan.
+ *
+ * Note that we assume that everything after the SET keyword is assumed to be a part of the
+ * key-value pair. The split between key and value is made by searching for the first `=`
+ * character in the raw string.
+ */
+ override def visitSetConfiguration(ctx: SetConfigurationContext): LogicalPlan = withOrigin(ctx) {
+ // Construct the command.
+ val raw = remainder(ctx.SET.getSymbol)
+ val keyValueSeparatorIndex = raw.indexOf('=')
+ if (keyValueSeparatorIndex >= 0) {
+ val key = raw.substring(0, keyValueSeparatorIndex).trim
+ val value = raw.substring(keyValueSeparatorIndex + 1).trim
+ SetCommand(Some(key -> Option(value)))
+ } else if (raw.nonEmpty) {
+ SetCommand(Some(raw.trim -> None))
+ } else {
+ SetCommand(None)
+ }
+ }
+
+ /**
+ * Create a [[SetDatabaseCommand]] logical plan.
+ */
+ override def visitUse(ctx: UseContext): LogicalPlan = withOrigin(ctx) {
+ SetDatabaseCommand(ctx.db.getText)
+ }
+
+ /**
+ * Create a [[ShowTablesCommand]] logical plan.
+ */
+ override def visitShowTables(ctx: ShowTablesContext): LogicalPlan = withOrigin(ctx) {
+ if (ctx.LIKE != null) {
+ logWarning("SHOW TABLES LIKE option is ignored.")
+ }
+ ShowTablesCommand(Option(ctx.db).map(_.getText))
+ }
+
+ /**
+ * Create a [[RefreshTable]] logical plan.
+ */
+ override def visitRefreshTable(ctx: RefreshTableContext): LogicalPlan = withOrigin(ctx) {
+ RefreshTable(visitTableIdentifier(ctx.tableIdentifier))
+ }
+
+ /**
+ * Create a [[CacheTableCommand]] logical plan.
+ */
+ override def visitCacheTable(ctx: CacheTableContext): LogicalPlan = withOrigin(ctx) {
+ val query = Option(ctx.query).map(plan)
+ CacheTableCommand(ctx.identifier.getText, query, ctx.LAZY != null)
+ }
+
+ /**
+ * Create an [[UncacheTableCommand]] logical plan.
+ */
+ override def visitUncacheTable(ctx: UncacheTableContext): LogicalPlan = withOrigin(ctx) {
+ UncacheTableCommand(ctx.identifier.getText)
+ }
+
+ /**
+ * Create a [[ClearCacheCommand]] logical plan.
+ */
+ override def visitClearCache(ctx: ClearCacheContext): LogicalPlan = withOrigin(ctx) {
+ ClearCacheCommand
+ }
+
+ /**
+ * Create an [[ExplainCommand]] logical plan.
+ */
+ override def visitExplain(ctx: ExplainContext): LogicalPlan = withOrigin(ctx) {
+ val options = ctx.explainOption.asScala
+ if (options.exists(_.FORMATTED != null)) {
+ logWarning("EXPLAIN FORMATTED option is ignored.")
+ }
+ if (options.exists(_.LOGICAL != null)) {
+ logWarning("EXPLAIN LOGICAL option is ignored.")
+ }
+
+ // Create the explain comment.
+ val statement = plan(ctx.statement)
+ if (isExplainableStatement(statement)) {
+ ExplainCommand(statement, extended = options.exists(_.EXTENDED != null))
+ } else {
+ ExplainCommand(OneRowRelation)
+ }
+ }
+
+ /**
+ * Determine if a plan should be explained at all.
+ */
+ protected def isExplainableStatement(plan: LogicalPlan): Boolean = plan match {
+ case _: datasources.DescribeCommand => false
+ case _ => true
+ }
+
+ /**
+ * Create a [[DescribeCommand]] logical plan.
+ */
+ override def visitDescribeTable(ctx: DescribeTableContext): LogicalPlan = withOrigin(ctx) {
+ // FORMATTED and columns are not supported. Return null and let the parser decide what to do
+ // with this (create an exception or pass it on to a different system).
+ if (ctx.describeColName != null || ctx.FORMATTED != null || ctx.partitionSpec != null) {
+ null
+ } else {
+ datasources.DescribeCommand(
+ visitTableIdentifier(ctx.tableIdentifier),
+ ctx.EXTENDED != null)
+ }
+ }
+
+ /** Type to keep track of a table header. */
+ type TableHeader = (TableIdentifier, Boolean, Boolean, Boolean)
+
+ /**
+ * Validate a create table statement and return the [[TableIdentifier]].
+ */
+ override def visitCreateTableHeader(
+ ctx: CreateTableHeaderContext): TableHeader = withOrigin(ctx) {
+ val temporary = ctx.TEMPORARY != null
+ val ifNotExists = ctx.EXISTS != null
+ assert(!temporary || !ifNotExists,
+ "a CREATE TEMPORARY TABLE statement does not allow IF NOT EXISTS clause.",
+ ctx)
+ (visitTableIdentifier(ctx.tableIdentifier), temporary, ifNotExists, ctx.EXTERNAL != null)
+ }
+
+ /**
+ * Create a [[CreateTableUsing]] or a [[CreateTableUsingAsSelect]] logical plan.
+ *
+ * TODO add bucketing and partitioning.
+ */
+ override def visitCreateTableUsing(ctx: CreateTableUsingContext): LogicalPlan = withOrigin(ctx) {
+ val (table, temp, ifNotExists, external) = visitCreateTableHeader(ctx.createTableHeader)
+ if (external) {
+ logWarning("EXTERNAL option is not supported.")
+ }
+ val options = Option(ctx.tablePropertyList).map(visitTablePropertyList).getOrElse(Map.empty)
+ val provider = ctx.tableProvider.qualifiedName.getText
+
+ if (ctx.query != null) {
+ // Get the backing query.
+ val query = plan(ctx.query)
+
+ // Determine the storage mode.
+ val mode = if (ifNotExists) {
+ SaveMode.Ignore
+ } else if (temp) {
+ SaveMode.Overwrite
+ } else {
+ SaveMode.ErrorIfExists
+ }
+ CreateTableUsingAsSelect(table, provider, temp, Array.empty, None, mode, options, query)
+ } else {
+ val struct = Option(ctx.colTypeList).map(createStructType)
+ CreateTableUsing(table, struct, provider, temp, options, ifNotExists, managedIfNoPath = false)
+ }
+ }
+
+ /**
+ * Convert a table property list into a key-value map.
+ */
+ override def visitTablePropertyList(
+ ctx: TablePropertyListContext): Map[String, String] = withOrigin(ctx) {
+ ctx.tableProperty.asScala.map { property =>
+ // A key can either be a String or a collection of dot separated elements. We need to treat
+ // these differently.
+ val key = if (property.key.STRING != null) {
+ string(property.key.STRING)
+ } else {
+ property.key.getText
+ }
+ val value = Option(property.value).map(string).orNull
+ key -> value
+ }.toMap
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index 8abb9d7e4a..7ce15e3f35 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -27,8 +27,8 @@ import org.apache.spark.sql.catalyst.analysis.{Star, UnresolvedFunction}
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
-import org.apache.spark.sql.catalyst.parser.CatalystQl
import org.apache.spark.sql.catalyst.plans.logical.BroadcastHint
+import org.apache.spark.sql.execution.SparkSqlParser
import org.apache.spark.sql.expressions.UserDefinedFunction
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils
@@ -1172,8 +1172,7 @@ object functions {
* @group normal_funcs
*/
def expr(expr: String): Column = {
- val parser = SQLContext.getActive().map(_.sessionState.sqlParser).getOrElse(new CatalystQl())
- Column(parser.parseExpression(expr))
+ Column(SparkSqlParser.parseExpression(expr))
}
//////////////////////////////////////////////////////////////////////////////////////////////
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala
index e5f02caabc..9bc640763f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala
@@ -81,7 +81,7 @@ private[sql] class SessionState(ctx: SQLContext) {
/**
* Parser that extracts expressions, plans, table identifiers etc. from SQL texts.
*/
- lazy val sqlParser: ParserInterface = new SparkQl(conf)
+ lazy val sqlParser: ParserInterface = SparkSqlParser
/**
* Planner that converts optimized logical plans to physical plans.
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
index 5af1a4fcd7..a5a4ff13de 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
@@ -329,8 +329,8 @@ class JoinSuite extends QueryTest with SharedSQLContext {
}
test("full outer join") {
- upperCaseData.where('N <= 4).registerTempTable("left")
- upperCaseData.where('N >= 3).registerTempTable("right")
+ upperCaseData.where('N <= 4).registerTempTable("`left`")
+ upperCaseData.where('N >= 3).registerTempTable("`right`")
val left = UnresolvedRelation(TableIdentifier("left"), None)
val right = UnresolvedRelation(TableIdentifier("right"), None)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index c958eac266..b727e88668 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -1656,7 +1656,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
val e2 = intercept[AnalysisException] {
sql("select interval 23 nanosecond")
}
- assert(e2.message.contains("cannot recognize input near"))
+ assert(e2.message.contains("No interval can be constructed"))
}
test("SPARK-8945: add and subtract expressions for interval type") {