diff options
29 files changed, 4127 insertions, 66 deletions
@@ -238,6 +238,7 @@ The text of each license is also included at licenses/LICENSE-[project].txt. (BSD 3 Clause) netlib core (com.github.fommil.netlib:core:1.1.2 - https://github.com/fommil/netlib-java/core) (BSD 3 Clause) JPMML-Model (org.jpmml:pmml-model:1.2.7 - https://github.com/jpmml/jpmml-model) (BSD License) AntLR Parser Generator (antlr:antlr:2.7.7 - http://www.antlr.org/) + (BSD License) ANTLR 4.5.2-1 (org.antlr:antlr4:4.5.2-1 - http://wwww.antlr.org/) (BSD licence) ANTLR ST4 4.0.4 (org.antlr:ST4:4.0.4 - http://www.stringtemplate.org) (BSD licence) ANTLR StringTemplate (org.antlr:stringtemplate:3.2.1 - http://www.stringtemplate.org) (BSD License) Javolution (javolution:javolution:5.5.1 - http://javolution.org) diff --git a/dev/deps/spark-deps-hadoop-2.2 b/dev/deps/spark-deps-hadoop-2.2 index 512675a599..7c2f88bdb1 100644 --- a/dev/deps/spark-deps-hadoop-2.2 +++ b/dev/deps/spark-deps-hadoop-2.2 @@ -3,6 +3,7 @@ RoaringBitmap-0.5.11.jar ST4-4.0.4.jar activation-1.1.jar antlr-runtime-3.5.2.jar +antlr4-runtime-4.5.2-1.jar aopalliance-1.0.jar apache-log4j-extras-1.2.17.jar arpack_combined_all-0.1.jar diff --git a/dev/deps/spark-deps-hadoop-2.3 b/dev/deps/spark-deps-hadoop-2.3 index 31f8694fed..f4d600038d 100644 --- a/dev/deps/spark-deps-hadoop-2.3 +++ b/dev/deps/spark-deps-hadoop-2.3 @@ -3,6 +3,7 @@ RoaringBitmap-0.5.11.jar ST4-4.0.4.jar activation-1.1.1.jar antlr-runtime-3.5.2.jar +antlr4-runtime-4.5.2-1.jar aopalliance-1.0.jar apache-log4j-extras-1.2.17.jar arpack_combined_all-0.1.jar diff --git a/dev/deps/spark-deps-hadoop-2.4 b/dev/deps/spark-deps-hadoop-2.4 index 0fa8bccab0..7c5e2c35bd 100644 --- a/dev/deps/spark-deps-hadoop-2.4 +++ b/dev/deps/spark-deps-hadoop-2.4 @@ -3,6 +3,7 @@ RoaringBitmap-0.5.11.jar ST4-4.0.4.jar activation-1.1.1.jar antlr-runtime-3.5.2.jar +antlr4-runtime-4.5.2-1.jar aopalliance-1.0.jar apache-log4j-extras-1.2.17.jar arpack_combined_all-0.1.jar diff --git a/dev/deps/spark-deps-hadoop-2.6 b/dev/deps/spark-deps-hadoop-2.6 index 8d2f6e6e32..03d9a51057 100644 --- a/dev/deps/spark-deps-hadoop-2.6 +++ b/dev/deps/spark-deps-hadoop-2.6 @@ -3,6 +3,7 @@ RoaringBitmap-0.5.11.jar ST4-4.0.4.jar activation-1.1.1.jar antlr-runtime-3.5.2.jar +antlr4-runtime-4.5.2-1.jar aopalliance-1.0.jar apache-log4j-extras-1.2.17.jar apacheds-i18n-2.0.0-M15.jar diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7 index a114c4ae8d..5765071a1c 100644 --- a/dev/deps/spark-deps-hadoop-2.7 +++ b/dev/deps/spark-deps-hadoop-2.7 @@ -3,6 +3,7 @@ RoaringBitmap-0.5.11.jar ST4-4.0.4.jar activation-1.1.1.jar antlr-runtime-3.5.2.jar +antlr4-runtime-4.5.2-1.jar aopalliance-1.0.jar apache-log4j-extras-1.2.17.jar apacheds-i18n-2.0.0-M15.jar @@ -178,6 +178,7 @@ <jsr305.version>1.3.9</jsr305.version> <libthrift.version>0.9.2</libthrift.version> <antlr.version>3.5.2</antlr.version> + <antlr4.version>4.5.2-1</antlr4.version> <test.java.home>${java.home}</test.java.home> <test.exclude.tags></test.exclude.tags> @@ -1759,6 +1760,11 @@ <artifactId>antlr-runtime</artifactId> <version>${antlr.version}</version> </dependency> + <dependency> + <groupId>org.antlr</groupId> + <artifactId>antlr4-runtime</artifactId> + <version>${antlr4.version}</version> + </dependency> </dependencies> </dependencyManagement> diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index fb229b979d..39a9e16f7e 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -25,6 +25,7 @@ import sbt._ import sbt.Classpaths.publishTask import sbt.Keys._ import sbtunidoc.Plugin.UnidocKeys.unidocGenjavadocVersion +import com.simplytyped.Antlr4Plugin._ import com.typesafe.sbt.pom.{PomBuild, SbtPomKeys} import com.typesafe.tools.mima.plugin.MimaKeys @@ -401,7 +402,10 @@ object OldDeps { } object Catalyst { - lazy val settings = Seq( + lazy val settings = antlr4Settings ++ Seq( + antlr4PackageName in Antlr4 := Some("org.apache.spark.sql.catalyst.parser.ng"), + antlr4GenListener in Antlr4 := true, + antlr4GenVisitor in Antlr4 := true, // ANTLR code-generation step. // // This has been heavily inspired by com.github.stefri.sbt-antlr (0.5.3). It fixes a number of @@ -414,7 +418,7 @@ object Catalyst { "SparkSqlLexer.g", "SparkSqlParser.g") val sourceDir = (sourceDirectory in Compile).value / "antlr3" - val targetDir = (sourceManaged in Compile).value + val targetDir = (sourceManaged in Compile).value / "antlr3" // Create default ANTLR Tool. val antlr = new org.antlr.Tool diff --git a/project/plugins.sbt b/project/plugins.sbt index eeca94a47c..d9ed7962bf 100644 --- a/project/plugins.sbt +++ b/project/plugins.sbt @@ -23,3 +23,9 @@ libraryDependencies += "org.ow2.asm" % "asm" % "5.0.3" libraryDependencies += "org.ow2.asm" % "asm-commons" % "5.0.3" libraryDependencies += "org.antlr" % "antlr" % "3.5.2" + + +// TODO I am not sure we want such a dep. +resolvers += "simplytyped" at "http://simplytyped.github.io/repo/releases" + +addSbtPlugin("com.simplytyped" % "sbt-antlr4" % "0.7.10") diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 83ef76c13c..1a5d422af9 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -51,7 +51,7 @@ from pyspark.sql.types import UserDefinedType, _infer_type from pyspark.tests import ReusedPySparkTestCase from pyspark.sql.functions import UserDefinedFunction, sha2 from pyspark.sql.window import Window -from pyspark.sql.utils import AnalysisException, IllegalArgumentException +from pyspark.sql.utils import AnalysisException, ParseException, IllegalArgumentException class UTCOffsetTimezone(datetime.tzinfo): @@ -1130,7 +1130,9 @@ class SQLTests(ReusedPySparkTestCase): def test_capture_analysis_exception(self): self.assertRaises(AnalysisException, lambda: self.sqlCtx.sql("select abc")) self.assertRaises(AnalysisException, lambda: self.df.selectExpr("a + b")) - self.assertRaises(AnalysisException, lambda: self.sqlCtx.sql("abc")) + + def test_capture_parse_exception(self): + self.assertRaises(ParseException, lambda: self.sqlCtx.sql("abc")) def test_capture_illegalargument_exception(self): self.assertRaisesRegexp(IllegalArgumentException, "Setting negative mapred.reduce.tasks", diff --git a/python/pyspark/sql/utils.py b/python/pyspark/sql/utils.py index b0a0373372..b89ea8c6e0 100644 --- a/python/pyspark/sql/utils.py +++ b/python/pyspark/sql/utils.py @@ -33,6 +33,12 @@ class AnalysisException(CapturedException): """ +class ParseException(CapturedException): + """ + Failed to parse a SQL command. + """ + + class IllegalArgumentException(CapturedException): """ Passed an illegal or inappropriate argument. @@ -49,6 +55,8 @@ def capture_sql_exception(f): e.java_exception.getStackTrace())) if s.startswith('org.apache.spark.sql.AnalysisException: '): raise AnalysisException(s.split(': ', 1)[1], stackTrace) + if s.startswith('org.apache.spark.sql.catalyst.parser.ng.ParseException: '): + raise ParseException(s.split(': ', 1)[1], stackTrace) if s.startswith('java.lang.IllegalArgumentException: '): raise IllegalArgumentException(s.split(': ', 1)[1], stackTrace) raise diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml index 5d1d9edd25..c834a011f1 100644 --- a/sql/catalyst/pom.xml +++ b/sql/catalyst/pom.xml @@ -76,6 +76,10 @@ <artifactId>antlr-runtime</artifactId> </dependency> <dependency> + <groupId>org.antlr</groupId> + <artifactId>antlr4-runtime</artifactId> + </dependency> + <dependency> <groupId>commons-codec</groupId> <artifactId>commons-codec</artifactId> </dependency> diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/ng/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/ng/SqlBase.g4 new file mode 100644 index 0000000000..e46fd9bed5 --- /dev/null +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/ng/SqlBase.g4 @@ -0,0 +1,911 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * This file is an adaptation of Presto's presto-parser/src/main/antlr4/com/facebook/presto/sql/parser/SqlBase.g4 grammar. + */ + +grammar SqlBase; + +tokens { + DELIMITER +} + +singleStatement + : statement EOF + ; + +singleExpression + : namedExpression EOF + ; + +singleTableIdentifier + : tableIdentifier EOF + ; + +singleDataType + : dataType EOF + ; + +statement + : query #statementDefault + | USE db=identifier #use + | CREATE DATABASE (IF NOT EXISTS)? identifier + (COMMENT comment=STRING)? locationSpec? + (WITH DBPROPERTIES tablePropertyList)? #createDatabase + | ALTER DATABASE identifier SET DBPROPERTIES tablePropertyList #setDatabaseProperties + | DROP DATABASE (IF EXISTS)? identifier (RESTRICT | CASCADE)? #dropDatabase + | createTableHeader ('(' colTypeList ')')? tableProvider + (OPTIONS tablePropertyList)? #createTableUsing + | createTableHeader tableProvider + (OPTIONS tablePropertyList)? AS? query #createTableUsing + | createTableHeader ('(' colTypeList ')')? (COMMENT STRING)? + (PARTITIONED BY identifierList)? bucketSpec? skewSpec? + rowFormat? createFileFormat? locationSpec? + (TBLPROPERTIES tablePropertyList)? + (AS? query)? #createTable + | ANALYZE TABLE tableIdentifier partitionSpec? COMPUTE STATISTICS + (identifier | FOR COLUMNS identifierSeq?) #analyze + | ALTER TABLE from=tableIdentifier RENAME TO to=tableIdentifier #renameTable + | ALTER TABLE tableIdentifier + SET TBLPROPERTIES tablePropertyList #setTableProperties + | ALTER TABLE tableIdentifier + UNSET TBLPROPERTIES (IF EXISTS)? tablePropertyList #unsetTableProperties + | ALTER TABLE tableIdentifier (partitionSpec)? + SET SERDE STRING (WITH SERDEPROPERTIES tablePropertyList)? #setTableSerDe + | ALTER TABLE tableIdentifier (partitionSpec)? + SET SERDEPROPERTIES tablePropertyList #setTableSerDe + | ALTER TABLE tableIdentifier bucketSpec #bucketTable + | ALTER TABLE tableIdentifier NOT CLUSTERED #unclusterTable + | ALTER TABLE tableIdentifier NOT SORTED #unsortTable + | ALTER TABLE tableIdentifier skewSpec #skewTable + | ALTER TABLE tableIdentifier NOT SKEWED #unskewTable + | ALTER TABLE tableIdentifier NOT STORED AS DIRECTORIES #unstoreTable + | ALTER TABLE tableIdentifier + SET SKEWED LOCATION skewedLocationList #setTableSkewLocations + | ALTER TABLE tableIdentifier ADD (IF NOT EXISTS)? + partitionSpecLocation+ #addTablePartition + | ALTER TABLE tableIdentifier + from=partitionSpec RENAME TO to=partitionSpec #renameTablePartition + | ALTER TABLE from=tableIdentifier + EXCHANGE partitionSpec WITH TABLE to=tableIdentifier #exchangeTablePartition + | ALTER TABLE tableIdentifier + DROP (IF EXISTS)? partitionSpec (',' partitionSpec)* PURGE? #dropTablePartitions + | ALTER TABLE tableIdentifier ARCHIVE partitionSpec #archiveTablePartition + | ALTER TABLE tableIdentifier UNARCHIVE partitionSpec #unarchiveTablePartition + | ALTER TABLE tableIdentifier partitionSpec? + SET FILEFORMAT fileFormat #setTableFileFormat + | ALTER TABLE tableIdentifier partitionSpec? SET locationSpec #setTableLocation + | ALTER TABLE tableIdentifier TOUCH partitionSpec? #touchTable + | ALTER TABLE tableIdentifier partitionSpec? COMPACT STRING #compactTable + | ALTER TABLE tableIdentifier partitionSpec? CONCATENATE #concatenateTable + | ALTER TABLE tableIdentifier partitionSpec? + CHANGE COLUMN? oldName=identifier colType + (FIRST | AFTER after=identifier)? (CASCADE | RESTRICT)? #changeColumn + | ALTER TABLE tableIdentifier partitionSpec? + ADD COLUMNS '(' colTypeList ')' (CASCADE | RESTRICT)? #addColumns + | ALTER TABLE tableIdentifier partitionSpec? + REPLACE COLUMNS '(' colTypeList ')' (CASCADE | RESTRICT)? #replaceColumns + | DROP TABLE (IF EXISTS)? tableIdentifier PURGE? + (FOR METADATA? REPLICATION '(' STRING ')')? #dropTable + | CREATE (OR REPLACE)? VIEW (IF NOT EXISTS)? tableIdentifier + identifierCommentList? (COMMENT STRING)? + (PARTITIONED ON identifierList)? + (TBLPROPERTIES tablePropertyList)? AS query #createView + | ALTER VIEW tableIdentifier AS? query #alterViewQuery + | CREATE TEMPORARY? FUNCTION qualifiedName AS className=STRING + (USING resource (',' resource)*)? #createFunction + | DROP TEMPORARY? FUNCTION (IF EXISTS)? qualifiedName #dropFunction + | EXPLAIN explainOption* statement #explain + | SHOW TABLES ((FROM | IN) db=identifier)? + (LIKE (qualifiedName | pattern=STRING))? #showTables + | SHOW FUNCTIONS (LIKE? (qualifiedName | pattern=STRING))? #showFunctions + | (DESC | DESCRIBE) FUNCTION EXTENDED? qualifiedName #describeFunction + | (DESC | DESCRIBE) option=(EXTENDED | FORMATTED)? + tableIdentifier partitionSpec? describeColName? #describeTable + | (DESC | DESCRIBE) DATABASE EXTENDED? identifier #describeDatabase + | REFRESH TABLE tableIdentifier #refreshTable + | CACHE LAZY? TABLE identifier (AS? query)? #cacheTable + | UNCACHE TABLE identifier #uncacheTable + | CLEAR CACHE #clearCache + | ADD identifier .*? #addResource + | SET .*? #setConfiguration + | hiveNativeCommands #executeNativeCommand + ; + +hiveNativeCommands + : createTableHeader LIKE tableIdentifier + rowFormat? createFileFormat? locationSpec? + (TBLPROPERTIES tablePropertyList)? + | DELETE FROM tableIdentifier (WHERE booleanExpression)? + | TRUNCATE TABLE tableIdentifier partitionSpec? + (COLUMNS identifierList)? + | ALTER VIEW from=tableIdentifier AS? RENAME TO to=tableIdentifier + | ALTER VIEW from=tableIdentifier AS? + SET TBLPROPERTIES tablePropertyList + | ALTER VIEW from=tableIdentifier AS? + UNSET TBLPROPERTIES (IF EXISTS)? tablePropertyList + | ALTER VIEW from=tableIdentifier AS? + ADD (IF NOT EXISTS)? partitionSpecLocation+ + | ALTER VIEW from=tableIdentifier AS? + DROP (IF EXISTS)? partitionSpec (',' partitionSpec)* PURGE? + | DROP VIEW (IF EXISTS)? qualifiedName + | SHOW COLUMNS (FROM | IN) tableIdentifier ((FROM|IN) identifier)? + | START TRANSACTION (transactionMode (',' transactionMode)*)? + | COMMIT WORK? + | ROLLBACK WORK? + | SHOW PARTITIONS tableIdentifier partitionSpec? + | DFS .*? + | (CREATE | ALTER | DROP | SHOW | DESC | DESCRIBE | REVOKE | GRANT | LOCK | UNLOCK | MSCK | EXPORT | IMPORT | LOAD) .*? + ; + +createTableHeader + : CREATE TEMPORARY? EXTERNAL? TABLE (IF NOT EXISTS)? tableIdentifier + ; + +bucketSpec + : CLUSTERED BY identifierList + (SORTED BY orderedIdentifierList)? + INTO INTEGER_VALUE BUCKETS + ; + +skewSpec + : SKEWED BY identifierList + ON (constantList | nestedConstantList) + (STORED AS DIRECTORIES)? + ; + +locationSpec + : LOCATION STRING + ; + +query + : ctes? queryNoWith + ; + +insertInto + : INSERT OVERWRITE TABLE tableIdentifier partitionSpec? (IF NOT EXISTS)? + | INSERT INTO TABLE? tableIdentifier partitionSpec? + ; + +partitionSpecLocation + : partitionSpec locationSpec? + ; + +partitionSpec + : PARTITION '(' partitionVal (',' partitionVal)* ')' + ; + +partitionVal + : identifier (EQ constant)? + ; + +describeColName + : identifier ('.' (identifier | STRING))* + ; + +ctes + : WITH namedQuery (',' namedQuery)* + ; + +namedQuery + : name=identifier AS? '(' queryNoWith ')' + ; + +tableProvider + : USING qualifiedName + ; + +tablePropertyList + : '(' tableProperty (',' tableProperty)* ')' + ; + +tableProperty + : key=tablePropertyKey (EQ? value=STRING)? + ; + +tablePropertyKey + : looseIdentifier ('.' looseIdentifier)* + | STRING + ; + +constantList + : '(' constant (',' constant)* ')' + ; + +nestedConstantList + : '(' constantList (',' constantList)* ')' + ; + +skewedLocation + : (constant | constantList) EQ STRING + ; + +skewedLocationList + : '(' skewedLocation (',' skewedLocation)* ')' + ; + +createFileFormat + : STORED AS fileFormat + | STORED BY storageHandler + ; + +fileFormat + : INPUTFORMAT inFmt=STRING OUTPUTFORMAT outFmt=STRING (SERDE serdeCls=STRING)? + (INPUTDRIVER inDriver=STRING OUTPUTDRIVER outDriver=STRING)? #tableFileFormat + | identifier #genericFileFormat + ; + +storageHandler + : STRING (WITH SERDEPROPERTIES tablePropertyList)? + ; + +resource + : identifier STRING + ; + +queryNoWith + : insertInto? queryTerm queryOrganization #singleInsertQuery + | fromClause multiInsertQueryBody+ #multiInsertQuery + ; + +queryOrganization + : (ORDER BY order+=sortItem (',' order+=sortItem)*)? + (CLUSTER BY clusterBy+=expression (',' clusterBy+=expression)*)? + (DISTRIBUTE BY distributeBy+=expression (',' distributeBy+=expression)*)? + (SORT BY sort+=sortItem (',' sort+=sortItem)*)? + windows? + (LIMIT limit=expression)? + ; + +multiInsertQueryBody + : insertInto? + querySpecification + queryOrganization + ; + +queryTerm + : queryPrimary #queryTermDefault + | left=queryTerm operator=(INTERSECT | UNION | EXCEPT) setQuantifier? right=queryTerm #setOperation + ; + +queryPrimary + : querySpecification #queryPrimaryDefault + | TABLE tableIdentifier #table + | inlineTable #inlineTableDefault1 + | '(' queryNoWith ')' #subquery + ; + +sortItem + : expression ordering=(ASC | DESC)? + ; + +querySpecification + : (((SELECT kind=TRANSFORM '(' namedExpressionSeq ')' + | kind=MAP namedExpressionSeq + | kind=REDUCE namedExpressionSeq)) + inRowFormat=rowFormat? + (RECORDWRITER recordWriter=STRING)? + USING script=STRING + (AS (identifierSeq | colTypeList | ('(' (identifierSeq | colTypeList) ')')))? + outRowFormat=rowFormat? + (RECORDREADER recordReader=STRING)? + fromClause? + (WHERE where=booleanExpression)?) + | ((kind=SELECT setQuantifier? namedExpressionSeq fromClause? + | fromClause (kind=SELECT setQuantifier? namedExpressionSeq)?) + lateralView* + (WHERE where=booleanExpression)? + aggregation? + (HAVING having=booleanExpression)? + windows?) + ; + +fromClause + : FROM relation (',' relation)* lateralView* + ; + +aggregation + : GROUP BY groupingExpressions+=expression (',' groupingExpressions+=expression)* ( + WITH kind=ROLLUP + | WITH kind=CUBE + | kind=GROUPING SETS '(' groupingSet (',' groupingSet)* ')')? + ; + +groupingSet + : '(' (expression (',' expression)*)? ')' + | expression + ; + +lateralView + : LATERAL VIEW (OUTER)? qualifiedName '(' (expression (',' expression)*)? ')' tblName=identifier (AS? colName+=identifier (',' colName+=identifier)*)? + ; + +setQuantifier + : DISTINCT + | ALL + ; + +relation + : left=relation + ((CROSS | joinType) JOIN right=relation joinCriteria? + | NATURAL joinType JOIN right=relation + ) #joinRelation + | relationPrimary #relationDefault + ; + +joinType + : INNER? + | LEFT OUTER? + | LEFT SEMI + | RIGHT OUTER? + | FULL OUTER? + ; + +joinCriteria + : ON booleanExpression + | USING '(' identifier (',' identifier)* ')' + ; + +sample + : TABLESAMPLE '(' + ( (percentage=(INTEGER_VALUE | DECIMAL_VALUE) sampleType=PERCENTLIT) + | (expression sampleType=ROWS) + | (sampleType=BUCKET numerator=INTEGER_VALUE OUT OF denominator=INTEGER_VALUE (ON identifier)?)) + ')' + ; + +identifierList + : '(' identifierSeq ')' + ; + +identifierSeq + : identifier (',' identifier)* + ; + +orderedIdentifierList + : '(' orderedIdentifier (',' orderedIdentifier)* ')' + ; + +orderedIdentifier + : identifier ordering=(ASC | DESC)? + ; + +identifierCommentList + : '(' identifierComment (',' identifierComment)* ')' + ; + +identifierComment + : identifier (COMMENT STRING)? + ; + +relationPrimary + : tableIdentifier sample? (AS? identifier)? #tableName + | '(' queryNoWith ')' sample? (AS? identifier)? #aliasedQuery + | '(' relation ')' sample? (AS? identifier)? #aliasedRelation + | inlineTable #inlineTableDefault2 + ; + +inlineTable + : VALUES expression (',' expression)* (AS? identifier identifierList?)? + ; + +rowFormat + : ROW FORMAT SERDE name=STRING (WITH SERDEPROPERTIES props=tablePropertyList)? #rowFormatSerde + | ROW FORMAT DELIMITED + (FIELDS TERMINATED BY fieldsTerminatedBy=STRING (ESCAPED BY escapedBy=STRING)?)? + (COLLECTION ITEMS TERMINATED BY collectionItemsTerminatedBy=STRING)? + (MAP KEYS TERMINATED BY keysTerminatedBy=STRING)? + (LINES TERMINATED BY linesSeparatedBy=STRING)? + (NULL DEFINED AS nullDefinedAs=STRING)? #rowFormatDelimited + ; + +tableIdentifier + : (db=identifier '.')? table=identifier + ; + +namedExpression + : expression (AS? (identifier | identifierList))? + ; + +namedExpressionSeq + : namedExpression (',' namedExpression)* + ; + +expression + : booleanExpression + ; + +booleanExpression + : predicated #booleanDefault + | NOT booleanExpression #logicalNot + | left=booleanExpression operator=AND right=booleanExpression #logicalBinary + | left=booleanExpression operator=OR right=booleanExpression #logicalBinary + | EXISTS '(' query ')' #exists + ; + +// workaround for: +// https://github.com/antlr/antlr4/issues/780 +// https://github.com/antlr/antlr4/issues/781 +predicated + : valueExpression predicate[$valueExpression.ctx]? + ; + +predicate[ParserRuleContext value] + : NOT? BETWEEN lower=valueExpression AND upper=valueExpression #between + | NOT? IN '(' expression (',' expression)* ')' #inList + | NOT? IN '(' query ')' #inSubquery + | NOT? like=(RLIKE | LIKE) pattern=valueExpression #like + | IS NOT? NULL #nullPredicate + ; + +valueExpression + : primaryExpression #valueExpressionDefault + | operator=(MINUS | PLUS | TILDE) valueExpression #arithmeticUnary + | left=valueExpression operator=(ASTERISK | SLASH | PERCENT | DIV) right=valueExpression #arithmeticBinary + | left=valueExpression operator=(PLUS | MINUS) right=valueExpression #arithmeticBinary + | left=valueExpression operator=AMPERSAND right=valueExpression #arithmeticBinary + | left=valueExpression operator=HAT right=valueExpression #arithmeticBinary + | left=valueExpression operator=PIPE right=valueExpression #arithmeticBinary + | left=valueExpression comparisonOperator right=valueExpression #comparison + ; + +primaryExpression + : constant #constantDefault + | ASTERISK #star + | qualifiedName '.' ASTERISK #star + | '(' expression (',' expression)+ ')' #rowConstructor + | qualifiedName '(' (setQuantifier? expression (',' expression)*)? ')' (OVER windowSpec)? #functionCall + | '(' query ')' #subqueryExpression + | CASE valueExpression whenClause+ (ELSE elseExpression=expression)? END #simpleCase + | CASE whenClause+ (ELSE elseExpression=expression)? END #searchedCase + | CAST '(' expression AS dataType ')' #cast + | value=primaryExpression '[' index=valueExpression ']' #subscript + | identifier #columnReference + | base=primaryExpression '.' fieldName=identifier #dereference + | '(' expression ')' #parenthesizedExpression + ; + +constant + : NULL #nullLiteral + | interval #intervalLiteral + | identifier STRING #typeConstructor + | number #numericLiteral + | booleanValue #booleanLiteral + | STRING+ #stringLiteral + ; + +comparisonOperator + : EQ | NEQ | NEQJ | LT | LTE | GT | GTE | NSEQ + ; + +booleanValue + : TRUE | FALSE + ; + +interval + : INTERVAL intervalField* + ; + +intervalField + : value=intervalValue unit=identifier (TO to=identifier)? + ; + +intervalValue + : (PLUS | MINUS)? (INTEGER_VALUE | DECIMAL_VALUE) + | STRING + ; + +dataType + : complex=ARRAY '<' dataType '>' #complexDataType + | complex=MAP '<' dataType ',' dataType '>' #complexDataType + | complex=STRUCT ('<' colTypeList? '>' | NEQ) #complexDataType + | identifier ('(' INTEGER_VALUE (',' INTEGER_VALUE)* ')')? #primitiveDataType + ; + +colTypeList + : colType (',' colType)* + ; + +colType + : identifier ':'? dataType (COMMENT STRING)? + ; + +whenClause + : WHEN condition=expression THEN result=expression + ; + +windows + : WINDOW namedWindow (',' namedWindow)* + ; + +namedWindow + : identifier AS windowSpec + ; + +windowSpec + : name=identifier #windowRef + | '(' + ( CLUSTER BY partition+=expression (',' partition+=expression)* + | ((PARTITION | DISTRIBUTE) BY partition+=expression (',' partition+=expression)*)? + ((ORDER | SORT) BY sortItem (',' sortItem)*)?) + windowFrame? + ')' #windowDef + ; + +windowFrame + : frameType=RANGE start=frameBound + | frameType=ROWS start=frameBound + | frameType=RANGE BETWEEN start=frameBound AND end=frameBound + | frameType=ROWS BETWEEN start=frameBound AND end=frameBound + ; + +frameBound + : UNBOUNDED boundType=(PRECEDING | FOLLOWING) + | boundType=CURRENT ROW + | expression boundType=(PRECEDING | FOLLOWING) + ; + + +explainOption + : LOGICAL | FORMATTED | EXTENDED + ; + +transactionMode + : ISOLATION LEVEL SNAPSHOT #isolationLevel + | READ accessMode=(ONLY | WRITE) #transactionAccessMode + ; + +qualifiedName + : identifier ('.' identifier)* + ; + +// Identifier that also allows the use of a number of SQL keywords (mainly for backwards compatibility). +looseIdentifier + : identifier + | FROM + | TO + | TABLE + | WITH + ; + +identifier + : IDENTIFIER #unquotedIdentifier + | quotedIdentifier #quotedIdentifierAlternative + | nonReserved #unquotedIdentifier + ; + +quotedIdentifier + : BACKQUOTED_IDENTIFIER + ; + +number + : DECIMAL_VALUE #decimalLiteral + | SCIENTIFIC_DECIMAL_VALUE #scientificDecimalLiteral + | INTEGER_VALUE #integerLiteral + | BIGINT_LITERAL #bigIntLiteral + | SMALLINT_LITERAL #smallIntLiteral + | TINYINT_LITERAL #tinyIntLiteral + | DOUBLE_LITERAL #doubleLiteral + ; + +nonReserved + : SHOW | TABLES | COLUMNS | COLUMN | PARTITIONS | FUNCTIONS + | ADD + | OVER | PARTITION | RANGE | ROWS | PRECEDING | FOLLOWING | CURRENT | ROW | MAP | ARRAY | STRUCT + | LATERAL | WINDOW | REDUCE | TRANSFORM | USING | SERDE | SERDEPROPERTIES | RECORDREADER + | DELIMITED | FIELDS | TERMINATED | COLLECTION | ITEMS | KEYS | ESCAPED | LINES | SEPARATED + | EXTENDED | REFRESH | CLEAR | CACHE | UNCACHE | LAZY | TEMPORARY | OPTIONS + | GROUPING | CUBE | ROLLUP + | EXPLAIN | FORMAT | LOGICAL | FORMATTED + | TABLESAMPLE | USE | TO | BUCKET | PERCENTLIT | OUT | OF + | SET + | VIEW | REPLACE + | IF + | NO | DATA + | START | TRANSACTION | COMMIT | ROLLBACK | WORK | ISOLATION | LEVEL + | SNAPSHOT | READ | WRITE | ONLY + | SORT | CLUSTER | DISTRIBUTE UNSET | TBLPROPERTIES | SKEWED | STORED | DIRECTORIES | LOCATION + | EXCHANGE | ARCHIVE | UNARCHIVE | FILEFORMAT | TOUCH | COMPACT | CONCATENATE | CHANGE | FIRST + | AFTER | CASCADE | RESTRICT | BUCKETS | CLUSTERED | SORTED | PURGE | INPUTFORMAT | OUTPUTFORMAT + | INPUTDRIVER | OUTPUTDRIVER | DBPROPERTIES | DFS | TRUNCATE | METADATA | REPLICATION | COMPUTE + | STATISTICS | ANALYZE | PARTITIONED | EXTERNAL | DEFINED | RECORDWRITER + | REVOKE | GRANT | LOCK | UNLOCK | MSCK | EXPORT | IMPORT | LOAD | VALUES | COMMENT + ; + +SELECT: 'SELECT'; +FROM: 'FROM'; +ADD: 'ADD'; +AS: 'AS'; +ALL: 'ALL'; +DISTINCT: 'DISTINCT'; +WHERE: 'WHERE'; +GROUP: 'GROUP'; +BY: 'BY'; +GROUPING: 'GROUPING'; +SETS: 'SETS'; +CUBE: 'CUBE'; +ROLLUP: 'ROLLUP'; +ORDER: 'ORDER'; +HAVING: 'HAVING'; +LIMIT: 'LIMIT'; +AT: 'AT'; +OR: 'OR'; +AND: 'AND'; +IN: 'IN'; +NOT: 'NOT' | '!'; +NO: 'NO'; +EXISTS: 'EXISTS'; +BETWEEN: 'BETWEEN'; +LIKE: 'LIKE'; +RLIKE: 'RLIKE' | 'REGEXP'; +IS: 'IS'; +NULL: 'NULL'; +TRUE: 'TRUE'; +FALSE: 'FALSE'; +NULLS: 'NULLS'; +ASC: 'ASC'; +DESC: 'DESC'; +FOR: 'FOR'; +INTERVAL: 'INTERVAL'; +CASE: 'CASE'; +WHEN: 'WHEN'; +THEN: 'THEN'; +ELSE: 'ELSE'; +END: 'END'; +JOIN: 'JOIN'; +CROSS: 'CROSS'; +OUTER: 'OUTER'; +INNER: 'INNER'; +LEFT: 'LEFT'; +SEMI: 'SEMI'; +RIGHT: 'RIGHT'; +FULL: 'FULL'; +NATURAL: 'NATURAL'; +ON: 'ON'; +LATERAL: 'LATERAL'; +WINDOW: 'WINDOW'; +OVER: 'OVER'; +PARTITION: 'PARTITION'; +RANGE: 'RANGE'; +ROWS: 'ROWS'; +UNBOUNDED: 'UNBOUNDED'; +PRECEDING: 'PRECEDING'; +FOLLOWING: 'FOLLOWING'; +CURRENT: 'CURRENT'; +ROW: 'ROW'; +WITH: 'WITH'; +VALUES: 'VALUES'; +CREATE: 'CREATE'; +TABLE: 'TABLE'; +VIEW: 'VIEW'; +REPLACE: 'REPLACE'; +INSERT: 'INSERT'; +DELETE: 'DELETE'; +INTO: 'INTO'; +DESCRIBE: 'DESCRIBE'; +EXPLAIN: 'EXPLAIN'; +FORMAT: 'FORMAT'; +LOGICAL: 'LOGICAL'; +CAST: 'CAST'; +SHOW: 'SHOW'; +TABLES: 'TABLES'; +COLUMNS: 'COLUMNS'; +COLUMN: 'COLUMN'; +USE: 'USE'; +PARTITIONS: 'PARTITIONS'; +FUNCTIONS: 'FUNCTIONS'; +DROP: 'DROP'; +UNION: 'UNION'; +EXCEPT: 'EXCEPT'; +INTERSECT: 'INTERSECT'; +TO: 'TO'; +TABLESAMPLE: 'TABLESAMPLE'; +STRATIFY: 'STRATIFY'; +ALTER: 'ALTER'; +RENAME: 'RENAME'; +ARRAY: 'ARRAY'; +MAP: 'MAP'; +STRUCT: 'STRUCT'; +COMMENT: 'COMMENT'; +SET: 'SET'; +DATA: 'DATA'; +START: 'START'; +TRANSACTION: 'TRANSACTION'; +COMMIT: 'COMMIT'; +ROLLBACK: 'ROLLBACK'; +WORK: 'WORK'; +ISOLATION: 'ISOLATION'; +LEVEL: 'LEVEL'; +SNAPSHOT: 'SNAPSHOT'; +READ: 'READ'; +WRITE: 'WRITE'; +ONLY: 'ONLY'; + +IF: 'IF'; + +EQ : '=' | '=='; +NSEQ: '<=>'; +NEQ : '<>'; +NEQJ: '!='; +LT : '<'; +LTE : '<='; +GT : '>'; +GTE : '>='; + +PLUS: '+'; +MINUS: '-'; +ASTERISK: '*'; +SLASH: '/'; +PERCENT: '%'; +DIV: 'DIV'; +TILDE: '~'; +AMPERSAND: '&'; +PIPE: '|'; +HAT: '^'; + +PERCENTLIT: 'PERCENT'; +BUCKET: 'BUCKET'; +OUT: 'OUT'; +OF: 'OF'; + +SORT: 'SORT'; +CLUSTER: 'CLUSTER'; +DISTRIBUTE: 'DISTRIBUTE'; +OVERWRITE: 'OVERWRITE'; +TRANSFORM: 'TRANSFORM'; +REDUCE: 'REDUCE'; +USING: 'USING'; +SERDE: 'SERDE'; +SERDEPROPERTIES: 'SERDEPROPERTIES'; +RECORDREADER: 'RECORDREADER'; +RECORDWRITER: 'RECORDWRITER'; +DELIMITED: 'DELIMITED'; +FIELDS: 'FIELDS'; +TERMINATED: 'TERMINATED'; +COLLECTION: 'COLLECTION'; +ITEMS: 'ITEMS'; +KEYS: 'KEYS'; +ESCAPED: 'ESCAPED'; +LINES: 'LINES'; +SEPARATED: 'SEPARATED'; +FUNCTION: 'FUNCTION'; +EXTENDED: 'EXTENDED'; +REFRESH: 'REFRESH'; +CLEAR: 'CLEAR'; +CACHE: 'CACHE'; +UNCACHE: 'UNCACHE'; +LAZY: 'LAZY'; +FORMATTED: 'FORMATTED'; +TEMPORARY: 'TEMPORARY' | 'TEMP'; +OPTIONS: 'OPTIONS'; +UNSET: 'UNSET'; +TBLPROPERTIES: 'TBLPROPERTIES'; +DBPROPERTIES: 'DBPROPERTIES'; +BUCKETS: 'BUCKETS'; +SKEWED: 'SKEWED'; +STORED: 'STORED'; +DIRECTORIES: 'DIRECTORIES'; +LOCATION: 'LOCATION'; +EXCHANGE: 'EXCHANGE'; +ARCHIVE: 'ARCHIVE'; +UNARCHIVE: 'UNARCHIVE'; +FILEFORMAT: 'FILEFORMAT'; +TOUCH: 'TOUCH'; +COMPACT: 'COMPACT'; +CONCATENATE: 'CONCATENATE'; +CHANGE: 'CHANGE'; +FIRST: 'FIRST'; +AFTER: 'AFTER'; +CASCADE: 'CASCADE'; +RESTRICT: 'RESTRICT'; +CLUSTERED: 'CLUSTERED'; +SORTED: 'SORTED'; +PURGE: 'PURGE'; +INPUTFORMAT: 'INPUTFORMAT'; +OUTPUTFORMAT: 'OUTPUTFORMAT'; +INPUTDRIVER: 'INPUTDRIVER'; +OUTPUTDRIVER: 'OUTPUTDRIVER'; +DATABASE: 'DATABASE' | 'SCHEMA'; +DFS: 'DFS'; +TRUNCATE: 'TRUNCATE'; +METADATA: 'METADATA'; +REPLICATION: 'REPLICATION'; +ANALYZE: 'ANALYZE'; +COMPUTE: 'COMPUTE'; +STATISTICS: 'STATISTICS'; +PARTITIONED: 'PARTITIONED'; +EXTERNAL: 'EXTERNAL'; +DEFINED: 'DEFINED'; +REVOKE: 'REVOKE'; +GRANT: 'GRANT'; +LOCK: 'LOCK'; +UNLOCK: 'UNLOCK'; +MSCK: 'MSCK'; +EXPORT: 'EXPORT'; +IMPORT: 'IMPORT'; +LOAD: 'LOAD'; + +STRING + : '\'' ( ~('\''|'\\') | ('\\' .) )* '\'' + | '\"' ( ~('\"'|'\\') | ('\\' .) )* '\"' + ; + +BIGINT_LITERAL + : DIGIT+ 'L' + ; + +SMALLINT_LITERAL + : DIGIT+ 'S' + ; + +TINYINT_LITERAL + : DIGIT+ 'Y' + ; + +INTEGER_VALUE + : DIGIT+ + ; + +DECIMAL_VALUE + : DIGIT+ '.' DIGIT* + | '.' DIGIT+ + ; + +SCIENTIFIC_DECIMAL_VALUE + : DIGIT+ ('.' DIGIT*)? EXPONENT + | '.' DIGIT+ EXPONENT + ; + +DOUBLE_LITERAL + : + (INTEGER_VALUE | DECIMAL_VALUE | SCIENTIFIC_DECIMAL_VALUE) 'D' + ; + +IDENTIFIER + : (LETTER | DIGIT | '_')+ + ; + +BACKQUOTED_IDENTIFIER + : '`' ( ~'`' | '``' )* '`' + ; + +fragment EXPONENT + : 'E' [+-]? DIGIT+ + ; + +fragment DIGIT + : [0-9] + ; + +fragment LETTER + : [A-Z] + ; + +SIMPLE_COMMENT + : '--' ~[\r\n]* '\r'? '\n'? -> channel(HIDDEN) + ; + +BRACKETED_COMMENT + : '/*' .*? '*/' -> channel(HIDDEN) + ; + +WS + : [ \r\n\t]+ -> channel(HIDDEN) + ; + +// Catch-all for anything we can't recognize. +// We use this to be able to ignore and recover all the text +// when splitting statements with DelimiterLexer +UNRECOGNIZED + : . + ; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 3540014c3e..105947028d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -21,7 +21,7 @@ import java.sql.{Date, Timestamp} import scala.language.implicitConversions -import org.apache.spark.sql.catalyst.analysis.{EliminateSubqueryAliases, UnresolvedAttribute, UnresolvedExtractValue} +import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.{Inner, JoinType} @@ -161,6 +161,10 @@ package object dsl { def lower(e: Expression): Expression = Lower(e) def sqrt(e: Expression): Expression = Sqrt(e) def abs(e: Expression): Expression = Abs(e) + def star(names: String*): Expression = names match { + case Seq() => UnresolvedStar(None) + case target => UnresolvedStar(Option(target)) + } implicit class DslSymbol(sym: Symbol) extends ImplicitAttribute { def s: String = sym.name } // TODO more implicit class for literal? @@ -231,6 +235,12 @@ package object dsl { AttributeReference(s, structType, nullable = true)() def struct(attrs: AttributeReference*): AttributeReference = struct(StructType.fromAttributes(attrs)) + + /** Create a function. */ + def function(exprs: Expression*): UnresolvedFunction = + UnresolvedFunction(s, exprs, isDistinct = false) + def distinctFunction(exprs: Expression*): UnresolvedFunction = + UnresolvedFunction(s, exprs, isDistinct = true) } implicit class DslAttribute(a: AttributeReference) { @@ -243,8 +253,20 @@ package object dsl { object expressions extends ExpressionConversions // scalastyle:ignore object plans { // scalastyle:ignore + def table(ref: String): LogicalPlan = + UnresolvedRelation(TableIdentifier(ref), None) + + def table(db: String, ref: String): LogicalPlan = + UnresolvedRelation(TableIdentifier(ref, Option(db)), None) + implicit class DslLogicalPlan(val logicalPlan: LogicalPlan) { - def select(exprs: NamedExpression*): LogicalPlan = Project(exprs, logicalPlan) + def select(exprs: Expression*): LogicalPlan = { + val namedExpressions = exprs.map { + case e: NamedExpression => e + case e => UnresolvedAlias(e) + } + Project(namedExpressions, logicalPlan) + } def where(condition: Expression): LogicalPlan = Filter(condition, logicalPlan) @@ -296,6 +318,14 @@ package object dsl { analysis.UnresolvedRelation(TableIdentifier(tableName)), Map.empty, logicalPlan, overwrite, false) + def as(alias: String): LogicalPlan = logicalPlan match { + case UnresolvedRelation(tbl, _) => UnresolvedRelation(tbl, Option(alias)) + case plan => SubqueryAlias(alias, plan) + } + + def distribute(exprs: Expression*): LogicalPlan = + RepartitionByExpression(exprs, logicalPlan) + def analyze: LogicalPlan = EliminateSubqueryAliases(analysis.SimpleAnalyzer.execute(logicalPlan)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ng/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ng/AstBuilder.scala new file mode 100644 index 0000000000..5a64c414fb --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ng/AstBuilder.scala @@ -0,0 +1,1452 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.parser.ng + +import java.sql.{Date, Timestamp} + +import scala.collection.JavaConverters._ +import scala.collection.mutable.ArrayBuffer + +import org.antlr.v4.runtime.{ParserRuleContext, Token} +import org.antlr.v4.runtime.tree.{ParseTree, TerminalNode} + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier} +import org.apache.spark.sql.catalyst.analysis._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.parser.ng.SqlBaseParser._ +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.CalendarInterval +import org.apache.spark.util.random.RandomSampler + +/** + * The AstBuilder converts an ANTLR4 ParseTree into a catalyst Expression, LogicalPlan or + * TableIdentifier. + */ +class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { + import ParserUtils._ + + protected def typedVisit[T](ctx: ParseTree): T = { + ctx.accept(this).asInstanceOf[T] + } + + override def visitSingleStatement(ctx: SingleStatementContext): LogicalPlan = withOrigin(ctx) { + visit(ctx.statement).asInstanceOf[LogicalPlan] + } + + override def visitSingleExpression(ctx: SingleExpressionContext): Expression = withOrigin(ctx) { + visitNamedExpression(ctx.namedExpression) + } + + override def visitSingleTableIdentifier( + ctx: SingleTableIdentifierContext): TableIdentifier = withOrigin(ctx) { + visitTableIdentifier(ctx.tableIdentifier) + } + + override def visitSingleDataType(ctx: SingleDataTypeContext): DataType = withOrigin(ctx) { + visit(ctx.dataType).asInstanceOf[DataType] + } + + /* ******************************************************************************************** + * Plan parsing + * ******************************************************************************************** */ + protected def plan(tree: ParserRuleContext): LogicalPlan = typedVisit(tree) + + /** + * Make sure we do not try to create a plan for a native command. + */ + override def visitExecuteNativeCommand(ctx: ExecuteNativeCommandContext): LogicalPlan = null + + /** + * Create a plan for a SHOW FUNCTIONS command. + */ + override def visitShowFunctions(ctx: ShowFunctionsContext): LogicalPlan = withOrigin(ctx) { + import ctx._ + if (qualifiedName != null) { + val names = qualifiedName().identifier().asScala.map(_.getText).toList + names match { + case db :: name :: Nil => + ShowFunctions(Some(db), Some(name)) + case name :: Nil => + ShowFunctions(None, Some(name)) + case _ => + throw new ParseException("SHOW FUNCTIONS unsupported name", ctx) + } + } else if (pattern != null) { + ShowFunctions(None, Some(string(pattern))) + } else { + ShowFunctions(None, None) + } + } + + /** + * Create a plan for a DESCRIBE FUNCTION command. + */ + override def visitDescribeFunction(ctx: DescribeFunctionContext): LogicalPlan = withOrigin(ctx) { + val functionName = ctx.qualifiedName().identifier().asScala.map(_.getText).mkString(".") + DescribeFunction(functionName, ctx.EXTENDED != null) + } + + /** + * Create a top-level plan with Common Table Expressions. + */ + override def visitQuery(ctx: QueryContext): LogicalPlan = withOrigin(ctx) { + val query = plan(ctx.queryNoWith) + + // Apply CTEs + query.optional(ctx.ctes) { + val ctes = ctx.ctes.namedQuery.asScala.map { + case nCtx => + val namedQuery = visitNamedQuery(nCtx) + (namedQuery.alias, namedQuery) + } + + // Check for duplicate names. + ctes.groupBy(_._1).filter(_._2.size > 1).foreach { + case (name, _) => + throw new ParseException( + s"Name '$name' is used for multiple common table expressions", ctx) + } + + With(query, ctes.toMap) + } + } + + /** + * Create a named logical plan. + * + * This is only used for Common Table Expressions. + */ + override def visitNamedQuery(ctx: NamedQueryContext): SubqueryAlias = withOrigin(ctx) { + SubqueryAlias(ctx.name.getText, plan(ctx.queryNoWith)) + } + + /** + * Create a logical plan which allows for multiple inserts using one 'from' statement. These + * queries have the following SQL form: + * {{{ + * [WITH cte...]? + * FROM src + * [INSERT INTO tbl1 SELECT *]+ + * }}} + * For example: + * {{{ + * FROM db.tbl1 A + * INSERT INTO dbo.tbl1 SELECT * WHERE A.value = 10 LIMIT 5 + * INSERT INTO dbo.tbl2 SELECT * WHERE A.value = 12 + * }}} + * This (Hive) feature cannot be combined with set-operators. + */ + override def visitMultiInsertQuery(ctx: MultiInsertQueryContext): LogicalPlan = withOrigin(ctx) { + val from = visitFromClause(ctx.fromClause) + + // Build the insert clauses. + val inserts = ctx.multiInsertQueryBody.asScala.map { + body => + assert(body.querySpecification.fromClause == null, + "Multi-Insert queries cannot have a FROM clause in their individual SELECT statements", + body) + + withQuerySpecification(body.querySpecification, from). + // Add organization statements. + optionalMap(body.queryOrganization)(withQueryResultClauses). + // Add insert. + optionalMap(body.insertInto())(withInsertInto) + } + + // If there are multiple INSERTS just UNION them together into one query. + inserts match { + case Seq(query) => query + case queries => Union(queries) + } + } + + /** + * Create a logical plan for a regular (single-insert) query. + */ + override def visitSingleInsertQuery( + ctx: SingleInsertQueryContext): LogicalPlan = withOrigin(ctx) { + plan(ctx.queryTerm). + // Add organization statements. + optionalMap(ctx.queryOrganization)(withQueryResultClauses). + // Add insert. + optionalMap(ctx.insertInto())(withInsertInto) + } + + /** + * Add an INSERT INTO [TABLE]/INSERT OVERWRITE TABLE operation to the logical plan. + */ + private def withInsertInto( + ctx: InsertIntoContext, + query: LogicalPlan): LogicalPlan = withOrigin(ctx) { + val tableIdent = visitTableIdentifier(ctx.tableIdentifier) + val partitionKeys = Option(ctx.partitionSpec).map(visitPartitionSpec).getOrElse(Map.empty) + + InsertIntoTable( + UnresolvedRelation(tableIdent, None), + partitionKeys, + query, + ctx.OVERWRITE != null, + ctx.EXISTS != null) + } + + /** + * Create a partition specification map. + */ + override def visitPartitionSpec( + ctx: PartitionSpecContext): Map[String, Option[String]] = withOrigin(ctx) { + ctx.partitionVal.asScala.map { pVal => + val name = pVal.identifier.getText.toLowerCase + val value = Option(pVal.constant).map(visitStringConstant) + name -> value + }.toMap + } + + /** + * Create a partition specification map without optional values. + */ + protected def visitNonOptionalPartitionSpec( + ctx: PartitionSpecContext): Map[String, String] = withOrigin(ctx) { + visitPartitionSpec(ctx).mapValues(_.orNull).map(identity) + } + + /** + * Convert a constant of any type into a string. This is typically used in DDL commands, and its + * main purpose is to prevent slight differences due to back to back conversions i.e.: + * String -> Literal -> String. + */ + protected def visitStringConstant(ctx: ConstantContext): String = withOrigin(ctx) { + ctx match { + case s: StringLiteralContext => createString(s) + case o => o.getText + } + } + + /** + * Add ORDER BY/SORT BY/CLUSTER BY/DISTRIBUTE BY/LIMIT/WINDOWS clauses to the logical plan. These + * clauses determine the shape (ordering/partitioning/rows) of the query result. + */ + private def withQueryResultClauses( + ctx: QueryOrganizationContext, + query: LogicalPlan): LogicalPlan = withOrigin(ctx) { + import ctx._ + + // Handle ORDER BY, SORT BY, DISTRIBUTE BY, and CLUSTER BY clause. + val withOrder = if ( + !order.isEmpty && sort.isEmpty && distributeBy.isEmpty && clusterBy.isEmpty) { + // ORDER BY ... + Sort(order.asScala.map(visitSortItem), global = true, query) + } else if (order.isEmpty && !sort.isEmpty && distributeBy.isEmpty && clusterBy.isEmpty) { + // SORT BY ... + Sort(sort.asScala.map(visitSortItem), global = false, query) + } else if (order.isEmpty && sort.isEmpty && !distributeBy.isEmpty && clusterBy.isEmpty) { + // DISTRIBUTE BY ... + RepartitionByExpression(expressionList(distributeBy), query) + } else if (order.isEmpty && !sort.isEmpty && !distributeBy.isEmpty && clusterBy.isEmpty) { + // SORT BY ... DISTRIBUTE BY ... + Sort( + sort.asScala.map(visitSortItem), + global = false, + RepartitionByExpression(expressionList(distributeBy), query)) + } else if (order.isEmpty && sort.isEmpty && distributeBy.isEmpty && !clusterBy.isEmpty) { + // CLUSTER BY ... + val expressions = expressionList(clusterBy) + Sort( + expressions.map(SortOrder(_, Ascending)), + global = false, + RepartitionByExpression(expressions, query)) + } else if (order.isEmpty && sort.isEmpty && distributeBy.isEmpty && clusterBy.isEmpty) { + // [EMPTY] + query + } else { + throw new ParseException( + "Combination of ORDER BY/SORT BY/DISTRIBUTE BY/CLUSTER BY is not supported", ctx) + } + + // WINDOWS + val withWindow = withOrder.optionalMap(windows)(withWindows) + + // LIMIT + withWindow.optional(limit) { + Limit(typedVisit(limit), withWindow) + } + } + + /** + * Create a logical plan using a query specification. + */ + override def visitQuerySpecification( + ctx: QuerySpecificationContext): LogicalPlan = withOrigin(ctx) { + val from = OneRowRelation.optional(ctx.fromClause) { + visitFromClause(ctx.fromClause) + } + withQuerySpecification(ctx, from) + } + + /** + * Add a query specification to a logical plan. The query specification is the core of the logical + * plan, this is where sourcing (FROM clause), transforming (SELECT TRANSFORM/MAP/REDUCE), + * projection (SELECT), aggregation (GROUP BY ... HAVING ...) and filtering (WHERE) takes place. + * + * Note that query hints are ignored (both by the parser and the builder). + */ + private def withQuerySpecification( + ctx: QuerySpecificationContext, + relation: LogicalPlan): LogicalPlan = withOrigin(ctx) { + import ctx._ + + // WHERE + def filter(ctx: BooleanExpressionContext, plan: LogicalPlan): LogicalPlan = { + Filter(expression(ctx), plan) + } + + // Expressions. + val expressions = Option(namedExpressionSeq).toSeq + .flatMap(_.namedExpression.asScala) + .map(typedVisit[Expression]) + + // Create either a transform or a regular query. + val specType = Option(kind).map(_.getType).getOrElse(SqlBaseParser.SELECT) + specType match { + case SqlBaseParser.MAP | SqlBaseParser.REDUCE | SqlBaseParser.TRANSFORM => + // Transform + + // Add where. + val withFilter = relation.optionalMap(where)(filter) + + // Create the attributes. + val (attributes, schemaLess) = if (colTypeList != null) { + // Typed return columns. + (createStructType(colTypeList).toAttributes, false) + } else if (identifierSeq != null) { + // Untyped return columns. + val attrs = visitIdentifierSeq(identifierSeq).map { name => + AttributeReference(name, StringType, nullable = true)() + } + (attrs, false) + } else { + (Seq(AttributeReference("key", StringType)(), + AttributeReference("value", StringType)()), true) + } + + // Create the transform. + ScriptTransformation( + expressions, + string(script), + attributes, + withFilter, + withScriptIOSchema(inRowFormat, recordWriter, outRowFormat, recordReader, schemaLess)) + + case SqlBaseParser.SELECT => + // Regular select + + // Add lateral views. + val withLateralView = ctx.lateralView.asScala.foldLeft(relation)(withGenerate) + + // Add where. + val withFilter = withLateralView.optionalMap(where)(filter) + + // Add aggregation or a project. + val namedExpressions = expressions.map { + case e: NamedExpression => e + case e: Expression => UnresolvedAlias(e) + } + val withProject = if (aggregation != null) { + withAggregation(aggregation, namedExpressions, withFilter) + } else if (namedExpressions.nonEmpty) { + Project(namedExpressions, withFilter) + } else { + withFilter + } + + // Having + val withHaving = withProject.optional(having) { + // Note that we added a cast to boolean. If the expression itself is already boolean, + // the optimizer will get rid of the unnecessary cast. + Filter(Cast(expression(having), BooleanType), withProject) + } + + // Distinct + val withDistinct = if (setQuantifier() != null && setQuantifier().DISTINCT() != null) { + Distinct(withHaving) + } else { + withHaving + } + + // Window + withDistinct.optionalMap(windows)(withWindows) + } + } + + /** + * Create a (Hive based) [[ScriptInputOutputSchema]]. + */ + protected def withScriptIOSchema( + inRowFormat: RowFormatContext, + recordWriter: Token, + outRowFormat: RowFormatContext, + recordReader: Token, + schemaLess: Boolean): ScriptInputOutputSchema = null + + /** + * Create a logical plan for a given 'FROM' clause. Note that we support multiple (comma + * separated) relations here, these get converted into a single plan by condition-less inner join. + */ + override def visitFromClause(ctx: FromClauseContext): LogicalPlan = withOrigin(ctx) { + val from = ctx.relation.asScala.map(plan).reduceLeft(Join(_, _, Inner, None)) + ctx.lateralView.asScala.foldLeft(from)(withGenerate) + } + + /** + * Connect two queries by a Set operator. + * + * Supported Set operators are: + * - UNION [DISTINCT] + * - UNION ALL + * - EXCEPT [DISTINCT] + * - INTERSECT [DISTINCT] + */ + override def visitSetOperation(ctx: SetOperationContext): LogicalPlan = withOrigin(ctx) { + val left = plan(ctx.left) + val right = plan(ctx.right) + val all = Option(ctx.setQuantifier()).exists(_.ALL != null) + ctx.operator.getType match { + case SqlBaseParser.UNION if all => + Union(left, right) + case SqlBaseParser.UNION => + Distinct(Union(left, right)) + case SqlBaseParser.INTERSECT if all => + throw new ParseException("INTERSECT ALL is not supported.", ctx) + case SqlBaseParser.INTERSECT => + Intersect(left, right) + case SqlBaseParser.EXCEPT if all => + throw new ParseException("EXCEPT ALL is not supported.", ctx) + case SqlBaseParser.EXCEPT => + Except(left, right) + } + } + + /** + * Add a [[WithWindowDefinition]] operator to a logical plan. + */ + private def withWindows( + ctx: WindowsContext, + query: LogicalPlan): LogicalPlan = withOrigin(ctx) { + // Collect all window specifications defined in the WINDOW clause. + val baseWindowMap = ctx.namedWindow.asScala.map { + wCtx => + (wCtx.identifier.getText, typedVisit[WindowSpec](wCtx.windowSpec)) + }.toMap + + // Handle cases like + // window w1 as (partition by p_mfgr order by p_name + // range between 2 preceding and 2 following), + // w2 as w1 + val windowMapView = baseWindowMap.mapValues { + case WindowSpecReference(name) => + baseWindowMap.get(name) match { + case Some(spec: WindowSpecDefinition) => + spec + case Some(ref) => + throw new ParseException(s"Window reference '$name' is not a window specification", ctx) + case None => + throw new ParseException(s"Cannot resolve window reference '$name'", ctx) + } + case spec: WindowSpecDefinition => spec + } + + // Note that mapValues creates a view instead of materialized map. We force materialization by + // mapping over identity. + WithWindowDefinition(windowMapView.map(identity), query) + } + + /** + * Add an [[Aggregate]] to a logical plan. + */ + private def withAggregation( + ctx: AggregationContext, + selectExpressions: Seq[NamedExpression], + query: LogicalPlan): LogicalPlan = withOrigin(ctx) { + import ctx._ + val groupByExpressions = expressionList(groupingExpressions) + + if (GROUPING != null) { + // GROUP BY .... GROUPING SETS (...) + val expressionMap = groupByExpressions.zipWithIndex.toMap + val numExpressions = expressionMap.size + val mask = (1 << numExpressions) - 1 + val masks = ctx.groupingSet.asScala.map { + _.expression.asScala.foldLeft(mask) { + case (bitmap, eCtx) => + // Find the index of the expression. + val e = typedVisit[Expression](eCtx) + val index = expressionMap.find(_._1.semanticEquals(e)).map(_._2).getOrElse( + throw new ParseException( + s"$e doesn't show up in the GROUP BY list", ctx)) + // 0 means that the column at the given index is a grouping column, 1 means it is not, + // so we unset the bit in bitmap. + bitmap & ~(1 << (numExpressions - 1 - index)) + } + } + GroupingSets(masks, groupByExpressions, query, selectExpressions) + } else { + // GROUP BY .... (WITH CUBE | WITH ROLLUP)? + val mappedGroupByExpressions = if (CUBE != null) { + Seq(Cube(groupByExpressions)) + } else if (ROLLUP != null) { + Seq(Rollup(groupByExpressions)) + } else { + groupByExpressions + } + Aggregate(mappedGroupByExpressions, selectExpressions, query) + } + } + + /** + * Add a [[Generate]] (Lateral View) to a logical plan. + */ + private def withGenerate( + query: LogicalPlan, + ctx: LateralViewContext): LogicalPlan = withOrigin(ctx) { + val expressions = expressionList(ctx.expression) + + // Create the generator. + val generator = ctx.qualifiedName.getText.toLowerCase match { + case "explode" if expressions.size == 1 => + Explode(expressions.head) + case "json_tuple" => + JsonTuple(expressions) + case other => + withGenerator(other, expressions, ctx) + } + + Generate( + generator, + join = true, + outer = ctx.OUTER != null, + Some(ctx.tblName.getText.toLowerCase), + ctx.colName.asScala.map(_.getText).map(UnresolvedAttribute.apply), + query) + } + + /** + * Create a [[Generator]]. Override this method in order to support custom Generators. + */ + protected def withGenerator( + name: String, + expressions: Seq[Expression], + ctx: LateralViewContext): Generator = { + throw new ParseException(s"Generator function '$name' is not supported", ctx) + } + + /** + * Create a joins between two or more logical plans. + */ + override def visitJoinRelation(ctx: JoinRelationContext): LogicalPlan = withOrigin(ctx) { + /** Build a join between two plans. */ + def join(ctx: JoinRelationContext, left: LogicalPlan, right: LogicalPlan): Join = { + val baseJoinType = ctx.joinType match { + case null => Inner + case jt if jt.FULL != null => FullOuter + case jt if jt.SEMI != null => LeftSemi + case jt if jt.LEFT != null => LeftOuter + case jt if jt.RIGHT != null => RightOuter + case _ => Inner + } + + // Resolve the join type and join condition + val (joinType, condition) = Option(ctx.joinCriteria) match { + case Some(c) if c.USING != null => + val columns = c.identifier.asScala.map { column => + UnresolvedAttribute.quoted(column.getText) + } + (UsingJoin(baseJoinType, columns), None) + case Some(c) if c.booleanExpression != null => + (baseJoinType, Option(expression(c.booleanExpression))) + case None if ctx.NATURAL != null => + (NaturalJoin(baseJoinType), None) + case None => + (baseJoinType, None) + } + Join(left, right, joinType, condition) + } + + // Handle all consecutive join clauses. ANTLR produces a right nested tree in which the the + // first join clause is at the top. However fields of previously referenced tables can be used + // in following join clauses. The tree needs to be reversed in order to make this work. + var result = plan(ctx.left) + var current = ctx + while (current != null) { + current.right match { + case right: JoinRelationContext => + result = join(current, result, plan(right.left)) + current = right + case right => + result = join(current, result, plan(right)) + current = null + } + } + result + } + + /** + * Add a [[Sample]] to a logical plan. + * + * This currently supports the following sampling methods: + * - TABLESAMPLE(x ROWS): Sample the table down to the given number of rows. + * - TABLESAMPLE(x PERCENT): Sample the table down to the given percentage. Note that percentages + * are defined as a number between 0 and 100. + * - TABLESAMPLE(BUCKET x OUT OF y): Sample the table down to a 'x' divided by 'y' fraction. + */ + private def withSample(ctx: SampleContext, query: LogicalPlan): LogicalPlan = withOrigin(ctx) { + // Create a sampled plan if we need one. + def sample(fraction: Double): Sample = { + // The range of fraction accepted by Sample is [0, 1]. Because Hive's block sampling + // function takes X PERCENT as the input and the range of X is [0, 100], we need to + // adjust the fraction. + val eps = RandomSampler.roundingEpsilon + assert(fraction >= 0.0 - eps && fraction <= 1.0 + eps, + s"Sampling fraction ($fraction) must be on interval [0, 1]", + ctx) + Sample(0.0, fraction, withReplacement = false, (math.random * 1000).toInt, query)(true) + } + + ctx.sampleType.getType match { + case SqlBaseParser.ROWS => + Limit(expression(ctx.expression), query) + + case SqlBaseParser.PERCENTLIT => + val fraction = ctx.percentage.getText.toDouble + sample(fraction / 100.0d) + + case SqlBaseParser.BUCKET if ctx.ON != null => + throw new ParseException("TABLESAMPLE(BUCKET x OUT OF y ON id) is not supported", ctx) + + case SqlBaseParser.BUCKET => + sample(ctx.numerator.getText.toDouble / ctx.denominator.getText.toDouble) + } + } + + /** + * Create a logical plan for a sub-query. + */ + override def visitSubquery(ctx: SubqueryContext): LogicalPlan = withOrigin(ctx) { + plan(ctx.queryNoWith) + } + + /** + * Create an un-aliased table reference. This is typically used for top-level table references, + * for example: + * {{{ + * INSERT INTO db.tbl2 + * TABLE db.tbl1 + * }}} + */ + override def visitTable(ctx: TableContext): LogicalPlan = withOrigin(ctx) { + UnresolvedRelation(visitTableIdentifier(ctx.tableIdentifier), None) + } + + /** + * Create an aliased table reference. This is typically used in FROM clauses. + */ + override def visitTableName(ctx: TableNameContext): LogicalPlan = withOrigin(ctx) { + val table = UnresolvedRelation( + visitTableIdentifier(ctx.tableIdentifier), + Option(ctx.identifier).map(_.getText)) + table.optionalMap(ctx.sample)(withSample) + } + + /** + * Create an inline table (a virtual table in Hive parlance). + */ + override def visitInlineTable(ctx: InlineTableContext): LogicalPlan = withOrigin(ctx) { + // Get the backing expressions. + val expressions = ctx.expression.asScala.map { eCtx => + val e = expression(eCtx) + assert(e.foldable, "All expressions in an inline table must be constants.", eCtx) + e + } + + // Validate and evaluate the rows. + val (structType, structConstructor) = expressions.head.dataType match { + case st: StructType => + (st, (e: Expression) => e) + case dt => + val st = CreateStruct(Seq(expressions.head)).dataType + (st, (e: Expression) => CreateStruct(Seq(e))) + } + val rows = expressions.map { + case expression => + val safe = Cast(structConstructor(expression), structType) + safe.eval().asInstanceOf[InternalRow] + } + + // Construct attributes. + val baseAttributes = structType.toAttributes.map(_.withNullability(true)) + val attributes = if (ctx.identifierList != null) { + val aliases = visitIdentifierList(ctx.identifierList) + assert(aliases.size == baseAttributes.size, + "Number of aliases must match the number of fields in an inline table.", ctx) + baseAttributes.zip(aliases).map(p => p._1.withName(p._2)) + } else { + baseAttributes + } + + // Create plan and add an alias if a name has been defined. + LocalRelation(attributes, rows).optionalMap(ctx.identifier)(aliasPlan) + } + + /** + * Create an alias (SubqueryAlias) for a join relation. This is practically the same as + * visitAliasedQuery and visitNamedExpression, ANTLR4 however requires us to use 3 different + * hooks. + */ + override def visitAliasedRelation(ctx: AliasedRelationContext): LogicalPlan = withOrigin(ctx) { + plan(ctx.relation).optionalMap(ctx.sample)(withSample).optionalMap(ctx.identifier)(aliasPlan) + } + + /** + * Create an alias (SubqueryAlias) for a sub-query. This is practically the same as + * visitAliasedRelation and visitNamedExpression, ANTLR4 however requires us to use 3 different + * hooks. + */ + override def visitAliasedQuery(ctx: AliasedQueryContext): LogicalPlan = withOrigin(ctx) { + plan(ctx.queryNoWith).optionalMap(ctx.sample)(withSample).optionalMap(ctx.identifier)(aliasPlan) + } + + /** + * Create an alias (SubqueryAlias) for a LogicalPlan. + */ + private def aliasPlan(alias: IdentifierContext, plan: LogicalPlan): LogicalPlan = { + SubqueryAlias(alias.getText, plan) + } + + /** + * Create a Sequence of Strings for a parenthesis enclosed alias list. + */ + override def visitIdentifierList(ctx: IdentifierListContext): Seq[String] = withOrigin(ctx) { + visitIdentifierSeq(ctx.identifierSeq) + } + + /** + * Create a Sequence of Strings for an identifier list. + */ + override def visitIdentifierSeq(ctx: IdentifierSeqContext): Seq[String] = withOrigin(ctx) { + ctx.identifier.asScala.map(_.getText) + } + + /* ******************************************************************************************** + * Table Identifier parsing + * ******************************************************************************************** */ + /** + * Create a [[TableIdentifier]] from a 'tableName' or 'databaseName'.'tableName' pattern. + */ + override def visitTableIdentifier( + ctx: TableIdentifierContext): TableIdentifier = withOrigin(ctx) { + TableIdentifier(ctx.table.getText, Option(ctx.db).map(_.getText)) + } + + /* ******************************************************************************************** + * Expression parsing + * ******************************************************************************************** */ + /** + * Create an expression from the given context. This method just passes the context on to the + * vistor and only takes care of typing (We assume that the visitor returns an Expression here). + */ + protected def expression(ctx: ParserRuleContext): Expression = typedVisit(ctx) + + /** + * Create sequence of expressions from the given sequence of contexts. + */ + private def expressionList(trees: java.util.List[ExpressionContext]): Seq[Expression] = { + trees.asScala.map(expression) + } + + /** + * Invert a boolean expression if it has a valid NOT clause. + */ + private def invertIfNotDefined(expression: Expression, not: TerminalNode): Expression = { + if (not != null) { + Not(expression) + } else { + expression + } + } + + /** + * Create a star (i.e. all) expression; this selects all elements (in the specified object). + * Both un-targeted (global) and targeted aliases are supported. + */ + override def visitStar(ctx: StarContext): Expression = withOrigin(ctx) { + UnresolvedStar(Option(ctx.qualifiedName()).map(_.identifier.asScala.map(_.getText))) + } + + /** + * Create an aliased expression if an alias is specified. Both single and multi-aliases are + * supported. + */ + override def visitNamedExpression(ctx: NamedExpressionContext): Expression = withOrigin(ctx) { + val e = expression(ctx.expression) + if (ctx.identifier != null) { + Alias(e, ctx.identifier.getText)() + } else if (ctx.identifierList != null) { + MultiAlias(e, visitIdentifierList(ctx.identifierList)) + } else { + e + } + } + + /** + * Combine a number of boolean expressions into a balanced expression tree. These expressions are + * either combined by a logical [[And]] or a logical [[Or]]. + * + * A balanced binary tree is created because regular left recursive trees cause considerable + * performance degradations and can cause stack overflows. + */ + override def visitLogicalBinary(ctx: LogicalBinaryContext): Expression = withOrigin(ctx) { + val expressionType = ctx.operator.getType + val expressionCombiner = expressionType match { + case SqlBaseParser.AND => And.apply _ + case SqlBaseParser.OR => Or.apply _ + } + + // Collect all similar left hand contexts. + val contexts = ArrayBuffer(ctx.right) + var current = ctx.left + def collectContexts: Boolean = current match { + case lbc: LogicalBinaryContext if lbc.operator.getType == expressionType => + contexts += lbc.right + current = lbc.left + true + case _ => + contexts += current + false + } + while (collectContexts) { + // No body - all updates take place in the collectContexts. + } + + // Reverse the contexts to have them in the same sequence as in the SQL statement & turn them + // into expressions. + val expressions = contexts.reverse.map(expression) + + // Create a balanced tree. + def reduceToExpressionTree(low: Int, high: Int): Expression = high - low match { + case 0 => + expressions(low) + case 1 => + expressionCombiner(expressions(low), expressions(high)) + case x => + val mid = low + x / 2 + expressionCombiner( + reduceToExpressionTree(low, mid), + reduceToExpressionTree(mid + 1, high)) + } + reduceToExpressionTree(0, expressions.size - 1) + } + + /** + * Invert a boolean expression. + */ + override def visitLogicalNot(ctx: LogicalNotContext): Expression = withOrigin(ctx) { + Not(expression(ctx.booleanExpression())) + } + + /** + * Create a filtering correlated sub-query. This is not supported yet. + */ + override def visitExists(ctx: ExistsContext): Expression = { + throw new ParseException("EXISTS clauses are not supported.", ctx) + } + + /** + * Create a comparison expression. This compares two expressions. The following comparison + * operators are supported: + * - Equal: '=' or '==' + * - Null-safe Equal: '<=>' + * - Not Equal: '<>' or '!=' + * - Less than: '<' + * - Less then or Equal: '<=' + * - Greater than: '>' + * - Greater then or Equal: '>=' + */ + override def visitComparison(ctx: ComparisonContext): Expression = withOrigin(ctx) { + val left = expression(ctx.left) + val right = expression(ctx.right) + val operator = ctx.comparisonOperator().getChild(0).asInstanceOf[TerminalNode] + operator.getSymbol.getType match { + case SqlBaseParser.EQ => + EqualTo(left, right) + case SqlBaseParser.NSEQ => + EqualNullSafe(left, right) + case SqlBaseParser.NEQ | SqlBaseParser.NEQJ => + Not(EqualTo(left, right)) + case SqlBaseParser.LT => + LessThan(left, right) + case SqlBaseParser.LTE => + LessThanOrEqual(left, right) + case SqlBaseParser.GT => + GreaterThan(left, right) + case SqlBaseParser.GTE => + GreaterThanOrEqual(left, right) + } + } + + /** + * Create a BETWEEN expression. This tests if an expression lies with in the bounds set by two + * other expressions. The inverse can also be created. + */ + override def visitBetween(ctx: BetweenContext): Expression = withOrigin(ctx) { + val value = expression(ctx.value) + val between = And( + GreaterThanOrEqual(value, expression(ctx.lower)), + LessThanOrEqual(value, expression(ctx.upper))) + invertIfNotDefined(between, ctx.NOT) + } + + /** + * Create an IN expression. This tests if the value of the left hand side expression is + * contained by the sequence of expressions on the right hand side. + */ + override def visitInList(ctx: InListContext): Expression = withOrigin(ctx) { + val in = In(expression(ctx.value), ctx.expression().asScala.map(expression)) + invertIfNotDefined(in, ctx.NOT) + } + + /** + * Create an IN expression, where the the right hand side is a query. This is unsupported. + */ + override def visitInSubquery(ctx: InSubqueryContext): Expression = { + throw new ParseException("IN with a Sub-query is currently not supported.", ctx) + } + + /** + * Create a (R)LIKE/REGEXP expression. + */ + override def visitLike(ctx: LikeContext): Expression = { + val left = expression(ctx.value) + val right = expression(ctx.pattern) + val like = ctx.like.getType match { + case SqlBaseParser.LIKE => + Like(left, right) + case SqlBaseParser.RLIKE => + RLike(left, right) + } + invertIfNotDefined(like, ctx.NOT) + } + + /** + * Create an IS (NOT) NULL expression. + */ + override def visitNullPredicate(ctx: NullPredicateContext): Expression = withOrigin(ctx) { + val value = expression(ctx.value) + if (ctx.NOT != null) { + IsNotNull(value) + } else { + IsNull(value) + } + } + + /** + * Create a binary arithmetic expression. The following arithmetic operators are supported: + * - Mulitplication: '*' + * - Division: '/' + * - Hive Long Division: 'DIV' + * - Modulo: '%' + * - Addition: '+' + * - Subtraction: '-' + * - Binary AND: '&' + * - Binary XOR + * - Binary OR: '|' + */ + override def visitArithmeticBinary(ctx: ArithmeticBinaryContext): Expression = withOrigin(ctx) { + val left = expression(ctx.left) + val right = expression(ctx.right) + ctx.operator.getType match { + case SqlBaseParser.ASTERISK => + Multiply(left, right) + case SqlBaseParser.SLASH => + Divide(left, right) + case SqlBaseParser.PERCENT => + Remainder(left, right) + case SqlBaseParser.DIV => + Cast(Divide(left, right), LongType) + case SqlBaseParser.PLUS => + Add(left, right) + case SqlBaseParser.MINUS => + Subtract(left, right) + case SqlBaseParser.AMPERSAND => + BitwiseAnd(left, right) + case SqlBaseParser.HAT => + BitwiseXor(left, right) + case SqlBaseParser.PIPE => + BitwiseOr(left, right) + } + } + + /** + * Create a unary arithmetic expression. The following arithmetic operators are supported: + * - Plus: '+' + * - Minus: '-' + * - Bitwise Not: '~' + */ + override def visitArithmeticUnary(ctx: ArithmeticUnaryContext): Expression = withOrigin(ctx) { + val value = expression(ctx.valueExpression) + ctx.operator.getType match { + case SqlBaseParser.PLUS => + value + case SqlBaseParser.MINUS => + UnaryMinus(value) + case SqlBaseParser.TILDE => + BitwiseNot(value) + } + } + + /** + * Create a [[Cast]] expression. + */ + override def visitCast(ctx: CastContext): Expression = withOrigin(ctx) { + Cast(expression(ctx.expression), typedVisit(ctx.dataType)) + } + + /** + * Create a (windowed) Function expression. + */ + override def visitFunctionCall(ctx: FunctionCallContext): Expression = withOrigin(ctx) { + // Create the function call. + val name = ctx.qualifiedName.getText + val isDistinct = Option(ctx.setQuantifier()).exists(_.DISTINCT != null) + val arguments = ctx.expression().asScala.map(expression) match { + case Seq(UnresolvedStar(None)) if name.toLowerCase == "count" && !isDistinct => + // Transform COUNT(*) into COUNT(1). Move this to analysis? + Seq(Literal(1)) + case expressions => + expressions + } + val function = UnresolvedFunction(name, arguments, isDistinct) + + // Check if the function is evaluated in a windowed context. + ctx.windowSpec match { + case spec: WindowRefContext => + UnresolvedWindowExpression(function, visitWindowRef(spec)) + case spec: WindowDefContext => + WindowExpression(function, visitWindowDef(spec)) + case _ => function + } + } + + /** + * Create a reference to a window frame, i.e. [[WindowSpecReference]]. + */ + override def visitWindowRef(ctx: WindowRefContext): WindowSpecReference = withOrigin(ctx) { + WindowSpecReference(ctx.identifier.getText) + } + + /** + * Create a window definition, i.e. [[WindowSpecDefinition]]. + */ + override def visitWindowDef(ctx: WindowDefContext): WindowSpecDefinition = withOrigin(ctx) { + // CLUSTER BY ... | PARTITION BY ... ORDER BY ... + val partition = ctx.partition.asScala.map(expression) + val order = ctx.sortItem.asScala.map(visitSortItem) + + // RANGE/ROWS BETWEEN ... + val frameSpecOption = Option(ctx.windowFrame).map { frame => + val frameType = frame.frameType.getType match { + case SqlBaseParser.RANGE => RangeFrame + case SqlBaseParser.ROWS => RowFrame + } + + SpecifiedWindowFrame( + frameType, + visitFrameBound(frame.start), + Option(frame.end).map(visitFrameBound).getOrElse(CurrentRow)) + } + + WindowSpecDefinition( + partition, + order, + frameSpecOption.getOrElse(UnspecifiedFrame)) + } + + /** + * Create or resolve a [[FrameBoundary]]. Simple math expressions are allowed for Value + * Preceding/Following boundaries. These expressions must be constant (foldable) and return an + * integer value. + */ + override def visitFrameBound(ctx: FrameBoundContext): FrameBoundary = withOrigin(ctx) { + // We currently only allow foldable integers. + def value: Int = { + val e = expression(ctx.expression) + assert(e.resolved && e.foldable && e.dataType == IntegerType, + "Frame bound value must be a constant integer.", + ctx) + e.eval().asInstanceOf[Int] + } + + // Create the FrameBoundary + ctx.boundType.getType match { + case SqlBaseParser.PRECEDING if ctx.UNBOUNDED != null => + UnboundedPreceding + case SqlBaseParser.PRECEDING => + ValuePreceding(value) + case SqlBaseParser.CURRENT => + CurrentRow + case SqlBaseParser.FOLLOWING if ctx.UNBOUNDED != null => + UnboundedFollowing + case SqlBaseParser.FOLLOWING => + ValueFollowing(value) + } + } + + /** + * Create a [[CreateStruct]] expression. + */ + override def visitRowConstructor(ctx: RowConstructorContext): Expression = withOrigin(ctx) { + CreateStruct(ctx.expression.asScala.map(expression)) + } + + /** + * Create a [[ScalarSubquery]] expression. + */ + override def visitSubqueryExpression( + ctx: SubqueryExpressionContext): Expression = withOrigin(ctx) { + ScalarSubquery(plan(ctx.query)) + } + + /** + * Create a value based [[CaseWhen]] expression. This has the following SQL form: + * {{{ + * CASE [expression] + * WHEN [value] THEN [expression] + * ... + * ELSE [expression] + * END + * }}} + */ + override def visitSimpleCase(ctx: SimpleCaseContext): Expression = withOrigin(ctx) { + val e = expression(ctx.valueExpression) + val branches = ctx.whenClause.asScala.map { wCtx => + (EqualTo(e, expression(wCtx.condition)), expression(wCtx.result)) + } + CaseWhen(branches, Option(ctx.elseExpression).map(expression)) + } + + /** + * Create a condition based [[CaseWhen]] expression. This has the following SQL syntax: + * {{{ + * CASE + * WHEN [predicate] THEN [expression] + * ... + * ELSE [expression] + * END + * }}} + * + * @param ctx the parse tree + * */ + override def visitSearchedCase(ctx: SearchedCaseContext): Expression = withOrigin(ctx) { + val branches = ctx.whenClause.asScala.map { wCtx => + (expression(wCtx.condition), expression(wCtx.result)) + } + CaseWhen(branches, Option(ctx.elseExpression).map(expression)) + } + + /** + * Create a dereference expression. The return type depends on the type of the parent, this can + * either be a [[UnresolvedAttribute]] (if the parent is an [[UnresolvedAttribute]]), or an + * [[UnresolvedExtractValue]] if the parent is some expression. + */ + override def visitDereference(ctx: DereferenceContext): Expression = withOrigin(ctx) { + val attr = ctx.fieldName.getText + expression(ctx.base) match { + case UnresolvedAttribute(nameParts) => + UnresolvedAttribute(nameParts :+ attr) + case e => + UnresolvedExtractValue(e, Literal(attr)) + } + } + + /** + * Create an [[UnresolvedAttribute]] expression. + */ + override def visitColumnReference(ctx: ColumnReferenceContext): Expression = withOrigin(ctx) { + UnresolvedAttribute.quoted(ctx.getText) + } + + /** + * Create an [[UnresolvedExtractValue]] expression, this is used for subscript access to an array. + */ + override def visitSubscript(ctx: SubscriptContext): Expression = withOrigin(ctx) { + UnresolvedExtractValue(expression(ctx.value), expression(ctx.index)) + } + + /** + * Create an expression for an expression between parentheses. This is need because the ANTLR + * visitor cannot automatically convert the nested context into an expression. + */ + override def visitParenthesizedExpression( + ctx: ParenthesizedExpressionContext): Expression = withOrigin(ctx) { + expression(ctx.expression) + } + + /** + * Create a [[SortOrder]] expression. + */ + override def visitSortItem(ctx: SortItemContext): SortOrder = withOrigin(ctx) { + if (ctx.DESC != null) { + SortOrder(expression(ctx.expression), Descending) + } else { + SortOrder(expression(ctx.expression), Ascending) + } + } + + /** + * Create a typed Literal expression. A typed literal has the following SQL syntax: + * {{{ + * [TYPE] '[VALUE]' + * }}} + * Currently Date and Timestamp typed literals are supported. + * + * TODO what the added value of this over casting? + */ + override def visitTypeConstructor(ctx: TypeConstructorContext): Literal = withOrigin(ctx) { + val value = string(ctx.STRING) + ctx.identifier.getText.toUpperCase match { + case "DATE" => + Literal(Date.valueOf(value)) + case "TIMESTAMP" => + Literal(Timestamp.valueOf(value)) + case other => + throw new ParseException(s"Literals of type '$other' are currently not supported.", ctx) + } + } + + /** + * Create a NULL literal expression. + */ + override def visitNullLiteral(ctx: NullLiteralContext): Literal = withOrigin(ctx) { + Literal(null) + } + + /** + * Create a Boolean literal expression. + */ + override def visitBooleanLiteral(ctx: BooleanLiteralContext): Literal = withOrigin(ctx) { + if (ctx.getText.toBoolean) { + Literal.TrueLiteral + } else { + Literal.FalseLiteral + } + } + + /** + * Create an integral literal expression. The code selects the most narrow integral type + * possible, either a BigDecimal, a Long or an Integer is returned. + */ + override def visitIntegerLiteral(ctx: IntegerLiteralContext): Literal = withOrigin(ctx) { + BigDecimal(ctx.getText) match { + case v if v.isValidInt => + Literal(v.intValue()) + case v if v.isValidLong => + Literal(v.longValue()) + case v => Literal(v.underlying()) + } + } + + /** + * Create a double literal for a number denoted in scientifc notation. + */ + override def visitScientificDecimalLiteral( + ctx: ScientificDecimalLiteralContext): Literal = withOrigin(ctx) { + Literal(ctx.getText.toDouble) + } + + /** + * Create a decimal literal for a regular decimal number. + */ + override def visitDecimalLiteral(ctx: DecimalLiteralContext): Literal = withOrigin(ctx) { + Literal(BigDecimal(ctx.getText).underlying()) + } + + /** Create a numeric literal expression. */ + private def numericLiteral(ctx: NumberContext)(f: String => Any): Literal = withOrigin(ctx) { + val raw = ctx.getText + try { + Literal(f(raw.substring(0, raw.length - 1))) + } catch { + case e: NumberFormatException => + throw new ParseException(e.getMessage, ctx) + } + } + + /** + * Create a Byte Literal expression. + */ + override def visitTinyIntLiteral(ctx: TinyIntLiteralContext): Literal = numericLiteral(ctx) { + _.toByte + } + + /** + * Create a Short Literal expression. + */ + override def visitSmallIntLiteral(ctx: SmallIntLiteralContext): Literal = numericLiteral(ctx) { + _.toShort + } + + /** + * Create a Long Literal expression. + */ + override def visitBigIntLiteral(ctx: BigIntLiteralContext): Literal = numericLiteral(ctx) { + _.toLong + } + + /** + * Create a Double Literal expression. + */ + override def visitDoubleLiteral(ctx: DoubleLiteralContext): Literal = numericLiteral(ctx) { + _.toDouble + } + + /** + * Create a String literal expression. + */ + override def visitStringLiteral(ctx: StringLiteralContext): Literal = withOrigin(ctx) { + Literal(createString(ctx)) + } + + /** + * Create a String from a string literal context. This supports multiple consecutive string + * literals, these are concatenated, for example this expression "'hello' 'world'" will be + * converted into "helloworld". + * + * Special characters can be escaped by using Hive/C-style escaping. + */ + private def createString(ctx: StringLiteralContext): String = { + ctx.STRING().asScala.map(string).mkString + } + + /** + * Create a [[CalendarInterval]] literal expression. An interval expression can contain multiple + * unit value pairs, for instance: interval 2 months 2 days. + */ + override def visitInterval(ctx: IntervalContext): Literal = withOrigin(ctx) { + val intervals = ctx.intervalField.asScala.map(visitIntervalField) + assert(intervals.nonEmpty, "at least one time unit should be given for interval literal", ctx) + Literal(intervals.reduce(_.add(_))) + } + + /** + * Create a [[CalendarInterval]] for a unit value pair. Two unit configuration types are + * supported: + * - Single unit. + * - From-To unit (only 'YEAR TO MONTH' and 'DAY TO SECOND' are supported). + */ + override def visitIntervalField(ctx: IntervalFieldContext): CalendarInterval = withOrigin(ctx) { + import ctx._ + val s = value.getText + val interval = (unit.getText.toLowerCase, Option(to).map(_.getText.toLowerCase)) match { + case (u, None) if u.endsWith("s") => + // Handle plural forms, e.g: yearS/monthS/weekS/dayS/hourS/minuteS/hourS/... + CalendarInterval.fromSingleUnitString(u.substring(0, u.length - 1), s) + case (u, None) => + CalendarInterval.fromSingleUnitString(u, s) + case ("year", Some("month")) => + CalendarInterval.fromYearMonthString(s) + case ("day", Some("second")) => + CalendarInterval.fromDayTimeString(s) + case (from, Some(t)) => + throw new ParseException(s"Intervals FROM $from TO $t are not supported.", ctx) + } + assert(interval != null, "No interval can be constructed", ctx) + interval + } + + /* ******************************************************************************************** + * DataType parsing + * ******************************************************************************************** */ + /** + * Resolve/create a primitive type. + */ + override def visitPrimitiveDataType(ctx: PrimitiveDataTypeContext): DataType = withOrigin(ctx) { + (ctx.identifier.getText.toLowerCase, ctx.INTEGER_VALUE().asScala.toList) match { + case ("boolean", Nil) => BooleanType + case ("tinyint" | "byte", Nil) => ByteType + case ("smallint" | "short", Nil) => ShortType + case ("int" | "integer", Nil) => IntegerType + case ("bigint" | "long", Nil) => LongType + case ("float", Nil) => FloatType + case ("double", Nil) => DoubleType + case ("date", Nil) => DateType + case ("timestamp", Nil) => TimestampType + case ("char" | "varchar" | "string", Nil) => StringType + case ("char" | "varchar", _ :: Nil) => StringType + case ("binary", Nil) => BinaryType + case ("decimal", Nil) => DecimalType.USER_DEFAULT + case ("decimal", precision :: Nil) => DecimalType(precision.getText.toInt, 0) + case ("decimal", precision :: scale :: Nil) => + DecimalType(precision.getText.toInt, scale.getText.toInt) + case (dt, params) => + throw new ParseException( + s"DataType $dt${params.mkString("(", ",", ")")} is not supported.", ctx) + } + } + + /** + * Create a complex DataType. Arrays, Maps and Structures are supported. + */ + override def visitComplexDataType(ctx: ComplexDataTypeContext): DataType = withOrigin(ctx) { + ctx.complex.getType match { + case SqlBaseParser.ARRAY => + ArrayType(typedVisit(ctx.dataType(0))) + case SqlBaseParser.MAP => + MapType(typedVisit(ctx.dataType(0)), typedVisit(ctx.dataType(1))) + case SqlBaseParser.STRUCT => + createStructType(ctx.colTypeList()) + } + } + + /** + * Create a [[StructType]] from a sequence of [[StructField]]s. + */ + protected def createStructType(ctx: ColTypeListContext): StructType = { + StructType(Option(ctx).toSeq.flatMap(visitColTypeList)) + } + + /** + * Create a [[StructType]] from a number of column definitions. + */ + override def visitColTypeList(ctx: ColTypeListContext): Seq[StructField] = withOrigin(ctx) { + ctx.colType().asScala.map(visitColType) + } + + /** + * Create a [[StructField]] from a column definition. + */ + override def visitColType(ctx: ColTypeContext): StructField = withOrigin(ctx) { + import ctx._ + + // Add the comment to the metadata. + val builder = new MetadataBuilder + if (STRING != null) { + builder.putString("comment", string(STRING)) + } + + StructField(identifier.getText, typedVisit(dataType), nullable = true, builder.build()) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ng/ParseDriver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ng/ParseDriver.scala new file mode 100644 index 0000000000..c9a286374c --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ng/ParseDriver.scala @@ -0,0 +1,240 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.parser.ng + +import org.antlr.v4.runtime._ +import org.antlr.v4.runtime.atn.PredictionMode +import org.antlr.v4.runtime.misc.ParseCancellationException + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.parser.ParserInterface +import org.apache.spark.sql.catalyst.parser.ng.SqlBaseParser._ +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.trees.Origin +import org.apache.spark.sql.types.DataType + +/** + * Base SQL parsing infrastructure. + */ +abstract class AbstractSqlParser extends ParserInterface with Logging { + + /** Creates/Resolves DataType for a given SQL string. */ + def parseDataType(sqlText: String): DataType = parse(sqlText) { parser => + // TODO add this to the parser interface. + astBuilder.visitSingleDataType(parser.singleDataType()) + } + + /** Creates Expression for a given SQL string. */ + override def parseExpression(sqlText: String): Expression = parse(sqlText) { parser => + astBuilder.visitSingleExpression(parser.singleExpression()) + } + + /** Creates TableIdentifier for a given SQL string. */ + override def parseTableIdentifier(sqlText: String): TableIdentifier = parse(sqlText) { parser => + astBuilder.visitSingleTableIdentifier(parser.singleTableIdentifier()) + } + + /** Creates LogicalPlan for a given SQL string. */ + override def parsePlan(sqlText: String): LogicalPlan = parse(sqlText) { parser => + astBuilder.visitSingleStatement(parser.singleStatement()) match { + case plan: LogicalPlan => plan + case _ => nativeCommand(sqlText) + } + } + + /** Get the builder (visitor) which converts a ParseTree into a AST. */ + protected def astBuilder: AstBuilder + + /** Create a native command, or fail when this is not supported. */ + protected def nativeCommand(sqlText: String): LogicalPlan = { + val position = Origin(None, None) + throw new ParseException(Option(sqlText), "Unsupported SQL statement", position, position) + } + + protected def parse[T](command: String)(toResult: SqlBaseParser => T): T = { + logInfo(s"Parsing command: $command") + + val lexer = new SqlBaseLexer(new ANTLRNoCaseStringStream(command)) + lexer.removeErrorListeners() + lexer.addErrorListener(ParseErrorListener) + + val tokenStream = new CommonTokenStream(lexer) + val parser = new SqlBaseParser(tokenStream) + parser.addParseListener(PostProcessor) + parser.removeErrorListeners() + parser.addErrorListener(ParseErrorListener) + + try { + try { + // first, try parsing with potentially faster SLL mode + parser.getInterpreter.setPredictionMode(PredictionMode.SLL) + toResult(parser) + } + catch { + case e: ParseCancellationException => + // if we fail, parse with LL mode + tokenStream.reset() // rewind input stream + parser.reset() + + // Try Again. + parser.getInterpreter.setPredictionMode(PredictionMode.LL) + toResult(parser) + } + } + catch { + case e: ParseException if e.command.isDefined => + throw e + case e: ParseException => + throw e.withCommand(command) + case e: AnalysisException => + val position = Origin(e.line, e.startPosition) + throw new ParseException(Option(command), e.message, position, position) + } + } +} + +/** + * Concrete SQL parser for Catalyst-only SQL statements. + */ +object CatalystSqlParser extends AbstractSqlParser { + val astBuilder = new AstBuilder +} + +/** + * This string stream provides the lexer with upper case characters only. This greatly simplifies + * lexing the stream, while we can maintain the original command. + * + * This is based on Hive's org.apache.hadoop.hive.ql.parse.ParseDriver.ANTLRNoCaseStringStream + * + * The comment below (taken from the original class) describes the rationale for doing this: + * + * This class provides and implementation for a case insensitive token checker for the lexical + * analysis part of antlr. By converting the token stream into upper case at the time when lexical + * rules are checked, this class ensures that the lexical rules need to just match the token with + * upper case letters as opposed to combination of upper case and lower case characters. This is + * purely used for matching lexical rules. The actual token text is stored in the same way as the + * user input without actually converting it into an upper case. The token values are generated by + * the consume() function of the super class ANTLRStringStream. The LA() function is the lookahead + * function and is purely used for matching lexical rules. This also means that the grammar will + * only accept capitalized tokens in case it is run from other tools like antlrworks which do not + * have the ANTLRNoCaseStringStream implementation. + */ + +private[parser] class ANTLRNoCaseStringStream(input: String) extends ANTLRInputStream(input) { + override def LA(i: Int): Int = { + val la = super.LA(i) + if (la == 0 || la == IntStream.EOF) la + else Character.toUpperCase(la) + } +} + +/** + * The ParseErrorListener converts parse errors into AnalysisExceptions. + */ +case object ParseErrorListener extends BaseErrorListener { + override def syntaxError( + recognizer: Recognizer[_, _], + offendingSymbol: scala.Any, + line: Int, + charPositionInLine: Int, + msg: String, + e: RecognitionException): Unit = { + val position = Origin(Some(line), Some(charPositionInLine)) + throw new ParseException(None, msg, position, position) + } +} + +/** + * A [[ParseException]] is an [[AnalysisException]] that is thrown during the parse process. It + * contains fields and an extended error message that make reporting and diagnosing errors easier. + */ +class ParseException( + val command: Option[String], + message: String, + val start: Origin, + val stop: Origin) extends AnalysisException(message, start.line, start.startPosition) { + + def this(message: String, ctx: ParserRuleContext) = { + this(Option(ParserUtils.command(ctx)), + message, + ParserUtils.position(ctx.getStart), + ParserUtils.position(ctx.getStop)) + } + + override def getMessage: String = { + val builder = new StringBuilder + builder ++= "\n" ++= message + start match { + case Origin(Some(l), Some(p)) => + builder ++= s"(line $l, pos $p)\n" + command.foreach { cmd => + val (above, below) = cmd.split("\n").splitAt(l) + builder ++= "\n== SQL ==\n" + above.foreach(builder ++= _ += '\n') + builder ++= (0 until p).map(_ => "-").mkString("") ++= "^^^\n" + below.foreach(builder ++= _ += '\n') + } + case _ => + command.foreach { cmd => + builder ++= "\n== SQL ==\n" ++= cmd + } + } + builder.toString + } + + def withCommand(cmd: String): ParseException = { + new ParseException(Option(cmd), message, start, stop) + } +} + +/** + * The post-processor validates & cleans-up the parse tree during the parse process. + */ +case object PostProcessor extends SqlBaseBaseListener { + + /** Remove the back ticks from an Identifier. */ + override def exitQuotedIdentifier(ctx: QuotedIdentifierContext): Unit = { + replaceTokenByIdentifier(ctx, 1) { token => + // Remove the double back ticks in the string. + token.setText(token.getText.replace("``", "`")) + token + } + } + + /** Treat non-reserved keywords as Identifiers. */ + override def exitNonReserved(ctx: NonReservedContext): Unit = { + replaceTokenByIdentifier(ctx, 0)(identity) + } + + private def replaceTokenByIdentifier( + ctx: ParserRuleContext, + stripMargins: Int)( + f: CommonToken => CommonToken = identity): Unit = { + val parent = ctx.getParent + parent.removeLastChild() + val token = ctx.getChild(0).getPayload.asInstanceOf[Token] + parent.addChild(f(new CommonToken( + new org.antlr.v4.runtime.misc.Pair(token.getTokenSource, token.getInputStream), + SqlBaseParser.IDENTIFIER, + token.getChannel, + token.getStartIndex + stripMargins, + token.getStopIndex - stripMargins))) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ng/ParserUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ng/ParserUtils.scala new file mode 100644 index 0000000000..1fbfa763b4 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ng/ParserUtils.scala @@ -0,0 +1,118 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.parser.ng + +import org.antlr.v4.runtime.{CharStream, ParserRuleContext, Token} +import org.antlr.v4.runtime.misc.Interval +import org.antlr.v4.runtime.tree.TerminalNode + +import org.apache.spark.sql.catalyst.parser.ParseUtils.unescapeSQLString +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin} + +/** + * A collection of utility methods for use during the parsing process. + */ +object ParserUtils { + /** Get the command which created the token. */ + def command(ctx: ParserRuleContext): String = { + command(ctx.getStart.getInputStream) + } + + /** Get the command which created the token. */ + def command(stream: CharStream): String = { + stream.getText(Interval.of(0, stream.size())) + } + + /** Get the code that creates the given node. */ + def source(ctx: ParserRuleContext): String = { + val stream = ctx.getStart.getInputStream + stream.getText(Interval.of(ctx.getStart.getStartIndex, ctx.getStop.getStopIndex)) + } + + /** Get all the text which comes after the given rule. */ + def remainder(ctx: ParserRuleContext): String = remainder(ctx.getStop) + + /** Get all the text which comes after the given token. */ + def remainder(token: Token): String = { + val stream = token.getInputStream + val interval = Interval.of(token.getStopIndex + 1, stream.size()) + stream.getText(interval) + } + + /** Convert a string token into a string. */ + def string(token: Token): String = unescapeSQLString(token.getText) + + /** Convert a string node into a string. */ + def string(node: TerminalNode): String = unescapeSQLString(node.getText) + + /** Get the origin (line and position) of the token. */ + def position(token: Token): Origin = { + Origin(Option(token.getLine), Option(token.getCharPositionInLine)) + } + + /** Assert if a condition holds. If it doesn't throw a parse exception. */ + def assert(f: => Boolean, message: String, ctx: ParserRuleContext): Unit = { + if (!f) { + throw new ParseException(message, ctx) + } + } + + /** + * Register the origin of the context. Any TreeNode created in the closure will be assigned the + * registered origin. This method restores the previously set origin after completion of the + * closure. + */ + def withOrigin[T](ctx: ParserRuleContext)(f: => T): T = { + val current = CurrentOrigin.get + CurrentOrigin.set(position(ctx.getStart)) + try { + f + } finally { + CurrentOrigin.set(current) + } + } + + /** Some syntactic sugar which makes it easier to work with optional clauses for LogicalPlans. */ + implicit class EnhancedLogicalPlan(val plan: LogicalPlan) extends AnyVal { + /** + * Create a plan using the block of code when the given context exists. Otherwise return the + * original plan. + */ + def optional(ctx: AnyRef)(f: => LogicalPlan): LogicalPlan = { + if (ctx != null) { + f + } else { + plan + } + } + + /** + * Map a [[LogicalPlan]] to another [[LogicalPlan]] if the passed context exists using the + * passed function. The original plan is returned when the context does not exist. + */ + def optionalMap[C <: ParserRuleContext]( + ctx: C)( + f: (C, LogicalPlan) => LogicalPlan): LogicalPlan = { + if (ctx != null) { + f(ctx, plan) + } else { + plan + } + } + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/CatalystQlSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/CatalystQlSuite.scala index c068e895b6..223485e292 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/CatalystQlSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/CatalystQlSuite.scala @@ -21,15 +21,20 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.parser.ng.CatalystSqlParser import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.unsafe.types.CalendarInterval class CatalystQlSuite extends PlanTest { val parser = new CatalystQl() + import org.apache.spark.sql.catalyst.dsl.expressions._ + import org.apache.spark.sql.catalyst.dsl.plans._ + + val star = UnresolvedAlias(UnresolvedStar(None)) test("test case insensitive") { - val result = Project(UnresolvedAlias(Literal(1)):: Nil, OneRowRelation) + val result = OneRowRelation.select(1) assert(result === parser.parsePlan("seLect 1")) assert(result === parser.parsePlan("select 1")) assert(result === parser.parsePlan("SELECT 1")) @@ -37,52 +42,31 @@ class CatalystQlSuite extends PlanTest { test("test NOT operator with comparison operations") { val parsed = parser.parsePlan("SELECT NOT TRUE > TRUE") - val expected = Project( - UnresolvedAlias( - Not( - GreaterThan(Literal(true), Literal(true))) - ) :: Nil, - OneRowRelation) + val expected = OneRowRelation.select(Not(GreaterThan(true, true))) comparePlans(parsed, expected) } test("test Union Distinct operator") { - val parsed1 = parser.parsePlan("SELECT * FROM t0 UNION SELECT * FROM t1") - val parsed2 = parser.parsePlan("SELECT * FROM t0 UNION DISTINCT SELECT * FROM t1") - val expected = - Project(UnresolvedAlias(UnresolvedStar(None)) :: Nil, - SubqueryAlias("u_1", - Distinct( - Union( - Project(UnresolvedAlias(UnresolvedStar(None)) :: Nil, - UnresolvedRelation(TableIdentifier("t0"), None)), - Project(UnresolvedAlias(UnresolvedStar(None)) :: Nil, - UnresolvedRelation(TableIdentifier("t1"), None)))))) + val parsed1 = parser.parsePlan( + "SELECT * FROM t0 UNION SELECT * FROM t1") + val parsed2 = parser.parsePlan( + "SELECT * FROM t0 UNION DISTINCT SELECT * FROM t1") + val expected = Distinct(Union(table("t0").select(star), table("t1").select(star))) + .as("u_1").select(star) comparePlans(parsed1, expected) comparePlans(parsed2, expected) } test("test Union All operator") { val parsed = parser.parsePlan("SELECT * FROM t0 UNION ALL SELECT * FROM t1") - val expected = - Project(UnresolvedAlias(UnresolvedStar(None)) :: Nil, - SubqueryAlias("u_1", - Union( - Project(UnresolvedAlias(UnresolvedStar(None)) :: Nil, - UnresolvedRelation(TableIdentifier("t0"), None)), - Project(UnresolvedAlias(UnresolvedStar(None)) :: Nil, - UnresolvedRelation(TableIdentifier("t1"), None))))) + val expected = Union(table("t0").select(star), table("t1").select(star)).as("u_1").select(star) comparePlans(parsed, expected) } test("support hive interval literal") { def checkInterval(sql: String, result: CalendarInterval): Unit = { val parsed = parser.parsePlan(sql) - val expected = Project( - UnresolvedAlias( - Literal(result) - ) :: Nil, - OneRowRelation) + val expected = OneRowRelation.select(Literal(result)) comparePlans(parsed, expected) } @@ -129,11 +113,7 @@ class CatalystQlSuite extends PlanTest { test("support scientific notation") { def assertRight(input: String, output: Double): Unit = { val parsed = parser.parsePlan("SELECT " + input) - val expected = Project( - UnresolvedAlias( - Literal(output) - ) :: Nil, - OneRowRelation) + val expected = OneRowRelation.select(Literal(output)) comparePlans(parsed, expected) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DataTypeParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DataTypeParserSuite.scala index 7d3608033b..d9bd33c50a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DataTypeParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DataTypeParserSuite.scala @@ -18,19 +18,24 @@ package org.apache.spark.sql.catalyst.parser import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.parser.ng.{CatalystSqlParser, ParseException} import org.apache.spark.sql.types._ -class DataTypeParserSuite extends SparkFunSuite { +abstract class AbstractDataTypeParserSuite extends SparkFunSuite { + + def parse(sql: String): DataType def checkDataType(dataTypeString: String, expectedDataType: DataType): Unit = { test(s"parse ${dataTypeString.replace("\n", "")}") { - assert(DataTypeParser.parse(dataTypeString) === expectedDataType) + assert(parse(dataTypeString) === expectedDataType) } } + def intercept(sql: String) + def unsupported(dataTypeString: String): Unit = { test(s"$dataTypeString is not supported") { - intercept[DataTypeException](DataTypeParser.parse(dataTypeString)) + intercept(dataTypeString) } } @@ -97,13 +102,6 @@ class DataTypeParserSuite extends SparkFunSuite { StructField("arrAy", ArrayType(DoubleType, true), true) :: StructField("anotherArray", ArrayType(StringType, true), true) :: Nil) ) - // A column name can be a reserved word in our DDL parser and SqlParser. - checkDataType( - "Struct<TABLE: string, CASE:boolean>", - StructType( - StructField("TABLE", StringType, true) :: - StructField("CASE", BooleanType, true) :: Nil) - ) // Use backticks to quote column names having special characters. checkDataType( "struct<`x+y`:int, `!@#$%^&*()`:string, `1_2.345<>:\"`:varchar(20)>", @@ -118,6 +116,43 @@ class DataTypeParserSuite extends SparkFunSuite { unsupported("it is not a data type") unsupported("struct<x+y: int, 1.1:timestamp>") unsupported("struct<x: int") +} + +class DataTypeParserSuite extends AbstractDataTypeParserSuite { + override def intercept(sql: String): Unit = + intercept[DataTypeException](DataTypeParser.parse(sql)) + + override def parse(sql: String): DataType = + DataTypeParser.parse(sql) + + // A column name can be a reserved word in our DDL parser and SqlParser. + checkDataType( + "Struct<TABLE: string, CASE:boolean>", + StructType( + StructField("TABLE", StringType, true) :: + StructField("CASE", BooleanType, true) :: Nil) + ) + unsupported("struct<x int, y string>") + unsupported("struct<`x``y` int>") } + +class CatalystQlDataTypeParserSuite extends AbstractDataTypeParserSuite { + override def intercept(sql: String): Unit = + intercept[ParseException](CatalystSqlParser.parseDataType(sql)) + + override def parse(sql: String): DataType = + CatalystSqlParser.parseDataType(sql) + + // A column name can be a reserved word in our DDL parser and SqlParser. + unsupported("Struct<TABLE: string, CASE:boolean>") + + checkDataType( + "struct<x int, y string>", + (new StructType).add("x", IntegerType).add("y", StringType)) + + checkDataType( + "struct<`x``y` int>", + (new StructType).add("x`y", IntegerType)) +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ng/ErrorParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ng/ErrorParserSuite.scala new file mode 100644 index 0000000000..1963fc368f --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ng/ErrorParserSuite.scala @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.parser.ng + +import org.apache.spark.SparkFunSuite + +/** + * Test various parser errors. + */ +class ErrorParserSuite extends SparkFunSuite { + def intercept(sql: String, line: Int, startPosition: Int, messages: String*): Unit = { + val e = intercept[ParseException](CatalystSqlParser.parsePlan(sql)) + + // Check position. + assert(e.line.isDefined) + assert(e.line.get === line) + assert(e.startPosition.isDefined) + assert(e.startPosition.get === startPosition) + + // Check messages. + val error = e.getMessage + messages.foreach { message => + assert(error.contains(message)) + } + } + + test("no viable input") { + intercept("select from tbl", 1, 7, "no viable alternative at input", "-------^^^") + intercept("select\nfrom tbl", 2, 0, "no viable alternative at input", "^^^") + intercept("select ((r + 1) ", 1, 16, "no viable alternative at input", "----------------^^^") + } + + test("extraneous input") { + intercept("select 1 1", 1, 9, "extraneous input '1' expecting", "---------^^^") + intercept("select *\nfrom r as q t", 2, 12, "extraneous input", "------------^^^") + } + + test("mismatched input") { + intercept("select * from r order by q from t", 1, 27, + "mismatched input", + "---------------------------^^^") + intercept("select *\nfrom r\norder by q\nfrom t", 4, 0, "mismatched input", "^^^") + } + + test("semantic errors") { + intercept("select *\nfrom r\norder by q\ncluster by q", 3, 0, + "Combination of ORDER BY/SORT BY/DISTRIBUTE BY/CLUSTER BY is not supported", + "^^^") + intercept("select * from r where a in (select * from t)", 1, 24, + "IN with a Sub-query is currently not supported", + "------------------------^^^") + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ng/ExpressionParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ng/ExpressionParserSuite.scala new file mode 100644 index 0000000000..32311a5a66 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ng/ExpressionParserSuite.scala @@ -0,0 +1,497 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.parser.ng + +import java.sql.{Date, Timestamp} + +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, _} +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.CalendarInterval + +/** + * Test basic expression parsing. If a type of expression is supported it should be tested here. + * + * Please note that some of the expressions test don't have to be sound expressions, only their + * structure needs to be valid. Unsound expressions should be caught by the Analyzer or + * CheckAnalysis classes. + */ +class ExpressionParserSuite extends PlanTest { + import CatalystSqlParser._ + import org.apache.spark.sql.catalyst.dsl.expressions._ + import org.apache.spark.sql.catalyst.dsl.plans._ + + def assertEqual(sqlCommand: String, e: Expression): Unit = { + compareExpressions(parseExpression(sqlCommand), e) + } + + def intercept(sqlCommand: String, messages: String*): Unit = { + val e = intercept[ParseException](parseExpression(sqlCommand)) + messages.foreach { message => + assert(e.message.contains(message)) + } + } + + test("star expressions") { + // Global Star + assertEqual("*", UnresolvedStar(None)) + + // Targeted Star + assertEqual("a.b.*", UnresolvedStar(Option(Seq("a", "b")))) + } + + // NamedExpression (Alias/Multialias) + test("named expressions") { + // No Alias + val r0 = 'a + assertEqual("a", r0) + + // Single Alias. + val r1 = 'a as "b" + assertEqual("a as b", r1) + assertEqual("a b", r1) + + // Multi-Alias + assertEqual("a as (b, c)", MultiAlias('a, Seq("b", "c"))) + assertEqual("a() (b, c)", MultiAlias('a.function(), Seq("b", "c"))) + + // Numeric literals without a space between the literal qualifier and the alias, should not be + // interpreted as such. An unresolved reference should be returned instead. + // TODO add the JIRA-ticket number. + assertEqual("1SL", Symbol("1SL")) + + // Aliased star is allowed. + assertEqual("a.* b", UnresolvedStar(Option(Seq("a"))) as 'b) + } + + test("binary logical expressions") { + // And + assertEqual("a and b", 'a && 'b) + + // Or + assertEqual("a or b", 'a || 'b) + + // Combination And/Or check precedence + assertEqual("a and b or c and d", ('a && 'b) || ('c && 'd)) + assertEqual("a or b or c and d", 'a || 'b || ('c && 'd)) + + // Multiple AND/OR get converted into a balanced tree + assertEqual("a or b or c or d or e or f", (('a || 'b) || 'c) || (('d || 'e) || 'f)) + assertEqual("a and b and c and d and e and f", (('a && 'b) && 'c) && (('d && 'e) && 'f)) + } + + test("long binary logical expressions") { + def testVeryBinaryExpression(op: String, clazz: Class[_]): Unit = { + val sql = (1 to 1000).map(x => s"$x == $x").mkString(op) + val e = parseExpression(sql) + assert(e.collect { case _: EqualTo => true }.size === 1000) + assert(e.collect { case x if clazz.isInstance(x) => true }.size === 999) + } + testVeryBinaryExpression(" AND ", classOf[And]) + testVeryBinaryExpression(" OR ", classOf[Or]) + } + + test("not expressions") { + assertEqual("not a", !'a) + assertEqual("!a", !'a) + assertEqual("not true > true", Not(GreaterThan(true, true))) + } + + test("exists expression") { + intercept("exists (select 1 from b where b.x = a.x)", "EXISTS clauses are not supported") + } + + test("comparison expressions") { + assertEqual("a = b", 'a === 'b) + assertEqual("a == b", 'a === 'b) + assertEqual("a <=> b", 'a <=> 'b) + assertEqual("a <> b", 'a =!= 'b) + assertEqual("a != b", 'a =!= 'b) + assertEqual("a < b", 'a < 'b) + assertEqual("a <= b", 'a <= 'b) + assertEqual("a > b", 'a > 'b) + assertEqual("a >= b", 'a >= 'b) + } + + test("between expressions") { + assertEqual("a between b and c", 'a >= 'b && 'a <= 'c) + assertEqual("a not between b and c", !('a >= 'b && 'a <= 'c)) + } + + test("in expressions") { + assertEqual("a in (b, c, d)", 'a in ('b, 'c, 'd)) + assertEqual("a not in (b, c, d)", !('a in ('b, 'c, 'd))) + } + + test("in sub-query") { + intercept("a in (select b from c)", "IN with a Sub-query is currently not supported") + } + + test("like expressions") { + assertEqual("a like 'pattern%'", 'a like "pattern%") + assertEqual("a not like 'pattern%'", !('a like "pattern%")) + assertEqual("a rlike 'pattern%'", 'a rlike "pattern%") + assertEqual("a not rlike 'pattern%'", !('a rlike "pattern%")) + assertEqual("a regexp 'pattern%'", 'a rlike "pattern%") + assertEqual("a not regexp 'pattern%'", !('a rlike "pattern%")) + } + + test("is null expressions") { + assertEqual("a is null", 'a.isNull) + assertEqual("a is not null", 'a.isNotNull) + assertEqual("a = b is null", ('a === 'b).isNull) + assertEqual("a = b is not null", ('a === 'b).isNotNull) + } + + test("binary arithmetic expressions") { + // Simple operations + assertEqual("a * b", 'a * 'b) + assertEqual("a / b", 'a / 'b) + assertEqual("a DIV b", ('a / 'b).cast(LongType)) + assertEqual("a % b", 'a % 'b) + assertEqual("a + b", 'a + 'b) + assertEqual("a - b", 'a - 'b) + assertEqual("a & b", 'a & 'b) + assertEqual("a ^ b", 'a ^ 'b) + assertEqual("a | b", 'a | 'b) + + // Check precedences + assertEqual( + "a * t | b ^ c & d - e + f % g DIV h / i * k", + 'a * 't | ('b ^ ('c & ('d - 'e + (('f % 'g / 'h).cast(LongType) / 'i * 'k))))) + } + + test("unary arithmetic expressions") { + assertEqual("+a", 'a) + assertEqual("-a", -'a) + assertEqual("~a", ~'a) + assertEqual("-+~~a", -(~(~'a))) + } + + test("cast expressions") { + // Note that DataType parsing is tested elsewhere. + assertEqual("cast(a as int)", 'a.cast(IntegerType)) + assertEqual("cast(a as timestamp)", 'a.cast(TimestampType)) + assertEqual("cast(a as array<int>)", 'a.cast(ArrayType(IntegerType))) + assertEqual("cast(cast(a as int) as long)", 'a.cast(IntegerType).cast(LongType)) + } + + test("function expressions") { + assertEqual("foo()", 'foo.function()) + assertEqual("foo.bar()", Symbol("foo.bar").function()) + assertEqual("foo(*)", 'foo.function(star())) + assertEqual("count(*)", 'count.function(1)) + assertEqual("foo(a, b)", 'foo.function('a, 'b)) + assertEqual("foo(all a, b)", 'foo.function('a, 'b)) + assertEqual("foo(distinct a, b)", 'foo.distinctFunction('a, 'b)) + assertEqual("grouping(distinct a, b)", 'grouping.distinctFunction('a, 'b)) + assertEqual("`select`(all a, b)", 'select.function('a, 'b)) + } + + test("window function expressions") { + val func = 'foo.function(star()) + def windowed( + partitioning: Seq[Expression] = Seq.empty, + ordering: Seq[SortOrder] = Seq.empty, + frame: WindowFrame = UnspecifiedFrame): Expression = { + WindowExpression(func, WindowSpecDefinition(partitioning, ordering, frame)) + } + + // Basic window testing. + assertEqual("foo(*) over w1", UnresolvedWindowExpression(func, WindowSpecReference("w1"))) + assertEqual("foo(*) over ()", windowed()) + assertEqual("foo(*) over (partition by a, b)", windowed(Seq('a, 'b))) + assertEqual("foo(*) over (distribute by a, b)", windowed(Seq('a, 'b))) + assertEqual("foo(*) over (cluster by a, b)", windowed(Seq('a, 'b))) + assertEqual("foo(*) over (order by a desc, b asc)", windowed(Seq.empty, Seq('a.desc, 'b.asc ))) + assertEqual("foo(*) over (sort by a desc, b asc)", windowed(Seq.empty, Seq('a.desc, 'b.asc ))) + assertEqual("foo(*) over (partition by a, b order by c)", windowed(Seq('a, 'b), Seq('c.asc))) + assertEqual("foo(*) over (distribute by a, b sort by c)", windowed(Seq('a, 'b), Seq('c.asc))) + + // Test use of expressions in window functions. + assertEqual( + "sum(product + 1) over (partition by ((product) + (1)) order by 2)", + WindowExpression('sum.function('product + 1), + WindowSpecDefinition(Seq('product + 1), Seq(Literal(2).asc), UnspecifiedFrame))) + assertEqual( + "sum(product + 1) over (partition by ((product / 2) + 1) order by 2)", + WindowExpression('sum.function('product + 1), + WindowSpecDefinition(Seq('product / 2 + 1), Seq(Literal(2).asc), UnspecifiedFrame))) + + // Range/Row + val frameTypes = Seq(("rows", RowFrame), ("range", RangeFrame)) + val boundaries = Seq( + ("10 preceding", ValuePreceding(10), CurrentRow), + ("3 + 1 following", ValueFollowing(4), CurrentRow), // Will fail during analysis + ("unbounded preceding", UnboundedPreceding, CurrentRow), + ("unbounded following", UnboundedFollowing, CurrentRow), // Will fail during analysis + ("between unbounded preceding and current row", UnboundedPreceding, CurrentRow), + ("between unbounded preceding and unbounded following", + UnboundedPreceding, UnboundedFollowing), + ("between 10 preceding and current row", ValuePreceding(10), CurrentRow), + ("between current row and 5 following", CurrentRow, ValueFollowing(5)), + ("between 10 preceding and 5 following", ValuePreceding(10), ValueFollowing(5)) + ) + frameTypes.foreach { + case (frameTypeSql, frameType) => + boundaries.foreach { + case (boundarySql, begin, end) => + val query = s"foo(*) over (partition by a order by b $frameTypeSql $boundarySql)" + val expr = windowed(Seq('a), Seq('b.asc), SpecifiedWindowFrame(frameType, begin, end)) + assertEqual(query, expr) + } + } + + // We cannot use non integer constants. + intercept("foo(*) over (partition by a order by b rows 10.0 preceding)", + "Frame bound value must be a constant integer.") + + // We cannot use an arbitrary expression. + intercept("foo(*) over (partition by a order by b rows exp(b) preceding)", + "Frame bound value must be a constant integer.") + } + + test("row constructor") { + // Note that '(a)' will be interpreted as a nested expression. + assertEqual("(a, b)", CreateStruct(Seq('a, 'b))) + assertEqual("(a, b, c)", CreateStruct(Seq('a, 'b, 'c))) + } + + test("scalar sub-query") { + assertEqual( + "(select max(val) from tbl) > current", + ScalarSubquery(table("tbl").select('max.function('val))) > 'current) + assertEqual( + "a = (select b from s)", + 'a === ScalarSubquery(table("s").select('b))) + } + + test("case when") { + assertEqual("case a when 1 then b when 2 then c else d end", + CaseKeyWhen('a, Seq(1, 'b, 2, 'c, 'd))) + assertEqual("case when a = 1 then b when a = 2 then c else d end", + CaseWhen(Seq(('a === 1, 'b.expr), ('a === 2, 'c.expr)), 'd)) + } + + test("dereference") { + assertEqual("a.b", UnresolvedAttribute("a.b")) + assertEqual("`select`.b", UnresolvedAttribute("select.b")) + assertEqual("(a + b).b", ('a + 'b).getField("b")) // This will fail analysis. + assertEqual("struct(a, b).b", 'struct.function('a, 'b).getField("b")) + } + + test("reference") { + // Regular + assertEqual("a", 'a) + + // Starting with a digit. + assertEqual("1a", Symbol("1a")) + + // Quoted using a keyword. + assertEqual("`select`", 'select) + + // Unquoted using an unreserved keyword. + assertEqual("columns", 'columns) + } + + test("subscript") { + assertEqual("a[b]", 'a.getItem('b)) + assertEqual("a[1 + 1]", 'a.getItem(Literal(1) + 1)) + assertEqual("`c`.a[b]", UnresolvedAttribute("c.a").getItem('b)) + } + + test("parenthesis") { + assertEqual("(a)", 'a) + assertEqual("r * (a + b)", 'r * ('a + 'b)) + } + + test("type constructors") { + // Dates. + assertEqual("dAte '2016-03-11'", Literal(Date.valueOf("2016-03-11"))) + intercept[IllegalArgumentException] { + parseExpression("DAtE 'mar 11 2016'") + } + + // Timestamps. + assertEqual("tImEstAmp '2016-03-11 20:54:00.000'", + Literal(Timestamp.valueOf("2016-03-11 20:54:00.000"))) + intercept[IllegalArgumentException] { + parseExpression("timestamP '2016-33-11 20:54:00.000'") + } + + // Unsupported datatype. + intercept("GEO '(10,-6)'", "Literals of type 'GEO' are currently not supported.") + } + + test("literals") { + // NULL + assertEqual("null", Literal(null)) + + // Boolean + assertEqual("trUe", Literal(true)) + assertEqual("False", Literal(false)) + + // Integral should have the narrowest possible type + assertEqual("787324", Literal(787324)) + assertEqual("7873247234798249234", Literal(7873247234798249234L)) + assertEqual("78732472347982492793712334", + Literal(BigDecimal("78732472347982492793712334").underlying())) + + // Decimal + assertEqual("7873247234798249279371.2334", + Literal(BigDecimal("7873247234798249279371.2334").underlying())) + + // Scientific Decimal + assertEqual("9.0e1", 90d) + assertEqual(".9e+2", 90d) + assertEqual("0.9e+2", 90d) + assertEqual("900e-1", 90d) + assertEqual("900.0E-1", 90d) + assertEqual("9.e+1", 90d) + intercept(".e3") + + // Tiny Int Literal + assertEqual("10Y", Literal(10.toByte)) + intercept("-1000Y") + + // Small Int Literal + assertEqual("10S", Literal(10.toShort)) + intercept("40000S") + + // Long Int Literal + assertEqual("10L", Literal(10L)) + intercept("78732472347982492793712334L") + + // Double Literal + assertEqual("10.0D", Literal(10.0D)) + // TODO we need to figure out if we should throw an exception here! + assertEqual("1E309", Literal(Double.PositiveInfinity)) + } + + test("strings") { + // Single Strings. + assertEqual("\"hello\"", "hello") + assertEqual("'hello'", "hello") + + // Multi-Strings. + assertEqual("\"hello\" 'world'", "helloworld") + assertEqual("'hello' \" \" 'world'", "hello world") + + // 'LIKE' string literals. Notice that an escaped '%' is the same as an escaped '\' and a + // regular '%'; to get the correct result you need to add another escaped '\'. + // TODO figure out if we shouldn't change the ParseUtils.unescapeSQLString method? + assertEqual("'pattern%'", "pattern%") + assertEqual("'no-pattern\\%'", "no-pattern\\%") + assertEqual("'pattern\\\\%'", "pattern\\%") + assertEqual("'pattern\\\\\\%'", "pattern\\\\%") + + // Escaped characters. + // See: http://dev.mysql.com/doc/refman/5.7/en/string-literals.html + assertEqual("'\\0'", "\u0000") // ASCII NUL (X'00') + assertEqual("'\\''", "\'") // Single quote + assertEqual("'\\\"'", "\"") // Double quote + assertEqual("'\\b'", "\b") // Backspace + assertEqual("'\\n'", "\n") // Newline + assertEqual("'\\r'", "\r") // Carriage return + assertEqual("'\\t'", "\t") // Tab character + assertEqual("'\\Z'", "\u001A") // ASCII 26 - CTRL + Z (EOF on windows) + + // Octals + assertEqual("'\\110\\145\\154\\154\\157\\041'", "Hello!") + + // Unicode + assertEqual("'\\u0087\\u0111\\u0114\\u0108\\u0100\\u0032\\u0058\\u0041'", "World :)") + } + + test("intervals") { + def intervalLiteral(u: String, s: String): Literal = { + Literal(CalendarInterval.fromSingleUnitString(u, s)) + } + + // Empty interval statement + intercept("interval", "at least one time unit should be given for interval literal") + + // Single Intervals. + val units = Seq( + "year", + "month", + "week", + "day", + "hour", + "minute", + "second", + "millisecond", + "microsecond") + val forms = Seq("", "s") + val values = Seq("0", "10", "-7", "21") + units.foreach { unit => + forms.foreach { form => + values.foreach { value => + val expected = intervalLiteral(unit, value) + assertEqual(s"interval $value $unit$form", expected) + assertEqual(s"interval '$value' $unit$form", expected) + } + } + } + + // Hive nanosecond notation. + assertEqual("interval 13.123456789 seconds", intervalLiteral("second", "13.123456789")) + assertEqual("interval -13.123456789 second", intervalLiteral("second", "-13.123456789")) + + // Non Existing unit + intercept("interval 10 nanoseconds", "No interval can be constructed") + + // Year-Month intervals. + val yearMonthValues = Seq("123-10", "496-0", "-2-3", "-123-0") + yearMonthValues.foreach { value => + val result = Literal(CalendarInterval.fromYearMonthString(value)) + assertEqual(s"interval '$value' year to month", result) + } + + // Day-Time intervals. + val datTimeValues = Seq( + "99 11:22:33.123456789", + "-99 11:22:33.123456789", + "10 9:8:7.123456789", + "1 0:0:0", + "-1 0:0:0", + "1 0:0:1") + datTimeValues.foreach { value => + val result = Literal(CalendarInterval.fromDayTimeString(value)) + assertEqual(s"interval '$value' day to second", result) + } + + // Unknown FROM TO intervals + intercept("interval 10 month to second", "Intervals FROM month TO second are not supported.") + + // Composed intervals. + assertEqual( + "interval 3 months 22 seconds 1 millisecond", + Literal(new CalendarInterval(3, 22001000L))) + assertEqual( + "interval 3 years '-1-10' year to month 3 weeks '1 0:0:2' day to second", + Literal(new CalendarInterval(14, + 22 * CalendarInterval.MICROS_PER_DAY + 2 * CalendarInterval.MICROS_PER_SECOND))) + } + + test("composed expressions") { + assertEqual("1 + r.r As q", (Literal(1) + UnresolvedAttribute("r.r")).as("q")) + assertEqual("1 - f('o', o(bar))", Literal(1) - 'f.function("o", 'o.function('bar))) + intercept("1 - f('o', o(bar)) hello * world", "mismatched input '*'") + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ng/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ng/PlanParserSuite.scala new file mode 100644 index 0000000000..4206d22ca7 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ng/PlanParserSuite.scala @@ -0,0 +1,429 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.parser.ng + +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.types.{BooleanType, IntegerType} + +class PlanParserSuite extends PlanTest { + import CatalystSqlParser._ + import org.apache.spark.sql.catalyst.dsl.expressions._ + import org.apache.spark.sql.catalyst.dsl.plans._ + + def assertEqual(sqlCommand: String, plan: LogicalPlan): Unit = { + comparePlans(parsePlan(sqlCommand), plan) + } + + def intercept(sqlCommand: String, messages: String*): Unit = { + val e = intercept[ParseException](parsePlan(sqlCommand)) + messages.foreach { message => + assert(e.message.contains(message)) + } + } + + test("case insensitive") { + val plan = table("a").select(star()) + assertEqual("sELEct * FroM a", plan) + assertEqual("select * fRoM a", plan) + assertEqual("SELECT * FROM a", plan) + } + + test("show functions") { + assertEqual("show functions", ShowFunctions(None, None)) + assertEqual("show functions foo", ShowFunctions(None, Some("foo"))) + assertEqual("show functions foo.bar", ShowFunctions(Some("foo"), Some("bar"))) + assertEqual("show functions 'foo\\\\.*'", ShowFunctions(None, Some("foo\\.*"))) + intercept("show functions foo.bar.baz", "SHOW FUNCTIONS unsupported name") + } + + test("describe function") { + assertEqual("describe function bar", DescribeFunction("bar", isExtended = false)) + assertEqual("describe function extended bar", DescribeFunction("bar", isExtended = true)) + assertEqual("describe function foo.bar", DescribeFunction("foo.bar", isExtended = false)) + assertEqual("describe function extended f.bar", DescribeFunction("f.bar", isExtended = true)) + } + + test("set operations") { + val a = table("a").select(star()) + val b = table("b").select(star()) + + assertEqual("select * from a union select * from b", Distinct(a.union(b))) + assertEqual("select * from a union distinct select * from b", Distinct(a.union(b))) + assertEqual("select * from a union all select * from b", a.union(b)) + assertEqual("select * from a except select * from b", a.except(b)) + intercept("select * from a except all select * from b", "EXCEPT ALL is not supported.") + assertEqual("select * from a except distinct select * from b", a.except(b)) + assertEqual("select * from a intersect select * from b", a.intersect(b)) + intercept("select * from a intersect all select * from b", "INTERSECT ALL is not supported.") + assertEqual("select * from a intersect distinct select * from b", a.intersect(b)) + } + + test("common table expressions") { + def cte(plan: LogicalPlan, namedPlans: (String, LogicalPlan)*): With = { + val ctes = namedPlans.map { + case (name, cte) => + name -> SubqueryAlias(name, cte) + }.toMap + With(plan, ctes) + } + assertEqual( + "with cte1 as (select * from a) select * from cte1", + cte(table("cte1").select(star()), "cte1" -> table("a").select(star()))) + assertEqual( + "with cte1 (select 1) select * from cte1", + cte(table("cte1").select(star()), "cte1" -> OneRowRelation.select(1))) + assertEqual( + "with cte1 (select 1), cte2 as (select * from cte1) select * from cte2", + cte(table("cte2").select(star()), + "cte1" -> OneRowRelation.select(1), + "cte2" -> table("cte1").select(star()))) + intercept( + "with cte1 (select 1), cte1 as (select 1 from cte1) select * from cte1", + "Name 'cte1' is used for multiple common table expressions") + } + + test("simple select query") { + assertEqual("select 1", OneRowRelation.select(1)) + assertEqual("select a, b", OneRowRelation.select('a, 'b)) + assertEqual("select a, b from db.c", table("db", "c").select('a, 'b)) + assertEqual("select a, b from db.c where x < 1", table("db", "c").where('x < 1).select('a, 'b)) + assertEqual( + "select a, b from db.c having x < 1", + table("db", "c").select('a, 'b).where(('x < 1).cast(BooleanType))) + assertEqual("select distinct a, b from db.c", Distinct(table("db", "c").select('a, 'b))) + assertEqual("select all a, b from db.c", table("db", "c").select('a, 'b)) + } + + test("reverse select query") { + assertEqual("from a", table("a")) + assertEqual("from a select b, c", table("a").select('b, 'c)) + assertEqual( + "from db.a select b, c where d < 1", table("db", "a").where('d < 1).select('b, 'c)) + assertEqual("from a select distinct b, c", Distinct(table("a").select('b, 'c))) + assertEqual( + "from (from a union all from b) c select *", + table("a").union(table("b")).as("c").select(star())) + } + + test("transform query spec") { + val p = ScriptTransformation(Seq('a, 'b), "func", Seq.empty, table("e"), null) + assertEqual("select transform(a, b) using 'func' from e where f < 10", + p.copy(child = p.child.where('f < 10), output = Seq('key.string, 'value.string))) + assertEqual("map a, b using 'func' as c, d from e", + p.copy(output = Seq('c.string, 'd.string))) + assertEqual("reduce a, b using 'func' as (c: int, d decimal(10, 0)) from e", + p.copy(output = Seq('c.int, 'd.decimal(10, 0)))) + } + + test("multi select query") { + assertEqual( + "from a select * select * where s < 10", + table("a").select(star()).union(table("a").where('s < 10).select(star()))) + intercept( + "from a select * select * from x where a.s < 10", + "Multi-Insert queries cannot have a FROM clause in their individual SELECT statements") + assertEqual( + "from a insert into tbl1 select * insert into tbl2 select * where s < 10", + table("a").select(star()).insertInto("tbl1").union( + table("a").where('s < 10).select(star()).insertInto("tbl2"))) + } + + test("query organization") { + // Test all valid combinations of order by/sort by/distribute by/cluster by/limit/windows + val baseSql = "select * from t" + val basePlan = table("t").select(star()) + + val ws = Map("w1" -> WindowSpecDefinition(Seq.empty, Seq.empty, UnspecifiedFrame)) + val limitWindowClauses = Seq( + ("", (p: LogicalPlan) => p), + (" limit 10", (p: LogicalPlan) => p.limit(10)), + (" window w1 as ()", (p: LogicalPlan) => WithWindowDefinition(ws, p)), + (" window w1 as () limit 10", (p: LogicalPlan) => WithWindowDefinition(ws, p).limit(10)) + ) + + val orderSortDistrClusterClauses = Seq( + ("", basePlan), + (" order by a, b desc", basePlan.orderBy('a.asc, 'b.desc)), + (" sort by a, b desc", basePlan.sortBy('a.asc, 'b.desc)), + (" distribute by a, b", basePlan.distribute('a, 'b)), + (" distribute by a sort by b", basePlan.distribute('a).sortBy('b.asc)), + (" cluster by a, b", basePlan.distribute('a, 'b).sortBy('a.asc, 'b.asc)) + ) + + orderSortDistrClusterClauses.foreach { + case (s1, p1) => + limitWindowClauses.foreach { + case (s2, pf2) => + assertEqual(baseSql + s1 + s2, pf2(p1)) + } + } + + val msg = "Combination of ORDER BY/SORT BY/DISTRIBUTE BY/CLUSTER BY is not supported" + intercept(s"$baseSql order by a sort by a", msg) + intercept(s"$baseSql cluster by a distribute by a", msg) + intercept(s"$baseSql order by a cluster by a", msg) + intercept(s"$baseSql order by a distribute by a", msg) + } + + test("insert into") { + val sql = "select * from t" + val plan = table("t").select(star()) + def insert( + partition: Map[String, Option[String]], + overwrite: Boolean = false, + ifNotExists: Boolean = false): LogicalPlan = + InsertIntoTable(table("s"), partition, plan, overwrite, ifNotExists) + + // Single inserts + assertEqual(s"insert overwrite table s $sql", + insert(Map.empty, overwrite = true)) + assertEqual(s"insert overwrite table s if not exists $sql", + insert(Map.empty, overwrite = true, ifNotExists = true)) + assertEqual(s"insert into s $sql", + insert(Map.empty)) + assertEqual(s"insert into table s partition (c = 'd', e = 1) $sql", + insert(Map("c" -> Option("d"), "e" -> Option("1")))) + assertEqual(s"insert overwrite table s partition (c = 'd', x) if not exists $sql", + insert(Map("c" -> Option("d"), "x" -> None), overwrite = true, ifNotExists = true)) + + // Multi insert + val plan2 = table("t").where('x > 5).select(star()) + assertEqual("from t insert into s select * limit 1 insert into u select * where x > 5", + InsertIntoTable( + table("s"), Map.empty, plan.limit(1), overwrite = false, ifNotExists = false).union( + InsertIntoTable( + table("u"), Map.empty, plan2, overwrite = false, ifNotExists = false))) + } + + test("aggregation") { + val sql = "select a, b, sum(c) as c from d group by a, b" + + // Normal + assertEqual(sql, table("d").groupBy('a, 'b)('a, 'b, 'sum.function('c).as("c"))) + + // Cube + assertEqual(s"$sql with cube", + table("d").groupBy(Cube(Seq('a, 'b)))('a, 'b, 'sum.function('c).as("c"))) + + // Rollup + assertEqual(s"$sql with rollup", + table("d").groupBy(Rollup(Seq('a, 'b)))('a, 'b, 'sum.function('c).as("c"))) + + // Grouping Sets + assertEqual(s"$sql grouping sets((a, b), (a), ())", + GroupingSets(Seq(0, 1, 3), Seq('a, 'b), table("d"), Seq('a, 'b, 'sum.function('c).as("c")))) + intercept(s"$sql grouping sets((a, b), (c), ())", + "c doesn't show up in the GROUP BY list") + } + + test("limit") { + val sql = "select * from t" + val plan = table("t").select(star()) + assertEqual(s"$sql limit 10", plan.limit(10)) + assertEqual(s"$sql limit cast(9 / 4 as int)", plan.limit(Cast(Literal(9) / 4, IntegerType))) + } + + test("window spec") { + // Note that WindowSpecs are testing in the ExpressionParserSuite + val sql = "select * from t" + val plan = table("t").select(star()) + val spec = WindowSpecDefinition(Seq('a, 'b), Seq('c.asc), + SpecifiedWindowFrame(RowFrame, ValuePreceding(1), ValueFollowing(1))) + + // Test window resolution. + val ws1 = Map("w1" -> spec, "w2" -> spec, "w3" -> spec) + assertEqual( + s"""$sql + |window w1 as (partition by a, b order by c rows between 1 preceding and 1 following), + | w2 as w1, + | w3 as w1""".stripMargin, + WithWindowDefinition(ws1, plan)) + + // Fail with no reference. + intercept(s"$sql window w2 as w1", "Cannot resolve window reference 'w1'") + + // Fail when resolved reference is not a window spec. + intercept( + s"""$sql + |window w1 as (partition by a, b order by c rows between 1 preceding and 1 following), + | w2 as w1, + | w3 as w2""".stripMargin, + "Window reference 'w2' is not a window specification" + ) + } + + test("lateral view") { + // Single lateral view + assertEqual( + "select * from t lateral view explode(x) expl as x", + table("t") + .generate(Explode('x), join = true, outer = false, Some("expl"), Seq("x")) + .select(star())) + + // Multiple lateral views + assertEqual( + """select * + |from t + |lateral view explode(x) expl + |lateral view outer json_tuple(x, y) jtup q, z""".stripMargin, + table("t") + .generate(Explode('x), join = true, outer = false, Some("expl"), Seq.empty) + .generate(JsonTuple(Seq('x, 'y)), join = true, outer = true, Some("jtup"), Seq("q", "z")) + .select(star())) + + // Multi-Insert lateral views. + val from = table("t1").generate(Explode('x), join = true, outer = false, Some("expl"), Seq("x")) + assertEqual( + """from t1 + |lateral view explode(x) expl as x + |insert into t2 + |select * + |lateral view json_tuple(x, y) jtup q, z + |insert into t3 + |select * + |where s < 10 + """.stripMargin, + Union(from + .generate(JsonTuple(Seq('x, 'y)), join = true, outer = false, Some("jtup"), Seq("q", "z")) + .select(star()) + .insertInto("t2"), + from.where('s < 10).select(star()).insertInto("t3"))) + + // Unsupported generator. + intercept( + "select * from t lateral view posexplode(x) posexpl as x, y", + "Generator function 'posexplode' is not supported") + } + + test("joins") { + // Test single joins. + val testUnconditionalJoin = (sql: String, jt: JoinType) => { + assertEqual( + s"select * from t as tt $sql u", + table("t").as("tt").join(table("u"), jt, None).select(star())) + } + val testConditionalJoin = (sql: String, jt: JoinType) => { + assertEqual( + s"select * from t $sql u as uu on a = b", + table("t").join(table("u").as("uu"), jt, Option('a === 'b)).select(star())) + } + val testNaturalJoin = (sql: String, jt: JoinType) => { + assertEqual( + s"select * from t tt natural $sql u as uu", + table("t").as("tt").join(table("u").as("uu"), NaturalJoin(jt), None).select(star())) + } + val testUsingJoin = (sql: String, jt: JoinType) => { + assertEqual( + s"select * from t $sql u using(a, b)", + table("t").join(table("u"), UsingJoin(jt, Seq('a.attr, 'b.attr)), None).select(star())) + } + val testAll = Seq(testUnconditionalJoin, testConditionalJoin, testNaturalJoin, testUsingJoin) + + def test(sql: String, jt: JoinType, tests: Seq[(String, JoinType) => Unit]): Unit = { + tests.foreach(_(sql, jt)) + } + test("cross join", Inner, Seq(testUnconditionalJoin)) + test(",", Inner, Seq(testUnconditionalJoin)) + test("join", Inner, testAll) + test("inner join", Inner, testAll) + test("left join", LeftOuter, testAll) + test("left outer join", LeftOuter, testAll) + test("right join", RightOuter, testAll) + test("right outer join", RightOuter, testAll) + test("full join", FullOuter, testAll) + test("full outer join", FullOuter, testAll) + + // Test multiple consecutive joins + assertEqual( + "select * from a join b join c right join d", + table("a").join(table("b")).join(table("c")).join(table("d"), RightOuter).select(star())) + } + + test("sampled relations") { + val sql = "select * from t" + assertEqual(s"$sql tablesample(100 rows)", + table("t").limit(100).select(star())) + assertEqual(s"$sql tablesample(43 percent) as x", + Sample(0, .43d, withReplacement = false, 10L, table("t").as("x"))(true).select(star())) + assertEqual(s"$sql tablesample(bucket 4 out of 10) as x", + Sample(0, .4d, withReplacement = false, 10L, table("t").as("x"))(true).select(star())) + intercept(s"$sql tablesample(bucket 4 out of 10 on x) as x", + "TABLESAMPLE(BUCKET x OUT OF y ON id) is not supported") + intercept(s"$sql tablesample(bucket 11 out of 10) as x", + s"Sampling fraction (${11.0/10.0}) must be on interval [0, 1]") + } + + test("sub-query") { + val plan = table("t0").select('id) + assertEqual("select id from (t0)", plan) + assertEqual("select id from ((((((t0))))))", plan) + assertEqual( + "(select * from t1) union distinct (select * from t2)", + Distinct(table("t1").select(star()).union(table("t2").select(star())))) + assertEqual( + "select * from ((select * from t1) union (select * from t2)) t", + Distinct( + table("t1").select(star()).union(table("t2").select(star()))).as("t").select(star())) + assertEqual( + """select id + |from (((select id from t0) + | union all + | (select id from t0)) + | union all + | (select id from t0)) as u_1 + """.stripMargin, + plan.union(plan).union(plan).as("u_1").select('id)) + } + + test("scalar sub-query") { + assertEqual( + "select (select max(b) from s) ss from t", + table("t").select(ScalarSubquery(table("s").select('max.function('b))).as("ss"))) + assertEqual( + "select * from t where a = (select b from s)", + table("t").where('a === ScalarSubquery(table("s").select('b))).select(star())) + assertEqual( + "select g from t group by g having a > (select b from s)", + table("t") + .groupBy('g)('g) + .where(('a > ScalarSubquery(table("s").select('b))).cast(BooleanType))) + } + + test("table reference") { + assertEqual("table t", table("t")) + assertEqual("table d.t", table("d", "t")) + } + + test("inline table") { + assertEqual("values 1, 2, 3, 4", LocalRelation.fromExternalRows( + Seq('col1.int), + Seq(1, 2, 3, 4).map(x => Row(x)))) + assertEqual( + "values (1, 'a'), (2, 'b'), (3, 'c') as tbl(a, b)", + LocalRelation.fromExternalRows( + Seq('a.int, 'b.string), + Seq((1, "a"), (2, "b"), (3, "c")).map(x => Row(x._1, x._2))).as("tbl")) + intercept("values (a, 'a'), (b, 'b')", + "All expressions in an inline table must be constants.") + intercept("values (1, 'a'), (2, 'b') as tbl(a, b, c)", + "Number of aliases must match the number of fields in an inline table.") + intercept[ArrayIndexOutOfBoundsException](parsePlan("values (1, 'a'), (2, 'b', 5Y)")) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ng/TableIdentifierParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ng/TableIdentifierParserSuite.scala new file mode 100644 index 0000000000..0874322187 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ng/TableIdentifierParserSuite.scala @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.parser.ng + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.TableIdentifier + +class TableIdentifierParserSuite extends SparkFunSuite { + import CatalystSqlParser._ + + test("table identifier") { + // Regular names. + assert(TableIdentifier("q") === parseTableIdentifier("q")) + assert(TableIdentifier("q", Option("d")) === parseTableIdentifier("d.q")) + + // Illegal names. + intercept[ParseException](parseTableIdentifier("")) + intercept[ParseException](parseTableIdentifier("d.q.g")) + + // SQL Keywords. + val keywords = Seq("select", "from", "where", "left", "right") + keywords.foreach { keyword => + intercept[ParseException](parseTableIdentifier(keyword)) + assert(TableIdentifier(keyword) === parseTableIdentifier(s"`$keyword`")) + assert(TableIdentifier(keyword, Option("db")) === parseTableIdentifier(s"db.`$keyword`")) + } + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala index 0541844e0b..aa5d4330d3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.plans import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, OneRowRelation} +import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, OneRowRelation, Sample} import org.apache.spark.sql.catalyst.util._ /** @@ -32,6 +32,8 @@ abstract class PlanTest extends SparkFunSuite with PredicateHelper { */ protected def normalizeExprIds(plan: LogicalPlan) = { plan transformAllExpressions { + case s: ScalarSubquery => + ScalarSubquery(s.query, ExprId(0)) case a: AttributeReference => AttributeReference(a.name, a.dataType, a.nullable)(exprId = ExprId(0)) case a: Alias => @@ -40,21 +42,25 @@ abstract class PlanTest extends SparkFunSuite with PredicateHelper { } /** - * Normalizes the filter conditions that appear in the plan. For instance, - * ((expr 1 && expr 2) && expr 3), (expr 1 && expr 2 && expr 3), (expr 3 && (expr 1 && expr 2) - * etc., will all now be equivalent. + * Normalizes plans: + * - Filter the filter conditions that appear in a plan. For instance, + * ((expr 1 && expr 2) && expr 3), (expr 1 && expr 2 && expr 3), (expr 3 && (expr 1 && expr 2) + * etc., will all now be equivalent. + * - Sample the seed will replaced by 0L. */ - private def normalizeFilters(plan: LogicalPlan) = { + private def normalizePlan(plan: LogicalPlan): LogicalPlan = { plan transform { case filter @ Filter(condition: Expression, child: LogicalPlan) => Filter(splitConjunctivePredicates(condition).sortBy(_.hashCode()).reduce(And), child) + case sample: Sample => + sample.copy(seed = 0L)(true) } } /** Fails the test if the two plans do not match */ protected def comparePlans(plan1: LogicalPlan, plan2: LogicalPlan) { - val normalized1 = normalizeFilters(normalizeExprIds(plan1)) - val normalized2 = normalizeFilters(normalizeExprIds(plan2)) + val normalized1 = normalizePlan(normalizeExprIds(plan1)) + val normalized2 = normalizePlan(normalizeExprIds(plan2)) if (normalized1 != normalized2) { fail( s""" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala new file mode 100644 index 0000000000..c098fa99c2 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -0,0 +1,219 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution + +import scala.collection.JavaConverters._ + +import org.apache.spark.sql.SaveMode +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.parser.ng.{AbstractSqlParser, AstBuilder} +import org.apache.spark.sql.catalyst.parser.ng.SqlBaseParser._ +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, OneRowRelation} +import org.apache.spark.sql.execution.command.{DescribeCommand => _, _} +import org.apache.spark.sql.execution.datasources._ + +/** + * Concrete parser for Spark SQL statements. + */ +object SparkSqlParser extends AbstractSqlParser{ + val astBuilder = new SparkSqlAstBuilder +} + +/** + * Builder that converts an ANTLR ParseTree into a LogicalPlan/Expression/TableIdentifier. + */ +class SparkSqlAstBuilder extends AstBuilder { + import org.apache.spark.sql.catalyst.parser.ng.ParserUtils._ + + /** + * Create a [[SetCommand]] logical plan. + * + * Note that we assume that everything after the SET keyword is assumed to be a part of the + * key-value pair. The split between key and value is made by searching for the first `=` + * character in the raw string. + */ + override def visitSetConfiguration(ctx: SetConfigurationContext): LogicalPlan = withOrigin(ctx) { + // Construct the command. + val raw = remainder(ctx.SET.getSymbol) + val keyValueSeparatorIndex = raw.indexOf('=') + if (keyValueSeparatorIndex >= 0) { + val key = raw.substring(0, keyValueSeparatorIndex).trim + val value = raw.substring(keyValueSeparatorIndex + 1).trim + SetCommand(Some(key -> Option(value))) + } else if (raw.nonEmpty) { + SetCommand(Some(raw.trim -> None)) + } else { + SetCommand(None) + } + } + + /** + * Create a [[SetDatabaseCommand]] logical plan. + */ + override def visitUse(ctx: UseContext): LogicalPlan = withOrigin(ctx) { + SetDatabaseCommand(ctx.db.getText) + } + + /** + * Create a [[ShowTablesCommand]] logical plan. + */ + override def visitShowTables(ctx: ShowTablesContext): LogicalPlan = withOrigin(ctx) { + if (ctx.LIKE != null) { + logWarning("SHOW TABLES LIKE option is ignored.") + } + ShowTablesCommand(Option(ctx.db).map(_.getText)) + } + + /** + * Create a [[RefreshTable]] logical plan. + */ + override def visitRefreshTable(ctx: RefreshTableContext): LogicalPlan = withOrigin(ctx) { + RefreshTable(visitTableIdentifier(ctx.tableIdentifier)) + } + + /** + * Create a [[CacheTableCommand]] logical plan. + */ + override def visitCacheTable(ctx: CacheTableContext): LogicalPlan = withOrigin(ctx) { + val query = Option(ctx.query).map(plan) + CacheTableCommand(ctx.identifier.getText, query, ctx.LAZY != null) + } + + /** + * Create an [[UncacheTableCommand]] logical plan. + */ + override def visitUncacheTable(ctx: UncacheTableContext): LogicalPlan = withOrigin(ctx) { + UncacheTableCommand(ctx.identifier.getText) + } + + /** + * Create a [[ClearCacheCommand]] logical plan. + */ + override def visitClearCache(ctx: ClearCacheContext): LogicalPlan = withOrigin(ctx) { + ClearCacheCommand + } + + /** + * Create an [[ExplainCommand]] logical plan. + */ + override def visitExplain(ctx: ExplainContext): LogicalPlan = withOrigin(ctx) { + val options = ctx.explainOption.asScala + if (options.exists(_.FORMATTED != null)) { + logWarning("EXPLAIN FORMATTED option is ignored.") + } + if (options.exists(_.LOGICAL != null)) { + logWarning("EXPLAIN LOGICAL option is ignored.") + } + + // Create the explain comment. + val statement = plan(ctx.statement) + if (isExplainableStatement(statement)) { + ExplainCommand(statement, extended = options.exists(_.EXTENDED != null)) + } else { + ExplainCommand(OneRowRelation) + } + } + + /** + * Determine if a plan should be explained at all. + */ + protected def isExplainableStatement(plan: LogicalPlan): Boolean = plan match { + case _: datasources.DescribeCommand => false + case _ => true + } + + /** + * Create a [[DescribeCommand]] logical plan. + */ + override def visitDescribeTable(ctx: DescribeTableContext): LogicalPlan = withOrigin(ctx) { + // FORMATTED and columns are not supported. Return null and let the parser decide what to do + // with this (create an exception or pass it on to a different system). + if (ctx.describeColName != null || ctx.FORMATTED != null || ctx.partitionSpec != null) { + null + } else { + datasources.DescribeCommand( + visitTableIdentifier(ctx.tableIdentifier), + ctx.EXTENDED != null) + } + } + + /** Type to keep track of a table header. */ + type TableHeader = (TableIdentifier, Boolean, Boolean, Boolean) + + /** + * Validate a create table statement and return the [[TableIdentifier]]. + */ + override def visitCreateTableHeader( + ctx: CreateTableHeaderContext): TableHeader = withOrigin(ctx) { + val temporary = ctx.TEMPORARY != null + val ifNotExists = ctx.EXISTS != null + assert(!temporary || !ifNotExists, + "a CREATE TEMPORARY TABLE statement does not allow IF NOT EXISTS clause.", + ctx) + (visitTableIdentifier(ctx.tableIdentifier), temporary, ifNotExists, ctx.EXTERNAL != null) + } + + /** + * Create a [[CreateTableUsing]] or a [[CreateTableUsingAsSelect]] logical plan. + * + * TODO add bucketing and partitioning. + */ + override def visitCreateTableUsing(ctx: CreateTableUsingContext): LogicalPlan = withOrigin(ctx) { + val (table, temp, ifNotExists, external) = visitCreateTableHeader(ctx.createTableHeader) + if (external) { + logWarning("EXTERNAL option is not supported.") + } + val options = Option(ctx.tablePropertyList).map(visitTablePropertyList).getOrElse(Map.empty) + val provider = ctx.tableProvider.qualifiedName.getText + + if (ctx.query != null) { + // Get the backing query. + val query = plan(ctx.query) + + // Determine the storage mode. + val mode = if (ifNotExists) { + SaveMode.Ignore + } else if (temp) { + SaveMode.Overwrite + } else { + SaveMode.ErrorIfExists + } + CreateTableUsingAsSelect(table, provider, temp, Array.empty, None, mode, options, query) + } else { + val struct = Option(ctx.colTypeList).map(createStructType) + CreateTableUsing(table, struct, provider, temp, options, ifNotExists, managedIfNoPath = false) + } + } + + /** + * Convert a table property list into a key-value map. + */ + override def visitTablePropertyList( + ctx: TablePropertyListContext): Map[String, String] = withOrigin(ctx) { + ctx.tableProperty.asScala.map { property => + // A key can either be a String or a collection of dot separated elements. We need to treat + // these differently. + val key = if (property.key.STRING != null) { + string(property.key.STRING) + } else { + property.key.getText + } + val value = Option(property.value).map(string).orNull + key -> value + }.toMap + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 8abb9d7e4a..7ce15e3f35 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -27,8 +27,8 @@ import org.apache.spark.sql.catalyst.analysis.{Star, UnresolvedFunction} import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ -import org.apache.spark.sql.catalyst.parser.CatalystQl import org.apache.spark.sql.catalyst.plans.logical.BroadcastHint +import org.apache.spark.sql.execution.SparkSqlParser import org.apache.spark.sql.expressions.UserDefinedFunction import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -1172,8 +1172,7 @@ object functions { * @group normal_funcs */ def expr(expr: String): Column = { - val parser = SQLContext.getActive().map(_.sessionState.sqlParser).getOrElse(new CatalystQl()) - Column(parser.parseExpression(expr)) + Column(SparkSqlParser.parseExpression(expr)) } ////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala index e5f02caabc..9bc640763f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala @@ -81,7 +81,7 @@ private[sql] class SessionState(ctx: SQLContext) { /** * Parser that extracts expressions, plans, table identifiers etc. from SQL texts. */ - lazy val sqlParser: ParserInterface = new SparkQl(conf) + lazy val sqlParser: ParserInterface = SparkSqlParser /** * Planner that converts optimized logical plans to physical plans. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 5af1a4fcd7..a5a4ff13de 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -329,8 +329,8 @@ class JoinSuite extends QueryTest with SharedSQLContext { } test("full outer join") { - upperCaseData.where('N <= 4).registerTempTable("left") - upperCaseData.where('N >= 3).registerTempTable("right") + upperCaseData.where('N <= 4).registerTempTable("`left`") + upperCaseData.where('N >= 3).registerTempTable("`right`") val left = UnresolvedRelation(TableIdentifier("left"), None) val right = UnresolvedRelation(TableIdentifier("right"), None) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index c958eac266..b727e88668 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -1656,7 +1656,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { val e2 = intercept[AnalysisException] { sql("select interval 23 nanosecond") } - assert(e2.message.contains("cannot recognize input near")) + assert(e2.message.contains("No interval can be constructed")) } test("SPARK-8945: add and subtract expressions for interval type") { |