aboutsummaryrefslogtreecommitdiff
path: root/sql/catalyst
diff options
context:
space:
mode:
authorJosh Rosen <joshrosen@databricks.com>2016-04-14 16:43:28 -0700
committerJosh Rosen <joshrosen@databricks.com>2016-04-14 16:43:28 -0700
commitee4090b60e8b6a350913d1d5049f0770c251cd4a (patch)
tree7e082fa815430c23e0387461be0726cc3e4d04b5 /sql/catalyst
parent2407f5b14edcdcf750113766d82e78732f9852d6 (diff)
parentd7e124edfe2578ecdf8e816a4dda3ce430a09172 (diff)
downloadspark-ee4090b60e8b6a350913d1d5049f0770c251cd4a.tar.gz
spark-ee4090b60e8b6a350913d1d5049f0770c251cd4a.tar.bz2
spark-ee4090b60e8b6a350913d1d5049f0770c251cd4a.zip
Merge remote-tracking branch 'origin/master' into build-for-2.12
Diffstat (limited to 'sql/catalyst')
-rw-r--r--sql/catalyst/pom.xml13
-rw-r--r--sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/ExpressionParser.g400
-rw-r--r--sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/FromClauseParser.g341
-rw-r--r--sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/IdentifiersParser.g184
-rw-r--r--sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/KeywordParser.g244
-rw-r--r--sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SelectClauseParser.g235
-rw-r--r--sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlLexer.g491
-rw-r--r--sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g2596
-rw-r--r--sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4957
-rw-r--r--sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/parser/ParseUtils.java135
-rw-r--r--sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala231
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala314
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala32
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala58
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala494
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala4
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala77
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/NoSuchItemException.scala31
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala53
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala50
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala321
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/functionResources.scala61
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala62
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala53
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala91
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala34
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/errors/package.scala7
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala5
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala4
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala4
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala8
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala6
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala168
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala14
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala6
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala4
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala5
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala7
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala3
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala37
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala120
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala12
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatter.scala38
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala42
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala10
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala10
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala15
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala82
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala4
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/grouping.scala18
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala6
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala129
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala6
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala8
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala11
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala86
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala7
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala124
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala4
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala14
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala117
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala6
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/identifiers.scala16
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala286
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ASTNode.scala99
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AbstractSparkSQLParser.scala145
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala1455
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/CatalystQl.scala933
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/DataTypeParser.scala67
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala245
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserConf.scala26
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala281
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala104
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala50
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala17
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala4
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala25
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala111
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala6
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala27
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringKeyHashMap.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala23
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala9
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala6
-rw-r--r--sql/catalyst/src/test/resources/log4j.properties3
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala18
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/RowTest.scala2
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala4
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala56
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala6
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala6
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/CatalogTestCases.scala61
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala275
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala4
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala1
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TimeWindowSuite.scala111
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatterSuite.scala55
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BinaryComparisonSimplificationSuite.scala95
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala3
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala2
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsSuite.scala4
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala76
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala4
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LikeSimplificationSuite.scala14
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerExtendableSuite.scala14
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PruneFiltersSuite.scala2
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala16
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/TypedFilterOptimizationSuite.scala74
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/CatalystQlSuite.scala243
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DataTypeParserSuite.scala54
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ErrorParserSuite.scala67
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala497
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ParserUtilsSuite.scala65
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala431
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala (renamed from sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ASTNodeSuite.scala)34
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala126
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala23
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/SameResultSuite.scala7
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/StringUtilsSuite.scala12
128 files changed, 7507 insertions, 7622 deletions
diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml
index 5d1d9edd25..1748fa2778 100644
--- a/sql/catalyst/pom.xml
+++ b/sql/catalyst/pom.xml
@@ -73,7 +73,7 @@
</dependency>
<dependency>
<groupId>org.antlr</groupId>
- <artifactId>antlr-runtime</artifactId>
+ <artifactId>antlr4-runtime</artifactId>
</dependency>
<dependency>
<groupId>commons-codec</groupId>
@@ -113,20 +113,17 @@
</plugin>
<plugin>
<groupId>org.antlr</groupId>
- <artifactId>antlr3-maven-plugin</artifactId>
+ <artifactId>antlr4-maven-plugin</artifactId>
<executions>
<execution>
<goals>
- <goal>antlr</goal>
+ <goal>antlr4</goal>
</goals>
</execution>
</executions>
<configuration>
- <sourceDirectory>../catalyst/src/main/antlr3</sourceDirectory>
- <includes>
- <include>**/SparkSqlLexer.g</include>
- <include>**/SparkSqlParser.g</include>
- </includes>
+ <visitor>true</visitor>
+ <sourceDirectory>../catalyst/src/main/antlr4</sourceDirectory>
</configuration>
</plugin>
</plugins>
diff --git a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/ExpressionParser.g b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/ExpressionParser.g
deleted file mode 100644
index 13a6a2d276..0000000000
--- a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/ExpressionParser.g
+++ /dev/null
@@ -1,400 +0,0 @@
-/**
- Licensed to the Apache Software Foundation (ASF) under one or more
- contributor license agreements. See the NOTICE file distributed with
- this work for additional information regarding copyright ownership.
- The ASF licenses this file to You under the Apache License, Version 2.0
- (the "License"); you may not use this file except in compliance with
- the License. You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License.
-
- This file is an adaptation of Hive's org/apache/hadoop/hive/ql/IdentifiersParser.g grammar.
-*/
-
-parser grammar ExpressionParser;
-
-options
-{
-output=AST;
-ASTLabelType=CommonTree;
-backtrack=false;
-k=3;
-}
-
-@members {
- @Override
- public Object recoverFromMismatchedSet(IntStream input,
- RecognitionException re, BitSet follow) throws RecognitionException {
- throw re;
- }
- @Override
- public void displayRecognitionError(String[] tokenNames,
- RecognitionException e) {
- gParent.displayRecognitionError(tokenNames, e);
- }
- protected boolean useSQL11ReservedKeywordsForIdentifier() {
- return gParent.useSQL11ReservedKeywordsForIdentifier();
- }
-}
-
-@rulecatch {
-catch (RecognitionException e) {
- throw e;
-}
-}
-
-// fun(par1, par2, par3)
-function
-@init { gParent.pushMsg("function specification", state); }
-@after { gParent.popMsg(state); }
- :
- functionName
- LPAREN
- (
- (STAR) => (star=STAR)
- | (dist=KW_DISTINCT)? (selectExpression (COMMA selectExpression)*)?
- )
- RPAREN (KW_OVER ws=window_specification)?
- -> {$star != null}? ^(TOK_FUNCTIONSTAR functionName $ws?)
- -> {$dist == null}? ^(TOK_FUNCTION functionName (selectExpression+)? $ws?)
- -> ^(TOK_FUNCTIONDI functionName (selectExpression+)? $ws?)
- ;
-
-functionName
-@init { gParent.pushMsg("function name", state); }
-@after { gParent.popMsg(state); }
- : // Keyword IF is also a function name
- (KW_IF | KW_ARRAY | KW_MAP | KW_STRUCT | KW_UNIONTYPE) => (KW_IF | KW_ARRAY | KW_MAP | KW_STRUCT | KW_UNIONTYPE)
- |
- (functionIdentifier) => functionIdentifier
- |
- {!useSQL11ReservedKeywordsForIdentifier()}? sql11ReservedKeywordsUsedAsCastFunctionName -> Identifier[$sql11ReservedKeywordsUsedAsCastFunctionName.text]
- ;
-
-castExpression
-@init { gParent.pushMsg("cast expression", state); }
-@after { gParent.popMsg(state); }
- :
- KW_CAST
- LPAREN
- expression
- KW_AS
- primitiveType
- RPAREN -> ^(TOK_FUNCTION primitiveType expression)
- ;
-
-caseExpression
-@init { gParent.pushMsg("case expression", state); }
-@after { gParent.popMsg(state); }
- :
- KW_CASE expression
- (KW_WHEN expression KW_THEN expression)+
- (KW_ELSE expression)?
- KW_END -> ^(TOK_FUNCTION KW_CASE expression*)
- ;
-
-whenExpression
-@init { gParent.pushMsg("case expression", state); }
-@after { gParent.popMsg(state); }
- :
- KW_CASE
- ( KW_WHEN expression KW_THEN expression)+
- (KW_ELSE expression)?
- KW_END -> ^(TOK_FUNCTION KW_WHEN expression*)
- ;
-
-constant
-@init { gParent.pushMsg("constant", state); }
-@after { gParent.popMsg(state); }
- :
- Number
- | dateLiteral
- | timestampLiteral
- | intervalLiteral
- | StringLiteral
- | stringLiteralSequence
- | BigintLiteral
- | SmallintLiteral
- | TinyintLiteral
- | DoubleLiteral
- | booleanValue
- ;
-
-stringLiteralSequence
- :
- StringLiteral StringLiteral+ -> ^(TOK_STRINGLITERALSEQUENCE StringLiteral StringLiteral+)
- ;
-
-dateLiteral
- :
- KW_DATE StringLiteral ->
- {
- // Create DateLiteral token, but with the text of the string value
- // This makes the dateLiteral more consistent with the other type literals.
- adaptor.create(TOK_DATELITERAL, $StringLiteral.text)
- }
- |
- KW_CURRENT_DATE -> ^(TOK_FUNCTION KW_CURRENT_DATE)
- ;
-
-timestampLiteral
- :
- KW_TIMESTAMP StringLiteral ->
- {
- adaptor.create(TOK_TIMESTAMPLITERAL, $StringLiteral.text)
- }
- |
- KW_CURRENT_TIMESTAMP -> ^(TOK_FUNCTION KW_CURRENT_TIMESTAMP)
- ;
-
-intervalLiteral
- :
- (KW_INTERVAL intervalConstant KW_YEAR KW_TO KW_MONTH) => KW_INTERVAL intervalConstant KW_YEAR KW_TO KW_MONTH
- -> ^(TOK_INTERVAL_YEAR_MONTH_LITERAL intervalConstant)
- | (KW_INTERVAL intervalConstant KW_DAY KW_TO KW_SECOND) => KW_INTERVAL intervalConstant KW_DAY KW_TO KW_SECOND
- -> ^(TOK_INTERVAL_DAY_TIME_LITERAL intervalConstant)
- | KW_INTERVAL
- ((intervalConstant KW_YEAR)=> year=intervalConstant KW_YEAR)?
- ((intervalConstant KW_MONTH)=> month=intervalConstant KW_MONTH)?
- ((intervalConstant KW_WEEK)=> week=intervalConstant KW_WEEK)?
- ((intervalConstant KW_DAY)=> day=intervalConstant KW_DAY)?
- ((intervalConstant KW_HOUR)=> hour=intervalConstant KW_HOUR)?
- ((intervalConstant KW_MINUTE)=> minute=intervalConstant KW_MINUTE)?
- ((intervalConstant KW_SECOND)=> second=intervalConstant KW_SECOND)?
- ((intervalConstant KW_MILLISECOND)=> millisecond=intervalConstant KW_MILLISECOND)?
- ((intervalConstant KW_MICROSECOND)=> microsecond=intervalConstant KW_MICROSECOND)?
- -> ^(TOK_INTERVAL
- ^(TOK_INTERVAL_YEAR_LITERAL $year?)
- ^(TOK_INTERVAL_MONTH_LITERAL $month?)
- ^(TOK_INTERVAL_WEEK_LITERAL $week?)
- ^(TOK_INTERVAL_DAY_LITERAL $day?)
- ^(TOK_INTERVAL_HOUR_LITERAL $hour?)
- ^(TOK_INTERVAL_MINUTE_LITERAL $minute?)
- ^(TOK_INTERVAL_SECOND_LITERAL $second?)
- ^(TOK_INTERVAL_MILLISECOND_LITERAL $millisecond?)
- ^(TOK_INTERVAL_MICROSECOND_LITERAL $microsecond?))
- ;
-
-intervalConstant
- :
- sign=(MINUS|PLUS)? value=Number -> {
- adaptor.create(Number, ($sign != null ? $sign.getText() : "") + $value.getText())
- }
- | StringLiteral
- ;
-
-expression
-@init { gParent.pushMsg("expression specification", state); }
-@after { gParent.popMsg(state); }
- :
- precedenceOrExpression
- ;
-
-atomExpression
- :
- (KW_NULL) => KW_NULL -> TOK_NULL
- | (constant) => constant
- | castExpression
- | caseExpression
- | whenExpression
- | (functionName LPAREN) => function
- | tableOrColumn
- | (LPAREN KW_SELECT) => subQueryExpression
- -> ^(TOK_SUBQUERY_EXPR ^(TOK_SUBQUERY_OP) subQueryExpression)
- | LPAREN! expression RPAREN!
- ;
-
-
-precedenceFieldExpression
- :
- atomExpression ((LSQUARE^ expression RSQUARE!) | (DOT^ identifier))*
- ;
-
-precedenceUnaryOperator
- :
- PLUS | MINUS | TILDE
- ;
-
-nullCondition
- :
- KW_NULL -> ^(TOK_ISNULL)
- | KW_NOT KW_NULL -> ^(TOK_ISNOTNULL)
- ;
-
-precedenceUnaryPrefixExpression
- :
- (precedenceUnaryOperator+)=> precedenceUnaryOperator^ precedenceUnaryPrefixExpression
- | precedenceFieldExpression
- ;
-
-precedenceUnarySuffixExpression
- :
- (
- (LPAREN precedenceUnaryPrefixExpression RPAREN) => LPAREN precedenceUnaryPrefixExpression (a=KW_IS nullCondition)? RPAREN
- |
- precedenceUnaryPrefixExpression (a=KW_IS nullCondition)?
- )
- -> {$a != null}? ^(TOK_FUNCTION nullCondition precedenceUnaryPrefixExpression)
- -> precedenceUnaryPrefixExpression
- ;
-
-
-precedenceBitwiseXorOperator
- :
- BITWISEXOR
- ;
-
-precedenceBitwiseXorExpression
- :
- precedenceUnarySuffixExpression (precedenceBitwiseXorOperator^ precedenceUnarySuffixExpression)*
- ;
-
-
-precedenceStarOperator
- :
- STAR | DIVIDE | MOD | DIV
- ;
-
-precedenceStarExpression
- :
- precedenceBitwiseXorExpression (precedenceStarOperator^ precedenceBitwiseXorExpression)*
- ;
-
-
-precedencePlusOperator
- :
- PLUS | MINUS
- ;
-
-precedencePlusExpression
- :
- precedenceStarExpression (precedencePlusOperator^ precedenceStarExpression)*
- ;
-
-
-precedenceAmpersandOperator
- :
- AMPERSAND
- ;
-
-precedenceAmpersandExpression
- :
- precedencePlusExpression (precedenceAmpersandOperator^ precedencePlusExpression)*
- ;
-
-
-precedenceBitwiseOrOperator
- :
- BITWISEOR
- ;
-
-precedenceBitwiseOrExpression
- :
- precedenceAmpersandExpression (precedenceBitwiseOrOperator^ precedenceAmpersandExpression)*
- ;
-
-
-// Equal operators supporting NOT prefix
-precedenceEqualNegatableOperator
- :
- KW_LIKE | KW_RLIKE | KW_REGEXP
- ;
-
-precedenceEqualOperator
- :
- precedenceEqualNegatableOperator | EQUAL | EQUAL_NS | NOTEQUAL | LESSTHANOREQUALTO | LESSTHAN | GREATERTHANOREQUALTO | GREATERTHAN
- ;
-
-subQueryExpression
- :
- LPAREN! selectStatement[true] RPAREN!
- ;
-
-precedenceEqualExpression
- :
- (LPAREN precedenceBitwiseOrExpression COMMA) => precedenceEqualExpressionMutiple
- |
- precedenceEqualExpressionSingle
- ;
-
-precedenceEqualExpressionSingle
- :
- (left=precedenceBitwiseOrExpression -> $left)
- (
- (KW_NOT precedenceEqualNegatableOperator notExpr=precedenceBitwiseOrExpression)
- -> ^(KW_NOT ^(precedenceEqualNegatableOperator $precedenceEqualExpressionSingle $notExpr))
- | (precedenceEqualOperator equalExpr=precedenceBitwiseOrExpression)
- -> ^(precedenceEqualOperator $precedenceEqualExpressionSingle $equalExpr)
- | (KW_NOT KW_IN LPAREN KW_SELECT)=> (KW_NOT KW_IN subQueryExpression)
- -> ^(KW_NOT ^(TOK_SUBQUERY_EXPR ^(TOK_SUBQUERY_OP KW_IN) subQueryExpression $precedenceEqualExpressionSingle))
- | (KW_NOT KW_IN expressions)
- -> ^(KW_NOT ^(TOK_FUNCTION KW_IN $precedenceEqualExpressionSingle expressions))
- | (KW_IN LPAREN KW_SELECT)=> (KW_IN subQueryExpression)
- -> ^(TOK_SUBQUERY_EXPR ^(TOK_SUBQUERY_OP KW_IN) subQueryExpression $precedenceEqualExpressionSingle)
- | (KW_IN expressions)
- -> ^(TOK_FUNCTION KW_IN $precedenceEqualExpressionSingle expressions)
- | ( KW_NOT KW_BETWEEN (min=precedenceBitwiseOrExpression) KW_AND (max=precedenceBitwiseOrExpression) )
- -> ^(TOK_FUNCTION Identifier["between"] KW_TRUE $left $min $max)
- | ( KW_BETWEEN (min=precedenceBitwiseOrExpression) KW_AND (max=precedenceBitwiseOrExpression) )
- -> ^(TOK_FUNCTION Identifier["between"] KW_FALSE $left $min $max)
- )*
- | (KW_EXISTS LPAREN KW_SELECT)=> (KW_EXISTS subQueryExpression) -> ^(TOK_SUBQUERY_EXPR ^(TOK_SUBQUERY_OP KW_EXISTS) subQueryExpression)
- ;
-
-expressions
- :
- LPAREN expression (COMMA expression)* RPAREN -> expression+
- ;
-
-//we transform the (col0, col1) in ((v00,v01),(v10,v11)) into struct(col0, col1) in (struct(v00,v01),struct(v10,v11))
-precedenceEqualExpressionMutiple
- :
- (LPAREN precedenceBitwiseOrExpression (COMMA precedenceBitwiseOrExpression)+ RPAREN -> ^(TOK_FUNCTION Identifier["struct"] precedenceBitwiseOrExpression+))
- ( (KW_IN LPAREN expressionsToStruct (COMMA expressionsToStruct)+ RPAREN)
- -> ^(TOK_FUNCTION KW_IN $precedenceEqualExpressionMutiple expressionsToStruct+)
- | (KW_NOT KW_IN LPAREN expressionsToStruct (COMMA expressionsToStruct)+ RPAREN)
- -> ^(KW_NOT ^(TOK_FUNCTION KW_IN $precedenceEqualExpressionMutiple expressionsToStruct+)))
- ;
-
-expressionsToStruct
- :
- LPAREN expression (COMMA expression)* RPAREN -> ^(TOK_FUNCTION Identifier["struct"] expression+)
- ;
-
-precedenceNotOperator
- :
- KW_NOT
- ;
-
-precedenceNotExpression
- :
- (precedenceNotOperator^)* precedenceEqualExpression
- ;
-
-
-precedenceAndOperator
- :
- KW_AND
- ;
-
-precedenceAndExpression
- :
- precedenceNotExpression (precedenceAndOperator^ precedenceNotExpression)*
- ;
-
-
-precedenceOrOperator
- :
- KW_OR
- ;
-
-precedenceOrExpression
- :
- precedenceAndExpression (precedenceOrOperator^ precedenceAndExpression)*
- ;
diff --git a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/FromClauseParser.g b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/FromClauseParser.g
deleted file mode 100644
index 1bf461c912..0000000000
--- a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/FromClauseParser.g
+++ /dev/null
@@ -1,341 +0,0 @@
-/**
- Licensed to the Apache Software Foundation (ASF) under one or more
- contributor license agreements. See the NOTICE file distributed with
- this work for additional information regarding copyright ownership.
- The ASF licenses this file to You under the Apache License, Version 2.0
- (the "License"); you may not use this file except in compliance with
- the License. You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License.
-
- This file is an adaptation of Hive's org/apache/hadoop/hive/ql/FromClauseParser.g grammar.
-*/
-parser grammar FromClauseParser;
-
-options
-{
-output=AST;
-ASTLabelType=CommonTree;
-backtrack=false;
-k=3;
-}
-
-@members {
- @Override
- public Object recoverFromMismatchedSet(IntStream input,
- RecognitionException re, BitSet follow) throws RecognitionException {
- throw re;
- }
- @Override
- public void displayRecognitionError(String[] tokenNames,
- RecognitionException e) {
- gParent.displayRecognitionError(tokenNames, e);
- }
- protected boolean useSQL11ReservedKeywordsForIdentifier() {
- return gParent.useSQL11ReservedKeywordsForIdentifier();
- }
-}
-
-@rulecatch {
-catch (RecognitionException e) {
- throw e;
-}
-}
-
-//-----------------------------------------------------------------------------------
-
-tableAllColumns
- : STAR
- -> ^(TOK_ALLCOLREF)
- | tableName DOT STAR
- -> ^(TOK_ALLCOLREF tableName)
- ;
-
-// (table|column)
-tableOrColumn
-@init { gParent.pushMsg("table or column identifier", state); }
-@after { gParent.popMsg(state); }
- :
- identifier -> ^(TOK_TABLE_OR_COL identifier)
- ;
-
-expressionList
-@init { gParent.pushMsg("expression list", state); }
-@after { gParent.popMsg(state); }
- :
- expression (COMMA expression)* -> ^(TOK_EXPLIST expression+)
- ;
-
-aliasList
-@init { gParent.pushMsg("alias list", state); }
-@after { gParent.popMsg(state); }
- :
- identifier (COMMA identifier)* -> ^(TOK_ALIASLIST identifier+)
- ;
-
-//----------------------- Rules for parsing fromClause ------------------------------
-// from [col1, col2, col3] table1, [col4, col5] table2
-fromClause
-@init { gParent.pushMsg("from clause", state); }
-@after { gParent.popMsg(state); }
- :
- KW_FROM joinSource -> ^(TOK_FROM joinSource)
- ;
-
-joinSource
-@init { gParent.pushMsg("join source", state); }
-@after { gParent.popMsg(state); }
- : fromSource ( joinToken^ fromSource ( joinCond {$joinToken.start.getType() != COMMA}? )? )*
- | uniqueJoinToken^ uniqueJoinSource (COMMA! uniqueJoinSource)+
- ;
-
-joinCond
-@init { gParent.pushMsg("join expression list", state); }
-@after { gParent.popMsg(state); }
- : KW_ON! expression
- | KW_USING LPAREN columnNameList RPAREN -> ^(TOK_USING columnNameList)
- ;
-
-uniqueJoinSource
-@init { gParent.pushMsg("unique join source", state); }
-@after { gParent.popMsg(state); }
- : KW_PRESERVE? fromSource uniqueJoinExpr
- ;
-
-uniqueJoinExpr
-@init { gParent.pushMsg("unique join expression list", state); }
-@after { gParent.popMsg(state); }
- : LPAREN e1+=expression (COMMA e1+=expression)* RPAREN
- -> ^(TOK_EXPLIST $e1*)
- ;
-
-uniqueJoinToken
-@init { gParent.pushMsg("unique join", state); }
-@after { gParent.popMsg(state); }
- : KW_UNIQUEJOIN -> TOK_UNIQUEJOIN;
-
-joinToken
-@init { gParent.pushMsg("join type specifier", state); }
-@after { gParent.popMsg(state); }
- :
- KW_JOIN -> TOK_JOIN
- | KW_INNER KW_JOIN -> TOK_JOIN
- | KW_NATURAL KW_JOIN -> TOK_NATURALJOIN
- | KW_NATURAL KW_INNER KW_JOIN -> TOK_NATURALJOIN
- | COMMA -> TOK_JOIN
- | KW_CROSS KW_JOIN -> TOK_CROSSJOIN
- | KW_LEFT (KW_OUTER)? KW_JOIN -> TOK_LEFTOUTERJOIN
- | KW_RIGHT (KW_OUTER)? KW_JOIN -> TOK_RIGHTOUTERJOIN
- | KW_FULL (KW_OUTER)? KW_JOIN -> TOK_FULLOUTERJOIN
- | KW_NATURAL KW_LEFT (KW_OUTER)? KW_JOIN -> TOK_NATURALLEFTOUTERJOIN
- | KW_NATURAL KW_RIGHT (KW_OUTER)? KW_JOIN -> TOK_NATURALRIGHTOUTERJOIN
- | KW_NATURAL KW_FULL (KW_OUTER)? KW_JOIN -> TOK_NATURALFULLOUTERJOIN
- | KW_LEFT KW_SEMI KW_JOIN -> TOK_LEFTSEMIJOIN
- | KW_ANTI KW_JOIN -> TOK_ANTIJOIN
- ;
-
-lateralView
-@init {gParent.pushMsg("lateral view", state); }
-@after {gParent.popMsg(state); }
- :
- (KW_LATERAL KW_VIEW KW_OUTER) => KW_LATERAL KW_VIEW KW_OUTER function tableAlias (KW_AS identifier ((COMMA)=> COMMA identifier)*)?
- -> ^(TOK_LATERAL_VIEW_OUTER ^(TOK_SELECT ^(TOK_SELEXPR function identifier* tableAlias)))
- |
- KW_LATERAL KW_VIEW function tableAlias (KW_AS identifier ((COMMA)=> COMMA identifier)*)?
- -> ^(TOK_LATERAL_VIEW ^(TOK_SELECT ^(TOK_SELEXPR function identifier* tableAlias)))
- ;
-
-tableAlias
-@init {gParent.pushMsg("table alias", state); }
-@after {gParent.popMsg(state); }
- :
- identifier -> ^(TOK_TABALIAS identifier)
- ;
-
-fromSource
-@init { gParent.pushMsg("from source", state); }
-@after { gParent.popMsg(state); }
- :
- (LPAREN KW_VALUES) => fromSource0
- | fromSource0
- | (LPAREN joinSource) => LPAREN joinSource RPAREN -> joinSource
- ;
-
-
-fromSource0
-@init { gParent.pushMsg("from source 0", state); }
-@after { gParent.popMsg(state); }
- :
- ((Identifier LPAREN)=> partitionedTableFunction | tableSource | subQuerySource | virtualTableSource) (lateralView^)*
- ;
-
-tableBucketSample
-@init { gParent.pushMsg("table bucket sample specification", state); }
-@after { gParent.popMsg(state); }
- :
- KW_TABLESAMPLE LPAREN KW_BUCKET (numerator=Number) KW_OUT KW_OF (denominator=Number) (KW_ON expr+=expression (COMMA expr+=expression)*)? RPAREN -> ^(TOK_TABLEBUCKETSAMPLE $numerator $denominator $expr*)
- ;
-
-splitSample
-@init { gParent.pushMsg("table split sample specification", state); }
-@after { gParent.popMsg(state); }
- :
- KW_TABLESAMPLE LPAREN (numerator=Number) (percent=KW_PERCENT|KW_ROWS) RPAREN
- -> {percent != null}? ^(TOK_TABLESPLITSAMPLE TOK_PERCENT $numerator)
- -> ^(TOK_TABLESPLITSAMPLE TOK_ROWCOUNT $numerator)
- |
- KW_TABLESAMPLE LPAREN (numerator=ByteLengthLiteral) RPAREN
- -> ^(TOK_TABLESPLITSAMPLE TOK_LENGTH $numerator)
- ;
-
-tableSample
-@init { gParent.pushMsg("table sample specification", state); }
-@after { gParent.popMsg(state); }
- :
- tableBucketSample |
- splitSample
- ;
-
-tableSource
-@init { gParent.pushMsg("table source", state); }
-@after { gParent.popMsg(state); }
- : tabname=tableName
- ((tableProperties) => props=tableProperties)?
- ((tableSample) => ts=tableSample)?
- ((KW_AS) => (KW_AS alias=Identifier)
- |
- (Identifier) => (alias=Identifier))?
- -> ^(TOK_TABREF $tabname $props? $ts? $alias?)
- ;
-
-tableName
-@init { gParent.pushMsg("table name", state); }
-@after { gParent.popMsg(state); }
- :
- id1=identifier (DOT id2=identifier)?
- -> ^(TOK_TABNAME $id1 $id2?)
- ;
-
-viewName
-@init { gParent.pushMsg("view name", state); }
-@after { gParent.popMsg(state); }
- :
- (db=identifier DOT)? view=identifier
- -> ^(TOK_TABNAME $db? $view)
- ;
-
-subQuerySource
-@init { gParent.pushMsg("subquery source", state); }
-@after { gParent.popMsg(state); }
- :
- LPAREN queryStatementExpression[false] RPAREN KW_AS? identifier -> ^(TOK_SUBQUERY queryStatementExpression identifier)
- ;
-
-//---------------------- Rules for parsing PTF clauses -----------------------------
-partitioningSpec
-@init { gParent.pushMsg("partitioningSpec clause", state); }
-@after { gParent.popMsg(state); }
- :
- partitionByClause orderByClause? -> ^(TOK_PARTITIONINGSPEC partitionByClause orderByClause?) |
- orderByClause -> ^(TOK_PARTITIONINGSPEC orderByClause) |
- distributeByClause sortByClause? -> ^(TOK_PARTITIONINGSPEC distributeByClause sortByClause?) |
- sortByClause -> ^(TOK_PARTITIONINGSPEC sortByClause) |
- clusterByClause -> ^(TOK_PARTITIONINGSPEC clusterByClause)
- ;
-
-partitionTableFunctionSource
-@init { gParent.pushMsg("partitionTableFunctionSource clause", state); }
-@after { gParent.popMsg(state); }
- :
- subQuerySource |
- tableSource |
- partitionedTableFunction
- ;
-
-partitionedTableFunction
-@init { gParent.pushMsg("ptf clause", state); }
-@after { gParent.popMsg(state); }
- :
- name=Identifier LPAREN KW_ON
- ((partitionTableFunctionSource) => (ptfsrc=partitionTableFunctionSource spec=partitioningSpec?))
- ((Identifier LPAREN expression RPAREN ) => Identifier LPAREN expression RPAREN ( COMMA Identifier LPAREN expression RPAREN)*)?
- ((RPAREN) => (RPAREN)) ((Identifier) => alias=Identifier)?
- -> ^(TOK_PTBLFUNCTION $name $alias? $ptfsrc $spec? expression*)
- ;
-
-//----------------------- Rules for parsing whereClause -----------------------------
-// where a=b and ...
-whereClause
-@init { gParent.pushMsg("where clause", state); }
-@after { gParent.popMsg(state); }
- :
- KW_WHERE searchCondition -> ^(TOK_WHERE searchCondition)
- ;
-
-searchCondition
-@init { gParent.pushMsg("search condition", state); }
-@after { gParent.popMsg(state); }
- :
- expression
- ;
-
-//-----------------------------------------------------------------------------------
-
-//-------- Row Constructor ----------------------------------------------------------
-//in support of SELECT * FROM (VALUES(1,2,3),(4,5,6),...) as FOO(a,b,c) and
-// INSERT INTO <table> (col1,col2,...) VALUES(...),(...),...
-// INSERT INTO <table> (col1,col2,...) SELECT * FROM (VALUES(1,2,3),(4,5,6),...) as Foo(a,b,c)
-valueRowConstructor
-@init { gParent.pushMsg("value row constructor", state); }
-@after { gParent.popMsg(state); }
- :
- LPAREN precedenceUnaryPrefixExpression (COMMA precedenceUnaryPrefixExpression)* RPAREN -> ^(TOK_VALUE_ROW precedenceUnaryPrefixExpression+)
- ;
-
-valuesTableConstructor
-@init { gParent.pushMsg("values table constructor", state); }
-@after { gParent.popMsg(state); }
- :
- valueRowConstructor (COMMA valueRowConstructor)* -> ^(TOK_VALUES_TABLE valueRowConstructor+)
- ;
-
-/*
-VALUES(1),(2) means 2 rows, 1 column each.
-VALUES(1,2),(3,4) means 2 rows, 2 columns each.
-VALUES(1,2,3) means 1 row, 3 columns
-*/
-valuesClause
-@init { gParent.pushMsg("values clause", state); }
-@after { gParent.popMsg(state); }
- :
- KW_VALUES valuesTableConstructor -> valuesTableConstructor
- ;
-
-/*
-This represents a clause like this:
-(VALUES(1,2),(2,3)) as VirtTable(col1,col2)
-*/
-virtualTableSource
-@init { gParent.pushMsg("virtual table source", state); }
-@after { gParent.popMsg(state); }
- :
- LPAREN valuesClause RPAREN tableNameColList -> ^(TOK_VIRTUAL_TABLE tableNameColList valuesClause)
- ;
-/*
-e.g. as VirtTable(col1,col2)
-Note that we only want literals as column names
-*/
-tableNameColList
-@init { gParent.pushMsg("from source", state); }
-@after { gParent.popMsg(state); }
- :
- KW_AS? identifier LPAREN identifier (COMMA identifier)* RPAREN -> ^(TOK_VIRTUAL_TABREF ^(TOK_TABNAME identifier) ^(TOK_COL_NAME identifier+))
- ;
-
-//-----------------------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/IdentifiersParser.g b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/IdentifiersParser.g
deleted file mode 100644
index 916eb6a7ac..0000000000
--- a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/IdentifiersParser.g
+++ /dev/null
@@ -1,184 +0,0 @@
-/**
- Licensed to the Apache Software Foundation (ASF) under one or more
- contributor license agreements. See the NOTICE file distributed with
- this work for additional information regarding copyright ownership.
- The ASF licenses this file to You under the Apache License, Version 2.0
- (the "License"); you may not use this file except in compliance with
- the License. You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License.
-
- This file is an adaptation of Hive's org/apache/hadoop/hive/ql/IdentifiersParser.g grammar.
-*/
-parser grammar IdentifiersParser;
-
-options
-{
-output=AST;
-ASTLabelType=CommonTree;
-backtrack=false;
-k=3;
-}
-
-@members {
- @Override
- public Object recoverFromMismatchedSet(IntStream input,
- RecognitionException re, BitSet follow) throws RecognitionException {
- throw re;
- }
- @Override
- public void displayRecognitionError(String[] tokenNames,
- RecognitionException e) {
- gParent.displayRecognitionError(tokenNames, e);
- }
- protected boolean useSQL11ReservedKeywordsForIdentifier() {
- return gParent.useSQL11ReservedKeywordsForIdentifier();
- }
-}
-
-@rulecatch {
-catch (RecognitionException e) {
- throw e;
-}
-}
-
-//-----------------------------------------------------------------------------------
-
-// group by a,b
-groupByClause
-@init { gParent.pushMsg("group by clause", state); }
-@after { gParent.popMsg(state); }
- :
- KW_GROUP KW_BY
- expression
- ( COMMA expression)*
- ((rollup=KW_WITH KW_ROLLUP) | (cube=KW_WITH KW_CUBE)) ?
- (sets=KW_GROUPING KW_SETS
- LPAREN groupingSetExpression ( COMMA groupingSetExpression)* RPAREN ) ?
- -> {rollup != null}? ^(TOK_ROLLUP_GROUPBY expression+)
- -> {cube != null}? ^(TOK_CUBE_GROUPBY expression+)
- -> {sets != null}? ^(TOK_GROUPING_SETS expression+ groupingSetExpression+)
- -> ^(TOK_GROUPBY expression+)
- ;
-
-groupingSetExpression
-@init {gParent.pushMsg("grouping set expression", state); }
-@after {gParent.popMsg(state); }
- :
- (LPAREN) => groupingSetExpressionMultiple
- |
- groupingExpressionSingle
- ;
-
-groupingSetExpressionMultiple
-@init {gParent.pushMsg("grouping set part expression", state); }
-@after {gParent.popMsg(state); }
- :
- LPAREN
- expression? (COMMA expression)*
- RPAREN
- -> ^(TOK_GROUPING_SETS_EXPRESSION expression*)
- ;
-
-groupingExpressionSingle
-@init { gParent.pushMsg("groupingExpression expression", state); }
-@after { gParent.popMsg(state); }
- :
- expression -> ^(TOK_GROUPING_SETS_EXPRESSION expression)
- ;
-
-havingClause
-@init { gParent.pushMsg("having clause", state); }
-@after { gParent.popMsg(state); }
- :
- KW_HAVING havingCondition -> ^(TOK_HAVING havingCondition)
- ;
-
-havingCondition
-@init { gParent.pushMsg("having condition", state); }
-@after { gParent.popMsg(state); }
- :
- expression
- ;
-
-expressionsInParenthese
- :
- LPAREN expression (COMMA expression)* RPAREN -> expression+
- ;
-
-expressionsNotInParenthese
- :
- expression (COMMA expression)* -> expression+
- ;
-
-columnRefOrderInParenthese
- :
- LPAREN columnRefOrder (COMMA columnRefOrder)* RPAREN -> columnRefOrder+
- ;
-
-columnRefOrderNotInParenthese
- :
- columnRefOrder (COMMA columnRefOrder)* -> columnRefOrder+
- ;
-
-// order by a,b
-orderByClause
-@init { gParent.pushMsg("order by clause", state); }
-@after { gParent.popMsg(state); }
- :
- KW_ORDER KW_BY columnRefOrder ( COMMA columnRefOrder)* -> ^(TOK_ORDERBY columnRefOrder+)
- ;
-
-clusterByClause
-@init { gParent.pushMsg("cluster by clause", state); }
-@after { gParent.popMsg(state); }
- :
- KW_CLUSTER KW_BY
- (
- (LPAREN) => expressionsInParenthese -> ^(TOK_CLUSTERBY expressionsInParenthese)
- |
- expressionsNotInParenthese -> ^(TOK_CLUSTERBY expressionsNotInParenthese)
- )
- ;
-
-partitionByClause
-@init { gParent.pushMsg("partition by clause", state); }
-@after { gParent.popMsg(state); }
- :
- KW_PARTITION KW_BY
- (
- (LPAREN) => expressionsInParenthese -> ^(TOK_DISTRIBUTEBY expressionsInParenthese)
- |
- expressionsNotInParenthese -> ^(TOK_DISTRIBUTEBY expressionsNotInParenthese)
- )
- ;
-
-distributeByClause
-@init { gParent.pushMsg("distribute by clause", state); }
-@after { gParent.popMsg(state); }
- :
- KW_DISTRIBUTE KW_BY
- (
- (LPAREN) => expressionsInParenthese -> ^(TOK_DISTRIBUTEBY expressionsInParenthese)
- |
- expressionsNotInParenthese -> ^(TOK_DISTRIBUTEBY expressionsNotInParenthese)
- )
- ;
-
-sortByClause
-@init { gParent.pushMsg("sort by clause", state); }
-@after { gParent.popMsg(state); }
- :
- KW_SORT KW_BY
- (
- (LPAREN) => columnRefOrderInParenthese -> ^(TOK_SORTBY columnRefOrderInParenthese)
- |
- columnRefOrderNotInParenthese -> ^(TOK_SORTBY columnRefOrderNotInParenthese)
- )
- ;
diff --git a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/KeywordParser.g b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/KeywordParser.g
deleted file mode 100644
index 12cd5f54a0..0000000000
--- a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/KeywordParser.g
+++ /dev/null
@@ -1,244 +0,0 @@
-/**
- Licensed to the Apache Software Foundation (ASF) under one or more
- contributor license agreements. See the NOTICE file distributed with
- this work for additional information regarding copyright ownership.
- The ASF licenses this file to You under the Apache License, Version 2.0
- (the "License"); you may not use this file except in compliance with
- the License. You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License.
-
- This file is an adaptation of Hive's org/apache/hadoop/hive/ql/IdentifiersParser.g grammar.
-*/
-
-parser grammar KeywordParser;
-
-options
-{
-output=AST;
-ASTLabelType=CommonTree;
-backtrack=false;
-k=3;
-}
-
-@members {
- @Override
- public Object recoverFromMismatchedSet(IntStream input,
- RecognitionException re, BitSet follow) throws RecognitionException {
- throw re;
- }
- @Override
- public void displayRecognitionError(String[] tokenNames,
- RecognitionException e) {
- gParent.displayRecognitionError(tokenNames, e);
- }
- protected boolean useSQL11ReservedKeywordsForIdentifier() {
- return gParent.useSQL11ReservedKeywordsForIdentifier();
- }
-}
-
-@rulecatch {
-catch (RecognitionException e) {
- throw e;
-}
-}
-
-booleanValue
- :
- KW_TRUE^ | KW_FALSE^
- ;
-
-booleanValueTok
- :
- KW_TRUE -> TOK_TRUE
- | KW_FALSE -> TOK_FALSE
- ;
-
-tableOrPartition
- :
- tableName partitionSpec? -> ^(TOK_TAB tableName partitionSpec?)
- ;
-
-partitionSpec
- :
- KW_PARTITION
- LPAREN partitionVal (COMMA partitionVal )* RPAREN -> ^(TOK_PARTSPEC partitionVal +)
- ;
-
-partitionVal
- :
- identifier (EQUAL constant)? -> ^(TOK_PARTVAL identifier constant?)
- ;
-
-dropPartitionSpec
- :
- KW_PARTITION
- LPAREN dropPartitionVal (COMMA dropPartitionVal )* RPAREN -> ^(TOK_PARTSPEC dropPartitionVal +)
- ;
-
-dropPartitionVal
- :
- identifier dropPartitionOperator constant -> ^(TOK_PARTVAL identifier dropPartitionOperator constant)
- ;
-
-dropPartitionOperator
- :
- EQUAL | NOTEQUAL | LESSTHANOREQUALTO | LESSTHAN | GREATERTHANOREQUALTO | GREATERTHAN
- ;
-
-sysFuncNames
- :
- KW_AND
- | KW_OR
- | KW_NOT
- | KW_LIKE
- | KW_IF
- | KW_CASE
- | KW_WHEN
- | KW_TINYINT
- | KW_SMALLINT
- | KW_INT
- | KW_BIGINT
- | KW_FLOAT
- | KW_DOUBLE
- | KW_BOOLEAN
- | KW_STRING
- | KW_BINARY
- | KW_ARRAY
- | KW_MAP
- | KW_STRUCT
- | KW_UNIONTYPE
- | EQUAL
- | EQUAL_NS
- | NOTEQUAL
- | LESSTHANOREQUALTO
- | LESSTHAN
- | GREATERTHANOREQUALTO
- | GREATERTHAN
- | DIVIDE
- | PLUS
- | MINUS
- | STAR
- | MOD
- | DIV
- | AMPERSAND
- | TILDE
- | BITWISEOR
- | BITWISEXOR
- | KW_RLIKE
- | KW_REGEXP
- | KW_IN
- | KW_BETWEEN
- ;
-
-descFuncNames
- :
- (sysFuncNames) => sysFuncNames
- | StringLiteral
- | functionIdentifier
- ;
-
-//We are allowed to use From and To in CreateTableUsing command's options (actually seems we can use any string as the option key). But we can't simply add them into nonReserved because by doing that we mess other existing rules. So we create a looseIdentifier and looseNonReserved here.
-looseIdentifier
- :
- Identifier
- | looseNonReserved -> Identifier[$looseNonReserved.text]
- // If it decides to support SQL11 reserved keywords, i.e., useSQL11ReservedKeywordsForIdentifier()=false,
- // the sql11keywords in existing q tests will NOT be added back.
- | {useSQL11ReservedKeywordsForIdentifier()}? sql11ReservedKeywordsUsedAsIdentifier -> Identifier[$sql11ReservedKeywordsUsedAsIdentifier.text]
- ;
-
-identifier
- :
- Identifier
- | nonReserved -> Identifier[$nonReserved.text]
- // If it decides to support SQL11 reserved keywords, i.e., useSQL11ReservedKeywordsForIdentifier()=false,
- // the sql11keywords in existing q tests will NOT be added back.
- | {useSQL11ReservedKeywordsForIdentifier()}? sql11ReservedKeywordsUsedAsIdentifier -> Identifier[$sql11ReservedKeywordsUsedAsIdentifier.text]
- ;
-
-functionIdentifier
-@init { gParent.pushMsg("function identifier", state); }
-@after { gParent.popMsg(state); }
- :
- identifier (DOT identifier)? -> identifier+
- ;
-
-principalIdentifier
-@init { gParent.pushMsg("identifier for principal spec", state); }
-@after { gParent.popMsg(state); }
- : identifier
- | QuotedIdentifier
- ;
-
-looseNonReserved
- : nonReserved | KW_FROM | KW_TO
- ;
-
-//The new version of nonReserved + sql11ReservedKeywordsUsedAsIdentifier = old version of nonReserved
-//Non reserved keywords are basically the keywords that can be used as identifiers.
-//All the KW_* are automatically not only keywords, but also reserved keywords.
-//That means, they can NOT be used as identifiers.
-//If you would like to use them as identifiers, put them in the nonReserved list below.
-//If you are not sure, please refer to the SQL2011 column in
-//http://www.postgresql.org/docs/9.5/static/sql-keywords-appendix.html
-nonReserved
- :
- KW_ADD | KW_ADMIN | KW_AFTER | KW_ANALYZE | KW_ARCHIVE | KW_ASC | KW_BEFORE | KW_BUCKET | KW_BUCKETS
- | KW_CASCADE | KW_CHANGE | KW_CLUSTER | KW_CLUSTERED | KW_CLUSTERSTATUS | KW_COLLECTION | KW_COLUMNS
- | KW_COMMENT | KW_COMPACT | KW_COMPACTIONS | KW_COMPUTE | KW_CONCATENATE | KW_CONTINUE | KW_DATA | KW_DAY
- | KW_DATABASES | KW_DATETIME | KW_DBPROPERTIES | KW_DEFERRED | KW_DEFINED | KW_DELIMITED | KW_DEPENDENCY
- | KW_DESC | KW_DIRECTORIES | KW_DIRECTORY | KW_DISABLE | KW_DISTRIBUTE | KW_ELEM_TYPE
- | KW_ENABLE | KW_ESCAPED | KW_EXCLUSIVE | KW_EXPLAIN | KW_EXPORT | KW_FIELDS | KW_FILE | KW_FILEFORMAT
- | KW_FIRST | KW_FORMAT | KW_FORMATTED | KW_FUNCTIONS | KW_HOLD_DDLTIME | KW_HOUR | KW_IDXPROPERTIES | KW_IGNORE
- | KW_INDEX | KW_INDEXES | KW_INPATH | KW_INPUTDRIVER | KW_INPUTFORMAT | KW_ITEMS | KW_JAR
- | KW_KEYS | KW_KEY_TYPE | KW_LIMIT | KW_LINES | KW_LOAD | KW_LOCATION | KW_LOCK | KW_LOCKS | KW_LOGICAL | KW_LONG
- | KW_MAPJOIN | KW_MATERIALIZED | KW_METADATA | KW_MINUS | KW_MINUTE | KW_MONTH | KW_MSCK | KW_NOSCAN | KW_NO_DROP | KW_OFFLINE
- | KW_OPTION | KW_OUTPUTDRIVER | KW_OUTPUTFORMAT | KW_OVERWRITE | KW_OWNER | KW_PARTITIONED | KW_PARTITIONS | KW_PLUS | KW_PRETTY
- | KW_PRINCIPALS | KW_PROTECTION | KW_PURGE | KW_READ | KW_READONLY | KW_REBUILD | KW_RECORDREADER | KW_RECORDWRITER
- | KW_RELOAD | KW_RENAME | KW_REPAIR | KW_REPLACE | KW_REPLICATION | KW_RESTRICT | KW_REWRITE
- | KW_ROLE | KW_ROLES | KW_SCHEMA | KW_SCHEMAS | KW_SECOND | KW_SEMI | KW_SERDE | KW_SERDEPROPERTIES | KW_SERVER | KW_SETS | KW_SHARED
- | KW_SHOW | KW_SHOW_DATABASE | KW_SKEWED | KW_SORT | KW_SORTED | KW_SSL | KW_STATISTICS | KW_STORED
- | KW_STREAMTABLE | KW_STRING | KW_STRUCT | KW_TABLES | KW_TBLPROPERTIES | KW_TEMPORARY | KW_TERMINATED
- | KW_TINYINT | KW_TOUCH | KW_TRANSACTIONS | KW_UNARCHIVE | KW_UNDO | KW_UNIONTYPE | KW_UNLOCK | KW_UNSET
- | KW_UNSIGNED | KW_URI | KW_USE | KW_UTC | KW_UTCTIMESTAMP | KW_VALUE_TYPE | KW_VIEW | KW_WHILE | KW_YEAR
- | KW_WORK
- | KW_TRANSACTION
- | KW_WRITE
- | KW_ISOLATION
- | KW_LEVEL
- | KW_SNAPSHOT
- | KW_AUTOCOMMIT
- | KW_ANTI
- | KW_WEEK | KW_MILLISECOND | KW_MICROSECOND
- | KW_CLEAR | KW_LAZY | KW_CACHE | KW_UNCACHE | KW_DFS
-;
-
-//The following SQL2011 reserved keywords are used as cast function name only, but not as identifiers.
-sql11ReservedKeywordsUsedAsCastFunctionName
- :
- KW_BIGINT | KW_BINARY | KW_BOOLEAN | KW_CURRENT_DATE | KW_CURRENT_TIMESTAMP | KW_DATE | KW_DOUBLE | KW_FLOAT | KW_INT | KW_SMALLINT | KW_TIMESTAMP
- ;
-
-//The following SQL2011 reserved keywords are used as identifiers in many q tests, they may be added back due to backward compatibility.
-//We are planning to remove the following whole list after several releases.
-//Thus, please do not change the following list unless you know what to do.
-sql11ReservedKeywordsUsedAsIdentifier
- :
- KW_ALL | KW_ALTER | KW_ARRAY | KW_AS | KW_AUTHORIZATION | KW_BETWEEN | KW_BIGINT | KW_BINARY | KW_BOOLEAN
- | KW_BOTH | KW_BY | KW_CREATE | KW_CUBE | KW_CURRENT_DATE | KW_CURRENT_TIMESTAMP | KW_CURSOR | KW_DATE | KW_DECIMAL | KW_DELETE | KW_DESCRIBE
- | KW_DOUBLE | KW_DROP | KW_EXISTS | KW_EXTERNAL | KW_FALSE | KW_FETCH | KW_FLOAT | KW_FOR | KW_FULL | KW_GRANT
- | KW_GROUP | KW_GROUPING | KW_IMPORT | KW_IN | KW_INNER | KW_INSERT | KW_INT | KW_INTERSECT | KW_INTO | KW_IS | KW_LATERAL
- | KW_LEFT | KW_LIKE | KW_LOCAL | KW_NONE | KW_NULL | KW_OF | KW_ORDER | KW_OUT | KW_OUTER | KW_PARTITION
- | KW_PERCENT | KW_PROCEDURE | KW_RANGE | KW_READS | KW_REVOKE | KW_RIGHT
- | KW_ROLLUP | KW_ROW | KW_ROWS | KW_SET | KW_SMALLINT | KW_TABLE | KW_TIMESTAMP | KW_TO | KW_TRIGGER | KW_TRUE
- | KW_TRUNCATE | KW_UNION | KW_UPDATE | KW_USER | KW_USING | KW_VALUES | KW_WITH
-//The following two keywords come from MySQL. Although they are not keywords in SQL2011, they are reserved keywords in MySQL.
- | KW_REGEXP | KW_RLIKE
- ;
diff --git a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SelectClauseParser.g b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SelectClauseParser.g
deleted file mode 100644
index f18b6ec496..0000000000
--- a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SelectClauseParser.g
+++ /dev/null
@@ -1,235 +0,0 @@
-/**
- Licensed to the Apache Software Foundation (ASF) under one or more
- contributor license agreements. See the NOTICE file distributed with
- this work for additional information regarding copyright ownership.
- The ASF licenses this file to You under the Apache License, Version 2.0
- (the "License"); you may not use this file except in compliance with
- the License. You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License.
-
- This file is an adaptation of Hive's org/apache/hadoop/hive/ql/SelectClauseParser.g grammar.
-*/
-parser grammar SelectClauseParser;
-
-options
-{
-output=AST;
-ASTLabelType=CommonTree;
-backtrack=false;
-k=3;
-}
-
-@members {
- @Override
- public Object recoverFromMismatchedSet(IntStream input,
- RecognitionException re, BitSet follow) throws RecognitionException {
- throw re;
- }
- @Override
- public void displayRecognitionError(String[] tokenNames,
- RecognitionException e) {
- gParent.displayRecognitionError(tokenNames, e);
- }
- protected boolean useSQL11ReservedKeywordsForIdentifier() {
- return gParent.useSQL11ReservedKeywordsForIdentifier();
- }
-}
-
-@rulecatch {
-catch (RecognitionException e) {
- throw e;
-}
-}
-
-//----------------------- Rules for parsing selectClause -----------------------------
-// select a,b,c ...
-selectClause
-@init { gParent.pushMsg("select clause", state); }
-@after { gParent.popMsg(state); }
- :
- KW_SELECT hintClause? (((KW_ALL | dist=KW_DISTINCT)? selectList)
- | (transform=KW_TRANSFORM selectTrfmClause))
- -> {$transform == null && $dist == null}? ^(TOK_SELECT hintClause? selectList)
- -> {$transform == null && $dist != null}? ^(TOK_SELECTDI hintClause? selectList)
- -> ^(TOK_SELECT hintClause? ^(TOK_SELEXPR selectTrfmClause) )
- |
- trfmClause ->^(TOK_SELECT ^(TOK_SELEXPR trfmClause))
- ;
-
-selectList
-@init { gParent.pushMsg("select list", state); }
-@after { gParent.popMsg(state); }
- :
- selectItem ( COMMA selectItem )* -> selectItem+
- ;
-
-selectTrfmClause
-@init { gParent.pushMsg("transform clause", state); }
-@after { gParent.popMsg(state); }
- :
- LPAREN selectExpressionList RPAREN
- inSerde=rowFormat inRec=recordWriter
- KW_USING StringLiteral
- ( KW_AS ((LPAREN (aliasList | columnNameTypeList) RPAREN) | (aliasList | columnNameTypeList)))?
- outSerde=rowFormat outRec=recordReader
- -> ^(TOK_TRANSFORM selectExpressionList $inSerde $inRec StringLiteral $outSerde $outRec aliasList? columnNameTypeList?)
- ;
-
-hintClause
-@init { gParent.pushMsg("hint clause", state); }
-@after { gParent.popMsg(state); }
- :
- DIVIDE STAR PLUS hintList STAR DIVIDE -> ^(TOK_HINTLIST hintList)
- ;
-
-hintList
-@init { gParent.pushMsg("hint list", state); }
-@after { gParent.popMsg(state); }
- :
- hintItem (COMMA hintItem)* -> hintItem+
- ;
-
-hintItem
-@init { gParent.pushMsg("hint item", state); }
-@after { gParent.popMsg(state); }
- :
- hintName (LPAREN hintArgs RPAREN)? -> ^(TOK_HINT hintName hintArgs?)
- ;
-
-hintName
-@init { gParent.pushMsg("hint name", state); }
-@after { gParent.popMsg(state); }
- :
- KW_MAPJOIN -> TOK_MAPJOIN
- | KW_STREAMTABLE -> TOK_STREAMTABLE
- ;
-
-hintArgs
-@init { gParent.pushMsg("hint arguments", state); }
-@after { gParent.popMsg(state); }
- :
- hintArgName (COMMA hintArgName)* -> ^(TOK_HINTARGLIST hintArgName+)
- ;
-
-hintArgName
-@init { gParent.pushMsg("hint argument name", state); }
-@after { gParent.popMsg(state); }
- :
- identifier
- ;
-
-selectItem
-@init { gParent.pushMsg("selection target", state); }
-@after { gParent.popMsg(state); }
- :
- (tableAllColumns) => tableAllColumns -> ^(TOK_SELEXPR tableAllColumns)
- |
- namedExpression
- ;
-
-namedExpression
-@init { gParent.pushMsg("select named expression", state); }
-@after { gParent.popMsg(state); }
- :
- ( expression
- ((KW_AS? identifier) | (KW_AS LPAREN identifier (COMMA identifier)* RPAREN))?
- ) -> ^(TOK_SELEXPR expression identifier*)
- ;
-
-trfmClause
-@init { gParent.pushMsg("transform clause", state); }
-@after { gParent.popMsg(state); }
- :
- ( KW_MAP selectExpressionList
- | KW_REDUCE selectExpressionList )
- inSerde=rowFormat inRec=recordWriter
- KW_USING StringLiteral
- ( KW_AS ((LPAREN (aliasList | columnNameTypeList) RPAREN) | (aliasList | columnNameTypeList)))?
- outSerde=rowFormat outRec=recordReader
- -> ^(TOK_TRANSFORM selectExpressionList $inSerde $inRec StringLiteral $outSerde $outRec aliasList? columnNameTypeList?)
- ;
-
-selectExpression
-@init { gParent.pushMsg("select expression", state); }
-@after { gParent.popMsg(state); }
- :
- (tableAllColumns) => tableAllColumns
- |
- expression
- ;
-
-selectExpressionList
-@init { gParent.pushMsg("select expression list", state); }
-@after { gParent.popMsg(state); }
- :
- selectExpression (COMMA selectExpression)* -> ^(TOK_EXPLIST selectExpression+)
- ;
-
-//---------------------- Rules for windowing clauses -------------------------------
-window_clause
-@init { gParent.pushMsg("window_clause", state); }
-@after { gParent.popMsg(state); }
-:
- KW_WINDOW window_defn (COMMA window_defn)* -> ^(KW_WINDOW window_defn+)
-;
-
-window_defn
-@init { gParent.pushMsg("window_defn", state); }
-@after { gParent.popMsg(state); }
-:
- Identifier KW_AS window_specification -> ^(TOK_WINDOWDEF Identifier window_specification)
-;
-
-window_specification
-@init { gParent.pushMsg("window_specification", state); }
-@after { gParent.popMsg(state); }
-:
- (Identifier | ( LPAREN Identifier? partitioningSpec? window_frame? RPAREN)) -> ^(TOK_WINDOWSPEC Identifier? partitioningSpec? window_frame?)
-;
-
-window_frame :
- window_range_expression |
- window_value_expression
-;
-
-window_range_expression
-@init { gParent.pushMsg("window_range_expression", state); }
-@after { gParent.popMsg(state); }
-:
- KW_ROWS sb=window_frame_start_boundary -> ^(TOK_WINDOWRANGE $sb) |
- KW_ROWS KW_BETWEEN s=window_frame_boundary KW_AND end=window_frame_boundary -> ^(TOK_WINDOWRANGE $s $end)
-;
-
-window_value_expression
-@init { gParent.pushMsg("window_value_expression", state); }
-@after { gParent.popMsg(state); }
-:
- KW_RANGE sb=window_frame_start_boundary -> ^(TOK_WINDOWVALUES $sb) |
- KW_RANGE KW_BETWEEN s=window_frame_boundary KW_AND end=window_frame_boundary -> ^(TOK_WINDOWVALUES $s $end)
-;
-
-window_frame_start_boundary
-@init { gParent.pushMsg("windowframestartboundary", state); }
-@after { gParent.popMsg(state); }
-:
- KW_UNBOUNDED KW_PRECEDING -> ^(KW_PRECEDING KW_UNBOUNDED) |
- KW_CURRENT KW_ROW -> ^(KW_CURRENT) |
- Number KW_PRECEDING -> ^(KW_PRECEDING Number)
-;
-
-window_frame_boundary
-@init { gParent.pushMsg("windowframeboundary", state); }
-@after { gParent.popMsg(state); }
-:
- KW_UNBOUNDED (r=KW_PRECEDING|r=KW_FOLLOWING) -> ^($r KW_UNBOUNDED) |
- KW_CURRENT KW_ROW -> ^(KW_CURRENT) |
- Number (d=KW_PRECEDING | d=KW_FOLLOWING ) -> ^($d Number)
-;
-
diff --git a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlLexer.g b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlLexer.g
deleted file mode 100644
index fd1ad59207..0000000000
--- a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlLexer.g
+++ /dev/null
@@ -1,491 +0,0 @@
-/**
- Licensed to the Apache Software Foundation (ASF) under one or more
- contributor license agreements. See the NOTICE file distributed with
- this work for additional information regarding copyright ownership.
- The ASF licenses this file to You under the Apache License, Version 2.0
- (the "License"); you may not use this file except in compliance with
- the License. You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License.
-
- This file is an adaptation of Hive's org/apache/hadoop/hive/ql/HiveLexer.g grammar.
-*/
-lexer grammar SparkSqlLexer;
-
-@lexer::header {
-package org.apache.spark.sql.catalyst.parser;
-
-}
-
-@lexer::members {
- private ParserConf parserConf;
- private ParseErrorReporter reporter;
-
- public void configure(ParserConf parserConf, ParseErrorReporter reporter) {
- this.parserConf = parserConf;
- this.reporter = reporter;
- }
-
- protected boolean allowQuotedId() {
- if (parserConf == null) {
- return true;
- }
- return parserConf.supportQuotedId();
- }
-
- @Override
- public void displayRecognitionError(String[] tokenNames, RecognitionException e) {
- if (reporter != null) {
- reporter.report(this, e, tokenNames);
- }
- }
-}
-
-// Keywords
-
-KW_TRUE : 'TRUE';
-KW_FALSE : 'FALSE';
-KW_ALL : 'ALL';
-KW_NONE: 'NONE';
-KW_AND : 'AND';
-KW_OR : 'OR';
-KW_NOT : 'NOT' | '!';
-KW_LIKE : 'LIKE';
-
-KW_IF : 'IF';
-KW_EXISTS : 'EXISTS';
-
-KW_ASC : 'ASC';
-KW_DESC : 'DESC';
-KW_ORDER : 'ORDER';
-KW_GROUP : 'GROUP';
-KW_BY : 'BY';
-KW_HAVING : 'HAVING';
-KW_WHERE : 'WHERE';
-KW_FROM : 'FROM';
-KW_AS : 'AS';
-KW_SELECT : 'SELECT';
-KW_DISTINCT : 'DISTINCT';
-KW_INSERT : 'INSERT';
-KW_OVERWRITE : 'OVERWRITE';
-KW_OUTER : 'OUTER';
-KW_UNIQUEJOIN : 'UNIQUEJOIN';
-KW_PRESERVE : 'PRESERVE';
-KW_JOIN : 'JOIN';
-KW_LEFT : 'LEFT';
-KW_RIGHT : 'RIGHT';
-KW_FULL : 'FULL';
-KW_ANTI : 'ANTI';
-KW_ON : 'ON';
-KW_PARTITION : 'PARTITION';
-KW_PARTITIONS : 'PARTITIONS';
-KW_TABLE: 'TABLE';
-KW_TABLES: 'TABLES';
-KW_COLUMNS: 'COLUMNS';
-KW_INDEX: 'INDEX';
-KW_INDEXES: 'INDEXES';
-KW_REBUILD: 'REBUILD';
-KW_FUNCTIONS: 'FUNCTIONS';
-KW_SHOW: 'SHOW';
-KW_MSCK: 'MSCK';
-KW_REPAIR: 'REPAIR';
-KW_DIRECTORY: 'DIRECTORY';
-KW_LOCAL: 'LOCAL';
-KW_TRANSFORM : 'TRANSFORM';
-KW_USING: 'USING';
-KW_CLUSTER: 'CLUSTER';
-KW_DISTRIBUTE: 'DISTRIBUTE';
-KW_SORT: 'SORT';
-KW_UNION: 'UNION';
-KW_EXCEPT: 'EXCEPT';
-KW_LOAD: 'LOAD';
-KW_EXPORT: 'EXPORT';
-KW_IMPORT: 'IMPORT';
-KW_REPLICATION: 'REPLICATION';
-KW_METADATA: 'METADATA';
-KW_DATA: 'DATA';
-KW_INPATH: 'INPATH';
-KW_IS: 'IS';
-KW_NULL: 'NULL';
-KW_CREATE: 'CREATE';
-KW_EXTERNAL: 'EXTERNAL';
-KW_ALTER: 'ALTER';
-KW_CHANGE: 'CHANGE';
-KW_COLUMN: 'COLUMN';
-KW_FIRST: 'FIRST';
-KW_AFTER: 'AFTER';
-KW_DESCRIBE: 'DESCRIBE';
-KW_DROP: 'DROP';
-KW_RENAME: 'RENAME';
-KW_TO: 'TO';
-KW_COMMENT: 'COMMENT';
-KW_BOOLEAN: 'BOOLEAN';
-KW_TINYINT: 'TINYINT';
-KW_SMALLINT: 'SMALLINT';
-KW_INT: 'INT';
-KW_BIGINT: 'BIGINT';
-KW_FLOAT: 'FLOAT';
-KW_DOUBLE: 'DOUBLE';
-KW_DATE: 'DATE';
-KW_DATETIME: 'DATETIME';
-KW_TIMESTAMP: 'TIMESTAMP';
-KW_INTERVAL: 'INTERVAL';
-KW_DECIMAL: 'DECIMAL';
-KW_STRING: 'STRING';
-KW_CHAR: 'CHAR';
-KW_VARCHAR: 'VARCHAR';
-KW_ARRAY: 'ARRAY';
-KW_STRUCT: 'STRUCT';
-KW_MAP: 'MAP';
-KW_UNIONTYPE: 'UNIONTYPE';
-KW_REDUCE: 'REDUCE';
-KW_PARTITIONED: 'PARTITIONED';
-KW_CLUSTERED: 'CLUSTERED';
-KW_SORTED: 'SORTED';
-KW_INTO: 'INTO';
-KW_BUCKETS: 'BUCKETS';
-KW_ROW: 'ROW';
-KW_ROWS: 'ROWS';
-KW_FORMAT: 'FORMAT';
-KW_DELIMITED: 'DELIMITED';
-KW_FIELDS: 'FIELDS';
-KW_TERMINATED: 'TERMINATED';
-KW_ESCAPED: 'ESCAPED';
-KW_COLLECTION: 'COLLECTION';
-KW_ITEMS: 'ITEMS';
-KW_KEYS: 'KEYS';
-KW_KEY_TYPE: '$KEY$';
-KW_LINES: 'LINES';
-KW_STORED: 'STORED';
-KW_FILEFORMAT: 'FILEFORMAT';
-KW_INPUTFORMAT: 'INPUTFORMAT';
-KW_OUTPUTFORMAT: 'OUTPUTFORMAT';
-KW_INPUTDRIVER: 'INPUTDRIVER';
-KW_OUTPUTDRIVER: 'OUTPUTDRIVER';
-KW_ENABLE: 'ENABLE';
-KW_DISABLE: 'DISABLE';
-KW_LOCATION: 'LOCATION';
-KW_TABLESAMPLE: 'TABLESAMPLE';
-KW_BUCKET: 'BUCKET';
-KW_OUT: 'OUT';
-KW_OF: 'OF';
-KW_PERCENT: 'PERCENT';
-KW_CAST: 'CAST';
-KW_ADD: 'ADD';
-KW_REPLACE: 'REPLACE';
-KW_RLIKE: 'RLIKE';
-KW_REGEXP: 'REGEXP';
-KW_TEMPORARY: 'TEMPORARY';
-KW_FUNCTION: 'FUNCTION';
-KW_MACRO: 'MACRO';
-KW_FILE: 'FILE';
-KW_JAR: 'JAR';
-KW_EXPLAIN: 'EXPLAIN';
-KW_EXTENDED: 'EXTENDED';
-KW_FORMATTED: 'FORMATTED';
-KW_PRETTY: 'PRETTY';
-KW_DEPENDENCY: 'DEPENDENCY';
-KW_LOGICAL: 'LOGICAL';
-KW_SERDE: 'SERDE';
-KW_WITH: 'WITH';
-KW_DEFERRED: 'DEFERRED';
-KW_SERDEPROPERTIES: 'SERDEPROPERTIES';
-KW_DBPROPERTIES: 'DBPROPERTIES';
-KW_LIMIT: 'LIMIT';
-KW_SET: 'SET';
-KW_UNSET: 'UNSET';
-KW_TBLPROPERTIES: 'TBLPROPERTIES';
-KW_IDXPROPERTIES: 'IDXPROPERTIES';
-KW_VALUE_TYPE: '$VALUE$';
-KW_ELEM_TYPE: '$ELEM$';
-KW_DEFINED: 'DEFINED';
-KW_CASE: 'CASE';
-KW_WHEN: 'WHEN';
-KW_THEN: 'THEN';
-KW_ELSE: 'ELSE';
-KW_END: 'END';
-KW_MAPJOIN: 'MAPJOIN';
-KW_STREAMTABLE: 'STREAMTABLE';
-KW_CLUSTERSTATUS: 'CLUSTERSTATUS';
-KW_UTC: 'UTC';
-KW_UTCTIMESTAMP: 'UTC_TMESTAMP';
-KW_LONG: 'LONG';
-KW_DELETE: 'DELETE';
-KW_PLUS: 'PLUS';
-KW_MINUS: 'MINUS';
-KW_FETCH: 'FETCH';
-KW_INTERSECT: 'INTERSECT';
-KW_VIEW: 'VIEW';
-KW_IN: 'IN';
-KW_DATABASE: 'DATABASE';
-KW_DATABASES: 'DATABASES';
-KW_MATERIALIZED: 'MATERIALIZED';
-KW_SCHEMA: 'SCHEMA';
-KW_SCHEMAS: 'SCHEMAS';
-KW_GRANT: 'GRANT';
-KW_REVOKE: 'REVOKE';
-KW_SSL: 'SSL';
-KW_UNDO: 'UNDO';
-KW_LOCK: 'LOCK';
-KW_LOCKS: 'LOCKS';
-KW_UNLOCK: 'UNLOCK';
-KW_SHARED: 'SHARED';
-KW_EXCLUSIVE: 'EXCLUSIVE';
-KW_PROCEDURE: 'PROCEDURE';
-KW_UNSIGNED: 'UNSIGNED';
-KW_WHILE: 'WHILE';
-KW_READ: 'READ';
-KW_READS: 'READS';
-KW_PURGE: 'PURGE';
-KW_RANGE: 'RANGE';
-KW_ANALYZE: 'ANALYZE';
-KW_BEFORE: 'BEFORE';
-KW_BETWEEN: 'BETWEEN';
-KW_BOTH: 'BOTH';
-KW_BINARY: 'BINARY';
-KW_CROSS: 'CROSS';
-KW_CONTINUE: 'CONTINUE';
-KW_CURSOR: 'CURSOR';
-KW_TRIGGER: 'TRIGGER';
-KW_RECORDREADER: 'RECORDREADER';
-KW_RECORDWRITER: 'RECORDWRITER';
-KW_SEMI: 'SEMI';
-KW_LATERAL: 'LATERAL';
-KW_TOUCH: 'TOUCH';
-KW_ARCHIVE: 'ARCHIVE';
-KW_UNARCHIVE: 'UNARCHIVE';
-KW_COMPUTE: 'COMPUTE';
-KW_STATISTICS: 'STATISTICS';
-KW_USE: 'USE';
-KW_OPTION: 'OPTION';
-KW_CONCATENATE: 'CONCATENATE';
-KW_SHOW_DATABASE: 'SHOW_DATABASE';
-KW_UPDATE: 'UPDATE';
-KW_RESTRICT: 'RESTRICT';
-KW_CASCADE: 'CASCADE';
-KW_SKEWED: 'SKEWED';
-KW_ROLLUP: 'ROLLUP';
-KW_CUBE: 'CUBE';
-KW_DIRECTORIES: 'DIRECTORIES';
-KW_FOR: 'FOR';
-KW_WINDOW: 'WINDOW';
-KW_UNBOUNDED: 'UNBOUNDED';
-KW_PRECEDING: 'PRECEDING';
-KW_FOLLOWING: 'FOLLOWING';
-KW_CURRENT: 'CURRENT';
-KW_CURRENT_DATE: 'CURRENT_DATE';
-KW_CURRENT_TIMESTAMP: 'CURRENT_TIMESTAMP';
-KW_LESS: 'LESS';
-KW_MORE: 'MORE';
-KW_OVER: 'OVER';
-KW_GROUPING: 'GROUPING';
-KW_SETS: 'SETS';
-KW_TRUNCATE: 'TRUNCATE';
-KW_NOSCAN: 'NOSCAN';
-KW_PARTIALSCAN: 'PARTIALSCAN';
-KW_USER: 'USER';
-KW_ROLE: 'ROLE';
-KW_ROLES: 'ROLES';
-KW_INNER: 'INNER';
-KW_EXCHANGE: 'EXCHANGE';
-KW_URI: 'URI';
-KW_SERVER : 'SERVER';
-KW_ADMIN: 'ADMIN';
-KW_OWNER: 'OWNER';
-KW_PRINCIPALS: 'PRINCIPALS';
-KW_COMPACT: 'COMPACT';
-KW_COMPACTIONS: 'COMPACTIONS';
-KW_TRANSACTIONS: 'TRANSACTIONS';
-KW_REWRITE : 'REWRITE';
-KW_AUTHORIZATION: 'AUTHORIZATION';
-KW_CONF: 'CONF';
-KW_VALUES: 'VALUES';
-KW_RELOAD: 'RELOAD';
-KW_YEAR: 'YEAR'|'YEARS';
-KW_MONTH: 'MONTH'|'MONTHS';
-KW_DAY: 'DAY'|'DAYS';
-KW_HOUR: 'HOUR'|'HOURS';
-KW_MINUTE: 'MINUTE'|'MINUTES';
-KW_SECOND: 'SECOND'|'SECONDS';
-KW_START: 'START';
-KW_TRANSACTION: 'TRANSACTION';
-KW_COMMIT: 'COMMIT';
-KW_ROLLBACK: 'ROLLBACK';
-KW_WORK: 'WORK';
-KW_ONLY: 'ONLY';
-KW_WRITE: 'WRITE';
-KW_ISOLATION: 'ISOLATION';
-KW_LEVEL: 'LEVEL';
-KW_SNAPSHOT: 'SNAPSHOT';
-KW_AUTOCOMMIT: 'AUTOCOMMIT';
-KW_REFRESH: 'REFRESH';
-KW_OPTIONS: 'OPTIONS';
-KW_WEEK: 'WEEK'|'WEEKS';
-KW_MILLISECOND: 'MILLISECOND'|'MILLISECONDS';
-KW_MICROSECOND: 'MICROSECOND'|'MICROSECONDS';
-KW_CLEAR: 'CLEAR';
-KW_LAZY: 'LAZY';
-KW_CACHE: 'CACHE';
-KW_UNCACHE: 'UNCACHE';
-KW_DFS: 'DFS';
-
-KW_NATURAL: 'NATURAL';
-
-// Operators
-// NOTE: if you add a new function/operator, add it to sysFuncNames so that describe function _FUNC_ will work.
-
-DOT : '.'; // generated as a part of Number rule
-COLON : ':' ;
-COMMA : ',' ;
-SEMICOLON : ';' ;
-
-LPAREN : '(' ;
-RPAREN : ')' ;
-LSQUARE : '[' ;
-RSQUARE : ']' ;
-LCURLY : '{';
-RCURLY : '}';
-
-EQUAL : '=' | '==';
-EQUAL_NS : '<=>';
-NOTEQUAL : '<>' | '!=';
-LESSTHANOREQUALTO : '<=';
-LESSTHAN : '<';
-GREATERTHANOREQUALTO : '>=';
-GREATERTHAN : '>';
-
-DIVIDE : '/';
-PLUS : '+';
-MINUS : '-';
-STAR : '*';
-MOD : '%';
-DIV : 'DIV';
-
-AMPERSAND : '&';
-TILDE : '~';
-BITWISEOR : '|';
-BITWISEXOR : '^';
-QUESTION : '?';
-DOLLAR : '$';
-
-// LITERALS
-fragment
-Letter
- : 'a'..'z' | 'A'..'Z'
- ;
-
-fragment
-HexDigit
- : 'a'..'f' | 'A'..'F'
- ;
-
-fragment
-Digit
- :
- '0'..'9'
- ;
-
-fragment
-Exponent
- :
- ('e' | 'E') ( PLUS|MINUS )? (Digit)+
- ;
-
-fragment
-RegexComponent
- : 'a'..'z' | 'A'..'Z' | '0'..'9' | '_'
- | PLUS | STAR | QUESTION | MINUS | DOT
- | LPAREN | RPAREN | LSQUARE | RSQUARE | LCURLY | RCURLY
- | BITWISEXOR | BITWISEOR | DOLLAR | '!'
- ;
-
-StringLiteral
- :
- ( '\'' ( ~('\''|'\\') | ('\\' .) )* '\''
- | '\"' ( ~('\"'|'\\') | ('\\' .) )* '\"'
- )+
- ;
-
-BigintLiteral
- :
- (Digit)+ 'L'
- ;
-
-SmallintLiteral
- :
- (Digit)+ 'S'
- ;
-
-TinyintLiteral
- :
- (Digit)+ 'Y'
- ;
-
-DoubleLiteral
- :
- Number 'D'
- ;
-
-ByteLengthLiteral
- :
- (Digit)+ ('b' | 'B' | 'k' | 'K' | 'm' | 'M' | 'g' | 'G')
- ;
-
-Number
- :
- ((Digit+ (DOT Digit*)?) | (DOT Digit+)) Exponent?
- ;
-
-/*
-An Identifier can be:
-- tableName
-- columnName
-- select expr alias
-- lateral view aliases
-- database name
-- view name
-- subquery alias
-- function name
-- ptf argument identifier
-- index name
-- property name for: db,tbl,partition...
-- fileFormat
-- role name
-- privilege name
-- principal name
-- macro name
-- hint name
-- window name
-*/
-Identifier
- :
- (Letter | Digit | '_')+
- | {allowQuotedId()}? QuotedIdentifier /* though at the language level we allow all Identifiers to be QuotedIdentifiers;
- at the API level only columns are allowed to be of this form */
- | '`' RegexComponent+ '`'
- ;
-
-fragment
-QuotedIdentifier
- :
- '`' ( '``' | ~('`') )* '`' { setText(getText().replaceAll("``", "`")); }
- ;
-
-WS : (' '|'\r'|'\t'|'\n') {$channel=HIDDEN;}
- ;
-
-COMMENT
- : '--' (~('\n'|'\r'))*
- { $channel=HIDDEN; }
- ;
-
-/* Prevent that the lexer swallows unknown characters. */
-ANY
- :.
- ;
diff --git a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g
deleted file mode 100644
index f0c236859d..0000000000
--- a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g
+++ /dev/null
@@ -1,2596 +0,0 @@
-/**
- Licensed to the Apache Software Foundation (ASF) under one or more
- contributor license agreements. See the NOTICE file distributed with
- this work for additional information regarding copyright ownership.
- The ASF licenses this file to You under the Apache License, Version 2.0
- (the "License"); you may not use this file except in compliance with
- the License. You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License.
-
- This file is an adaptation of Hive's org/apache/hadoop/hive/ql/HiveParser.g grammar.
-*/
-parser grammar SparkSqlParser;
-
-options
-{
-tokenVocab=SparkSqlLexer;
-output=AST;
-ASTLabelType=CommonTree;
-backtrack=false;
-k=3;
-}
-import SelectClauseParser, FromClauseParser, IdentifiersParser, KeywordParser, ExpressionParser;
-
-tokens {
-TOK_INSERT;
-TOK_QUERY;
-TOK_SELECT;
-TOK_SELECTDI;
-TOK_SELEXPR;
-TOK_FROM;
-TOK_TAB;
-TOK_PARTSPEC;
-TOK_PARTVAL;
-TOK_DIR;
-TOK_TABREF;
-TOK_SUBQUERY;
-TOK_INSERT_INTO;
-TOK_DESTINATION;
-TOK_ALLCOLREF;
-TOK_TABLE_OR_COL;
-TOK_FUNCTION;
-TOK_FUNCTIONDI;
-TOK_FUNCTIONSTAR;
-TOK_WHERE;
-TOK_OP_EQ;
-TOK_OP_NE;
-TOK_OP_LE;
-TOK_OP_LT;
-TOK_OP_GE;
-TOK_OP_GT;
-TOK_OP_DIV;
-TOK_OP_ADD;
-TOK_OP_SUB;
-TOK_OP_MUL;
-TOK_OP_MOD;
-TOK_OP_BITAND;
-TOK_OP_BITNOT;
-TOK_OP_BITOR;
-TOK_OP_BITXOR;
-TOK_OP_AND;
-TOK_OP_OR;
-TOK_OP_NOT;
-TOK_OP_LIKE;
-TOK_TRUE;
-TOK_FALSE;
-TOK_TRANSFORM;
-TOK_SERDE;
-TOK_SERDENAME;
-TOK_SERDEPROPS;
-TOK_EXPLIST;
-TOK_ALIASLIST;
-TOK_GROUPBY;
-TOK_ROLLUP_GROUPBY;
-TOK_CUBE_GROUPBY;
-TOK_GROUPING_SETS;
-TOK_GROUPING_SETS_EXPRESSION;
-TOK_HAVING;
-TOK_ORDERBY;
-TOK_CLUSTERBY;
-TOK_DISTRIBUTEBY;
-TOK_SORTBY;
-TOK_UNIONALL;
-TOK_UNIONDISTINCT;
-TOK_EXCEPT;
-TOK_INTERSECT;
-TOK_JOIN;
-TOK_LEFTOUTERJOIN;
-TOK_RIGHTOUTERJOIN;
-TOK_FULLOUTERJOIN;
-TOK_UNIQUEJOIN;
-TOK_CROSSJOIN;
-TOK_NATURALJOIN;
-TOK_NATURALLEFTOUTERJOIN;
-TOK_NATURALRIGHTOUTERJOIN;
-TOK_NATURALFULLOUTERJOIN;
-TOK_LOAD;
-TOK_EXPORT;
-TOK_IMPORT;
-TOK_REPLICATION;
-TOK_METADATA;
-TOK_NULL;
-TOK_ISNULL;
-TOK_ISNOTNULL;
-TOK_TINYINT;
-TOK_SMALLINT;
-TOK_INT;
-TOK_BIGINT;
-TOK_BOOLEAN;
-TOK_FLOAT;
-TOK_DOUBLE;
-TOK_DATE;
-TOK_DATELITERAL;
-TOK_DATETIME;
-TOK_TIMESTAMP;
-TOK_TIMESTAMPLITERAL;
-TOK_INTERVAL;
-TOK_INTERVAL_YEAR_MONTH;
-TOK_INTERVAL_YEAR_MONTH_LITERAL;
-TOK_INTERVAL_DAY_TIME;
-TOK_INTERVAL_DAY_TIME_LITERAL;
-TOK_INTERVAL_YEAR_LITERAL;
-TOK_INTERVAL_MONTH_LITERAL;
-TOK_INTERVAL_WEEK_LITERAL;
-TOK_INTERVAL_DAY_LITERAL;
-TOK_INTERVAL_HOUR_LITERAL;
-TOK_INTERVAL_MINUTE_LITERAL;
-TOK_INTERVAL_SECOND_LITERAL;
-TOK_INTERVAL_MILLISECOND_LITERAL;
-TOK_INTERVAL_MICROSECOND_LITERAL;
-TOK_STRING;
-TOK_CHAR;
-TOK_VARCHAR;
-TOK_BINARY;
-TOK_DECIMAL;
-TOK_LIST;
-TOK_STRUCT;
-TOK_MAP;
-TOK_UNIONTYPE;
-TOK_COLTYPELIST;
-TOK_CREATEDATABASE;
-TOK_CREATETABLE;
-TOK_CREATETABLEUSING;
-TOK_TRUNCATETABLE;
-TOK_CREATEINDEX;
-TOK_CREATEINDEX_INDEXTBLNAME;
-TOK_DEFERRED_REBUILDINDEX;
-TOK_DROPINDEX;
-TOK_LIKETABLE;
-TOK_DESCTABLE;
-TOK_DESCFUNCTION;
-TOK_ALTERTABLE;
-TOK_ALTERTABLE_RENAME;
-TOK_ALTERTABLE_ADDCOLS;
-TOK_ALTERTABLE_RENAMECOL;
-TOK_ALTERTABLE_RENAMEPART;
-TOK_ALTERTABLE_REPLACECOLS;
-TOK_ALTERTABLE_ADDPARTS;
-TOK_ALTERTABLE_DROPPARTS;
-TOK_ALTERTABLE_PARTCOLTYPE;
-TOK_ALTERTABLE_MERGEFILES;
-TOK_ALTERTABLE_TOUCH;
-TOK_ALTERTABLE_ARCHIVE;
-TOK_ALTERTABLE_UNARCHIVE;
-TOK_ALTERTABLE_SERDEPROPERTIES;
-TOK_ALTERTABLE_SERIALIZER;
-TOK_ALTERTABLE_UPDATECOLSTATS;
-TOK_TABLE_PARTITION;
-TOK_ALTERTABLE_FILEFORMAT;
-TOK_ALTERTABLE_LOCATION;
-TOK_ALTERTABLE_PROPERTIES;
-TOK_ALTERTABLE_CHANGECOL_AFTER_POSITION;
-TOK_ALTERTABLE_DROPPROPERTIES;
-TOK_ALTERTABLE_SKEWED;
-TOK_ALTERTABLE_EXCHANGEPARTITION;
-TOK_ALTERTABLE_SKEWED_LOCATION;
-TOK_ALTERTABLE_BUCKETS;
-TOK_ALTERTABLE_CLUSTER_SORT;
-TOK_ALTERTABLE_COMPACT;
-TOK_ALTERINDEX_REBUILD;
-TOK_ALTERINDEX_PROPERTIES;
-TOK_MSCK;
-TOK_SHOWDATABASES;
-TOK_SHOWTABLES;
-TOK_SHOWCOLUMNS;
-TOK_SHOWFUNCTIONS;
-TOK_SHOWPARTITIONS;
-TOK_SHOW_CREATEDATABASE;
-TOK_SHOW_CREATETABLE;
-TOK_SHOW_TABLESTATUS;
-TOK_SHOW_TBLPROPERTIES;
-TOK_SHOWLOCKS;
-TOK_SHOWCONF;
-TOK_LOCKTABLE;
-TOK_UNLOCKTABLE;
-TOK_LOCKDB;
-TOK_UNLOCKDB;
-TOK_SWITCHDATABASE;
-TOK_DROPDATABASE;
-TOK_DROPTABLE;
-TOK_DATABASECOMMENT;
-TOK_TABCOLLIST;
-TOK_TABCOL;
-TOK_TABLECOMMENT;
-TOK_TABLEPARTCOLS;
-TOK_TABLEROWFORMAT;
-TOK_TABLEROWFORMATFIELD;
-TOK_TABLEROWFORMATCOLLITEMS;
-TOK_TABLEROWFORMATMAPKEYS;
-TOK_TABLEROWFORMATLINES;
-TOK_TABLEROWFORMATNULL;
-TOK_TABLEFILEFORMAT;
-TOK_FILEFORMAT_GENERIC;
-TOK_OFFLINE;
-TOK_ENABLE;
-TOK_DISABLE;
-TOK_READONLY;
-TOK_NO_DROP;
-TOK_STORAGEHANDLER;
-TOK_NOT_CLUSTERED;
-TOK_NOT_SORTED;
-TOK_TABCOLNAME;
-TOK_TABLELOCATION;
-TOK_PARTITIONLOCATION;
-TOK_TABLEBUCKETSAMPLE;
-TOK_TABLESPLITSAMPLE;
-TOK_PERCENT;
-TOK_LENGTH;
-TOK_ROWCOUNT;
-TOK_TMP_FILE;
-TOK_TABSORTCOLNAMEASC;
-TOK_TABSORTCOLNAMEDESC;
-TOK_STRINGLITERALSEQUENCE;
-TOK_CREATEFUNCTION;
-TOK_DROPFUNCTION;
-TOK_RELOADFUNCTION;
-TOK_CREATEMACRO;
-TOK_DROPMACRO;
-TOK_TEMPORARY;
-TOK_CREATEVIEW;
-TOK_DROPVIEW;
-TOK_ALTERVIEW;
-TOK_ALTERVIEW_PROPERTIES;
-TOK_ALTERVIEW_DROPPROPERTIES;
-TOK_ALTERVIEW_ADDPARTS;
-TOK_ALTERVIEW_DROPPARTS;
-TOK_ALTERVIEW_RENAME;
-TOK_VIEWPARTCOLS;
-TOK_EXPLAIN;
-TOK_EXPLAIN_SQ_REWRITE;
-TOK_TABLESERIALIZER;
-TOK_TABLEPROPERTIES;
-TOK_TABLEPROPLIST;
-TOK_INDEXPROPERTIES;
-TOK_INDEXPROPLIST;
-TOK_TABTYPE;
-TOK_LIMIT;
-TOK_TABLEPROPERTY;
-TOK_IFEXISTS;
-TOK_IFNOTEXISTS;
-TOK_ORREPLACE;
-TOK_HINTLIST;
-TOK_HINT;
-TOK_MAPJOIN;
-TOK_STREAMTABLE;
-TOK_HINTARGLIST;
-TOK_USERSCRIPTCOLNAMES;
-TOK_USERSCRIPTCOLSCHEMA;
-TOK_RECORDREADER;
-TOK_RECORDWRITER;
-TOK_LEFTSEMIJOIN;
-TOK_ANTIJOIN;
-TOK_LATERAL_VIEW;
-TOK_LATERAL_VIEW_OUTER;
-TOK_TABALIAS;
-TOK_ANALYZE;
-TOK_CREATEROLE;
-TOK_DROPROLE;
-TOK_GRANT;
-TOK_REVOKE;
-TOK_SHOW_GRANT;
-TOK_PRIVILEGE_LIST;
-TOK_PRIVILEGE;
-TOK_PRINCIPAL_NAME;
-TOK_USER;
-TOK_GROUP;
-TOK_ROLE;
-TOK_RESOURCE_ALL;
-TOK_GRANT_WITH_OPTION;
-TOK_GRANT_WITH_ADMIN_OPTION;
-TOK_ADMIN_OPTION_FOR;
-TOK_GRANT_OPTION_FOR;
-TOK_PRIV_ALL;
-TOK_PRIV_ALTER_METADATA;
-TOK_PRIV_ALTER_DATA;
-TOK_PRIV_DELETE;
-TOK_PRIV_DROP;
-TOK_PRIV_INDEX;
-TOK_PRIV_INSERT;
-TOK_PRIV_LOCK;
-TOK_PRIV_SELECT;
-TOK_PRIV_SHOW_DATABASE;
-TOK_PRIV_CREATE;
-TOK_PRIV_OBJECT;
-TOK_PRIV_OBJECT_COL;
-TOK_GRANT_ROLE;
-TOK_REVOKE_ROLE;
-TOK_SHOW_ROLE_GRANT;
-TOK_SHOW_ROLES;
-TOK_SHOW_SET_ROLE;
-TOK_SHOW_ROLE_PRINCIPALS;
-TOK_SHOWINDEXES;
-TOK_SHOWDBLOCKS;
-TOK_INDEXCOMMENT;
-TOK_DESCDATABASE;
-TOK_DATABASEPROPERTIES;
-TOK_DATABASELOCATION;
-TOK_DBPROPLIST;
-TOK_ALTERDATABASE_PROPERTIES;
-TOK_ALTERDATABASE_OWNER;
-TOK_TABNAME;
-TOK_TABSRC;
-TOK_RESTRICT;
-TOK_CASCADE;
-TOK_TABLESKEWED;
-TOK_TABCOLVALUE;
-TOK_TABCOLVALUE_PAIR;
-TOK_TABCOLVALUES;
-TOK_SKEWED_LOCATIONS;
-TOK_SKEWED_LOCATION_LIST;
-TOK_SKEWED_LOCATION_MAP;
-TOK_STOREDASDIRS;
-TOK_PARTITIONINGSPEC;
-TOK_PTBLFUNCTION;
-TOK_WINDOWDEF;
-TOK_WINDOWSPEC;
-TOK_WINDOWVALUES;
-TOK_WINDOWRANGE;
-TOK_SUBQUERY_EXPR;
-TOK_SUBQUERY_OP;
-TOK_SUBQUERY_OP_NOTIN;
-TOK_SUBQUERY_OP_NOTEXISTS;
-TOK_DB_TYPE;
-TOK_TABLE_TYPE;
-TOK_CTE;
-TOK_ARCHIVE;
-TOK_FILE;
-TOK_JAR;
-TOK_RESOURCE_URI;
-TOK_RESOURCE_LIST;
-TOK_SHOW_COMPACTIONS;
-TOK_SHOW_TRANSACTIONS;
-TOK_DELETE_FROM;
-TOK_UPDATE_TABLE;
-TOK_SET_COLUMNS_CLAUSE;
-TOK_VALUE_ROW;
-TOK_VALUES_TABLE;
-TOK_VIRTUAL_TABLE;
-TOK_VIRTUAL_TABREF;
-TOK_ANONYMOUS;
-TOK_COL_NAME;
-TOK_URI_TYPE;
-TOK_SERVER_TYPE;
-TOK_START_TRANSACTION;
-TOK_ISOLATION_LEVEL;
-TOK_ISOLATION_SNAPSHOT;
-TOK_TXN_ACCESS_MODE;
-TOK_TXN_READ_ONLY;
-TOK_TXN_READ_WRITE;
-TOK_COMMIT;
-TOK_ROLLBACK;
-TOK_SET_AUTOCOMMIT;
-TOK_REFRESHTABLE;
-TOK_TABLEPROVIDER;
-TOK_TABLEOPTIONS;
-TOK_TABLEOPTION;
-TOK_CACHETABLE;
-TOK_UNCACHETABLE;
-TOK_CLEARCACHE;
-TOK_SETCONFIG;
-TOK_DFS;
-TOK_ADDFILE;
-TOK_ADDJAR;
-TOK_USING;
-}
-
-
-// Package headers
-@header {
-package org.apache.spark.sql.catalyst.parser;
-
-import java.util.Arrays;
-import java.util.Collection;
-import java.util.HashMap;
-}
-
-
-@members {
- Stack msgs = new Stack<String>();
-
- private static HashMap<String, String> xlateMap;
- static {
- //this is used to support auto completion in CLI
- xlateMap = new HashMap<String, String>();
-
- // Keywords
- xlateMap.put("KW_TRUE", "TRUE");
- xlateMap.put("KW_FALSE", "FALSE");
- xlateMap.put("KW_ALL", "ALL");
- xlateMap.put("KW_NONE", "NONE");
- xlateMap.put("KW_AND", "AND");
- xlateMap.put("KW_OR", "OR");
- xlateMap.put("KW_NOT", "NOT");
- xlateMap.put("KW_LIKE", "LIKE");
-
- xlateMap.put("KW_ASC", "ASC");
- xlateMap.put("KW_DESC", "DESC");
- xlateMap.put("KW_ORDER", "ORDER");
- xlateMap.put("KW_BY", "BY");
- xlateMap.put("KW_GROUP", "GROUP");
- xlateMap.put("KW_WHERE", "WHERE");
- xlateMap.put("KW_FROM", "FROM");
- xlateMap.put("KW_AS", "AS");
- xlateMap.put("KW_SELECT", "SELECT");
- xlateMap.put("KW_DISTINCT", "DISTINCT");
- xlateMap.put("KW_INSERT", "INSERT");
- xlateMap.put("KW_OVERWRITE", "OVERWRITE");
- xlateMap.put("KW_OUTER", "OUTER");
- xlateMap.put("KW_JOIN", "JOIN");
- xlateMap.put("KW_LEFT", "LEFT");
- xlateMap.put("KW_RIGHT", "RIGHT");
- xlateMap.put("KW_FULL", "FULL");
- xlateMap.put("KW_ON", "ON");
- xlateMap.put("KW_PARTITION", "PARTITION");
- xlateMap.put("KW_PARTITIONS", "PARTITIONS");
- xlateMap.put("KW_TABLE", "TABLE");
- xlateMap.put("KW_TABLES", "TABLES");
- xlateMap.put("KW_TBLPROPERTIES", "TBLPROPERTIES");
- xlateMap.put("KW_SHOW", "SHOW");
- xlateMap.put("KW_MSCK", "MSCK");
- xlateMap.put("KW_DIRECTORY", "DIRECTORY");
- xlateMap.put("KW_LOCAL", "LOCAL");
- xlateMap.put("KW_TRANSFORM", "TRANSFORM");
- xlateMap.put("KW_USING", "USING");
- xlateMap.put("KW_CLUSTER", "CLUSTER");
- xlateMap.put("KW_DISTRIBUTE", "DISTRIBUTE");
- xlateMap.put("KW_SORT", "SORT");
- xlateMap.put("KW_UNION", "UNION");
- xlateMap.put("KW_LOAD", "LOAD");
- xlateMap.put("KW_DATA", "DATA");
- xlateMap.put("KW_INPATH", "INPATH");
- xlateMap.put("KW_IS", "IS");
- xlateMap.put("KW_NULL", "NULL");
- xlateMap.put("KW_CREATE", "CREATE");
- xlateMap.put("KW_EXTERNAL", "EXTERNAL");
- xlateMap.put("KW_ALTER", "ALTER");
- xlateMap.put("KW_DESCRIBE", "DESCRIBE");
- xlateMap.put("KW_DROP", "DROP");
- xlateMap.put("KW_RENAME", "RENAME");
- xlateMap.put("KW_TO", "TO");
- xlateMap.put("KW_COMMENT", "COMMENT");
- xlateMap.put("KW_BOOLEAN", "BOOLEAN");
- xlateMap.put("KW_TINYINT", "TINYINT");
- xlateMap.put("KW_SMALLINT", "SMALLINT");
- xlateMap.put("KW_INT", "INT");
- xlateMap.put("KW_BIGINT", "BIGINT");
- xlateMap.put("KW_FLOAT", "FLOAT");
- xlateMap.put("KW_DOUBLE", "DOUBLE");
- xlateMap.put("KW_DATE", "DATE");
- xlateMap.put("KW_DATETIME", "DATETIME");
- xlateMap.put("KW_TIMESTAMP", "TIMESTAMP");
- xlateMap.put("KW_STRING", "STRING");
- xlateMap.put("KW_BINARY", "BINARY");
- xlateMap.put("KW_ARRAY", "ARRAY");
- xlateMap.put("KW_MAP", "MAP");
- xlateMap.put("KW_REDUCE", "REDUCE");
- xlateMap.put("KW_PARTITIONED", "PARTITIONED");
- xlateMap.put("KW_CLUSTERED", "CLUSTERED");
- xlateMap.put("KW_SORTED", "SORTED");
- xlateMap.put("KW_INTO", "INTO");
- xlateMap.put("KW_BUCKETS", "BUCKETS");
- xlateMap.put("KW_ROW", "ROW");
- xlateMap.put("KW_FORMAT", "FORMAT");
- xlateMap.put("KW_DELIMITED", "DELIMITED");
- xlateMap.put("KW_FIELDS", "FIELDS");
- xlateMap.put("KW_TERMINATED", "TERMINATED");
- xlateMap.put("KW_COLLECTION", "COLLECTION");
- xlateMap.put("KW_ITEMS", "ITEMS");
- xlateMap.put("KW_KEYS", "KEYS");
- xlateMap.put("KW_KEY_TYPE", "\$KEY\$");
- xlateMap.put("KW_LINES", "LINES");
- xlateMap.put("KW_STORED", "STORED");
- xlateMap.put("KW_SEQUENCEFILE", "SEQUENCEFILE");
- xlateMap.put("KW_TEXTFILE", "TEXTFILE");
- xlateMap.put("KW_INPUTFORMAT", "INPUTFORMAT");
- xlateMap.put("KW_OUTPUTFORMAT", "OUTPUTFORMAT");
- xlateMap.put("KW_LOCATION", "LOCATION");
- xlateMap.put("KW_TABLESAMPLE", "TABLESAMPLE");
- xlateMap.put("KW_BUCKET", "BUCKET");
- xlateMap.put("KW_OUT", "OUT");
- xlateMap.put("KW_OF", "OF");
- xlateMap.put("KW_CAST", "CAST");
- xlateMap.put("KW_ADD", "ADD");
- xlateMap.put("KW_REPLACE", "REPLACE");
- xlateMap.put("KW_COLUMNS", "COLUMNS");
- xlateMap.put("KW_RLIKE", "RLIKE");
- xlateMap.put("KW_REGEXP", "REGEXP");
- xlateMap.put("KW_TEMPORARY", "TEMPORARY");
- xlateMap.put("KW_FUNCTION", "FUNCTION");
- xlateMap.put("KW_EXPLAIN", "EXPLAIN");
- xlateMap.put("KW_EXTENDED", "EXTENDED");
- xlateMap.put("KW_SERDE", "SERDE");
- xlateMap.put("KW_WITH", "WITH");
- xlateMap.put("KW_SERDEPROPERTIES", "SERDEPROPERTIES");
- xlateMap.put("KW_LIMIT", "LIMIT");
- xlateMap.put("KW_SET", "SET");
- xlateMap.put("KW_PROPERTIES", "TBLPROPERTIES");
- xlateMap.put("KW_VALUE_TYPE", "\$VALUE\$");
- xlateMap.put("KW_ELEM_TYPE", "\$ELEM\$");
- xlateMap.put("KW_DEFINED", "DEFINED");
- xlateMap.put("KW_SUBQUERY", "SUBQUERY");
- xlateMap.put("KW_REWRITE", "REWRITE");
- xlateMap.put("KW_UPDATE", "UPDATE");
- xlateMap.put("KW_VALUES", "VALUES");
- xlateMap.put("KW_PURGE", "PURGE");
- xlateMap.put("KW_WEEK", "WEEK");
- xlateMap.put("KW_MILLISECOND", "MILLISECOND");
- xlateMap.put("KW_MICROSECOND", "MICROSECOND");
- xlateMap.put("KW_CLEAR", "CLEAR");
- xlateMap.put("KW_LAZY", "LAZY");
- xlateMap.put("KW_CACHE", "CACHE");
- xlateMap.put("KW_UNCACHE", "UNCACHE");
- xlateMap.put("KW_DFS", "DFS");
-
- // Operators
- xlateMap.put("DOT", ".");
- xlateMap.put("COLON", ":");
- xlateMap.put("COMMA", ",");
- xlateMap.put("SEMICOLON", ");");
-
- xlateMap.put("LPAREN", "(");
- xlateMap.put("RPAREN", ")");
- xlateMap.put("LSQUARE", "[");
- xlateMap.put("RSQUARE", "]");
-
- xlateMap.put("EQUAL", "=");
- xlateMap.put("NOTEQUAL", "<>");
- xlateMap.put("EQUAL_NS", "<=>");
- xlateMap.put("LESSTHANOREQUALTO", "<=");
- xlateMap.put("LESSTHAN", "<");
- xlateMap.put("GREATERTHANOREQUALTO", ">=");
- xlateMap.put("GREATERTHAN", ">");
-
- xlateMap.put("DIVIDE", "/");
- xlateMap.put("PLUS", "+");
- xlateMap.put("MINUS", "-");
- xlateMap.put("STAR", "*");
- xlateMap.put("MOD", "\%");
-
- xlateMap.put("AMPERSAND", "&");
- xlateMap.put("TILDE", "~");
- xlateMap.put("BITWISEOR", "|");
- xlateMap.put("BITWISEXOR", "^");
- xlateMap.put("CharSetLiteral", "\\'");
- }
-
- public static Collection<String> getKeywords() {
- return xlateMap.values();
- }
-
- private static String xlate(String name) {
-
- String ret = xlateMap.get(name);
- if (ret == null) {
- ret = name;
- }
-
- return ret;
- }
-
- @Override
- public Object recoverFromMismatchedSet(IntStream input,
- RecognitionException re, BitSet follow) throws RecognitionException {
- throw re;
- }
-
- @Override
- public void displayRecognitionError(String[] tokenNames, RecognitionException e) {
- if (reporter != null) {
- reporter.report(this, e, tokenNames);
- }
- }
-
- @Override
- public String getErrorHeader(RecognitionException e) {
- String header = null;
- if (e.charPositionInLine < 0 && input.LT(-1) != null) {
- Token t = input.LT(-1);
- header = "line " + t.getLine() + ":" + t.getCharPositionInLine();
- } else {
- header = super.getErrorHeader(e);
- }
-
- return header;
- }
-
- @Override
- public String getErrorMessage(RecognitionException e, String[] tokenNames) {
- String msg = null;
-
- // Translate the token names to something that the user can understand
- String[] xlateNames = new String[tokenNames.length];
- for (int i = 0; i < tokenNames.length; ++i) {
- xlateNames[i] = SparkSqlParser.xlate(tokenNames[i]);
- }
-
- if (e instanceof NoViableAltException) {
- @SuppressWarnings("unused")
- NoViableAltException nvae = (NoViableAltException) e;
- // for development, can add
- // "decision=<<"+nvae.grammarDecisionDescription+">>"
- // and "(decision="+nvae.decisionNumber+") and
- // "state "+nvae.stateNumber
- msg = "cannot recognize input near"
- + (input.LT(1) != null ? " " + getTokenErrorDisplay(input.LT(1)) : "")
- + (input.LT(2) != null ? " " + getTokenErrorDisplay(input.LT(2)) : "")
- + (input.LT(3) != null ? " " + getTokenErrorDisplay(input.LT(3)) : "");
- } else if (e instanceof MismatchedTokenException) {
- MismatchedTokenException mte = (MismatchedTokenException) e;
- msg = super.getErrorMessage(e, xlateNames) + (input.LT(-1) == null ? "":" near '" + input.LT(-1).getText()) + "'";
- } else if (e instanceof FailedPredicateException) {
- FailedPredicateException fpe = (FailedPredicateException) e;
- msg = "Failed to recognize predicate '" + fpe.token.getText() + "'. Failed rule: '" + fpe.ruleName + "'";
- } else {
- msg = super.getErrorMessage(e, xlateNames);
- }
-
- if (msgs.size() > 0) {
- msg = msg + " in " + msgs.peek();
- }
- return msg;
- }
-
- public void pushMsg(String msg, RecognizerSharedState state) {
- // ANTLR generated code does not wrap the @init code wit this backtracking check,
- // even if the matching @after has it. If we have parser rules with that are doing
- // some lookahead with syntactic predicates this can cause the push() and pop() calls
- // to become unbalanced, so make sure both push/pop check the backtracking state.
- if (state.backtracking == 0) {
- msgs.push(msg);
- }
- }
-
- public void popMsg(RecognizerSharedState state) {
- if (state.backtracking == 0) {
- Object o = msgs.pop();
- }
- }
-
- // counter to generate unique union aliases
- private int aliasCounter;
- private String generateUnionAlias() {
- return "u_" + (++aliasCounter);
- }
- private char [] excludedCharForColumnName = {'.', ':'};
- private boolean containExcludedCharForCreateTableColumnName(String input) {
- if (input.length() > 0) {
- if (input.charAt(0) == '`' && input.charAt(input.length() - 1) == '`') {
- // When column name is backquoted, we don't care about excluded chars.
- return false;
- }
- }
- for(char c : excludedCharForColumnName) {
- if(input.indexOf(c)>-1) {
- return true;
- }
- }
- return false;
- }
- private CommonTree throwSetOpException() throws RecognitionException {
- throw new FailedPredicateException(input, "orderByClause clusterByClause distributeByClause sortByClause limitClause can only be applied to the whole union.", "");
- }
- private CommonTree throwColumnNameException() throws RecognitionException {
- throw new FailedPredicateException(input, Arrays.toString(excludedCharForColumnName) + " can not be used in column name in create table statement.", "");
- }
-
- private ParserConf parserConf;
- private ParseErrorReporter reporter;
-
- public void configure(ParserConf parserConf, ParseErrorReporter reporter) {
- this.parserConf = parserConf;
- this.reporter = reporter;
- }
-
- protected boolean useSQL11ReservedKeywordsForIdentifier() {
- if (parserConf == null) {
- return true;
- }
- return !parserConf.supportSQL11ReservedKeywords();
- }
-}
-
-@rulecatch {
-catch (RecognitionException e) {
- reportError(e);
- throw e;
-}
-}
-
-// starting rule
-statement
- : explainStatement EOF
- | execStatement EOF
- | KW_ADD KW_JAR -> ^(TOK_ADDJAR)
- | KW_ADD KW_FILE -> ^(TOK_ADDFILE)
- | KW_DFS -> ^(TOK_DFS)
- | (KW_SET)=> KW_SET -> ^(TOK_SETCONFIG)
- ;
-
-// Rule for expression parsing
-singleNamedExpression
- :
- namedExpression EOF
- ;
-
-// Rule for table name parsing
-singleTableName
- :
- tableName EOF
- ;
-
-explainStatement
-@init { pushMsg("explain statement", state); }
-@after { popMsg(state); }
- : KW_EXPLAIN (
- explainOption* execStatement -> ^(TOK_EXPLAIN execStatement explainOption*)
- |
- KW_REWRITE queryStatementExpression[true] -> ^(TOK_EXPLAIN_SQ_REWRITE queryStatementExpression))
- ;
-
-explainOption
-@init { msgs.push("explain option"); }
-@after { msgs.pop(); }
- : KW_EXTENDED|KW_FORMATTED|KW_DEPENDENCY|KW_LOGICAL|KW_AUTHORIZATION
- ;
-
-execStatement
-@init { pushMsg("statement", state); }
-@after { popMsg(state); }
- : queryStatementExpression[true]
- | loadStatement
- | exportStatement
- | importStatement
- | ddlStatement
- | deleteStatement
- | updateStatement
- | sqlTransactionStatement
- | cacheStatement
- ;
-
-loadStatement
-@init { pushMsg("load statement", state); }
-@after { popMsg(state); }
- : KW_LOAD KW_DATA (islocal=KW_LOCAL)? KW_INPATH (path=StringLiteral) (isoverwrite=KW_OVERWRITE)? KW_INTO KW_TABLE (tab=tableOrPartition)
- -> ^(TOK_LOAD $path $tab $islocal? $isoverwrite?)
- ;
-
-replicationClause
-@init { pushMsg("replication clause", state); }
-@after { popMsg(state); }
- : KW_FOR (isMetadataOnly=KW_METADATA)? KW_REPLICATION LPAREN (replId=StringLiteral) RPAREN
- -> ^(TOK_REPLICATION $replId $isMetadataOnly?)
- ;
-
-exportStatement
-@init { pushMsg("export statement", state); }
-@after { popMsg(state); }
- : KW_EXPORT
- KW_TABLE (tab=tableOrPartition)
- KW_TO (path=StringLiteral)
- replicationClause?
- -> ^(TOK_EXPORT $tab $path replicationClause?)
- ;
-
-importStatement
-@init { pushMsg("import statement", state); }
-@after { popMsg(state); }
- : KW_IMPORT
- ((ext=KW_EXTERNAL)? KW_TABLE (tab=tableOrPartition))?
- KW_FROM (path=StringLiteral)
- tableLocation?
- -> ^(TOK_IMPORT $path $tab? $ext? tableLocation?)
- ;
-
-ddlStatement
-@init { pushMsg("ddl statement", state); }
-@after { popMsg(state); }
- : createDatabaseStatement
- | switchDatabaseStatement
- | dropDatabaseStatement
- | createTableStatement
- | dropTableStatement
- | truncateTableStatement
- | alterStatement
- | descStatement
- | refreshStatement
- | showStatement
- | metastoreCheck
- | createViewStatement
- | dropViewStatement
- | createFunctionStatement
- | createMacroStatement
- | createIndexStatement
- | dropIndexStatement
- | dropFunctionStatement
- | reloadFunctionStatement
- | dropMacroStatement
- | analyzeStatement
- | lockStatement
- | unlockStatement
- | lockDatabase
- | unlockDatabase
- | createRoleStatement
- | dropRoleStatement
- | (grantPrivileges) => grantPrivileges
- | (revokePrivileges) => revokePrivileges
- | showGrants
- | showRoleGrants
- | showRolePrincipals
- | showRoles
- | grantRole
- | revokeRole
- | setRole
- | showCurrentRole
- ;
-
-ifExists
-@init { pushMsg("if exists clause", state); }
-@after { popMsg(state); }
- : KW_IF KW_EXISTS
- -> ^(TOK_IFEXISTS)
- ;
-
-restrictOrCascade
-@init { pushMsg("restrict or cascade clause", state); }
-@after { popMsg(state); }
- : KW_RESTRICT
- -> ^(TOK_RESTRICT)
- | KW_CASCADE
- -> ^(TOK_CASCADE)
- ;
-
-ifNotExists
-@init { pushMsg("if not exists clause", state); }
-@after { popMsg(state); }
- : KW_IF KW_NOT KW_EXISTS
- -> ^(TOK_IFNOTEXISTS)
- ;
-
-storedAsDirs
-@init { pushMsg("stored as directories", state); }
-@after { popMsg(state); }
- : KW_STORED KW_AS KW_DIRECTORIES
- -> ^(TOK_STOREDASDIRS)
- ;
-
-orReplace
-@init { pushMsg("or replace clause", state); }
-@after { popMsg(state); }
- : KW_OR KW_REPLACE
- -> ^(TOK_ORREPLACE)
- ;
-
-createDatabaseStatement
-@init { pushMsg("create database statement", state); }
-@after { popMsg(state); }
- : KW_CREATE (KW_DATABASE|KW_SCHEMA)
- ifNotExists?
- name=identifier
- databaseComment?
- dbLocation?
- (KW_WITH KW_DBPROPERTIES dbprops=dbProperties)?
- -> ^(TOK_CREATEDATABASE $name ifNotExists? dbLocation? databaseComment? $dbprops?)
- ;
-
-dbLocation
-@init { pushMsg("database location specification", state); }
-@after { popMsg(state); }
- :
- KW_LOCATION locn=StringLiteral -> ^(TOK_DATABASELOCATION $locn)
- ;
-
-dbProperties
-@init { pushMsg("dbproperties", state); }
-@after { popMsg(state); }
- :
- LPAREN dbPropertiesList RPAREN -> ^(TOK_DATABASEPROPERTIES dbPropertiesList)
- ;
-
-dbPropertiesList
-@init { pushMsg("database properties list", state); }
-@after { popMsg(state); }
- :
- keyValueProperty (COMMA keyValueProperty)* -> ^(TOK_DBPROPLIST keyValueProperty+)
- ;
-
-
-switchDatabaseStatement
-@init { pushMsg("switch database statement", state); }
-@after { popMsg(state); }
- : KW_USE identifier
- -> ^(TOK_SWITCHDATABASE identifier)
- ;
-
-dropDatabaseStatement
-@init { pushMsg("drop database statement", state); }
-@after { popMsg(state); }
- : KW_DROP (KW_DATABASE|KW_SCHEMA) ifExists? identifier restrictOrCascade?
- -> ^(TOK_DROPDATABASE identifier ifExists? restrictOrCascade?)
- ;
-
-databaseComment
-@init { pushMsg("database's comment", state); }
-@after { popMsg(state); }
- : KW_COMMENT comment=StringLiteral
- -> ^(TOK_DATABASECOMMENT $comment)
- ;
-
-createTableStatement
-@init { pushMsg("create table statement", state); }
-@after { popMsg(state); }
- : KW_CREATE (temp=KW_TEMPORARY)? (ext=KW_EXTERNAL)? KW_TABLE ifNotExists? name=tableName
- (
- like=KW_LIKE likeName=tableName
- tableRowFormat?
- tableFileFormat?
- tableLocation?
- tablePropertiesPrefixed?
- -> ^(TOK_CREATETABLE $name $temp? $ext? ifNotExists?
- ^(TOK_LIKETABLE $likeName?)
- tableRowFormat?
- tableFileFormat?
- tableLocation?
- tablePropertiesPrefixed?
- )
- |
- (tableProvider) => tableProvider
- tableOpts?
- (KW_AS selectStatementWithCTE)?
- -> ^(TOK_CREATETABLEUSING $name $temp? ifNotExists?
- tableProvider
- tableOpts?
- selectStatementWithCTE?
- )
- | (LPAREN columnNameTypeList RPAREN)?
- (p=tableProvider?)
- tableOpts?
- tableComment?
- tablePartition?
- tableBuckets?
- tableSkewed?
- tableRowFormat?
- tableFileFormat?
- tableLocation?
- tablePropertiesPrefixed?
- (KW_AS selectStatementWithCTE)?
- -> {p != null}?
- ^(TOK_CREATETABLEUSING $name $temp? ifNotExists?
- columnNameTypeList?
- $p
- tableOpts?
- selectStatementWithCTE?
- )
- ->
- ^(TOK_CREATETABLE $name $temp? $ext? ifNotExists?
- ^(TOK_LIKETABLE $likeName?)
- columnNameTypeList?
- tableComment?
- tablePartition?
- tableBuckets?
- tableSkewed?
- tableRowFormat?
- tableFileFormat?
- tableLocation?
- tablePropertiesPrefixed?
- selectStatementWithCTE?
- )
- )
- ;
-
-truncateTableStatement
-@init { pushMsg("truncate table statement", state); }
-@after { popMsg(state); }
- : KW_TRUNCATE KW_TABLE tablePartitionPrefix (KW_COLUMNS LPAREN columnNameList RPAREN)? -> ^(TOK_TRUNCATETABLE tablePartitionPrefix columnNameList?);
-
-createIndexStatement
-@init { pushMsg("create index statement", state);}
-@after {popMsg(state);}
- : KW_CREATE KW_INDEX indexName=identifier
- KW_ON KW_TABLE tab=tableName LPAREN indexedCols=columnNameList RPAREN
- KW_AS typeName=StringLiteral
- autoRebuild?
- indexPropertiesPrefixed?
- indexTblName?
- tableRowFormat?
- tableFileFormat?
- tableLocation?
- tablePropertiesPrefixed?
- indexComment?
- ->^(TOK_CREATEINDEX $indexName $typeName $tab $indexedCols
- autoRebuild?
- indexPropertiesPrefixed?
- indexTblName?
- tableRowFormat?
- tableFileFormat?
- tableLocation?
- tablePropertiesPrefixed?
- indexComment?)
- ;
-
-indexComment
-@init { pushMsg("comment on an index", state);}
-@after {popMsg(state);}
- :
- KW_COMMENT comment=StringLiteral -> ^(TOK_INDEXCOMMENT $comment)
- ;
-
-autoRebuild
-@init { pushMsg("auto rebuild index", state);}
-@after {popMsg(state);}
- : KW_WITH KW_DEFERRED KW_REBUILD
- ->^(TOK_DEFERRED_REBUILDINDEX)
- ;
-
-indexTblName
-@init { pushMsg("index table name", state);}
-@after {popMsg(state);}
- : KW_IN KW_TABLE indexTbl=tableName
- ->^(TOK_CREATEINDEX_INDEXTBLNAME $indexTbl)
- ;
-
-indexPropertiesPrefixed
-@init { pushMsg("table properties with prefix", state); }
-@after { popMsg(state); }
- :
- KW_IDXPROPERTIES! indexProperties
- ;
-
-indexProperties
-@init { pushMsg("index properties", state); }
-@after { popMsg(state); }
- :
- LPAREN indexPropertiesList RPAREN -> ^(TOK_INDEXPROPERTIES indexPropertiesList)
- ;
-
-indexPropertiesList
-@init { pushMsg("index properties list", state); }
-@after { popMsg(state); }
- :
- keyValueProperty (COMMA keyValueProperty)* -> ^(TOK_INDEXPROPLIST keyValueProperty+)
- ;
-
-dropIndexStatement
-@init { pushMsg("drop index statement", state);}
-@after {popMsg(state);}
- : KW_DROP KW_INDEX ifExists? indexName=identifier KW_ON tab=tableName
- ->^(TOK_DROPINDEX $indexName $tab ifExists?)
- ;
-
-dropTableStatement
-@init { pushMsg("drop statement", state); }
-@after { popMsg(state); }
- : KW_DROP KW_TABLE ifExists? tableName KW_PURGE? replicationClause?
- -> ^(TOK_DROPTABLE tableName ifExists? KW_PURGE? replicationClause?)
- ;
-
-alterStatement
-@init { pushMsg("alter statement", state); }
-@after { popMsg(state); }
- : KW_ALTER KW_TABLE tableName alterTableStatementSuffix -> ^(TOK_ALTERTABLE tableName alterTableStatementSuffix)
- | KW_ALTER KW_VIEW tableName KW_AS? alterViewStatementSuffix -> ^(TOK_ALTERVIEW tableName alterViewStatementSuffix)
- | KW_ALTER KW_INDEX alterIndexStatementSuffix -> alterIndexStatementSuffix
- | KW_ALTER (KW_DATABASE|KW_SCHEMA) alterDatabaseStatementSuffix -> alterDatabaseStatementSuffix
- ;
-
-alterTableStatementSuffix
-@init { pushMsg("alter table statement", state); }
-@after { popMsg(state); }
- : (alterStatementSuffixRename[true]) => alterStatementSuffixRename[true]
- | alterStatementSuffixDropPartitions[true]
- | alterStatementSuffixAddPartitions[true]
- | alterStatementSuffixTouch
- | alterStatementSuffixArchive
- | alterStatementSuffixUnArchive
- | alterStatementSuffixProperties
- | alterStatementSuffixSkewedby
- | alterStatementSuffixExchangePartition
- | alterStatementPartitionKeyType
- | partitionSpec? alterTblPartitionStatementSuffix -> alterTblPartitionStatementSuffix partitionSpec?
- ;
-
-alterTblPartitionStatementSuffix
-@init {pushMsg("alter table partition statement suffix", state);}
-@after {popMsg(state);}
- : alterStatementSuffixFileFormat
- | alterStatementSuffixLocation
- | alterStatementSuffixMergeFiles
- | alterStatementSuffixSerdeProperties
- | alterStatementSuffixRenamePart
- | alterStatementSuffixBucketNum
- | alterTblPartitionStatementSuffixSkewedLocation
- | alterStatementSuffixClusterbySortby
- | alterStatementSuffixCompact
- | alterStatementSuffixUpdateStatsCol
- | alterStatementSuffixRenameCol
- | alterStatementSuffixAddCol
- ;
-
-alterStatementPartitionKeyType
-@init {msgs.push("alter partition key type"); }
-@after {msgs.pop();}
- : KW_PARTITION KW_COLUMN LPAREN columnNameType RPAREN
- -> ^(TOK_ALTERTABLE_PARTCOLTYPE columnNameType)
- ;
-
-alterViewStatementSuffix
-@init { pushMsg("alter view statement", state); }
-@after { popMsg(state); }
- : alterViewSuffixProperties
- | alterStatementSuffixRename[false]
- | alterStatementSuffixAddPartitions[false]
- | alterStatementSuffixDropPartitions[false]
- | selectStatementWithCTE
- ;
-
-alterIndexStatementSuffix
-@init { pushMsg("alter index statement", state); }
-@after { popMsg(state); }
- : indexName=identifier KW_ON tableName partitionSpec?
- (
- KW_REBUILD
- ->^(TOK_ALTERINDEX_REBUILD tableName $indexName partitionSpec?)
- |
- KW_SET KW_IDXPROPERTIES
- indexProperties
- ->^(TOK_ALTERINDEX_PROPERTIES tableName $indexName indexProperties)
- )
- ;
-
-alterDatabaseStatementSuffix
-@init { pushMsg("alter database statement", state); }
-@after { popMsg(state); }
- : alterDatabaseSuffixProperties
- | alterDatabaseSuffixSetOwner
- ;
-
-alterDatabaseSuffixProperties
-@init { pushMsg("alter database properties statement", state); }
-@after { popMsg(state); }
- : name=identifier KW_SET KW_DBPROPERTIES dbProperties
- -> ^(TOK_ALTERDATABASE_PROPERTIES $name dbProperties)
- ;
-
-alterDatabaseSuffixSetOwner
-@init { pushMsg("alter database set owner", state); }
-@after { popMsg(state); }
- : dbName=identifier KW_SET KW_OWNER principalName
- -> ^(TOK_ALTERDATABASE_OWNER $dbName principalName)
- ;
-
-alterStatementSuffixRename[boolean table]
-@init { pushMsg("rename statement", state); }
-@after { popMsg(state); }
- : KW_RENAME KW_TO tableName
- -> { table }? ^(TOK_ALTERTABLE_RENAME tableName)
- -> ^(TOK_ALTERVIEW_RENAME tableName)
- ;
-
-alterStatementSuffixAddCol
-@init { pushMsg("add column statement", state); }
-@after { popMsg(state); }
- : (add=KW_ADD | replace=KW_REPLACE) KW_COLUMNS LPAREN columnNameTypeList RPAREN restrictOrCascade?
- -> {$add != null}? ^(TOK_ALTERTABLE_ADDCOLS columnNameTypeList restrictOrCascade?)
- -> ^(TOK_ALTERTABLE_REPLACECOLS columnNameTypeList restrictOrCascade?)
- ;
-
-alterStatementSuffixRenameCol
-@init { pushMsg("rename column name", state); }
-@after { popMsg(state); }
- : KW_CHANGE KW_COLUMN? oldName=identifier newName=identifier colType (KW_COMMENT comment=StringLiteral)? alterStatementChangeColPosition? restrictOrCascade?
- ->^(TOK_ALTERTABLE_RENAMECOL $oldName $newName colType $comment? alterStatementChangeColPosition? restrictOrCascade?)
- ;
-
-alterStatementSuffixUpdateStatsCol
-@init { pushMsg("update column statistics", state); }
-@after { popMsg(state); }
- : KW_UPDATE KW_STATISTICS KW_FOR KW_COLUMN? colName=identifier KW_SET tableProperties (KW_COMMENT comment=StringLiteral)?
- ->^(TOK_ALTERTABLE_UPDATECOLSTATS $colName tableProperties $comment?)
- ;
-
-alterStatementChangeColPosition
- : first=KW_FIRST|KW_AFTER afterCol=identifier
- ->{$first != null}? ^(TOK_ALTERTABLE_CHANGECOL_AFTER_POSITION )
- -> ^(TOK_ALTERTABLE_CHANGECOL_AFTER_POSITION $afterCol)
- ;
-
-alterStatementSuffixAddPartitions[boolean table]
-@init { pushMsg("add partition statement", state); }
-@after { popMsg(state); }
- : KW_ADD ifNotExists? alterStatementSuffixAddPartitionsElement+
- -> { table }? ^(TOK_ALTERTABLE_ADDPARTS ifNotExists? alterStatementSuffixAddPartitionsElement+)
- -> ^(TOK_ALTERVIEW_ADDPARTS ifNotExists? alterStatementSuffixAddPartitionsElement+)
- ;
-
-alterStatementSuffixAddPartitionsElement
- : partitionSpec partitionLocation?
- ;
-
-alterStatementSuffixTouch
-@init { pushMsg("touch statement", state); }
-@after { popMsg(state); }
- : KW_TOUCH (partitionSpec)*
- -> ^(TOK_ALTERTABLE_TOUCH (partitionSpec)*)
- ;
-
-alterStatementSuffixArchive
-@init { pushMsg("archive statement", state); }
-@after { popMsg(state); }
- : KW_ARCHIVE (partitionSpec)*
- -> ^(TOK_ALTERTABLE_ARCHIVE (partitionSpec)*)
- ;
-
-alterStatementSuffixUnArchive
-@init { pushMsg("unarchive statement", state); }
-@after { popMsg(state); }
- : KW_UNARCHIVE (partitionSpec)*
- -> ^(TOK_ALTERTABLE_UNARCHIVE (partitionSpec)*)
- ;
-
-partitionLocation
-@init { pushMsg("partition location", state); }
-@after { popMsg(state); }
- :
- KW_LOCATION locn=StringLiteral -> ^(TOK_PARTITIONLOCATION $locn)
- ;
-
-alterStatementSuffixDropPartitions[boolean table]
-@init { pushMsg("drop partition statement", state); }
-@after { popMsg(state); }
- : KW_DROP ifExists? dropPartitionSpec (COMMA dropPartitionSpec)* KW_PURGE? replicationClause?
- -> { table }? ^(TOK_ALTERTABLE_DROPPARTS dropPartitionSpec+ ifExists? KW_PURGE? replicationClause?)
- -> ^(TOK_ALTERVIEW_DROPPARTS dropPartitionSpec+ ifExists? replicationClause?)
- ;
-
-alterStatementSuffixProperties
-@init { pushMsg("alter properties statement", state); }
-@after { popMsg(state); }
- : KW_SET KW_TBLPROPERTIES tableProperties
- -> ^(TOK_ALTERTABLE_PROPERTIES tableProperties)
- | KW_UNSET KW_TBLPROPERTIES ifExists? tableProperties
- -> ^(TOK_ALTERTABLE_DROPPROPERTIES tableProperties ifExists?)
- ;
-
-alterViewSuffixProperties
-@init { pushMsg("alter view properties statement", state); }
-@after { popMsg(state); }
- : KW_SET KW_TBLPROPERTIES tableProperties
- -> ^(TOK_ALTERVIEW_PROPERTIES tableProperties)
- | KW_UNSET KW_TBLPROPERTIES ifExists? tableProperties
- -> ^(TOK_ALTERVIEW_DROPPROPERTIES tableProperties ifExists?)
- ;
-
-alterStatementSuffixSerdeProperties
-@init { pushMsg("alter serdes statement", state); }
-@after { popMsg(state); }
- : KW_SET KW_SERDE serdeName=StringLiteral (KW_WITH KW_SERDEPROPERTIES tableProperties)?
- -> ^(TOK_ALTERTABLE_SERIALIZER $serdeName tableProperties?)
- | KW_SET KW_SERDEPROPERTIES tableProperties
- -> ^(TOK_ALTERTABLE_SERDEPROPERTIES tableProperties)
- ;
-
-tablePartitionPrefix
-@init {pushMsg("table partition prefix", state);}
-@after {popMsg(state);}
- : tableName partitionSpec?
- ->^(TOK_TABLE_PARTITION tableName partitionSpec?)
- ;
-
-alterStatementSuffixFileFormat
-@init {pushMsg("alter fileformat statement", state); }
-@after {popMsg(state);}
- : KW_SET KW_FILEFORMAT fileFormat
- -> ^(TOK_ALTERTABLE_FILEFORMAT fileFormat)
- ;
-
-alterStatementSuffixClusterbySortby
-@init {pushMsg("alter partition cluster by sort by statement", state);}
-@after {popMsg(state);}
- : KW_NOT KW_CLUSTERED -> ^(TOK_ALTERTABLE_CLUSTER_SORT TOK_NOT_CLUSTERED)
- | KW_NOT KW_SORTED -> ^(TOK_ALTERTABLE_CLUSTER_SORT TOK_NOT_SORTED)
- | tableBuckets -> ^(TOK_ALTERTABLE_CLUSTER_SORT tableBuckets)
- ;
-
-alterTblPartitionStatementSuffixSkewedLocation
-@init {pushMsg("alter partition skewed location", state);}
-@after {popMsg(state);}
- : KW_SET KW_SKEWED KW_LOCATION skewedLocations
- -> ^(TOK_ALTERTABLE_SKEWED_LOCATION skewedLocations)
- ;
-
-skewedLocations
-@init { pushMsg("skewed locations", state); }
-@after { popMsg(state); }
- :
- LPAREN skewedLocationsList RPAREN -> ^(TOK_SKEWED_LOCATIONS skewedLocationsList)
- ;
-
-skewedLocationsList
-@init { pushMsg("skewed locations list", state); }
-@after { popMsg(state); }
- :
- skewedLocationMap (COMMA skewedLocationMap)* -> ^(TOK_SKEWED_LOCATION_LIST skewedLocationMap+)
- ;
-
-skewedLocationMap
-@init { pushMsg("specifying skewed location map", state); }
-@after { popMsg(state); }
- :
- key=skewedValueLocationElement EQUAL value=StringLiteral -> ^(TOK_SKEWED_LOCATION_MAP $key $value)
- ;
-
-alterStatementSuffixLocation
-@init {pushMsg("alter location", state);}
-@after {popMsg(state);}
- : KW_SET KW_LOCATION newLoc=StringLiteral
- -> ^(TOK_ALTERTABLE_LOCATION $newLoc)
- ;
-
-
-alterStatementSuffixSkewedby
-@init {pushMsg("alter skewed by statement", state);}
-@after{popMsg(state);}
- : tableSkewed
- ->^(TOK_ALTERTABLE_SKEWED tableSkewed)
- |
- KW_NOT KW_SKEWED
- ->^(TOK_ALTERTABLE_SKEWED)
- |
- KW_NOT storedAsDirs
- ->^(TOK_ALTERTABLE_SKEWED storedAsDirs)
- ;
-
-alterStatementSuffixExchangePartition
-@init {pushMsg("alter exchange partition", state);}
-@after{popMsg(state);}
- : KW_EXCHANGE partitionSpec KW_WITH KW_TABLE exchangename=tableName
- -> ^(TOK_ALTERTABLE_EXCHANGEPARTITION partitionSpec $exchangename)
- ;
-
-alterStatementSuffixRenamePart
-@init { pushMsg("alter table rename partition statement", state); }
-@after { popMsg(state); }
- : KW_RENAME KW_TO partitionSpec
- ->^(TOK_ALTERTABLE_RENAMEPART partitionSpec)
- ;
-
-alterStatementSuffixStatsPart
-@init { pushMsg("alter table stats partition statement", state); }
-@after { popMsg(state); }
- : KW_UPDATE KW_STATISTICS KW_FOR KW_COLUMN? colName=identifier KW_SET tableProperties (KW_COMMENT comment=StringLiteral)?
- ->^(TOK_ALTERTABLE_UPDATECOLSTATS $colName tableProperties $comment?)
- ;
-
-alterStatementSuffixMergeFiles
-@init { pushMsg("", state); }
-@after { popMsg(state); }
- : KW_CONCATENATE
- -> ^(TOK_ALTERTABLE_MERGEFILES)
- ;
-
-alterStatementSuffixBucketNum
-@init { pushMsg("", state); }
-@after { popMsg(state); }
- : KW_INTO num=Number KW_BUCKETS
- -> ^(TOK_ALTERTABLE_BUCKETS $num)
- ;
-
-alterStatementSuffixCompact
-@init { msgs.push("compaction request"); }
-@after { msgs.pop(); }
- : KW_COMPACT compactType=StringLiteral
- -> ^(TOK_ALTERTABLE_COMPACT $compactType)
- ;
-
-
-fileFormat
-@init { pushMsg("file format specification", state); }
-@after { popMsg(state); }
- : KW_INPUTFORMAT inFmt=StringLiteral KW_OUTPUTFORMAT outFmt=StringLiteral KW_SERDE serdeCls=StringLiteral (KW_INPUTDRIVER inDriver=StringLiteral KW_OUTPUTDRIVER outDriver=StringLiteral)?
- -> ^(TOK_TABLEFILEFORMAT $inFmt $outFmt $serdeCls $inDriver? $outDriver?)
- | genericSpec=identifier -> ^(TOK_FILEFORMAT_GENERIC $genericSpec)
- ;
-
-tabTypeExpr
-@init { pushMsg("specifying table types", state); }
-@after { popMsg(state); }
- : identifier (DOT^ identifier)?
- (identifier (DOT^
- (
- (KW_ELEM_TYPE) => KW_ELEM_TYPE
- |
- (KW_KEY_TYPE) => KW_KEY_TYPE
- |
- (KW_VALUE_TYPE) => KW_VALUE_TYPE
- | identifier
- ))*
- )?
- ;
-
-partTypeExpr
-@init { pushMsg("specifying table partitions", state); }
-@after { popMsg(state); }
- : tabTypeExpr partitionSpec? -> ^(TOK_TABTYPE tabTypeExpr partitionSpec?)
- ;
-
-tabPartColTypeExpr
-@init { pushMsg("specifying table partitions columnName", state); }
-@after { popMsg(state); }
- : tableName partitionSpec? extColumnName? -> ^(TOK_TABTYPE tableName partitionSpec? extColumnName?)
- ;
-
-refreshStatement
-@init { pushMsg("refresh statement", state); }
-@after { popMsg(state); }
- :
- KW_REFRESH KW_TABLE tableName -> ^(TOK_REFRESHTABLE tableName)
- ;
-
-descStatement
-@init { pushMsg("describe statement", state); }
-@after { popMsg(state); }
- :
- (KW_DESCRIBE|KW_DESC)
- (
- (KW_DATABASE|KW_SCHEMA) => (KW_DATABASE|KW_SCHEMA) KW_EXTENDED? (dbName=identifier) -> ^(TOK_DESCDATABASE $dbName KW_EXTENDED?)
- |
- (KW_FUNCTION) => KW_FUNCTION KW_EXTENDED? (name=descFuncNames) -> ^(TOK_DESCFUNCTION $name KW_EXTENDED?)
- |
- (KW_FORMATTED|KW_EXTENDED|KW_PRETTY) => ((descOptions=KW_FORMATTED|descOptions=KW_EXTENDED|descOptions=KW_PRETTY) parttype=tabPartColTypeExpr) -> ^(TOK_DESCTABLE $parttype $descOptions)
- |
- parttype=tabPartColTypeExpr -> ^(TOK_DESCTABLE $parttype)
- )
- ;
-
-analyzeStatement
-@init { pushMsg("analyze statement", state); }
-@after { popMsg(state); }
- : KW_ANALYZE KW_TABLE (parttype=tableOrPartition) KW_COMPUTE KW_STATISTICS ((noscan=KW_NOSCAN) | (partialscan=KW_PARTIALSCAN)
- | (KW_FOR KW_COLUMNS (statsColumnName=columnNameList)?))?
- -> ^(TOK_ANALYZE $parttype $noscan? $partialscan? KW_COLUMNS? $statsColumnName?)
- ;
-
-showStatement
-@init { pushMsg("show statement", state); }
-@after { popMsg(state); }
- : KW_SHOW (KW_DATABASES|KW_SCHEMAS) (KW_LIKE showStmtIdentifier)? -> ^(TOK_SHOWDATABASES showStmtIdentifier?)
- | KW_SHOW KW_TABLES ((KW_FROM|KW_IN) db_name=identifier)? (KW_LIKE showStmtIdentifier|showStmtIdentifier)? -> ^(TOK_SHOWTABLES ^(TOK_FROM $db_name)? showStmtIdentifier?)
- | KW_SHOW KW_COLUMNS (KW_FROM|KW_IN) tableName ((KW_FROM|KW_IN) db_name=identifier)?
- -> ^(TOK_SHOWCOLUMNS tableName $db_name?)
- | KW_SHOW KW_FUNCTIONS (KW_LIKE showFunctionIdentifier|showFunctionIdentifier)? -> ^(TOK_SHOWFUNCTIONS KW_LIKE? showFunctionIdentifier?)
- | KW_SHOW KW_PARTITIONS tabName=tableName partitionSpec? -> ^(TOK_SHOWPARTITIONS $tabName partitionSpec?)
- | KW_SHOW KW_CREATE (
- (KW_DATABASE|KW_SCHEMA) => (KW_DATABASE|KW_SCHEMA) db_name=identifier -> ^(TOK_SHOW_CREATEDATABASE $db_name)
- |
- KW_TABLE tabName=tableName -> ^(TOK_SHOW_CREATETABLE $tabName)
- )
- | KW_SHOW KW_TABLE KW_EXTENDED ((KW_FROM|KW_IN) db_name=identifier)? KW_LIKE showStmtIdentifier partitionSpec?
- -> ^(TOK_SHOW_TABLESTATUS showStmtIdentifier $db_name? partitionSpec?)
- | KW_SHOW KW_TBLPROPERTIES tableName (LPAREN prptyName=StringLiteral RPAREN)? -> ^(TOK_SHOW_TBLPROPERTIES tableName $prptyName?)
- | KW_SHOW KW_LOCKS
- (
- (KW_DATABASE|KW_SCHEMA) => (KW_DATABASE|KW_SCHEMA) (dbName=Identifier) (isExtended=KW_EXTENDED)? -> ^(TOK_SHOWDBLOCKS $dbName $isExtended?)
- |
- (parttype=partTypeExpr)? (isExtended=KW_EXTENDED)? -> ^(TOK_SHOWLOCKS $parttype? $isExtended?)
- )
- | KW_SHOW (showOptions=KW_FORMATTED)? (KW_INDEX|KW_INDEXES) KW_ON showStmtIdentifier ((KW_FROM|KW_IN) db_name=identifier)?
- -> ^(TOK_SHOWINDEXES showStmtIdentifier $showOptions? $db_name?)
- | KW_SHOW KW_COMPACTIONS -> ^(TOK_SHOW_COMPACTIONS)
- | KW_SHOW KW_TRANSACTIONS -> ^(TOK_SHOW_TRANSACTIONS)
- | KW_SHOW KW_CONF StringLiteral -> ^(TOK_SHOWCONF StringLiteral)
- ;
-
-lockStatement
-@init { pushMsg("lock statement", state); }
-@after { popMsg(state); }
- : KW_LOCK KW_TABLE tableName partitionSpec? lockMode -> ^(TOK_LOCKTABLE tableName lockMode partitionSpec?)
- ;
-
-lockDatabase
-@init { pushMsg("lock database statement", state); }
-@after { popMsg(state); }
- : KW_LOCK (KW_DATABASE|KW_SCHEMA) (dbName=Identifier) lockMode -> ^(TOK_LOCKDB $dbName lockMode)
- ;
-
-lockMode
-@init { pushMsg("lock mode", state); }
-@after { popMsg(state); }
- : KW_SHARED | KW_EXCLUSIVE
- ;
-
-unlockStatement
-@init { pushMsg("unlock statement", state); }
-@after { popMsg(state); }
- : KW_UNLOCK KW_TABLE tableName partitionSpec? -> ^(TOK_UNLOCKTABLE tableName partitionSpec?)
- ;
-
-unlockDatabase
-@init { pushMsg("unlock database statement", state); }
-@after { popMsg(state); }
- : KW_UNLOCK (KW_DATABASE|KW_SCHEMA) (dbName=Identifier) -> ^(TOK_UNLOCKDB $dbName)
- ;
-
-createRoleStatement
-@init { pushMsg("create role", state); }
-@after { popMsg(state); }
- : KW_CREATE KW_ROLE roleName=identifier
- -> ^(TOK_CREATEROLE $roleName)
- ;
-
-dropRoleStatement
-@init {pushMsg("drop role", state);}
-@after {popMsg(state);}
- : KW_DROP KW_ROLE roleName=identifier
- -> ^(TOK_DROPROLE $roleName)
- ;
-
-grantPrivileges
-@init {pushMsg("grant privileges", state);}
-@after {popMsg(state);}
- : KW_GRANT privList=privilegeList
- privilegeObject?
- KW_TO principalSpecification
- withGrantOption?
- -> ^(TOK_GRANT $privList principalSpecification privilegeObject? withGrantOption?)
- ;
-
-revokePrivileges
-@init {pushMsg("revoke privileges", state);}
-@afer {popMsg(state);}
- : KW_REVOKE grantOptionFor? privilegeList privilegeObject? KW_FROM principalSpecification
- -> ^(TOK_REVOKE privilegeList principalSpecification privilegeObject? grantOptionFor?)
- ;
-
-grantRole
-@init {pushMsg("grant role", state);}
-@after {popMsg(state);}
- : KW_GRANT KW_ROLE? identifier (COMMA identifier)* KW_TO principalSpecification withAdminOption?
- -> ^(TOK_GRANT_ROLE principalSpecification withAdminOption? identifier+)
- ;
-
-revokeRole
-@init {pushMsg("revoke role", state);}
-@after {popMsg(state);}
- : KW_REVOKE adminOptionFor? KW_ROLE? identifier (COMMA identifier)* KW_FROM principalSpecification
- -> ^(TOK_REVOKE_ROLE principalSpecification adminOptionFor? identifier+)
- ;
-
-showRoleGrants
-@init {pushMsg("show role grants", state);}
-@after {popMsg(state);}
- : KW_SHOW KW_ROLE KW_GRANT principalName
- -> ^(TOK_SHOW_ROLE_GRANT principalName)
- ;
-
-
-showRoles
-@init {pushMsg("show roles", state);}
-@after {popMsg(state);}
- : KW_SHOW KW_ROLES
- -> ^(TOK_SHOW_ROLES)
- ;
-
-showCurrentRole
-@init {pushMsg("show current role", state);}
-@after {popMsg(state);}
- : KW_SHOW KW_CURRENT KW_ROLES
- -> ^(TOK_SHOW_SET_ROLE)
- ;
-
-setRole
-@init {pushMsg("set role", state);}
-@after {popMsg(state);}
- : KW_SET KW_ROLE
- (
- (KW_ALL) => (all=KW_ALL) -> ^(TOK_SHOW_SET_ROLE Identifier[$all.text])
- |
- (KW_NONE) => (none=KW_NONE) -> ^(TOK_SHOW_SET_ROLE Identifier[$none.text])
- |
- identifier -> ^(TOK_SHOW_SET_ROLE identifier)
- )
- ;
-
-showGrants
-@init {pushMsg("show grants", state);}
-@after {popMsg(state);}
- : KW_SHOW KW_GRANT principalName? (KW_ON privilegeIncludeColObject)?
- -> ^(TOK_SHOW_GRANT principalName? privilegeIncludeColObject?)
- ;
-
-showRolePrincipals
-@init {pushMsg("show role principals", state);}
-@after {popMsg(state);}
- : KW_SHOW KW_PRINCIPALS roleName=identifier
- -> ^(TOK_SHOW_ROLE_PRINCIPALS $roleName)
- ;
-
-
-privilegeIncludeColObject
-@init {pushMsg("privilege object including columns", state);}
-@after {popMsg(state);}
- : (KW_ALL) => KW_ALL -> ^(TOK_RESOURCE_ALL)
- | privObjectCols -> ^(TOK_PRIV_OBJECT_COL privObjectCols)
- ;
-
-privilegeObject
-@init {pushMsg("privilege object", state);}
-@after {popMsg(state);}
- : KW_ON privObject -> ^(TOK_PRIV_OBJECT privObject)
- ;
-
-// database or table type. Type is optional, default type is table
-privObject
- : (KW_DATABASE|KW_SCHEMA) identifier -> ^(TOK_DB_TYPE identifier)
- | KW_TABLE? tableName partitionSpec? -> ^(TOK_TABLE_TYPE tableName partitionSpec?)
- | KW_URI (path=StringLiteral) -> ^(TOK_URI_TYPE $path)
- | KW_SERVER identifier -> ^(TOK_SERVER_TYPE identifier)
- ;
-
-privObjectCols
- : (KW_DATABASE|KW_SCHEMA) identifier -> ^(TOK_DB_TYPE identifier)
- | KW_TABLE? tableName (LPAREN cols=columnNameList RPAREN)? partitionSpec? -> ^(TOK_TABLE_TYPE tableName $cols? partitionSpec?)
- | KW_URI (path=StringLiteral) -> ^(TOK_URI_TYPE $path)
- | KW_SERVER identifier -> ^(TOK_SERVER_TYPE identifier)
- ;
-
-privilegeList
-@init {pushMsg("grant privilege list", state);}
-@after {popMsg(state);}
- : privlegeDef (COMMA privlegeDef)*
- -> ^(TOK_PRIVILEGE_LIST privlegeDef+)
- ;
-
-privlegeDef
-@init {pushMsg("grant privilege", state);}
-@after {popMsg(state);}
- : privilegeType (LPAREN cols=columnNameList RPAREN)?
- -> ^(TOK_PRIVILEGE privilegeType $cols?)
- ;
-
-privilegeType
-@init {pushMsg("privilege type", state);}
-@after {popMsg(state);}
- : KW_ALL -> ^(TOK_PRIV_ALL)
- | KW_ALTER -> ^(TOK_PRIV_ALTER_METADATA)
- | KW_UPDATE -> ^(TOK_PRIV_ALTER_DATA)
- | KW_CREATE -> ^(TOK_PRIV_CREATE)
- | KW_DROP -> ^(TOK_PRIV_DROP)
- | KW_INDEX -> ^(TOK_PRIV_INDEX)
- | KW_LOCK -> ^(TOK_PRIV_LOCK)
- | KW_SELECT -> ^(TOK_PRIV_SELECT)
- | KW_SHOW_DATABASE -> ^(TOK_PRIV_SHOW_DATABASE)
- | KW_INSERT -> ^(TOK_PRIV_INSERT)
- | KW_DELETE -> ^(TOK_PRIV_DELETE)
- ;
-
-principalSpecification
-@init { pushMsg("user/group/role name list", state); }
-@after { popMsg(state); }
- : principalName (COMMA principalName)* -> ^(TOK_PRINCIPAL_NAME principalName+)
- ;
-
-principalName
-@init {pushMsg("user|group|role name", state);}
-@after {popMsg(state);}
- : KW_USER principalIdentifier -> ^(TOK_USER principalIdentifier)
- | KW_GROUP principalIdentifier -> ^(TOK_GROUP principalIdentifier)
- | KW_ROLE identifier -> ^(TOK_ROLE identifier)
- ;
-
-withGrantOption
-@init {pushMsg("with grant option", state);}
-@after {popMsg(state);}
- : KW_WITH KW_GRANT KW_OPTION
- -> ^(TOK_GRANT_WITH_OPTION)
- ;
-
-grantOptionFor
-@init {pushMsg("grant option for", state);}
-@after {popMsg(state);}
- : KW_GRANT KW_OPTION KW_FOR
- -> ^(TOK_GRANT_OPTION_FOR)
-;
-
-adminOptionFor
-@init {pushMsg("admin option for", state);}
-@after {popMsg(state);}
- : KW_ADMIN KW_OPTION KW_FOR
- -> ^(TOK_ADMIN_OPTION_FOR)
-;
-
-withAdminOption
-@init {pushMsg("with admin option", state);}
-@after {popMsg(state);}
- : KW_WITH KW_ADMIN KW_OPTION
- -> ^(TOK_GRANT_WITH_ADMIN_OPTION)
- ;
-
-metastoreCheck
-@init { pushMsg("metastore check statement", state); }
-@after { popMsg(state); }
- : KW_MSCK (repair=KW_REPAIR)? (KW_TABLE tableName partitionSpec? (COMMA partitionSpec)*)?
- -> ^(TOK_MSCK $repair? (tableName partitionSpec*)?)
- ;
-
-resourceList
-@init { pushMsg("resource list", state); }
-@after { popMsg(state); }
- :
- resource (COMMA resource)* -> ^(TOK_RESOURCE_LIST resource+)
- ;
-
-resource
-@init { pushMsg("resource", state); }
-@after { popMsg(state); }
- :
- resType=resourceType resPath=StringLiteral -> ^(TOK_RESOURCE_URI $resType $resPath)
- ;
-
-resourceType
-@init { pushMsg("resource type", state); }
-@after { popMsg(state); }
- :
- KW_JAR -> ^(TOK_JAR)
- |
- KW_FILE -> ^(TOK_FILE)
- |
- KW_ARCHIVE -> ^(TOK_ARCHIVE)
- ;
-
-createFunctionStatement
-@init { pushMsg("create function statement", state); }
-@after { popMsg(state); }
- : KW_CREATE (temp=KW_TEMPORARY)? KW_FUNCTION functionIdentifier KW_AS StringLiteral
- (KW_USING rList=resourceList)?
- -> {$temp != null}? ^(TOK_CREATEFUNCTION functionIdentifier StringLiteral $rList? TOK_TEMPORARY)
- -> ^(TOK_CREATEFUNCTION functionIdentifier StringLiteral $rList?)
- ;
-
-dropFunctionStatement
-@init { pushMsg("drop function statement", state); }
-@after { popMsg(state); }
- : KW_DROP (temp=KW_TEMPORARY)? KW_FUNCTION ifExists? functionIdentifier
- -> {$temp != null}? ^(TOK_DROPFUNCTION functionIdentifier ifExists? TOK_TEMPORARY)
- -> ^(TOK_DROPFUNCTION functionIdentifier ifExists?)
- ;
-
-reloadFunctionStatement
-@init { pushMsg("reload function statement", state); }
-@after { popMsg(state); }
- : KW_RELOAD KW_FUNCTION -> ^(TOK_RELOADFUNCTION);
-
-createMacroStatement
-@init { pushMsg("create macro statement", state); }
-@after { popMsg(state); }
- : KW_CREATE KW_TEMPORARY KW_MACRO Identifier
- LPAREN columnNameTypeList? RPAREN expression
- -> ^(TOK_CREATEMACRO Identifier columnNameTypeList? expression)
- ;
-
-dropMacroStatement
-@init { pushMsg("drop macro statement", state); }
-@after { popMsg(state); }
- : KW_DROP KW_TEMPORARY KW_MACRO ifExists? Identifier
- -> ^(TOK_DROPMACRO Identifier ifExists?)
- ;
-
-createViewStatement
-@init {
- pushMsg("create view statement", state);
-}
-@after { popMsg(state); }
- : KW_CREATE (orReplace)? KW_VIEW (ifNotExists)? name=tableName
- (LPAREN columnNameCommentList RPAREN)? tableComment? viewPartition?
- tablePropertiesPrefixed?
- KW_AS
- selectStatementWithCTE
- -> ^(TOK_CREATEVIEW $name orReplace?
- ifNotExists?
- columnNameCommentList?
- tableComment?
- viewPartition?
- tablePropertiesPrefixed?
- selectStatementWithCTE
- )
- ;
-
-viewPartition
-@init { pushMsg("view partition specification", state); }
-@after { popMsg(state); }
- : KW_PARTITIONED KW_ON LPAREN columnNameList RPAREN
- -> ^(TOK_VIEWPARTCOLS columnNameList)
- ;
-
-dropViewStatement
-@init { pushMsg("drop view statement", state); }
-@after { popMsg(state); }
- : KW_DROP KW_VIEW ifExists? viewName -> ^(TOK_DROPVIEW viewName ifExists?)
- ;
-
-showFunctionIdentifier
-@init { pushMsg("identifier for show function statement", state); }
-@after { popMsg(state); }
- : functionIdentifier
- | StringLiteral
- ;
-
-showStmtIdentifier
-@init { pushMsg("identifier for show statement", state); }
-@after { popMsg(state); }
- : identifier
- | StringLiteral
- ;
-
-tableProvider
-@init { pushMsg("table's provider", state); }
-@after { popMsg(state); }
- :
- KW_USING Identifier (DOT Identifier)*
- -> ^(TOK_TABLEPROVIDER Identifier+)
- ;
-
-optionKeyValue
-@init { pushMsg("table's option specification", state); }
-@after { popMsg(state); }
- :
- (looseIdentifier (DOT looseIdentifier)*) StringLiteral
- -> ^(TOK_TABLEOPTION looseIdentifier+ StringLiteral)
- ;
-
-tableOpts
-@init { pushMsg("table's options", state); }
-@after { popMsg(state); }
- :
- KW_OPTIONS LPAREN optionKeyValue (COMMA optionKeyValue)* RPAREN
- -> ^(TOK_TABLEOPTIONS optionKeyValue+)
- ;
-
-tableComment
-@init { pushMsg("table's comment", state); }
-@after { popMsg(state); }
- :
- KW_COMMENT comment=StringLiteral -> ^(TOK_TABLECOMMENT $comment)
- ;
-
-tablePartition
-@init { pushMsg("table partition specification", state); }
-@after { popMsg(state); }
- : KW_PARTITIONED KW_BY LPAREN columnNameTypeList RPAREN
- -> ^(TOK_TABLEPARTCOLS columnNameTypeList)
- ;
-
-tableBuckets
-@init { pushMsg("table buckets specification", state); }
-@after { popMsg(state); }
- :
- KW_CLUSTERED KW_BY LPAREN bucketCols=columnNameList RPAREN (KW_SORTED KW_BY LPAREN sortCols=columnNameOrderList RPAREN)? KW_INTO num=Number KW_BUCKETS
- -> ^(TOK_ALTERTABLE_BUCKETS $bucketCols $sortCols? $num)
- ;
-
-tableSkewed
-@init { pushMsg("table skewed specification", state); }
-@after { popMsg(state); }
- :
- KW_SKEWED KW_BY LPAREN skewedCols=columnNameList RPAREN KW_ON LPAREN (skewedValues=skewedValueElement) RPAREN ((storedAsDirs) => storedAsDirs)?
- -> ^(TOK_TABLESKEWED $skewedCols $skewedValues storedAsDirs?)
- ;
-
-rowFormat
-@init { pushMsg("serde specification", state); }
-@after { popMsg(state); }
- : rowFormatSerde -> ^(TOK_SERDE rowFormatSerde)
- | rowFormatDelimited -> ^(TOK_SERDE rowFormatDelimited)
- | -> ^(TOK_SERDE)
- ;
-
-recordReader
-@init { pushMsg("record reader specification", state); }
-@after { popMsg(state); }
- : KW_RECORDREADER StringLiteral -> ^(TOK_RECORDREADER StringLiteral)
- | -> ^(TOK_RECORDREADER)
- ;
-
-recordWriter
-@init { pushMsg("record writer specification", state); }
-@after { popMsg(state); }
- : KW_RECORDWRITER StringLiteral -> ^(TOK_RECORDWRITER StringLiteral)
- | -> ^(TOK_RECORDWRITER)
- ;
-
-rowFormatSerde
-@init { pushMsg("serde format specification", state); }
-@after { popMsg(state); }
- : KW_ROW KW_FORMAT KW_SERDE name=StringLiteral (KW_WITH KW_SERDEPROPERTIES serdeprops=tableProperties)?
- -> ^(TOK_SERDENAME $name $serdeprops?)
- ;
-
-rowFormatDelimited
-@init { pushMsg("serde properties specification", state); }
-@after { popMsg(state); }
- :
- KW_ROW KW_FORMAT KW_DELIMITED tableRowFormatFieldIdentifier? tableRowFormatCollItemsIdentifier? tableRowFormatMapKeysIdentifier? tableRowFormatLinesIdentifier? tableRowNullFormat?
- -> ^(TOK_SERDEPROPS tableRowFormatFieldIdentifier? tableRowFormatCollItemsIdentifier? tableRowFormatMapKeysIdentifier? tableRowFormatLinesIdentifier? tableRowNullFormat?)
- ;
-
-tableRowFormat
-@init { pushMsg("table row format specification", state); }
-@after { popMsg(state); }
- :
- rowFormatDelimited
- -> ^(TOK_TABLEROWFORMAT rowFormatDelimited)
- | rowFormatSerde
- -> ^(TOK_TABLESERIALIZER rowFormatSerde)
- ;
-
-tablePropertiesPrefixed
-@init { pushMsg("table properties with prefix", state); }
-@after { popMsg(state); }
- :
- KW_TBLPROPERTIES! tableProperties
- ;
-
-tableProperties
-@init { pushMsg("table properties", state); }
-@after { popMsg(state); }
- :
- LPAREN tablePropertiesList RPAREN -> ^(TOK_TABLEPROPERTIES tablePropertiesList)
- ;
-
-tablePropertiesList
-@init { pushMsg("table properties list", state); }
-@after { popMsg(state); }
- :
- keyValueProperty (COMMA keyValueProperty)* -> ^(TOK_TABLEPROPLIST keyValueProperty+)
- |
- keyProperty (COMMA keyProperty)* -> ^(TOK_TABLEPROPLIST keyProperty+)
- ;
-
-keyValueProperty
-@init { pushMsg("specifying key/value property", state); }
-@after { popMsg(state); }
- :
- key=StringLiteral EQUAL value=StringLiteral -> ^(TOK_TABLEPROPERTY $key $value)
- ;
-
-keyProperty
-@init { pushMsg("specifying key property", state); }
-@after { popMsg(state); }
- :
- key=StringLiteral -> ^(TOK_TABLEPROPERTY $key TOK_NULL)
- ;
-
-tableRowFormatFieldIdentifier
-@init { pushMsg("table row format's field separator", state); }
-@after { popMsg(state); }
- :
- KW_FIELDS KW_TERMINATED KW_BY fldIdnt=StringLiteral (KW_ESCAPED KW_BY fldEscape=StringLiteral)?
- -> ^(TOK_TABLEROWFORMATFIELD $fldIdnt $fldEscape?)
- ;
-
-tableRowFormatCollItemsIdentifier
-@init { pushMsg("table row format's column separator", state); }
-@after { popMsg(state); }
- :
- KW_COLLECTION KW_ITEMS KW_TERMINATED KW_BY collIdnt=StringLiteral
- -> ^(TOK_TABLEROWFORMATCOLLITEMS $collIdnt)
- ;
-
-tableRowFormatMapKeysIdentifier
-@init { pushMsg("table row format's map key separator", state); }
-@after { popMsg(state); }
- :
- KW_MAP KW_KEYS KW_TERMINATED KW_BY mapKeysIdnt=StringLiteral
- -> ^(TOK_TABLEROWFORMATMAPKEYS $mapKeysIdnt)
- ;
-
-tableRowFormatLinesIdentifier
-@init { pushMsg("table row format's line separator", state); }
-@after { popMsg(state); }
- :
- KW_LINES KW_TERMINATED KW_BY linesIdnt=StringLiteral
- -> ^(TOK_TABLEROWFORMATLINES $linesIdnt)
- ;
-
-tableRowNullFormat
-@init { pushMsg("table row format's null specifier", state); }
-@after { popMsg(state); }
- :
- KW_NULL KW_DEFINED KW_AS nullIdnt=StringLiteral
- -> ^(TOK_TABLEROWFORMATNULL $nullIdnt)
- ;
-tableFileFormat
-@init { pushMsg("table file format specification", state); }
-@after { popMsg(state); }
- :
- (KW_STORED KW_AS KW_INPUTFORMAT) => KW_STORED KW_AS KW_INPUTFORMAT inFmt=StringLiteral KW_OUTPUTFORMAT outFmt=StringLiteral (KW_INPUTDRIVER inDriver=StringLiteral KW_OUTPUTDRIVER outDriver=StringLiteral)?
- -> ^(TOK_TABLEFILEFORMAT $inFmt $outFmt $inDriver? $outDriver?)
- | KW_STORED KW_BY storageHandler=StringLiteral
- (KW_WITH KW_SERDEPROPERTIES serdeprops=tableProperties)?
- -> ^(TOK_STORAGEHANDLER $storageHandler $serdeprops?)
- | KW_STORED KW_AS genericSpec=identifier
- -> ^(TOK_FILEFORMAT_GENERIC $genericSpec)
- ;
-
-tableLocation
-@init { pushMsg("table location specification", state); }
-@after { popMsg(state); }
- :
- KW_LOCATION locn=StringLiteral -> ^(TOK_TABLELOCATION $locn)
- ;
-
-columnNameTypeList
-@init { pushMsg("column name type list", state); }
-@after { popMsg(state); }
- : columnNameType (COMMA columnNameType)* -> ^(TOK_TABCOLLIST columnNameType+)
- ;
-
-columnNameColonTypeList
-@init { pushMsg("column name type list", state); }
-@after { popMsg(state); }
- : columnNameColonType (COMMA columnNameColonType)* -> ^(TOK_TABCOLLIST columnNameColonType+)
- ;
-
-columnNameList
-@init { pushMsg("column name list", state); }
-@after { popMsg(state); }
- : columnName (COMMA columnName)* -> ^(TOK_TABCOLNAME columnName+)
- ;
-
-columnName
-@init { pushMsg("column name", state); }
-@after { popMsg(state); }
- :
- identifier
- ;
-
-extColumnName
-@init { pushMsg("column name for complex types", state); }
-@after { popMsg(state); }
- :
- identifier (DOT^ ((KW_ELEM_TYPE) => KW_ELEM_TYPE | (KW_KEY_TYPE) => KW_KEY_TYPE | (KW_VALUE_TYPE) => KW_VALUE_TYPE | identifier))*
- ;
-
-columnNameOrderList
-@init { pushMsg("column name order list", state); }
-@after { popMsg(state); }
- : columnNameOrder (COMMA columnNameOrder)* -> ^(TOK_TABCOLNAME columnNameOrder+)
- ;
-
-skewedValueElement
-@init { pushMsg("skewed value element", state); }
-@after { popMsg(state); }
- :
- skewedColumnValues
- | skewedColumnValuePairList
- ;
-
-skewedColumnValuePairList
-@init { pushMsg("column value pair list", state); }
-@after { popMsg(state); }
- : skewedColumnValuePair (COMMA skewedColumnValuePair)* -> ^(TOK_TABCOLVALUE_PAIR skewedColumnValuePair+)
- ;
-
-skewedColumnValuePair
-@init { pushMsg("column value pair", state); }
-@after { popMsg(state); }
- :
- LPAREN colValues=skewedColumnValues RPAREN
- -> ^(TOK_TABCOLVALUES $colValues)
- ;
-
-skewedColumnValues
-@init { pushMsg("column values", state); }
-@after { popMsg(state); }
- : skewedColumnValue (COMMA skewedColumnValue)* -> ^(TOK_TABCOLVALUE skewedColumnValue+)
- ;
-
-skewedColumnValue
-@init { pushMsg("column value", state); }
-@after { popMsg(state); }
- :
- constant
- ;
-
-skewedValueLocationElement
-@init { pushMsg("skewed value location element", state); }
-@after { popMsg(state); }
- :
- skewedColumnValue
- | skewedColumnValuePair
- ;
-
-columnNameOrder
-@init { pushMsg("column name order", state); }
-@after { popMsg(state); }
- : identifier (asc=KW_ASC | desc=KW_DESC)?
- -> {$desc == null}? ^(TOK_TABSORTCOLNAMEASC identifier)
- -> ^(TOK_TABSORTCOLNAMEDESC identifier)
- ;
-
-columnNameCommentList
-@init { pushMsg("column name comment list", state); }
-@after { popMsg(state); }
- : columnNameComment (COMMA columnNameComment)* -> ^(TOK_TABCOLNAME columnNameComment+)
- ;
-
-columnNameComment
-@init { pushMsg("column name comment", state); }
-@after { popMsg(state); }
- : colName=identifier (KW_COMMENT comment=StringLiteral)?
- -> ^(TOK_TABCOL $colName TOK_NULL $comment?)
- ;
-
-columnRefOrder
-@init { pushMsg("column order", state); }
-@after { popMsg(state); }
- : expression (asc=KW_ASC | desc=KW_DESC)?
- -> {$desc == null}? ^(TOK_TABSORTCOLNAMEASC expression)
- -> ^(TOK_TABSORTCOLNAMEDESC expression)
- ;
-
-columnNameType
-@init { pushMsg("column specification", state); }
-@after { popMsg(state); }
- : colName=identifier colType (KW_COMMENT comment=StringLiteral)?
- -> {containExcludedCharForCreateTableColumnName($colName.text)}? {throwColumnNameException()}
- -> {$comment == null}? ^(TOK_TABCOL $colName colType)
- -> ^(TOK_TABCOL $colName colType $comment)
- ;
-
-columnNameColonType
-@init { pushMsg("column specification", state); }
-@after { popMsg(state); }
- : colName=identifier COLON colType (KW_COMMENT comment=StringLiteral)?
- -> {$comment == null}? ^(TOK_TABCOL $colName colType)
- -> ^(TOK_TABCOL $colName colType $comment)
- ;
-
-colType
-@init { pushMsg("column type", state); }
-@after { popMsg(state); }
- : type
- ;
-
-colTypeList
-@init { pushMsg("column type list", state); }
-@after { popMsg(state); }
- : colType (COMMA colType)* -> ^(TOK_COLTYPELIST colType+)
- ;
-
-type
- : primitiveType
- | listType
- | structType
- | mapType
- | unionType;
-
-primitiveType
-@init { pushMsg("primitive type specification", state); }
-@after { popMsg(state); }
- : KW_TINYINT -> TOK_TINYINT
- | KW_SMALLINT -> TOK_SMALLINT
- | KW_INT -> TOK_INT
- | KW_BIGINT -> TOK_BIGINT
- | KW_LONG -> TOK_BIGINT
- | KW_BOOLEAN -> TOK_BOOLEAN
- | KW_FLOAT -> TOK_FLOAT
- | KW_DOUBLE -> TOK_DOUBLE
- | KW_DATE -> TOK_DATE
- | KW_DATETIME -> TOK_DATETIME
- | KW_TIMESTAMP -> TOK_TIMESTAMP
- // Uncomment to allow intervals as table column types
- //| KW_INTERVAL KW_YEAR KW_TO KW_MONTH -> TOK_INTERVAL_YEAR_MONTH
- //| KW_INTERVAL KW_DAY KW_TO KW_SECOND -> TOK_INTERVAL_DAY_TIME
- | KW_STRING -> TOK_STRING
- | KW_BINARY -> TOK_BINARY
- | KW_DECIMAL (LPAREN prec=Number (COMMA scale=Number)? RPAREN)? -> ^(TOK_DECIMAL $prec? $scale?)
- | KW_VARCHAR LPAREN length=Number RPAREN -> ^(TOK_VARCHAR $length)
- | KW_CHAR LPAREN length=Number RPAREN -> ^(TOK_CHAR $length)
- ;
-
-listType
-@init { pushMsg("list type", state); }
-@after { popMsg(state); }
- : KW_ARRAY LESSTHAN type GREATERTHAN -> ^(TOK_LIST type)
- ;
-
-structType
-@init { pushMsg("struct type", state); }
-@after { popMsg(state); }
- : KW_STRUCT LESSTHAN columnNameColonTypeList GREATERTHAN -> ^(TOK_STRUCT columnNameColonTypeList)
- ;
-
-mapType
-@init { pushMsg("map type", state); }
-@after { popMsg(state); }
- : KW_MAP LESSTHAN left=type COMMA right=type GREATERTHAN
- -> ^(TOK_MAP $left $right)
- ;
-
-unionType
-@init { pushMsg("uniontype type", state); }
-@after { popMsg(state); }
- : KW_UNIONTYPE LESSTHAN colTypeList GREATERTHAN -> ^(TOK_UNIONTYPE colTypeList)
- ;
-
-setOperator
-@init { pushMsg("set operator", state); }
-@after { popMsg(state); }
- : KW_UNION KW_ALL -> ^(TOK_UNIONALL)
- | KW_UNION KW_DISTINCT? -> ^(TOK_UNIONDISTINCT)
- | KW_EXCEPT -> ^(TOK_EXCEPT)
- | KW_INTERSECT -> ^(TOK_INTERSECT)
- ;
-
-queryStatementExpression[boolean topLevel]
- :
- /* Would be nice to do this as a gated semantic perdicate
- But the predicate gets pushed as a lookahead decision.
- Calling rule doesnot know about topLevel
- */
- (w=withClause {topLevel}?)?
- queryStatementExpressionBody[topLevel] {
- if ($w.tree != null) {
- $queryStatementExpressionBody.tree.insertChild(0, $w.tree);
- }
- }
- -> queryStatementExpressionBody
- ;
-
-queryStatementExpressionBody[boolean topLevel]
- :
- fromStatement[topLevel]
- | regularBody[topLevel]
- ;
-
-withClause
- :
- KW_WITH cteStatement (COMMA cteStatement)* -> ^(TOK_CTE cteStatement+)
-;
-
-cteStatement
- :
- identifier KW_AS LPAREN queryStatementExpression[false] RPAREN
- -> ^(TOK_SUBQUERY queryStatementExpression identifier)
-;
-
-fromStatement[boolean topLevel]
-: (singleFromStatement -> singleFromStatement)
- (u=setOperator r=singleFromStatement
- -> ^($u {$fromStatement.tree} $r)
- )*
- -> {u != null && topLevel}? ^(TOK_QUERY
- ^(TOK_FROM
- ^(TOK_SUBQUERY
- {$fromStatement.tree}
- {adaptor.create(Identifier, generateUnionAlias())}
- )
- )
- ^(TOK_INSERT
- ^(TOK_DESTINATION ^(TOK_DIR TOK_TMP_FILE))
- ^(TOK_SELECT ^(TOK_SELEXPR TOK_ALLCOLREF))
- )
- )
- -> {$fromStatement.tree}
- ;
-
-
-singleFromStatement
- :
- fromClause
- ( b+=body )+ -> ^(TOK_QUERY fromClause body+)
- ;
-
-/*
-The valuesClause rule below ensures that the parse tree for
-"insert into table FOO values (1,2),(3,4)" looks the same as
-"insert into table FOO select a,b from (values(1,2),(3,4)) as BAR(a,b)" which itself is made to look
-very similar to the tree for "insert into table FOO select a,b from BAR". Since virtual table name
-is implicit, it's represented as TOK_ANONYMOUS.
-*/
-regularBody[boolean topLevel]
- :
- i=insertClause
- (
- s=selectStatement[topLevel]
- {$s.tree.getFirstChildWithType(TOK_INSERT).replaceChildren(0, 0, $i.tree);} -> {$s.tree}
- |
- valuesClause
- -> ^(TOK_QUERY
- ^(TOK_FROM
- ^(TOK_VIRTUAL_TABLE ^(TOK_VIRTUAL_TABREF ^(TOK_ANONYMOUS)) valuesClause)
- )
- ^(TOK_INSERT {$i.tree} ^(TOK_SELECT ^(TOK_SELEXPR TOK_ALLCOLREF)))
- )
- )
- |
- selectStatement[topLevel]
- ;
-
-selectStatement[boolean topLevel]
- :
- (
- (
- LPAREN
- s=selectClause
- f=fromClause?
- w=whereClause?
- g=groupByClause?
- h=havingClause?
- o=orderByClause?
- c=clusterByClause?
- d=distributeByClause?
- sort=sortByClause?
- win=window_clause?
- l=limitClause?
- RPAREN
- |
- s=selectClause
- f=fromClause?
- w=whereClause?
- g=groupByClause?
- h=havingClause?
- o=orderByClause?
- c=clusterByClause?
- d=distributeByClause?
- sort=sortByClause?
- win=window_clause?
- l=limitClause?
- )
- -> ^(TOK_QUERY $f? ^(TOK_INSERT ^(TOK_DESTINATION ^(TOK_DIR TOK_TMP_FILE))
- $s $w? $g? $h? $o? $c?
- $d? $sort? $win? $l?))
- )
- (set=setOpSelectStatement[$selectStatement.tree, topLevel])?
- -> {set == null}?
- {$selectStatement.tree}
- -> {o==null && c==null && d==null && sort==null && l==null}?
- {$set.tree}
- -> {throwSetOpException()}
- ;
-
-setOpSelectStatement[CommonTree t, boolean topLevel]
- :
- ((
- u=setOperator LPAREN b=simpleSelectStatement RPAREN
- |
- u=setOperator b=simpleSelectStatement)
- -> {$setOpSelectStatement.tree != null}?
- ^($u {$setOpSelectStatement.tree} $b)
- -> ^($u {$t} $b)
- )+
- o=orderByClause?
- c=clusterByClause?
- d=distributeByClause?
- sort=sortByClause?
- win=window_clause?
- l=limitClause?
- -> {o==null && c==null && d==null && sort==null && win==null && l==null && !topLevel}?
- {$setOpSelectStatement.tree}
- -> ^(TOK_QUERY
- ^(TOK_FROM
- ^(TOK_SUBQUERY
- {$setOpSelectStatement.tree}
- {adaptor.create(Identifier, generateUnionAlias())}
- )
- )
- ^(TOK_INSERT
- ^(TOK_DESTINATION ^(TOK_DIR TOK_TMP_FILE))
- ^(TOK_SELECT ^(TOK_SELEXPR TOK_ALLCOLREF))
- $o? $c? $d? $sort? $win? $l?
- )
- )
- ;
-
-simpleSelectStatement
- :
- selectClause
- fromClause?
- whereClause?
- groupByClause?
- havingClause?
- ((window_clause) => window_clause)?
- -> ^(TOK_QUERY fromClause? ^(TOK_INSERT ^(TOK_DESTINATION ^(TOK_DIR TOK_TMP_FILE))
- selectClause whereClause? groupByClause? havingClause? window_clause?))
- ;
-
-selectStatementWithCTE
- :
- (w=withClause)?
- selectStatement[true] {
- if ($w.tree != null) {
- $selectStatement.tree.insertChild(0, $w.tree);
- }
- }
- -> selectStatement
- ;
-
-body
- :
- insertClause
- selectClause
- lateralView?
- whereClause?
- groupByClause?
- havingClause?
- orderByClause?
- clusterByClause?
- distributeByClause?
- sortByClause?
- window_clause?
- limitClause? -> ^(TOK_INSERT insertClause
- selectClause lateralView? whereClause? groupByClause? havingClause? orderByClause? clusterByClause?
- distributeByClause? sortByClause? window_clause? limitClause?)
- |
- selectClause
- lateralView?
- whereClause?
- groupByClause?
- havingClause?
- orderByClause?
- clusterByClause?
- distributeByClause?
- sortByClause?
- window_clause?
- limitClause? -> ^(TOK_INSERT ^(TOK_DESTINATION ^(TOK_DIR TOK_TMP_FILE))
- selectClause lateralView? whereClause? groupByClause? havingClause? orderByClause? clusterByClause?
- distributeByClause? sortByClause? window_clause? limitClause?)
- ;
-
-insertClause
-@init { pushMsg("insert clause", state); }
-@after { popMsg(state); }
- :
- KW_INSERT KW_OVERWRITE destination ifNotExists? -> ^(TOK_DESTINATION destination ifNotExists?)
- | KW_INSERT KW_INTO KW_TABLE? tableOrPartition (LPAREN targetCols=columnNameList RPAREN)?
- -> ^(TOK_INSERT_INTO tableOrPartition $targetCols?)
- ;
-
-destination
-@init { pushMsg("destination specification", state); }
-@after { popMsg(state); }
- :
- (local = KW_LOCAL)? KW_DIRECTORY StringLiteral tableRowFormat? tableFileFormat?
- -> ^(TOK_DIR StringLiteral $local? tableRowFormat? tableFileFormat?)
- | KW_TABLE tableOrPartition -> tableOrPartition
- ;
-
-limitClause
-@init { pushMsg("limit clause", state); }
-@after { popMsg(state); }
- :
- KW_LIMIT num=Number -> ^(TOK_LIMIT $num)
- ;
-
-//DELETE FROM <tableName> WHERE ...;
-deleteStatement
-@init { pushMsg("delete statement", state); }
-@after { popMsg(state); }
- :
- KW_DELETE KW_FROM tableName (whereClause)? -> ^(TOK_DELETE_FROM tableName whereClause?)
- ;
-
-/*SET <columName> = (3 + col2)*/
-columnAssignmentClause
- :
- tableOrColumn EQUAL^ precedencePlusExpression
- ;
-
-/*SET col1 = 5, col2 = (4 + col4), ...*/
-setColumnsClause
- :
- KW_SET columnAssignmentClause (COMMA columnAssignmentClause)* -> ^(TOK_SET_COLUMNS_CLAUSE columnAssignmentClause* )
- ;
-
-/*
- UPDATE <table>
- SET col1 = val1, col2 = val2... WHERE ...
-*/
-updateStatement
-@init { pushMsg("update statement", state); }
-@after { popMsg(state); }
- :
- KW_UPDATE tableName setColumnsClause whereClause? -> ^(TOK_UPDATE_TABLE tableName setColumnsClause whereClause?)
- ;
-
-/*
-BEGIN user defined transaction boundaries; follows SQL 2003 standard exactly except for addition of
-"setAutoCommitStatement" which is not in the standard doc but is supported by most SQL engines.
-*/
-sqlTransactionStatement
-@init { pushMsg("transaction statement", state); }
-@after { popMsg(state); }
- : startTransactionStatement
- | commitStatement
- | rollbackStatement
- | setAutoCommitStatement
- ;
-
-startTransactionStatement
- :
- KW_START KW_TRANSACTION ( transactionMode ( COMMA transactionMode )* )? -> ^(TOK_START_TRANSACTION transactionMode*)
- ;
-
-transactionMode
- :
- isolationLevel
- | transactionAccessMode -> ^(TOK_TXN_ACCESS_MODE transactionAccessMode)
- ;
-
-transactionAccessMode
- :
- KW_READ KW_ONLY -> TOK_TXN_READ_ONLY
- | KW_READ KW_WRITE -> TOK_TXN_READ_WRITE
- ;
-
-isolationLevel
- :
- KW_ISOLATION KW_LEVEL levelOfIsolation -> ^(TOK_ISOLATION_LEVEL levelOfIsolation)
- ;
-
-/*READ UNCOMMITTED | READ COMMITTED | REPEATABLE READ | SERIALIZABLE may be supported later*/
-levelOfIsolation
- :
- KW_SNAPSHOT -> TOK_ISOLATION_SNAPSHOT
- ;
-
-commitStatement
- :
- KW_COMMIT ( KW_WORK )? -> TOK_COMMIT
- ;
-
-rollbackStatement
- :
- KW_ROLLBACK ( KW_WORK )? -> TOK_ROLLBACK
- ;
-setAutoCommitStatement
- :
- KW_SET KW_AUTOCOMMIT booleanValueTok -> ^(TOK_SET_AUTOCOMMIT booleanValueTok)
- ;
-/*
-END user defined transaction boundaries
-*/
-
-/*
-Table Caching statements.
- */
-cacheStatement
-@init { pushMsg("cache statement", state); }
-@after { popMsg(state); }
- :
- cacheTableStatement
- | uncacheTableStatement
- | clearCacheStatement
- ;
-
-cacheTableStatement
- :
- KW_CACHE (lazy=KW_LAZY)? KW_TABLE identifier (KW_AS selectStatementWithCTE)? -> ^(TOK_CACHETABLE identifier $lazy? selectStatementWithCTE?)
- ;
-
-uncacheTableStatement
- :
- KW_UNCACHE KW_TABLE identifier -> ^(TOK_UNCACHETABLE identifier)
- ;
-
-clearCacheStatement
- :
- KW_CLEAR KW_CACHE -> ^(TOK_CLEARCACHE)
- ;
-
diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4
new file mode 100644
index 0000000000..9cf2dd257e
--- /dev/null
+++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4
@@ -0,0 +1,957 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ * This file is an adaptation of Presto's presto-parser/src/main/antlr4/com/facebook/presto/sql/parser/SqlBase.g4 grammar.
+ */
+
+grammar SqlBase;
+
+tokens {
+ DELIMITER
+}
+
+singleStatement
+ : statement EOF
+ ;
+
+singleExpression
+ : namedExpression EOF
+ ;
+
+singleTableIdentifier
+ : tableIdentifier EOF
+ ;
+
+singleDataType
+ : dataType EOF
+ ;
+
+statement
+ : query #statementDefault
+ | USE db=identifier #use
+ | CREATE DATABASE (IF NOT EXISTS)? identifier
+ (COMMENT comment=STRING)? locationSpec?
+ (WITH DBPROPERTIES tablePropertyList)? #createDatabase
+ | ALTER DATABASE identifier SET DBPROPERTIES tablePropertyList #setDatabaseProperties
+ | DROP DATABASE (IF EXISTS)? identifier (RESTRICT | CASCADE)? #dropDatabase
+ | createTableHeader ('(' colTypeList ')')? tableProvider
+ (OPTIONS tablePropertyList)? #createTableUsing
+ | createTableHeader tableProvider
+ (OPTIONS tablePropertyList)? AS? query #createTableUsing
+ | createTableHeader ('(' columns=colTypeList ')')?
+ (COMMENT STRING)?
+ (PARTITIONED BY '(' partitionColumns=colTypeList ')')?
+ bucketSpec? skewSpec?
+ rowFormat? createFileFormat? locationSpec?
+ (TBLPROPERTIES tablePropertyList)?
+ (AS? query)? #createTable
+ | CREATE TABLE (IF NOT EXISTS)? target=tableIdentifier
+ LIKE source=tableIdentifier #createTableLike
+ | ANALYZE TABLE tableIdentifier partitionSpec? COMPUTE STATISTICS
+ (identifier | FOR COLUMNS identifierSeq?)? #analyze
+ | ALTER (TABLE | VIEW) from=tableIdentifier
+ RENAME TO to=tableIdentifier #renameTable
+ | ALTER (TABLE | VIEW) tableIdentifier
+ SET TBLPROPERTIES tablePropertyList #setTableProperties
+ | ALTER (TABLE | VIEW) tableIdentifier
+ UNSET TBLPROPERTIES (IF EXISTS)? tablePropertyList #unsetTableProperties
+ | ALTER TABLE tableIdentifier (partitionSpec)?
+ SET SERDE STRING (WITH SERDEPROPERTIES tablePropertyList)? #setTableSerDe
+ | ALTER TABLE tableIdentifier (partitionSpec)?
+ SET SERDEPROPERTIES tablePropertyList #setTableSerDe
+ | ALTER TABLE tableIdentifier bucketSpec #bucketTable
+ | ALTER TABLE tableIdentifier NOT CLUSTERED #unclusterTable
+ | ALTER TABLE tableIdentifier NOT SORTED #unsortTable
+ | ALTER TABLE tableIdentifier skewSpec #skewTable
+ | ALTER TABLE tableIdentifier NOT SKEWED #unskewTable
+ | ALTER TABLE tableIdentifier NOT STORED AS DIRECTORIES #unstoreTable
+ | ALTER TABLE tableIdentifier
+ SET SKEWED LOCATION skewedLocationList #setTableSkewLocations
+ | ALTER TABLE tableIdentifier ADD (IF NOT EXISTS)?
+ partitionSpecLocation+ #addTablePartition
+ | ALTER VIEW tableIdentifier ADD (IF NOT EXISTS)?
+ partitionSpec+ #addTablePartition
+ | ALTER TABLE tableIdentifier
+ from=partitionSpec RENAME TO to=partitionSpec #renameTablePartition
+ | ALTER TABLE from=tableIdentifier
+ EXCHANGE partitionSpec WITH TABLE to=tableIdentifier #exchangeTablePartition
+ | ALTER TABLE tableIdentifier
+ DROP (IF EXISTS)? partitionSpec (',' partitionSpec)* PURGE? #dropTablePartitions
+ | ALTER VIEW tableIdentifier
+ DROP (IF EXISTS)? partitionSpec (',' partitionSpec)* #dropTablePartitions
+ | ALTER TABLE tableIdentifier ARCHIVE partitionSpec #archiveTablePartition
+ | ALTER TABLE tableIdentifier UNARCHIVE partitionSpec #unarchiveTablePartition
+ | ALTER TABLE tableIdentifier partitionSpec?
+ SET FILEFORMAT fileFormat #setTableFileFormat
+ | ALTER TABLE tableIdentifier partitionSpec? SET locationSpec #setTableLocation
+ | ALTER TABLE tableIdentifier TOUCH partitionSpec? #touchTable
+ | ALTER TABLE tableIdentifier partitionSpec? COMPACT STRING #compactTable
+ | ALTER TABLE tableIdentifier partitionSpec? CONCATENATE #concatenateTable
+ | ALTER TABLE tableIdentifier partitionSpec?
+ CHANGE COLUMN? oldName=identifier colType
+ (FIRST | AFTER after=identifier)? (CASCADE | RESTRICT)? #changeColumn
+ | ALTER TABLE tableIdentifier partitionSpec?
+ ADD COLUMNS '(' colTypeList ')' (CASCADE | RESTRICT)? #addColumns
+ | ALTER TABLE tableIdentifier partitionSpec?
+ REPLACE COLUMNS '(' colTypeList ')' (CASCADE | RESTRICT)? #replaceColumns
+ | DROP TABLE (IF EXISTS)? tableIdentifier PURGE?
+ (FOR METADATA? REPLICATION '(' STRING ')')? #dropTable
+ | DROP VIEW (IF EXISTS)? tableIdentifier #dropTable
+ | CREATE (OR REPLACE)? VIEW (IF NOT EXISTS)? tableIdentifier
+ identifierCommentList? (COMMENT STRING)?
+ (PARTITIONED ON identifierList)?
+ (TBLPROPERTIES tablePropertyList)? AS query #createView
+ | ALTER VIEW tableIdentifier AS? query #alterViewQuery
+ | CREATE TEMPORARY? FUNCTION qualifiedName AS className=STRING
+ (USING resource (',' resource)*)? #createFunction
+ | DROP TEMPORARY? FUNCTION (IF EXISTS)? qualifiedName #dropFunction
+ | EXPLAIN explainOption* statement #explain
+ | SHOW TABLES ((FROM | IN) db=identifier)?
+ (LIKE? pattern=STRING)? #showTables
+ | SHOW DATABASES (LIKE pattern=STRING)? #showDatabases
+ | SHOW TBLPROPERTIES table=tableIdentifier
+ ('(' key=tablePropertyKey ')')? #showTblProperties
+ | SHOW FUNCTIONS (LIKE? (qualifiedName | pattern=STRING))? #showFunctions
+ | (DESC | DESCRIBE) FUNCTION EXTENDED? qualifiedName #describeFunction
+ | (DESC | DESCRIBE) option=(EXTENDED | FORMATTED)?
+ tableIdentifier partitionSpec? describeColName? #describeTable
+ | (DESC | DESCRIBE) DATABASE EXTENDED? identifier #describeDatabase
+ | REFRESH TABLE tableIdentifier #refreshTable
+ | CACHE LAZY? TABLE identifier (AS? query)? #cacheTable
+ | UNCACHE TABLE identifier #uncacheTable
+ | CLEAR CACHE #clearCache
+ | ADD identifier .*? #addResource
+ | SET ROLE .*? #failNativeCommand
+ | SET .*? #setConfiguration
+ | kws=unsupportedHiveNativeCommands .*? #failNativeCommand
+ | hiveNativeCommands #executeNativeCommand
+ ;
+
+hiveNativeCommands
+ : DELETE FROM tableIdentifier (WHERE booleanExpression)?
+ | TRUNCATE TABLE tableIdentifier partitionSpec?
+ (COLUMNS identifierList)?
+ | SHOW COLUMNS (FROM | IN) tableIdentifier ((FROM|IN) identifier)?
+ | START TRANSACTION (transactionMode (',' transactionMode)*)?
+ | COMMIT WORK?
+ | ROLLBACK WORK?
+ | SHOW PARTITIONS tableIdentifier partitionSpec?
+ | DFS .*?
+ | (CREATE | ALTER | DROP | SHOW | DESC | DESCRIBE | LOAD) .*?
+ ;
+
+unsupportedHiveNativeCommands
+ : kw1=CREATE kw2=ROLE
+ | kw1=DROP kw2=ROLE
+ | kw1=GRANT kw2=ROLE?
+ | kw1=REVOKE kw2=ROLE?
+ | kw1=SHOW kw2=GRANT
+ | kw1=SHOW kw2=ROLE kw3=GRANT?
+ | kw1=SHOW kw2=PRINCIPALS
+ | kw1=SHOW kw2=ROLES
+ | kw1=SHOW kw2=CURRENT kw3=ROLES
+ | kw1=EXPORT kw2=TABLE
+ | kw1=IMPORT kw2=TABLE
+ | kw1=SHOW kw2=COMPACTIONS
+ | kw1=SHOW kw2=CREATE kw3=TABLE
+ | kw1=SHOW kw2=TRANSACTIONS
+ | kw1=SHOW kw2=INDEXES
+ | kw1=SHOW kw2=LOCKS
+ | kw1=CREATE kw2=INDEX
+ | kw1=DROP kw2=INDEX
+ | kw1=ALTER kw2=INDEX
+ | kw1=LOCK kw2=TABLE
+ | kw1=LOCK kw2=DATABASE
+ | kw1=UNLOCK kw2=TABLE
+ | kw1=UNLOCK kw2=DATABASE
+ | kw1=CREATE kw2=TEMPORARY kw3=MACRO
+ | kw1=DROP kw2=TEMPORARY kw3=MACRO
+ | kw1=MSCK kw2=REPAIR kw3=TABLE
+ ;
+
+createTableHeader
+ : CREATE TEMPORARY? EXTERNAL? TABLE (IF NOT EXISTS)? tableIdentifier
+ ;
+
+bucketSpec
+ : CLUSTERED BY identifierList
+ (SORTED BY orderedIdentifierList)?
+ INTO INTEGER_VALUE BUCKETS
+ ;
+
+skewSpec
+ : SKEWED BY identifierList
+ ON (constantList | nestedConstantList)
+ (STORED AS DIRECTORIES)?
+ ;
+
+locationSpec
+ : LOCATION STRING
+ ;
+
+query
+ : ctes? queryNoWith
+ ;
+
+insertInto
+ : INSERT OVERWRITE TABLE tableIdentifier partitionSpec? (IF NOT EXISTS)?
+ | INSERT INTO TABLE? tableIdentifier partitionSpec?
+ ;
+
+partitionSpecLocation
+ : partitionSpec locationSpec?
+ ;
+
+partitionSpec
+ : PARTITION '(' partitionVal (',' partitionVal)* ')'
+ ;
+
+partitionVal
+ : identifier (EQ constant)?
+ ;
+
+describeColName
+ : identifier ('.' (identifier | STRING))*
+ ;
+
+ctes
+ : WITH namedQuery (',' namedQuery)*
+ ;
+
+namedQuery
+ : name=identifier AS? '(' queryNoWith ')'
+ ;
+
+tableProvider
+ : USING qualifiedName
+ ;
+
+tablePropertyList
+ : '(' tableProperty (',' tableProperty)* ')'
+ ;
+
+tableProperty
+ : key=tablePropertyKey (EQ? value=STRING)?
+ ;
+
+tablePropertyKey
+ : looseIdentifier ('.' looseIdentifier)*
+ | STRING
+ ;
+
+constantList
+ : '(' constant (',' constant)* ')'
+ ;
+
+nestedConstantList
+ : '(' constantList (',' constantList)* ')'
+ ;
+
+skewedLocation
+ : (constant | constantList) EQ STRING
+ ;
+
+skewedLocationList
+ : '(' skewedLocation (',' skewedLocation)* ')'
+ ;
+
+createFileFormat
+ : STORED AS fileFormat
+ | STORED BY storageHandler
+ ;
+
+fileFormat
+ : INPUTFORMAT inFmt=STRING OUTPUTFORMAT outFmt=STRING (SERDE serdeCls=STRING)? #tableFileFormat
+ | identifier #genericFileFormat
+ ;
+
+storageHandler
+ : STRING (WITH SERDEPROPERTIES tablePropertyList)?
+ ;
+
+resource
+ : identifier STRING
+ ;
+
+queryNoWith
+ : insertInto? queryTerm queryOrganization #singleInsertQuery
+ | fromClause multiInsertQueryBody+ #multiInsertQuery
+ ;
+
+queryOrganization
+ : (ORDER BY order+=sortItem (',' order+=sortItem)*)?
+ (CLUSTER BY clusterBy+=expression (',' clusterBy+=expression)*)?
+ (DISTRIBUTE BY distributeBy+=expression (',' distributeBy+=expression)*)?
+ (SORT BY sort+=sortItem (',' sort+=sortItem)*)?
+ windows?
+ (LIMIT limit=expression)?
+ ;
+
+multiInsertQueryBody
+ : insertInto?
+ querySpecification
+ queryOrganization
+ ;
+
+queryTerm
+ : queryPrimary #queryTermDefault
+ | left=queryTerm operator=(INTERSECT | UNION | EXCEPT) setQuantifier? right=queryTerm #setOperation
+ ;
+
+queryPrimary
+ : querySpecification #queryPrimaryDefault
+ | TABLE tableIdentifier #table
+ | inlineTable #inlineTableDefault1
+ | '(' queryNoWith ')' #subquery
+ ;
+
+sortItem
+ : expression ordering=(ASC | DESC)?
+ ;
+
+querySpecification
+ : (((SELECT kind=TRANSFORM '(' namedExpressionSeq ')'
+ | kind=MAP namedExpressionSeq
+ | kind=REDUCE namedExpressionSeq))
+ inRowFormat=rowFormat?
+ (RECORDWRITER recordWriter=STRING)?
+ USING script=STRING
+ (AS (identifierSeq | colTypeList | ('(' (identifierSeq | colTypeList) ')')))?
+ outRowFormat=rowFormat?
+ (RECORDREADER recordReader=STRING)?
+ fromClause?
+ (WHERE where=booleanExpression)?)
+ | ((kind=SELECT setQuantifier? namedExpressionSeq fromClause?
+ | fromClause (kind=SELECT setQuantifier? namedExpressionSeq)?)
+ lateralView*
+ (WHERE where=booleanExpression)?
+ aggregation?
+ (HAVING having=booleanExpression)?
+ windows?)
+ ;
+
+fromClause
+ : FROM relation (',' relation)* lateralView*
+ ;
+
+aggregation
+ : GROUP BY groupingExpressions+=expression (',' groupingExpressions+=expression)* (
+ WITH kind=ROLLUP
+ | WITH kind=CUBE
+ | kind=GROUPING SETS '(' groupingSet (',' groupingSet)* ')')?
+ ;
+
+groupingSet
+ : '(' (expression (',' expression)*)? ')'
+ | expression
+ ;
+
+lateralView
+ : LATERAL VIEW (OUTER)? qualifiedName '(' (expression (',' expression)*)? ')' tblName=identifier (AS? colName+=identifier (',' colName+=identifier)*)?
+ ;
+
+setQuantifier
+ : DISTINCT
+ | ALL
+ ;
+
+relation
+ : left=relation
+ ((CROSS | joinType) JOIN right=relation joinCriteria?
+ | NATURAL joinType JOIN right=relation
+ ) #joinRelation
+ | relationPrimary #relationDefault
+ ;
+
+joinType
+ : INNER?
+ | LEFT OUTER?
+ | LEFT SEMI
+ | RIGHT OUTER?
+ | FULL OUTER?
+ | LEFT? ANTI
+ ;
+
+joinCriteria
+ : ON booleanExpression
+ | USING '(' identifier (',' identifier)* ')'
+ ;
+
+sample
+ : TABLESAMPLE '('
+ ( (percentage=(INTEGER_VALUE | DECIMAL_VALUE) sampleType=PERCENTLIT)
+ | (expression sampleType=ROWS)
+ | (sampleType=BUCKET numerator=INTEGER_VALUE OUT OF denominator=INTEGER_VALUE (ON identifier)?))
+ ')'
+ ;
+
+identifierList
+ : '(' identifierSeq ')'
+ ;
+
+identifierSeq
+ : identifier (',' identifier)*
+ ;
+
+orderedIdentifierList
+ : '(' orderedIdentifier (',' orderedIdentifier)* ')'
+ ;
+
+orderedIdentifier
+ : identifier ordering=(ASC | DESC)?
+ ;
+
+identifierCommentList
+ : '(' identifierComment (',' identifierComment)* ')'
+ ;
+
+identifierComment
+ : identifier (COMMENT STRING)?
+ ;
+
+relationPrimary
+ : tableIdentifier sample? (AS? identifier)? #tableName
+ | '(' queryNoWith ')' sample? (AS? identifier)? #aliasedQuery
+ | '(' relation ')' sample? (AS? identifier)? #aliasedRelation
+ | inlineTable #inlineTableDefault2
+ ;
+
+inlineTable
+ : VALUES expression (',' expression)* (AS? identifier identifierList?)?
+ ;
+
+rowFormat
+ : ROW FORMAT SERDE name=STRING (WITH SERDEPROPERTIES props=tablePropertyList)? #rowFormatSerde
+ | ROW FORMAT DELIMITED
+ (FIELDS TERMINATED BY fieldsTerminatedBy=STRING (ESCAPED BY escapedBy=STRING)?)?
+ (COLLECTION ITEMS TERMINATED BY collectionItemsTerminatedBy=STRING)?
+ (MAP KEYS TERMINATED BY keysTerminatedBy=STRING)?
+ (LINES TERMINATED BY linesSeparatedBy=STRING)?
+ (NULL DEFINED AS nullDefinedAs=STRING)? #rowFormatDelimited
+ ;
+
+tableIdentifier
+ : (db=identifier '.')? table=identifier
+ ;
+
+namedExpression
+ : expression (AS? (identifier | identifierList))?
+ ;
+
+namedExpressionSeq
+ : namedExpression (',' namedExpression)*
+ ;
+
+expression
+ : booleanExpression
+ ;
+
+booleanExpression
+ : predicated #booleanDefault
+ | NOT booleanExpression #logicalNot
+ | left=booleanExpression operator=AND right=booleanExpression #logicalBinary
+ | left=booleanExpression operator=OR right=booleanExpression #logicalBinary
+ | EXISTS '(' query ')' #exists
+ ;
+
+// workaround for:
+// https://github.com/antlr/antlr4/issues/780
+// https://github.com/antlr/antlr4/issues/781
+predicated
+ : valueExpression predicate?
+ ;
+
+predicate
+ : NOT? kind=BETWEEN lower=valueExpression AND upper=valueExpression
+ | NOT? kind=IN '(' expression (',' expression)* ')'
+ | NOT? kind=IN '(' query ')'
+ | NOT? kind=(RLIKE | LIKE) pattern=valueExpression
+ | IS NOT? kind=NULL
+ ;
+
+valueExpression
+ : primaryExpression #valueExpressionDefault
+ | operator=(MINUS | PLUS | TILDE) valueExpression #arithmeticUnary
+ | left=valueExpression operator=(ASTERISK | SLASH | PERCENT | DIV) right=valueExpression #arithmeticBinary
+ | left=valueExpression operator=(PLUS | MINUS) right=valueExpression #arithmeticBinary
+ | left=valueExpression operator=AMPERSAND right=valueExpression #arithmeticBinary
+ | left=valueExpression operator=HAT right=valueExpression #arithmeticBinary
+ | left=valueExpression operator=PIPE right=valueExpression #arithmeticBinary
+ | left=valueExpression comparisonOperator right=valueExpression #comparison
+ ;
+
+primaryExpression
+ : constant #constantDefault
+ | ASTERISK #star
+ | qualifiedName '.' ASTERISK #star
+ | '(' expression (',' expression)+ ')' #rowConstructor
+ | qualifiedName '(' (setQuantifier? expression (',' expression)*)? ')' (OVER windowSpec)? #functionCall
+ | '(' query ')' #subqueryExpression
+ | CASE valueExpression whenClause+ (ELSE elseExpression=expression)? END #simpleCase
+ | CASE whenClause+ (ELSE elseExpression=expression)? END #searchedCase
+ | CAST '(' expression AS dataType ')' #cast
+ | value=primaryExpression '[' index=valueExpression ']' #subscript
+ | identifier #columnReference
+ | base=primaryExpression '.' fieldName=identifier #dereference
+ | '(' expression ')' #parenthesizedExpression
+ ;
+
+constant
+ : NULL #nullLiteral
+ | interval #intervalLiteral
+ | identifier STRING #typeConstructor
+ | number #numericLiteral
+ | booleanValue #booleanLiteral
+ | STRING+ #stringLiteral
+ ;
+
+comparisonOperator
+ : EQ | NEQ | NEQJ | LT | LTE | GT | GTE | NSEQ
+ ;
+
+booleanValue
+ : TRUE | FALSE
+ ;
+
+interval
+ : INTERVAL intervalField*
+ ;
+
+intervalField
+ : value=intervalValue unit=identifier (TO to=identifier)?
+ ;
+
+intervalValue
+ : (PLUS | MINUS)? (INTEGER_VALUE | DECIMAL_VALUE)
+ | STRING
+ ;
+
+dataType
+ : complex=ARRAY '<' dataType '>' #complexDataType
+ | complex=MAP '<' dataType ',' dataType '>' #complexDataType
+ | complex=STRUCT ('<' colTypeList? '>' | NEQ) #complexDataType
+ | identifier ('(' INTEGER_VALUE (',' INTEGER_VALUE)* ')')? #primitiveDataType
+ ;
+
+colTypeList
+ : colType (',' colType)*
+ ;
+
+colType
+ : identifier ':'? dataType (COMMENT STRING)?
+ ;
+
+whenClause
+ : WHEN condition=expression THEN result=expression
+ ;
+
+windows
+ : WINDOW namedWindow (',' namedWindow)*
+ ;
+
+namedWindow
+ : identifier AS windowSpec
+ ;
+
+windowSpec
+ : name=identifier #windowRef
+ | '('
+ ( CLUSTER BY partition+=expression (',' partition+=expression)*
+ | ((PARTITION | DISTRIBUTE) BY partition+=expression (',' partition+=expression)*)?
+ ((ORDER | SORT) BY sortItem (',' sortItem)*)?)
+ windowFrame?
+ ')' #windowDef
+ ;
+
+windowFrame
+ : frameType=RANGE start=frameBound
+ | frameType=ROWS start=frameBound
+ | frameType=RANGE BETWEEN start=frameBound AND end=frameBound
+ | frameType=ROWS BETWEEN start=frameBound AND end=frameBound
+ ;
+
+frameBound
+ : UNBOUNDED boundType=(PRECEDING | FOLLOWING)
+ | boundType=CURRENT ROW
+ | expression boundType=(PRECEDING | FOLLOWING)
+ ;
+
+
+explainOption
+ : LOGICAL | FORMATTED | EXTENDED | CODEGEN
+ ;
+
+transactionMode
+ : ISOLATION LEVEL SNAPSHOT #isolationLevel
+ | READ accessMode=(ONLY | WRITE) #transactionAccessMode
+ ;
+
+qualifiedName
+ : identifier ('.' identifier)*
+ ;
+
+// Identifier that also allows the use of a number of SQL keywords (mainly for backwards compatibility).
+looseIdentifier
+ : identifier
+ | FROM
+ | TO
+ | TABLE
+ | WITH
+ ;
+
+identifier
+ : IDENTIFIER #unquotedIdentifier
+ | quotedIdentifier #quotedIdentifierAlternative
+ | nonReserved #unquotedIdentifier
+ ;
+
+quotedIdentifier
+ : BACKQUOTED_IDENTIFIER
+ ;
+
+number
+ : DECIMAL_VALUE #decimalLiteral
+ | SCIENTIFIC_DECIMAL_VALUE #scientificDecimalLiteral
+ | INTEGER_VALUE #integerLiteral
+ | BIGINT_LITERAL #bigIntLiteral
+ | SMALLINT_LITERAL #smallIntLiteral
+ | TINYINT_LITERAL #tinyIntLiteral
+ | DOUBLE_LITERAL #doubleLiteral
+ ;
+
+nonReserved
+ : SHOW | TABLES | COLUMNS | COLUMN | PARTITIONS | FUNCTIONS | DATABASES
+ | ADD
+ | OVER | PARTITION | RANGE | ROWS | PRECEDING | FOLLOWING | CURRENT | ROW | MAP | ARRAY | STRUCT
+ | LATERAL | WINDOW | REDUCE | TRANSFORM | USING | SERDE | SERDEPROPERTIES | RECORDREADER
+ | DELIMITED | FIELDS | TERMINATED | COLLECTION | ITEMS | KEYS | ESCAPED | LINES | SEPARATED
+ | EXTENDED | REFRESH | CLEAR | CACHE | UNCACHE | LAZY | TEMPORARY | OPTIONS
+ | GROUPING | CUBE | ROLLUP
+ | EXPLAIN | FORMAT | LOGICAL | FORMATTED | CODEGEN
+ | TABLESAMPLE | USE | TO | BUCKET | PERCENTLIT | OUT | OF
+ | SET
+ | VIEW | REPLACE
+ | IF
+ | NO | DATA
+ | START | TRANSACTION | COMMIT | ROLLBACK | WORK | ISOLATION | LEVEL
+ | SNAPSHOT | READ | WRITE | ONLY
+ | SORT | CLUSTER | DISTRIBUTE | UNSET | TBLPROPERTIES | SKEWED | STORED | DIRECTORIES | LOCATION
+ | EXCHANGE | ARCHIVE | UNARCHIVE | FILEFORMAT | TOUCH | COMPACT | CONCATENATE | CHANGE | FIRST
+ | AFTER | CASCADE | RESTRICT | BUCKETS | CLUSTERED | SORTED | PURGE | INPUTFORMAT | OUTPUTFORMAT
+ | INPUTDRIVER | OUTPUTDRIVER | DBPROPERTIES | DFS | TRUNCATE | METADATA | REPLICATION | COMPUTE
+ | STATISTICS | ANALYZE | PARTITIONED | EXTERNAL | DEFINED | RECORDWRITER
+ | REVOKE | GRANT | LOCK | UNLOCK | MSCK | REPAIR | EXPORT | IMPORT | LOAD | VALUES | COMMENT | ROLE
+ | ROLES | COMPACTIONS | PRINCIPALS | TRANSACTIONS | INDEX | INDEXES | LOCKS | OPTION
+ ;
+
+SELECT: 'SELECT';
+FROM: 'FROM';
+ADD: 'ADD';
+AS: 'AS';
+ALL: 'ALL';
+DISTINCT: 'DISTINCT';
+WHERE: 'WHERE';
+GROUP: 'GROUP';
+BY: 'BY';
+GROUPING: 'GROUPING';
+SETS: 'SETS';
+CUBE: 'CUBE';
+ROLLUP: 'ROLLUP';
+ORDER: 'ORDER';
+HAVING: 'HAVING';
+LIMIT: 'LIMIT';
+AT: 'AT';
+OR: 'OR';
+AND: 'AND';
+IN: 'IN';
+NOT: 'NOT' | '!';
+NO: 'NO';
+EXISTS: 'EXISTS';
+BETWEEN: 'BETWEEN';
+LIKE: 'LIKE';
+RLIKE: 'RLIKE' | 'REGEXP';
+IS: 'IS';
+NULL: 'NULL';
+TRUE: 'TRUE';
+FALSE: 'FALSE';
+NULLS: 'NULLS';
+ASC: 'ASC';
+DESC: 'DESC';
+FOR: 'FOR';
+INTERVAL: 'INTERVAL';
+CASE: 'CASE';
+WHEN: 'WHEN';
+THEN: 'THEN';
+ELSE: 'ELSE';
+END: 'END';
+JOIN: 'JOIN';
+CROSS: 'CROSS';
+OUTER: 'OUTER';
+INNER: 'INNER';
+LEFT: 'LEFT';
+SEMI: 'SEMI';
+RIGHT: 'RIGHT';
+FULL: 'FULL';
+NATURAL: 'NATURAL';
+ON: 'ON';
+LATERAL: 'LATERAL';
+WINDOW: 'WINDOW';
+OVER: 'OVER';
+PARTITION: 'PARTITION';
+RANGE: 'RANGE';
+ROWS: 'ROWS';
+UNBOUNDED: 'UNBOUNDED';
+PRECEDING: 'PRECEDING';
+FOLLOWING: 'FOLLOWING';
+CURRENT: 'CURRENT';
+ROW: 'ROW';
+WITH: 'WITH';
+VALUES: 'VALUES';
+CREATE: 'CREATE';
+TABLE: 'TABLE';
+VIEW: 'VIEW';
+REPLACE: 'REPLACE';
+INSERT: 'INSERT';
+DELETE: 'DELETE';
+INTO: 'INTO';
+DESCRIBE: 'DESCRIBE';
+EXPLAIN: 'EXPLAIN';
+FORMAT: 'FORMAT';
+LOGICAL: 'LOGICAL';
+CODEGEN: 'CODEGEN';
+CAST: 'CAST';
+SHOW: 'SHOW';
+TABLES: 'TABLES';
+COLUMNS: 'COLUMNS';
+COLUMN: 'COLUMN';
+USE: 'USE';
+PARTITIONS: 'PARTITIONS';
+FUNCTIONS: 'FUNCTIONS';
+DROP: 'DROP';
+UNION: 'UNION';
+EXCEPT: 'EXCEPT';
+INTERSECT: 'INTERSECT';
+TO: 'TO';
+TABLESAMPLE: 'TABLESAMPLE';
+STRATIFY: 'STRATIFY';
+ALTER: 'ALTER';
+RENAME: 'RENAME';
+ARRAY: 'ARRAY';
+MAP: 'MAP';
+STRUCT: 'STRUCT';
+COMMENT: 'COMMENT';
+SET: 'SET';
+DATA: 'DATA';
+START: 'START';
+TRANSACTION: 'TRANSACTION';
+COMMIT: 'COMMIT';
+ROLLBACK: 'ROLLBACK';
+WORK: 'WORK';
+ISOLATION: 'ISOLATION';
+LEVEL: 'LEVEL';
+SNAPSHOT: 'SNAPSHOT';
+READ: 'READ';
+WRITE: 'WRITE';
+ONLY: 'ONLY';
+MACRO: 'MACRO';
+
+IF: 'IF';
+
+EQ : '=' | '==';
+NSEQ: '<=>';
+NEQ : '<>';
+NEQJ: '!=';
+LT : '<';
+LTE : '<=';
+GT : '>';
+GTE : '>=';
+
+PLUS: '+';
+MINUS: '-';
+ASTERISK: '*';
+SLASH: '/';
+PERCENT: '%';
+DIV: 'DIV';
+TILDE: '~';
+AMPERSAND: '&';
+PIPE: '|';
+HAT: '^';
+
+PERCENTLIT: 'PERCENT';
+BUCKET: 'BUCKET';
+OUT: 'OUT';
+OF: 'OF';
+
+SORT: 'SORT';
+CLUSTER: 'CLUSTER';
+DISTRIBUTE: 'DISTRIBUTE';
+OVERWRITE: 'OVERWRITE';
+TRANSFORM: 'TRANSFORM';
+REDUCE: 'REDUCE';
+USING: 'USING';
+SERDE: 'SERDE';
+SERDEPROPERTIES: 'SERDEPROPERTIES';
+RECORDREADER: 'RECORDREADER';
+RECORDWRITER: 'RECORDWRITER';
+DELIMITED: 'DELIMITED';
+FIELDS: 'FIELDS';
+TERMINATED: 'TERMINATED';
+COLLECTION: 'COLLECTION';
+ITEMS: 'ITEMS';
+KEYS: 'KEYS';
+ESCAPED: 'ESCAPED';
+LINES: 'LINES';
+SEPARATED: 'SEPARATED';
+FUNCTION: 'FUNCTION';
+EXTENDED: 'EXTENDED';
+REFRESH: 'REFRESH';
+CLEAR: 'CLEAR';
+CACHE: 'CACHE';
+UNCACHE: 'UNCACHE';
+LAZY: 'LAZY';
+FORMATTED: 'FORMATTED';
+TEMPORARY: 'TEMPORARY' | 'TEMP';
+OPTIONS: 'OPTIONS';
+UNSET: 'UNSET';
+TBLPROPERTIES: 'TBLPROPERTIES';
+DBPROPERTIES: 'DBPROPERTIES';
+BUCKETS: 'BUCKETS';
+SKEWED: 'SKEWED';
+STORED: 'STORED';
+DIRECTORIES: 'DIRECTORIES';
+LOCATION: 'LOCATION';
+EXCHANGE: 'EXCHANGE';
+ARCHIVE: 'ARCHIVE';
+UNARCHIVE: 'UNARCHIVE';
+FILEFORMAT: 'FILEFORMAT';
+TOUCH: 'TOUCH';
+COMPACT: 'COMPACT';
+CONCATENATE: 'CONCATENATE';
+CHANGE: 'CHANGE';
+FIRST: 'FIRST';
+AFTER: 'AFTER';
+CASCADE: 'CASCADE';
+RESTRICT: 'RESTRICT';
+CLUSTERED: 'CLUSTERED';
+SORTED: 'SORTED';
+PURGE: 'PURGE';
+INPUTFORMAT: 'INPUTFORMAT';
+OUTPUTFORMAT: 'OUTPUTFORMAT';
+INPUTDRIVER: 'INPUTDRIVER';
+OUTPUTDRIVER: 'OUTPUTDRIVER';
+DATABASE: 'DATABASE' | 'SCHEMA';
+DATABASES: 'DATABASES' | 'SCHEMAS';
+DFS: 'DFS';
+TRUNCATE: 'TRUNCATE';
+METADATA: 'METADATA';
+REPLICATION: 'REPLICATION';
+ANALYZE: 'ANALYZE';
+COMPUTE: 'COMPUTE';
+STATISTICS: 'STATISTICS';
+PARTITIONED: 'PARTITIONED';
+EXTERNAL: 'EXTERNAL';
+DEFINED: 'DEFINED';
+REVOKE: 'REVOKE';
+GRANT: 'GRANT';
+LOCK: 'LOCK';
+UNLOCK: 'UNLOCK';
+MSCK: 'MSCK';
+REPAIR: 'REPAIR';
+EXPORT: 'EXPORT';
+IMPORT: 'IMPORT';
+LOAD: 'LOAD';
+ROLE: 'ROLE';
+ROLES: 'ROLES';
+COMPACTIONS: 'COMPACTIONS';
+PRINCIPALS: 'PRINCIPALS';
+TRANSACTIONS: 'TRANSACTIONS';
+INDEX: 'INDEX';
+INDEXES: 'INDEXES';
+LOCKS: 'LOCKS';
+OPTION: 'OPTION';
+ANTI: 'ANTI';
+
+STRING
+ : '\'' ( ~('\''|'\\') | ('\\' .) )* '\''
+ | '\"' ( ~('\"'|'\\') | ('\\' .) )* '\"'
+ ;
+
+BIGINT_LITERAL
+ : DIGIT+ 'L'
+ ;
+
+SMALLINT_LITERAL
+ : DIGIT+ 'S'
+ ;
+
+TINYINT_LITERAL
+ : DIGIT+ 'Y'
+ ;
+
+INTEGER_VALUE
+ : DIGIT+
+ ;
+
+DECIMAL_VALUE
+ : DIGIT+ '.' DIGIT*
+ | '.' DIGIT+
+ ;
+
+SCIENTIFIC_DECIMAL_VALUE
+ : DIGIT+ ('.' DIGIT*)? EXPONENT
+ | '.' DIGIT+ EXPONENT
+ ;
+
+DOUBLE_LITERAL
+ :
+ (INTEGER_VALUE | DECIMAL_VALUE | SCIENTIFIC_DECIMAL_VALUE) 'D'
+ ;
+
+IDENTIFIER
+ : (LETTER | DIGIT | '_')+
+ ;
+
+BACKQUOTED_IDENTIFIER
+ : '`' ( ~'`' | '``' )* '`'
+ ;
+
+fragment EXPONENT
+ : 'E' [+-]? DIGIT+
+ ;
+
+fragment DIGIT
+ : [0-9]
+ ;
+
+fragment LETTER
+ : [A-Z]
+ ;
+
+SIMPLE_COMMENT
+ : '--' ~[\r\n]* '\r'? '\n'? -> channel(HIDDEN)
+ ;
+
+BRACKETED_COMMENT
+ : '/*' .*? '*/' -> channel(HIDDEN)
+ ;
+
+WS
+ : [ \r\n\t]+ -> channel(HIDDEN)
+ ;
+
+// Catch-all for anything we can't recognize.
+// We use this to be able to ignore and recover all the text
+// when splitting statements with DelimiterLexer
+UNRECOGNIZED
+ : .
+ ;
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/parser/ParseUtils.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/parser/ParseUtils.java
deleted file mode 100644
index 01f89112a7..0000000000
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/parser/ParseUtils.java
+++ /dev/null
@@ -1,135 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.catalyst.parser;
-
-import java.nio.charset.StandardCharsets;
-
-/**
- * A couple of utility methods that help with parsing ASTs.
- *
- * The 'unescapeSQLString' method in this class was take from the SemanticAnalyzer in Hive:
- * ql/src/java/org/apache/hadoop/hive/ql/parse/BaseSemanticAnalyzer.java
- */
-public final class ParseUtils {
- private ParseUtils() {
- super();
- }
-
- private static final int[] multiplier = new int[] {1000, 100, 10, 1};
-
- @SuppressWarnings("nls")
- public static String unescapeSQLString(String b) {
- Character enclosure = null;
-
- // Some of the strings can be passed in as unicode. For example, the
- // delimiter can be passed in as \002 - So, we first check if the
- // string is a unicode number, else go back to the old behavior
- StringBuilder sb = new StringBuilder(b.length());
- for (int i = 0; i < b.length(); i++) {
-
- char currentChar = b.charAt(i);
- if (enclosure == null) {
- if (currentChar == '\'' || b.charAt(i) == '\"') {
- enclosure = currentChar;
- }
- // ignore all other chars outside the enclosure
- continue;
- }
-
- if (enclosure.equals(currentChar)) {
- enclosure = null;
- continue;
- }
-
- if (currentChar == '\\' && (i + 6 < b.length()) && b.charAt(i + 1) == 'u') {
- int code = 0;
- int base = i + 2;
- for (int j = 0; j < 4; j++) {
- int digit = Character.digit(b.charAt(j + base), 16);
- code += digit * multiplier[j];
- }
- sb.append((char)code);
- i += 5;
- continue;
- }
-
- if (currentChar == '\\' && (i + 4 < b.length())) {
- char i1 = b.charAt(i + 1);
- char i2 = b.charAt(i + 2);
- char i3 = b.charAt(i + 3);
- if ((i1 >= '0' && i1 <= '1') && (i2 >= '0' && i2 <= '7')
- && (i3 >= '0' && i3 <= '7')) {
- byte bVal = (byte) ((i3 - '0') + ((i2 - '0') * 8) + ((i1 - '0') * 8 * 8));
- byte[] bValArr = new byte[1];
- bValArr[0] = bVal;
- String tmp = new String(bValArr, StandardCharsets.UTF_8);
- sb.append(tmp);
- i += 3;
- continue;
- }
- }
-
- if (currentChar == '\\' && (i + 2 < b.length())) {
- char n = b.charAt(i + 1);
- switch (n) {
- case '0':
- sb.append("\0");
- break;
- case '\'':
- sb.append("'");
- break;
- case '"':
- sb.append("\"");
- break;
- case 'b':
- sb.append("\b");
- break;
- case 'n':
- sb.append("\n");
- break;
- case 'r':
- sb.append("\r");
- break;
- case 't':
- sb.append("\t");
- break;
- case 'Z':
- sb.append("\u001A");
- break;
- case '\\':
- sb.append("\\");
- break;
- // The following 2 lines are exactly what MySQL does TODO: why do we do this?
- case '%':
- sb.append("\\%");
- break;
- case '_':
- sb.append("\\_");
- break;
- default:
- sb.append(n);
- }
- i++;
- } else {
- sb.append(currentChar);
- }
- }
- return sb.toString();
- }
-}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java
index aa7fc2121e..7784345a7a 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java
@@ -151,7 +151,7 @@ public final class UnsafeExternalRowSorter {
Platform.throwException(e);
}
throw new RuntimeException("Exception should have been re-thrown in next()");
- };
+ }
};
} catch (IOException e) {
cleanupResources();
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala
index b19538a23f..ffa694fcdc 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala
@@ -17,22 +17,20 @@
package org.apache.spark.sql
-import java.lang.reflect.Modifier
-
import scala.annotation.implicitNotFound
-import scala.reflect.{classTag, ClassTag}
+import scala.reflect.ClassTag
import org.apache.spark.annotation.Experimental
-import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder}
-import org.apache.spark.sql.catalyst.expressions.{BoundReference, DecodeUsingSerializer, EncodeUsingSerializer}
import org.apache.spark.sql.types._
+
/**
* :: Experimental ::
* Used to convert a JVM object of type `T` to and from the internal Spark SQL representation.
*
* == Scala ==
- * Encoders are generally created automatically through implicits from a `SQLContext`.
+ * Encoders are generally created automatically through implicits from a `SQLContext`, or can be
+ * explicitly created by calling static methods on [[Encoders]].
*
* {{{
* import sqlContext.implicits._
@@ -81,224 +79,3 @@ trait Encoder[T] extends Serializable {
/** A ClassTag that can be used to construct and Array to contain a collection of `T`. */
def clsTag: ClassTag[T]
}
-
-/**
- * :: Experimental ::
- * Methods for creating an [[Encoder]].
- *
- * @since 1.6.0
- */
-@Experimental
-object Encoders {
-
- /**
- * An encoder for nullable boolean type.
- * @since 1.6.0
- */
- def BOOLEAN: Encoder[java.lang.Boolean] = ExpressionEncoder()
-
- /**
- * An encoder for nullable byte type.
- * @since 1.6.0
- */
- def BYTE: Encoder[java.lang.Byte] = ExpressionEncoder()
-
- /**
- * An encoder for nullable short type.
- * @since 1.6.0
- */
- def SHORT: Encoder[java.lang.Short] = ExpressionEncoder()
-
- /**
- * An encoder for nullable int type.
- * @since 1.6.0
- */
- def INT: Encoder[java.lang.Integer] = ExpressionEncoder()
-
- /**
- * An encoder for nullable long type.
- * @since 1.6.0
- */
- def LONG: Encoder[java.lang.Long] = ExpressionEncoder()
-
- /**
- * An encoder for nullable float type.
- * @since 1.6.0
- */
- def FLOAT: Encoder[java.lang.Float] = ExpressionEncoder()
-
- /**
- * An encoder for nullable double type.
- * @since 1.6.0
- */
- def DOUBLE: Encoder[java.lang.Double] = ExpressionEncoder()
-
- /**
- * An encoder for nullable string type.
- * @since 1.6.0
- */
- def STRING: Encoder[java.lang.String] = ExpressionEncoder()
-
- /**
- * An encoder for nullable decimal type.
- * @since 1.6.0
- */
- def DECIMAL: Encoder[java.math.BigDecimal] = ExpressionEncoder()
-
- /**
- * An encoder for nullable date type.
- * @since 1.6.0
- */
- def DATE: Encoder[java.sql.Date] = ExpressionEncoder()
-
- /**
- * An encoder for nullable timestamp type.
- * @since 1.6.0
- */
- def TIMESTAMP: Encoder[java.sql.Timestamp] = ExpressionEncoder()
-
- /**
- * An encoder for arrays of bytes.
- * @since 1.6.1
- */
- def BINARY: Encoder[Array[Byte]] = ExpressionEncoder()
-
- /**
- * Creates an encoder for Java Bean of type T.
- *
- * T must be publicly accessible.
- *
- * supported types for java bean field:
- * - primitive types: boolean, int, double, etc.
- * - boxed types: Boolean, Integer, Double, etc.
- * - String
- * - java.math.BigDecimal
- * - time related: java.sql.Date, java.sql.Timestamp
- * - collection types: only array and java.util.List currently, map support is in progress
- * - nested java bean.
- *
- * @since 1.6.0
- */
- def bean[T](beanClass: Class[T]): Encoder[T] = ExpressionEncoder.javaBean(beanClass)
-
- /**
- * (Scala-specific) Creates an encoder that serializes objects of type T using Kryo.
- * This encoder maps T into a single byte array (binary) field.
- *
- * T must be publicly accessible.
- *
- * @since 1.6.0
- */
- def kryo[T: ClassTag]: Encoder[T] = genericSerializer(useKryo = true)
-
- /**
- * Creates an encoder that serializes objects of type T using Kryo.
- * This encoder maps T into a single byte array (binary) field.
- *
- * T must be publicly accessible.
- *
- * @since 1.6.0
- */
- def kryo[T](clazz: Class[T]): Encoder[T] = kryo(ClassTag[T](clazz))
-
- /**
- * (Scala-specific) Creates an encoder that serializes objects of type T using generic Java
- * serialization. This encoder maps T into a single byte array (binary) field.
- *
- * Note that this is extremely inefficient and should only be used as the last resort.
- *
- * T must be publicly accessible.
- *
- * @since 1.6.0
- */
- def javaSerialization[T: ClassTag]: Encoder[T] = genericSerializer(useKryo = false)
-
- /**
- * Creates an encoder that serializes objects of type T using generic Java serialization.
- * This encoder maps T into a single byte array (binary) field.
- *
- * Note that this is extremely inefficient and should only be used as the last resort.
- *
- * T must be publicly accessible.
- *
- * @since 1.6.0
- */
- def javaSerialization[T](clazz: Class[T]): Encoder[T] = javaSerialization(ClassTag[T](clazz))
-
- /** Throws an exception if T is not a public class. */
- private def validatePublicClass[T: ClassTag](): Unit = {
- if (!Modifier.isPublic(classTag[T].runtimeClass.getModifiers)) {
- throw new UnsupportedOperationException(
- s"${classTag[T].runtimeClass.getName} is not a public class. " +
- "Only public classes are supported.")
- }
- }
-
- /** A way to construct encoders using generic serializers. */
- private def genericSerializer[T: ClassTag](useKryo: Boolean): Encoder[T] = {
- if (classTag[T].runtimeClass.isPrimitive) {
- throw new UnsupportedOperationException("Primitive types are not supported.")
- }
-
- validatePublicClass[T]()
-
- ExpressionEncoder[T](
- schema = new StructType().add("value", BinaryType),
- flat = true,
- toRowExpressions = Seq(
- EncodeUsingSerializer(
- BoundReference(0, ObjectType(classOf[AnyRef]), nullable = true), kryo = useKryo)),
- fromRowExpression =
- DecodeUsingSerializer[T](
- BoundReference(0, BinaryType, nullable = true), classTag[T], kryo = useKryo),
- clsTag = classTag[T]
- )
- }
-
- /**
- * An encoder for 2-ary tuples.
- * @since 1.6.0
- */
- def tuple[T1, T2](
- e1: Encoder[T1],
- e2: Encoder[T2]): Encoder[(T1, T2)] = {
- ExpressionEncoder.tuple(encoderFor(e1), encoderFor(e2))
- }
-
- /**
- * An encoder for 3-ary tuples.
- * @since 1.6.0
- */
- def tuple[T1, T2, T3](
- e1: Encoder[T1],
- e2: Encoder[T2],
- e3: Encoder[T3]): Encoder[(T1, T2, T3)] = {
- ExpressionEncoder.tuple(encoderFor(e1), encoderFor(e2), encoderFor(e3))
- }
-
- /**
- * An encoder for 4-ary tuples.
- * @since 1.6.0
- */
- def tuple[T1, T2, T3, T4](
- e1: Encoder[T1],
- e2: Encoder[T2],
- e3: Encoder[T3],
- e4: Encoder[T4]): Encoder[(T1, T2, T3, T4)] = {
- ExpressionEncoder.tuple(encoderFor(e1), encoderFor(e2), encoderFor(e3), encoderFor(e4))
- }
-
- /**
- * An encoder for 5-ary tuples.
- * @since 1.6.0
- */
- def tuple[T1, T2, T3, T4, T5](
- e1: Encoder[T1],
- e2: Encoder[T2],
- e3: Encoder[T3],
- e4: Encoder[T4],
- e5: Encoder[T5]): Encoder[(T1, T2, T3, T4, T5)] = {
- ExpressionEncoder.tuple(
- encoderFor(e1), encoderFor(e2), encoderFor(e3), encoderFor(e4), encoderFor(e5))
- }
-}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala
new file mode 100644
index 0000000000..3f4df704db
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala
@@ -0,0 +1,314 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import java.lang.reflect.Modifier
+
+import scala.reflect.{classTag, ClassTag}
+import scala.reflect.runtime.universe.TypeTag
+
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder}
+import org.apache.spark.sql.catalyst.expressions.{BoundReference, DecodeUsingSerializer, EncodeUsingSerializer}
+import org.apache.spark.sql.types._
+
+/**
+ * :: Experimental ::
+ * Methods for creating an [[Encoder]].
+ *
+ * @since 1.6.0
+ */
+@Experimental
+object Encoders {
+
+ /**
+ * An encoder for nullable boolean type.
+ * The Scala primitive encoder is available as [[scalaBoolean]].
+ * @since 1.6.0
+ */
+ def BOOLEAN: Encoder[java.lang.Boolean] = ExpressionEncoder()
+
+ /**
+ * An encoder for nullable byte type.
+ * The Scala primitive encoder is available as [[scalaByte]].
+ * @since 1.6.0
+ */
+ def BYTE: Encoder[java.lang.Byte] = ExpressionEncoder()
+
+ /**
+ * An encoder for nullable short type.
+ * The Scala primitive encoder is available as [[scalaShort]].
+ * @since 1.6.0
+ */
+ def SHORT: Encoder[java.lang.Short] = ExpressionEncoder()
+
+ /**
+ * An encoder for nullable int type.
+ * The Scala primitive encoder is available as [[scalaInt]].
+ * @since 1.6.0
+ */
+ def INT: Encoder[java.lang.Integer] = ExpressionEncoder()
+
+ /**
+ * An encoder for nullable long type.
+ * The Scala primitive encoder is available as [[scalaLong]].
+ * @since 1.6.0
+ */
+ def LONG: Encoder[java.lang.Long] = ExpressionEncoder()
+
+ /**
+ * An encoder for nullable float type.
+ * The Scala primitive encoder is available as [[scalaFloat]].
+ * @since 1.6.0
+ */
+ def FLOAT: Encoder[java.lang.Float] = ExpressionEncoder()
+
+ /**
+ * An encoder for nullable double type.
+ * The Scala primitive encoder is available as [[scalaDouble]].
+ * @since 1.6.0
+ */
+ def DOUBLE: Encoder[java.lang.Double] = ExpressionEncoder()
+
+ /**
+ * An encoder for nullable string type.
+ *
+ * @since 1.6.0
+ */
+ def STRING: Encoder[java.lang.String] = ExpressionEncoder()
+
+ /**
+ * An encoder for nullable decimal type.
+ *
+ * @since 1.6.0
+ */
+ def DECIMAL: Encoder[java.math.BigDecimal] = ExpressionEncoder()
+
+ /**
+ * An encoder for nullable date type.
+ *
+ * @since 1.6.0
+ */
+ def DATE: Encoder[java.sql.Date] = ExpressionEncoder()
+
+ /**
+ * An encoder for nullable timestamp type.
+ *
+ * @since 1.6.0
+ */
+ def TIMESTAMP: Encoder[java.sql.Timestamp] = ExpressionEncoder()
+
+ /**
+ * An encoder for arrays of bytes.
+ *
+ * @since 1.6.1
+ */
+ def BINARY: Encoder[Array[Byte]] = ExpressionEncoder()
+
+ /**
+ * Creates an encoder for Java Bean of type T.
+ *
+ * T must be publicly accessible.
+ *
+ * supported types for java bean field:
+ * - primitive types: boolean, int, double, etc.
+ * - boxed types: Boolean, Integer, Double, etc.
+ * - String
+ * - java.math.BigDecimal
+ * - time related: java.sql.Date, java.sql.Timestamp
+ * - collection types: only array and java.util.List currently, map support is in progress
+ * - nested java bean.
+ *
+ * @since 1.6.0
+ */
+ def bean[T](beanClass: Class[T]): Encoder[T] = ExpressionEncoder.javaBean(beanClass)
+
+ /**
+ * (Scala-specific) Creates an encoder that serializes objects of type T using Kryo.
+ * This encoder maps T into a single byte array (binary) field.
+ *
+ * T must be publicly accessible.
+ *
+ * @since 1.6.0
+ */
+ def kryo[T: ClassTag]: Encoder[T] = genericSerializer(useKryo = true)
+
+ /**
+ * Creates an encoder that serializes objects of type T using Kryo.
+ * This encoder maps T into a single byte array (binary) field.
+ *
+ * T must be publicly accessible.
+ *
+ * @since 1.6.0
+ */
+ def kryo[T](clazz: Class[T]): Encoder[T] = kryo(ClassTag[T](clazz))
+
+ /**
+ * (Scala-specific) Creates an encoder that serializes objects of type T using generic Java
+ * serialization. This encoder maps T into a single byte array (binary) field.
+ *
+ * Note that this is extremely inefficient and should only be used as the last resort.
+ *
+ * T must be publicly accessible.
+ *
+ * @since 1.6.0
+ */
+ def javaSerialization[T: ClassTag]: Encoder[T] = genericSerializer(useKryo = false)
+
+ /**
+ * Creates an encoder that serializes objects of type T using generic Java serialization.
+ * This encoder maps T into a single byte array (binary) field.
+ *
+ * Note that this is extremely inefficient and should only be used as the last resort.
+ *
+ * T must be publicly accessible.
+ *
+ * @since 1.6.0
+ */
+ def javaSerialization[T](clazz: Class[T]): Encoder[T] = javaSerialization(ClassTag[T](clazz))
+
+ /** Throws an exception if T is not a public class. */
+ private def validatePublicClass[T: ClassTag](): Unit = {
+ if (!Modifier.isPublic(classTag[T].runtimeClass.getModifiers)) {
+ throw new UnsupportedOperationException(
+ s"${classTag[T].runtimeClass.getName} is not a public class. " +
+ "Only public classes are supported.")
+ }
+ }
+
+ /** A way to construct encoders using generic serializers. */
+ private def genericSerializer[T: ClassTag](useKryo: Boolean): Encoder[T] = {
+ if (classTag[T].runtimeClass.isPrimitive) {
+ throw new UnsupportedOperationException("Primitive types are not supported.")
+ }
+
+ validatePublicClass[T]()
+
+ ExpressionEncoder[T](
+ schema = new StructType().add("value", BinaryType),
+ flat = true,
+ serializer = Seq(
+ EncodeUsingSerializer(
+ BoundReference(0, ObjectType(classOf[AnyRef]), nullable = true), kryo = useKryo)),
+ deserializer =
+ DecodeUsingSerializer[T](
+ BoundReference(0, BinaryType, nullable = true), classTag[T], kryo = useKryo),
+ clsTag = classTag[T]
+ )
+ }
+
+ /**
+ * An encoder for 2-ary tuples.
+ *
+ * @since 1.6.0
+ */
+ def tuple[T1, T2](
+ e1: Encoder[T1],
+ e2: Encoder[T2]): Encoder[(T1, T2)] = {
+ ExpressionEncoder.tuple(encoderFor(e1), encoderFor(e2))
+ }
+
+ /**
+ * An encoder for 3-ary tuples.
+ *
+ * @since 1.6.0
+ */
+ def tuple[T1, T2, T3](
+ e1: Encoder[T1],
+ e2: Encoder[T2],
+ e3: Encoder[T3]): Encoder[(T1, T2, T3)] = {
+ ExpressionEncoder.tuple(encoderFor(e1), encoderFor(e2), encoderFor(e3))
+ }
+
+ /**
+ * An encoder for 4-ary tuples.
+ *
+ * @since 1.6.0
+ */
+ def tuple[T1, T2, T3, T4](
+ e1: Encoder[T1],
+ e2: Encoder[T2],
+ e3: Encoder[T3],
+ e4: Encoder[T4]): Encoder[(T1, T2, T3, T4)] = {
+ ExpressionEncoder.tuple(encoderFor(e1), encoderFor(e2), encoderFor(e3), encoderFor(e4))
+ }
+
+ /**
+ * An encoder for 5-ary tuples.
+ *
+ * @since 1.6.0
+ */
+ def tuple[T1, T2, T3, T4, T5](
+ e1: Encoder[T1],
+ e2: Encoder[T2],
+ e3: Encoder[T3],
+ e4: Encoder[T4],
+ e5: Encoder[T5]): Encoder[(T1, T2, T3, T4, T5)] = {
+ ExpressionEncoder.tuple(
+ encoderFor(e1), encoderFor(e2), encoderFor(e3), encoderFor(e4), encoderFor(e5))
+ }
+
+ /**
+ * An encoder for Scala's product type (tuples, case classes, etc).
+ * @since 2.0.0
+ */
+ def product[T <: Product : TypeTag]: Encoder[T] = ExpressionEncoder()
+
+ /**
+ * An encoder for Scala's primitive int type.
+ * @since 2.0.0
+ */
+ def scalaInt: Encoder[Int] = ExpressionEncoder()
+
+ /**
+ * An encoder for Scala's primitive long type.
+ * @since 2.0.0
+ */
+ def scalaLong: Encoder[Long] = ExpressionEncoder()
+
+ /**
+ * An encoder for Scala's primitive double type.
+ * @since 2.0.0
+ */
+ def scalaDouble: Encoder[Double] = ExpressionEncoder()
+
+ /**
+ * An encoder for Scala's primitive float type.
+ * @since 2.0.0
+ */
+ def scalaFloat: Encoder[Float] = ExpressionEncoder()
+
+ /**
+ * An encoder for Scala's primitive byte type.
+ * @since 2.0.0
+ */
+ def scalaByte: Encoder[Byte] = ExpressionEncoder()
+
+ /**
+ * An encoder for Scala's primitive short type.
+ * @since 2.0.0
+ */
+ def scalaShort: Encoder[Short] = ExpressionEncoder()
+
+ /**
+ * An encoder for Scala's primitive boolean type.
+ * @since 2.0.0
+ */
+ def scalaBoolean: Encoder[Boolean] = ExpressionEncoder()
+
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala
index d5ac01500b..2b98aacdd7 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala
@@ -26,7 +26,7 @@ private[spark] trait CatalystConf {
def groupByOrdinal: Boolean
/**
- * Returns the [[Resolver]] for the current configuration, which can be used to determin if two
+ * Returns the [[Resolver]] for the current configuration, which can be used to determine if two
* identifiers are equal.
*/
def resolver: Resolver = {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
index 59ee41d02f..6f9fbbbead 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
@@ -155,16 +155,16 @@ object JavaTypeInference {
}
/**
- * Returns an expression that can be used to construct an object of java bean `T` given an input
- * row with a compatible schema. Fields of the row will be extracted using UnresolvedAttributes
+ * Returns an expression that can be used to deserialize an internal row to an object of java bean
+ * `T` with a compatible schema. Fields of the row will be extracted using UnresolvedAttributes
* of the same name as the constructor arguments. Nested classes will have their fields accessed
* using UnresolvedExtractValue.
*/
- def constructorFor(beanClass: Class[_]): Expression = {
- constructorFor(TypeToken.of(beanClass), None)
+ def deserializerFor(beanClass: Class[_]): Expression = {
+ deserializerFor(TypeToken.of(beanClass), None)
}
- private def constructorFor(typeToken: TypeToken[_], path: Option[Expression]): Expression = {
+ private def deserializerFor(typeToken: TypeToken[_], path: Option[Expression]): Expression = {
/** Returns the current path with a sub-field extracted. */
def addToPath(part: String): Expression = path
.map(p => UnresolvedExtractValue(p, expressions.Literal(part)))
@@ -231,7 +231,7 @@ object JavaTypeInference {
}.getOrElse {
Invoke(
MapObjects(
- p => constructorFor(typeToken.getComponentType, Some(p)),
+ p => deserializerFor(typeToken.getComponentType, Some(p)),
getPath,
inferDataType(elementType)._1),
"array",
@@ -243,7 +243,7 @@ object JavaTypeInference {
val array =
Invoke(
MapObjects(
- p => constructorFor(et, Some(p)),
+ p => deserializerFor(et, Some(p)),
getPath,
inferDataType(et)._1),
"array",
@@ -259,7 +259,7 @@ object JavaTypeInference {
val keyData =
Invoke(
MapObjects(
- p => constructorFor(keyType, Some(p)),
+ p => deserializerFor(keyType, Some(p)),
Invoke(getPath, "keyArray", ArrayType(keyDataType)),
keyDataType),
"array",
@@ -268,7 +268,7 @@ object JavaTypeInference {
val valueData =
Invoke(
MapObjects(
- p => constructorFor(valueType, Some(p)),
+ p => deserializerFor(valueType, Some(p)),
Invoke(getPath, "valueArray", ArrayType(valueDataType)),
valueDataType),
"array",
@@ -288,7 +288,7 @@ object JavaTypeInference {
val fieldName = p.getName
val fieldType = typeToken.method(p.getReadMethod).getReturnType
val (_, nullable) = inferDataType(fieldType)
- val constructor = constructorFor(fieldType, Some(addToPath(fieldName)))
+ val constructor = deserializerFor(fieldType, Some(addToPath(fieldName)))
val setter = if (nullable) {
constructor
} else {
@@ -313,14 +313,14 @@ object JavaTypeInference {
}
/**
- * Returns expressions for extracting all the fields from the given type.
+ * Returns an expression for serializing an object of the given type to an internal row.
*/
- def extractorsFor(beanClass: Class[_]): CreateNamedStruct = {
+ def serializerFor(beanClass: Class[_]): CreateNamedStruct = {
val inputObject = BoundReference(0, ObjectType(beanClass), nullable = true)
- extractorFor(inputObject, TypeToken.of(beanClass)).asInstanceOf[CreateNamedStruct]
+ serializerFor(inputObject, TypeToken.of(beanClass)).asInstanceOf[CreateNamedStruct]
}
- private def extractorFor(inputObject: Expression, typeToken: TypeToken[_]): Expression = {
+ private def serializerFor(inputObject: Expression, typeToken: TypeToken[_]): Expression = {
def toCatalystArray(input: Expression, elementType: TypeToken[_]): Expression = {
val (dataType, nullable) = inferDataType(elementType)
@@ -330,7 +330,7 @@ object JavaTypeInference {
input :: Nil,
dataType = ArrayType(dataType, nullable))
} else {
- MapObjects(extractorFor(_, elementType), input, ObjectType(elementType.getRawType))
+ MapObjects(serializerFor(_, elementType), input, ObjectType(elementType.getRawType))
}
}
@@ -403,7 +403,7 @@ object JavaTypeInference {
inputObject,
p.getReadMethod.getName,
inferExternalType(fieldType.getRawType))
- expressions.Literal(fieldName) :: extractorFor(fieldValue, fieldType) :: Nil
+ expressions.Literal(fieldName) :: serializerFor(fieldValue, fieldType) :: Nil
})
} else {
throw new UnsupportedOperationException(
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index f208401160..4795fc2557 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -110,8 +110,8 @@ object ScalaReflection extends ScalaReflection {
}
/**
- * Returns an expression that can be used to construct an object of type `T` given an input
- * row with a compatible schema. Fields of the row will be extracted using UnresolvedAttributes
+ * Returns an expression that can be used to deserialize an input row to an object of type `T`
+ * with a compatible schema. Fields of the row will be extracted using UnresolvedAttributes
* of the same name as the constructor arguments. Nested classes will have their fields accessed
* using UnresolvedExtractValue.
*
@@ -119,14 +119,14 @@ object ScalaReflection extends ScalaReflection {
* from ordinal 0 (since there are no names to map to). The actual location can be moved by
* calling resolve/bind with a new schema.
*/
- def constructorFor[T : TypeTag]: Expression = {
+ def deserializerFor[T : TypeTag]: Expression = {
val tpe = localTypeOf[T]
val clsName = getClassNameFromType(tpe)
val walkedTypePath = s"""- root class: "${clsName}"""" :: Nil
- constructorFor(tpe, None, walkedTypePath)
+ deserializerFor(tpe, None, walkedTypePath)
}
- private def constructorFor(
+ private def deserializerFor(
tpe: `Type`,
path: Option[Expression],
walkedTypePath: Seq[String]): Expression = ScalaReflectionLock.synchronized {
@@ -161,7 +161,7 @@ object ScalaReflection extends ScalaReflection {
}
/**
- * When we build the `fromRowExpression` for an encoder, we set up a lot of "unresolved" stuff
+ * When we build the `deserializer` for an encoder, we set up a lot of "unresolved" stuff
* and lost the required data type, which may lead to runtime error if the real type doesn't
* match the encoder's schema.
* For example, we build an encoder for `case class Data(a: Int, b: String)` and the real type
@@ -188,7 +188,7 @@ object ScalaReflection extends ScalaReflection {
val TypeRef(_, _, Seq(optType)) = t
val className = getClassNameFromType(optType)
val newTypePath = s"""- option value class: "$className"""" +: walkedTypePath
- WrapOption(constructorFor(optType, path, newTypePath), dataTypeFor(optType))
+ WrapOption(deserializerFor(optType, path, newTypePath), dataTypeFor(optType))
case t if t <:< localTypeOf[java.lang.Integer] =>
val boxedType = classOf[java.lang.Integer]
@@ -272,7 +272,7 @@ object ScalaReflection extends ScalaReflection {
val newTypePath = s"""- array element class: "$className"""" +: walkedTypePath
Invoke(
MapObjects(
- p => constructorFor(elementType, Some(p), newTypePath),
+ p => deserializerFor(elementType, Some(p), newTypePath),
getPath,
schemaFor(elementType).dataType),
"array",
@@ -286,7 +286,7 @@ object ScalaReflection extends ScalaReflection {
val newTypePath = s"""- array element class: "$className"""" +: walkedTypePath
val mapFunction: Expression => Expression = p => {
- val converter = constructorFor(elementType, Some(p), newTypePath)
+ val converter = deserializerFor(elementType, Some(p), newTypePath)
if (nullable) {
converter
} else {
@@ -312,7 +312,7 @@ object ScalaReflection extends ScalaReflection {
val keyData =
Invoke(
MapObjects(
- p => constructorFor(keyType, Some(p), walkedTypePath),
+ p => deserializerFor(keyType, Some(p), walkedTypePath),
Invoke(getPath, "keyArray", ArrayType(schemaFor(keyType).dataType)),
schemaFor(keyType).dataType),
"array",
@@ -321,7 +321,7 @@ object ScalaReflection extends ScalaReflection {
val valueData =
Invoke(
MapObjects(
- p => constructorFor(valueType, Some(p), walkedTypePath),
+ p => deserializerFor(valueType, Some(p), walkedTypePath),
Invoke(getPath, "valueArray", ArrayType(schemaFor(valueType).dataType)),
schemaFor(valueType).dataType),
"array",
@@ -344,12 +344,12 @@ object ScalaReflection extends ScalaReflection {
val newTypePath = s"""- field (class: "$clsName", name: "$fieldName")""" +: walkedTypePath
// For tuples, we based grab the inner fields by ordinal instead of name.
if (cls.getName startsWith "scala.Tuple") {
- constructorFor(
+ deserializerFor(
fieldType,
Some(addToPathOrdinal(i, dataType, newTypePath)),
newTypePath)
} else {
- val constructor = constructorFor(
+ val constructor = deserializerFor(
fieldType,
Some(addToPath(fieldName, dataType, newTypePath)),
newTypePath)
@@ -387,7 +387,7 @@ object ScalaReflection extends ScalaReflection {
}
/**
- * Returns expressions for extracting all the fields from the given type.
+ * Returns an expression for serializing an object of type T to an internal row.
*
* If the given type is not supported, i.e. there is no encoder can be built for this type,
* an [[UnsupportedOperationException]] will be thrown with detailed error message to explain
@@ -398,18 +398,18 @@ object ScalaReflection extends ScalaReflection {
* * the element type of [[Array]] or [[Seq]]: `array element class: "abc.xyz.MyClass"`
* * the field of [[Product]]: `field (class: "abc.xyz.MyClass", name: "myField")`
*/
- def extractorsFor[T : TypeTag](inputObject: Expression): CreateNamedStruct = {
+ def serializerFor[T : TypeTag](inputObject: Expression): CreateNamedStruct = {
val tpe = localTypeOf[T]
val clsName = getClassNameFromType(tpe)
val walkedTypePath = s"""- root class: "${clsName}"""" :: Nil
- extractorFor(inputObject, tpe, walkedTypePath) match {
+ serializerFor(inputObject, tpe, walkedTypePath) match {
case expressions.If(_, _, s: CreateNamedStruct) if tpe <:< localTypeOf[Product] => s
case other => CreateNamedStruct(expressions.Literal("value") :: other :: Nil)
}
}
/** Helper for extracting internal fields from a case class. */
- private def extractorFor(
+ private def serializerFor(
inputObject: Expression,
tpe: `Type`,
walkedTypePath: Seq[String]): Expression = ScalaReflectionLock.synchronized {
@@ -425,7 +425,7 @@ object ScalaReflection extends ScalaReflection {
} else {
val clsName = getClassNameFromType(elementType)
val newPath = s"""- array element class: "$clsName"""" +: walkedTypePath
- MapObjects(extractorFor(_, elementType, newPath), input, externalDataType)
+ MapObjects(serializerFor(_, elementType, newPath), input, externalDataType)
}
}
@@ -491,7 +491,7 @@ object ScalaReflection extends ScalaReflection {
expressions.If(
IsNull(unwrapped),
expressions.Literal.create(null, silentSchemaFor(optType).dataType),
- extractorFor(unwrapped, optType, newPath))
+ serializerFor(unwrapped, optType, newPath))
}
case t if t <:< localTypeOf[Product] =>
@@ -500,7 +500,7 @@ object ScalaReflection extends ScalaReflection {
val fieldValue = Invoke(inputObject, fieldName, dataTypeFor(fieldType))
val clsName = getClassNameFromType(fieldType)
val newPath = s"""- field (class: "$clsName", name: "$fieldName")""" +: walkedTypePath
- expressions.Literal(fieldName) :: extractorFor(fieldValue, fieldType, newPath) :: Nil
+ expressions.Literal(fieldName) :: serializerFor(fieldValue, fieldType, newPath) :: Nil
})
val nullOutput = expressions.Literal.create(null, nonNullOutput.dataType)
expressions.If(IsNull(inputObject), nullOutput, nonNullOutput)
@@ -762,15 +762,15 @@ trait ScalaReflection {
}
/**
- * Returns the full class name for a type. The returned name is the canonical
- * Scala name, where each component is separated by a period. It is NOT the
- * Java-equivalent runtime name (no dollar signs).
- *
- * In simple cases, both the Scala and Java names are the same, however when Scala
- * generates constructs that do not map to a Java equivalent, such as singleton objects
- * or nested classes in package objects, it uses the dollar sign ($) to create
- * synthetic classes, emulating behaviour in Java bytecode.
- */
+ * Returns the full class name for a type. The returned name is the canonical
+ * Scala name, where each component is separated by a period. It is NOT the
+ * Java-equivalent runtime name (no dollar signs).
+ *
+ * In simple cases, both the Scala and Java names are the same, however when Scala
+ * generates constructs that do not map to a Java equivalent, such as singleton objects
+ * or nested classes in package objects, it uses the dollar sign ($) to create
+ * synthetic classes, emulating behaviour in Java bytecode.
+ */
def getClassNameFromType(tpe: `Type`): String = {
tpe.erasure.typeSymbol.asClass.fullName
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 3b83e68018..de40ddde1b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -17,8 +17,6 @@
package org.apache.spark.sql.catalyst.analysis
-import java.lang.reflect.Modifier
-
import scala.annotation.tailrec
import scala.collection.mutable.ArrayBuffer
@@ -42,9 +40,12 @@ import org.apache.spark.sql.types._
* to resolve attribute references.
*/
object SimpleAnalyzer
- extends SimpleAnalyzer(new SimpleCatalystConf(caseSensitiveAnalysis = true))
-class SimpleAnalyzer(conf: CatalystConf)
- extends Analyzer(new SessionCatalog(new InMemoryCatalog, conf), EmptyFunctionRegistry, conf)
+ extends SimpleAnalyzer(
+ EmptyFunctionRegistry,
+ new SimpleCatalystConf(caseSensitiveAnalysis = true))
+
+class SimpleAnalyzer(functionRegistry: FunctionRegistry, conf: CatalystConf)
+ extends Analyzer(new SessionCatalog(new InMemoryCatalog, functionRegistry, conf), conf)
/**
* Provides a logical query plan analyzer, which translates [[UnresolvedAttribute]]s and
@@ -53,7 +54,6 @@ class SimpleAnalyzer(conf: CatalystConf)
*/
class Analyzer(
catalog: SessionCatalog,
- registry: FunctionRegistry,
conf: CatalystConf,
maxIterations: Int = 100)
extends RuleExecutor[LogicalPlan] with CheckAnalysis {
@@ -81,11 +81,13 @@ class Analyzer(
Batch("Resolution", fixedPoint,
ResolveRelations ::
ResolveReferences ::
+ ResolveDeserializer ::
+ ResolveNewInstance ::
+ ResolveUpCast ::
ResolveGroupingAnalytics ::
ResolvePivot ::
- ResolveUpCast ::
ResolveOrdinalInOrderByAndGroupBy ::
- ResolveSortReferences ::
+ ResolveMissingReferences ::
ResolveGenerate ::
ResolveFunctions ::
ResolveAliases ::
@@ -96,6 +98,7 @@ class Analyzer(
ExtractWindowExpressions ::
GlobalAggregates ::
ResolveAggregateFunctions ::
+ TimeWindowing ::
HiveTypeCoercion.typeCoercionRules ++
extendedResolutionRules : _*),
Batch("Nondeterministic", Once,
@@ -225,21 +228,56 @@ class Analyzer(
Seq.tabulate(1 << c.groupByExprs.length)(i => i)
}
- private def hasGroupingId(expr: Seq[Expression]): Boolean = {
- expr.exists(_.collectFirst {
- case u: UnresolvedAttribute if resolver(u.name, VirtualColumn.groupingIdName) => u
- }.isDefined)
+ private def hasGroupingAttribute(expr: Expression): Boolean = {
+ expr.collectFirst {
+ case u: UnresolvedAttribute if resolver(u.name, VirtualColumn.hiveGroupingIdName) => u
+ }.isDefined
}
- def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
+ private def hasGroupingFunction(e: Expression): Boolean = {
+ e.collectFirst {
+ case g: Grouping => g
+ case g: GroupingID => g
+ }.isDefined
+ }
+
+ private def replaceGroupingFunc(
+ expr: Expression,
+ groupByExprs: Seq[Expression],
+ gid: Expression): Expression = {
+ expr transform {
+ case e: GroupingID =>
+ if (e.groupByExprs.isEmpty || e.groupByExprs == groupByExprs) {
+ gid
+ } else {
+ throw new AnalysisException(
+ s"Columns of grouping_id (${e.groupByExprs.mkString(",")}) does not match " +
+ s"grouping columns (${groupByExprs.mkString(",")})")
+ }
+ case Grouping(col: Expression) =>
+ val idx = groupByExprs.indexOf(col)
+ if (idx >= 0) {
+ Cast(BitwiseAnd(ShiftRight(gid, Literal(groupByExprs.length - 1 - idx)),
+ Literal(1)), ByteType)
+ } else {
+ throw new AnalysisException(s"Column of grouping ($col) can't be found " +
+ s"in grouping columns ${groupByExprs.mkString(",")}")
+ }
+ }
+ }
+
+ // This require transformUp to replace grouping()/grouping_id() in resolved Filter/Sort
+ def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
case a if !a.childrenResolved => a // be sure all of the children are resolved.
+ case p if p.expressions.exists(hasGroupingAttribute) =>
+ failAnalysis(
+ s"${VirtualColumn.hiveGroupingIdName} is deprecated; use grouping_id() instead")
+
case Aggregate(Seq(c @ Cube(groupByExprs)), aggregateExpressions, child) =>
GroupingSets(bitmasks(c), groupByExprs, child, aggregateExpressions)
case Aggregate(Seq(r @ Rollup(groupByExprs)), aggregateExpressions, child) =>
GroupingSets(bitmasks(r), groupByExprs, child, aggregateExpressions)
- case g: GroupingSets if g.expressions.exists(!_.resolved) && hasGroupingId(g.expressions) =>
- failAnalysis(
- s"${VirtualColumn.groupingIdName} is deprecated; use grouping_id() instead")
+
// Ensure all the expressions have been resolved.
case x: GroupingSets if x.expressions.forall(_.resolved) =>
val gid = AttributeReference(VirtualColumn.groupingIdName, IntegerType, false)()
@@ -267,7 +305,7 @@ class Analyzer(
def isPartOfAggregation(e: Expression): Boolean = {
aggsBuffer.exists(a => a.find(_ eq e).isDefined)
}
- expr.transformDown {
+ replaceGroupingFunc(expr, x.groupByExprs, gid).transformDown {
// AggregateExpression should be computed on the unmodified value of its argument
// expressions, so we should not replace any references to grouping expression
// inside it.
@@ -275,23 +313,6 @@ class Analyzer(
aggsBuffer += e
e
case e if isPartOfAggregation(e) => e
- case e: GroupingID =>
- if (e.groupByExprs.isEmpty || e.groupByExprs == x.groupByExprs) {
- gid
- } else {
- throw new AnalysisException(
- s"Columns of grouping_id (${e.groupByExprs.mkString(",")}) does not match " +
- s"grouping columns (${x.groupByExprs.mkString(",")})")
- }
- case Grouping(col: Expression) =>
- val idx = x.groupByExprs.indexOf(col)
- if (idx >= 0) {
- Cast(BitwiseAnd(ShiftRight(gid, Literal(x.groupByExprs.length - 1 - idx)),
- Literal(1)), ByteType)
- } else {
- throw new AnalysisException(s"Column of grouping ($col) can't be found " +
- s"in grouping columns ${x.groupByExprs.mkString(",")}")
- }
case e =>
val index = groupByAliases.indexWhere(_.child.semanticEquals(e))
if (index == -1) {
@@ -303,9 +324,37 @@ class Analyzer(
}
Aggregate(
- groupByAttributes :+ VirtualColumn.groupingIdAttribute,
+ groupByAttributes :+ gid,
aggregations,
Expand(x.bitmasks, groupByAliases, groupByAttributes, gid, x.child))
+
+ case f @ Filter(cond, child) if hasGroupingFunction(cond) =>
+ val groupingExprs = findGroupingExprs(child)
+ // The unresolved grouping id will be resolved by ResolveMissingReferences
+ val newCond = replaceGroupingFunc(cond, groupingExprs, VirtualColumn.groupingIdAttribute)
+ f.copy(condition = newCond)
+
+ case s @ Sort(order, _, child) if order.exists(hasGroupingFunction) =>
+ val groupingExprs = findGroupingExprs(child)
+ val gid = VirtualColumn.groupingIdAttribute
+ // The unresolved grouping id will be resolved by ResolveMissingReferences
+ val newOrder = order.map(replaceGroupingFunc(_, groupingExprs, gid).asInstanceOf[SortOrder])
+ s.copy(order = newOrder)
+ }
+
+ private def findGroupingExprs(plan: LogicalPlan): Seq[Expression] = {
+ plan.collectFirst {
+ case a: Aggregate =>
+ // this Aggregate should have grouping id as the last grouping key.
+ val gid = a.groupingExpressions.last
+ if (!gid.isInstanceOf[AttributeReference]
+ || gid.asInstanceOf[AttributeReference].name != VirtualColumn.groupingIdName) {
+ failAnalysis(s"grouping()/grouping_id() can only be used with GroupingSets/Cube/Rollup")
+ }
+ a.groupingExpressions.take(a.groupingExpressions.length - 1)
+ }.getOrElse {
+ failAnalysis(s"grouping()/grouping_id() can only be used with GroupingSets/Cube/Rollup")
+ }
}
}
@@ -329,6 +378,11 @@ class Analyzer(
Last(ifExpr(expr), Literal(true))
case a: AggregateFunction =>
a.withNewChildren(a.children.map(ifExpr))
+ }.transform {
+ // We are duplicating aggregates that are now computing a different value for each
+ // pivot value.
+ // TODO: Don't construct the physical container until after analysis.
+ case ae: AggregateExpression => ae.copy(resultId = NamedExpression.newExprId)
}
if (filteredAggregate.fastEquals(aggregate)) {
throw new AnalysisException(
@@ -355,7 +409,7 @@ class Analyzer(
catalog.lookupRelation(u.tableIdentifier, u.alias)
} catch {
case _: NoSuchTableException =>
- u.failAnalysis(s"Table not found: ${u.tableName}")
+ u.failAnalysis(s"Table or View not found: ${u.tableName}")
}
}
@@ -487,18 +541,9 @@ class Analyzer(
Generate(newG.asInstanceOf[Generator], join, outer, qualifier, output, child)
}
- // A special case for ObjectOperator, because the deserializer expressions in ObjectOperator
- // should be resolved by their corresponding attributes instead of children's output.
- case o: ObjectOperator if containsUnresolvedDeserializer(o.deserializers.map(_._1)) =>
- val deserializerToAttributes = o.deserializers.map {
- case (deserializer, attributes) => new TreeNodeRef(deserializer) -> attributes
- }.toMap
-
- o.transformExpressions {
- case expr => deserializerToAttributes.get(new TreeNodeRef(expr)).map { attributes =>
- resolveDeserializer(expr, attributes)
- }.getOrElse(expr)
- }
+ // Skips plan which contains deserializer expressions, as they should be resolved by another
+ // rule: ResolveDeserializer.
+ case plan if containsDeserializer(plan.expressions) => plan
case q: LogicalPlan =>
logTrace(s"Attempting to resolve ${q.simpleString}")
@@ -514,38 +559,6 @@ class Analyzer(
}
}
- private def containsUnresolvedDeserializer(exprs: Seq[Expression]): Boolean = {
- exprs.exists { expr =>
- !expr.resolved || expr.find(_.isInstanceOf[BoundReference]).isDefined
- }
- }
-
- def resolveDeserializer(
- deserializer: Expression,
- attributes: Seq[Attribute]): Expression = {
- val unbound = deserializer transform {
- case b: BoundReference => attributes(b.ordinal)
- }
-
- resolveExpression(unbound, LocalRelation(attributes), throws = true) transform {
- case n: NewInstance
- // If this is an inner class of another class, register the outer object in `OuterScopes`.
- // Note that static inner classes (e.g., inner classes within Scala objects) don't need
- // outer pointer registration.
- if n.outerPointer.isEmpty &&
- n.cls.isMemberClass &&
- !Modifier.isStatic(n.cls.getModifiers) =>
- val outer = OuterScopes.getOuterScope(n.cls)
- if (outer == null) {
- throw new AnalysisException(
- s"Unable to generate an encoder for inner class `${n.cls.getName}` without " +
- "access to the scope that this class was defined in.\n" +
- "Try moving this class out of its parent class.")
- }
- n.copy(outerPointer = Some(outer))
- }
- }
-
def newAliases(expressions: Seq[NamedExpression]): Seq[NamedExpression] = {
expressions.map {
case a: Alias => Alias(a.child, a.name)(isGenerated = a.isGenerated)
@@ -611,6 +624,10 @@ class Analyzer(
}
}
+ private def containsDeserializer(exprs: Seq[Expression]): Boolean = {
+ exprs.exists(_.find(_.isInstanceOf[UnresolvedDeserializer]).isDefined)
+ }
+
protected[sql] def resolveExpression(
expr: Expression,
plan: LogicalPlan,
@@ -692,13 +709,15 @@ class Analyzer(
* clause. This rule detects such queries and adds the required attributes to the original
* projection, so that they will be available during sorting. Another projection is added to
* remove these attributes after sorting.
+ *
+ * The HAVING clause could also used a grouping columns that is not presented in the SELECT.
*/
- object ResolveSortReferences extends Rule[LogicalPlan] {
+ object ResolveMissingReferences extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
// Skip sort with aggregate. This will be handled in ResolveAggregateFunctions
case sa @ Sort(_, _, child: Aggregate) => sa
- case s @ Sort(order, _, child) if !s.resolved && child.resolved =>
+ case s @ Sort(order, _, child) if child.resolved =>
try {
val newOrder = order.map(resolveExpressionRecursively(_, child).asInstanceOf[SortOrder])
val requiredAttrs = AttributeSet(newOrder).filter(_.resolved)
@@ -718,12 +737,32 @@ class Analyzer(
// in Sort
case ae: AnalysisException => s
}
+
+ case f @ Filter(cond, child) if child.resolved =>
+ try {
+ val newCond = resolveExpressionRecursively(cond, child)
+ val requiredAttrs = newCond.references.filter(_.resolved)
+ val missingAttrs = requiredAttrs -- child.outputSet
+ if (missingAttrs.nonEmpty) {
+ // Add missing attributes and then project them away.
+ Project(child.output,
+ Filter(newCond, addMissingAttr(child, missingAttrs)))
+ } else if (newCond != cond) {
+ f.copy(condition = newCond)
+ } else {
+ f
+ }
+ } catch {
+ // Attempting to resolve it might fail. When this happens, return the original plan.
+ // Users will see an AnalysisException for resolution failure of missing attributes
+ case ae: AnalysisException => f
+ }
}
/**
- * Add the missing attributes into projectList of Project/Window or aggregateExpressions of
- * Aggregate.
- */
+ * Add the missing attributes into projectList of Project/Window or aggregateExpressions of
+ * Aggregate.
+ */
private def addMissingAttr(plan: LogicalPlan, missingAttrs: AttributeSet): LogicalPlan = {
if (missingAttrs.isEmpty) {
return plan
@@ -755,9 +794,9 @@ class Analyzer(
}
/**
- * Resolve the expression on a specified logical plan and it's child (recursively), until
- * the expression is resolved or meet a non-unary node or Subquery.
- */
+ * Resolve the expression on a specified logical plan and it's child (recursively), until
+ * the expression is resolved or meet a non-unary node or Subquery.
+ */
@tailrec
private def resolveExpressionRecursively(expr: Expression, plan: LogicalPlan): Expression = {
val resolved = resolveExpression(expr, plan)
@@ -781,9 +820,18 @@ class Analyzer(
case q: LogicalPlan =>
q transformExpressions {
case u if !u.childrenResolved => u // Skip until children are resolved.
+ case u @ UnresolvedGenerator(name, children) =>
+ withPosition(u) {
+ catalog.lookupFunction(name, children) match {
+ case generator: Generator => generator
+ case other =>
+ failAnalysis(s"$name is expected to be a generator. However, " +
+ s"its class is ${other.getClass.getCanonicalName}, which is not a generator.")
+ }
+ }
case u @ UnresolvedFunction(name, children, isDistinct) =>
withPosition(u) {
- registry.lookupFunction(name, children) match {
+ catalog.lookupFunction(name, children) match {
// DISTINCT is not meaningful for a Max or a Min.
case max: Max if isDistinct =>
AggregateExpression(max, Complete, isDistinct = false)
@@ -863,27 +911,33 @@ class Analyzer(
if aggregate.resolved =>
// Try resolving the condition of the filter as though it is in the aggregate clause
- val aggregatedCondition =
- Aggregate(
- grouping,
- Alias(havingCondition, "havingCondition")(isGenerated = true) :: Nil,
- child)
- val resolvedOperator = execute(aggregatedCondition)
- def resolvedAggregateFilter =
- resolvedOperator
- .asInstanceOf[Aggregate]
- .aggregateExpressions.head
-
- // If resolution was successful and we see the filter has an aggregate in it, add it to
- // the original aggregate operator.
- if (resolvedOperator.resolved && containsAggregate(resolvedAggregateFilter)) {
- val aggExprsWithHaving = resolvedAggregateFilter +: originalAggExprs
-
- Project(aggregate.output,
- Filter(resolvedAggregateFilter.toAttribute,
- aggregate.copy(aggregateExpressions = aggExprsWithHaving)))
- } else {
- filter
+ try {
+ val aggregatedCondition =
+ Aggregate(
+ grouping,
+ Alias(havingCondition, "havingCondition")(isGenerated = true) :: Nil,
+ child)
+ val resolvedOperator = execute(aggregatedCondition)
+ def resolvedAggregateFilter =
+ resolvedOperator
+ .asInstanceOf[Aggregate]
+ .aggregateExpressions.head
+
+ // If resolution was successful and we see the filter has an aggregate in it, add it to
+ // the original aggregate operator.
+ if (resolvedOperator.resolved && containsAggregate(resolvedAggregateFilter)) {
+ val aggExprsWithHaving = resolvedAggregateFilter +: originalAggExprs
+
+ Project(aggregate.output,
+ Filter(resolvedAggregateFilter.toAttribute,
+ aggregate.copy(aggregateExpressions = aggExprsWithHaving)))
+ } else {
+ filter
+ }
+ } catch {
+ // Attempting to resolve in the aggregate can result in ambiguity. When this happens,
+ // just return the original plan.
+ case ae: AnalysisException => filter
}
case sort @ Sort(sortOrder, global, aggregate: Aggregate) if aggregate.resolved =>
@@ -947,11 +1001,8 @@ class Analyzer(
}
}
- private def isAggregateExpression(e: Expression): Boolean = {
- e.isInstanceOf[AggregateExpression] || e.isInstanceOf[Grouping] || e.isInstanceOf[GroupingID]
- }
def containsAggregate(condition: Expression): Boolean = {
- condition.find(isAggregateExpression).isDefined
+ condition.find(_.isInstanceOf[AggregateExpression]).isDefined
}
}
@@ -1146,11 +1197,11 @@ class Analyzer(
// Extract Windowed AggregateExpression
case we @ WindowExpression(
- AggregateExpression(function, mode, isDistinct),
+ ae @ AggregateExpression(function, _, _, _),
spec: WindowSpecDefinition) =>
val newChildren = function.children.map(extractExpr)
val newFunction = function.withNewChildren(newChildren).asInstanceOf[AggregateFunction]
- val newAgg = AggregateExpression(newFunction, mode, isDistinct)
+ val newAgg = ae.copy(aggregateFunction = newFunction)
seenWindowAggregates += newAgg
WindowExpression(newAgg, spec)
@@ -1386,8 +1437,8 @@ class Analyzer(
}
/**
- * Check and add order to [[AggregateWindowFunction]]s.
- */
+ * Check and add order to [[AggregateWindowFunction]]s.
+ */
object ResolveWindowOrder extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case logical: LogicalPlan => logical transformExpressions {
@@ -1444,7 +1495,7 @@ class Analyzer(
val projectList = joinType match {
case LeftOuter =>
leftKeys ++ lUniqueOutput ++ rUniqueOutput.map(_.withNullability(true))
- case LeftSemi =>
+ case LeftExistence(_) =>
leftKeys ++ lUniqueOutput
case RightOuter =>
rightKeys ++ lUniqueOutput.map(_.withNullability(true)) ++ rUniqueOutput
@@ -1463,7 +1514,94 @@ class Analyzer(
Project(projectList, Join(left, right, joinType, newCondition))
}
+ /**
+ * Replaces [[UnresolvedDeserializer]] with the deserialization expression that has been resolved
+ * to the given input attributes.
+ */
+ object ResolveDeserializer extends Rule[LogicalPlan] {
+ def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
+ case p if !p.childrenResolved => p
+ case p if p.resolved => p
+ case p => p transformExpressions {
+ case UnresolvedDeserializer(deserializer, inputAttributes) =>
+ val inputs = if (inputAttributes.isEmpty) {
+ p.children.flatMap(_.output)
+ } else {
+ inputAttributes
+ }
+ val unbound = deserializer transform {
+ case b: BoundReference => inputs(b.ordinal)
+ }
+ resolveExpression(unbound, LocalRelation(inputs), throws = true)
+ }
+ }
+ }
+
+ /**
+ * Resolves [[NewInstance]] by finding and adding the outer scope to it if the object being
+ * constructed is an inner class.
+ */
+ object ResolveNewInstance extends Rule[LogicalPlan] {
+ def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
+ case p if !p.childrenResolved => p
+ case p if p.resolved => p
+
+ case p => p transformExpressions {
+ case n: NewInstance if n.childrenResolved && !n.resolved =>
+ val outer = OuterScopes.getOuterScope(n.cls)
+ if (outer == null) {
+ throw new AnalysisException(
+ s"Unable to generate an encoder for inner class `${n.cls.getName}` without " +
+ "access to the scope that this class was defined in.\n" +
+ "Try moving this class out of its parent class.")
+ }
+ n.copy(outerPointer = Some(outer))
+ }
+ }
+ }
+
+ /**
+ * Replace the [[UpCast]] expression by [[Cast]], and throw exceptions if the cast may truncate.
+ */
+ object ResolveUpCast extends Rule[LogicalPlan] {
+ private def fail(from: Expression, to: DataType, walkedTypePath: Seq[String]) = {
+ throw new AnalysisException(s"Cannot up cast ${from.sql} from " +
+ s"${from.dataType.simpleString} to ${to.simpleString} as it may truncate\n" +
+ "The type path of the target object is:\n" + walkedTypePath.mkString("", "\n", "\n") +
+ "You can either add an explicit cast to the input data or choose a higher precision " +
+ "type of the field in the target object")
+ }
+
+ private def illegalNumericPrecedence(from: DataType, to: DataType): Boolean = {
+ val fromPrecedence = HiveTypeCoercion.numericPrecedence.indexOf(from)
+ val toPrecedence = HiveTypeCoercion.numericPrecedence.indexOf(to)
+ toPrecedence > 0 && fromPrecedence > toPrecedence
+ }
+
+ def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
+ case p if !p.childrenResolved => p
+ case p if p.resolved => p
+
+ case p => p transformExpressions {
+ case u @ UpCast(child, _, _) if !child.resolved => u
+
+ case UpCast(child, dataType, walkedTypePath) => (child.dataType, dataType) match {
+ case (from: NumericType, to: DecimalType) if !to.isWiderThan(from) =>
+ fail(child, to, walkedTypePath)
+ case (from: DecimalType, to: NumericType) if !from.isTighterThan(to) =>
+ fail(child, to, walkedTypePath)
+ case (from, to) if illegalNumericPrecedence(from, to) =>
+ fail(child, to, walkedTypePath)
+ case (TimestampType, DateType) =>
+ fail(child, DateType, walkedTypePath)
+ case (StringType, to: NumericType) =>
+ fail(child, to, walkedTypePath)
+ case _ => Cast(child, dataType.asNullable)
+ }
+ }
+ }
+ }
}
/**
@@ -1477,8 +1615,8 @@ object EliminateSubqueryAliases extends Rule[LogicalPlan] {
}
/**
- * Removes [[Union]] operators from the plan if it just has one child.
- */
+ * Removes [[Union]] operators from the plan if it just has one child.
+ */
object EliminateUnions extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case Union(children) if children.size == 1 => children.head
@@ -1532,6 +1670,8 @@ object CleanupAliases extends Rule[LogicalPlan] {
// Operators that operate on objects should only have expressions from encoders, which should
// never have extra aliases.
case o: ObjectOperator => o
+ case d: DeserializeToObject => d
+ case s: SerializeFromObject => s
case other =>
var stop = false
@@ -1548,40 +1688,90 @@ object CleanupAliases extends Rule[LogicalPlan] {
}
/**
- * Replace the `UpCast` expression by `Cast`, and throw exceptions if the cast may truncate.
+ * Maps a time column to multiple time windows using the Expand operator. Since it's non-trivial to
+ * figure out how many windows a time column can map to, we over-estimate the number of windows and
+ * filter out the rows where the time column is not inside the time window.
*/
-object ResolveUpCast extends Rule[LogicalPlan] {
- private def fail(from: Expression, to: DataType, walkedTypePath: Seq[String]) = {
- throw new AnalysisException(s"Cannot up cast ${from.sql} from " +
- s"${from.dataType.simpleString} to ${to.simpleString} as it may truncate\n" +
- "The type path of the target object is:\n" + walkedTypePath.mkString("", "\n", "\n") +
- "You can either add an explicit cast to the input data or choose a higher precision " +
- "type of the field in the target object")
- }
+object TimeWindowing extends Rule[LogicalPlan] {
+ import org.apache.spark.sql.catalyst.dsl.expressions._
- private def illegalNumericPrecedence(from: DataType, to: DataType): Boolean = {
- val fromPrecedence = HiveTypeCoercion.numericPrecedence.indexOf(from)
- val toPrecedence = HiveTypeCoercion.numericPrecedence.indexOf(to)
- toPrecedence > 0 && fromPrecedence > toPrecedence
- }
+ private final val WINDOW_START = "start"
+ private final val WINDOW_END = "end"
- def apply(plan: LogicalPlan): LogicalPlan = {
- plan transformAllExpressions {
- case u @ UpCast(child, _, _) if !child.resolved => u
-
- case UpCast(child, dataType, walkedTypePath) => (child.dataType, dataType) match {
- case (from: NumericType, to: DecimalType) if !to.isWiderThan(from) =>
- fail(child, to, walkedTypePath)
- case (from: DecimalType, to: NumericType) if !from.isTighterThan(to) =>
- fail(child, to, walkedTypePath)
- case (from, to) if illegalNumericPrecedence(from, to) =>
- fail(child, to, walkedTypePath)
- case (TimestampType, DateType) =>
- fail(child, DateType, walkedTypePath)
- case (StringType, to: NumericType) =>
- fail(child, to, walkedTypePath)
- case _ => Cast(child, dataType.asNullable)
+ /**
+ * Generates the logical plan for generating window ranges on a timestamp column. Without
+ * knowing what the timestamp value is, it's non-trivial to figure out deterministically how many
+ * window ranges a timestamp will map to given all possible combinations of a window duration,
+ * slide duration and start time (offset). Therefore, we express and over-estimate the number of
+ * windows there may be, and filter the valid windows. We use last Project operator to group
+ * the window columns into a struct so they can be accessed as `window.start` and `window.end`.
+ *
+ * The windows are calculated as below:
+ * maxNumOverlapping <- ceil(windowDuration / slideDuration)
+ * for (i <- 0 until maxNumOverlapping)
+ * windowId <- ceil((timestamp - startTime) / slideDuration)
+ * windowStart <- windowId * slideDuration + (i - maxNumOverlapping) * slideDuration + startTime
+ * windowEnd <- windowStart + windowDuration
+ * return windowStart, windowEnd
+ *
+ * This behaves as follows for the given parameters for the time: 12:05. The valid windows are
+ * marked with a +, and invalid ones are marked with a x. The invalid ones are filtered using the
+ * Filter operator.
+ * window: 12m, slide: 5m, start: 0m :: window: 12m, slide: 5m, start: 2m
+ * 11:55 - 12:07 + 11:52 - 12:04 x
+ * 12:00 - 12:12 + 11:57 - 12:09 +
+ * 12:05 - 12:17 + 12:02 - 12:14 +
+ *
+ * @param plan The logical plan
+ * @return the logical plan that will generate the time windows using the Expand operator, with
+ * the Filter operator for correctness and Project for usability.
+ */
+ def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
+ case p: LogicalPlan if p.children.size == 1 =>
+ val child = p.children.head
+ val windowExpressions =
+ p.expressions.flatMap(_.collect { case t: TimeWindow => t }).distinct.toList // Not correct.
+
+ // Only support a single window expression for now
+ if (windowExpressions.size == 1 &&
+ windowExpressions.head.timeColumn.resolved &&
+ windowExpressions.head.checkInputDataTypes().isSuccess) {
+ val window = windowExpressions.head
+ val windowAttr = AttributeReference("window", window.dataType)()
+
+ val maxNumOverlapping = math.ceil(window.windowDuration * 1.0 / window.slideDuration).toInt
+ val windows = Seq.tabulate(maxNumOverlapping + 1) { i =>
+ val windowId = Ceil((PreciseTimestamp(window.timeColumn) - window.startTime) /
+ window.slideDuration)
+ val windowStart = (windowId + i - maxNumOverlapping) *
+ window.slideDuration + window.startTime
+ val windowEnd = windowStart + window.windowDuration
+
+ CreateNamedStruct(
+ Literal(WINDOW_START) :: windowStart ::
+ Literal(WINDOW_END) :: windowEnd :: Nil)
+ }
+
+ val projections = windows.map(_ +: p.children.head.output)
+
+ val filterExpr =
+ window.timeColumn >= windowAttr.getField(WINDOW_START) &&
+ window.timeColumn < windowAttr.getField(WINDOW_END)
+
+ val expandedPlan =
+ Filter(filterExpr,
+ Expand(projections, windowAttr +: child.output, child))
+
+ val substitutedPlan = p transformExpressions {
+ case t: TimeWindow => windowAttr
+ }
+
+ substitutedPlan.withNewChildren(expandedPlan :: Nil)
+ } else if (windowExpressions.size > 1) {
+ p.failAnalysis("Multiple time window expressions would result in a cartesian product " +
+ "of rows, therefore they are not currently not supported.")
+ } else {
+ p // Return unchanged. Analyzer will throw exception later
}
- }
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
index 1d1e892e32..d6a8c3eec8 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
@@ -52,7 +52,7 @@ trait CheckAnalysis {
case p if p.analyzed => // Skip already analyzed sub-plans
case u: UnresolvedRelation =>
- u.failAnalysis(s"Table not found: ${u.tableIdentifier}")
+ u.failAnalysis(s"Table or View not found: ${u.tableIdentifier}")
case operator: LogicalPlan =>
operator transformExpressionsUp {
@@ -76,7 +76,7 @@ trait CheckAnalysis {
case g: GroupingID =>
failAnalysis(s"grouping_id() can only be used with GroupingSets/Cube/Rollup")
- case w @ WindowExpression(AggregateExpression(_, _, true), _) =>
+ case w @ WindowExpression(AggregateExpression(_, _, true, _), _) =>
failAnalysis(s"Distinct window functions are not supported: $w")
case w @ WindowExpression(_: OffsetWindowFunction, WindowSpecDefinition(_, order,
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index f584a4b73a..f2abf136da 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -45,6 +45,19 @@ trait FunctionRegistry {
/* Get the class of the registered function by specified name. */
def lookupFunction(name: String): Option[ExpressionInfo]
+
+ /* Get the builder of the registered function by specified name. */
+ def lookupFunctionBuilder(name: String): Option[FunctionBuilder]
+
+ /** Drop a function and return whether the function existed. */
+ def dropFunction(name: String): Boolean
+
+ /** Checks if a function with a given name exists. */
+ def functionExists(name: String): Boolean = lookupFunction(name).isDefined
+
+ /** Clear all registered functions. */
+ def clear(): Unit
+
}
class SimpleFunctionRegistry extends FunctionRegistry {
@@ -76,6 +89,18 @@ class SimpleFunctionRegistry extends FunctionRegistry {
functionBuilders.get(name).map(_._1)
}
+ override def lookupFunctionBuilder(name: String): Option[FunctionBuilder] = synchronized {
+ functionBuilders.get(name).map(_._2)
+ }
+
+ override def dropFunction(name: String): Boolean = synchronized {
+ functionBuilders.remove(name).isDefined
+ }
+
+ override def clear(): Unit = {
+ functionBuilders.clear()
+ }
+
def copy(): SimpleFunctionRegistry = synchronized {
val registry = new SimpleFunctionRegistry
functionBuilders.iterator.foreach { case (name, (info, builder)) =>
@@ -106,6 +131,19 @@ object EmptyFunctionRegistry extends FunctionRegistry {
override def lookupFunction(name: String): Option[ExpressionInfo] = {
throw new UnsupportedOperationException
}
+
+ override def lookupFunctionBuilder(name: String): Option[FunctionBuilder] = {
+ throw new UnsupportedOperationException
+ }
+
+ override def dropFunction(name: String): Boolean = {
+ throw new UnsupportedOperationException
+ }
+
+ override def clear(): Unit = {
+ throw new UnsupportedOperationException
+ }
+
}
@@ -133,6 +171,7 @@ object FunctionRegistry {
expression[Rand]("rand"),
expression[Randn]("randn"),
expression[CreateStruct]("struct"),
+ expression[CaseWhen]("when"),
// math functions
expression[Acos]("acos"),
@@ -179,6 +218,12 @@ object FunctionRegistry {
expression[Tan]("tan"),
expression[Tanh]("tanh"),
+ expression[Add]("+"),
+ expression[Subtract]("-"),
+ expression[Multiply]("*"),
+ expression[Divide]("/"),
+ expression[Remainder]("%"),
+
// aggregate functions
expression[HyperLogLogPlusPlus]("approx_count_distinct"),
expression[Average]("avg"),
@@ -219,6 +264,7 @@ object FunctionRegistry {
expression[Lower]("lcase"),
expression[Length]("length"),
expression[Levenshtein]("levenshtein"),
+ expression[Like]("like"),
expression[Lower]("lower"),
expression[StringLocate]("locate"),
expression[StringLPad]("lpad"),
@@ -229,6 +275,7 @@ object FunctionRegistry {
expression[RegExpReplace]("regexp_replace"),
expression[StringRepeat]("repeat"),
expression[StringReverse]("reverse"),
+ expression[RLike]("rlike"),
expression[StringRPad]("rpad"),
expression[StringTrimRight]("rtrim"),
expression[SoundEx]("soundex"),
@@ -273,6 +320,7 @@ object FunctionRegistry {
expression[UnixTimestamp]("unix_timestamp"),
expression[WeekOfYear]("weekofyear"),
expression[Year]("year"),
+ expression[TimeWindow]("window"),
// collection functions
expression[ArrayContains]("array_contains"),
@@ -304,7 +352,29 @@ object FunctionRegistry {
expression[NTile]("ntile"),
expression[Rank]("rank"),
expression[DenseRank]("dense_rank"),
- expression[PercentRank]("percent_rank")
+ expression[PercentRank]("percent_rank"),
+
+ // predicates
+ expression[And]("and"),
+ expression[In]("in"),
+ expression[Not]("not"),
+ expression[Or]("or"),
+
+ expression[EqualNullSafe]("<=>"),
+ expression[EqualTo]("="),
+ expression[EqualTo]("=="),
+ expression[GreaterThan](">"),
+ expression[GreaterThanOrEqual](">="),
+ expression[LessThan]("<"),
+ expression[LessThanOrEqual]("<="),
+ expression[Not]("!"),
+
+ // bitwise
+ expression[BitwiseAnd]("&"),
+ expression[BitwiseNot]("~"),
+ expression[BitwiseOr]("|"),
+ expression[BitwiseXor]("^")
+
)
val builtin: SimpleFunctionRegistry = {
@@ -337,7 +407,10 @@ object FunctionRegistry {
}
Try(f.newInstance(expressions : _*).asInstanceOf[Expression]) match {
case Success(e) => e
- case Failure(e) => throw new AnalysisException(e.getMessage)
+ case Failure(e) =>
+ // the exception is an invocation exception. To get a meaningful message, we need the
+ // cause.
+ throw new AnalysisException(e.getCause.getMessage)
}
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/NoSuchItemException.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/NoSuchItemException.scala
index e9f04eecf8..5e18316c94 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/NoSuchItemException.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/NoSuchItemException.scala
@@ -17,6 +17,7 @@
package org.apache.spark.sql.catalyst.analysis
+import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.catalog.ExternalCatalog.TablePartitionSpec
@@ -24,29 +25,13 @@ import org.apache.spark.sql.catalyst.catalog.ExternalCatalog.TablePartitionSpec
* Thrown by a catalog when an item cannot be found. The analyzer will rethrow the exception
* as an [[org.apache.spark.sql.AnalysisException]] with the correct position information.
*/
-abstract class NoSuchItemException extends Exception {
- override def getMessage: String
-}
+class NoSuchDatabaseException(db: String) extends AnalysisException(s"Database $db not found")
-class NoSuchDatabaseException(db: String) extends NoSuchItemException {
- override def getMessage: String = s"Database $db not found"
-}
+class NoSuchTableException(db: String, table: String)
+ extends AnalysisException(s"Table or View $table not found in database $db")
-class NoSuchTableException(db: String, table: String) extends NoSuchItemException {
- override def getMessage: String = s"Table $table not found in database $db"
-}
+class NoSuchPartitionException(db: String, table: String, spec: TablePartitionSpec) extends
+ AnalysisException(s"Partition not found in table $table database $db:\n" + spec.mkString("\n"))
-class NoSuchPartitionException(
- db: String,
- table: String,
- spec: TablePartitionSpec)
- extends NoSuchItemException {
-
- override def getMessage: String = {
- s"Partition not found in table $table database $db:\n" + spec.mkString("\n")
- }
-}
-
-class NoSuchFunctionException(db: String, func: String) extends NoSuchItemException {
- override def getMessage: String = s"Function $func not found in database $db"
-}
+class NoSuchFunctionException(db: String, func: String)
+ extends AnalysisException(s"Function $func not found in database $db")
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
index e73d367a73..4ec43aba02 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
@@ -18,9 +18,9 @@
package org.apache.spark.sql.catalyst.analysis
import org.apache.spark.sql.AnalysisException
-import org.apache.spark.sql.catalyst.{errors, TableIdentifier}
+import org.apache.spark.sql.catalyst.{errors, InternalRow, TableIdentifier}
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
+import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodegenFallback, ExprCode}
import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan}
import org.apache.spark.sql.catalyst.trees.TreeNode
import org.apache.spark.sql.catalyst.util.quoteIdentifier
@@ -133,6 +133,33 @@ object UnresolvedAttribute {
}
}
+/**
+ * Represents an unresolved generator, which will be created by the parser for
+ * the [[org.apache.spark.sql.catalyst.plans.logical.Generate]] operator.
+ * The analyzer will resolve this generator.
+ */
+case class UnresolvedGenerator(name: String, children: Seq[Expression]) extends Generator {
+
+ override def elementTypes: Seq[(DataType, Boolean, String)] =
+ throw new UnresolvedException(this, "elementTypes")
+ override def dataType: DataType = throw new UnresolvedException(this, "dataType")
+ override def foldable: Boolean = throw new UnresolvedException(this, "foldable")
+ override def nullable: Boolean = throw new UnresolvedException(this, "nullable")
+ override lazy val resolved = false
+
+ override def prettyName: String = name
+ override def toString: String = s"'$name(${children.mkString(", ")})"
+
+ override def eval(input: InternalRow = null): TraversableOnce[InternalRow] =
+ throw new UnsupportedOperationException(s"Cannot evaluate expression: $this")
+
+ override protected def genCode(ctx: CodegenContext, ev: ExprCode): String =
+ throw new UnsupportedOperationException(s"Cannot evaluate expression: $this")
+
+ override def terminate(): TraversableOnce[InternalRow] =
+ throw new UnsupportedOperationException(s"Cannot evaluate expression: $this")
+}
+
case class UnresolvedFunction(
name: String,
children: Seq[Expression],
@@ -307,3 +334,25 @@ case class UnresolvedAlias(child: Expression, aliasName: Option[String] = None)
override lazy val resolved = false
}
+
+/**
+ * Holds the deserializer expression and the attributes that are available during the resolution
+ * for it. Deserializer expression is a special kind of expression that is not always resolved by
+ * children output, but by given attributes, e.g. the `keyDeserializer` in `MapGroups` should be
+ * resolved by `groupingAttributes` instead of children output.
+ *
+ * @param deserializer The unresolved deserializer expression
+ * @param inputAttributes The input attributes used to resolve deserializer expression, can be empty
+ * if we want to resolve deserializer by children output.
+ */
+case class UnresolvedDeserializer(deserializer: Expression, inputAttributes: Seq[Attribute] = Nil)
+ extends UnaryExpression with Unevaluable with NonSQLExpression {
+ // The input attributes used to resolve deserializer expression must be all resolved.
+ require(inputAttributes.forall(_.resolved), "Input attributes must all be resolved.")
+
+ override def child: Expression = deserializer
+ override def dataType: DataType = throw new UnresolvedException(this, "dataType")
+ override def foldable: Boolean = throw new UnresolvedException(this, "foldable")
+ override def nullable: Boolean = throw new UnresolvedException(this, "nullable")
+ override lazy val resolved = false
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala
index e216fa5528..f8a6fb74cc 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala
@@ -21,7 +21,7 @@ import scala.collection.mutable
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier}
-
+import org.apache.spark.sql.catalyst.util.StringUtils
/**
* An in-memory (ephemeral) implementation of the system catalog.
@@ -47,16 +47,6 @@ class InMemoryCatalog extends ExternalCatalog {
// Database name -> description
private val catalog = new scala.collection.mutable.HashMap[String, DatabaseDesc]
- private def filterPattern(names: Seq[String], pattern: String): Seq[String] = {
- val regex = pattern.replaceAll("\\*", ".*").r
- names.filter { funcName => regex.pattern.matcher(funcName).matches() }
- }
-
- private def functionExists(db: String, funcName: String): Boolean = {
- requireDbExists(db)
- catalog(db).functions.contains(funcName)
- }
-
private def partitionExists(db: String, table: String, spec: TablePartitionSpec): Boolean = {
requireTableExists(db, table)
catalog(db).tables(table).partitions.contains(spec)
@@ -72,7 +62,7 @@ class InMemoryCatalog extends ExternalCatalog {
private def requireTableExists(db: String, table: String): Unit = {
if (!tableExists(db, table)) {
throw new AnalysisException(
- s"Table not found: '$table' does not exist in database '$db'")
+ s"Table or View not found: '$table' does not exist in database '$db'")
}
}
@@ -141,7 +131,7 @@ class InMemoryCatalog extends ExternalCatalog {
}
override def listDatabases(pattern: String): Seq[String] = synchronized {
- filterPattern(listDatabases(), pattern)
+ StringUtils.filterPattern(listDatabases(), pattern)
}
override def setCurrentDatabase(db: String): Unit = { /* no-op */ }
@@ -155,7 +145,7 @@ class InMemoryCatalog extends ExternalCatalog {
tableDefinition: CatalogTable,
ignoreIfExists: Boolean): Unit = synchronized {
requireDbExists(db)
- val table = tableDefinition.name.table
+ val table = tableDefinition.identifier.table
if (tableExists(db, table)) {
if (!ignoreIfExists) {
throw new AnalysisException(s"Table '$table' already exists in database '$db'")
@@ -174,7 +164,7 @@ class InMemoryCatalog extends ExternalCatalog {
catalog(db).tables.remove(table)
} else {
if (!ignoreIfNotExists) {
- throw new AnalysisException(s"Table '$table' does not exist in database '$db'")
+ throw new AnalysisException(s"Table or View '$table' does not exist in database '$db'")
}
}
}
@@ -182,14 +172,14 @@ class InMemoryCatalog extends ExternalCatalog {
override def renameTable(db: String, oldName: String, newName: String): Unit = synchronized {
requireTableExists(db, oldName)
val oldDesc = catalog(db).tables(oldName)
- oldDesc.table = oldDesc.table.copy(name = TableIdentifier(newName, Some(db)))
+ oldDesc.table = oldDesc.table.copy(identifier = TableIdentifier(newName, Some(db)))
catalog(db).tables.put(newName, oldDesc)
catalog(db).tables.remove(oldName)
}
override def alterTable(db: String, tableDefinition: CatalogTable): Unit = synchronized {
- requireTableExists(db, tableDefinition.name.table)
- catalog(db).tables(tableDefinition.name.table).table = tableDefinition
+ requireTableExists(db, tableDefinition.identifier.table)
+ catalog(db).tables(tableDefinition.identifier.table).table = tableDefinition
}
override def getTable(db: String, table: String): CatalogTable = synchronized {
@@ -197,6 +187,10 @@ class InMemoryCatalog extends ExternalCatalog {
catalog(db).tables(table).table
}
+ override def getTableOption(db: String, table: String): Option[CatalogTable] = synchronized {
+ if (!tableExists(db, table)) None else Option(catalog(db).tables(table).table)
+ }
+
override def tableExists(db: String, table: String): Boolean = synchronized {
requireDbExists(db)
catalog(db).tables.contains(table)
@@ -208,7 +202,7 @@ class InMemoryCatalog extends ExternalCatalog {
}
override def listTables(db: String, pattern: String): Seq[String] = synchronized {
- filterPattern(listTables(db), pattern)
+ StringUtils.filterPattern(listTables(db), pattern)
}
// --------------------------------------------------------------------------
@@ -296,10 +290,10 @@ class InMemoryCatalog extends ExternalCatalog {
override def createFunction(db: String, func: CatalogFunction): Unit = synchronized {
requireDbExists(db)
- if (functionExists(db, func.name.funcName)) {
+ if (functionExists(db, func.identifier.funcName)) {
throw new AnalysisException(s"Function '$func' already exists in '$db' database")
} else {
- catalog(db).functions.put(func.name.funcName, func)
+ catalog(db).functions.put(func.identifier.funcName, func)
}
}
@@ -310,24 +304,24 @@ class InMemoryCatalog extends ExternalCatalog {
override def renameFunction(db: String, oldName: String, newName: String): Unit = synchronized {
requireFunctionExists(db, oldName)
- val newFunc = getFunction(db, oldName).copy(name = FunctionIdentifier(newName, Some(db)))
+ val newFunc = getFunction(db, oldName).copy(identifier = FunctionIdentifier(newName, Some(db)))
catalog(db).functions.remove(oldName)
catalog(db).functions.put(newName, newFunc)
}
- override def alterFunction(db: String, funcDefinition: CatalogFunction): Unit = synchronized {
- requireFunctionExists(db, funcDefinition.name.funcName)
- catalog(db).functions.put(funcDefinition.name.funcName, funcDefinition)
- }
-
override def getFunction(db: String, funcName: String): CatalogFunction = synchronized {
requireFunctionExists(db, funcName)
catalog(db).functions(funcName)
}
+ override def functionExists(db: String, funcName: String): Boolean = {
+ requireDbExists(db)
+ catalog(db).functions.contains(funcName)
+ }
+
override def listFunctions(db: String, pattern: String): Seq[String] = synchronized {
requireDbExists(db)
- filterPattern(catalog(db).functions.keysIterator.toSeq, pattern)
+ StringUtils.filterPattern(catalog(db).functions.keysIterator.toSeq, pattern)
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala
index 34265faa74..34e1cb7315 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala
@@ -17,30 +17,47 @@
package org.apache.spark.sql.catalyst.catalog
-import java.util.concurrent.ConcurrentHashMap
+import java.io.File
-import scala.collection.JavaConverters._
+import scala.collection.mutable
+import org.apache.spark.internal.Logging
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.{CatalystConf, SimpleCatalystConf}
import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier}
+import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, NoSuchFunctionException, SimpleFunctionRegistry}
+import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
+import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionInfo}
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, SubqueryAlias}
-
+import org.apache.spark.sql.catalyst.util.StringUtils
/**
* An internal catalog that is used by a Spark Session. This internal catalog serves as a
* proxy to the underlying metastore (e.g. Hive Metastore) and it also manages temporary
* tables and functions of the Spark Session that it belongs to.
+ *
+ * This class is not thread-safe.
*/
-class SessionCatalog(externalCatalog: ExternalCatalog, conf: CatalystConf) {
+class SessionCatalog(
+ externalCatalog: ExternalCatalog,
+ functionResourceLoader: FunctionResourceLoader,
+ functionRegistry: FunctionRegistry,
+ conf: CatalystConf) extends Logging {
import ExternalCatalog._
+ def this(
+ externalCatalog: ExternalCatalog,
+ functionRegistry: FunctionRegistry,
+ conf: CatalystConf) {
+ this(externalCatalog, DummyFunctionResourceLoader, functionRegistry, conf)
+ }
+
+ // For testing only.
def this(externalCatalog: ExternalCatalog) {
- this(externalCatalog, new SimpleCatalystConf(true))
+ this(externalCatalog, new SimpleFunctionRegistry, new SimpleCatalystConf(true))
}
- protected[this] val tempTables = new ConcurrentHashMap[String, LogicalPlan]
- protected[this] val tempFunctions = new ConcurrentHashMap[String, CatalogFunction]
+ protected[this] val tempTables = new mutable.HashMap[String, LogicalPlan]
// Note: we track current database here because certain operations do not explicitly
// specify the database (e.g. DROP TABLE my_table). In these cases we must first
@@ -79,7 +96,7 @@ class SessionCatalog(externalCatalog: ExternalCatalog, conf: CatalystConf) {
externalCatalog.alterDatabase(dbDefinition)
}
- def getDatabase(db: String): CatalogDatabase = {
+ def getDatabaseMetadata(db: String): CatalogDatabase = {
externalCatalog.getDatabase(db)
}
@@ -104,6 +121,10 @@ class SessionCatalog(externalCatalog: ExternalCatalog, conf: CatalystConf) {
currentDb = db
}
+ def getDefaultDBPath(db: String): String = {
+ System.getProperty("java.io.tmpdir") + File.separator + db + ".db"
+ }
+
// ----------------------------------------------------------------------------
// Tables
// ----------------------------------------------------------------------------
@@ -122,9 +143,9 @@ class SessionCatalog(externalCatalog: ExternalCatalog, conf: CatalystConf) {
* If no such database is specified, create it in the current database.
*/
def createTable(tableDefinition: CatalogTable, ignoreIfExists: Boolean): Unit = {
- val db = tableDefinition.name.database.getOrElse(currentDb)
- val table = formatTableName(tableDefinition.name.table)
- val newTableDefinition = tableDefinition.copy(name = TableIdentifier(table, Some(db)))
+ val db = tableDefinition.identifier.database.getOrElse(currentDb)
+ val table = formatTableName(tableDefinition.identifier.table)
+ val newTableDefinition = tableDefinition.copy(identifier = TableIdentifier(table, Some(db)))
externalCatalog.createTable(db, newTableDefinition, ignoreIfExists)
}
@@ -138,22 +159,34 @@ class SessionCatalog(externalCatalog: ExternalCatalog, conf: CatalystConf) {
* this becomes a no-op.
*/
def alterTable(tableDefinition: CatalogTable): Unit = {
- val db = tableDefinition.name.database.getOrElse(currentDb)
- val table = formatTableName(tableDefinition.name.table)
- val newTableDefinition = tableDefinition.copy(name = TableIdentifier(table, Some(db)))
+ val db = tableDefinition.identifier.database.getOrElse(currentDb)
+ val table = formatTableName(tableDefinition.identifier.table)
+ val newTableDefinition = tableDefinition.copy(identifier = TableIdentifier(table, Some(db)))
externalCatalog.alterTable(db, newTableDefinition)
}
/**
* Retrieve the metadata of an existing metastore table.
* If no database is specified, assume the table is in the current database.
+ * If the specified table is not found in the database then an [[AnalysisException]] is thrown.
*/
- def getTable(name: TableIdentifier): CatalogTable = {
+ def getTableMetadata(name: TableIdentifier): CatalogTable = {
val db = name.database.getOrElse(currentDb)
val table = formatTableName(name.table)
externalCatalog.getTable(db, table)
}
+ /**
+ * Retrieve the metadata of an existing metastore table.
+ * If no database is specified, assume the table is in the current database.
+ * If the specified table is not found in the database then return None if it doesn't exist.
+ */
+ def getTableMetadataOption(name: TableIdentifier): Option[CatalogTable] = {
+ val db = name.database.getOrElse(currentDb)
+ val table = formatTableName(name.table)
+ externalCatalog.getTableOption(db, table)
+ }
+
// -------------------------------------------------------------
// | Methods that interact with temporary and metastore tables |
// -------------------------------------------------------------
@@ -164,9 +197,9 @@ class SessionCatalog(externalCatalog: ExternalCatalog, conf: CatalystConf) {
def createTempTable(
name: String,
tableDefinition: LogicalPlan,
- ignoreIfExists: Boolean): Unit = {
+ overrideIfExists: Boolean): Unit = {
val table = formatTableName(name)
- if (tempTables.containsKey(table) && !ignoreIfExists) {
+ if (tempTables.contains(table) && !overrideIfExists) {
throw new AnalysisException(s"Temporary table '$name' already exists.")
}
tempTables.put(table, tableDefinition)
@@ -188,10 +221,11 @@ class SessionCatalog(externalCatalog: ExternalCatalog, conf: CatalystConf) {
val db = oldName.database.getOrElse(currentDb)
val oldTableName = formatTableName(oldName.table)
val newTableName = formatTableName(newName.table)
- if (oldName.database.isDefined || !tempTables.containsKey(oldTableName)) {
+ if (oldName.database.isDefined || !tempTables.contains(oldTableName)) {
externalCatalog.renameTable(db, oldTableName, newTableName)
} else {
- val table = tempTables.remove(oldTableName)
+ val table = tempTables(oldTableName)
+ tempTables.remove(oldTableName)
tempTables.put(newTableName, table)
}
}
@@ -206,8 +240,14 @@ class SessionCatalog(externalCatalog: ExternalCatalog, conf: CatalystConf) {
def dropTable(name: TableIdentifier, ignoreIfNotExists: Boolean): Unit = {
val db = name.database.getOrElse(currentDb)
val table = formatTableName(name.table)
- if (name.database.isDefined || !tempTables.containsKey(table)) {
- externalCatalog.dropTable(db, table, ignoreIfNotExists)
+ if (name.database.isDefined || !tempTables.contains(table)) {
+ // When ignoreIfNotExists is false, no exception is issued when the table does not exist.
+ // Instead, log it as an error message.
+ if (externalCatalog.tableExists(db, table)) {
+ externalCatalog.dropTable(db, table, ignoreIfNotExists = true)
+ } else if (!ignoreIfNotExists) {
+ logError(s"Table or View '${name.quotedString}' does not exist")
+ }
} else {
tempTables.remove(table)
}
@@ -224,11 +264,11 @@ class SessionCatalog(externalCatalog: ExternalCatalog, conf: CatalystConf) {
val db = name.database.getOrElse(currentDb)
val table = formatTableName(name.table)
val relation =
- if (name.database.isDefined || !tempTables.containsKey(table)) {
+ if (name.database.isDefined || !tempTables.contains(table)) {
val metadata = externalCatalog.getTable(db, table)
CatalogRelation(db, metadata, alias)
} else {
- tempTables.get(table)
+ tempTables(table)
}
val qualifiedTable = SubqueryAlias(table, relation)
// If an alias was specified by the lookup, wrap the plan in a subquery so that
@@ -247,7 +287,7 @@ class SessionCatalog(externalCatalog: ExternalCatalog, conf: CatalystConf) {
def tableExists(name: TableIdentifier): Boolean = {
val db = name.database.getOrElse(currentDb)
val table = formatTableName(name.table)
- if (name.database.isDefined || !tempTables.containsKey(table)) {
+ if (name.database.isDefined || !tempTables.contains(table)) {
externalCatalog.tableExists(db, table)
} else {
true // it's a temporary table
@@ -255,6 +295,16 @@ class SessionCatalog(externalCatalog: ExternalCatalog, conf: CatalystConf) {
}
/**
+ * Return whether a table with the specified name is a temporary table.
+ *
+ * Note: The temporary table cache is checked only when database is not
+ * explicitly specified.
+ */
+ def isTemporaryTable(name: TableIdentifier): Boolean = {
+ name.database.isEmpty && tempTables.contains(formatTableName(name.table))
+ }
+
+ /**
* List all tables in the specified database, including temporary tables.
*/
def listTables(db: String): Seq[TableIdentifier] = listTables(db, "*")
@@ -265,19 +315,24 @@ class SessionCatalog(externalCatalog: ExternalCatalog, conf: CatalystConf) {
def listTables(db: String, pattern: String): Seq[TableIdentifier] = {
val dbTables =
externalCatalog.listTables(db, pattern).map { t => TableIdentifier(t, Some(db)) }
- val regex = pattern.replaceAll("\\*", ".*").r
- val _tempTables = tempTables.keys().asScala
- .filter { t => regex.pattern.matcher(t).matches() }
+ val _tempTables = StringUtils.filterPattern(tempTables.keys.toSeq, pattern)
.map { t => TableIdentifier(t) }
dbTables ++ _tempTables
}
+ // TODO: It's strange that we have both refresh and invalidate here.
+
/**
* Refresh the cache entry for a metastore table, if any.
*/
def refreshTable(name: TableIdentifier): Unit = { /* no-op */ }
/**
+ * Invalidate the cache entry for a metastore table, if any.
+ */
+ def invalidateTable(name: TableIdentifier): Unit = { /* no-op */ }
+
+ /**
* Drop all existing temporary tables.
* For testing only.
*/
@@ -290,7 +345,7 @@ class SessionCatalog(externalCatalog: ExternalCatalog, conf: CatalystConf) {
* For testing only.
*/
private[catalog] def getTempTable(name: String): Option[LogicalPlan] = {
- Option(tempTables.get(name))
+ tempTables.get(name)
}
// ----------------------------------------------------------------------------
@@ -398,36 +453,57 @@ class SessionCatalog(externalCatalog: ExternalCatalog, conf: CatalystConf) {
* Create a metastore function in the database specified in `funcDefinition`.
* If no such database is specified, create it in the current database.
*/
- def createFunction(funcDefinition: CatalogFunction): Unit = {
- val db = funcDefinition.name.database.getOrElse(currentDb)
- val newFuncDefinition = funcDefinition.copy(
- name = FunctionIdentifier(funcDefinition.name.funcName, Some(db)))
- externalCatalog.createFunction(db, newFuncDefinition)
+ def createFunction(funcDefinition: CatalogFunction, ignoreIfExists: Boolean): Unit = {
+ val db = funcDefinition.identifier.database.getOrElse(currentDb)
+ val identifier = FunctionIdentifier(funcDefinition.identifier.funcName, Some(db))
+ val newFuncDefinition = funcDefinition.copy(identifier = identifier)
+ if (!functionExists(identifier)) {
+ externalCatalog.createFunction(db, newFuncDefinition)
+ } else if (!ignoreIfExists) {
+ throw new AnalysisException(s"function '$identifier' already exists in database '$db'")
+ }
}
/**
* Drop a metastore function.
* If no database is specified, assume the function is in the current database.
*/
- def dropFunction(name: FunctionIdentifier): Unit = {
+ def dropFunction(name: FunctionIdentifier, ignoreIfNotExists: Boolean): Unit = {
val db = name.database.getOrElse(currentDb)
- externalCatalog.dropFunction(db, name.funcName)
+ val identifier = name.copy(database = Some(db))
+ if (functionExists(identifier)) {
+ // TODO: registry should just take in FunctionIdentifier for type safety
+ if (functionRegistry.functionExists(identifier.unquotedString)) {
+ // If we have loaded this function into the FunctionRegistry,
+ // also drop it from there.
+ // For a permanent function, because we loaded it to the FunctionRegistry
+ // when it's first used, we also need to drop it from the FunctionRegistry.
+ functionRegistry.dropFunction(identifier.unquotedString)
+ }
+ externalCatalog.dropFunction(db, name.funcName)
+ } else if (!ignoreIfNotExists) {
+ throw new AnalysisException(s"function '$identifier' does not exist in database '$db'")
+ }
}
/**
- * Alter a metastore function whose name that matches the one specified in `funcDefinition`.
- *
- * If no database is specified in `funcDefinition`, assume the function is in the
- * current database.
+ * Retrieve the metadata of a metastore function.
*
- * Note: If the underlying implementation does not support altering a certain field,
- * this becomes a no-op.
+ * If a database is specified in `name`, this will return the function in that database.
+ * If no database is specified, this will return the function in the current database.
*/
- def alterFunction(funcDefinition: CatalogFunction): Unit = {
- val db = funcDefinition.name.database.getOrElse(currentDb)
- val newFuncDefinition = funcDefinition.copy(
- name = FunctionIdentifier(funcDefinition.name.funcName, Some(db)))
- externalCatalog.alterFunction(db, newFuncDefinition)
+ def getFunctionMetadata(name: FunctionIdentifier): CatalogFunction = {
+ val db = name.database.getOrElse(currentDb)
+ externalCatalog.getFunction(db, name.funcName)
+ }
+
+ /**
+ * Check if the specified function exists.
+ */
+ def functionExists(name: FunctionIdentifier): Boolean = {
+ val db = name.database.getOrElse(currentDb)
+ functionRegistry.functionExists(name.unquotedString) ||
+ externalCatalog.functionExists(db, name.funcName)
}
// ----------------------------------------------------------------
@@ -435,17 +511,40 @@ class SessionCatalog(externalCatalog: ExternalCatalog, conf: CatalystConf) {
// ----------------------------------------------------------------
/**
+ * Construct a [[FunctionBuilder]] based on the provided class that represents a function.
+ *
+ * This performs reflection to decide what type of [[Expression]] to return in the builder.
+ */
+ private[sql] def makeFunctionBuilder(name: String, functionClassName: String): FunctionBuilder = {
+ // TODO: at least support UDAFs here
+ throw new UnsupportedOperationException("Use sqlContext.udf.register(...) instead.")
+ }
+
+ /**
+ * Loads resources such as JARs and Files for a function. Every resource is represented
+ * by a tuple (resource type, resource uri).
+ */
+ def loadFunctionResources(resources: Seq[(String, String)]): Unit = {
+ resources.foreach { case (resourceType, uri) =>
+ val functionResource =
+ FunctionResource(FunctionResourceType.fromString(resourceType.toLowerCase), uri)
+ functionResourceLoader.loadResource(functionResource)
+ }
+ }
+
+ /**
* Create a temporary function.
* This assumes no database is specified in `funcDefinition`.
*/
- def createTempFunction(funcDefinition: CatalogFunction, ignoreIfExists: Boolean): Unit = {
- require(funcDefinition.name.database.isEmpty,
- "attempted to create a temporary function while specifying a database")
- val name = funcDefinition.name.funcName
- if (tempFunctions.containsKey(name) && !ignoreIfExists) {
+ def createTempFunction(
+ name: String,
+ info: ExpressionInfo,
+ funcDefinition: FunctionBuilder,
+ ignoreIfExists: Boolean): Unit = {
+ if (functionRegistry.lookupFunctionBuilder(name).isDefined && !ignoreIfExists) {
throw new AnalysisException(s"Temporary function '$name' already exists.")
}
- tempFunctions.put(name, funcDefinition)
+ functionRegistry.registerFunction(name, info, funcDefinition)
}
/**
@@ -455,53 +554,71 @@ class SessionCatalog(externalCatalog: ExternalCatalog, conf: CatalystConf) {
// Hive has DROP FUNCTION and DROP TEMPORARY FUNCTION. We may want to consolidate
// dropFunction and dropTempFunction.
def dropTempFunction(name: String, ignoreIfNotExists: Boolean): Unit = {
- if (!tempFunctions.containsKey(name) && !ignoreIfNotExists) {
+ if (!functionRegistry.dropFunction(name) && !ignoreIfNotExists) {
throw new AnalysisException(
s"Temporary function '$name' cannot be dropped because it does not exist!")
}
- tempFunctions.remove(name)
+ }
+
+ protected def failFunctionLookup(name: String): Nothing = {
+ throw new AnalysisException(s"Undefined function: $name. This function is " +
+ s"neither a registered temporary function nor " +
+ s"a permanent function registered in the database $currentDb.")
}
/**
- * Rename a function.
+ * Return an [[Expression]] that represents the specified function, assuming it exists.
*
- * If a database is specified in `oldName`, this will rename the function in that database.
- * If no database is specified, this will first attempt to rename a temporary function with
- * the same name, then, if that does not exist, rename the function in the current database.
+ * For a temporary function or a permanent function that has been loaded,
+ * this method will simply lookup the function through the
+ * FunctionRegistry and create an expression based on the builder.
*
- * This assumes the database specified in `oldName` matches the one specified in `newName`.
+ * For a permanent function that has not been loaded, we will first fetch its metadata
+ * from the underlying external catalog. Then, we will load all resources associated
+ * with this function (i.e. jars and files). Finally, we create a function builder
+ * based on the function class and put the builder into the FunctionRegistry.
+ * The name of this function in the FunctionRegistry will be `databaseName.functionName`.
*/
- def renameFunction(oldName: FunctionIdentifier, newName: FunctionIdentifier): Unit = {
- if (oldName.database != newName.database) {
- throw new AnalysisException("rename does not support moving functions across databases")
- }
- val db = oldName.database.getOrElse(currentDb)
- if (oldName.database.isDefined || !tempFunctions.containsKey(oldName.funcName)) {
- externalCatalog.renameFunction(db, oldName.funcName, newName.funcName)
+ def lookupFunction(name: String, children: Seq[Expression]): Expression = {
+ // TODO: Right now, the name can be qualified or not qualified.
+ // It will be better to get a FunctionIdentifier.
+ // TODO: Right now, we assume that name is not qualified!
+ val qualifiedName = FunctionIdentifier(name, Some(currentDb)).unquotedString
+ if (functionRegistry.functionExists(name)) {
+ // This function has been already loaded into the function registry.
+ functionRegistry.lookupFunction(name, children)
+ } else if (functionRegistry.functionExists(qualifiedName)) {
+ // This function has been already loaded into the function registry.
+ // Unlike the above block, we find this function by using the qualified name.
+ functionRegistry.lookupFunction(qualifiedName, children)
} else {
- val func = tempFunctions.remove(oldName.funcName)
- val newFunc = func.copy(name = func.name.copy(funcName = newName.funcName))
- tempFunctions.put(newName.funcName, newFunc)
+ // The function has not been loaded to the function registry, which means
+ // that the function is a permanent function (if it actually has been registered
+ // in the metastore). We need to first put the function in the FunctionRegistry.
+ val catalogFunction = try {
+ externalCatalog.getFunction(currentDb, name)
+ } catch {
+ case e: AnalysisException => failFunctionLookup(name)
+ case e: NoSuchFunctionException => failFunctionLookup(name)
+ }
+ loadFunctionResources(catalogFunction.resources)
+ // Please note that qualifiedName is provided by the user. However,
+ // catalogFunction.identifier.unquotedString is returned by the underlying
+ // catalog. So, it is possible that qualifiedName is not exactly the same as
+ // catalogFunction.identifier.unquotedString (difference is on case-sensitivity).
+ // At here, we preserve the input from the user.
+ val info = new ExpressionInfo(catalogFunction.className, qualifiedName)
+ val builder = makeFunctionBuilder(qualifiedName, catalogFunction.className)
+ createTempFunction(qualifiedName, info, builder, ignoreIfExists = false)
+ // Now, we need to create the Expression.
+ functionRegistry.lookupFunction(qualifiedName, children)
}
}
/**
- * Retrieve the metadata of an existing function.
- *
- * If a database is specified in `name`, this will return the function in that database.
- * If no database is specified, this will first attempt to return a temporary function with
- * the same name, then, if that does not exist, return the function in the current database.
+ * List all functions in the specified database, including temporary functions.
*/
- def getFunction(name: FunctionIdentifier): CatalogFunction = {
- val db = name.database.getOrElse(currentDb)
- if (name.database.isDefined || !tempFunctions.containsKey(name.funcName)) {
- externalCatalog.getFunction(db, name.funcName)
- } else {
- tempFunctions.get(name.funcName)
- }
- }
-
- // TODO: implement lookupFunction that returns something from the registry itself
+ def listFunctions(db: String): Seq[FunctionIdentifier] = listFunctions(db, "*")
/**
* List all matching functions in the specified database, including temporary functions.
@@ -509,18 +626,40 @@ class SessionCatalog(externalCatalog: ExternalCatalog, conf: CatalystConf) {
def listFunctions(db: String, pattern: String): Seq[FunctionIdentifier] = {
val dbFunctions =
externalCatalog.listFunctions(db, pattern).map { f => FunctionIdentifier(f, Some(db)) }
- val regex = pattern.replaceAll("\\*", ".*").r
- val _tempFunctions = tempFunctions.keys().asScala
- .filter { f => regex.pattern.matcher(f).matches() }
+ val loadedFunctions = StringUtils.filterPattern(functionRegistry.listFunction(), pattern)
.map { f => FunctionIdentifier(f) }
- dbFunctions ++ _tempFunctions
+ // TODO: Actually, there will be dbFunctions that have been loaded into the FunctionRegistry.
+ // So, the returned list may have two entries for the same function.
+ dbFunctions ++ loadedFunctions
}
+
+ // -----------------
+ // | Other methods |
+ // -----------------
+
/**
- * Return a temporary function. For testing only.
+ * Drop all existing databases (except "default") along with all associated tables,
+ * partitions and functions, and set the current database to "default".
+ *
+ * This is mainly used for tests.
*/
- private[catalog] def getTempFunction(name: String): Option[CatalogFunction] = {
- Option(tempFunctions.get(name))
+ private[sql] def reset(): Unit = {
+ val default = "default"
+ listDatabases().filter(_ != default).foreach { db =>
+ dropDatabase(db, ignoreIfNotExists = false, cascade = true)
+ }
+ tempTables.clear()
+ functionRegistry.clear()
+ // restore built-in functions
+ FunctionRegistry.builtin.listFunction().foreach { f =>
+ val expressionInfo = FunctionRegistry.builtin.lookupFunction(f)
+ val functionBuilder = FunctionRegistry.builtin.lookupFunctionBuilder(f)
+ require(expressionInfo.isDefined, s"built-in function '$f' is missing expression info")
+ require(functionBuilder.isDefined, s"built-in function '$f' is missing function builder")
+ functionRegistry.registerFunction(f, expressionInfo.get, functionBuilder.get)
+ }
+ setCurrentDatabase(default)
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/functionResources.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/functionResources.scala
new file mode 100644
index 0000000000..5adcc892cf
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/functionResources.scala
@@ -0,0 +1,61 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.catalog
+
+import org.apache.spark.sql.AnalysisException
+
+/** An trait that represents the type of a resourced needed by a function. */
+sealed trait FunctionResourceType
+
+object JarResource extends FunctionResourceType
+
+object FileResource extends FunctionResourceType
+
+// We do not allow users to specify a archive because it is YARN specific.
+// When loading resources, we will throw an exception and ask users to
+// use --archive with spark submit.
+object ArchiveResource extends FunctionResourceType
+
+object FunctionResourceType {
+ def fromString(resourceType: String): FunctionResourceType = {
+ resourceType.toLowerCase match {
+ case "jar" => JarResource
+ case "file" => FileResource
+ case "archive" => ArchiveResource
+ case other =>
+ throw new AnalysisException(s"Resource Type '$resourceType' is not supported.")
+ }
+ }
+}
+
+case class FunctionResource(resourceType: FunctionResourceType, uri: String)
+
+/**
+ * A simple trait representing a class that can be used to load resources used by
+ * a function. Because only a SQLContext can load resources, we create this trait
+ * to avoid of explicitly passing SQLContext around.
+ */
+trait FunctionResourceLoader {
+ def loadResource(resource: FunctionResource): Unit
+}
+
+object DummyFunctionResourceLoader extends FunctionResourceLoader {
+ override def loadResource(resource: FunctionResource): Unit = {
+ throw new UnsupportedOperationException
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala
index 34803133f6..ad989a97e4 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala
@@ -39,7 +39,7 @@ abstract class ExternalCatalog {
protected def requireDbExists(db: String): Unit = {
if (!databaseExists(db)) {
- throw new AnalysisException(s"Database $db does not exist")
+ throw new AnalysisException(s"Database '$db' does not exist")
}
}
@@ -91,6 +91,8 @@ abstract class ExternalCatalog {
def getTable(db: String, table: String): CatalogTable
+ def getTableOption(db: String, table: String): Option[CatalogTable]
+
def tableExists(db: String, table: String): Boolean
def listTables(db: String): Seq[String]
@@ -150,17 +152,10 @@ abstract class ExternalCatalog {
def renameFunction(db: String, oldName: String, newName: String): Unit
- /**
- * Alter a function whose name that matches the one specified in `funcDefinition`,
- * assuming the function exists.
- *
- * Note: If the underlying implementation does not support altering a certain field,
- * this becomes a no-op.
- */
- def alterFunction(db: String, funcDefinition: CatalogFunction): Unit
-
def getFunction(db: String, funcName: String): CatalogFunction
+ def functionExists(db: String, funcName: String): Boolean
+
def listFunctions(db: String, pattern: String): Seq[String]
}
@@ -169,10 +164,15 @@ abstract class ExternalCatalog {
/**
* A function defined in the catalog.
*
- * @param name name of the function
+ * @param identifier name of the function
* @param className fully qualified class name, e.g. "org.apache.spark.util.MyFunc"
+ * @param resources resource types and Uris used by the function
*/
-case class CatalogFunction(name: FunctionIdentifier, className: String)
+// TODO: Use FunctionResource instead of (String, String) as the element type of resources.
+case class CatalogFunction(
+ identifier: FunctionIdentifier,
+ className: String,
+ resources: Seq[(String, String)])
/**
@@ -216,26 +216,42 @@ case class CatalogTablePartition(
* future once we have a better understanding of how we want to handle skewed columns.
*/
case class CatalogTable(
- name: TableIdentifier,
+ identifier: TableIdentifier,
tableType: CatalogTableType,
storage: CatalogStorageFormat,
schema: Seq[CatalogColumn],
- partitionColumns: Seq[CatalogColumn] = Seq.empty,
- sortColumns: Seq[CatalogColumn] = Seq.empty,
- numBuckets: Int = 0,
+ partitionColumnNames: Seq[String] = Seq.empty,
+ sortColumnNames: Seq[String] = Seq.empty,
+ bucketColumnNames: Seq[String] = Seq.empty,
+ numBuckets: Int = -1,
createTime: Long = System.currentTimeMillis,
- lastAccessTime: Long = System.currentTimeMillis,
+ lastAccessTime: Long = -1,
properties: Map[String, String] = Map.empty,
viewOriginalText: Option[String] = None,
- viewText: Option[String] = None) {
+ viewText: Option[String] = None,
+ comment: Option[String] = None) {
+
+ // Verify that the provided columns are part of the schema
+ private val colNames = schema.map(_.name).toSet
+ private def requireSubsetOfSchema(cols: Seq[String], colType: String): Unit = {
+ require(cols.toSet.subsetOf(colNames), s"$colType columns (${cols.mkString(", ")}) " +
+ s"must be a subset of schema (${colNames.mkString(", ")}) in table '$identifier'")
+ }
+ requireSubsetOfSchema(partitionColumnNames, "partition")
+ requireSubsetOfSchema(sortColumnNames, "sort")
+ requireSubsetOfSchema(bucketColumnNames, "bucket")
+
+ /** Columns this table is partitioned by. */
+ def partitionColumns: Seq[CatalogColumn] =
+ schema.filter { c => partitionColumnNames.contains(c.name) }
/** Return the database this table was specified to belong to, assuming it exists. */
- def database: String = name.database.getOrElse {
- throw new AnalysisException(s"table $name did not specify database")
+ def database: String = identifier.database.getOrElse {
+ throw new AnalysisException(s"table $identifier did not specify database")
}
/** Return the fully qualified name of this table, assuming the database was specified. */
- def qualifiedName: String = name.unquotedString
+ def qualifiedName: String = identifier.unquotedString
/** Syntactic sugar to update a field in `storage`. */
def withNewStorage(
@@ -290,6 +306,6 @@ case class CatalogRelation(
// TODO: implement this
override def output: Seq[Attribute] = Seq.empty
- require(metadata.name.database == Some(db),
- "provided database does not much the one specified in the table definition")
+ require(metadata.identifier.database == Some(db),
+ "provided database does not match the one specified in the table definition")
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
index 3540014c3e..1e7296664b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
@@ -21,7 +21,8 @@ import java.sql.{Date, Timestamp}
import scala.language.implicitConversions
-import org.apache.spark.sql.catalyst.analysis.{EliminateSubqueryAliases, UnresolvedAttribute, UnresolvedExtractValue}
+import org.apache.spark.sql.Encoder
+import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.plans.{Inner, JoinType}
@@ -161,6 +162,18 @@ package object dsl {
def lower(e: Expression): Expression = Lower(e)
def sqrt(e: Expression): Expression = Sqrt(e)
def abs(e: Expression): Expression = Abs(e)
+ def star(names: String*): Expression = names match {
+ case Seq() => UnresolvedStar(None)
+ case target => UnresolvedStar(Option(target))
+ }
+
+ def callFunction[T, U](
+ func: T => U,
+ returnType: DataType,
+ argument: Expression): Expression = {
+ val function = Literal.create(func, ObjectType(classOf[T => U]))
+ Invoke(function, "apply", returnType, argument :: Nil)
+ }
implicit class DslSymbol(sym: Symbol) extends ImplicitAttribute { def s: String = sym.name }
// TODO more implicit class for literal?
@@ -231,6 +244,12 @@ package object dsl {
AttributeReference(s, structType, nullable = true)()
def struct(attrs: AttributeReference*): AttributeReference =
struct(StructType.fromAttributes(attrs))
+
+ /** Create a function. */
+ def function(exprs: Expression*): UnresolvedFunction =
+ UnresolvedFunction(s, exprs, isDistinct = false)
+ def distinctFunction(exprs: Expression*): UnresolvedFunction =
+ UnresolvedFunction(s, exprs, isDistinct = true)
}
implicit class DslAttribute(a: AttributeReference) {
@@ -243,11 +262,33 @@ package object dsl {
object expressions extends ExpressionConversions // scalastyle:ignore
object plans { // scalastyle:ignore
+ def table(ref: String): LogicalPlan =
+ UnresolvedRelation(TableIdentifier(ref), None)
+
+ def table(db: String, ref: String): LogicalPlan =
+ UnresolvedRelation(TableIdentifier(ref, Option(db)), None)
+
implicit class DslLogicalPlan(val logicalPlan: LogicalPlan) {
- def select(exprs: NamedExpression*): LogicalPlan = Project(exprs, logicalPlan)
+ def select(exprs: Expression*): LogicalPlan = {
+ val namedExpressions = exprs.map {
+ case e: NamedExpression => e
+ case e => UnresolvedAlias(e)
+ }
+ Project(namedExpressions, logicalPlan)
+ }
def where(condition: Expression): LogicalPlan = Filter(condition, logicalPlan)
+ def filter[T : Encoder](func: T => Boolean): LogicalPlan = {
+ val deserialized = logicalPlan.deserialize[T]
+ val condition = expressions.callFunction(func, BooleanType, deserialized.output.head)
+ Filter(condition, deserialized).serialize[T]
+ }
+
+ def serialize[T : Encoder]: LogicalPlan = CatalystSerde.serialize[T](logicalPlan)
+
+ def deserialize[T : Encoder]: LogicalPlan = CatalystSerde.deserialize[T](logicalPlan)
+
def limit(limitExpr: Expression): LogicalPlan = Limit(limitExpr, logicalPlan)
def join(
@@ -296,6 +337,14 @@ package object dsl {
analysis.UnresolvedRelation(TableIdentifier(tableName)),
Map.empty, logicalPlan, overwrite, false)
+ def as(alias: String): LogicalPlan = logicalPlan match {
+ case UnresolvedRelation(tbl, _) => UnresolvedRelation(tbl, Option(alias))
+ case plan => SubqueryAlias(alias, plan)
+ }
+
+ def distribute(exprs: Expression*): LogicalPlan =
+ RepartitionByExpression(exprs, logicalPlan)
+
def analyze: LogicalPlan =
EliminateSubqueryAliases(analysis.SimpleAnalyzer.execute(logicalPlan))
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
index 918233ddcd..56d29cfbe1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
@@ -24,7 +24,7 @@ import scala.reflect.runtime.universe.{typeTag, TypeTag}
import org.apache.spark.sql.{AnalysisException, Encoder}
import org.apache.spark.sql.catalyst.{InternalRow, JavaTypeInference, ScalaReflection}
-import org.apache.spark.sql.catalyst.analysis.{SimpleAnalyzer, UnresolvedAttribute, UnresolvedExtractValue}
+import org.apache.spark.sql.catalyst.analysis.{SimpleAnalyzer, UnresolvedAttribute, UnresolvedDeserializer, UnresolvedExtractValue}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection, GenerateUnsafeProjection}
import org.apache.spark.sql.catalyst.optimizer.SimplifyCasts
@@ -51,8 +51,8 @@ object ExpressionEncoder {
val flat = !classOf[Product].isAssignableFrom(cls)
val inputObject = BoundReference(0, ScalaReflection.dataTypeFor[T], nullable = false)
- val toRowExpression = ScalaReflection.extractorsFor[T](inputObject)
- val fromRowExpression = ScalaReflection.constructorFor[T]
+ val serializer = ScalaReflection.serializerFor[T](inputObject)
+ val deserializer = ScalaReflection.deserializerFor[T]
val schema = ScalaReflection.schemaFor[T] match {
case ScalaReflection.Schema(s: StructType, _) => s
@@ -62,8 +62,8 @@ object ExpressionEncoder {
new ExpressionEncoder[T](
schema,
flat,
- toRowExpression.flatten,
- fromRowExpression,
+ serializer.flatten,
+ deserializer,
ClassTag[T](cls))
}
@@ -72,14 +72,14 @@ object ExpressionEncoder {
val schema = JavaTypeInference.inferDataType(beanClass)._1
assert(schema.isInstanceOf[StructType])
- val toRowExpression = JavaTypeInference.extractorsFor(beanClass)
- val fromRowExpression = JavaTypeInference.constructorFor(beanClass)
+ val serializer = JavaTypeInference.serializerFor(beanClass)
+ val deserializer = JavaTypeInference.deserializerFor(beanClass)
new ExpressionEncoder[T](
schema.asInstanceOf[StructType],
flat = false,
- toRowExpression.flatten,
- fromRowExpression,
+ serializer.flatten,
+ deserializer,
ClassTag[T](beanClass))
}
@@ -103,9 +103,9 @@ object ExpressionEncoder {
val cls = Utils.getContextOrSparkClassLoader.loadClass(s"scala.Tuple${encoders.size}")
- val toRowExpressions = encoders.map {
- case e if e.flat => e.toRowExpressions.head
- case other => CreateStruct(other.toRowExpressions)
+ val serializer = encoders.map {
+ case e if e.flat => e.serializer.head
+ case other => CreateStruct(other.serializer)
}.zipWithIndex.map { case (expr, index) =>
expr.transformUp {
case BoundReference(0, t, _) =>
@@ -116,14 +116,14 @@ object ExpressionEncoder {
}
}
- val fromRowExpressions = encoders.zipWithIndex.map { case (enc, index) =>
+ val childrenDeserializers = encoders.zipWithIndex.map { case (enc, index) =>
if (enc.flat) {
- enc.fromRowExpression.transform {
+ enc.deserializer.transform {
case b: BoundReference => b.copy(ordinal = index)
}
} else {
val input = BoundReference(index, enc.schema, nullable = true)
- enc.fromRowExpression.transformUp {
+ enc.deserializer.transformUp {
case UnresolvedAttribute(nameParts) =>
assert(nameParts.length == 1)
UnresolvedExtractValue(input, Literal(nameParts.head))
@@ -132,14 +132,14 @@ object ExpressionEncoder {
}
}
- val fromRowExpression =
- NewInstance(cls, fromRowExpressions, ObjectType(cls), propagateNull = false)
+ val deserializer =
+ NewInstance(cls, childrenDeserializers, ObjectType(cls), propagateNull = false)
new ExpressionEncoder[Any](
schema,
flat = false,
- toRowExpressions,
- fromRowExpression,
+ serializer,
+ deserializer,
ClassTag(cls))
}
@@ -174,29 +174,29 @@ object ExpressionEncoder {
* A generic encoder for JVM objects.
*
* @param schema The schema after converting `T` to a Spark SQL row.
- * @param toRowExpressions A set of expressions, one for each top-level field that can be used to
- * extract the values from a raw object into an [[InternalRow]].
- * @param fromRowExpression An expression that will construct an object given an [[InternalRow]].
+ * @param serializer A set of expressions, one for each top-level field that can be used to
+ * extract the values from a raw object into an [[InternalRow]].
+ * @param deserializer An expression that will construct an object given an [[InternalRow]].
* @param clsTag A classtag for `T`.
*/
case class ExpressionEncoder[T](
schema: StructType,
flat: Boolean,
- toRowExpressions: Seq[Expression],
- fromRowExpression: Expression,
+ serializer: Seq[Expression],
+ deserializer: Expression,
clsTag: ClassTag[T])
extends Encoder[T] {
- if (flat) require(toRowExpressions.size == 1)
+ if (flat) require(serializer.size == 1)
@transient
- private lazy val extractProjection = GenerateUnsafeProjection.generate(toRowExpressions)
+ private lazy val extractProjection = GenerateUnsafeProjection.generate(serializer)
@transient
private lazy val inputRow = new GenericMutableRow(1)
@transient
- private lazy val constructProjection = GenerateSafeProjection.generate(fromRowExpression :: Nil)
+ private lazy val constructProjection = GenerateSafeProjection.generate(deserializer :: Nil)
/**
* Returns this encoder where it has been bound to its own output (i.e. no remaping of columns
@@ -212,7 +212,7 @@ case class ExpressionEncoder[T](
* Returns a new set (with unique ids) of [[NamedExpression]] that represent the serialized form
* of this object.
*/
- def namedExpressions: Seq[NamedExpression] = schema.map(_.name).zip(toRowExpressions).map {
+ def namedExpressions: Seq[NamedExpression] = schema.map(_.name).zip(serializer).map {
case (_, ne: NamedExpression) => ne.newInstance()
case (name, e) => Alias(e, name)()
}
@@ -228,7 +228,7 @@ case class ExpressionEncoder[T](
} catch {
case e: Exception =>
throw new RuntimeException(
- s"Error while encoding: $e\n${toRowExpressions.map(_.treeString).mkString("\n")}", e)
+ s"Error while encoding: $e\n${serializer.map(_.treeString).mkString("\n")}", e)
}
/**
@@ -240,7 +240,7 @@ case class ExpressionEncoder[T](
constructProjection(row).get(0, ObjectType(clsTag.runtimeClass)).asInstanceOf[T]
} catch {
case e: Exception =>
- throw new RuntimeException(s"Error while decoding: $e\n${fromRowExpression.treeString}", e)
+ throw new RuntimeException(s"Error while decoding: $e\n${deserializer.treeString}", e)
}
/**
@@ -249,7 +249,7 @@ case class ExpressionEncoder[T](
* has not been done already in places where we plan to do later composition of encoders.
*/
def assertUnresolved(): Unit = {
- (fromRowExpression +: toRowExpressions).foreach(_.foreach {
+ (deserializer +: serializer).foreach(_.foreach {
case a: AttributeReference if a.name != "loopVar" =>
sys.error(s"Unresolved encoder expected, but $a was found.")
case _ =>
@@ -257,7 +257,7 @@ case class ExpressionEncoder[T](
}
/**
- * Validates `fromRowExpression` to make sure it can be resolved by given schema, and produce
+ * Validates `deserializer` to make sure it can be resolved by given schema, and produce
* friendly error messages to explain why it fails to resolve if there is something wrong.
*/
def validate(schema: Seq[Attribute]): Unit = {
@@ -271,7 +271,7 @@ case class ExpressionEncoder[T](
// If this is a tuple encoder or tupled encoder, which means its leaf nodes are all
// `BoundReference`, make sure their ordinals are all valid.
var maxOrdinal = -1
- fromRowExpression.foreach {
+ deserializer.foreach {
case b: BoundReference => if (b.ordinal > maxOrdinal) maxOrdinal = b.ordinal
case _ =>
}
@@ -285,7 +285,7 @@ case class ExpressionEncoder[T](
// we unbound it by the given `schema` and propagate the actual type to `GetStructField`, after
// we resolve the `fromRowExpression`.
val resolved = SimpleAnalyzer.resolveExpression(
- fromRowExpression,
+ deserializer,
LocalRelation(schema),
throws = true)
@@ -312,42 +312,39 @@ case class ExpressionEncoder[T](
}
/**
- * Returns a new copy of this encoder, where the expressions used by `fromRow` are resolved to the
- * given schema.
+ * Returns a new copy of this encoder, where the `deserializer` is resolved to the given schema.
*/
def resolve(
schema: Seq[Attribute],
outerScopes: ConcurrentMap[String, AnyRef]): ExpressionEncoder[T] = {
- val deserializer = SimpleAnalyzer.ResolveReferences.resolveDeserializer(
- fromRowExpression, schema)
-
// Make a fake plan to wrap the deserializer, so that we can go though the whole analyzer, check
// analysis, go through optimizer, etc.
- val plan = Project(Alias(deserializer, "")() :: Nil, LocalRelation(schema))
+ val plan = Project(
+ Alias(UnresolvedDeserializer(deserializer, schema), "")() :: Nil,
+ LocalRelation(schema))
val analyzedPlan = SimpleAnalyzer.execute(plan)
SimpleAnalyzer.checkAnalysis(analyzedPlan)
- copy(fromRowExpression = SimplifyCasts(analyzedPlan).expressions.head.children.head)
+ copy(deserializer = SimplifyCasts(analyzedPlan).expressions.head.children.head)
}
/**
- * Returns a copy of this encoder where the expressions used to construct an object from an input
- * row have been bound to the ordinals of the given schema. Note that you need to first call
- * resolve before bind.
+ * Returns a copy of this encoder where the `deserializer` has been bound to the
+ * ordinals of the given schema. Note that you need to first call resolve before bind.
*/
def bind(schema: Seq[Attribute]): ExpressionEncoder[T] = {
- copy(fromRowExpression = BindReferences.bindReference(fromRowExpression, schema))
+ copy(deserializer = BindReferences.bindReference(deserializer, schema))
}
/**
* Returns a new encoder with input columns shifted by `delta` ordinals
*/
def shift(delta: Int): ExpressionEncoder[T] = {
- copy(fromRowExpression = fromRowExpression transform {
+ copy(deserializer = deserializer transform {
case r: BoundReference => r.copy(ordinal = r.ordinal + delta)
})
}
- protected val attrs = toRowExpressions.flatMap(_.collect {
+ protected val attrs = serializer.flatMap(_.collect {
case _: UnresolvedAttribute => ""
case a: Attribute => s"#${a.exprId}"
case b: BoundReference => s"[${b.ordinal}]"
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
index 30f56d8c2f..a8397aa5e5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
@@ -36,23 +36,23 @@ object RowEncoder {
val cls = classOf[Row]
val inputObject = BoundReference(0, ObjectType(cls), nullable = true)
// We use an If expression to wrap extractorsFor result of StructType
- val extractExpressions = extractorsFor(inputObject, schema).asInstanceOf[If].falseValue
- val constructExpression = constructorFor(schema)
+ val serializer = serializerFor(inputObject, schema).asInstanceOf[If].falseValue
+ val deserializer = deserializerFor(schema)
new ExpressionEncoder[Row](
schema,
flat = false,
- extractExpressions.asInstanceOf[CreateStruct].children,
- constructExpression,
+ serializer.asInstanceOf[CreateStruct].children,
+ deserializer,
ClassTag(cls))
}
- private def extractorsFor(
+ private def serializerFor(
inputObject: Expression,
inputType: DataType): Expression = inputType match {
case NullType | BooleanType | ByteType | ShortType | IntegerType | LongType |
FloatType | DoubleType | BinaryType | CalendarIntervalType => inputObject
- case p: PythonUserDefinedType => extractorsFor(inputObject, p.sqlType)
+ case p: PythonUserDefinedType => serializerFor(inputObject, p.sqlType)
case udt: UserDefinedType[_] =>
val obj = NewInstance(
@@ -95,7 +95,7 @@ object RowEncoder {
classOf[GenericArrayData],
inputObject :: Nil,
dataType = t)
- case _ => MapObjects(extractorsFor(_, et), inputObject, externalDataTypeForInput(et))
+ case _ => MapObjects(serializerFor(_, et), inputObject, externalDataTypeForInput(et))
}
case t @ MapType(kt, vt, valueNullable) =>
@@ -104,14 +104,14 @@ object RowEncoder {
Invoke(inputObject, "keysIterator", ObjectType(classOf[scala.collection.Iterator[_]])),
"toSeq",
ObjectType(classOf[scala.collection.Seq[_]]))
- val convertedKeys = extractorsFor(keys, ArrayType(kt, false))
+ val convertedKeys = serializerFor(keys, ArrayType(kt, false))
val values =
Invoke(
Invoke(inputObject, "valuesIterator", ObjectType(classOf[scala.collection.Iterator[_]])),
"toSeq",
ObjectType(classOf[scala.collection.Seq[_]]))
- val convertedValues = extractorsFor(values, ArrayType(vt, valueNullable))
+ val convertedValues = serializerFor(values, ArrayType(vt, valueNullable))
NewInstance(
classOf[ArrayBasedMapData],
@@ -128,7 +128,7 @@ object RowEncoder {
If(
Invoke(inputObject, "isNullAt", BooleanType, Literal(i) :: Nil),
Literal.create(null, f.dataType),
- extractorsFor(
+ serializerFor(
Invoke(inputObject, method, externalDataTypeForInput(f.dataType), Literal(i) :: Nil),
f.dataType))
}
@@ -166,7 +166,7 @@ object RowEncoder {
case _: NullType => ObjectType(classOf[java.lang.Object])
}
- private def constructorFor(schema: StructType): Expression = {
+ private def deserializerFor(schema: StructType): Expression = {
val fields = schema.zipWithIndex.map { case (f, i) =>
val dt = f.dataType match {
case p: PythonUserDefinedType => p.sqlType
@@ -176,13 +176,13 @@ object RowEncoder {
If(
IsNull(field),
Literal.create(null, externalDataTypeFor(dt)),
- constructorFor(field)
+ deserializerFor(field)
)
}
CreateExternalRow(fields, schema)
}
- private def constructorFor(input: Expression): Expression = input.dataType match {
+ private def deserializerFor(input: Expression): Expression = input.dataType match {
case NullType | BooleanType | ByteType | ShortType | IntegerType | LongType |
FloatType | DoubleType | BinaryType | CalendarIntervalType => input
@@ -216,7 +216,7 @@ object RowEncoder {
case ArrayType(et, nullable) =>
val arrayData =
Invoke(
- MapObjects(constructorFor(_), input, et),
+ MapObjects(deserializerFor(_), input, et),
"array",
ObjectType(classOf[Array[_]]))
StaticInvoke(
@@ -227,10 +227,10 @@ object RowEncoder {
case MapType(kt, vt, valueNullable) =>
val keyArrayType = ArrayType(kt, false)
- val keyData = constructorFor(Invoke(input, "keyArray", keyArrayType))
+ val keyData = deserializerFor(Invoke(input, "keyArray", keyArrayType))
val valueArrayType = ArrayType(vt, valueNullable)
- val valueData = constructorFor(Invoke(input, "valueArray", valueArrayType))
+ val valueData = deserializerFor(Invoke(input, "valueArray", valueArrayType))
StaticInvoke(
ArrayBasedMapData.getClass,
@@ -243,7 +243,7 @@ object RowEncoder {
If(
Invoke(input, "isNullAt", BooleanType, Literal(i) :: Nil),
Literal.create(null, externalDataTypeFor(f.dataType)),
- constructorFor(GetStructField(input, i)))
+ deserializerFor(GetStructField(input, i)))
}
If(IsNull(input),
Literal.create(null, externalDataTypeFor(input.dataType)),
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/errors/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/errors/package.scala
index 0d44d1dd96..0420b4b538 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/errors/package.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/errors/package.scala
@@ -25,15 +25,18 @@ import org.apache.spark.sql.catalyst.trees.TreeNode
package object errors {
class TreeNodeException[TreeType <: TreeNode[_]](
- tree: TreeType, msg: String, cause: Throwable)
+ @transient val tree: TreeType,
+ msg: String,
+ cause: Throwable)
extends Exception(msg, cause) {
+ val treeString = tree.toString
+
// Yes, this is the same as a default parameter, but... those don't seem to work with SBT
// external project dependencies for some reason.
def this(tree: TreeType, msg: String) = this(tree, msg, null)
override def getMessage: String = {
- val treeString = tree.toString
s"${super.getMessage}, tree:${if (treeString contains "\n") "\n" else " "}$tree"
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index a965cc8d53..0f8876a9e6 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -112,7 +112,7 @@ object Cast {
}
/** Cast the child expression to the target data type. */
-case class Cast(child: Expression, dataType: DataType) extends UnaryExpression {
+case class Cast(child: Expression, dataType: DataType) extends UnaryExpression with NullIntolerant {
override def toString: String = s"cast($child as ${dataType.simpleString})"
@@ -898,7 +898,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression {
val result = ctx.freshName("result")
val tmpRow = ctx.freshName("tmpRow")
- val fieldsEvalCode = fieldsCasts.zipWithIndex.map { case (cast, i) => {
+ val fieldsEvalCode = fieldsCasts.zipWithIndex.map { case (cast, i) =>
val fromFieldPrim = ctx.freshName("ffp")
val fromFieldNull = ctx.freshName("ffn")
val toFieldPrim = ctx.freshName("tfp")
@@ -920,7 +920,6 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression {
}
}
"""
- }
}.mkString("\n")
(c, evPrim, evNull) =>
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala
index affd1bdb32..8d8cc152ff 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala
@@ -97,11 +97,11 @@ class EquivalentExpressions {
def debugString(all: Boolean = false): String = {
val sb: mutable.StringBuilder = new StringBuilder()
sb.append("Equivalent expressions:\n")
- equivalenceMap.foreach { case (k, v) => {
+ equivalenceMap.foreach { case (k, v) =>
if (all || v.length > 1) {
sb.append(" " + v.mkString(", ")).append("\n")
}
- }}
+ }
sb.toString()
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
index 5f8899d599..a24a5db8d4 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
@@ -153,8 +153,8 @@ abstract class Expression extends TreeNode[Expression] {
* evaluate to the same result.
*/
lazy val canonicalized: Expression = {
- val canonicalizedChildred = children.map(_.canonicalized)
- Canonicalize.execute(withNewChildren(canonicalizedChildred))
+ val canonicalizedChildren = children.map(_.canonicalized)
+ Canonicalize.execute(withNewChildren(canonicalizedChildren))
}
/**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala
index dbd0acf06c..2ed6fc0d38 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala
@@ -17,14 +17,14 @@
package org.apache.spark.sql.catalyst.expressions
-import org.apache.spark.rdd.SqlNewHadoopRDDState
+import org.apache.spark.rdd.InputFileNameHolder
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.types.{DataType, StringType}
import org.apache.spark.unsafe.types.UTF8String
/**
- * Expression that returns the name of the current file being read in using [[SqlNewHadoopRDD]]
+ * Expression that returns the name of the current file being read.
*/
@ExpressionDescription(
usage = "_FUNC_() - Returns the name of the current file being read if available",
@@ -40,12 +40,12 @@ case class InputFileName() extends LeafExpression with Nondeterministic {
override protected def initInternal(): Unit = {}
override protected def evalInternal(input: InternalRow): UTF8String = {
- SqlNewHadoopRDDState.getInputFileName()
+ InputFileNameHolder.getInputFileName()
}
override def genCode(ctx: CodegenContext, ev: ExprCode): String = {
ev.isNull = "false"
s"final ${ctx.javaType(dataType)} ${ev.value} = " +
- "org.apache.spark.rdd.SqlNewHadoopRDDState.getInputFileName();"
+ "org.apache.spark.rdd.InputFileNameHolder.getInputFileName();"
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
index 053e612f3e..354311c5e7 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
@@ -136,9 +136,9 @@ object UnsafeProjection {
}
/**
- * Same as other create()'s but allowing enabling/disabling subexpression elimination.
- * TODO: refactor the plumbing and clean this up.
- */
+ * Same as other create()'s but allowing enabling/disabling subexpression elimination.
+ * TODO: refactor the plumbing and clean this up.
+ */
def create(
exprs: Seq[Expression],
inputSchema: Seq[Attribute],
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala
index 4615c55d67..61ca7272df 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala
@@ -62,7 +62,7 @@ import org.apache.spark.sql.types._
abstract class MutableValue extends Serializable {
var isNull: Boolean = true
def boxed: Any
- def update(v: Any)
+ def update(v: Any): Unit
def copy(): MutableValue
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala
new file mode 100644
index 0000000000..daf3de95dd
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala
@@ -0,0 +1,168 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions
+
+import org.apache.commons.lang.StringUtils
+
+import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckFailure
+import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
+import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.CalendarInterval
+
+case class TimeWindow(
+ timeColumn: Expression,
+ windowDuration: Long,
+ slideDuration: Long,
+ startTime: Long) extends UnaryExpression
+ with ImplicitCastInputTypes
+ with Unevaluable
+ with NonSQLExpression {
+
+ //////////////////////////
+ // SQL Constructors
+ //////////////////////////
+
+ def this(
+ timeColumn: Expression,
+ windowDuration: Expression,
+ slideDuration: Expression,
+ startTime: Expression) = {
+ this(timeColumn, TimeWindow.parseExpression(windowDuration),
+ TimeWindow.parseExpression(windowDuration), TimeWindow.parseExpression(startTime))
+ }
+
+ def this(timeColumn: Expression, windowDuration: Expression, slideDuration: Expression) = {
+ this(timeColumn, TimeWindow.parseExpression(windowDuration),
+ TimeWindow.parseExpression(windowDuration), 0)
+ }
+
+ def this(timeColumn: Expression, windowDuration: Expression) = {
+ this(timeColumn, windowDuration, windowDuration)
+ }
+
+ override def child: Expression = timeColumn
+ override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType)
+ override def dataType: DataType = new StructType()
+ .add(StructField("start", TimestampType))
+ .add(StructField("end", TimestampType))
+
+ // This expression is replaced in the analyzer.
+ override lazy val resolved = false
+
+ /**
+ * Validate the inputs for the window duration, slide duration, and start time in addition to
+ * the input data type.
+ */
+ override def checkInputDataTypes(): TypeCheckResult = {
+ val dataTypeCheck = super.checkInputDataTypes()
+ if (dataTypeCheck.isSuccess) {
+ if (windowDuration <= 0) {
+ return TypeCheckFailure(s"The window duration ($windowDuration) must be greater than 0.")
+ }
+ if (slideDuration <= 0) {
+ return TypeCheckFailure(s"The slide duration ($slideDuration) must be greater than 0.")
+ }
+ if (startTime < 0) {
+ return TypeCheckFailure(s"The start time ($startTime) must be greater than or equal to 0.")
+ }
+ if (slideDuration > windowDuration) {
+ return TypeCheckFailure(s"The slide duration ($slideDuration) must be less than or equal" +
+ s" to the windowDuration ($windowDuration).")
+ }
+ if (startTime >= slideDuration) {
+ return TypeCheckFailure(s"The start time ($startTime) must be less than the " +
+ s"slideDuration ($slideDuration).")
+ }
+ }
+ dataTypeCheck
+ }
+}
+
+object TimeWindow {
+ /**
+ * Parses the interval string for a valid time duration. CalendarInterval expects interval
+ * strings to start with the string `interval`. For usability, we prepend `interval` to the string
+ * if the user omitted it.
+ *
+ * @param interval The interval string
+ * @return The interval duration in microseconds. SparkSQL casts TimestampType has microsecond
+ * precision.
+ */
+ private def getIntervalInMicroSeconds(interval: String): Long = {
+ if (StringUtils.isBlank(interval)) {
+ throw new IllegalArgumentException(
+ "The window duration, slide duration and start time cannot be null or blank.")
+ }
+ val intervalString = if (interval.startsWith("interval")) {
+ interval
+ } else {
+ "interval " + interval
+ }
+ val cal = CalendarInterval.fromString(intervalString)
+ if (cal == null) {
+ throw new IllegalArgumentException(
+ s"The provided interval ($interval) did not correspond to a valid interval string.")
+ }
+ if (cal.months > 0) {
+ throw new IllegalArgumentException(
+ s"Intervals greater than a month is not supported ($interval).")
+ }
+ cal.microseconds
+ }
+
+ /**
+ * Parses the duration expression to generate the long value for the original constructor so
+ * that we can use `window` in SQL.
+ */
+ private def parseExpression(expr: Expression): Long = expr match {
+ case NonNullLiteral(s, StringType) => getIntervalInMicroSeconds(s.toString)
+ case IntegerLiteral(i) => i.toLong
+ case NonNullLiteral(l, LongType) => l.toString.toLong
+ case _ => throw new AnalysisException("The duration and time inputs to window must be " +
+ "an integer, long or string literal.")
+ }
+
+ def apply(
+ timeColumn: Expression,
+ windowDuration: String,
+ slideDuration: String,
+ startTime: String): TimeWindow = {
+ TimeWindow(timeColumn,
+ getIntervalInMicroSeconds(windowDuration),
+ getIntervalInMicroSeconds(slideDuration),
+ getIntervalInMicroSeconds(startTime))
+ }
+}
+
+/**
+ * Expression used internally to convert the TimestampType to Long without losing
+ * precision, i.e. in microseconds. Used in time windowing.
+ */
+case class PreciseTimestamp(child: Expression) extends UnaryExpression with ExpectsInputTypes {
+ override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType)
+ override def dataType: DataType = LongType
+ override def genCode(ctx: CodegenContext, ev: ExprCode): String = {
+ val eval = child.gen(ctx)
+ eval.code +
+ s"""boolean ${ev.isNull} = ${eval.isNull};
+ |${ctx.javaType(dataType)} ${ev.value} = ${eval.value};
+ """.stripMargin
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala
index 94ac4bf09b..ff70774847 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala
@@ -23,6 +23,8 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.types._
+@ExpressionDescription(
+ usage = "_FUNC_(x) - Returns the mean calculated from values of a group.")
case class Average(child: Expression) extends DeclarativeAggregate {
override def prettyName: String = "avg"
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala
index 9d2db45144..17a7c6dce8 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala
@@ -130,6 +130,10 @@ abstract class CentralMomentAgg(child: Expression) extends DeclarativeAggregate
}
// Compute the population standard deviation of a column
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+ usage = "_FUNC_(x) - Returns the population standard deviation calculated from values of a group.")
+// scalastyle:on line.size.limit
case class StddevPop(child: Expression) extends CentralMomentAgg(child) {
override protected def momentOrder = 2
@@ -143,6 +147,8 @@ case class StddevPop(child: Expression) extends CentralMomentAgg(child) {
}
// Compute the sample standard deviation of a column
+@ExpressionDescription(
+ usage = "_FUNC_(x) - Returns the sample standard deviation calculated from values of a group.")
case class StddevSamp(child: Expression) extends CentralMomentAgg(child) {
override protected def momentOrder = 2
@@ -157,6 +163,8 @@ case class StddevSamp(child: Expression) extends CentralMomentAgg(child) {
}
// Compute the population variance of a column
+@ExpressionDescription(
+ usage = "_FUNC_(x) - Returns the population variance calculated from values of a group.")
case class VariancePop(child: Expression) extends CentralMomentAgg(child) {
override protected def momentOrder = 2
@@ -170,6 +178,8 @@ case class VariancePop(child: Expression) extends CentralMomentAgg(child) {
}
// Compute the sample variance of a column
+@ExpressionDescription(
+ usage = "_FUNC_(x) - Returns the sample variance calculated from values of a group.")
case class VarianceSamp(child: Expression) extends CentralMomentAgg(child) {
override protected def momentOrder = 2
@@ -183,6 +193,8 @@ case class VarianceSamp(child: Expression) extends CentralMomentAgg(child) {
override def prettyName: String = "var_samp"
}
+@ExpressionDescription(
+ usage = "_FUNC_(x) - Returns the Skewness value calculated from values of a group.")
case class Skewness(child: Expression) extends CentralMomentAgg(child) {
override def prettyName: String = "skewness"
@@ -196,6 +208,8 @@ case class Skewness(child: Expression) extends CentralMomentAgg(child) {
}
}
+@ExpressionDescription(
+ usage = "_FUNC_(x) - Returns the Kurtosis value calculated from values of a group.")
case class Kurtosis(child: Expression) extends CentralMomentAgg(child) {
override protected def momentOrder = 4
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala
index e6b8214ef2..e29265e2f4 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala
@@ -28,6 +28,8 @@ import org.apache.spark.sql.types._
* Definition of Pearson correlation can be found at
* http://en.wikipedia.org/wiki/Pearson_product-moment_correlation_coefficient
*/
+@ExpressionDescription(
+ usage = "_FUNC_(x,y) - Returns Pearson coefficient of correlation between a set of number pairs.")
case class Corr(x: Expression, y: Expression) extends DeclarativeAggregate {
override def children: Seq[Expression] = Seq(x, y)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala
index 663c69e799..17ae012af7 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala
@@ -21,6 +21,12 @@ import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types._
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+ usage = """_FUNC_(*) - Returns the total number of retrieved rows, including rows containing NULL values.
+ _FUNC_(expr) - Returns the number of rows for which the supplied expression is non-NULL.
+ _FUNC_(DISTINCT expr[, expr...]) - Returns the number of rows for which the supplied expression(s) are unique and non-NULL.""")
+// scalastyle:on line.size.limit
case class Count(children: Seq[Expression]) extends DeclarativeAggregate {
override def nullable: Boolean = false
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala
index c175a8c4c7..d80afbebf7 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala
@@ -76,6 +76,8 @@ abstract class Covariance(x: Expression, y: Expression) extends DeclarativeAggre
}
}
+@ExpressionDescription(
+ usage = "_FUNC_(x,y) - Returns the population covariance of a set of number pairs.")
case class CovPopulation(left: Expression, right: Expression) extends Covariance(left, right) {
override val evaluateExpression: Expression = {
If(n === Literal(0.0), Literal.create(null, DoubleType),
@@ -85,6 +87,8 @@ case class CovPopulation(left: Expression, right: Expression) extends Covariance
}
+@ExpressionDescription(
+ usage = "_FUNC_(x,y) - Returns the sample covariance of a set of number pairs.")
case class CovSample(left: Expression, right: Expression) extends Covariance(left, right) {
override val evaluateExpression: Expression = {
If(n === Literal(0.0), Literal.create(null, DoubleType),
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala
index 35f57426fe..b8ab0364dd 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala
@@ -28,6 +28,11 @@ import org.apache.spark.sql.types._
* is used) its result will not be deterministic (unless the input table is sorted and has
* a single partition, and we use a single reducer to do the aggregation.).
*/
+@ExpressionDescription(
+ usage = """_FUNC_(expr) - Returns the first value of `child` for a group of rows.
+ _FUNC_(expr,isIgnoreNull=false) - Returns the first value of `child` for a group of rows.
+ If isIgnoreNull is true, returns only non-null values.
+ """)
case class First(child: Expression, ignoreNullsExpr: Expression) extends DeclarativeAggregate {
def this(child: Expression) = this(child, Literal.create(false, BooleanType))
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala
index b6bd56cff6..1d218da6db 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala
@@ -20,8 +20,6 @@ package org.apache.spark.sql.catalyst.expressions.aggregate
import java.lang.{Long => JLong}
import java.util
-import com.clearspring.analytics.hash.MurmurHash
-
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
@@ -48,6 +46,11 @@ import org.apache.spark.sql.types._
* @param relativeSD the maximum estimation error allowed.
*/
// scalastyle:on
+@ExpressionDescription(
+ usage = """_FUNC_(expr) - Returns the estimated cardinality by HyperLogLog++.
+ _FUNC_(expr, relativeSD=0.05) - Returns the estimated cardinality by HyperLogLog++
+ with relativeSD, the maximum estimation error allowed.
+ """)
case class HyperLogLogPlusPlus(
child: Expression,
relativeSD: Double = 0.05,
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala
index be7e12d7a2..b05d74b49b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala
@@ -28,6 +28,8 @@ import org.apache.spark.sql.types._
* is used) its result will not be deterministic (unless the input table is sorted and has
* a single partition, and we use a single reducer to do the aggregation.).
*/
+@ExpressionDescription(
+ usage = "_FUNC_(expr,isIgnoreNull) - Returns the last value of `child` for a group of rows.")
case class Last(child: Expression, ignoreNullsExpr: Expression) extends DeclarativeAggregate {
def this(child: Expression) = this(child, Literal.create(false, BooleanType))
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala
index 906003188d..c534fe495f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala
@@ -22,6 +22,8 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.types._
+@ExpressionDescription(
+ usage = "_FUNC_(expr) - Returns the maximum value of expr.")
case class Max(child: Expression) extends DeclarativeAggregate {
override def children: Seq[Expression] = child :: Nil
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala
index 39f7afbd08..35289b4681 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala
@@ -22,7 +22,8 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.types._
-
+@ExpressionDescription(
+ usage = "_FUNC_(expr) - Returns the minimum value of expr.")
case class Min(child: Expression) extends DeclarativeAggregate {
override def children: Seq[Expression] = child :: Nil
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala
index 08a67ea3df..ad217f25b5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala
@@ -22,6 +22,8 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.types._
+@ExpressionDescription(
+ usage = "_FUNC_(x) - Returns the sum calculated from values of a group.")
case class Sum(child: Expression) extends DeclarativeAggregate {
override def children: Seq[Expression] = child :: Nil
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
index ff3064ac66..d31ccf9985 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
@@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst.expressions.aggregate
import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.types._
@@ -66,6 +67,19 @@ private[sql] case object NoOp extends Expression with Unevaluable {
override def children: Seq[Expression] = Nil
}
+object AggregateExpression {
+ def apply(
+ aggregateFunction: AggregateFunction,
+ mode: AggregateMode,
+ isDistinct: Boolean): AggregateExpression = {
+ AggregateExpression(
+ aggregateFunction,
+ mode,
+ isDistinct,
+ NamedExpression.newExprId)
+ }
+}
+
/**
* A container for an [[AggregateFunction]] with its [[AggregateMode]] and a field
* (`isDistinct`) indicating if DISTINCT keyword is specified for this function.
@@ -73,10 +87,31 @@ private[sql] case object NoOp extends Expression with Unevaluable {
private[sql] case class AggregateExpression(
aggregateFunction: AggregateFunction,
mode: AggregateMode,
- isDistinct: Boolean)
+ isDistinct: Boolean,
+ resultId: ExprId)
extends Expression
with Unevaluable {
+ lazy val resultAttribute: Attribute = if (aggregateFunction.resolved) {
+ AttributeReference(
+ aggregateFunction.toString,
+ aggregateFunction.dataType,
+ aggregateFunction.nullable)(exprId = resultId)
+ } else {
+ // This is a bit of a hack. Really we should not be constructing this container and reasoning
+ // about datatypes / aggregation mode until after we have finished analysis and made it to
+ // planning.
+ UnresolvedAttribute(aggregateFunction.toString)
+ }
+
+ // We compute the same thing regardless of our final result.
+ override lazy val canonicalized: Expression =
+ AggregateExpression(
+ aggregateFunction.canonicalized.asInstanceOf[AggregateFunction],
+ mode,
+ isDistinct,
+ ExprId(0))
+
override def children: Seq[Expression] = aggregateFunction :: Nil
override def dataType: DataType = aggregateFunction.dataType
override def foldable: Boolean = false
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
index ed812e0679..f3d42fc0b2 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
@@ -23,8 +23,10 @@ import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.CalendarInterval
-
-case class UnaryMinus(child: Expression) extends UnaryExpression with ExpectsInputTypes {
+@ExpressionDescription(
+ usage = "_FUNC_(a) - Returns -a.")
+case class UnaryMinus(child: Expression) extends UnaryExpression
+ with ExpectsInputTypes with NullIntolerant {
override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection.NumericAndInterval)
@@ -58,7 +60,10 @@ case class UnaryMinus(child: Expression) extends UnaryExpression with ExpectsInp
override def sql: String = s"(-${child.sql})"
}
-case class UnaryPositive(child: Expression) extends UnaryExpression with ExpectsInputTypes {
+@ExpressionDescription(
+ usage = "_FUNC_(a) - Returns a.")
+case class UnaryPositive(child: Expression)
+ extends UnaryExpression with ExpectsInputTypes with NullIntolerant {
override def prettyName: String = "positive"
override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection.NumericAndInterval)
@@ -77,9 +82,10 @@ case class UnaryPositive(child: Expression) extends UnaryExpression with Expects
* A function that get the absolute value of the numeric value.
*/
@ExpressionDescription(
- usage = "_FUNC_(expr) - Returns the absolute value of the numeric value",
- extended = "> SELECT _FUNC_('-1');\n1")
-case class Abs(child: Expression) extends UnaryExpression with ExpectsInputTypes {
+ usage = "_FUNC_(expr) - Returns the absolute value of the numeric value.",
+ extended = "> SELECT _FUNC_('-1');\n 1")
+case class Abs(child: Expression)
+ extends UnaryExpression with ExpectsInputTypes with NullIntolerant {
override def inputTypes: Seq[AbstractDataType] = Seq(NumericType)
@@ -123,7 +129,9 @@ private[sql] object BinaryArithmetic {
def unapply(e: BinaryArithmetic): Option[(Expression, Expression)] = Some((e.left, e.right))
}
-case class Add(left: Expression, right: Expression) extends BinaryArithmetic {
+@ExpressionDescription(
+ usage = "a _FUNC_ b - Returns a+b.")
+case class Add(left: Expression, right: Expression) extends BinaryArithmetic with NullIntolerant {
override def inputType: AbstractDataType = TypeCollection.NumericAndInterval
@@ -152,7 +160,10 @@ case class Add(left: Expression, right: Expression) extends BinaryArithmetic {
}
}
-case class Subtract(left: Expression, right: Expression) extends BinaryArithmetic {
+@ExpressionDescription(
+ usage = "a _FUNC_ b - Returns a-b.")
+case class Subtract(left: Expression, right: Expression)
+ extends BinaryArithmetic with NullIntolerant {
override def inputType: AbstractDataType = TypeCollection.NumericAndInterval
@@ -181,7 +192,10 @@ case class Subtract(left: Expression, right: Expression) extends BinaryArithmeti
}
}
-case class Multiply(left: Expression, right: Expression) extends BinaryArithmetic {
+@ExpressionDescription(
+ usage = "a _FUNC_ b - Multiplies a by b.")
+case class Multiply(left: Expression, right: Expression)
+ extends BinaryArithmetic with NullIntolerant {
override def inputType: AbstractDataType = NumericType
@@ -193,7 +207,11 @@ case class Multiply(left: Expression, right: Expression) extends BinaryArithmeti
protected override def nullSafeEval(input1: Any, input2: Any): Any = numeric.times(input1, input2)
}
-case class Divide(left: Expression, right: Expression) extends BinaryArithmetic {
+@ExpressionDescription(
+ usage = "a _FUNC_ b - Divides a by b.",
+ extended = "> SELECT 3 _FUNC_ 2;\n 1.5")
+case class Divide(left: Expression, right: Expression)
+ extends BinaryArithmetic with NullIntolerant {
override def inputType: AbstractDataType = NumericType
@@ -237,25 +255,42 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic
} else {
s"($javaType)(${eval1.value} $symbol ${eval2.value})"
}
- s"""
- ${eval2.code}
- boolean ${ev.isNull} = false;
- $javaType ${ev.value} = ${ctx.defaultValue(javaType)};
- if (${eval2.isNull} || $isZero) {
- ${ev.isNull} = true;
- } else {
- ${eval1.code}
- if (${eval1.isNull}) {
+ if (!left.nullable && !right.nullable) {
+ s"""
+ ${eval2.code}
+ boolean ${ev.isNull} = false;
+ $javaType ${ev.value} = ${ctx.defaultValue(javaType)};
+ if ($isZero) {
${ev.isNull} = true;
} else {
+ ${eval1.code}
${ev.value} = $divide;
}
- }
- """
+ """
+ } else {
+ s"""
+ ${eval2.code}
+ boolean ${ev.isNull} = false;
+ $javaType ${ev.value} = ${ctx.defaultValue(javaType)};
+ if (${eval2.isNull} || $isZero) {
+ ${ev.isNull} = true;
+ } else {
+ ${eval1.code}
+ if (${eval1.isNull}) {
+ ${ev.isNull} = true;
+ } else {
+ ${ev.value} = $divide;
+ }
+ }
+ """
+ }
}
}
-case class Remainder(left: Expression, right: Expression) extends BinaryArithmetic {
+@ExpressionDescription(
+ usage = "a _FUNC_ b - Returns the remainder when dividing a by b.")
+case class Remainder(left: Expression, right: Expression)
+ extends BinaryArithmetic with NullIntolerant {
override def inputType: AbstractDataType = NumericType
@@ -299,21 +334,35 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet
} else {
s"($javaType)(${eval1.value} $symbol ${eval2.value})"
}
- s"""
- ${eval2.code}
- boolean ${ev.isNull} = false;
- $javaType ${ev.value} = ${ctx.defaultValue(javaType)};
- if (${eval2.isNull} || $isZero) {
- ${ev.isNull} = true;
- } else {
- ${eval1.code}
- if (${eval1.isNull}) {
+ if (!left.nullable && !right.nullable) {
+ s"""
+ ${eval2.code}
+ boolean ${ev.isNull} = false;
+ $javaType ${ev.value} = ${ctx.defaultValue(javaType)};
+ if ($isZero) {
${ev.isNull} = true;
} else {
+ ${eval1.code}
${ev.value} = $remainder;
}
- }
- """
+ """
+ } else {
+ s"""
+ ${eval2.code}
+ boolean ${ev.isNull} = false;
+ $javaType ${ev.value} = ${ctx.defaultValue(javaType)};
+ if (${eval2.isNull} || $isZero) {
+ ${ev.isNull} = true;
+ } else {
+ ${eval1.code}
+ if (${eval1.isNull}) {
+ ${ev.isNull} = true;
+ } else {
+ ${ev.value} = $remainder;
+ }
+ }
+ """
+ }
}
}
@@ -429,7 +478,10 @@ case class MinOf(left: Expression, right: Expression)
override def symbol: String = "min"
}
-case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic {
+@ExpressionDescription(
+ usage = "_FUNC_(a, b) - Returns the positive modulo",
+ extended = "> SELECT _FUNC_(10,3);\n 1")
+case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic with NullIntolerant {
override def toString: String = s"pmod($left, $right)"
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala
index 4c90b3f7d3..a7e1cd66f2 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala
@@ -26,6 +26,9 @@ import org.apache.spark.sql.types._
*
* Code generation inherited from BinaryArithmetic.
*/
+@ExpressionDescription(
+ usage = "a _FUNC_ b - Bitwise AND.",
+ extended = "> SELECT 3 _FUNC_ 5; 1")
case class BitwiseAnd(left: Expression, right: Expression) extends BinaryArithmetic {
override def inputType: AbstractDataType = IntegralType
@@ -51,6 +54,9 @@ case class BitwiseAnd(left: Expression, right: Expression) extends BinaryArithme
*
* Code generation inherited from BinaryArithmetic.
*/
+@ExpressionDescription(
+ usage = "a _FUNC_ b - Bitwise OR.",
+ extended = "> SELECT 3 _FUNC_ 5; 7")
case class BitwiseOr(left: Expression, right: Expression) extends BinaryArithmetic {
override def inputType: AbstractDataType = IntegralType
@@ -76,6 +82,9 @@ case class BitwiseOr(left: Expression, right: Expression) extends BinaryArithmet
*
* Code generation inherited from BinaryArithmetic.
*/
+@ExpressionDescription(
+ usage = "a _FUNC_ b - Bitwise exclusive OR.",
+ extended = "> SELECT 3 _FUNC_ 5; 2")
case class BitwiseXor(left: Expression, right: Expression) extends BinaryArithmetic {
override def inputType: AbstractDataType = IntegralType
@@ -99,6 +108,9 @@ case class BitwiseXor(left: Expression, right: Expression) extends BinaryArithme
/**
* A function that calculates bitwise not(~) of a number.
*/
+@ExpressionDescription(
+ usage = "_FUNC_ b - Bitwise NOT.",
+ extended = "> SELECT _FUNC_ 0; -1")
case class BitwiseNot(child: Expression) extends UnaryExpression with ExpectsInputTypes {
override def inputTypes: Seq[AbstractDataType] = Seq(IntegralType)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatter.scala
index 9d99bbffbe..ab4831f7ab 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatter.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatter.scala
@@ -43,15 +43,45 @@ object CodeFormatter {
private class CodeFormatter {
private val code = new StringBuilder
- private var indentLevel = 0
private val indentSize = 2
+
+ // Tracks the level of indentation in the current line.
+ private var indentLevel = 0
private var indentString = ""
private var currentLine = 1
+ // Tracks the level of indentation in multi-line comment blocks.
+ private var inCommentBlock = false
+ private var indentLevelOutsideCommentBlock = indentLevel
+
private def addLine(line: String): Unit = {
- val indentChange =
- line.count(c => "({".indexOf(c) >= 0) - line.count(c => ")}".indexOf(c) >= 0)
- val newIndentLevel = math.max(0, indentLevel + indentChange)
+
+ // We currently infer the level of indentation of a given line based on a simple heuristic that
+ // examines the number of parenthesis and braces in that line. This isn't the most robust
+ // implementation but works for all code that we generate.
+ val indentChange = line.count(c => "({".indexOf(c) >= 0) - line.count(c => ")}".indexOf(c) >= 0)
+ var newIndentLevel = math.max(0, indentLevel + indentChange)
+
+ // Please note that while we try to format the comment blocks in exactly the same way as the
+ // rest of the code, once the block ends, we reset the next line's indentation level to what it
+ // was immediately before entering the comment block.
+ if (!inCommentBlock) {
+ if (line.startsWith("/*")) {
+ // Handle multi-line comments
+ inCommentBlock = true
+ indentLevelOutsideCommentBlock = indentLevel
+ } else if (line.startsWith("//")) {
+ // Handle single line comments
+ newIndentLevel = indentLevel
+ }
+ }
+ if (inCommentBlock) {
+ if (line.endsWith("*/")) {
+ inCommentBlock = false
+ newIndentLevel = indentLevelOutsideCommentBlock
+ }
+ }
+
// Lines starting with '}' should be de-indented even if they contain '{' after;
// in addition, lines ending with ':' are typically labels
val thisLineIndent = if (line.startsWith("}") || line.startsWith(")") || line.endsWith(":")) {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index b511b4b3a0..f43626ca81 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -58,10 +58,10 @@ class CodegenContext {
val references: mutable.ArrayBuffer[Any] = new mutable.ArrayBuffer[Any]()
/**
- * Add an object to `references`, create a class member to access it.
- *
- * Returns the name of class member.
- */
+ * Add an object to `references`, create a class member to access it.
+ *
+ * Returns the name of class member.
+ */
def addReferenceObj(name: String, obj: Any, className: String = null): String = {
val term = freshName(name)
val idx = references.length
@@ -72,9 +72,9 @@ class CodegenContext {
}
/**
- * Holding a list of generated columns as input of current operator, will be used by
- * BoundReference to generate code.
- */
+ * Holding a list of generated columns as input of current operator, will be used by
+ * BoundReference to generate code.
+ */
var currentVars: Seq[ExprCode] = null
/**
@@ -169,14 +169,14 @@ class CodegenContext {
final var INPUT_ROW = "i"
/**
- * The map from a variable name to it's next ID.
- */
+ * The map from a variable name to it's next ID.
+ */
private val freshNameIds = new mutable.HashMap[String, Int]
freshNameIds += INPUT_ROW -> 1
/**
- * A prefix used to generate fresh name.
- */
+ * A prefix used to generate fresh name.
+ */
var freshNamePrefix = ""
/**
@@ -234,8 +234,8 @@ class CodegenContext {
}
/**
- * Update a column in MutableRow from ExprCode.
- */
+ * Update a column in MutableRow from ExprCode.
+ */
def updateColumn(
row: String,
dataType: DataType,
@@ -509,7 +509,7 @@ class CodegenContext {
/**
* Checks and sets up the state and codegen for subexpression elimination. This finds the
- * common subexpresses, generates the functions that evaluate those expressions and populates
+ * common subexpressions, generates the functions that evaluate those expressions and populates
* the mapping of common subexpressions to the generated functions.
*/
private def subexpressionElimination(expressions: Seq[Expression]) = {
@@ -519,7 +519,7 @@ class CodegenContext {
// Get all the expressions that appear at least twice and set up the state for subexpression
// elimination.
val commonExprs = equivalentExpressions.getAllEquivalentExprs.filter(_.size > 1)
- commonExprs.foreach(e => {
+ commonExprs.foreach { e =>
val expr = e.head
val fnName = freshName("evalExpr")
val isNull = s"${fnName}IsNull"
@@ -561,7 +561,7 @@ class CodegenContext {
subexprFunctions += s"$fnName($INPUT_ROW);"
val state = SubExprEliminationState(isNull, value)
e.foreach(subExprEliminationExprs.put(_, state))
- })
+ }
}
/**
@@ -626,15 +626,15 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
object CodeGenerator extends Logging {
/**
- * Compile the Java source code into a Java class, using Janino.
- */
+ * Compile the Java source code into a Java class, using Janino.
+ */
def compile(code: String): GeneratedClass = {
cache.get(code)
}
/**
- * Compile the Java source code into a Java class, using Janino.
- */
+ * Compile the Java source code into a Java class, using Janino.
+ */
private[this] def doCompile(code: String): GeneratedClass = {
val evaluator = new ClassBodyEvaluator()
evaluator.setParentClassLoader(Utils.getContextOrSparkClassLoader)
@@ -661,7 +661,7 @@ object CodeGenerator extends Logging {
logDebug({
// Only add extra debugging info to byte code when we are going to print the source code.
evaluator.setDebuggingInformation(true, true, false)
- formatted
+ s"\n$formatted"
})
try {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
index e36c985249..ab790cf372 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
@@ -26,6 +26,8 @@ import org.apache.spark.sql.types._
/**
* Given an array or map, returns its size.
*/
+@ExpressionDescription(
+ usage = "_FUNC_(expr) - Returns the size of an array or a map.")
case class Size(child: Expression) extends UnaryExpression with ExpectsInputTypes {
override def dataType: DataType = IntegerType
override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(ArrayType, MapType))
@@ -44,6 +46,11 @@ case class Size(child: Expression) extends UnaryExpression with ExpectsInputType
* Sorts the input array in ascending / descending order according to the natural ordering of
* the array elements and returns it.
*/
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+ usage = "_FUNC_(array(obj1, obj2,...)) - Sorts the input array in ascending order according to the natural ordering of the array elements.",
+ extended = " > SELECT _FUNC_(array('b', 'd', 'c', 'a'));\n 'a', 'b', 'c', 'd'")
+// scalastyle:on line.size.limit
case class SortArray(base: Expression, ascendingOrder: Expression)
extends BinaryExpression with ExpectsInputTypes with CodegenFallback {
@@ -125,6 +132,9 @@ case class SortArray(base: Expression, ascendingOrder: Expression)
/**
* Checks if the array (left) has the element (right)
*/
+@ExpressionDescription(
+ usage = "_FUNC_(array, value) - Returns TRUE if the array contains value.",
+ extended = " > SELECT _FUNC_(array(1, 2, 3), 2);\n true")
case class ArrayContains(left: Expression, right: Expression)
extends BinaryExpression with ImplicitCastInputTypes {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
index c299586dde..74de4a776d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
@@ -27,6 +27,8 @@ import org.apache.spark.unsafe.types.UTF8String
/**
* Returns an Array containing the evaluation of all children expressions.
*/
+@ExpressionDescription(
+ usage = "_FUNC_(n0, ...) - Returns an array with the given elements.")
case class CreateArray(children: Seq[Expression]) extends Expression {
override def foldable: Boolean = children.forall(_.foldable)
@@ -73,6 +75,8 @@ case class CreateArray(children: Seq[Expression]) extends Expression {
* Returns a catalyst Map containing the evaluation of all children expressions as keys and values.
* The children are a flatted sequence of kv pairs, e.g. (key1, value1, key2, value2, ...)
*/
+@ExpressionDescription(
+ usage = "_FUNC_(key0, value0, key1, value1...) - Creates a map with the given key/value pairs.")
case class CreateMap(children: Seq[Expression]) extends Expression {
private[sql] lazy val keys = children.indices.filter(_ % 2 == 0).map(children)
private[sql] lazy val values = children.indices.filter(_ % 2 != 0).map(children)
@@ -153,6 +157,8 @@ case class CreateMap(children: Seq[Expression]) extends Expression {
/**
* Returns a Row containing the evaluation of all children expressions.
*/
+@ExpressionDescription(
+ usage = "_FUNC_(col1, col2, col3, ...) - Creates a struct with the given field values.")
case class CreateStruct(children: Seq[Expression]) extends Expression {
override def foldable: Boolean = children.forall(_.foldable)
@@ -204,6 +210,10 @@ case class CreateStruct(children: Seq[Expression]) extends Expression {
*
* @param children Seq(name1, val1, name2, val2, ...)
*/
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+ usage = "_FUNC_(name1, val1, name2, val2, ...) - Creates a struct with the given field names and values.")
+// scalastyle:on line.size.limit
case class CreateNamedStruct(children: Seq[Expression]) extends Expression {
/**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
index 103ab365e3..ae6a94842f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
@@ -23,7 +23,10 @@ import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.types._
-
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+ usage = "_FUNC_(expr1,expr2,expr3) - If expr1 is TRUE then IF() returns expr2; otherwise it returns expr3.")
+// scalastyle:on line.size.limit
case class If(predicate: Expression, trueValue: Expression, falseValue: Expression)
extends Expression {
@@ -85,6 +88,10 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi
* @param branches seq of (branch condition, branch value)
* @param elseValue optional value for the else branch
*/
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+ usage = "CASE WHEN a THEN b [WHEN c THEN d]* [ELSE e] END - When a = true, returns b; when c = true, return d; else return e.")
+// scalastyle:on line.size.limit
case class CaseWhen(branches: Seq[(Expression, Expression)], elseValue: Option[Expression] = None)
extends Expression with CodegenFallback {
@@ -222,7 +229,7 @@ object CaseWhen {
}
/**
- * A factory method to faciliate the creation of this expression when used in parsers.
+ * A factory method to facilitate the creation of this expression when used in parsers.
* @param branches Expressions at even position are the branch conditions, and expressions at odd
* position are branch values.
*/
@@ -256,6 +263,8 @@ object CaseKeyWhen {
* A function that returns the least value of all parameters, skipping null values.
* It takes at least 2 parameters, and returns null iff all parameters are null.
*/
+@ExpressionDescription(
+ usage = "_FUNC_(n1, ...) - Returns the least value of all parameters, skipping null values.")
case class Least(children: Seq[Expression]) extends Expression {
override def nullable: Boolean = children.forall(_.nullable)
@@ -315,6 +324,8 @@ case class Least(children: Seq[Expression]) extends Expression {
* A function that returns the greatest value of all parameters, skipping null values.
* It takes at least 2 parameters, and returns null iff all parameters are null.
*/
+@ExpressionDescription(
+ usage = "_FUNC_(n1, ...) - Returns the greatest value of all parameters, skipping null values.")
case class Greatest(children: Seq[Expression]) extends Expression {
override def nullable: Boolean = children.forall(_.nullable)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala
index 1d0ea68d7a..9135753041 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala
@@ -35,6 +35,8 @@ import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
*
* There is no code generation since this expression should get constant folded by the optimizer.
*/
+@ExpressionDescription(
+ usage = "_FUNC_() - Returns the current date at the start of query evaluation.")
case class CurrentDate() extends LeafExpression with CodegenFallback {
override def foldable: Boolean = true
override def nullable: Boolean = false
@@ -54,6 +56,8 @@ case class CurrentDate() extends LeafExpression with CodegenFallback {
*
* There is no code generation since this expression should get constant folded by the optimizer.
*/
+@ExpressionDescription(
+ usage = "_FUNC_() - Returns the current timestamp at the start of query evaluation.")
case class CurrentTimestamp() extends LeafExpression with CodegenFallback {
override def foldable: Boolean = true
override def nullable: Boolean = false
@@ -70,6 +74,9 @@ case class CurrentTimestamp() extends LeafExpression with CodegenFallback {
/**
* Adds a number of days to startdate.
*/
+@ExpressionDescription(
+ usage = "_FUNC_(start_date, num_days) - Returns the date that is num_days after start_date.",
+ extended = "> SELECT _FUNC_('2016-07-30', 1);\n '2016-07-31'")
case class DateAdd(startDate: Expression, days: Expression)
extends BinaryExpression with ImplicitCastInputTypes {
@@ -96,6 +103,9 @@ case class DateAdd(startDate: Expression, days: Expression)
/**
* Subtracts a number of days to startdate.
*/
+@ExpressionDescription(
+ usage = "_FUNC_(start_date, num_days) - Returns the date that is num_days before start_date.",
+ extended = "> SELECT _FUNC_('2016-07-30', 1);\n '2016-07-29'")
case class DateSub(startDate: Expression, days: Expression)
extends BinaryExpression with ImplicitCastInputTypes {
override def left: Expression = startDate
@@ -118,6 +128,9 @@ case class DateSub(startDate: Expression, days: Expression)
override def prettyName: String = "date_sub"
}
+@ExpressionDescription(
+ usage = "_FUNC_(param) - Returns the hour component of the string/timestamp/interval.",
+ extended = "> SELECT _FUNC_('2009-07-30 12:58:59');\n 12")
case class Hour(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType)
@@ -134,6 +147,9 @@ case class Hour(child: Expression) extends UnaryExpression with ImplicitCastInpu
}
}
+@ExpressionDescription(
+ usage = "_FUNC_(param) - Returns the minute component of the string/timestamp/interval.",
+ extended = "> SELECT _FUNC_('2009-07-30 12:58:59');\n 58")
case class Minute(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType)
@@ -150,6 +166,9 @@ case class Minute(child: Expression) extends UnaryExpression with ImplicitCastIn
}
}
+@ExpressionDescription(
+ usage = "_FUNC_(param) - Returns the second component of the string/timestamp/interval.",
+ extended = "> SELECT _FUNC_('2009-07-30 12:58:59');\n 59")
case class Second(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType)
@@ -166,6 +185,9 @@ case class Second(child: Expression) extends UnaryExpression with ImplicitCastIn
}
}
+@ExpressionDescription(
+ usage = "_FUNC_(param) - Returns the day of year of date/timestamp.",
+ extended = "> SELECT _FUNC_('2016-04-09');\n 100")
case class DayOfYear(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
override def inputTypes: Seq[AbstractDataType] = Seq(DateType)
@@ -182,7 +204,9 @@ case class DayOfYear(child: Expression) extends UnaryExpression with ImplicitCas
}
}
-
+@ExpressionDescription(
+ usage = "_FUNC_(param) - Returns the year component of the date/timestamp/interval.",
+ extended = "> SELECT _FUNC_('2016-07-30');\n 2016")
case class Year(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
override def inputTypes: Seq[AbstractDataType] = Seq(DateType)
@@ -199,6 +223,8 @@ case class Year(child: Expression) extends UnaryExpression with ImplicitCastInpu
}
}
+@ExpressionDescription(
+ usage = "_FUNC_(param) - Returns the quarter of the year for date, in the range 1 to 4.")
case class Quarter(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
override def inputTypes: Seq[AbstractDataType] = Seq(DateType)
@@ -215,6 +241,9 @@ case class Quarter(child: Expression) extends UnaryExpression with ImplicitCastI
}
}
+@ExpressionDescription(
+ usage = "_FUNC_(param) - Returns the month component of the date/timestamp/interval",
+ extended = "> SELECT _FUNC_('2016-07-30');\n 7")
case class Month(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
override def inputTypes: Seq[AbstractDataType] = Seq(DateType)
@@ -231,6 +260,9 @@ case class Month(child: Expression) extends UnaryExpression with ImplicitCastInp
}
}
+@ExpressionDescription(
+ usage = "_FUNC_(param) - Returns the day of month of date/timestamp, or the day of interval.",
+ extended = "> SELECT _FUNC_('2009-07-30');\n 30")
case class DayOfMonth(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
override def inputTypes: Seq[AbstractDataType] = Seq(DateType)
@@ -247,6 +279,9 @@ case class DayOfMonth(child: Expression) extends UnaryExpression with ImplicitCa
}
}
+@ExpressionDescription(
+ usage = "_FUNC_(param) - Returns the week of the year of the given date.",
+ extended = "> SELECT _FUNC_('2008-02-20');\n 8")
case class WeekOfYear(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
override def inputTypes: Seq[AbstractDataType] = Seq(DateType)
@@ -283,6 +318,11 @@ case class WeekOfYear(child: Expression) extends UnaryExpression with ImplicitCa
}
}
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+ usage = "_FUNC_(date/timestamp/string, fmt) - Converts a date/timestamp/string to a value of string in the format specified by the date format fmt.",
+ extended = "> SELECT _FUNC_('2016-04-08', 'y')\n '2016'")
+// scalastyle:on line.size.limit
case class DateFormatClass(left: Expression, right: Expression) extends BinaryExpression
with ImplicitCastInputTypes {
@@ -310,6 +350,8 @@ case class DateFormatClass(left: Expression, right: Expression) extends BinaryEx
* Converts time string with given pattern.
* Deterministic version of [[UnixTimestamp]], must have at least one parameter.
*/
+@ExpressionDescription(
+ usage = "_FUNC_(date[, pattern]) - Returns the UNIX timestamp of the give time.")
case class ToUnixTimestamp(timeExp: Expression, format: Expression) extends UnixTime {
override def left: Expression = timeExp
override def right: Expression = format
@@ -331,6 +373,8 @@ case class ToUnixTimestamp(timeExp: Expression, format: Expression) extends Unix
* If the first parameter is a Date or Timestamp instead of String, we will ignore the
* second parameter.
*/
+@ExpressionDescription(
+ usage = "_FUNC_([date[, pattern]]) - Returns the UNIX timestamp of current or specified time.")
case class UnixTimestamp(timeExp: Expression, format: Expression) extends UnixTime {
override def left: Expression = timeExp
override def right: Expression = format
@@ -459,6 +503,9 @@ abstract class UnixTime extends BinaryExpression with ExpectsInputTypes {
* format. If the format is missing, using format like "1970-01-01 00:00:00".
* Note that hive Language Manual says it returns 0 if fail, but in fact it returns null.
*/
+@ExpressionDescription(
+ usage = "_FUNC_(unix_time, format) - Returns unix_time in the specified format",
+ extended = "> SELECT _FUNC_(0, 'yyyy-MM-dd HH:mm:ss');\n '1970-01-01 00:00:00'")
case class FromUnixTime(sec: Expression, format: Expression)
extends BinaryExpression with ImplicitCastInputTypes {
@@ -544,6 +591,9 @@ case class FromUnixTime(sec: Expression, format: Expression)
/**
* Returns the last day of the month which the date belongs to.
*/
+@ExpressionDescription(
+ usage = "_FUNC_(date) - Returns the last day of the month which the date belongs to.",
+ extended = "> SELECT _FUNC_('2009-01-12');\n '2009-01-31'")
case class LastDay(startDate: Expression) extends UnaryExpression with ImplicitCastInputTypes {
override def child: Expression = startDate
@@ -570,6 +620,11 @@ case class LastDay(startDate: Expression) extends UnaryExpression with ImplicitC
*
* Allowed "dayOfWeek" is defined in [[DateTimeUtils.getDayOfWeekFromString]].
*/
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+ usage = "_FUNC_(start_date, day_of_week) - Returns the first date which is later than start_date and named as indicated.",
+ extended = "> SELECT _FUNC_('2015-01-14', 'TU');\n '2015-01-20'")
+// scalastyle:on line.size.limit
case class NextDay(startDate: Expression, dayOfWeek: Expression)
extends BinaryExpression with ImplicitCastInputTypes {
@@ -654,6 +709,10 @@ case class TimeAdd(start: Expression, interval: Expression)
/**
* Assumes given timestamp is UTC and converts to given timezone.
*/
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+ usage = "_FUNC_(timestamp, string timezone) - Assumes given timestamp is UTC and converts to given timezone.")
+// scalastyle:on line.size.limit
case class FromUTCTimestamp(left: Expression, right: Expression)
extends BinaryExpression with ImplicitCastInputTypes {
@@ -729,6 +788,9 @@ case class TimeSub(start: Expression, interval: Expression)
/**
* Returns the date that is num_months after start_date.
*/
+@ExpressionDescription(
+ usage = "_FUNC_(start_date, num_months) - Returns the date that is num_months after start_date.",
+ extended = "> SELECT _FUNC_('2016-08-31', 1);\n '2016-09-30'")
case class AddMonths(startDate: Expression, numMonths: Expression)
extends BinaryExpression with ImplicitCastInputTypes {
@@ -756,6 +818,9 @@ case class AddMonths(startDate: Expression, numMonths: Expression)
/**
* Returns number of months between dates date1 and date2.
*/
+@ExpressionDescription(
+ usage = "_FUNC_(date1, date2) - returns number of months between dates date1 and date2.",
+ extended = "> SELECT _FUNC_('1997-02-28 10:30:00', '1996-10-30');\n 3.94959677")
case class MonthsBetween(date1: Expression, date2: Expression)
extends BinaryExpression with ImplicitCastInputTypes {
@@ -783,6 +848,10 @@ case class MonthsBetween(date1: Expression, date2: Expression)
/**
* Assumes given timestamp is in given timezone and converts to UTC.
*/
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+ usage = "_FUNC_(timestamp, string timezone) - Assumes given timestamp is in given timezone and converts to UTC.")
+// scalastyle:on line.size.limit
case class ToUTCTimestamp(left: Expression, right: Expression)
extends BinaryExpression with ImplicitCastInputTypes {
@@ -830,6 +899,9 @@ case class ToUTCTimestamp(left: Expression, right: Expression)
/**
* Returns the date part of a timestamp or string.
*/
+@ExpressionDescription(
+ usage = "_FUNC_(expr) - Extracts the date part of the date or datetime expression expr.",
+ extended = "> SELECT _FUNC_('2009-07-30 04:17:52');\n '2009-07-30'")
case class ToDate(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
// Implicit casting of spark will accept string in both date and timestamp format, as
@@ -850,6 +922,11 @@ case class ToDate(child: Expression) extends UnaryExpression with ImplicitCastIn
/**
* Returns date truncated to the unit specified by the format.
*/
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+ usage = "_FUNC_(date, fmt) - Returns returns date with the time portion of the day truncated to the unit specified by the format model fmt.",
+ extended = "> SELECT _FUNC_('2009-02-12', 'MM')\n '2009-02-01'\n> SELECT _FUNC_('2015-10-27', 'YEAR');\n '2015-01-01'")
+// scalastyle:on line.size.limit
case class TruncDate(date: Expression, format: Expression)
extends BinaryExpression with ImplicitCastInputTypes {
override def left: Expression = date
@@ -921,6 +998,9 @@ case class TruncDate(date: Expression, format: Expression)
/**
* Returns the number of days from startDate to endDate.
*/
+@ExpressionDescription(
+ usage = "_FUNC_(date1, date2) - Returns the number of days between date1 and date2.",
+ extended = "> SELECT _FUNC_('2009-07-30', '2009-07-31');\n 1")
case class DateDiff(endDate: Expression, startDate: Expression)
extends BinaryExpression with ImplicitCastInputTypes {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
index e7ef21aa85..65d7a1d5a0 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
@@ -99,6 +99,10 @@ case class UserDefinedGenerator(
/**
* Given an input array produces a sequence of rows for each value in the array.
*/
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+ usage = "_FUNC_(a) - Separates the elements of array a into multiple rows, or the elements of a map into multiple rows and columns.")
+// scalastyle:on line.size.limit
case class Explode(child: Expression) extends UnaryExpression with Generator with CodegenFallback {
override def children: Seq[Expression] = child :: Nil
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/grouping.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/grouping.scala
index 437e417266..3be761c867 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/grouping.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/grouping.scala
@@ -22,8 +22,8 @@ import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.types._
/**
- * A placeholder expression for cube/rollup, which will be replaced by analyzer
- */
+ * A placeholder expression for cube/rollup, which will be replaced by analyzer
+ */
trait GroupingSet extends Expression with CodegenFallback {
def groupByExprs: Seq[Expression]
@@ -43,9 +43,9 @@ case class Cube(groupByExprs: Seq[Expression]) extends GroupingSet {}
case class Rollup(groupByExprs: Seq[Expression]) extends GroupingSet {}
/**
- * Indicates whether a specified column expression in a GROUP BY list is aggregated or not.
- * GROUPING returns 1 for aggregated or 0 for not aggregated in the result set.
- */
+ * Indicates whether a specified column expression in a GROUP BY list is aggregated or not.
+ * GROUPING returns 1 for aggregated or 0 for not aggregated in the result set.
+ */
case class Grouping(child: Expression) extends Expression with Unevaluable {
override def references: AttributeSet = AttributeSet(VirtualColumn.groupingIdAttribute :: Nil)
override def children: Seq[Expression] = child :: Nil
@@ -54,10 +54,10 @@ case class Grouping(child: Expression) extends Expression with Unevaluable {
}
/**
- * GroupingID is a function that computes the level of grouping.
- *
- * If groupByExprs is empty, it means all grouping expressions in GroupingSets.
- */
+ * GroupingID is a function that computes the level of grouping.
+ *
+ * If groupByExprs is empty, it means all grouping expressions in GroupingSets.
+ */
case class GroupingID(groupByExprs: Seq[Expression]) extends Expression with Unevaluable {
override def references: AttributeSet = AttributeSet(VirtualColumn.groupingIdAttribute :: Nil)
override def children: Seq[Expression] = groupByExprs
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala
index 72b323587c..ecd09b7083 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala
@@ -106,6 +106,8 @@ private[this] object SharedFactory {
* Extracts json object from a json string based on json path specified, and returns json string
* of the extracted json object. It will return null if the input json string is invalid.
*/
+@ExpressionDescription(
+ usage = "_FUNC_(json_txt, path) - Extract a json object from path")
case class GetJsonObject(json: Expression, path: Expression)
extends BinaryExpression with ExpectsInputTypes with CodegenFallback {
@@ -319,6 +321,10 @@ case class GetJsonObject(json: Expression, path: Expression)
}
}
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+ usage = "_FUNC_(jsonStr, p1, p2, ..., pn) - like get_json_object, but it takes multiple names and return a tuple. All the input parameters and output column types are string.")
+// scalastyle:on line.size.limit
case class JsonTuple(children: Seq[Expression])
extends Generator with CodegenFallback {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala
index e3d1bc127d..c8a28e8477 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala
@@ -50,6 +50,7 @@ abstract class LeafMathExpression(c: Double, name: String)
/**
* A unary expression specifically for math functions. Math Functions expect a specific type of
* input format, therefore these functions extend `ExpectsInputTypes`.
+ *
* @param f The math function.
* @param name The short name of the function
*/
@@ -103,6 +104,7 @@ abstract class UnaryLogExpression(f: Double => Double, name: String)
/**
* A binary expression specifically for math functions that take two `Double`s as input and returns
* a `Double`.
+ *
* @param f The math function.
* @param name The short name of the function
*/
@@ -136,12 +138,18 @@ abstract class BinaryMathExpression(f: (Double, Double) => Double, name: String)
* Euler's number. Note that there is no code generation because this is only
* evaluated by the optimizer during constant folding.
*/
+@ExpressionDescription(
+ usage = "_FUNC_() - Returns Euler's number, E.",
+ extended = "> SELECT _FUNC_();\n 2.718281828459045")
case class EulerNumber() extends LeafMathExpression(math.E, "E")
/**
* Pi. Note that there is no code generation because this is only
* evaluated by the optimizer during constant folding.
*/
+@ExpressionDescription(
+ usage = "_FUNC_() - Returns PI.",
+ extended = "> SELECT _FUNC_();\n 3.141592653589793")
case class Pi() extends LeafMathExpression(math.Pi, "PI")
////////////////////////////////////////////////////////////////////////////////////////////////////
@@ -150,14 +158,29 @@ case class Pi() extends LeafMathExpression(math.Pi, "PI")
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
+@ExpressionDescription(
+ usage = "_FUNC_(x) - Returns the arc cosine of x if -1<=x<=1 or NaN otherwise.",
+ extended = "> SELECT _FUNC_(1);\n 0.0\n> SELECT _FUNC_(2);\n NaN")
case class Acos(child: Expression) extends UnaryMathExpression(math.acos, "ACOS")
+@ExpressionDescription(
+ usage = "_FUNC_(x) - Returns the arc sin of x if -1<=x<=1 or NaN otherwise.",
+ extended = "> SELECT _FUNC_(0);\n 0.0\n> SELECT _FUNC_(2);\n NaN")
case class Asin(child: Expression) extends UnaryMathExpression(math.asin, "ASIN")
+@ExpressionDescription(
+ usage = "_FUNC_(x) - Returns the arc tangent.",
+ extended = "> SELECT _FUNC_(0);\n 0.0")
case class Atan(child: Expression) extends UnaryMathExpression(math.atan, "ATAN")
+@ExpressionDescription(
+ usage = "_FUNC_(x) - Returns the cube root of a double value.",
+ extended = "> SELECT _FUNC_(27.0);\n 3.0")
case class Cbrt(child: Expression) extends UnaryMathExpression(math.cbrt, "CBRT")
+@ExpressionDescription(
+ usage = "_FUNC_(x) - Returns the smallest integer not smaller than x.",
+ extended = "> SELECT _FUNC_(-0.1);\n 0\n> SELECT _FUNC_(5);\n 5")
case class Ceil(child: Expression) extends UnaryMathExpression(math.ceil, "CEIL") {
override def dataType: DataType = child.dataType match {
case dt @ DecimalType.Fixed(_, 0) => dt
@@ -184,16 +207,26 @@ case class Ceil(child: Expression) extends UnaryMathExpression(math.ceil, "CEIL"
}
}
+@ExpressionDescription(
+ usage = "_FUNC_(x) - Returns the cosine of x.",
+ extended = "> SELECT _FUNC_(0);\n 1.0")
case class Cos(child: Expression) extends UnaryMathExpression(math.cos, "COS")
+@ExpressionDescription(
+ usage = "_FUNC_(x) - Returns the hyperbolic cosine of x.",
+ extended = "> SELECT _FUNC_(0);\n 1.0")
case class Cosh(child: Expression) extends UnaryMathExpression(math.cosh, "COSH")
/**
* Convert a num from one base to another
+ *
* @param numExpr the number to be converted
* @param fromBaseExpr from which base
* @param toBaseExpr to which base
*/
+@ExpressionDescription(
+ usage = "_FUNC_(num, from_base, to_base) - Convert num from from_base to to_base.",
+ extended = "> SELECT _FUNC_('100', 2, 10);\n '4'\n> SELECT _FUNC_(-10, 16, -10);\n '16'")
case class Conv(numExpr: Expression, fromBaseExpr: Expression, toBaseExpr: Expression)
extends TernaryExpression with ImplicitCastInputTypes {
@@ -222,10 +255,19 @@ case class Conv(numExpr: Expression, fromBaseExpr: Expression, toBaseExpr: Expre
}
}
+@ExpressionDescription(
+ usage = "_FUNC_(x) - Returns e to the power of x.",
+ extended = "> SELECT _FUNC_(0);\n 1.0")
case class Exp(child: Expression) extends UnaryMathExpression(math.exp, "EXP")
+@ExpressionDescription(
+ usage = "_FUNC_(x) - Returns exp(x) - 1.",
+ extended = "> SELECT _FUNC_(0);\n 0.0")
case class Expm1(child: Expression) extends UnaryMathExpression(math.expm1, "EXPM1")
+@ExpressionDescription(
+ usage = "_FUNC_(x) - Returns the largest integer not greater than x.",
+ extended = "> SELECT _FUNC_(-0.1);\n -1\n> SELECT _FUNC_(5);\n 5")
case class Floor(child: Expression) extends UnaryMathExpression(math.floor, "FLOOR") {
override def dataType: DataType = child.dataType match {
case dt @ DecimalType.Fixed(_, 0) => dt
@@ -283,6 +325,9 @@ object Factorial {
)
}
+@ExpressionDescription(
+ usage = "_FUNC_(n) - Returns n factorial for n is [0..20]. Otherwise, NULL.",
+ extended = "> SELECT _FUNC_(5);\n 120")
case class Factorial(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
override def inputTypes: Seq[DataType] = Seq(IntegerType)
@@ -315,8 +360,14 @@ case class Factorial(child: Expression) extends UnaryExpression with ImplicitCas
}
}
+@ExpressionDescription(
+ usage = "_FUNC_(x) - Returns the natural logarithm of x with base e.",
+ extended = "> SELECT _FUNC_(1);\n 0.0")
case class Log(child: Expression) extends UnaryLogExpression(math.log, "LOG")
+@ExpressionDescription(
+ usage = "_FUNC_(x) - Returns the logarithm of x with base 2.",
+ extended = "> SELECT _FUNC_(2);\n 1.0")
case class Log2(child: Expression)
extends UnaryLogExpression((x: Double) => math.log(x) / math.log(2), "LOG2") {
override def genCode(ctx: CodegenContext, ev: ExprCode): String = {
@@ -332,36 +383,72 @@ case class Log2(child: Expression)
}
}
+@ExpressionDescription(
+ usage = "_FUNC_(x) - Returns the logarithm of x with base 10.",
+ extended = "> SELECT _FUNC_(10);\n 1.0")
case class Log10(child: Expression) extends UnaryLogExpression(math.log10, "LOG10")
+@ExpressionDescription(
+ usage = "_FUNC_(x) - Returns log(1 + x).",
+ extended = "> SELECT _FUNC_(0);\n 0.0")
case class Log1p(child: Expression) extends UnaryLogExpression(math.log1p, "LOG1P") {
protected override val yAsymptote: Double = -1.0
}
+@ExpressionDescription(
+ usage = "_FUNC_(x, d) - Return the rounded x at d decimal places.",
+ extended = "> SELECT _FUNC_(12.3456, 1);\n 12.3")
case class Rint(child: Expression) extends UnaryMathExpression(math.rint, "ROUND") {
override def funcName: String = "rint"
}
+@ExpressionDescription(
+ usage = "_FUNC_(x) - Returns the sign of x.",
+ extended = "> SELECT _FUNC_(40);\n 1.0")
case class Signum(child: Expression) extends UnaryMathExpression(math.signum, "SIGNUM")
+@ExpressionDescription(
+ usage = "_FUNC_(x) - Returns the sine of x.",
+ extended = "> SELECT _FUNC_(0);\n 0.0")
case class Sin(child: Expression) extends UnaryMathExpression(math.sin, "SIN")
+@ExpressionDescription(
+ usage = "_FUNC_(x) - Returns the hyperbolic sine of x.",
+ extended = "> SELECT _FUNC_(0);\n 0.0")
case class Sinh(child: Expression) extends UnaryMathExpression(math.sinh, "SINH")
+@ExpressionDescription(
+ usage = "_FUNC_(x) - Returns the square root of x.",
+ extended = "> SELECT _FUNC_(4);\n 2.0")
case class Sqrt(child: Expression) extends UnaryMathExpression(math.sqrt, "SQRT")
+@ExpressionDescription(
+ usage = "_FUNC_(x) - Returns the tangent of x.",
+ extended = "> SELECT _FUNC_(0);\n 0.0")
case class Tan(child: Expression) extends UnaryMathExpression(math.tan, "TAN")
+@ExpressionDescription(
+ usage = "_FUNC_(x) - Returns the hyperbolic tangent of x.",
+ extended = "> SELECT _FUNC_(0);\n 0.0")
case class Tanh(child: Expression) extends UnaryMathExpression(math.tanh, "TANH")
+@ExpressionDescription(
+ usage = "_FUNC_(x) - Converts radians to degrees.",
+ extended = "> SELECT _FUNC_(3.141592653589793);\n 180.0")
case class ToDegrees(child: Expression) extends UnaryMathExpression(math.toDegrees, "DEGREES") {
override def funcName: String = "toDegrees"
}
+@ExpressionDescription(
+ usage = "_FUNC_(x) - Converts degrees to radians.",
+ extended = "> SELECT _FUNC_(180);\n 3.141592653589793")
case class ToRadians(child: Expression) extends UnaryMathExpression(math.toRadians, "RADIANS") {
override def funcName: String = "toRadians"
}
+@ExpressionDescription(
+ usage = "_FUNC_(x) - Returns x in binary.",
+ extended = "> SELECT _FUNC_(13);\n '1101'")
case class Bin(child: Expression)
extends UnaryExpression with Serializable with ImplicitCastInputTypes {
@@ -453,6 +540,9 @@ object Hex {
* Otherwise if the number is a STRING, it converts each character into its hex representation
* and returns the resulting STRING. Negative numbers would be treated as two's complement.
*/
+@ExpressionDescription(
+ usage = "_FUNC_(x) - Convert the argument to hexadecimal.",
+ extended = "> SELECT _FUNC_(17);\n '11'\n> SELECT _FUNC_('Spark SQL');\n '537061726B2053514C'")
case class Hex(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
override def inputTypes: Seq[AbstractDataType] =
@@ -481,6 +571,9 @@ case class Hex(child: Expression) extends UnaryExpression with ImplicitCastInput
* Performs the inverse operation of HEX.
* Resulting characters are returned as a byte array.
*/
+@ExpressionDescription(
+ usage = "_FUNC_(x) - Converts hexadecimal argument to binary.",
+ extended = "> SELECT decode(_FUNC_('537061726B2053514C'),'UTF-8');\n 'Spark SQL'")
case class Unhex(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
override def inputTypes: Seq[AbstractDataType] = Seq(StringType)
@@ -509,7 +602,9 @@ case class Unhex(child: Expression) extends UnaryExpression with ImplicitCastInp
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
-
+@ExpressionDescription(
+ usage = "_FUNC_(x,y) - Returns the arc tangent2.",
+ extended = "> SELECT _FUNC_(0, 0);\n 0.0")
case class Atan2(left: Expression, right: Expression)
extends BinaryMathExpression(math.atan2, "ATAN2") {
@@ -523,6 +618,9 @@ case class Atan2(left: Expression, right: Expression)
}
}
+@ExpressionDescription(
+ usage = "_FUNC_(x1, x2) - Raise x1 to the power of x2.",
+ extended = "> SELECT _FUNC_(2, 3);\n 8.0")
case class Pow(left: Expression, right: Expression)
extends BinaryMathExpression(math.pow, "POWER") {
override def genCode(ctx: CodegenContext, ev: ExprCode): String = {
@@ -532,10 +630,14 @@ case class Pow(left: Expression, right: Expression)
/**
- * Bitwise unsigned left shift.
+ * Bitwise left shift.
+ *
* @param left the base number to shift.
* @param right number of bits to left shift.
*/
+@ExpressionDescription(
+ usage = "_FUNC_(a, b) - Bitwise left shift.",
+ extended = "> SELECT _FUNC_(2, 1);\n 4")
case class ShiftLeft(left: Expression, right: Expression)
extends BinaryExpression with ImplicitCastInputTypes {
@@ -558,10 +660,14 @@ case class ShiftLeft(left: Expression, right: Expression)
/**
- * Bitwise unsigned left shift.
+ * Bitwise right shift.
+ *
* @param left the base number to shift.
- * @param right number of bits to left shift.
+ * @param right number of bits to right shift.
*/
+@ExpressionDescription(
+ usage = "_FUNC_(a, b) - Bitwise right shift.",
+ extended = "> SELECT _FUNC_(4, 1);\n 2")
case class ShiftRight(left: Expression, right: Expression)
extends BinaryExpression with ImplicitCastInputTypes {
@@ -585,9 +691,13 @@ case class ShiftRight(left: Expression, right: Expression)
/**
* Bitwise unsigned right shift, for integer and long data type.
+ *
* @param left the base number.
* @param right the number of bits to right shift.
*/
+@ExpressionDescription(
+ usage = "_FUNC_(a, b) - Bitwise unsigned right shift.",
+ extended = "> SELECT _FUNC_(4, 1);\n 2")
case class ShiftRightUnsigned(left: Expression, right: Expression)
extends BinaryExpression with ImplicitCastInputTypes {
@@ -608,16 +718,22 @@ case class ShiftRightUnsigned(left: Expression, right: Expression)
}
}
-
+@ExpressionDescription(
+ usage = "_FUNC_(a, b) - Returns sqrt(a**2 + b**2).",
+ extended = "> SELECT _FUNC_(3, 4);\n 5.0")
case class Hypot(left: Expression, right: Expression)
extends BinaryMathExpression(math.hypot, "HYPOT")
/**
* Computes the logarithm of a number.
+ *
* @param left the logarithm base, default to e.
* @param right the number to compute the logarithm of.
*/
+@ExpressionDescription(
+ usage = "_FUNC_(b, x) - Returns the logarithm of x with base b.",
+ extended = "> SELECT _FUNC_(10, 100);\n 2.0")
case class Logarithm(left: Expression, right: Expression)
extends BinaryMathExpression((c1, c2) => math.log(c2) / math.log(c1), "LOG") {
@@ -674,6 +790,9 @@ case class Logarithm(left: Expression, right: Expression)
* @param child expr to be round, all [[NumericType]] is allowed as Input
* @param scale new scale to be round to, this should be a constant int at runtime
*/
+@ExpressionDescription(
+ usage = "_FUNC_(x, d) - Round x to d decimal places.",
+ extended = "> SELECT _FUNC_(12.3456, 1);\n 12.3")
case class Round(child: Expression, scale: Expression)
extends BinaryExpression with ImplicitCastInputTypes {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
index e8a3e129b4..4bd918ed01 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
@@ -438,6 +438,8 @@ abstract class InterpretedHashFunction {
* We should use this hash function for both shuffle and bucket, so that we can guarantee shuffle
* and bucketing have same data distribution.
*/
+@ExpressionDescription(
+ usage = "_FUNC_(a1, a2, ...) - Returns a hash value of the arguments.")
case class Murmur3Hash(children: Seq[Expression], seed: Int) extends HashExpression[Int] {
def this(arguments: Seq[Expression]) = this(arguments, 42)
@@ -467,8 +469,8 @@ object Murmur3HashFunction extends InterpretedHashFunction {
}
/**
- * Print the result of an expression to stderr (used for debugging codegen).
- */
+ * Print the result of an expression to stderr (used for debugging codegen).
+ */
case class PrintToStderr(child: Expression) extends UnaryExpression {
override def dataType: DataType = child.dataType
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
index a5b5758167..78310fb2f1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
@@ -97,7 +97,7 @@ trait NamedExpression extends Expression {
}
}
-abstract class Attribute extends LeafExpression with NamedExpression {
+abstract class Attribute extends LeafExpression with NamedExpression with NullIntolerant {
override def references: AttributeSet = AttributeSet(this)
@@ -329,10 +329,12 @@ case class PrettyAttribute(
override def withName(newName: String): Attribute = throw new UnsupportedOperationException
override def qualifier: Option[String] = throw new UnsupportedOperationException
override def exprId: ExprId = throw new UnsupportedOperationException
- override def nullable: Boolean = throw new UnsupportedOperationException
+ override def nullable: Boolean = true
}
object VirtualColumn {
- val groupingIdName: String = "grouping__id"
+ // The attribute name used by Hive, which has different result than Spark, deprecated.
+ val hiveGroupingIdName: String = "grouping__id"
+ val groupingIdName: String = "spark_grouping_id"
val groupingIdAttribute: UnresolvedAttribute = UnresolvedAttribute(groupingIdName)
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala
index e22026d584..6a45249943 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala
@@ -34,6 +34,9 @@ import org.apache.spark.sql.types._
* coalesce(null, null, null) => null
* }}}
*/
+@ExpressionDescription(
+ usage = "_FUNC_(a1, a2, ...) - Returns the first non-null argument if exists. Otherwise, NULL.",
+ extended = "> SELECT _FUNC_(NULL, 1, NULL);\n 1")
case class Coalesce(children: Seq[Expression]) extends Expression {
/** Coalesce is nullable if all of its children are nullable, or if it has no children. */
@@ -89,6 +92,8 @@ case class Coalesce(children: Seq[Expression]) extends Expression {
/**
* Evaluates to `true` iff it's NaN.
*/
+@ExpressionDescription(
+ usage = "_FUNC_(a) - Returns true if a is NaN and false otherwise.")
case class IsNaN(child: Expression) extends UnaryExpression
with Predicate with ImplicitCastInputTypes {
@@ -126,6 +131,8 @@ case class IsNaN(child: Expression) extends UnaryExpression
* An Expression evaluates to `left` iff it's not NaN, or evaluates to `right` otherwise.
* This Expression is useful for mapping NaN values to null.
*/
+@ExpressionDescription(
+ usage = "_FUNC_(a,b) - Returns a iff it's not NaN, or b otherwise.")
case class NaNvl(left: Expression, right: Expression)
extends BinaryExpression with ImplicitCastInputTypes {
@@ -180,6 +187,8 @@ case class NaNvl(left: Expression, right: Expression)
/**
* An expression that is evaluated to true if the input is null.
*/
+@ExpressionDescription(
+ usage = "_FUNC_(a) - Returns true if a is NULL and false otherwise.")
case class IsNull(child: Expression) extends UnaryExpression with Predicate {
override def nullable: Boolean = false
@@ -201,6 +210,8 @@ case class IsNull(child: Expression) extends UnaryExpression with Predicate {
/**
* An expression that is evaluated to true if the input is not null.
*/
+@ExpressionDescription(
+ usage = "_FUNC_(a) - Returns true if a is not NULL and false otherwise.")
case class IsNotNull(child: Expression) extends UnaryExpression with Predicate {
override def nullable: Boolean = false
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
index 07b67a0240..26b1ff39b3 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
@@ -17,6 +17,8 @@
package org.apache.spark.sql.catalyst.expressions
+import java.lang.reflect.Modifier
+
import scala.annotation.tailrec
import scala.language.existentials
import scala.reflect.ClassTag
@@ -112,23 +114,23 @@ case class Invoke(
arguments: Seq[Expression] = Nil) extends Expression with NonSQLExpression {
override def nullable: Boolean = true
- override def children: Seq[Expression] = arguments.+:(targetObject)
+ override def children: Seq[Expression] = targetObject +: arguments
override def eval(input: InternalRow): Any =
throw new UnsupportedOperationException("Only code-generated evaluation is supported.")
- lazy val method = targetObject.dataType match {
+ @transient lazy val method = targetObject.dataType match {
case ObjectType(cls) =>
- cls
- .getMethods
- .find(_.getName == functionName)
- .getOrElse(sys.error(s"Couldn't find $functionName on $cls"))
- .getReturnType
- .getName
- case _ => ""
+ val m = cls.getMethods.find(_.getName == functionName)
+ if (m.isEmpty) {
+ sys.error(s"Couldn't find $functionName on $cls")
+ } else {
+ m
+ }
+ case _ => None
}
- lazy val unboxer = (dataType, method) match {
+ lazy val unboxer = (dataType, method.map(_.getReturnType.getName).getOrElse("")) match {
case (IntegerType, "java.lang.Object") => (s: String) =>
s"((java.lang.Integer)$s).intValue()"
case (LongType, "java.lang.Object") => (s: String) =>
@@ -155,21 +157,31 @@ case class Invoke(
// If the function can return null, we do an extra check to make sure our null bit is still set
// correctly.
val objNullCheck = if (ctx.defaultValue(dataType) == "null") {
- s"${ev.isNull} = ${ev.value} == null;"
+ s"boolean ${ev.isNull} = ${ev.value} == null;"
} else {
+ ev.isNull = obj.isNull
""
}
val value = unboxer(s"${obj.value}.$functionName($argString)")
+ val evaluate = if (method.forall(_.getExceptionTypes.isEmpty)) {
+ s"$javaType ${ev.value} = ${obj.isNull} ? ${ctx.defaultValue(dataType)} : ($javaType) $value;"
+ } else {
+ s"""
+ $javaType ${ev.value} = ${ctx.defaultValue(javaType)};
+ try {
+ ${ev.value} = ${obj.isNull} ? ${ctx.defaultValue(javaType)} : ($javaType) $value;
+ } catch (Exception e) {
+ org.apache.spark.unsafe.Platform.throwException(e);
+ }
+ """
+ }
+
s"""
${obj.code}
${argGen.map(_.code).mkString("\n")}
-
- boolean ${ev.isNull} = ${obj.isNull};
- $javaType ${ev.value} =
- ${ev.isNull} ?
- ${ctx.defaultValue(dataType)} : ($javaType) $value;
+ $evaluate
$objNullCheck
"""
}
@@ -214,6 +226,16 @@ case class NewInstance(
override def children: Seq[Expression] = arguments
+ override lazy val resolved: Boolean = {
+ // If the class to construct is an inner class, we need to get its outer pointer, or this
+ // expression should be regarded as unresolved.
+ // Note that static inner classes (e.g., inner classes within Scala objects) don't need
+ // outer pointer registration.
+ val needOuterPointer =
+ outerPointer.isEmpty && cls.isMemberClass && !Modifier.isStatic(cls.getModifiers)
+ childrenResolved && !needOuterPointer
+ }
+
override def eval(input: InternalRow): Any =
throw new UnsupportedOperationException("Only code-generated evaluation is supported.")
@@ -424,6 +446,8 @@ case class MapObjects private(
override def genCode(ctx: CodegenContext, ev: ExprCode): String = {
val javaType = ctx.javaType(dataType)
val elementJavaType = ctx.javaType(loopVar.dataType)
+ ctx.addMutableState("boolean", loopVar.isNull, "")
+ ctx.addMutableState(elementJavaType, loopVar.value, "")
val genInputData = inputData.gen(ctx)
val genFunction = lambdaFunction.gen(ctx)
val dataLength = ctx.freshName("dataLength")
@@ -444,9 +468,9 @@ case class MapObjects private(
}
val loopNullCheck = if (primitiveElement) {
- s"boolean ${loopVar.isNull} = ${genInputData.value}.isNullAt($loopIndex);"
+ s"${loopVar.isNull} = ${genInputData.value}.isNullAt($loopIndex);"
} else {
- s"boolean ${loopVar.isNull} = ${genInputData.isNull} || ${loopVar.value} == null;"
+ s"${loopVar.isNull} = ${genInputData.isNull} || ${loopVar.value} == null;"
}
s"""
@@ -462,7 +486,7 @@ case class MapObjects private(
int $loopIndex = 0;
while ($loopIndex < $dataLength) {
- $elementJavaType ${loopVar.value} =
+ ${loopVar.value} =
($elementJavaType)${genInputData.value}${itemAccessor(loopIndex)};
$loopNullCheck
@@ -502,22 +526,26 @@ case class CreateExternalRow(children: Seq[Expression], schema: StructType)
override def genCode(ctx: CodegenContext, ev: ExprCode): String = {
val rowClass = classOf[GenericRowWithSchema].getName
val values = ctx.freshName("values")
- val schemaField = ctx.addReferenceObj("schema", schema)
- s"""
- boolean ${ev.isNull} = false;
- final Object[] $values = new Object[${children.size}];
- """ +
- children.zipWithIndex.map { case (e, i) =>
- val eval = e.gen(ctx)
- eval.code + s"""
+ ctx.addMutableState("Object[]", values, "")
+
+ val childrenCodes = children.zipWithIndex.map { case (e, i) =>
+ val eval = e.gen(ctx)
+ eval.code + s"""
if (${eval.isNull}) {
$values[$i] = null;
} else {
$values[$i] = ${eval.value};
}
"""
- }.mkString("\n") +
- s"final ${classOf[Row].getName} ${ev.value} = new $rowClass($values, this.$schemaField);"
+ }
+ val childrenCode = ctx.splitExpressions(ctx.INPUT_ROW, childrenCodes)
+ val schemaField = ctx.addReferenceObj("schema", schema)
+ s"""
+ boolean ${ev.isNull} = false;
+ $values = new Object[${children.size}];
+ $childrenCode
+ final ${classOf[Row].getName} ${ev.value} = new $rowClass($values, this.$schemaField);
+ """
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala
index f1fa13daa7..23baa6f783 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala
@@ -92,4 +92,11 @@ package object expressions {
StructType(attrs.map(a => StructField(a.name, a.dataType, a.nullable)))
}
}
+
+ /**
+ * When an expression inherits this, meaning the expression is null intolerant (i.e. any null
+ * input will result in null output). We will use this information during constructing IsNotNull
+ * constraints.
+ */
+ trait NullIntolerant
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
index 20818bfb1a..38f1210a4e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
@@ -88,9 +88,10 @@ trait PredicateHelper {
expr.references.subsetOf(plan.outputSet)
}
-
+@ExpressionDescription(
+ usage = "_FUNC_ a - Logical not")
case class Not(child: Expression)
- extends UnaryExpression with Predicate with ImplicitCastInputTypes {
+ extends UnaryExpression with Predicate with ImplicitCastInputTypes with NullIntolerant {
override def toString: String = s"NOT $child"
@@ -109,6 +110,8 @@ case class Not(child: Expression)
/**
* Evaluates to `true` if `list` contains `value`.
*/
+@ExpressionDescription(
+ usage = "expr _FUNC_(val1, val2, ...) - Returns true if expr equals to any valN.")
case class In(value: Expression, list: Seq[Expression]) extends Predicate
with ImplicitCastInputTypes {
@@ -243,6 +246,8 @@ case class InSet(child: Expression, hset: Set[Any]) extends UnaryExpression with
}
}
+@ExpressionDescription(
+ usage = "a _FUNC_ b - Logical AND.")
case class And(left: Expression, right: Expression) extends BinaryOperator with Predicate {
override def inputType: AbstractDataType = BooleanType
@@ -274,26 +279,40 @@ case class And(left: Expression, right: Expression) extends BinaryOperator with
val eval2 = right.gen(ctx)
// The result should be `false`, if any of them is `false` whenever the other is null or not.
- s"""
- ${eval1.code}
- boolean ${ev.isNull} = false;
- boolean ${ev.value} = false;
+ if (!left.nullable && !right.nullable) {
+ ev.isNull = "false"
+ s"""
+ ${eval1.code}
+ boolean ${ev.value} = false;
- if (!${eval1.isNull} && !${eval1.value}) {
- } else {
- ${eval2.code}
- if (!${eval2.isNull} && !${eval2.value}) {
- } else if (!${eval1.isNull} && !${eval2.isNull}) {
- ${ev.value} = true;
+ if (${eval1.value}) {
+ ${eval2.code}
+ ${ev.value} = ${eval2.value};
+ }
+ """
+ } else {
+ s"""
+ ${eval1.code}
+ boolean ${ev.isNull} = false;
+ boolean ${ev.value} = false;
+
+ if (!${eval1.isNull} && !${eval1.value}) {
} else {
- ${ev.isNull} = true;
+ ${eval2.code}
+ if (!${eval2.isNull} && !${eval2.value}) {
+ } else if (!${eval1.isNull} && !${eval2.isNull}) {
+ ${ev.value} = true;
+ } else {
+ ${ev.isNull} = true;
+ }
}
- }
- """
+ """
+ }
}
}
-
+@ExpressionDescription(
+ usage = "a _FUNC_ b - Logical OR.")
case class Or(left: Expression, right: Expression) extends BinaryOperator with Predicate {
override def inputType: AbstractDataType = BooleanType
@@ -325,22 +344,35 @@ case class Or(left: Expression, right: Expression) extends BinaryOperator with P
val eval2 = right.gen(ctx)
// The result should be `true`, if any of them is `true` whenever the other is null or not.
- s"""
- ${eval1.code}
- boolean ${ev.isNull} = false;
- boolean ${ev.value} = true;
+ if (!left.nullable && !right.nullable) {
+ ev.isNull = "false"
+ s"""
+ ${eval1.code}
+ boolean ${ev.value} = true;
- if (!${eval1.isNull} && ${eval1.value}) {
- } else {
- ${eval2.code}
- if (!${eval2.isNull} && ${eval2.value}) {
- } else if (!${eval1.isNull} && !${eval2.isNull}) {
- ${ev.value} = false;
+ if (!${eval1.value}) {
+ ${eval2.code}
+ ${ev.value} = ${eval2.value};
+ }
+ """
+ } else {
+ s"""
+ ${eval1.code}
+ boolean ${ev.isNull} = false;
+ boolean ${ev.value} = true;
+
+ if (!${eval1.isNull} && ${eval1.value}) {
} else {
- ${ev.isNull} = true;
+ ${eval2.code}
+ if (!${eval2.isNull} && ${eval2.value}) {
+ } else if (!${eval1.isNull} && !${eval2.isNull}) {
+ ${ev.value} = false;
+ } else {
+ ${ev.isNull} = true;
+ }
}
- }
- """
+ """
+ }
}
}
@@ -375,8 +407,10 @@ private[sql] object Equality {
}
}
-
-case class EqualTo(left: Expression, right: Expression) extends BinaryComparison {
+@ExpressionDescription(
+ usage = "a _FUNC_ b - Returns TRUE if a equals b and false otherwise.")
+case class EqualTo(left: Expression, right: Expression)
+ extends BinaryComparison with NullIntolerant {
override def inputType: AbstractDataType = AnyDataType
@@ -399,7 +433,9 @@ case class EqualTo(left: Expression, right: Expression) extends BinaryComparison
}
}
-
+@ExpressionDescription(
+ usage = """a _FUNC_ b - Returns same result with EQUAL(=) operator for non-null operands,
+ but returns TRUE if both are NULL, FALSE if one of the them is NULL.""")
case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComparison {
override def inputType: AbstractDataType = AnyDataType
@@ -440,8 +476,10 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp
}
}
-
-case class LessThan(left: Expression, right: Expression) extends BinaryComparison {
+@ExpressionDescription(
+ usage = "a _FUNC_ b - Returns TRUE if a is less than b.")
+case class LessThan(left: Expression, right: Expression)
+ extends BinaryComparison with NullIntolerant {
override def inputType: AbstractDataType = TypeCollection.Ordered
@@ -452,8 +490,10 @@ case class LessThan(left: Expression, right: Expression) extends BinaryCompariso
protected override def nullSafeEval(input1: Any, input2: Any): Any = ordering.lt(input1, input2)
}
-
-case class LessThanOrEqual(left: Expression, right: Expression) extends BinaryComparison {
+@ExpressionDescription(
+ usage = "a _FUNC_ b - Returns TRUE if a is not greater than b.")
+case class LessThanOrEqual(left: Expression, right: Expression)
+ extends BinaryComparison with NullIntolerant {
override def inputType: AbstractDataType = TypeCollection.Ordered
@@ -464,8 +504,10 @@ case class LessThanOrEqual(left: Expression, right: Expression) extends BinaryCo
protected override def nullSafeEval(input1: Any, input2: Any): Any = ordering.lteq(input1, input2)
}
-
-case class GreaterThan(left: Expression, right: Expression) extends BinaryComparison {
+@ExpressionDescription(
+ usage = "a _FUNC_ b - Returns TRUE if a is greater than b.")
+case class GreaterThan(left: Expression, right: Expression)
+ extends BinaryComparison with NullIntolerant {
override def inputType: AbstractDataType = TypeCollection.Ordered
@@ -476,8 +518,10 @@ case class GreaterThan(left: Expression, right: Expression) extends BinaryCompar
protected override def nullSafeEval(input1: Any, input2: Any): Any = ordering.gt(input1, input2)
}
-
-case class GreaterThanOrEqual(left: Expression, right: Expression) extends BinaryComparison {
+@ExpressionDescription(
+ usage = "a _FUNC_ b - Returns TRUE if a is not smaller than b.")
+case class GreaterThanOrEqual(left: Expression, right: Expression)
+ extends BinaryComparison with NullIntolerant {
override def inputType: AbstractDataType = TypeCollection.Ordered
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala
index 6be3cbcae6..1ec092a5be 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala
@@ -55,6 +55,8 @@ abstract class RDG extends LeafExpression with Nondeterministic {
}
/** Generate a random column with i.i.d. uniformly distributed values in [0, 1). */
+@ExpressionDescription(
+ usage = "_FUNC_(a) - Returns a random column with i.i.d. uniformly distributed values in [0, 1).")
case class Rand(seed: Long) extends RDG {
override protected def evalInternal(input: InternalRow): Double = rng.nextDouble()
@@ -78,6 +80,8 @@ case class Rand(seed: Long) extends RDG {
}
/** Generate a random column with i.i.d. gaussian random distribution. */
+@ExpressionDescription(
+ usage = "_FUNC_(a) - Returns a random column with i.i.d. gaussian random distribution.")
case class Randn(seed: Long) extends RDG {
override protected def evalInternal(input: InternalRow): Double = rng.nextGaussian()
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala
index b68009331b..85a5429263 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala
@@ -67,6 +67,8 @@ trait StringRegexExpression extends ImplicitCastInputTypes {
/**
* Simple RegEx pattern matching function
*/
+@ExpressionDescription(
+ usage = "str _FUNC_ pattern - Returns true if str matches pattern and false otherwise.")
case class Like(left: Expression, right: Expression)
extends BinaryExpression with StringRegexExpression {
@@ -117,7 +119,8 @@ case class Like(left: Expression, right: Expression)
}
}
-
+@ExpressionDescription(
+ usage = "str _FUNC_ regexp - Returns true if str matches regexp and false otherwise.")
case class RLike(left: Expression, right: Expression)
extends BinaryExpression with StringRegexExpression {
@@ -169,6 +172,9 @@ case class RLike(left: Expression, right: Expression)
/**
* Splits str around pat (pattern is a regular expression).
*/
+@ExpressionDescription(
+ usage = "_FUNC_(str, regex) - Splits str around occurrences that match regex",
+ extended = "> SELECT _FUNC_('oneAtwoBthreeC', '[ABC]');\n ['one', 'two', 'three']")
case class StringSplit(str: Expression, pattern: Expression)
extends BinaryExpression with ImplicitCastInputTypes {
@@ -198,6 +204,9 @@ case class StringSplit(str: Expression, pattern: Expression)
*
* NOTE: this expression is not THREAD-SAFE, as it has some internal mutable status.
*/
+@ExpressionDescription(
+ usage = "_FUNC_(str, regexp, rep) - replace all substrings of str that match regexp with rep.",
+ extended = "> SELECT _FUNC_('100-200', '(\\d+)', 'num');\n 'num-num'")
case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expression)
extends TernaryExpression with ImplicitCastInputTypes {
@@ -289,6 +298,9 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio
*
* NOTE: this expression is not THREAD-SAFE, as it has some internal mutable status.
*/
+@ExpressionDescription(
+ usage = "_FUNC_(str, regexp[, idx]) - extracts a group that matches regexp.",
+ extended = "> SELECT _FUNC_('100-200', '(\\d+)-(\\d+)', 1);\n '100'")
case class RegExpExtract(subject: Expression, regexp: Expression, idx: Expression)
extends TernaryExpression with ImplicitCastInputTypes {
def this(s: Expression, r: Expression) = this(s, r, Literal(1))
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala
index be6b2530ef..93a8278528 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala
@@ -164,7 +164,7 @@ trait BaseGenericInternalRow extends InternalRow {
abstract class MutableRow extends InternalRow {
def setNullAt(i: Int): Unit
- def update(i: Int, value: Any)
+ def update(i: Int, value: Any): Unit
// default implementation (slow)
def setBoolean(i: Int, value: Boolean): Unit = { update(i, value) }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
index 3ee19cc4ad..a17482697d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
@@ -35,6 +35,9 @@ import org.apache.spark.unsafe.types.{ByteArray, UTF8String}
* An expression that concatenates multiple input strings into a single string.
* If any input is null, concat returns null.
*/
+@ExpressionDescription(
+ usage = "_FUNC_(str1, str2, ..., strN) - Returns the concatenation of str1, str2, ..., strN",
+ extended = "> SELECT _FUNC_('Spark','SQL');\n 'SparkSQL'")
case class Concat(children: Seq[Expression]) extends Expression with ImplicitCastInputTypes {
override def inputTypes: Seq[AbstractDataType] = Seq.fill(children.size)(StringType)
@@ -70,6 +73,10 @@ case class Concat(children: Seq[Expression]) extends Expression with ImplicitCas
*
* Returns null if the separator is null. Otherwise, concat_ws skips all null values.
*/
+@ExpressionDescription(
+ usage =
+ "_FUNC_(sep, [str | array(str)]+) - Returns the concatenation of the strings separated by sep.",
+ extended = "> SELECT _FUNC_(' ', Spark', 'SQL');\n 'Spark SQL'")
case class ConcatWs(children: Seq[Expression])
extends Expression with ImplicitCastInputTypes {
@@ -188,7 +195,7 @@ case class Upper(child: Expression)
*/
@ExpressionDescription(
usage = "_FUNC_(str) - Returns str with all characters changed to lowercase",
- extended = "> SELECT _FUNC_('SparkSql');\n'sparksql'")
+ extended = "> SELECT _FUNC_('SparkSql');\n 'sparksql'")
case class Lower(child: Expression) extends UnaryExpression with String2StringExpression {
override def convert(v: UTF8String): UTF8String = v.toLowerCase
@@ -270,6 +277,11 @@ object StringTranslate {
* The translate will happen when any character in the string matching with the character
* in the `matchingExpr`.
*/
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+ usage = """_FUNC_(input, from, to) - Translates the input string by replacing the characters present in the from string with the corresponding characters in the to string""",
+ extended = "> SELECT _FUNC_('AaBbCc', 'abc', '123');\n 'A1B2C3'")
+// scalastyle:on line.size.limit
case class StringTranslate(srcExpr: Expression, matchingExpr: Expression, replaceExpr: Expression)
extends TernaryExpression with ImplicitCastInputTypes {
@@ -325,6 +337,12 @@ case class StringTranslate(srcExpr: Expression, matchingExpr: Expression, replac
* delimited list (right). Returns 0, if the string wasn't found or if the given
* string (left) contains a comma.
*/
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+ usage = """_FUNC_(str, str_array) - Returns the index (1-based) of the given string (left) in the comma-delimited list (right).
+ Returns 0, if the string wasn't found or if the given string (left) contains a comma.""",
+ extended = "> SELECT _FUNC_('ab','abc,b,ab,c,def');\n 3")
+// scalastyle:on
case class FindInSet(left: Expression, right: Expression) extends BinaryExpression
with ImplicitCastInputTypes {
@@ -347,6 +365,9 @@ case class FindInSet(left: Expression, right: Expression) extends BinaryExpressi
/**
* A function that trim the spaces from both ends for the specified string.
*/
+@ExpressionDescription(
+ usage = "_FUNC_(str) - Removes the leading and trailing space characters from str.",
+ extended = "> SELECT _FUNC_(' SparkSQL ');\n 'SparkSQL'")
case class StringTrim(child: Expression)
extends UnaryExpression with String2StringExpression {
@@ -362,6 +383,9 @@ case class StringTrim(child: Expression)
/**
* A function that trim the spaces from left end for given string.
*/
+@ExpressionDescription(
+ usage = "_FUNC_(str) - Removes the leading space characters from str.",
+ extended = "> SELECT _FUNC_(' SparkSQL ');\n 'SparkSQL '")
case class StringTrimLeft(child: Expression)
extends UnaryExpression with String2StringExpression {
@@ -377,6 +401,9 @@ case class StringTrimLeft(child: Expression)
/**
* A function that trim the spaces from right end for given string.
*/
+@ExpressionDescription(
+ usage = "_FUNC_(str) - Removes the trailing space characters from str.",
+ extended = "> SELECT _FUNC_(' SparkSQL ');\n ' SparkSQL'")
case class StringTrimRight(child: Expression)
extends UnaryExpression with String2StringExpression {
@@ -396,6 +423,9 @@ case class StringTrimRight(child: Expression)
*
* NOTE: that this is not zero based, but 1-based index. The first character in str has index 1.
*/
+@ExpressionDescription(
+ usage = "_FUNC_(str, substr) - Returns the (1-based) index of the first occurrence of substr in str.",
+ extended = "> SELECT _FUNC_('SparkSQL', 'SQL');\n 6")
case class StringInstr(str: Expression, substr: Expression)
extends BinaryExpression with ImplicitCastInputTypes {
@@ -422,6 +452,15 @@ case class StringInstr(str: Expression, substr: Expression)
* returned. If count is negative, every to the right of the final delimiter (counting from the
* right) is returned. substring_index performs a case-sensitive match when searching for delim.
*/
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+ usage = """_FUNC_(str, delim, count) - Returns the substring from str before count occurrences of the delimiter delim.
+ If count is positive, everything to the left of the final delimiter (counting from the
+ left) is returned. If count is negative, everything to the right of the final delimiter
+ (counting from the right) is returned. Substring_index performs a case-sensitive match
+ when searching for delim.""",
+ extended = "> SELECT _FUNC_('www.apache.org', '.', 2);\n 'www.apache'")
+// scalastyle:on line.size.limit
case class SubstringIndex(strExpr: Expression, delimExpr: Expression, countExpr: Expression)
extends TernaryExpression with ImplicitCastInputTypes {
@@ -445,6 +484,12 @@ case class SubstringIndex(strExpr: Expression, delimExpr: Expression, countExpr:
* A function that returns the position of the first occurrence of substr
* in given string after position pos.
*/
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+ usage = """_FUNC_(substr, str[, pos]) - Returns the position of the first occurrence of substr in str after position pos.
+ The given pos and return value are 1-based.""",
+ extended = "> SELECT _FUNC_('bar', 'foobarbar', 5);\n 7")
+// scalastyle:on line.size.limit
case class StringLocate(substr: Expression, str: Expression, start: Expression)
extends TernaryExpression with ImplicitCastInputTypes {
@@ -510,6 +555,11 @@ case class StringLocate(substr: Expression, str: Expression, start: Expression)
/**
* Returns str, left-padded with pad to a length of len.
*/
+@ExpressionDescription(
+ usage = """_FUNC_(str, len, pad) - Returns str, left-padded with pad to a length of len.
+ If str is longer than len, the return value is shortened to len characters.""",
+ extended = "> SELECT _FUNC_('hi', 5, '??');\n '???hi'\n" +
+ "> SELECT _FUNC_('hi', 1, '??');\n 'h'")
case class StringLPad(str: Expression, len: Expression, pad: Expression)
extends TernaryExpression with ImplicitCastInputTypes {
@@ -531,6 +581,11 @@ case class StringLPad(str: Expression, len: Expression, pad: Expression)
/**
* Returns str, right-padded with pad to a length of len.
*/
+@ExpressionDescription(
+ usage = """_FUNC_(str, len, pad) - Returns str, right-padded with pad to a length of len.
+ If str is longer than len, the return value is shortened to len characters.""",
+ extended = "> SELECT _FUNC_('hi', 5, '??');\n 'hi???'\n" +
+ "> SELECT _FUNC_('hi', 1, '??');\n 'h'")
case class StringRPad(str: Expression, len: Expression, pad: Expression)
extends TernaryExpression with ImplicitCastInputTypes {
@@ -552,6 +607,11 @@ case class StringRPad(str: Expression, len: Expression, pad: Expression)
/**
* Returns the input formatted according do printf-style format strings
*/
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+ usage = "_FUNC_(String format, Obj... args) - Returns a formatted string from printf-style format strings.",
+ extended = "> SELECT _FUNC_(\"Hello World %d %s\", 100, \"days\");\n 'Hello World 100 days'")
+// scalastyle:on line.size.limit
case class FormatString(children: Expression*) extends Expression with ImplicitCastInputTypes {
require(children.nonEmpty, "format_string() should take at least 1 argument")
@@ -618,25 +678,33 @@ case class FormatString(children: Expression*) extends Expression with ImplicitC
}
/**
- * Returns string, with the first letter of each word in uppercase.
+ * Returns string, with the first letter of each word in uppercase, all other letters in lowercase.
* Words are delimited by whitespace.
*/
+@ExpressionDescription(
+ usage =
+ """_FUNC_(str) - Returns str with the first letter of each word in uppercase.
+ All other letters are in lowercase. Words are delimited by white space.""",
+ extended = "> SELECT initcap('sPark sql');\n 'Spark Sql'")
case class InitCap(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
override def inputTypes: Seq[DataType] = Seq(StringType)
override def dataType: DataType = StringType
override def nullSafeEval(string: Any): Any = {
- string.asInstanceOf[UTF8String].toTitleCase
+ string.asInstanceOf[UTF8String].toLowerCase.toTitleCase
}
override def genCode(ctx: CodegenContext, ev: ExprCode): String = {
- defineCodeGen(ctx, ev, str => s"$str.toTitleCase()")
+ defineCodeGen(ctx, ev, str => s"$str.toLowerCase().toTitleCase()")
}
}
/**
* Returns the string which repeat the given string value n times.
*/
+@ExpressionDescription(
+ usage = "_FUNC_(str, n) - Returns the string which repeat the given string value n times.",
+ extended = "> SELECT _FUNC_('123', 2);\n '123123'")
case class StringRepeat(str: Expression, times: Expression)
extends BinaryExpression with ImplicitCastInputTypes {
@@ -659,6 +727,9 @@ case class StringRepeat(str: Expression, times: Expression)
/**
* Returns the reversed given string.
*/
+@ExpressionDescription(
+ usage = "_FUNC_(str) - Returns the reversed given string.",
+ extended = "> SELECT _FUNC_('Spark SQL');\n 'LQS krapS'")
case class StringReverse(child: Expression) extends UnaryExpression with String2StringExpression {
override def convert(v: UTF8String): UTF8String = v.reverse()
@@ -672,6 +743,9 @@ case class StringReverse(child: Expression) extends UnaryExpression with String2
/**
* Returns a n spaces string.
*/
+@ExpressionDescription(
+ usage = "_FUNC_(n) - Returns a n spaces string.",
+ extended = "> SELECT _FUNC_(2);\n ' '")
case class StringSpace(child: Expression)
extends UnaryExpression with ImplicitCastInputTypes {
@@ -694,7 +768,14 @@ case class StringSpace(child: Expression)
/**
* A function that takes a substring of its first argument starting at a given position.
* Defined for String and Binary types.
+ *
+ * NOTE: that this is not zero based, but 1-based index. The first character in str has index 1.
*/
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+ usage = "_FUNC_(str, pos[, len]) - Returns the substring of str that starts at pos and is of length len or the slice of byte array that starts at pos and is of length len.",
+ extended = "> SELECT _FUNC_('Spark SQL', 5);\n 'k SQL'\n> SELECT _FUNC_('Spark SQL', -3);\n 'SQL'\n> SELECT _FUNC_('Spark SQL', 5, 1);\n 'k'")
+// scalastyle:on line.size.limit
case class Substring(str: Expression, pos: Expression, len: Expression)
extends TernaryExpression with ImplicitCastInputTypes {
@@ -732,6 +813,9 @@ case class Substring(str: Expression, pos: Expression, len: Expression)
/**
* A function that return the length of the given string or binary expression.
*/
+@ExpressionDescription(
+ usage = "_FUNC_(str | binary) - Returns the length of str or number of bytes in binary data.",
+ extended = "> SELECT _FUNC_('Spark SQL');\n 9")
case class Length(child: Expression) extends UnaryExpression with ExpectsInputTypes {
override def dataType: DataType = IntegerType
override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(StringType, BinaryType))
@@ -752,6 +836,9 @@ case class Length(child: Expression) extends UnaryExpression with ExpectsInputTy
/**
* A function that return the Levenshtein distance between the two given strings.
*/
+@ExpressionDescription(
+ usage = "_FUNC_(str1, str2) - Returns the Levenshtein distance between the two given strings.",
+ extended = "> SELECT _FUNC_('kitten', 'sitting');\n 3")
case class Levenshtein(left: Expression, right: Expression) extends BinaryExpression
with ImplicitCastInputTypes {
@@ -770,6 +857,9 @@ case class Levenshtein(left: Expression, right: Expression) extends BinaryExpres
/**
* A function that return soundex code of the given string expression.
*/
+@ExpressionDescription(
+ usage = "_FUNC_(str) - Returns soundex code of the string.",
+ extended = "> SELECT _FUNC_('Miller');\n 'M460'")
case class SoundEx(child: Expression) extends UnaryExpression with ExpectsInputTypes {
override def dataType: DataType = StringType
@@ -786,6 +876,10 @@ case class SoundEx(child: Expression) extends UnaryExpression with ExpectsInputT
/**
* Returns the numeric value of the first character of str.
*/
+@ExpressionDescription(
+ usage = "_FUNC_(str) - Returns the numeric value of the first character of str.",
+ extended = "> SELECT _FUNC_('222');\n 50\n" +
+ "> SELECT _FUNC_(2);\n 50")
case class Ascii(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
override def dataType: DataType = IntegerType
@@ -817,6 +911,8 @@ case class Ascii(child: Expression) extends UnaryExpression with ImplicitCastInp
/**
* Converts the argument from binary to a base 64 string.
*/
+@ExpressionDescription(
+ usage = "_FUNC_(bin) - Convert the argument from binary to a base 64 string.")
case class Base64(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
override def dataType: DataType = StringType
@@ -839,6 +935,8 @@ case class Base64(child: Expression) extends UnaryExpression with ImplicitCastIn
/**
* Converts the argument from a base 64 string to BINARY.
*/
+@ExpressionDescription(
+ usage = "_FUNC_(str) - Convert the argument from a base 64 string to binary.")
case class UnBase64(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
override def dataType: DataType = BinaryType
@@ -860,6 +958,8 @@ case class UnBase64(child: Expression) extends UnaryExpression with ImplicitCast
* (one of 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16').
* If either argument is null, the result will also be null.
*/
+@ExpressionDescription(
+ usage = "_FUNC_(bin, str) - Decode the first argument using the second argument character set.")
case class Decode(bin: Expression, charset: Expression)
extends BinaryExpression with ImplicitCastInputTypes {
@@ -889,7 +989,9 @@ case class Decode(bin: Expression, charset: Expression)
* Encodes the first argument into a BINARY using the provided character set
* (one of 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16').
* If either argument is null, the result will also be null.
-*/
+ */
+@ExpressionDescription(
+ usage = "_FUNC_(str, str) - Encode the first argument using the second argument character set.")
case class Encode(value: Expression, charset: Expression)
extends BinaryExpression with ImplicitCastInputTypes {
@@ -919,6 +1021,11 @@ case class Encode(value: Expression, charset: Expression)
* and returns the result as a string. If D is 0, the result has no decimal point or
* fractional part.
*/
+@ExpressionDescription(
+ usage = """_FUNC_(X, D) - Formats the number X like '#,###,###.##', rounded to D decimal places.
+ If D is 0, the result has no decimal point or fractional part.
+ This is supposed to function like MySQL's FORMAT.""",
+ extended = "> SELECT _FUNC_(12332.123456, 4);\n '12,332.1235'")
case class FormatNumber(x: Expression, d: Expression)
extends BinaryExpression with ExpectsInputTypes {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala
index b8679474cf..c0b453dccf 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala
@@ -451,7 +451,11 @@ abstract class RowNumberLike extends AggregateWindowFunction {
* A [[SizeBasedWindowFunction]] needs the size of the current window for its calculation.
*/
trait SizeBasedWindowFunction extends AggregateWindowFunction {
- protected def n: AttributeReference = SizeBasedWindowFunction.n
+ // It's made a val so that the attribute created on driver side is serialized to executor side.
+ // Otherwise, if it's defined as a function, when it's called on executor side, it actually
+ // returns the singleton value instantiated on executor side, which has different expression ID
+ // from the one created on driver side.
+ val n: AttributeReference = SizeBasedWindowFunction.n
}
object SizeBasedWindowFunction {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/identifiers.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/identifiers.scala
index 87f4d1b007..aae75956ea 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/identifiers.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/identifiers.scala
@@ -25,10 +25,10 @@ package org.apache.spark.sql.catalyst
* Format (quoted): "`name`" or "`db`.`name`"
*/
sealed trait IdentifierWithDatabase {
- val name: String
+ val identifier: String
def database: Option[String]
- def quotedString: String = database.map(db => s"`$db`.`$name`").getOrElse(s"`$name`")
- def unquotedString: String = database.map(db => s"$db.$name").getOrElse(name)
+ def quotedString: String = database.map(db => s"`$db`.`$identifier`").getOrElse(s"`$identifier`")
+ def unquotedString: String = database.map(db => s"$db.$identifier").getOrElse(identifier)
override def toString: String = quotedString
}
@@ -36,13 +36,15 @@ sealed trait IdentifierWithDatabase {
/**
* Identifies a table in a database.
* If `database` is not defined, the current database is used.
+ * When we register a permenent function in the FunctionRegistry, we use
+ * unquotedString as the function name.
*/
case class TableIdentifier(table: String, database: Option[String])
extends IdentifierWithDatabase {
- override val name: String = table
+ override val identifier: String = table
- def this(name: String) = this(name, None)
+ def this(table: String) = this(table, None)
}
@@ -58,9 +60,9 @@ object TableIdentifier {
case class FunctionIdentifier(funcName: String, database: Option[String])
extends IdentifierWithDatabase {
- override val name: String = funcName
+ override val identifier: String = funcName
- def this(name: String) = this(name, None)
+ def this(funcName: String) = this(funcName, None)
}
object FunctionIdentifier {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index a7a948ef1b..f5172b213a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -31,9 +31,9 @@ import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.types._
/**
- * Abstract class all optimizers should inherit of, contains the standard batches (extending
- * Optimizers can override this.
- */
+ * Abstract class all optimizers should inherit of, contains the standard batches (extending
+ * Optimizers can override this.
+ */
abstract class Optimizer extends RuleExecutor[LogicalPlan] {
def batches: Seq[Batch] = {
// Technically some of the rules in Finish Analysis are not optimizer rules and belong more
@@ -66,9 +66,7 @@ abstract class Optimizer extends RuleExecutor[LogicalPlan] {
ReorderJoin,
OuterJoinElimination,
PushPredicateThroughJoin,
- PushPredicateThroughProject,
- PushPredicateThroughGenerate,
- PushPredicateThroughAggregate,
+ PushDownPredicate,
LimitPushDown,
ColumnPruning,
InferFiltersFromConstraints,
@@ -86,6 +84,7 @@ abstract class Optimizer extends RuleExecutor[LogicalPlan] {
BooleanSimplification,
SimplifyConditionals,
RemoveDispensableExpressions,
+ BinaryComparisonSimplification,
PruneFilters,
EliminateSorts,
SimplifyCasts,
@@ -93,6 +92,8 @@ abstract class Optimizer extends RuleExecutor[LogicalPlan] {
EliminateSerialization) ::
Batch("Decimal Optimizations", FixedPoint(100),
DecimalAggregates) ::
+ Batch("Typed Filter Optimization", FixedPoint(100),
+ EmbedSerializerInFilter) ::
Batch("LocalRelation", FixedPoint(100),
ConvertToLocalRelation) ::
Batch("Subquery", Once,
@@ -111,11 +112,11 @@ abstract class Optimizer extends RuleExecutor[LogicalPlan] {
}
/**
- * Non-abstract representation of the standard Spark optimizing strategies
- *
- * To ensure extendability, we leave the standard rules in the abstract optimizer rules, while
- * specific rules go to the subclasses
- */
+ * Non-abstract representation of the standard Spark optimizing strategies
+ *
+ * To ensure extendability, we leave the standard rules in the abstract optimizer rules, while
+ * specific rules go to the subclasses
+ */
object DefaultOptimizer extends Optimizer
/**
@@ -136,6 +137,7 @@ object SamplePushDown extends Rule[LogicalPlan] {
* representation of data item. For example back to back map operations.
*/
object EliminateSerialization extends Rule[LogicalPlan] {
+ // TODO: find a more general way to do this optimization.
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case m @ MapPartitions(_, deserializer, _, child: ObjectOperator)
if !deserializer.isInstanceOf[Attribute] &&
@@ -144,6 +146,20 @@ object EliminateSerialization extends Rule[LogicalPlan] {
m.copy(
deserializer = childWithoutSerialization.output.head,
child = childWithoutSerialization)
+
+ case m @ MapElements(_, deserializer, _, child: ObjectOperator)
+ if !deserializer.isInstanceOf[Attribute] &&
+ deserializer.dataType == child.outputObject.dataType =>
+ val childWithoutSerialization = child.withObjectOutput
+ m.copy(
+ deserializer = childWithoutSerialization.output.head,
+ child = childWithoutSerialization)
+
+ case d @ DeserializeToObject(_, s: SerializeFromObject)
+ if d.outputObjectType == s.inputObjectType =>
+ // Adds an extra Project here, to preserve the output expr id of `DeserializeToObject`.
+ val objAttr = Alias(s.child.output.head, "obj")(exprId = d.output.head.exprId)
+ Project(objAttr :: Nil, s.child)
}
}
@@ -270,10 +286,10 @@ object SetOperationPushDown extends Rule[LogicalPlan] with PredicateHelper {
assert(children.nonEmpty)
if (projectList.forall(_.deterministic)) {
val newFirstChild = Project(projectList, children.head)
- val newOtherChildren = children.tail.map ( child => {
+ val newOtherChildren = children.tail.map { child =>
val rewrites = buildRewrites(children.head, child)
Project(projectList.map(pushToRight(_, rewrites)), child)
- } )
+ }
Union(newFirstChild +: newOtherChildren)
} else {
p
@@ -352,8 +368,8 @@ object ColumnPruning extends Rule[LogicalPlan] {
case p @ Project(_, g: Generate) if g.join && p.references.subsetOf(g.generatedSet) =>
p.copy(child = g.copy(join = false))
- // Eliminate unneeded attributes from right side of a LeftSemiJoin.
- case j @ Join(left, right, LeftSemi, condition) =>
+ // Eliminate unneeded attributes from right side of a Left Existence Join.
+ case j @ Join(left, right, LeftExistence(_), condition) =>
j.copy(right = prunedChild(right, j.references))
// all the columns will be used to compare, so we can't prune them
@@ -501,22 +517,28 @@ object LikeSimplification extends Rule[LogicalPlan] {
// Cases like "something\%" are not optimized, but this does not affect correctness.
private val startsWith = "([^_%]+)%".r
private val endsWith = "%([^_%]+)".r
+ private val startsAndEndsWith = "([^_%]+)%([^_%]+)".r
private val contains = "%([^_%]+)%".r
private val equalTo = "([^_%]*)".r
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
- case Like(l, Literal(utf, StringType)) =>
- utf.toString match {
- case startsWith(pattern) if !pattern.endsWith("\\") =>
- StartsWith(l, Literal(pattern))
- case endsWith(pattern) =>
- EndsWith(l, Literal(pattern))
- case contains(pattern) if !pattern.endsWith("\\") =>
- Contains(l, Literal(pattern))
- case equalTo(pattern) =>
- EqualTo(l, Literal(pattern))
+ case Like(input, Literal(pattern, StringType)) =>
+ pattern.toString match {
+ case startsWith(prefix) if !prefix.endsWith("\\") =>
+ StartsWith(input, Literal(prefix))
+ case endsWith(postfix) =>
+ EndsWith(input, Literal(postfix))
+ // 'a%a' pattern is basically same with 'a%' && '%a'.
+ // However, the additional `Length` condition is required to prevent 'a' match 'a%a'.
+ case startsAndEndsWith(prefix, postfix) if !prefix.endsWith("\\") =>
+ And(GreaterThanOrEqual(Length(input), Literal(prefix.size + postfix.size)),
+ And(StartsWith(input, Literal(prefix)), EndsWith(input, Literal(postfix))))
+ case contains(infix) if !infix.endsWith("\\") =>
+ Contains(input, Literal(infix))
+ case equalTo(str) =>
+ EqualTo(input, Literal(str))
case _ =>
- Like(l, Literal.create(utf, StringType))
+ Like(input, Literal.create(pattern, StringType))
}
}
}
@@ -527,14 +549,14 @@ object LikeSimplification extends Rule[LogicalPlan] {
* Null value propagation from bottom to top of the expression tree.
*/
object NullPropagation extends Rule[LogicalPlan] {
- def nonNullLiteral(e: Expression): Boolean = e match {
+ private def nonNullLiteral(e: Expression): Boolean = e match {
case Literal(null, _) => false
case _ => true
}
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case q: LogicalPlan => q transformExpressionsUp {
- case e @ AggregateExpression(Count(exprs), _, _) if !exprs.exists(nonNullLiteral) =>
+ case e @ AggregateExpression(Count(exprs), _, _, _) if !exprs.exists(nonNullLiteral) =>
Cast(Literal(0L), e.dataType)
case e @ IsNull(c) if !c.nullable => Literal.create(false, BooleanType)
case e @ IsNotNull(c) if !c.nullable => Literal.create(true, BooleanType)
@@ -547,9 +569,9 @@ object NullPropagation extends Rule[LogicalPlan] {
Literal.create(null, e.dataType)
case e @ EqualNullSafe(Literal(null, _), r) => IsNull(r)
case e @ EqualNullSafe(l, Literal(null, _)) => IsNull(l)
- case e @ AggregateExpression(Count(exprs), mode, false) if !exprs.exists(_.nullable) =>
+ case ae @ AggregateExpression(Count(exprs), _, false, _) if !exprs.exists(_.nullable) =>
// This rule should be only triggered when isDistinct field is false.
- AggregateExpression(Count(Literal(1)), mode, isDistinct = false)
+ ae.copy(aggregateFunction = Count(Literal(1)))
// For Coalesce, remove null literals.
case e @ Coalesce(children) =>
@@ -770,20 +792,50 @@ object BooleanSimplification extends Rule[LogicalPlan] with PredicateHelper {
}
/**
+ * Simplifies binary comparisons with semantically-equal expressions:
+ * 1) Replace '<=>' with 'true' literal.
+ * 2) Replace '=', '<=', and '>=' with 'true' literal if both operands are non-nullable.
+ * 3) Replace '<' and '>' with 'false' literal if both operands are non-nullable.
+ */
+object BinaryComparisonSimplification extends Rule[LogicalPlan] with PredicateHelper {
+ def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+ case q: LogicalPlan => q transformExpressionsUp {
+ // True with equality
+ case a EqualNullSafe b if a.semanticEquals(b) => TrueLiteral
+ case a EqualTo b if !a.nullable && !b.nullable && a.semanticEquals(b) => TrueLiteral
+ case a GreaterThanOrEqual b if !a.nullable && !b.nullable && a.semanticEquals(b) =>
+ TrueLiteral
+ case a LessThanOrEqual b if !a.nullable && !b.nullable && a.semanticEquals(b) => TrueLiteral
+
+ // False with inequality
+ case a GreaterThan b if !a.nullable && !b.nullable && a.semanticEquals(b) => FalseLiteral
+ case a LessThan b if !a.nullable && !b.nullable && a.semanticEquals(b) => FalseLiteral
+ }
+ }
+}
+
+/**
* Simplifies conditional expressions (if / case).
*/
object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper {
+ private def falseOrNullLiteral(e: Expression): Boolean = e match {
+ case FalseLiteral => true
+ case Literal(null, _) => true
+ case _ => false
+ }
+
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case q: LogicalPlan => q transformExpressionsUp {
case If(TrueLiteral, trueValue, _) => trueValue
case If(FalseLiteral, _, falseValue) => falseValue
+ case If(Literal(null, _), _, falseValue) => falseValue
- case e @ CaseWhen(branches, elseValue) if branches.exists(_._1 == FalseLiteral) =>
+ case e @ CaseWhen(branches, elseValue) if branches.exists(x => falseOrNullLiteral(x._1)) =>
// If there are branches that are always false, remove them.
// If there are no more branches left, just use the else value.
// Note that these two are handled together here in a single case statement because
// otherwise we cannot determine the data type for the elseValue if it is None (i.e. null).
- val newBranches = branches.filter(_._1 != FalseLiteral)
+ val newBranches = branches.filter(x => !falseOrNullLiteral(x._1))
if (newBranches.isEmpty) {
elseValue.getOrElse(Literal.create(null, e.dataType))
} else {
@@ -869,12 +921,13 @@ object PruneFilters extends Rule[LogicalPlan] with PredicateHelper {
}
/**
- * Pushes [[Filter]] operators through [[Project]] operators, in-lining any [[Alias Aliases]]
- * that were defined in the projection.
+ * Pushes [[Filter]] operators through many operators iff:
+ * 1) the operator is deterministic
+ * 2) the predicate is deterministic and the operator will not change any of rows.
*
* This heuristic is valid assuming the expression evaluation cost is minimal.
*/
-object PushPredicateThroughProject extends Rule[LogicalPlan] with PredicateHelper {
+object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
// SPARK-13473: We can't push the predicate down when the underlying projection output non-
// deterministic field(s). Non-deterministic expressions are essentially stateful. This
@@ -891,41 +944,7 @@ object PushPredicateThroughProject extends Rule[LogicalPlan] with PredicateHelpe
})
project.copy(child = Filter(replaceAlias(condition, aliasMap), grandChild))
- }
-
-}
-
-/**
- * Push [[Filter]] operators through [[Generate]] operators. Parts of the predicate that reference
- * attributes generated in [[Generate]] will remain above, and the rest should be pushed beneath.
- */
-object PushPredicateThroughGenerate extends Rule[LogicalPlan] with PredicateHelper {
-
- def apply(plan: LogicalPlan): LogicalPlan = plan transform {
- case filter @ Filter(condition, g: Generate) =>
- // Predicates that reference attributes produced by the `Generate` operator cannot
- // be pushed below the operator.
- val (pushDown, stayUp) = splitConjunctivePredicates(condition).partition { cond =>
- cond.references.subsetOf(g.child.outputSet) && cond.deterministic
- }
- if (pushDown.nonEmpty) {
- val pushDownPredicate = pushDown.reduce(And)
- val newGenerate = Generate(g.generator, join = g.join, outer = g.outer,
- g.qualifier, g.generatorOutput, Filter(pushDownPredicate, g.child))
- if (stayUp.isEmpty) newGenerate else Filter(stayUp.reduce(And), newGenerate)
- } else {
- filter
- }
- }
-}
-
-/**
- * Push [[Filter]] operators through [[Aggregate]] operators, iff the filters reference only
- * non-aggregate attributes (typically literals or grouping expressions).
- */
-object PushPredicateThroughAggregate extends Rule[LogicalPlan] with PredicateHelper {
- def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case filter @ Filter(condition, aggregate: Aggregate) =>
// Find all the aliased expressions in the aggregate list that don't include any actual
// AggregateExpression, and create a map from the alias to the expression
@@ -951,25 +970,91 @@ object PushPredicateThroughAggregate extends Rule[LogicalPlan] with PredicateHel
} else {
filter
}
+
+ case filter @ Filter(condition, child)
+ if child.isInstanceOf[Union] || child.isInstanceOf[Intersect] =>
+ // Union/Intersect could change the rows, so non-deterministic predicate can't be pushed down
+ val (pushDown, stayUp) = splitConjunctivePredicates(condition).partition { cond =>
+ cond.deterministic
+ }
+ if (pushDown.nonEmpty) {
+ val pushDownCond = pushDown.reduceLeft(And)
+ val output = child.output
+ val newGrandChildren = child.children.map { grandchild =>
+ val newCond = pushDownCond transform {
+ case e if output.exists(_.semanticEquals(e)) =>
+ grandchild.output(output.indexWhere(_.semanticEquals(e)))
+ }
+ assert(newCond.references.subsetOf(grandchild.outputSet))
+ Filter(newCond, grandchild)
+ }
+ val newChild = child.withNewChildren(newGrandChildren)
+ if (stayUp.nonEmpty) {
+ Filter(stayUp.reduceLeft(And), newChild)
+ } else {
+ newChild
+ }
+ } else {
+ filter
+ }
+
+ case filter @ Filter(condition, e @ Except(left, _)) =>
+ pushDownPredicate(filter, e.left) { predicate =>
+ e.copy(left = Filter(predicate, left))
+ }
+
+ // two filters should be combine together by other rules
+ case filter @ Filter(_, f: Filter) => filter
+ // should not push predicates through sample, or will generate different results.
+ case filter @ Filter(_, s: Sample) => filter
+ // TODO: push predicates through expand
+ case filter @ Filter(_, e: Expand) => filter
+
+ case filter @ Filter(condition, u: UnaryNode) if u.expressions.forall(_.deterministic) =>
+ pushDownPredicate(filter, u.child) { predicate =>
+ u.withNewChildren(Seq(Filter(predicate, u.child)))
+ }
+ }
+
+ private def pushDownPredicate(
+ filter: Filter,
+ grandchild: LogicalPlan)(insertFilter: Expression => LogicalPlan): LogicalPlan = {
+ // Only push down the predicates that is deterministic and all the referenced attributes
+ // come from grandchild.
+ // TODO: non-deterministic predicates could be pushed through some operators that do not change
+ // the rows.
+ val (pushDown, stayUp) = splitConjunctivePredicates(filter.condition).partition { cond =>
+ cond.deterministic && cond.references.subsetOf(grandchild.outputSet)
+ }
+ if (pushDown.nonEmpty) {
+ val newChild = insertFilter(pushDown.reduceLeft(And))
+ if (stayUp.nonEmpty) {
+ Filter(stayUp.reduceLeft(And), newChild)
+ } else {
+ newChild
+ }
+ } else {
+ filter
+ }
}
}
/**
- * Reorder the joins and push all the conditions into join, so that the bottom ones have at least
- * one condition.
- *
- * The order of joins will not be changed if all of them already have at least one condition.
- */
+ * Reorder the joins and push all the conditions into join, so that the bottom ones have at least
+ * one condition.
+ *
+ * The order of joins will not be changed if all of them already have at least one condition.
+ */
object ReorderJoin extends Rule[LogicalPlan] with PredicateHelper {
/**
- * Join a list of plans together and push down the conditions into them.
- *
- * The joined plan are picked from left to right, prefer those has at least one join condition.
- *
- * @param input a list of LogicalPlans to join.
- * @param conditions a list of condition for join.
- */
+ * Join a list of plans together and push down the conditions into them.
+ *
+ * The joined plan are picked from left to right, prefer those has at least one join condition.
+ *
+ * @param input a list of LogicalPlans to join.
+ * @param conditions a list of condition for join.
+ */
@tailrec
def createOrderedJoin(input: Seq[LogicalPlan], conditions: Seq[Expression]): LogicalPlan = {
assert(input.size >= 2)
@@ -1110,7 +1195,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper {
(leftFilterConditions ++ commonFilterCondition).
reduceLeftOption(And).map(Filter(_, newJoin)).getOrElse(newJoin)
- case _ @ (LeftOuter | LeftSemi) =>
+ case LeftOuter | LeftExistence(_) =>
// push down the left side only `where` condition
val newLeft = leftFilterConditions.
reduceLeftOption(And).map(Filter(_, left)).getOrElse(left)
@@ -1131,7 +1216,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper {
split(joinCondition.map(splitConjunctivePredicates).getOrElse(Nil), left, right)
joinType match {
- case _ @ (Inner | LeftSemi) =>
+ case Inner | LeftExistence(_) =>
// push down the single side only join filter for both sides sub queries
val newLeft = leftJoinConditions.
reduceLeftOption(And).map(Filter(_, left)).getOrElse(left)
@@ -1225,13 +1310,13 @@ object DecimalAggregates extends Rule[LogicalPlan] {
private val MAX_DOUBLE_DIGITS = 15
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
- case AggregateExpression(Sum(e @ DecimalType.Expression(prec, scale)), mode, isDistinct)
+ case ae @ AggregateExpression(Sum(e @ DecimalType.Expression(prec, scale)), _, _, _)
if prec + 10 <= MAX_LONG_DIGITS =>
- MakeDecimal(AggregateExpression(Sum(UnscaledValue(e)), mode, isDistinct), prec + 10, scale)
+ MakeDecimal(ae.copy(aggregateFunction = Sum(UnscaledValue(e))), prec + 10, scale)
- case AggregateExpression(Average(e @ DecimalType.Expression(prec, scale)), mode, isDistinct)
+ case ae @ AggregateExpression(Average(e @ DecimalType.Expression(prec, scale)), _, _, _)
if prec + 4 <= MAX_DOUBLE_DIGITS =>
- val newAggExpr = AggregateExpression(Average(UnscaledValue(e)), mode, isDistinct)
+ val newAggExpr = ae.copy(aggregateFunction = Average(UnscaledValue(e)))
Cast(
Divide(newAggExpr, Literal.create(math.pow(10.0, scale), DoubleType)),
DecimalType(prec + 4, scale + 4))
@@ -1313,3 +1398,30 @@ object ComputeCurrentTime extends Rule[LogicalPlan] {
}
}
}
+
+/**
+ * Typed [[Filter]] is by default surrounded by a [[DeserializeToObject]] beneath it and a
+ * [[SerializeFromObject]] above it. If these serializations can't be eliminated, we should embed
+ * the deserializer in filter condition to save the extra serialization at last.
+ */
+object EmbedSerializerInFilter extends Rule[LogicalPlan] {
+ def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+ case s @ SerializeFromObject(_, Filter(condition, d: DeserializeToObject)) =>
+ val numObjects = condition.collect {
+ case a: Attribute if a == d.output.head => a
+ }.length
+
+ if (numObjects > 1) {
+ // If the filter condition references the object more than one times, we should not embed
+ // deserializer in it as the deserialization will happen many times and slow down the
+ // execution.
+ // TODO: we can still embed it if we can make sure subexpression elimination works here.
+ s
+ } else {
+ val newCondition = condition transform {
+ case a: Attribute if a == d.output.head => d.deserializer.child
+ }
+ Filter(newCondition, d.child)
+ }
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ASTNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ASTNode.scala
deleted file mode 100644
index 28f7b10ed6..0000000000
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ASTNode.scala
+++ /dev/null
@@ -1,99 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.spark.sql.catalyst.parser
-
-import org.antlr.runtime.{Token, TokenRewriteStream}
-
-import org.apache.spark.sql.catalyst.trees.{Origin, TreeNode}
-
-case class ASTNode(
- token: Token,
- startIndex: Int,
- stopIndex: Int,
- children: List[ASTNode],
- stream: TokenRewriteStream) extends TreeNode[ASTNode] {
- /** Cache the number of children. */
- val numChildren: Int = children.size
-
- /** tuple used in pattern matching. */
- val pattern: Some[(String, List[ASTNode])] = Some((token.getText, children))
-
- /** Line in which the ASTNode starts. */
- lazy val line: Int = {
- val line = token.getLine
- if (line == 0) {
- if (children.nonEmpty) children.head.line
- else 0
- } else {
- line
- }
- }
-
- /** Position of the Character at which ASTNode starts. */
- lazy val positionInLine: Int = {
- val line = token.getCharPositionInLine
- if (line == -1) {
- if (children.nonEmpty) children.head.positionInLine
- else 0
- } else {
- line
- }
- }
-
- /** Origin of the ASTNode. */
- override val origin: Origin = Origin(Some(line), Some(positionInLine))
-
- /** Source text. */
- lazy val source: String = stream.toOriginalString(startIndex, stopIndex)
-
- /** Get the source text that remains after this token. */
- lazy val remainder: String = {
- stream.fill()
- stream.toOriginalString(stopIndex + 1, stream.size() - 1).trim()
- }
-
- def text: String = token.getText
-
- def tokenType: Int = token.getType
-
- /**
- * Checks if this node is equal to another node.
- *
- * Right now this function only checks the name, type, text and children of the node
- * for equality.
- */
- def treeEquals(other: ASTNode): Boolean = {
- def check(f: ASTNode => Any): Boolean = {
- val l = f(this)
- val r = f(other)
- (l == null && r == null) || l.equals(r)
- }
- if (other == null) {
- false
- } else if (!check(_.token.getType)
- || !check(_.token.getText)
- || !check(_.numChildren)) {
- false
- } else {
- children.zip(other.children).forall {
- case (l, r) => l treeEquals r
- }
- }
- }
-
- override def simpleString: String = s"$text $line, $startIndex, $stopIndex, $positionInLine "
-}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AbstractSparkSQLParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AbstractSparkSQLParser.scala
deleted file mode 100644
index 7b456a6de3..0000000000
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AbstractSparkSQLParser.scala
+++ /dev/null
@@ -1,145 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.catalyst.parser
-
-import scala.language.implicitConversions
-import scala.util.parsing.combinator.lexical.StdLexical
-import scala.util.parsing.combinator.syntactical.StandardTokenParsers
-import scala.util.parsing.combinator.PackratParsers
-import scala.util.parsing.input.CharArrayReader.EofCh
-
-import org.apache.spark.sql.catalyst.plans.logical._
-
-private[sql] abstract class AbstractSparkSQLParser
- extends StandardTokenParsers with PackratParsers with ParserInterface {
-
- def parsePlan(input: String): LogicalPlan = synchronized {
- // Initialize the Keywords.
- initLexical
- phrase(start)(new lexical.Scanner(input)) match {
- case Success(plan, _) => plan
- case failureOrError => sys.error(failureOrError.toString)
- }
- }
- /* One time initialization of lexical.This avoid reinitialization of lexical in parse method */
- protected lazy val initLexical: Unit = lexical.initialize(reservedWords)
-
- protected case class Keyword(str: String) {
- def normalize: String = lexical.normalizeKeyword(str)
- def parser: Parser[String] = normalize
- }
-
- protected implicit def asParser(k: Keyword): Parser[String] = k.parser
-
- // By default, use Reflection to find the reserved words defined in the sub class.
- // NOTICE, Since the Keyword properties defined by sub class, we couldn't call this
- // method during the parent class instantiation, because the sub class instance
- // isn't created yet.
- protected lazy val reservedWords: Seq[String] =
- this
- .getClass
- .getMethods
- .filter(_.getReturnType == classOf[Keyword])
- .map(_.invoke(this).asInstanceOf[Keyword].normalize)
-
- // Set the keywords as empty by default, will change that later.
- override val lexical = new SqlLexical
-
- protected def start: Parser[LogicalPlan]
-
- // Returns the whole input string
- protected lazy val wholeInput: Parser[String] = new Parser[String] {
- def apply(in: Input): ParseResult[String] =
- Success(in.source.toString, in.drop(in.source.length()))
- }
-
- // Returns the rest of the input string that are not parsed yet
- protected lazy val restInput: Parser[String] = new Parser[String] {
- def apply(in: Input): ParseResult[String] =
- Success(
- in.source.subSequence(in.offset, in.source.length()).toString,
- in.drop(in.source.length()))
- }
-}
-
-class SqlLexical extends StdLexical {
- case class DecimalLit(chars: String) extends Token {
- override def toString: String = chars
- }
-
- /* This is a work around to support the lazy setting */
- def initialize(keywords: Seq[String]): Unit = {
- reserved.clear()
- reserved ++= keywords
- }
-
- /* Normal the keyword string */
- def normalizeKeyword(str: String): String = str.toLowerCase
-
- delimiters += (
- "@", "*", "+", "-", "<", "=", "<>", "!=", "<=", ">=", ">", "/", "(", ")",
- ",", ";", "%", "{", "}", ":", "[", "]", ".", "&", "|", "^", "~", "<=>"
- )
-
- protected override def processIdent(name: String) = {
- val token = normalizeKeyword(name)
- if (reserved contains token) Keyword(token) else Identifier(name)
- }
-
- override lazy val token: Parser[Token] =
- ( rep1(digit) ~ scientificNotation ^^ { case i ~ s => DecimalLit(i.mkString + s) }
- | '.' ~> (rep1(digit) ~ scientificNotation) ^^
- { case i ~ s => DecimalLit("0." + i.mkString + s) }
- | rep1(digit) ~ ('.' ~> digit.*) ~ scientificNotation ^^
- { case i1 ~ i2 ~ s => DecimalLit(i1.mkString + "." + i2.mkString + s) }
- | digit.* ~ identChar ~ (identChar | digit).* ^^
- { case first ~ middle ~ rest => processIdent((first ++ (middle :: rest)).mkString) }
- | rep1(digit) ~ ('.' ~> digit.*).? ^^ {
- case i ~ None => NumericLit(i.mkString)
- case i ~ Some(d) => DecimalLit(i.mkString + "." + d.mkString)
- }
- | '\'' ~> chrExcept('\'', '\n', EofCh).* <~ '\'' ^^
- { case chars => StringLit(chars mkString "") }
- | '"' ~> chrExcept('"', '\n', EofCh).* <~ '"' ^^
- { case chars => StringLit(chars mkString "") }
- | '`' ~> chrExcept('`', '\n', EofCh).* <~ '`' ^^
- { case chars => Identifier(chars mkString "") }
- | EofCh ^^^ EOF
- | '\'' ~> failure("unclosed string literal")
- | '"' ~> failure("unclosed string literal")
- | delim
- | failure("illegal character")
- )
-
- override def identChar: Parser[Elem] = letter | elem('_')
-
- private lazy val scientificNotation: Parser[String] =
- (elem('e') | elem('E')) ~> (elem('+') | elem('-')).? ~ rep1(digit) ^^ {
- case s ~ rest => "e" + s.mkString + rest.mkString
- }
-
- override def whitespace: Parser[Any] =
- ( whitespaceChar
- | '/' ~ '*' ~ comment
- | '/' ~ '/' ~ chrExcept(EofCh, '\n').*
- | '#' ~ chrExcept(EofCh, '\n').*
- | '-' ~ '-' ~ chrExcept(EofCh, '\n').*
- | '/' ~ '*' ~ failure("unclosed comment")
- ).*
-}
-
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
new file mode 100644
index 0000000000..aa59f3fb2a
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
@@ -0,0 +1,1455 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.catalyst.parser
+
+import java.sql.{Date, Timestamp}
+
+import scala.collection.JavaConverters._
+import scala.collection.mutable.ArrayBuffer
+
+import org.antlr.v4.runtime.{ParserRuleContext, Token}
+import org.antlr.v4.runtime.tree.{ParseTree, RuleNode, TerminalNode}
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier}
+import org.apache.spark.sql.catalyst.analysis._
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.parser.SqlBaseParser._
+import org.apache.spark.sql.catalyst.plans._
+import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.CalendarInterval
+import org.apache.spark.util.random.RandomSampler
+
+/**
+ * The AstBuilder converts an ANTLR4 ParseTree into a catalyst Expression, LogicalPlan or
+ * TableIdentifier.
+ */
+class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging {
+ import ParserUtils._
+
+ protected def typedVisit[T](ctx: ParseTree): T = {
+ ctx.accept(this).asInstanceOf[T]
+ }
+
+ /**
+ * Override the default behavior for all visit methods. This will only return a non-null result
+ * when the context has only one child. This is done because there is no generic method to
+ * combine the results of the context children. In all other cases null is returned.
+ */
+ override def visitChildren(node: RuleNode): AnyRef = {
+ if (node.getChildCount == 1) {
+ node.getChild(0).accept(this)
+ } else {
+ null
+ }
+ }
+
+ override def visitSingleStatement(ctx: SingleStatementContext): LogicalPlan = withOrigin(ctx) {
+ visit(ctx.statement).asInstanceOf[LogicalPlan]
+ }
+
+ override def visitSingleExpression(ctx: SingleExpressionContext): Expression = withOrigin(ctx) {
+ visitNamedExpression(ctx.namedExpression)
+ }
+
+ override def visitSingleTableIdentifier(
+ ctx: SingleTableIdentifierContext): TableIdentifier = withOrigin(ctx) {
+ visitTableIdentifier(ctx.tableIdentifier)
+ }
+
+ override def visitSingleDataType(ctx: SingleDataTypeContext): DataType = withOrigin(ctx) {
+ visit(ctx.dataType).asInstanceOf[DataType]
+ }
+
+ /* ********************************************************************************************
+ * Plan parsing
+ * ******************************************************************************************** */
+ protected def plan(tree: ParserRuleContext): LogicalPlan = typedVisit(tree)
+
+ /**
+ * Make sure we do not try to create a plan for a native command.
+ */
+ override def visitExecuteNativeCommand(ctx: ExecuteNativeCommandContext): LogicalPlan = null
+
+ /**
+ * Create a plan for a SHOW FUNCTIONS command.
+ */
+ override def visitShowFunctions(ctx: ShowFunctionsContext): LogicalPlan = withOrigin(ctx) {
+ import ctx._
+ if (qualifiedName != null) {
+ val names = qualifiedName().identifier().asScala.map(_.getText).toList
+ names match {
+ case db :: name :: Nil =>
+ ShowFunctions(Some(db), Some(name))
+ case name :: Nil =>
+ ShowFunctions(None, Some(name))
+ case _ =>
+ throw new ParseException("SHOW FUNCTIONS unsupported name", ctx)
+ }
+ } else if (pattern != null) {
+ ShowFunctions(None, Some(string(pattern)))
+ } else {
+ ShowFunctions(None, None)
+ }
+ }
+
+ /**
+ * Create a plan for a DESCRIBE FUNCTION command.
+ */
+ override def visitDescribeFunction(ctx: DescribeFunctionContext): LogicalPlan = withOrigin(ctx) {
+ val functionName = ctx.qualifiedName().identifier().asScala.map(_.getText).mkString(".")
+ DescribeFunction(functionName, ctx.EXTENDED != null)
+ }
+
+ /**
+ * Create a top-level plan with Common Table Expressions.
+ */
+ override def visitQuery(ctx: QueryContext): LogicalPlan = withOrigin(ctx) {
+ val query = plan(ctx.queryNoWith)
+
+ // Apply CTEs
+ query.optional(ctx.ctes) {
+ val ctes = ctx.ctes.namedQuery.asScala.map {
+ case nCtx =>
+ val namedQuery = visitNamedQuery(nCtx)
+ (namedQuery.alias, namedQuery)
+ }
+
+ // Check for duplicate names.
+ ctes.groupBy(_._1).filter(_._2.size > 1).foreach {
+ case (name, _) =>
+ throw new ParseException(
+ s"Name '$name' is used for multiple common table expressions", ctx)
+ }
+
+ With(query, ctes.toMap)
+ }
+ }
+
+ /**
+ * Create a named logical plan.
+ *
+ * This is only used for Common Table Expressions.
+ */
+ override def visitNamedQuery(ctx: NamedQueryContext): SubqueryAlias = withOrigin(ctx) {
+ SubqueryAlias(ctx.name.getText, plan(ctx.queryNoWith))
+ }
+
+ /**
+ * Create a logical plan which allows for multiple inserts using one 'from' statement. These
+ * queries have the following SQL form:
+ * {{{
+ * [WITH cte...]?
+ * FROM src
+ * [INSERT INTO tbl1 SELECT *]+
+ * }}}
+ * For example:
+ * {{{
+ * FROM db.tbl1 A
+ * INSERT INTO dbo.tbl1 SELECT * WHERE A.value = 10 LIMIT 5
+ * INSERT INTO dbo.tbl2 SELECT * WHERE A.value = 12
+ * }}}
+ * This (Hive) feature cannot be combined with set-operators.
+ */
+ override def visitMultiInsertQuery(ctx: MultiInsertQueryContext): LogicalPlan = withOrigin(ctx) {
+ val from = visitFromClause(ctx.fromClause)
+
+ // Build the insert clauses.
+ val inserts = ctx.multiInsertQueryBody.asScala.map {
+ body =>
+ assert(body.querySpecification.fromClause == null,
+ "Multi-Insert queries cannot have a FROM clause in their individual SELECT statements",
+ body)
+
+ withQuerySpecification(body.querySpecification, from).
+ // Add organization statements.
+ optionalMap(body.queryOrganization)(withQueryResultClauses).
+ // Add insert.
+ optionalMap(body.insertInto())(withInsertInto)
+ }
+
+ // If there are multiple INSERTS just UNION them together into one query.
+ inserts match {
+ case Seq(query) => query
+ case queries => Union(queries)
+ }
+ }
+
+ /**
+ * Create a logical plan for a regular (single-insert) query.
+ */
+ override def visitSingleInsertQuery(
+ ctx: SingleInsertQueryContext): LogicalPlan = withOrigin(ctx) {
+ plan(ctx.queryTerm).
+ // Add organization statements.
+ optionalMap(ctx.queryOrganization)(withQueryResultClauses).
+ // Add insert.
+ optionalMap(ctx.insertInto())(withInsertInto)
+ }
+
+ /**
+ * Add an INSERT INTO [TABLE]/INSERT OVERWRITE TABLE operation to the logical plan.
+ */
+ private def withInsertInto(
+ ctx: InsertIntoContext,
+ query: LogicalPlan): LogicalPlan = withOrigin(ctx) {
+ val tableIdent = visitTableIdentifier(ctx.tableIdentifier)
+ val partitionKeys = Option(ctx.partitionSpec).map(visitPartitionSpec).getOrElse(Map.empty)
+
+ InsertIntoTable(
+ UnresolvedRelation(tableIdent, None),
+ partitionKeys,
+ query,
+ ctx.OVERWRITE != null,
+ ctx.EXISTS != null)
+ }
+
+ /**
+ * Create a partition specification map.
+ */
+ override def visitPartitionSpec(
+ ctx: PartitionSpecContext): Map[String, Option[String]] = withOrigin(ctx) {
+ ctx.partitionVal.asScala.map { pVal =>
+ val name = pVal.identifier.getText.toLowerCase
+ val value = Option(pVal.constant).map(visitStringConstant)
+ name -> value
+ }.toMap
+ }
+
+ /**
+ * Create a partition specification map without optional values.
+ */
+ protected def visitNonOptionalPartitionSpec(
+ ctx: PartitionSpecContext): Map[String, String] = withOrigin(ctx) {
+ visitPartitionSpec(ctx).mapValues(_.orNull).map(identity)
+ }
+
+ /**
+ * Convert a constant of any type into a string. This is typically used in DDL commands, and its
+ * main purpose is to prevent slight differences due to back to back conversions i.e.:
+ * String -> Literal -> String.
+ */
+ protected def visitStringConstant(ctx: ConstantContext): String = withOrigin(ctx) {
+ ctx match {
+ case s: StringLiteralContext => createString(s)
+ case o => o.getText
+ }
+ }
+
+ /**
+ * Add ORDER BY/SORT BY/CLUSTER BY/DISTRIBUTE BY/LIMIT/WINDOWS clauses to the logical plan. These
+ * clauses determine the shape (ordering/partitioning/rows) of the query result.
+ */
+ private def withQueryResultClauses(
+ ctx: QueryOrganizationContext,
+ query: LogicalPlan): LogicalPlan = withOrigin(ctx) {
+ import ctx._
+
+ // Handle ORDER BY, SORT BY, DISTRIBUTE BY, and CLUSTER BY clause.
+ val withOrder = if (
+ !order.isEmpty && sort.isEmpty && distributeBy.isEmpty && clusterBy.isEmpty) {
+ // ORDER BY ...
+ Sort(order.asScala.map(visitSortItem), global = true, query)
+ } else if (order.isEmpty && !sort.isEmpty && distributeBy.isEmpty && clusterBy.isEmpty) {
+ // SORT BY ...
+ Sort(sort.asScala.map(visitSortItem), global = false, query)
+ } else if (order.isEmpty && sort.isEmpty && !distributeBy.isEmpty && clusterBy.isEmpty) {
+ // DISTRIBUTE BY ...
+ RepartitionByExpression(expressionList(distributeBy), query)
+ } else if (order.isEmpty && !sort.isEmpty && !distributeBy.isEmpty && clusterBy.isEmpty) {
+ // SORT BY ... DISTRIBUTE BY ...
+ Sort(
+ sort.asScala.map(visitSortItem),
+ global = false,
+ RepartitionByExpression(expressionList(distributeBy), query))
+ } else if (order.isEmpty && sort.isEmpty && distributeBy.isEmpty && !clusterBy.isEmpty) {
+ // CLUSTER BY ...
+ val expressions = expressionList(clusterBy)
+ Sort(
+ expressions.map(SortOrder(_, Ascending)),
+ global = false,
+ RepartitionByExpression(expressions, query))
+ } else if (order.isEmpty && sort.isEmpty && distributeBy.isEmpty && clusterBy.isEmpty) {
+ // [EMPTY]
+ query
+ } else {
+ throw new ParseException(
+ "Combination of ORDER BY/SORT BY/DISTRIBUTE BY/CLUSTER BY is not supported", ctx)
+ }
+
+ // WINDOWS
+ val withWindow = withOrder.optionalMap(windows)(withWindows)
+
+ // LIMIT
+ withWindow.optional(limit) {
+ Limit(typedVisit(limit), withWindow)
+ }
+ }
+
+ /**
+ * Create a logical plan using a query specification.
+ */
+ override def visitQuerySpecification(
+ ctx: QuerySpecificationContext): LogicalPlan = withOrigin(ctx) {
+ val from = OneRowRelation.optional(ctx.fromClause) {
+ visitFromClause(ctx.fromClause)
+ }
+ withQuerySpecification(ctx, from)
+ }
+
+ /**
+ * Add a query specification to a logical plan. The query specification is the core of the logical
+ * plan, this is where sourcing (FROM clause), transforming (SELECT TRANSFORM/MAP/REDUCE),
+ * projection (SELECT), aggregation (GROUP BY ... HAVING ...) and filtering (WHERE) takes place.
+ *
+ * Note that query hints are ignored (both by the parser and the builder).
+ */
+ private def withQuerySpecification(
+ ctx: QuerySpecificationContext,
+ relation: LogicalPlan): LogicalPlan = withOrigin(ctx) {
+ import ctx._
+
+ // WHERE
+ def filter(ctx: BooleanExpressionContext, plan: LogicalPlan): LogicalPlan = {
+ Filter(expression(ctx), plan)
+ }
+
+ // Expressions.
+ val expressions = Option(namedExpressionSeq).toSeq
+ .flatMap(_.namedExpression.asScala)
+ .map(typedVisit[Expression])
+
+ // Create either a transform or a regular query.
+ val specType = Option(kind).map(_.getType).getOrElse(SqlBaseParser.SELECT)
+ specType match {
+ case SqlBaseParser.MAP | SqlBaseParser.REDUCE | SqlBaseParser.TRANSFORM =>
+ // Transform
+
+ // Add where.
+ val withFilter = relation.optionalMap(where)(filter)
+
+ // Create the attributes.
+ val (attributes, schemaLess) = if (colTypeList != null) {
+ // Typed return columns.
+ (createStructType(colTypeList).toAttributes, false)
+ } else if (identifierSeq != null) {
+ // Untyped return columns.
+ val attrs = visitIdentifierSeq(identifierSeq).map { name =>
+ AttributeReference(name, StringType, nullable = true)()
+ }
+ (attrs, false)
+ } else {
+ (Seq(AttributeReference("key", StringType)(),
+ AttributeReference("value", StringType)()), true)
+ }
+
+ // Create the transform.
+ ScriptTransformation(
+ expressions,
+ string(script),
+ attributes,
+ withFilter,
+ withScriptIOSchema(
+ ctx, inRowFormat, recordWriter, outRowFormat, recordReader, schemaLess))
+
+ case SqlBaseParser.SELECT =>
+ // Regular select
+
+ // Add lateral views.
+ val withLateralView = ctx.lateralView.asScala.foldLeft(relation)(withGenerate)
+
+ // Add where.
+ val withFilter = withLateralView.optionalMap(where)(filter)
+
+ // Add aggregation or a project.
+ val namedExpressions = expressions.map {
+ case e: NamedExpression => e
+ case e: Expression => UnresolvedAlias(e)
+ }
+ val withProject = if (aggregation != null) {
+ withAggregation(aggregation, namedExpressions, withFilter)
+ } else if (namedExpressions.nonEmpty) {
+ Project(namedExpressions, withFilter)
+ } else {
+ withFilter
+ }
+
+ // Having
+ val withHaving = withProject.optional(having) {
+ // Note that we added a cast to boolean. If the expression itself is already boolean,
+ // the optimizer will get rid of the unnecessary cast.
+ Filter(Cast(expression(having), BooleanType), withProject)
+ }
+
+ // Distinct
+ val withDistinct = if (setQuantifier() != null && setQuantifier().DISTINCT() != null) {
+ Distinct(withHaving)
+ } else {
+ withHaving
+ }
+
+ // Window
+ withDistinct.optionalMap(windows)(withWindows)
+ }
+ }
+
+ /**
+ * Create a (Hive based) [[ScriptInputOutputSchema]].
+ */
+ protected def withScriptIOSchema(
+ ctx: QuerySpecificationContext,
+ inRowFormat: RowFormatContext,
+ recordWriter: Token,
+ outRowFormat: RowFormatContext,
+ recordReader: Token,
+ schemaLess: Boolean): ScriptInputOutputSchema = {
+ throw new ParseException("Script Transform is not supported", ctx)
+ }
+
+ /**
+ * Create a logical plan for a given 'FROM' clause. Note that we support multiple (comma
+ * separated) relations here, these get converted into a single plan by condition-less inner join.
+ */
+ override def visitFromClause(ctx: FromClauseContext): LogicalPlan = withOrigin(ctx) {
+ val from = ctx.relation.asScala.map(plan).reduceLeft(Join(_, _, Inner, None))
+ ctx.lateralView.asScala.foldLeft(from)(withGenerate)
+ }
+
+ /**
+ * Connect two queries by a Set operator.
+ *
+ * Supported Set operators are:
+ * - UNION [DISTINCT]
+ * - UNION ALL
+ * - EXCEPT [DISTINCT]
+ * - INTERSECT [DISTINCT]
+ */
+ override def visitSetOperation(ctx: SetOperationContext): LogicalPlan = withOrigin(ctx) {
+ val left = plan(ctx.left)
+ val right = plan(ctx.right)
+ val all = Option(ctx.setQuantifier()).exists(_.ALL != null)
+ ctx.operator.getType match {
+ case SqlBaseParser.UNION if all =>
+ Union(left, right)
+ case SqlBaseParser.UNION =>
+ Distinct(Union(left, right))
+ case SqlBaseParser.INTERSECT if all =>
+ throw new ParseException("INTERSECT ALL is not supported.", ctx)
+ case SqlBaseParser.INTERSECT =>
+ Intersect(left, right)
+ case SqlBaseParser.EXCEPT if all =>
+ throw new ParseException("EXCEPT ALL is not supported.", ctx)
+ case SqlBaseParser.EXCEPT =>
+ Except(left, right)
+ }
+ }
+
+ /**
+ * Add a [[WithWindowDefinition]] operator to a logical plan.
+ */
+ private def withWindows(
+ ctx: WindowsContext,
+ query: LogicalPlan): LogicalPlan = withOrigin(ctx) {
+ // Collect all window specifications defined in the WINDOW clause.
+ val baseWindowMap = ctx.namedWindow.asScala.map {
+ wCtx =>
+ (wCtx.identifier.getText, typedVisit[WindowSpec](wCtx.windowSpec))
+ }.toMap
+
+ // Handle cases like
+ // window w1 as (partition by p_mfgr order by p_name
+ // range between 2 preceding and 2 following),
+ // w2 as w1
+ val windowMapView = baseWindowMap.mapValues {
+ case WindowSpecReference(name) =>
+ baseWindowMap.get(name) match {
+ case Some(spec: WindowSpecDefinition) =>
+ spec
+ case Some(ref) =>
+ throw new ParseException(s"Window reference '$name' is not a window specification", ctx)
+ case None =>
+ throw new ParseException(s"Cannot resolve window reference '$name'", ctx)
+ }
+ case spec: WindowSpecDefinition => spec
+ }
+
+ // Note that mapValues creates a view instead of materialized map. We force materialization by
+ // mapping over identity.
+ WithWindowDefinition(windowMapView.map(identity), query)
+ }
+
+ /**
+ * Add an [[Aggregate]] to a logical plan.
+ */
+ private def withAggregation(
+ ctx: AggregationContext,
+ selectExpressions: Seq[NamedExpression],
+ query: LogicalPlan): LogicalPlan = withOrigin(ctx) {
+ import ctx._
+ val groupByExpressions = expressionList(groupingExpressions)
+
+ if (GROUPING != null) {
+ // GROUP BY .... GROUPING SETS (...)
+ val expressionMap = groupByExpressions.zipWithIndex.toMap
+ val numExpressions = expressionMap.size
+ val mask = (1 << numExpressions) - 1
+ val masks = ctx.groupingSet.asScala.map {
+ _.expression.asScala.foldLeft(mask) {
+ case (bitmap, eCtx) =>
+ // Find the index of the expression.
+ val e = typedVisit[Expression](eCtx)
+ val index = expressionMap.find(_._1.semanticEquals(e)).map(_._2).getOrElse(
+ throw new ParseException(
+ s"$e doesn't show up in the GROUP BY list", ctx))
+ // 0 means that the column at the given index is a grouping column, 1 means it is not,
+ // so we unset the bit in bitmap.
+ bitmap & ~(1 << (numExpressions - 1 - index))
+ }
+ }
+ GroupingSets(masks, groupByExpressions, query, selectExpressions)
+ } else {
+ // GROUP BY .... (WITH CUBE | WITH ROLLUP)?
+ val mappedGroupByExpressions = if (CUBE != null) {
+ Seq(Cube(groupByExpressions))
+ } else if (ROLLUP != null) {
+ Seq(Rollup(groupByExpressions))
+ } else {
+ groupByExpressions
+ }
+ Aggregate(mappedGroupByExpressions, selectExpressions, query)
+ }
+ }
+
+ /**
+ * Add a [[Generate]] (Lateral View) to a logical plan.
+ */
+ private def withGenerate(
+ query: LogicalPlan,
+ ctx: LateralViewContext): LogicalPlan = withOrigin(ctx) {
+ val expressions = expressionList(ctx.expression)
+
+ // Create the generator.
+ val generator = ctx.qualifiedName.getText.toLowerCase match {
+ case "explode" if expressions.size == 1 =>
+ Explode(expressions.head)
+ case "json_tuple" =>
+ JsonTuple(expressions)
+ case name =>
+ UnresolvedGenerator(name, expressions)
+ }
+
+ Generate(
+ generator,
+ join = true,
+ outer = ctx.OUTER != null,
+ Some(ctx.tblName.getText.toLowerCase),
+ ctx.colName.asScala.map(_.getText).map(UnresolvedAttribute.apply),
+ query)
+ }
+
+ /**
+ * Create a joins between two or more logical plans.
+ */
+ override def visitJoinRelation(ctx: JoinRelationContext): LogicalPlan = withOrigin(ctx) {
+ /** Build a join between two plans. */
+ def join(ctx: JoinRelationContext, left: LogicalPlan, right: LogicalPlan): Join = {
+ val baseJoinType = ctx.joinType match {
+ case null => Inner
+ case jt if jt.FULL != null => FullOuter
+ case jt if jt.SEMI != null => LeftSemi
+ case jt if jt.ANTI != null => LeftAnti
+ case jt if jt.LEFT != null => LeftOuter
+ case jt if jt.RIGHT != null => RightOuter
+ case _ => Inner
+ }
+
+ // Resolve the join type and join condition
+ val (joinType, condition) = Option(ctx.joinCriteria) match {
+ case Some(c) if c.USING != null =>
+ val columns = c.identifier.asScala.map { column =>
+ UnresolvedAttribute.quoted(column.getText)
+ }
+ (UsingJoin(baseJoinType, columns), None)
+ case Some(c) if c.booleanExpression != null =>
+ (baseJoinType, Option(expression(c.booleanExpression)))
+ case None if ctx.NATURAL != null =>
+ (NaturalJoin(baseJoinType), None)
+ case None =>
+ (baseJoinType, None)
+ }
+ Join(left, right, joinType, condition)
+ }
+
+ // Handle all consecutive join clauses. ANTLR produces a right nested tree in which the the
+ // first join clause is at the top. However fields of previously referenced tables can be used
+ // in following join clauses. The tree needs to be reversed in order to make this work.
+ var result = plan(ctx.left)
+ var current = ctx
+ while (current != null) {
+ current.right match {
+ case right: JoinRelationContext =>
+ result = join(current, result, plan(right.left))
+ current = right
+ case right =>
+ result = join(current, result, plan(right))
+ current = null
+ }
+ }
+ result
+ }
+
+ /**
+ * Add a [[Sample]] to a logical plan.
+ *
+ * This currently supports the following sampling methods:
+ * - TABLESAMPLE(x ROWS): Sample the table down to the given number of rows.
+ * - TABLESAMPLE(x PERCENT): Sample the table down to the given percentage. Note that percentages
+ * are defined as a number between 0 and 100.
+ * - TABLESAMPLE(BUCKET x OUT OF y): Sample the table down to a 'x' divided by 'y' fraction.
+ */
+ private def withSample(ctx: SampleContext, query: LogicalPlan): LogicalPlan = withOrigin(ctx) {
+ // Create a sampled plan if we need one.
+ def sample(fraction: Double): Sample = {
+ // The range of fraction accepted by Sample is [0, 1]. Because Hive's block sampling
+ // function takes X PERCENT as the input and the range of X is [0, 100], we need to
+ // adjust the fraction.
+ val eps = RandomSampler.roundingEpsilon
+ assert(fraction >= 0.0 - eps && fraction <= 1.0 + eps,
+ s"Sampling fraction ($fraction) must be on interval [0, 1]",
+ ctx)
+ Sample(0.0, fraction, withReplacement = false, (math.random * 1000).toInt, query)(true)
+ }
+
+ ctx.sampleType.getType match {
+ case SqlBaseParser.ROWS =>
+ Limit(expression(ctx.expression), query)
+
+ case SqlBaseParser.PERCENTLIT =>
+ val fraction = ctx.percentage.getText.toDouble
+ sample(fraction / 100.0d)
+
+ case SqlBaseParser.BUCKET if ctx.ON != null =>
+ throw new ParseException("TABLESAMPLE(BUCKET x OUT OF y ON id) is not supported", ctx)
+
+ case SqlBaseParser.BUCKET =>
+ sample(ctx.numerator.getText.toDouble / ctx.denominator.getText.toDouble)
+ }
+ }
+
+ /**
+ * Create a logical plan for a sub-query.
+ */
+ override def visitSubquery(ctx: SubqueryContext): LogicalPlan = withOrigin(ctx) {
+ plan(ctx.queryNoWith)
+ }
+
+ /**
+ * Create an un-aliased table reference. This is typically used for top-level table references,
+ * for example:
+ * {{{
+ * INSERT INTO db.tbl2
+ * TABLE db.tbl1
+ * }}}
+ */
+ override def visitTable(ctx: TableContext): LogicalPlan = withOrigin(ctx) {
+ UnresolvedRelation(visitTableIdentifier(ctx.tableIdentifier), None)
+ }
+
+ /**
+ * Create an aliased table reference. This is typically used in FROM clauses.
+ */
+ override def visitTableName(ctx: TableNameContext): LogicalPlan = withOrigin(ctx) {
+ val table = UnresolvedRelation(
+ visitTableIdentifier(ctx.tableIdentifier),
+ Option(ctx.identifier).map(_.getText))
+ table.optionalMap(ctx.sample)(withSample)
+ }
+
+ /**
+ * Create an inline table (a virtual table in Hive parlance).
+ */
+ override def visitInlineTable(ctx: InlineTableContext): LogicalPlan = withOrigin(ctx) {
+ // Get the backing expressions.
+ val expressions = ctx.expression.asScala.map { eCtx =>
+ val e = expression(eCtx)
+ assert(e.foldable, "All expressions in an inline table must be constants.", eCtx)
+ e
+ }
+
+ // Validate and evaluate the rows.
+ val (structType, structConstructor) = expressions.head.dataType match {
+ case st: StructType =>
+ (st, (e: Expression) => e)
+ case dt =>
+ val st = CreateStruct(Seq(expressions.head)).dataType
+ (st, (e: Expression) => CreateStruct(Seq(e)))
+ }
+ val rows = expressions.map {
+ case expression =>
+ val safe = Cast(structConstructor(expression), structType)
+ safe.eval().asInstanceOf[InternalRow]
+ }
+
+ // Construct attributes.
+ val baseAttributes = structType.toAttributes.map(_.withNullability(true))
+ val attributes = if (ctx.identifierList != null) {
+ val aliases = visitIdentifierList(ctx.identifierList)
+ assert(aliases.size == baseAttributes.size,
+ "Number of aliases must match the number of fields in an inline table.", ctx)
+ baseAttributes.zip(aliases).map(p => p._1.withName(p._2))
+ } else {
+ baseAttributes
+ }
+
+ // Create plan and add an alias if a name has been defined.
+ LocalRelation(attributes, rows).optionalMap(ctx.identifier)(aliasPlan)
+ }
+
+ /**
+ * Create an alias (SubqueryAlias) for a join relation. This is practically the same as
+ * visitAliasedQuery and visitNamedExpression, ANTLR4 however requires us to use 3 different
+ * hooks.
+ */
+ override def visitAliasedRelation(ctx: AliasedRelationContext): LogicalPlan = withOrigin(ctx) {
+ plan(ctx.relation).optionalMap(ctx.sample)(withSample).optionalMap(ctx.identifier)(aliasPlan)
+ }
+
+ /**
+ * Create an alias (SubqueryAlias) for a sub-query. This is practically the same as
+ * visitAliasedRelation and visitNamedExpression, ANTLR4 however requires us to use 3 different
+ * hooks.
+ */
+ override def visitAliasedQuery(ctx: AliasedQueryContext): LogicalPlan = withOrigin(ctx) {
+ plan(ctx.queryNoWith).optionalMap(ctx.sample)(withSample).optionalMap(ctx.identifier)(aliasPlan)
+ }
+
+ /**
+ * Create an alias (SubqueryAlias) for a LogicalPlan.
+ */
+ private def aliasPlan(alias: IdentifierContext, plan: LogicalPlan): LogicalPlan = {
+ SubqueryAlias(alias.getText, plan)
+ }
+
+ /**
+ * Create a Sequence of Strings for a parenthesis enclosed alias list.
+ */
+ override def visitIdentifierList(ctx: IdentifierListContext): Seq[String] = withOrigin(ctx) {
+ visitIdentifierSeq(ctx.identifierSeq)
+ }
+
+ /**
+ * Create a Sequence of Strings for an identifier list.
+ */
+ override def visitIdentifierSeq(ctx: IdentifierSeqContext): Seq[String] = withOrigin(ctx) {
+ ctx.identifier.asScala.map(_.getText)
+ }
+
+ /* ********************************************************************************************
+ * Table Identifier parsing
+ * ******************************************************************************************** */
+ /**
+ * Create a [[TableIdentifier]] from a 'tableName' or 'databaseName'.'tableName' pattern.
+ */
+ override def visitTableIdentifier(
+ ctx: TableIdentifierContext): TableIdentifier = withOrigin(ctx) {
+ TableIdentifier(ctx.table.getText, Option(ctx.db).map(_.getText))
+ }
+
+ /* ********************************************************************************************
+ * Expression parsing
+ * ******************************************************************************************** */
+ /**
+ * Create an expression from the given context. This method just passes the context on to the
+ * vistor and only takes care of typing (We assume that the visitor returns an Expression here).
+ */
+ protected def expression(ctx: ParserRuleContext): Expression = typedVisit(ctx)
+
+ /**
+ * Create sequence of expressions from the given sequence of contexts.
+ */
+ private def expressionList(trees: java.util.List[ExpressionContext]): Seq[Expression] = {
+ trees.asScala.map(expression)
+ }
+
+ /**
+ * Create a star (i.e. all) expression; this selects all elements (in the specified object).
+ * Both un-targeted (global) and targeted aliases are supported.
+ */
+ override def visitStar(ctx: StarContext): Expression = withOrigin(ctx) {
+ UnresolvedStar(Option(ctx.qualifiedName()).map(_.identifier.asScala.map(_.getText)))
+ }
+
+ /**
+ * Create an aliased expression if an alias is specified. Both single and multi-aliases are
+ * supported.
+ */
+ override def visitNamedExpression(ctx: NamedExpressionContext): Expression = withOrigin(ctx) {
+ val e = expression(ctx.expression)
+ if (ctx.identifier != null) {
+ Alias(e, ctx.identifier.getText)()
+ } else if (ctx.identifierList != null) {
+ MultiAlias(e, visitIdentifierList(ctx.identifierList))
+ } else {
+ e
+ }
+ }
+
+ /**
+ * Combine a number of boolean expressions into a balanced expression tree. These expressions are
+ * either combined by a logical [[And]] or a logical [[Or]].
+ *
+ * A balanced binary tree is created because regular left recursive trees cause considerable
+ * performance degradations and can cause stack overflows.
+ */
+ override def visitLogicalBinary(ctx: LogicalBinaryContext): Expression = withOrigin(ctx) {
+ val expressionType = ctx.operator.getType
+ val expressionCombiner = expressionType match {
+ case SqlBaseParser.AND => And.apply _
+ case SqlBaseParser.OR => Or.apply _
+ }
+
+ // Collect all similar left hand contexts.
+ val contexts = ArrayBuffer(ctx.right)
+ var current = ctx.left
+ def collectContexts: Boolean = current match {
+ case lbc: LogicalBinaryContext if lbc.operator.getType == expressionType =>
+ contexts += lbc.right
+ current = lbc.left
+ true
+ case _ =>
+ contexts += current
+ false
+ }
+ while (collectContexts) {
+ // No body - all updates take place in the collectContexts.
+ }
+
+ // Reverse the contexts to have them in the same sequence as in the SQL statement & turn them
+ // into expressions.
+ val expressions = contexts.reverse.map(expression)
+
+ // Create a balanced tree.
+ def reduceToExpressionTree(low: Int, high: Int): Expression = high - low match {
+ case 0 =>
+ expressions(low)
+ case 1 =>
+ expressionCombiner(expressions(low), expressions(high))
+ case x =>
+ val mid = low + x / 2
+ expressionCombiner(
+ reduceToExpressionTree(low, mid),
+ reduceToExpressionTree(mid + 1, high))
+ }
+ reduceToExpressionTree(0, expressions.size - 1)
+ }
+
+ /**
+ * Invert a boolean expression.
+ */
+ override def visitLogicalNot(ctx: LogicalNotContext): Expression = withOrigin(ctx) {
+ Not(expression(ctx.booleanExpression()))
+ }
+
+ /**
+ * Create a filtering correlated sub-query. This is not supported yet.
+ */
+ override def visitExists(ctx: ExistsContext): Expression = {
+ throw new ParseException("EXISTS clauses are not supported.", ctx)
+ }
+
+ /**
+ * Create a comparison expression. This compares two expressions. The following comparison
+ * operators are supported:
+ * - Equal: '=' or '=='
+ * - Null-safe Equal: '<=>'
+ * - Not Equal: '<>' or '!='
+ * - Less than: '<'
+ * - Less then or Equal: '<='
+ * - Greater than: '>'
+ * - Greater then or Equal: '>='
+ */
+ override def visitComparison(ctx: ComparisonContext): Expression = withOrigin(ctx) {
+ val left = expression(ctx.left)
+ val right = expression(ctx.right)
+ val operator = ctx.comparisonOperator().getChild(0).asInstanceOf[TerminalNode]
+ operator.getSymbol.getType match {
+ case SqlBaseParser.EQ =>
+ EqualTo(left, right)
+ case SqlBaseParser.NSEQ =>
+ EqualNullSafe(left, right)
+ case SqlBaseParser.NEQ | SqlBaseParser.NEQJ =>
+ Not(EqualTo(left, right))
+ case SqlBaseParser.LT =>
+ LessThan(left, right)
+ case SqlBaseParser.LTE =>
+ LessThanOrEqual(left, right)
+ case SqlBaseParser.GT =>
+ GreaterThan(left, right)
+ case SqlBaseParser.GTE =>
+ GreaterThanOrEqual(left, right)
+ }
+ }
+
+ /**
+ * Create a predicated expression. A predicated expression is a normal expression with a
+ * predicate attached to it, for example:
+ * {{{
+ * a + 1 IS NULL
+ * }}}
+ */
+ override def visitPredicated(ctx: PredicatedContext): Expression = withOrigin(ctx) {
+ val e = expression(ctx.valueExpression)
+ if (ctx.predicate != null) {
+ withPredicate(e, ctx.predicate)
+ } else {
+ e
+ }
+ }
+
+ /**
+ * Add a predicate to the given expression. Supported expressions are:
+ * - (NOT) BETWEEN
+ * - (NOT) IN
+ * - (NOT) LIKE
+ * - (NOT) RLIKE
+ * - IS (NOT) NULL.
+ */
+ private def withPredicate(e: Expression, ctx: PredicateContext): Expression = withOrigin(ctx) {
+ // Invert a predicate if it has a valid NOT clause.
+ def invertIfNotDefined(e: Expression): Expression = ctx.NOT match {
+ case null => e
+ case not => Not(e)
+ }
+
+ // Create the predicate.
+ ctx.kind.getType match {
+ case SqlBaseParser.BETWEEN =>
+ // BETWEEN is translated to lower <= e && e <= upper
+ invertIfNotDefined(And(
+ GreaterThanOrEqual(e, expression(ctx.lower)),
+ LessThanOrEqual(e, expression(ctx.upper))))
+ case SqlBaseParser.IN if ctx.query != null =>
+ throw new ParseException("IN with a Sub-query is currently not supported.", ctx)
+ case SqlBaseParser.IN =>
+ invertIfNotDefined(In(e, ctx.expression.asScala.map(expression)))
+ case SqlBaseParser.LIKE =>
+ invertIfNotDefined(Like(e, expression(ctx.pattern)))
+ case SqlBaseParser.RLIKE =>
+ invertIfNotDefined(RLike(e, expression(ctx.pattern)))
+ case SqlBaseParser.NULL if ctx.NOT != null =>
+ IsNotNull(e)
+ case SqlBaseParser.NULL =>
+ IsNull(e)
+ }
+ }
+
+ /**
+ * Create a binary arithmetic expression. The following arithmetic operators are supported:
+ * - Multiplication: '*'
+ * - Division: '/'
+ * - Hive Long Division: 'DIV'
+ * - Modulo: '%'
+ * - Addition: '+'
+ * - Subtraction: '-'
+ * - Binary AND: '&'
+ * - Binary XOR
+ * - Binary OR: '|'
+ */
+ override def visitArithmeticBinary(ctx: ArithmeticBinaryContext): Expression = withOrigin(ctx) {
+ val left = expression(ctx.left)
+ val right = expression(ctx.right)
+ ctx.operator.getType match {
+ case SqlBaseParser.ASTERISK =>
+ Multiply(left, right)
+ case SqlBaseParser.SLASH =>
+ Divide(left, right)
+ case SqlBaseParser.PERCENT =>
+ Remainder(left, right)
+ case SqlBaseParser.DIV =>
+ Cast(Divide(left, right), LongType)
+ case SqlBaseParser.PLUS =>
+ Add(left, right)
+ case SqlBaseParser.MINUS =>
+ Subtract(left, right)
+ case SqlBaseParser.AMPERSAND =>
+ BitwiseAnd(left, right)
+ case SqlBaseParser.HAT =>
+ BitwiseXor(left, right)
+ case SqlBaseParser.PIPE =>
+ BitwiseOr(left, right)
+ }
+ }
+
+ /**
+ * Create a unary arithmetic expression. The following arithmetic operators are supported:
+ * - Plus: '+'
+ * - Minus: '-'
+ * - Bitwise Not: '~'
+ */
+ override def visitArithmeticUnary(ctx: ArithmeticUnaryContext): Expression = withOrigin(ctx) {
+ val value = expression(ctx.valueExpression)
+ ctx.operator.getType match {
+ case SqlBaseParser.PLUS =>
+ value
+ case SqlBaseParser.MINUS =>
+ UnaryMinus(value)
+ case SqlBaseParser.TILDE =>
+ BitwiseNot(value)
+ }
+ }
+
+ /**
+ * Create a [[Cast]] expression.
+ */
+ override def visitCast(ctx: CastContext): Expression = withOrigin(ctx) {
+ Cast(expression(ctx.expression), typedVisit(ctx.dataType))
+ }
+
+ /**
+ * Create a (windowed) Function expression.
+ */
+ override def visitFunctionCall(ctx: FunctionCallContext): Expression = withOrigin(ctx) {
+ // Create the function call.
+ val name = ctx.qualifiedName.getText
+ val isDistinct = Option(ctx.setQuantifier()).exists(_.DISTINCT != null)
+ val arguments = ctx.expression().asScala.map(expression) match {
+ case Seq(UnresolvedStar(None)) if name.toLowerCase == "count" && !isDistinct =>
+ // Transform COUNT(*) into COUNT(1). Move this to analysis?
+ Seq(Literal(1))
+ case expressions =>
+ expressions
+ }
+ val function = UnresolvedFunction(name, arguments, isDistinct)
+
+ // Check if the function is evaluated in a windowed context.
+ ctx.windowSpec match {
+ case spec: WindowRefContext =>
+ UnresolvedWindowExpression(function, visitWindowRef(spec))
+ case spec: WindowDefContext =>
+ WindowExpression(function, visitWindowDef(spec))
+ case _ => function
+ }
+ }
+
+ /**
+ * Create a reference to a window frame, i.e. [[WindowSpecReference]].
+ */
+ override def visitWindowRef(ctx: WindowRefContext): WindowSpecReference = withOrigin(ctx) {
+ WindowSpecReference(ctx.identifier.getText)
+ }
+
+ /**
+ * Create a window definition, i.e. [[WindowSpecDefinition]].
+ */
+ override def visitWindowDef(ctx: WindowDefContext): WindowSpecDefinition = withOrigin(ctx) {
+ // CLUSTER BY ... | PARTITION BY ... ORDER BY ...
+ val partition = ctx.partition.asScala.map(expression)
+ val order = ctx.sortItem.asScala.map(visitSortItem)
+
+ // RANGE/ROWS BETWEEN ...
+ val frameSpecOption = Option(ctx.windowFrame).map { frame =>
+ val frameType = frame.frameType.getType match {
+ case SqlBaseParser.RANGE => RangeFrame
+ case SqlBaseParser.ROWS => RowFrame
+ }
+
+ SpecifiedWindowFrame(
+ frameType,
+ visitFrameBound(frame.start),
+ Option(frame.end).map(visitFrameBound).getOrElse(CurrentRow))
+ }
+
+ WindowSpecDefinition(
+ partition,
+ order,
+ frameSpecOption.getOrElse(UnspecifiedFrame))
+ }
+
+ /**
+ * Create or resolve a [[FrameBoundary]]. Simple math expressions are allowed for Value
+ * Preceding/Following boundaries. These expressions must be constant (foldable) and return an
+ * integer value.
+ */
+ override def visitFrameBound(ctx: FrameBoundContext): FrameBoundary = withOrigin(ctx) {
+ // We currently only allow foldable integers.
+ def value: Int = {
+ val e = expression(ctx.expression)
+ assert(e.resolved && e.foldable && e.dataType == IntegerType,
+ "Frame bound value must be a constant integer.",
+ ctx)
+ e.eval().asInstanceOf[Int]
+ }
+
+ // Create the FrameBoundary
+ ctx.boundType.getType match {
+ case SqlBaseParser.PRECEDING if ctx.UNBOUNDED != null =>
+ UnboundedPreceding
+ case SqlBaseParser.PRECEDING =>
+ ValuePreceding(value)
+ case SqlBaseParser.CURRENT =>
+ CurrentRow
+ case SqlBaseParser.FOLLOWING if ctx.UNBOUNDED != null =>
+ UnboundedFollowing
+ case SqlBaseParser.FOLLOWING =>
+ ValueFollowing(value)
+ }
+ }
+
+ /**
+ * Create a [[CreateStruct]] expression.
+ */
+ override def visitRowConstructor(ctx: RowConstructorContext): Expression = withOrigin(ctx) {
+ CreateStruct(ctx.expression.asScala.map(expression))
+ }
+
+ /**
+ * Create a [[ScalarSubquery]] expression.
+ */
+ override def visitSubqueryExpression(
+ ctx: SubqueryExpressionContext): Expression = withOrigin(ctx) {
+ ScalarSubquery(plan(ctx.query))
+ }
+
+ /**
+ * Create a value based [[CaseWhen]] expression. This has the following SQL form:
+ * {{{
+ * CASE [expression]
+ * WHEN [value] THEN [expression]
+ * ...
+ * ELSE [expression]
+ * END
+ * }}}
+ */
+ override def visitSimpleCase(ctx: SimpleCaseContext): Expression = withOrigin(ctx) {
+ val e = expression(ctx.valueExpression)
+ val branches = ctx.whenClause.asScala.map { wCtx =>
+ (EqualTo(e, expression(wCtx.condition)), expression(wCtx.result))
+ }
+ CaseWhen(branches, Option(ctx.elseExpression).map(expression))
+ }
+
+ /**
+ * Create a condition based [[CaseWhen]] expression. This has the following SQL syntax:
+ * {{{
+ * CASE
+ * WHEN [predicate] THEN [expression]
+ * ...
+ * ELSE [expression]
+ * END
+ * }}}
+ *
+ * @param ctx the parse tree
+ * */
+ override def visitSearchedCase(ctx: SearchedCaseContext): Expression = withOrigin(ctx) {
+ val branches = ctx.whenClause.asScala.map { wCtx =>
+ (expression(wCtx.condition), expression(wCtx.result))
+ }
+ CaseWhen(branches, Option(ctx.elseExpression).map(expression))
+ }
+
+ /**
+ * Create a dereference expression. The return type depends on the type of the parent, this can
+ * either be a [[UnresolvedAttribute]] (if the parent is an [[UnresolvedAttribute]]), or an
+ * [[UnresolvedExtractValue]] if the parent is some expression.
+ */
+ override def visitDereference(ctx: DereferenceContext): Expression = withOrigin(ctx) {
+ val attr = ctx.fieldName.getText
+ expression(ctx.base) match {
+ case UnresolvedAttribute(nameParts) =>
+ UnresolvedAttribute(nameParts :+ attr)
+ case e =>
+ UnresolvedExtractValue(e, Literal(attr))
+ }
+ }
+
+ /**
+ * Create an [[UnresolvedAttribute]] expression.
+ */
+ override def visitColumnReference(ctx: ColumnReferenceContext): Expression = withOrigin(ctx) {
+ UnresolvedAttribute.quoted(ctx.getText)
+ }
+
+ /**
+ * Create an [[UnresolvedExtractValue]] expression, this is used for subscript access to an array.
+ */
+ override def visitSubscript(ctx: SubscriptContext): Expression = withOrigin(ctx) {
+ UnresolvedExtractValue(expression(ctx.value), expression(ctx.index))
+ }
+
+ /**
+ * Create an expression for an expression between parentheses. This is need because the ANTLR
+ * visitor cannot automatically convert the nested context into an expression.
+ */
+ override def visitParenthesizedExpression(
+ ctx: ParenthesizedExpressionContext): Expression = withOrigin(ctx) {
+ expression(ctx.expression)
+ }
+
+ /**
+ * Create a [[SortOrder]] expression.
+ */
+ override def visitSortItem(ctx: SortItemContext): SortOrder = withOrigin(ctx) {
+ if (ctx.DESC != null) {
+ SortOrder(expression(ctx.expression), Descending)
+ } else {
+ SortOrder(expression(ctx.expression), Ascending)
+ }
+ }
+
+ /**
+ * Create a typed Literal expression. A typed literal has the following SQL syntax:
+ * {{{
+ * [TYPE] '[VALUE]'
+ * }}}
+ * Currently Date and Timestamp typed literals are supported.
+ *
+ * TODO what the added value of this over casting?
+ */
+ override def visitTypeConstructor(ctx: TypeConstructorContext): Literal = withOrigin(ctx) {
+ val value = string(ctx.STRING)
+ ctx.identifier.getText.toUpperCase match {
+ case "DATE" =>
+ Literal(Date.valueOf(value))
+ case "TIMESTAMP" =>
+ Literal(Timestamp.valueOf(value))
+ case other =>
+ throw new ParseException(s"Literals of type '$other' are currently not supported.", ctx)
+ }
+ }
+
+ /**
+ * Create a NULL literal expression.
+ */
+ override def visitNullLiteral(ctx: NullLiteralContext): Literal = withOrigin(ctx) {
+ Literal(null)
+ }
+
+ /**
+ * Create a Boolean literal expression.
+ */
+ override def visitBooleanLiteral(ctx: BooleanLiteralContext): Literal = withOrigin(ctx) {
+ if (ctx.getText.toBoolean) {
+ Literal.TrueLiteral
+ } else {
+ Literal.FalseLiteral
+ }
+ }
+
+ /**
+ * Create an integral literal expression. The code selects the most narrow integral type
+ * possible, either a BigDecimal, a Long or an Integer is returned.
+ */
+ override def visitIntegerLiteral(ctx: IntegerLiteralContext): Literal = withOrigin(ctx) {
+ BigDecimal(ctx.getText) match {
+ case v if v.isValidInt =>
+ Literal(v.intValue())
+ case v if v.isValidLong =>
+ Literal(v.longValue())
+ case v => Literal(v.underlying())
+ }
+ }
+
+ /**
+ * Create a double literal for a number denoted in scientific notation.
+ */
+ override def visitScientificDecimalLiteral(
+ ctx: ScientificDecimalLiteralContext): Literal = withOrigin(ctx) {
+ Literal(ctx.getText.toDouble)
+ }
+
+ /**
+ * Create a decimal literal for a regular decimal number.
+ */
+ override def visitDecimalLiteral(ctx: DecimalLiteralContext): Literal = withOrigin(ctx) {
+ Literal(BigDecimal(ctx.getText).underlying())
+ }
+
+ /** Create a numeric literal expression. */
+ private def numericLiteral(ctx: NumberContext)(f: String => Any): Literal = withOrigin(ctx) {
+ val raw = ctx.getText
+ try {
+ Literal(f(raw.substring(0, raw.length - 1)))
+ } catch {
+ case e: NumberFormatException =>
+ throw new ParseException(e.getMessage, ctx)
+ }
+ }
+
+ /**
+ * Create a Byte Literal expression.
+ */
+ override def visitTinyIntLiteral(ctx: TinyIntLiteralContext): Literal = numericLiteral(ctx) {
+ _.toByte
+ }
+
+ /**
+ * Create a Short Literal expression.
+ */
+ override def visitSmallIntLiteral(ctx: SmallIntLiteralContext): Literal = numericLiteral(ctx) {
+ _.toShort
+ }
+
+ /**
+ * Create a Long Literal expression.
+ */
+ override def visitBigIntLiteral(ctx: BigIntLiteralContext): Literal = numericLiteral(ctx) {
+ _.toLong
+ }
+
+ /**
+ * Create a Double Literal expression.
+ */
+ override def visitDoubleLiteral(ctx: DoubleLiteralContext): Literal = numericLiteral(ctx) {
+ _.toDouble
+ }
+
+ /**
+ * Create a String literal expression.
+ */
+ override def visitStringLiteral(ctx: StringLiteralContext): Literal = withOrigin(ctx) {
+ Literal(createString(ctx))
+ }
+
+ /**
+ * Create a String from a string literal context. This supports multiple consecutive string
+ * literals, these are concatenated, for example this expression "'hello' 'world'" will be
+ * converted into "helloworld".
+ *
+ * Special characters can be escaped by using Hive/C-style escaping.
+ */
+ private def createString(ctx: StringLiteralContext): String = {
+ ctx.STRING().asScala.map(string).mkString
+ }
+
+ /**
+ * Create a [[CalendarInterval]] literal expression. An interval expression can contain multiple
+ * unit value pairs, for instance: interval 2 months 2 days.
+ */
+ override def visitInterval(ctx: IntervalContext): Literal = withOrigin(ctx) {
+ val intervals = ctx.intervalField.asScala.map(visitIntervalField)
+ assert(intervals.nonEmpty, "at least one time unit should be given for interval literal", ctx)
+ Literal(intervals.reduce(_.add(_)))
+ }
+
+ /**
+ * Create a [[CalendarInterval]] for a unit value pair. Two unit configuration types are
+ * supported:
+ * - Single unit.
+ * - From-To unit (only 'YEAR TO MONTH' and 'DAY TO SECOND' are supported).
+ */
+ override def visitIntervalField(ctx: IntervalFieldContext): CalendarInterval = withOrigin(ctx) {
+ import ctx._
+ val s = value.getText
+ try {
+ val interval = (unit.getText.toLowerCase, Option(to).map(_.getText.toLowerCase)) match {
+ case (u, None) if u.endsWith("s") =>
+ // Handle plural forms, e.g: yearS/monthS/weekS/dayS/hourS/minuteS/hourS/...
+ CalendarInterval.fromSingleUnitString(u.substring(0, u.length - 1), s)
+ case (u, None) =>
+ CalendarInterval.fromSingleUnitString(u, s)
+ case ("year", Some("month")) =>
+ CalendarInterval.fromYearMonthString(s)
+ case ("day", Some("second")) =>
+ CalendarInterval.fromDayTimeString(s)
+ case (from, Some(t)) =>
+ throw new ParseException(s"Intervals FROM $from TO $t are not supported.", ctx)
+ }
+ assert(interval != null, "No interval can be constructed", ctx)
+ interval
+ } catch {
+ // Handle Exceptions thrown by CalendarInterval
+ case e: IllegalArgumentException =>
+ val pe = new ParseException(e.getMessage, ctx)
+ pe.setStackTrace(e.getStackTrace)
+ throw pe
+ }
+ }
+
+ /* ********************************************************************************************
+ * DataType parsing
+ * ******************************************************************************************** */
+ /**
+ * Resolve/create a primitive type.
+ */
+ override def visitPrimitiveDataType(ctx: PrimitiveDataTypeContext): DataType = withOrigin(ctx) {
+ (ctx.identifier.getText.toLowerCase, ctx.INTEGER_VALUE().asScala.toList) match {
+ case ("boolean", Nil) => BooleanType
+ case ("tinyint" | "byte", Nil) => ByteType
+ case ("smallint" | "short", Nil) => ShortType
+ case ("int" | "integer", Nil) => IntegerType
+ case ("bigint" | "long", Nil) => LongType
+ case ("float", Nil) => FloatType
+ case ("double", Nil) => DoubleType
+ case ("date", Nil) => DateType
+ case ("timestamp", Nil) => TimestampType
+ case ("char" | "varchar" | "string", Nil) => StringType
+ case ("char" | "varchar", _ :: Nil) => StringType
+ case ("binary", Nil) => BinaryType
+ case ("decimal", Nil) => DecimalType.USER_DEFAULT
+ case ("decimal", precision :: Nil) => DecimalType(precision.getText.toInt, 0)
+ case ("decimal", precision :: scale :: Nil) =>
+ DecimalType(precision.getText.toInt, scale.getText.toInt)
+ case (dt, params) =>
+ throw new ParseException(
+ s"DataType $dt${params.mkString("(", ",", ")")} is not supported.", ctx)
+ }
+ }
+
+ /**
+ * Create a complex DataType. Arrays, Maps and Structures are supported.
+ */
+ override def visitComplexDataType(ctx: ComplexDataTypeContext): DataType = withOrigin(ctx) {
+ ctx.complex.getType match {
+ case SqlBaseParser.ARRAY =>
+ ArrayType(typedVisit(ctx.dataType(0)))
+ case SqlBaseParser.MAP =>
+ MapType(typedVisit(ctx.dataType(0)), typedVisit(ctx.dataType(1)))
+ case SqlBaseParser.STRUCT =>
+ createStructType(ctx.colTypeList())
+ }
+ }
+
+ /**
+ * Create a [[StructType]] from a sequence of [[StructField]]s.
+ */
+ protected def createStructType(ctx: ColTypeListContext): StructType = {
+ StructType(Option(ctx).toSeq.flatMap(visitColTypeList))
+ }
+
+ /**
+ * Create a [[StructType]] from a number of column definitions.
+ */
+ override def visitColTypeList(ctx: ColTypeListContext): Seq[StructField] = withOrigin(ctx) {
+ ctx.colType().asScala.map(visitColType)
+ }
+
+ /**
+ * Create a [[StructField]] from a column definition.
+ */
+ override def visitColType(ctx: ColTypeContext): StructField = withOrigin(ctx) {
+ import ctx._
+
+ // Add the comment to the metadata.
+ val builder = new MetadataBuilder
+ if (STRING != null) {
+ builder.putString("comment", string(STRING))
+ }
+
+ StructField(identifier.getText, typedVisit(dataType), nullable = true, builder.build())
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/CatalystQl.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/CatalystQl.scala
deleted file mode 100644
index c188c5b108..0000000000
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/CatalystQl.scala
+++ /dev/null
@@ -1,933 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.catalyst.parser
-
-import java.sql.Date
-
-import scala.collection.mutable.ArrayBuffer
-import scala.util.matching.Regex
-
-import org.apache.spark.sql.AnalysisException
-import org.apache.spark.sql.catalyst.TableIdentifier
-import org.apache.spark.sql.catalyst.analysis._
-import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.expressions.aggregate.Count
-import org.apache.spark.sql.catalyst.plans._
-import org.apache.spark.sql.catalyst.plans.logical._
-import org.apache.spark.sql.types._
-import org.apache.spark.unsafe.types.CalendarInterval
-import org.apache.spark.util.random.RandomSampler
-
-
-/**
- * This class translates SQL to Catalyst [[LogicalPlan]]s or [[Expression]]s.
- */
-private[sql] class CatalystQl(val conf: ParserConf = SimpleParserConf()) extends ParserInterface {
- import ParserUtils._
-
- /**
- * The safeParse method allows a user to focus on the parsing/AST transformation logic. This
- * method will take care of possible errors during the parsing process.
- */
- protected def safeParse[T](sql: String, ast: ASTNode)(toResult: ASTNode => T): T = {
- try {
- toResult(ast)
- } catch {
- case e: MatchError => throw e
- case e: AnalysisException => throw e
- case e: Exception =>
- throw new AnalysisException(e.getMessage)
- case e: NotImplementedError =>
- throw new AnalysisException(
- s"""Unsupported language features in query
- |== SQL ==
- |$sql
- |== AST ==
- |${ast.treeString}
- |== Error ==
- |$e
- |== Stacktrace ==
- |${e.getStackTrace.head}
- """.stripMargin)
- }
- }
-
- /** Creates LogicalPlan for a given SQL string. */
- def parsePlan(sql: String): LogicalPlan =
- safeParse(sql, ParseDriver.parsePlan(sql, conf))(nodeToPlan)
-
- /** Creates Expression for a given SQL string. */
- def parseExpression(sql: String): Expression =
- safeParse(sql, ParseDriver.parseExpression(sql, conf))(selExprNodeToExpr(_).get)
-
- /** Creates TableIdentifier for a given SQL string. */
- def parseTableIdentifier(sql: String): TableIdentifier =
- safeParse(sql, ParseDriver.parseTableName(sql, conf))(extractTableIdent)
-
- /**
- * SELECT MAX(value) FROM src GROUP BY k1, k2, k3 GROUPING SETS((k1, k2), (k2))
- * is equivalent to
- * SELECT MAX(value) FROM src GROUP BY k1, k2 UNION SELECT MAX(value) FROM src GROUP BY k2
- * Check the following link for details.
- *
-https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C+Grouping+and+Rollup
- *
- * The bitmask denotes the grouping expressions validity for a grouping set,
- * the bitmask also be called as grouping id (`GROUPING__ID`, the virtual column in Hive)
- * e.g. In superset (k1, k2, k3), (bit 2: k1, bit 1: k2, and bit 0: k3), the grouping id of
- * GROUPING SETS (k1, k2) and (k2) should be 1 and 5 respectively.
- */
- protected def extractGroupingSet(children: Seq[ASTNode]): (Seq[Expression], Seq[Int]) = {
- val (keyASTs, setASTs) = children.partition {
- case Token("TOK_GROUPING_SETS_EXPRESSION", _) => false // grouping sets
- case _ => true // grouping keys
- }
-
- val keys = keyASTs.map(nodeToExpr)
- val keyMap = keyASTs.zipWithIndex.toMap
-
- val mask = (1 << keys.length) - 1
- val bitmasks: Seq[Int] = setASTs.map {
- case Token("TOK_GROUPING_SETS_EXPRESSION", columns) =>
- columns.foldLeft(mask)((bitmap, col) => {
- val keyIndex = keyMap.find(_._1.treeEquals(col)).map(_._2).getOrElse(
- throw new AnalysisException(s"${col.treeString} doesn't show up in the GROUP BY list"))
- // 0 means that the column at the given index is a grouping column, 1 means it is not,
- // so we unset the bit in bitmap.
- bitmap & ~(1 << (keys.length - 1 - keyIndex))
- })
- case _ => sys.error("Expect GROUPING SETS clause")
- }
-
- (keys, bitmasks)
- }
-
- protected def nodeToPlan(node: ASTNode): LogicalPlan = node match {
- case Token("TOK_SHOWFUNCTIONS", args) =>
- // Skip LIKE.
- val pattern = args match {
- case like :: nodes if like.text.toUpperCase == "LIKE" => nodes
- case nodes => nodes
- }
-
- // Extract Database and Function name
- pattern match {
- case Nil =>
- ShowFunctions(None, None)
- case Token(name, Nil) :: Nil =>
- ShowFunctions(None, Some(unquoteString(cleanIdentifier(name))))
- case Token(db, Nil) :: Token(name, Nil) :: Nil =>
- ShowFunctions(Some(unquoteString(cleanIdentifier(db))),
- Some(unquoteString(cleanIdentifier(name))))
- case _ =>
- noParseRule("SHOW FUNCTIONS", node)
- }
-
- case Token("TOK_DESCFUNCTION", Token(functionName, Nil) :: isExtended) =>
- DescribeFunction(cleanIdentifier(functionName), isExtended.nonEmpty)
-
- case Token("TOK_QUERY", queryArgs @ Token("TOK_CTE" | "TOK_FROM" | "TOK_INSERT", _) :: _) =>
- val (fromClause: Option[ASTNode], insertClauses, cteRelations) =
- queryArgs match {
- case Token("TOK_CTE", ctes) :: Token("TOK_FROM", from) :: inserts =>
- val cteRelations = ctes.map { node =>
- val relation = nodeToRelation(node).asInstanceOf[SubqueryAlias]
- relation.alias -> relation
- }
- (Some(from.head), inserts, Some(cteRelations.toMap))
- case Token("TOK_FROM", from) :: inserts =>
- (Some(from.head), inserts, None)
- case Token("TOK_INSERT", _) :: Nil =>
- (None, queryArgs, None)
- }
-
- // Return one query for each insert clause.
- val queries = insertClauses.map {
- case Token("TOK_INSERT", singleInsert) =>
- val (
- intoClause ::
- destClause ::
- selectClause ::
- selectDistinctClause ::
- whereClause ::
- groupByClause ::
- rollupGroupByClause ::
- cubeGroupByClause ::
- groupingSetsClause ::
- orderByClause ::
- havingClause ::
- sortByClause ::
- clusterByClause ::
- distributeByClause ::
- limitClause ::
- lateralViewClause ::
- windowClause :: Nil) = {
- getClauses(
- Seq(
- "TOK_INSERT_INTO",
- "TOK_DESTINATION",
- "TOK_SELECT",
- "TOK_SELECTDI",
- "TOK_WHERE",
- "TOK_GROUPBY",
- "TOK_ROLLUP_GROUPBY",
- "TOK_CUBE_GROUPBY",
- "TOK_GROUPING_SETS",
- "TOK_ORDERBY",
- "TOK_HAVING",
- "TOK_SORTBY",
- "TOK_CLUSTERBY",
- "TOK_DISTRIBUTEBY",
- "TOK_LIMIT",
- "TOK_LATERAL_VIEW",
- "WINDOW"),
- singleInsert)
- }
-
- val relations = fromClause match {
- case Some(f) => nodeToRelation(f)
- case None => OneRowRelation
- }
-
- val withLateralView = lateralViewClause.map { lv =>
- nodeToGenerate(lv.children.head, outer = false, relations)
- }.getOrElse(relations)
-
- val withWhere = whereClause.map { whereNode =>
- val Seq(whereExpr) = whereNode.children
- Filter(nodeToExpr(whereExpr), withLateralView)
- }.getOrElse(withLateralView)
-
- val select = (selectClause orElse selectDistinctClause)
- .getOrElse(sys.error("No select clause."))
-
- val transformation = nodeToTransformation(select.children.head, withWhere)
-
- // The projection of the query can either be a normal projection, an aggregation
- // (if there is a group by) or a script transformation.
- val withProject: LogicalPlan = transformation.getOrElse {
- val selectExpressions =
- select.children.flatMap(selExprNodeToExpr).map(UnresolvedAlias(_))
- Seq(
- groupByClause.map(e => e match {
- case Token("TOK_GROUPBY", children) =>
- // Not a transformation so must be either project or aggregation.
- Aggregate(children.map(nodeToExpr), selectExpressions, withWhere)
- case _ => sys.error("Expect GROUP BY")
- }),
- groupingSetsClause.map(e => e match {
- case Token("TOK_GROUPING_SETS", children) =>
- val(groupByExprs, masks) = extractGroupingSet(children)
- GroupingSets(masks, groupByExprs, withWhere, selectExpressions)
- case _ => sys.error("Expect GROUPING SETS")
- }),
- rollupGroupByClause.map(e => e match {
- case Token("TOK_ROLLUP_GROUPBY", children) =>
- Aggregate(
- Seq(Rollup(children.map(nodeToExpr))),
- selectExpressions,
- withWhere)
- case _ => sys.error("Expect WITH ROLLUP")
- }),
- cubeGroupByClause.map(e => e match {
- case Token("TOK_CUBE_GROUPBY", children) =>
- Aggregate(
- Seq(Cube(children.map(nodeToExpr))),
- selectExpressions,
- withWhere)
- case _ => sys.error("Expect WITH CUBE")
- }),
- Some(Project(selectExpressions, withWhere))).flatten.head
- }
-
- // Handle HAVING clause.
- val withHaving = havingClause.map { h =>
- val havingExpr = h.children match { case Seq(hexpr) => nodeToExpr(hexpr) }
- // Note that we added a cast to boolean. If the expression itself is already boolean,
- // the optimizer will get rid of the unnecessary cast.
- Filter(Cast(havingExpr, BooleanType), withProject)
- }.getOrElse(withProject)
-
- // Handle SELECT DISTINCT
- val withDistinct =
- if (selectDistinctClause.isDefined) Distinct(withHaving) else withHaving
-
- // Handle ORDER BY, SORT BY, DISTRIBUTE BY, and CLUSTER BY clause.
- val withSort =
- (orderByClause, sortByClause, distributeByClause, clusterByClause) match {
- case (Some(totalOrdering), None, None, None) =>
- Sort(totalOrdering.children.map(nodeToSortOrder), global = true, withDistinct)
- case (None, Some(perPartitionOrdering), None, None) =>
- Sort(
- perPartitionOrdering.children.map(nodeToSortOrder),
- global = false, withDistinct)
- case (None, None, Some(partitionExprs), None) =>
- RepartitionByExpression(
- partitionExprs.children.map(nodeToExpr), withDistinct)
- case (None, Some(perPartitionOrdering), Some(partitionExprs), None) =>
- Sort(
- perPartitionOrdering.children.map(nodeToSortOrder), global = false,
- RepartitionByExpression(
- partitionExprs.children.map(nodeToExpr),
- withDistinct))
- case (None, None, None, Some(clusterExprs)) =>
- Sort(
- clusterExprs.children.map(nodeToExpr).map(SortOrder(_, Ascending)),
- global = false,
- RepartitionByExpression(
- clusterExprs.children.map(nodeToExpr),
- withDistinct))
- case (None, None, None, None) => withDistinct
- case _ => sys.error("Unsupported set of ordering / distribution clauses.")
- }
-
- val withLimit =
- limitClause.map(l => nodeToExpr(l.children.head))
- .map(Limit(_, withSort))
- .getOrElse(withSort)
-
- // Collect all window specifications defined in the WINDOW clause.
- val windowDefinitions = windowClause.map(_.children.collect {
- case Token("TOK_WINDOWDEF",
- Token(windowName, Nil) :: Token("TOK_WINDOWSPEC", spec) :: Nil) =>
- windowName -> nodesToWindowSpecification(spec)
- }.toMap)
- // Handle cases like
- // window w1 as (partition by p_mfgr order by p_name
- // range between 2 preceding and 2 following),
- // w2 as w1
- val resolvedCrossReference = windowDefinitions.map {
- windowDefMap => windowDefMap.map {
- case (windowName, WindowSpecReference(other)) =>
- (windowName, windowDefMap(other).asInstanceOf[WindowSpecDefinition])
- case o => o.asInstanceOf[(String, WindowSpecDefinition)]
- }
- }
-
- val withWindowDefinitions =
- resolvedCrossReference.map(WithWindowDefinition(_, withLimit)).getOrElse(withLimit)
-
- // TOK_INSERT_INTO means to add files to the table.
- // TOK_DESTINATION means to overwrite the table.
- val resultDestination =
- (intoClause orElse destClause).getOrElse(sys.error("No destination found."))
- val overwrite = intoClause.isEmpty
- nodeToDest(
- resultDestination,
- withWindowDefinitions,
- overwrite)
- }
-
- // If there are multiple INSERTS just UNION them together into one query.
- val query = if (queries.length == 1) queries.head else Union(queries)
-
- // return With plan if there is CTE
- cteRelations.map(With(query, _)).getOrElse(query)
-
- case Token("TOK_UNIONALL", left :: right :: Nil) =>
- Union(nodeToPlan(left), nodeToPlan(right))
- case Token("TOK_UNIONDISTINCT", left :: right :: Nil) =>
- Distinct(Union(nodeToPlan(left), nodeToPlan(right)))
- case Token("TOK_EXCEPT", left :: right :: Nil) =>
- Except(nodeToPlan(left), nodeToPlan(right))
- case Token("TOK_INTERSECT", left :: right :: Nil) =>
- Intersect(nodeToPlan(left), nodeToPlan(right))
-
- case _ =>
- noParseRule("Plan", node)
- }
-
- val allJoinTokens = "(TOK_.*JOIN)".r
- val laterViewToken = "TOK_LATERAL_VIEW(.*)".r
- protected def nodeToRelation(node: ASTNode): LogicalPlan = {
- node match {
- case Token("TOK_SUBQUERY", query :: Token(alias, Nil) :: Nil) =>
- SubqueryAlias(cleanIdentifier(alias), nodeToPlan(query))
-
- case Token(laterViewToken(isOuter), selectClause :: relationClause :: Nil) =>
- nodeToGenerate(
- selectClause,
- outer = isOuter.nonEmpty,
- nodeToRelation(relationClause))
-
- /* All relations, possibly with aliases or sampling clauses. */
- case Token("TOK_TABREF", clauses) =>
- // If the last clause is not a token then it's the alias of the table.
- val (nonAliasClauses, aliasClause) =
- if (clauses.last.text.startsWith("TOK")) {
- (clauses, None)
- } else {
- (clauses.dropRight(1), Some(clauses.last))
- }
-
- val (Some(tableNameParts) ::
- splitSampleClause ::
- bucketSampleClause :: Nil) = {
- getClauses(Seq("TOK_TABNAME", "TOK_TABLESPLITSAMPLE", "TOK_TABLEBUCKETSAMPLE"),
- nonAliasClauses)
- }
-
- val tableIdent = extractTableIdent(tableNameParts)
- val alias = aliasClause.map { case Token(a, Nil) => cleanIdentifier(a) }
- val relation = UnresolvedRelation(tableIdent, alias)
-
- // Apply sampling if requested.
- (bucketSampleClause orElse splitSampleClause).map {
- case Token("TOK_TABLESPLITSAMPLE",
- Token("TOK_ROWCOUNT", Nil) :: Token(count, Nil) :: Nil) =>
- Limit(Literal(count.toInt), relation)
- case Token("TOK_TABLESPLITSAMPLE",
- Token("TOK_PERCENT", Nil) :: Token(fraction, Nil) :: Nil) =>
- // The range of fraction accepted by Sample is [0, 1]. Because Hive's block sampling
- // function takes X PERCENT as the input and the range of X is [0, 100], we need to
- // adjust the fraction.
- require(
- fraction.toDouble >= (0.0 - RandomSampler.roundingEpsilon)
- && fraction.toDouble <= (100.0 + RandomSampler.roundingEpsilon),
- s"Sampling fraction ($fraction) must be on interval [0, 100]")
- Sample(0.0, fraction.toDouble / 100, withReplacement = false,
- (math.random * 1000).toInt,
- relation)(
- isTableSample = true)
- case Token("TOK_TABLEBUCKETSAMPLE",
- Token(numerator, Nil) ::
- Token(denominator, Nil) :: Nil) =>
- val fraction = numerator.toDouble / denominator.toDouble
- Sample(0.0, fraction, withReplacement = false, (math.random * 1000).toInt, relation)(
- isTableSample = true)
- case a =>
- noParseRule("Sampling", a)
- }.getOrElse(relation)
-
- case Token(allJoinTokens(joinToken), relation1 :: relation2 :: other) =>
- if (!(other.size <= 1)) {
- sys.error(s"Unsupported join operation: $other")
- }
-
- val (joinType, joinCondition) = getJoinInfo(joinToken, other, node)
-
- Join(nodeToRelation(relation1),
- nodeToRelation(relation2),
- joinType,
- joinCondition)
- case _ =>
- noParseRule("Relation", node)
- }
- }
-
- protected def getJoinInfo(
- joinToken: String,
- joinConditionToken: Seq[ASTNode],
- node: ASTNode): (JoinType, Option[Expression]) = {
- val joinType = joinToken match {
- case "TOK_JOIN" => Inner
- case "TOK_CROSSJOIN" => Inner
- case "TOK_RIGHTOUTERJOIN" => RightOuter
- case "TOK_LEFTOUTERJOIN" => LeftOuter
- case "TOK_FULLOUTERJOIN" => FullOuter
- case "TOK_LEFTSEMIJOIN" => LeftSemi
- case "TOK_UNIQUEJOIN" => noParseRule("Unique Join", node)
- case "TOK_ANTIJOIN" => noParseRule("Anti Join", node)
- case "TOK_NATURALJOIN" => NaturalJoin(Inner)
- case "TOK_NATURALRIGHTOUTERJOIN" => NaturalJoin(RightOuter)
- case "TOK_NATURALLEFTOUTERJOIN" => NaturalJoin(LeftOuter)
- case "TOK_NATURALFULLOUTERJOIN" => NaturalJoin(FullOuter)
- }
-
- joinConditionToken match {
- case Token("TOK_USING", columnList :: Nil) :: Nil =>
- val colNames = columnList.children.collect {
- case Token(name, Nil) => UnresolvedAttribute(name)
- }
- (UsingJoin(joinType, colNames), None)
- /* Join expression specified using ON clause */
- case _ => (joinType, joinConditionToken.headOption.map(nodeToExpr))
- }
- }
-
- protected def nodeToSortOrder(node: ASTNode): SortOrder = node match {
- case Token("TOK_TABSORTCOLNAMEASC", sortExpr :: Nil) =>
- SortOrder(nodeToExpr(sortExpr), Ascending)
- case Token("TOK_TABSORTCOLNAMEDESC", sortExpr :: Nil) =>
- SortOrder(nodeToExpr(sortExpr), Descending)
- case _ =>
- noParseRule("SortOrder", node)
- }
-
- val destinationToken = "TOK_DESTINATION|TOK_INSERT_INTO".r
- protected def nodeToDest(
- node: ASTNode,
- query: LogicalPlan,
- overwrite: Boolean): LogicalPlan = node match {
- case Token(destinationToken(),
- Token("TOK_DIR",
- Token("TOK_TMP_FILE", Nil) :: Nil) :: Nil) =>
- query
-
- case Token(destinationToken(),
- Token("TOK_TAB",
- tableArgs) :: Nil) =>
- val Some(tableNameParts) :: partitionClause :: Nil =
- getClauses(Seq("TOK_TABNAME", "TOK_PARTSPEC"), tableArgs)
-
- val tableIdent = extractTableIdent(tableNameParts)
-
- val partitionKeys = partitionClause.map(_.children.map {
- // Parse partitions. We also make keys case insensitive.
- case Token("TOK_PARTVAL", Token(key, Nil) :: Token(value, Nil) :: Nil) =>
- cleanIdentifier(key.toLowerCase) -> Some(unquoteString(value))
- case Token("TOK_PARTVAL", Token(key, Nil) :: Nil) =>
- cleanIdentifier(key.toLowerCase) -> None
- }.toMap).getOrElse(Map.empty)
-
- InsertIntoTable(
- UnresolvedRelation(tableIdent, None), partitionKeys, query, overwrite, ifNotExists = false)
-
- case Token(destinationToken(),
- Token("TOK_TAB",
- tableArgs) ::
- Token("TOK_IFNOTEXISTS",
- ifNotExists) :: Nil) =>
- val Some(tableNameParts) :: partitionClause :: Nil =
- getClauses(Seq("TOK_TABNAME", "TOK_PARTSPEC"), tableArgs)
-
- val tableIdent = extractTableIdent(tableNameParts)
-
- val partitionKeys = partitionClause.map(_.children.map {
- // Parse partitions. We also make keys case insensitive.
- case Token("TOK_PARTVAL", Token(key, Nil) :: Token(value, Nil) :: Nil) =>
- cleanIdentifier(key.toLowerCase) -> Some(unquoteString(value))
- case Token("TOK_PARTVAL", Token(key, Nil) :: Nil) =>
- cleanIdentifier(key.toLowerCase) -> None
- }.toMap).getOrElse(Map.empty)
-
- InsertIntoTable(
- UnresolvedRelation(tableIdent, None), partitionKeys, query, overwrite, ifNotExists = true)
-
- case _ =>
- noParseRule("Destination", node)
- }
-
- protected def selExprNodeToExpr(node: ASTNode): Option[Expression] = node match {
- case Token("TOK_SELEXPR", e :: Nil) =>
- Some(nodeToExpr(e))
-
- case Token("TOK_SELEXPR", e :: Token(alias, Nil) :: Nil) =>
- Some(Alias(nodeToExpr(e), cleanIdentifier(alias))())
-
- case Token("TOK_SELEXPR", e :: aliasChildren) =>
- val aliasNames = aliasChildren.collect {
- case Token(name, Nil) => cleanIdentifier(name)
- }
- Some(MultiAlias(nodeToExpr(e), aliasNames))
-
- /* Hints are ignored */
- case Token("TOK_HINTLIST", _) => None
-
- case _ =>
- noParseRule("Select", node)
- }
-
- /**
- * Flattens the left deep tree with the specified pattern into a list.
- */
- private def flattenLeftDeepTree(node: ASTNode, pattern: Regex): Seq[ASTNode] = {
- val collected = ArrayBuffer[ASTNode]()
- var rest = node
- while (rest match {
- case Token(pattern(), l :: r :: Nil) =>
- collected += r
- rest = l
- true
- case _ => false
- }) {
- // do nothing
- }
- collected += rest
- // keep them in the same order as in SQL
- collected.reverse
- }
-
- /**
- * Creates a balanced tree that has similar number of nodes on left and right.
- *
- * This help to reduce the depth of the tree to prevent StackOverflow in analyzer/optimizer.
- */
- private def balancedTree(
- expr: Seq[Expression],
- f: (Expression, Expression) => Expression): Expression = expr.length match {
- case 1 => expr.head
- case 2 => f(expr.head, expr(1))
- case l => f(balancedTree(expr.slice(0, l / 2), f), balancedTree(expr.slice(l / 2, l), f))
- }
-
- protected def nodeToExpr(node: ASTNode): Expression = node match {
- /* Attribute References */
- case Token("TOK_TABLE_OR_COL", Token(name, Nil) :: Nil) =>
- UnresolvedAttribute.quoted(cleanIdentifier(name))
- case Token(".", qualifier :: Token(attr, Nil) :: Nil) =>
- nodeToExpr(qualifier) match {
- case UnresolvedAttribute(nameParts) =>
- UnresolvedAttribute(nameParts :+ cleanIdentifier(attr))
- case other => UnresolvedExtractValue(other, Literal(cleanIdentifier(attr)))
- }
- case Token("TOK_SUBQUERY_EXPR", Token("TOK_SUBQUERY_OP", Nil) :: subquery :: Nil) =>
- ScalarSubquery(nodeToPlan(subquery))
-
- /* Stars (*) */
- case Token("TOK_ALLCOLREF", Nil) => UnresolvedStar(None)
- // The format of dbName.tableName.* cannot be parsed by HiveParser. TOK_TABNAME will only
- // has a single child which is tableName.
- case Token("TOK_ALLCOLREF", Token("TOK_TABNAME", target) :: Nil) if target.nonEmpty =>
- UnresolvedStar(Some(target.map(x => cleanIdentifier(x.text))))
-
- /* Aggregate Functions */
- case Token("TOK_FUNCTIONDI", Token(COUNT(), Nil) :: args) =>
- Count(args.map(nodeToExpr)).toAggregateExpression(isDistinct = true)
- case Token("TOK_FUNCTIONSTAR", Token(COUNT(), Nil) :: Nil) =>
- Count(Literal(1)).toAggregateExpression()
-
- /* Casts */
- case Token("TOK_FUNCTION", Token("TOK_STRING", Nil) :: arg :: Nil) =>
- Cast(nodeToExpr(arg), StringType)
- case Token("TOK_FUNCTION", Token("TOK_VARCHAR", _) :: arg :: Nil) =>
- Cast(nodeToExpr(arg), StringType)
- case Token("TOK_FUNCTION", Token("TOK_CHAR", _) :: arg :: Nil) =>
- Cast(nodeToExpr(arg), StringType)
- case Token("TOK_FUNCTION", Token("TOK_INT", Nil) :: arg :: Nil) =>
- Cast(nodeToExpr(arg), IntegerType)
- case Token("TOK_FUNCTION", Token("TOK_BIGINT", Nil) :: arg :: Nil) =>
- Cast(nodeToExpr(arg), LongType)
- case Token("TOK_FUNCTION", Token("TOK_FLOAT", Nil) :: arg :: Nil) =>
- Cast(nodeToExpr(arg), FloatType)
- case Token("TOK_FUNCTION", Token("TOK_DOUBLE", Nil) :: arg :: Nil) =>
- Cast(nodeToExpr(arg), DoubleType)
- case Token("TOK_FUNCTION", Token("TOK_SMALLINT", Nil) :: arg :: Nil) =>
- Cast(nodeToExpr(arg), ShortType)
- case Token("TOK_FUNCTION", Token("TOK_TINYINT", Nil) :: arg :: Nil) =>
- Cast(nodeToExpr(arg), ByteType)
- case Token("TOK_FUNCTION", Token("TOK_BINARY", Nil) :: arg :: Nil) =>
- Cast(nodeToExpr(arg), BinaryType)
- case Token("TOK_FUNCTION", Token("TOK_BOOLEAN", Nil) :: arg :: Nil) =>
- Cast(nodeToExpr(arg), BooleanType)
- case Token("TOK_FUNCTION", Token("TOK_DECIMAL", precision :: scale :: nil) :: arg :: Nil) =>
- Cast(nodeToExpr(arg), DecimalType(precision.text.toInt, scale.text.toInt))
- case Token("TOK_FUNCTION", Token("TOK_DECIMAL", precision :: Nil) :: arg :: Nil) =>
- Cast(nodeToExpr(arg), DecimalType(precision.text.toInt, 0))
- case Token("TOK_FUNCTION", Token("TOK_DECIMAL", Nil) :: arg :: Nil) =>
- Cast(nodeToExpr(arg), DecimalType.USER_DEFAULT)
- case Token("TOK_FUNCTION", Token("TOK_TIMESTAMP", Nil) :: arg :: Nil) =>
- Cast(nodeToExpr(arg), TimestampType)
- case Token("TOK_FUNCTION", Token("TOK_DATE", Nil) :: arg :: Nil) =>
- Cast(nodeToExpr(arg), DateType)
-
- /* Arithmetic */
- case Token("+", child :: Nil) => nodeToExpr(child)
- case Token("-", child :: Nil) => UnaryMinus(nodeToExpr(child))
- case Token("~", child :: Nil) => BitwiseNot(nodeToExpr(child))
- case Token("+", left :: right:: Nil) => Add(nodeToExpr(left), nodeToExpr(right))
- case Token("-", left :: right:: Nil) => Subtract(nodeToExpr(left), nodeToExpr(right))
- case Token("*", left :: right:: Nil) => Multiply(nodeToExpr(left), nodeToExpr(right))
- case Token("/", left :: right:: Nil) => Divide(nodeToExpr(left), nodeToExpr(right))
- case Token(DIV(), left :: right:: Nil) =>
- Cast(Divide(nodeToExpr(left), nodeToExpr(right)), LongType)
- case Token("%", left :: right:: Nil) => Remainder(nodeToExpr(left), nodeToExpr(right))
- case Token("&", left :: right:: Nil) => BitwiseAnd(nodeToExpr(left), nodeToExpr(right))
- case Token("|", left :: right:: Nil) => BitwiseOr(nodeToExpr(left), nodeToExpr(right))
- case Token("^", left :: right:: Nil) => BitwiseXor(nodeToExpr(left), nodeToExpr(right))
-
- /* Comparisons */
- case Token("=", left :: right:: Nil) => EqualTo(nodeToExpr(left), nodeToExpr(right))
- case Token("==", left :: right:: Nil) => EqualTo(nodeToExpr(left), nodeToExpr(right))
- case Token("<=>", left :: right:: Nil) => EqualNullSafe(nodeToExpr(left), nodeToExpr(right))
- case Token("!=", left :: right:: Nil) => Not(EqualTo(nodeToExpr(left), nodeToExpr(right)))
- case Token("<>", left :: right:: Nil) => Not(EqualTo(nodeToExpr(left), nodeToExpr(right)))
- case Token(">", left :: right:: Nil) => GreaterThan(nodeToExpr(left), nodeToExpr(right))
- case Token(">=", left :: right:: Nil) => GreaterThanOrEqual(nodeToExpr(left), nodeToExpr(right))
- case Token("<", left :: right:: Nil) => LessThan(nodeToExpr(left), nodeToExpr(right))
- case Token("<=", left :: right:: Nil) => LessThanOrEqual(nodeToExpr(left), nodeToExpr(right))
- case Token(LIKE(), left :: right:: Nil) => Like(nodeToExpr(left), nodeToExpr(right))
- case Token(RLIKE(), left :: right:: Nil) => RLike(nodeToExpr(left), nodeToExpr(right))
- case Token(REGEXP(), left :: right:: Nil) => RLike(nodeToExpr(left), nodeToExpr(right))
- case Token("TOK_FUNCTION", Token("TOK_ISNOTNULL", Nil) :: child :: Nil) =>
- IsNotNull(nodeToExpr(child))
- case Token("TOK_FUNCTION", Token("TOK_ISNULL", Nil) :: child :: Nil) =>
- IsNull(nodeToExpr(child))
- case Token("TOK_FUNCTION", Token(IN(), Nil) :: value :: list) =>
- In(nodeToExpr(value), list.map(nodeToExpr))
- case Token("TOK_FUNCTION",
- Token(BETWEEN(), Nil) ::
- kw ::
- target ::
- minValue ::
- maxValue :: Nil) =>
-
- val targetExpression = nodeToExpr(target)
- val betweenExpr =
- And(
- GreaterThanOrEqual(targetExpression, nodeToExpr(minValue)),
- LessThanOrEqual(targetExpression, nodeToExpr(maxValue)))
- kw match {
- case Token("KW_FALSE", Nil) => betweenExpr
- case Token("KW_TRUE", Nil) => Not(betweenExpr)
- }
-
- /* Boolean Logic */
- case Token(AND(), left :: right:: Nil) =>
- balancedTree(flattenLeftDeepTree(node, AND).map(nodeToExpr), And)
- case Token(OR(), left :: right:: Nil) =>
- balancedTree(flattenLeftDeepTree(node, OR).map(nodeToExpr), Or)
- case Token(NOT(), child :: Nil) => Not(nodeToExpr(child))
- case Token("!", child :: Nil) => Not(nodeToExpr(child))
-
- /* Case statements */
- case Token("TOK_FUNCTION", Token(WHEN(), Nil) :: branches) =>
- CaseWhen.createFromParser(branches.map(nodeToExpr))
- case Token("TOK_FUNCTION", Token(CASE(), Nil) :: branches) =>
- val keyExpr = nodeToExpr(branches.head)
- CaseKeyWhen(keyExpr, branches.drop(1).map(nodeToExpr))
-
- /* Complex datatype manipulation */
- case Token("[", child :: ordinal :: Nil) =>
- UnresolvedExtractValue(nodeToExpr(child), nodeToExpr(ordinal))
-
- /* Window Functions */
- case Token(text, args :+ Token("TOK_WINDOWSPEC", spec)) =>
- val function = nodeToExpr(node.copy(children = node.children.init))
- nodesToWindowSpecification(spec) match {
- case reference: WindowSpecReference =>
- UnresolvedWindowExpression(function, reference)
- case definition: WindowSpecDefinition =>
- WindowExpression(function, definition)
- }
-
- /* UDFs - Must be last otherwise will preempt built in functions */
- case Token("TOK_FUNCTION", Token(name, Nil) :: args) =>
- UnresolvedFunction(name, args.map(nodeToExpr), isDistinct = false)
- // Aggregate function with DISTINCT keyword.
- case Token("TOK_FUNCTIONDI", Token(name, Nil) :: args) =>
- UnresolvedFunction(name, args.map(nodeToExpr), isDistinct = true)
- case Token("TOK_FUNCTIONSTAR", Token(name, Nil) :: args) =>
- UnresolvedFunction(name, UnresolvedStar(None) :: Nil, isDistinct = false)
-
- /* Literals */
- case Token("TOK_NULL", Nil) => Literal.create(null, NullType)
- case Token(TRUE(), Nil) => Literal.create(true, BooleanType)
- case Token(FALSE(), Nil) => Literal.create(false, BooleanType)
- case Token("TOK_STRINGLITERALSEQUENCE", strings) =>
- Literal(strings.map(s => ParseUtils.unescapeSQLString(s.text)).mkString)
-
- case ast if ast.tokenType == SparkSqlParser.TinyintLiteral =>
- Literal.create(ast.text.substring(0, ast.text.length() - 1).toByte, ByteType)
-
- case ast if ast.tokenType == SparkSqlParser.SmallintLiteral =>
- Literal.create(ast.text.substring(0, ast.text.length() - 1).toShort, ShortType)
-
- case ast if ast.tokenType == SparkSqlParser.BigintLiteral =>
- Literal.create(ast.text.substring(0, ast.text.length() - 1).toLong, LongType)
-
- case ast if ast.tokenType == SparkSqlParser.DoubleLiteral =>
- Literal(ast.text.toDouble)
-
- case ast if ast.tokenType == SparkSqlParser.Number =>
- val text = ast.text
- text match {
- case INTEGRAL() =>
- BigDecimal(text) match {
- case v if v.isValidInt =>
- Literal(v.intValue())
- case v if v.isValidLong =>
- Literal(v.longValue())
- case v => Literal(v.underlying())
- }
- case DECIMAL(_*) =>
- Literal(BigDecimal(text).underlying())
- case _ =>
- // Convert a scientifically notated decimal into a double.
- Literal(text.toDouble)
- }
- case ast if ast.tokenType == SparkSqlParser.StringLiteral =>
- Literal(ParseUtils.unescapeSQLString(ast.text))
-
- case ast if ast.tokenType == SparkSqlParser.TOK_DATELITERAL =>
- Literal(Date.valueOf(ast.text.substring(1, ast.text.length - 1)))
-
- case ast if ast.tokenType == SparkSqlParser.TOK_INTERVAL_YEAR_MONTH_LITERAL =>
- Literal(CalendarInterval.fromYearMonthString(ast.children.head.text))
-
- case ast if ast.tokenType == SparkSqlParser.TOK_INTERVAL_DAY_TIME_LITERAL =>
- Literal(CalendarInterval.fromDayTimeString(ast.children.head.text))
-
- case Token("TOK_INTERVAL", elements) =>
- var interval = new CalendarInterval(0, 0)
- var updated = false
- elements.foreach {
- // The interval node will always contain children for all possible time units. A child node
- // is only useful when it contains exactly one (numeric) child.
- case e @ Token(name, Token(value, Nil) :: Nil) =>
- val unit = name match {
- case "TOK_INTERVAL_YEAR_LITERAL" => "year"
- case "TOK_INTERVAL_MONTH_LITERAL" => "month"
- case "TOK_INTERVAL_WEEK_LITERAL" => "week"
- case "TOK_INTERVAL_DAY_LITERAL" => "day"
- case "TOK_INTERVAL_HOUR_LITERAL" => "hour"
- case "TOK_INTERVAL_MINUTE_LITERAL" => "minute"
- case "TOK_INTERVAL_SECOND_LITERAL" => "second"
- case "TOK_INTERVAL_MILLISECOND_LITERAL" => "millisecond"
- case "TOK_INTERVAL_MICROSECOND_LITERAL" => "microsecond"
- case _ => noParseRule(s"Interval($name)", e)
- }
- interval = interval.add(CalendarInterval.fromSingleUnitString(unit, value))
- updated = true
- case _ =>
- }
- if (!updated) {
- throw new AnalysisException("at least one time unit should be given for interval literal")
- }
- Literal(interval)
-
- case _ =>
- noParseRule("Expression", node)
- }
-
- /* Case insensitive matches for Window Specification */
- val PRECEDING = "(?i)preceding".r
- val FOLLOWING = "(?i)following".r
- val CURRENT = "(?i)current".r
- protected def nodesToWindowSpecification(nodes: Seq[ASTNode]): WindowSpec = nodes match {
- case Token(windowName, Nil) :: Nil =>
- // Refer to a window spec defined in the window clause.
- WindowSpecReference(windowName)
- case Nil =>
- // OVER()
- WindowSpecDefinition(
- partitionSpec = Nil,
- orderSpec = Nil,
- frameSpecification = UnspecifiedFrame)
- case spec =>
- val (partitionClause :: rowFrame :: rangeFrame :: Nil) =
- getClauses(
- Seq(
- "TOK_PARTITIONINGSPEC",
- "TOK_WINDOWRANGE",
- "TOK_WINDOWVALUES"),
- spec)
-
- // Handle Partition By and Order By.
- val (partitionSpec, orderSpec) = partitionClause.map { partitionAndOrdering =>
- val (partitionByClause :: orderByClause :: sortByClause :: clusterByClause :: Nil) =
- getClauses(
- Seq("TOK_DISTRIBUTEBY", "TOK_ORDERBY", "TOK_SORTBY", "TOK_CLUSTERBY"),
- partitionAndOrdering.children)
-
- (partitionByClause, orderByClause.orElse(sortByClause), clusterByClause) match {
- case (Some(partitionByExpr), Some(orderByExpr), None) =>
- (partitionByExpr.children.map(nodeToExpr),
- orderByExpr.children.map(nodeToSortOrder))
- case (Some(partitionByExpr), None, None) =>
- (partitionByExpr.children.map(nodeToExpr), Nil)
- case (None, Some(orderByExpr), None) =>
- (Nil, orderByExpr.children.map(nodeToSortOrder))
- case (None, None, Some(clusterByExpr)) =>
- val expressions = clusterByExpr.children.map(nodeToExpr)
- (expressions, expressions.map(SortOrder(_, Ascending)))
- case _ =>
- noParseRule("Partition & Ordering", partitionAndOrdering)
- }
- }.getOrElse {
- (Nil, Nil)
- }
-
- // Handle Window Frame
- val windowFrame =
- if (rowFrame.isEmpty && rangeFrame.isEmpty) {
- UnspecifiedFrame
- } else {
- val frameType = rowFrame.map(_ => RowFrame).getOrElse(RangeFrame)
- def nodeToBoundary(node: ASTNode): FrameBoundary = node match {
- case Token(PRECEDING(), Token(count, Nil) :: Nil) =>
- if (count.toLowerCase() == "unbounded") {
- UnboundedPreceding
- } else {
- ValuePreceding(count.toInt)
- }
- case Token(FOLLOWING(), Token(count, Nil) :: Nil) =>
- if (count.toLowerCase() == "unbounded") {
- UnboundedFollowing
- } else {
- ValueFollowing(count.toInt)
- }
- case Token(CURRENT(), Nil) => CurrentRow
- case _ =>
- noParseRule("Window Frame Boundary", node)
- }
-
- rowFrame.orElse(rangeFrame).map { frame =>
- frame.children match {
- case precedingNode :: followingNode :: Nil =>
- SpecifiedWindowFrame(
- frameType,
- nodeToBoundary(precedingNode),
- nodeToBoundary(followingNode))
- case precedingNode :: Nil =>
- SpecifiedWindowFrame(frameType, nodeToBoundary(precedingNode), CurrentRow)
- case _ =>
- noParseRule("Window Frame", frame)
- }
- }.getOrElse(sys.error(s"If you see this, please file a bug report with your query."))
- }
-
- WindowSpecDefinition(partitionSpec, orderSpec, windowFrame)
- }
-
- protected def nodeToTransformation(
- node: ASTNode,
- child: LogicalPlan): Option[ScriptTransformation] = None
-
- val explode = "(?i)explode".r
- val jsonTuple = "(?i)json_tuple".r
- protected def nodeToGenerate(node: ASTNode, outer: Boolean, child: LogicalPlan): Generate = {
- val Token("TOK_SELECT", Token("TOK_SELEXPR", clauses) :: Nil) = node
-
- val alias = cleanIdentifier(getClause("TOK_TABALIAS", clauses).children.head.text)
-
- val generator = clauses.head match {
- case Token("TOK_FUNCTION", Token(explode(), Nil) :: childNode :: Nil) =>
- Explode(nodeToExpr(childNode))
- case Token("TOK_FUNCTION", Token(jsonTuple(), Nil) :: children) =>
- JsonTuple(children.map(nodeToExpr))
- case other =>
- nodeToGenerator(other)
- }
-
- val attributes = clauses.collect {
- case Token(a, Nil) => UnresolvedAttribute(cleanIdentifier(a.toLowerCase))
- }
-
- Generate(
- generator,
- join = true,
- outer = outer,
- Some(cleanIdentifier(alias.toLowerCase)),
- attributes,
- child)
- }
-
- protected def nodeToGenerator(node: ASTNode): Generator = noParseRule("Generator", node)
-
-}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/DataTypeParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/DataTypeParser.scala
index 21deb82107..0b570c9e42 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/DataTypeParser.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/DataTypeParser.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.parser
import scala.language.implicitConversions
import scala.util.matching.Regex
import scala.util.parsing.combinator.syntactical.StandardTokenParsers
+import scala.util.parsing.input.CharArrayReader._
import org.apache.spark.sql.types._
@@ -117,3 +118,69 @@ private[sql] object DataTypeParser {
/** The exception thrown from the [[DataTypeParser]]. */
private[sql] class DataTypeException(message: String) extends Exception(message)
+
+class SqlLexical extends scala.util.parsing.combinator.lexical.StdLexical {
+ case class DecimalLit(chars: String) extends Token {
+ override def toString: String = chars
+ }
+
+ /* This is a work around to support the lazy setting */
+ def initialize(keywords: Seq[String]): Unit = {
+ reserved.clear()
+ reserved ++= keywords
+ }
+
+ /* Normal the keyword string */
+ def normalizeKeyword(str: String): String = str.toLowerCase
+
+ delimiters += (
+ "@", "*", "+", "-", "<", "=", "<>", "!=", "<=", ">=", ">", "/", "(", ")",
+ ",", ";", "%", "{", "}", ":", "[", "]", ".", "&", "|", "^", "~", "<=>"
+ )
+
+ protected override def processIdent(name: String) = {
+ val token = normalizeKeyword(name)
+ if (reserved contains token) Keyword(token) else Identifier(name)
+ }
+
+ override lazy val token: Parser[Token] =
+ ( rep1(digit) ~ scientificNotation ^^ { case i ~ s => DecimalLit(i.mkString + s) }
+ | '.' ~> (rep1(digit) ~ scientificNotation) ^^
+ { case i ~ s => DecimalLit("0." + i.mkString + s) }
+ | rep1(digit) ~ ('.' ~> digit.*) ~ scientificNotation ^^
+ { case i1 ~ i2 ~ s => DecimalLit(i1.mkString + "." + i2.mkString + s) }
+ | digit.* ~ identChar ~ (identChar | digit).* ^^
+ { case first ~ middle ~ rest => processIdent((first ++ (middle :: rest)).mkString) }
+ | rep1(digit) ~ ('.' ~> digit.*).? ^^ {
+ case i ~ None => NumericLit(i.mkString)
+ case i ~ Some(d) => DecimalLit(i.mkString + "." + d.mkString)
+ }
+ | '\'' ~> chrExcept('\'', '\n', EofCh).* <~ '\'' ^^
+ { case chars => StringLit(chars mkString "") }
+ | '"' ~> chrExcept('"', '\n', EofCh).* <~ '"' ^^
+ { case chars => StringLit(chars mkString "") }
+ | '`' ~> chrExcept('`', '\n', EofCh).* <~ '`' ^^
+ { case chars => Identifier(chars mkString "") }
+ | EofCh ^^^ EOF
+ | '\'' ~> failure("unclosed string literal")
+ | '"' ~> failure("unclosed string literal")
+ | delim
+ | failure("illegal character")
+ )
+
+ override def identChar: Parser[Elem] = letter | elem('_')
+
+ private lazy val scientificNotation: Parser[String] =
+ (elem('e') | elem('E')) ~> (elem('+') | elem('-')).? ~ rep1(digit) ^^ {
+ case s ~ rest => "e" + s.mkString + rest.mkString
+ }
+
+ override def whitespace: Parser[Any] =
+ ( whitespaceChar
+ | '/' ~ '*' ~ comment
+ | '/' ~ '/' ~ chrExcept(EofCh, '\n').*
+ | '#' ~ chrExcept(EofCh, '\n').*
+ | '-' ~ '-' ~ chrExcept(EofCh, '\n').*
+ | '/' ~ '*' ~ failure("unclosed comment")
+ ).*
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala
index 51cfc50130..d0132529f1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala
@@ -16,91 +16,106 @@
*/
package org.apache.spark.sql.catalyst.parser
-import scala.annotation.tailrec
-
-import org.antlr.runtime._
-import org.antlr.runtime.tree.CommonTree
+import org.antlr.v4.runtime._
+import org.antlr.v4.runtime.atn.PredictionMode
+import org.antlr.v4.runtime.misc.ParseCancellationException
import org.apache.spark.internal.Logging
import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.catalyst.TableIdentifier
+import org.apache.spark.sql.catalyst.expressions.Expression
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.trees.Origin
+import org.apache.spark.sql.types.DataType
/**
- * The ParseDriver takes a SQL command and turns this into an AST.
- *
- * This is based on Hive's org.apache.hadoop.hive.ql.parse.ParseDriver
+ * Base SQL parsing infrastructure.
*/
-object ParseDriver extends Logging {
- /** Create an LogicalPlan ASTNode from a SQL command. */
- def parsePlan(command: String, conf: ParserConf): ASTNode = parse(command, conf) { parser =>
- parser.statement().getTree
- }
+abstract class AbstractSqlParser extends ParserInterface with Logging {
- /** Create an Expression ASTNode from a SQL command. */
- def parseExpression(command: String, conf: ParserConf): ASTNode = parse(command, conf) { parser =>
- parser.singleNamedExpression().getTree
+ /** Creates/Resolves DataType for a given SQL string. */
+ def parseDataType(sqlText: String): DataType = parse(sqlText) { parser =>
+ // TODO add this to the parser interface.
+ astBuilder.visitSingleDataType(parser.singleDataType())
}
- /** Create an TableIdentifier ASTNode from a SQL command. */
- def parseTableName(command: String, conf: ParserConf): ASTNode = parse(command, conf) { parser =>
- parser.singleTableName().getTree
+ /** Creates Expression for a given SQL string. */
+ override def parseExpression(sqlText: String): Expression = parse(sqlText) { parser =>
+ astBuilder.visitSingleExpression(parser.singleExpression())
}
- private def parse(
- command: String,
- conf: ParserConf)(
- toTree: SparkSqlParser => CommonTree): ASTNode = {
- logInfo(s"Parsing command: $command")
+ /** Creates TableIdentifier for a given SQL string. */
+ override def parseTableIdentifier(sqlText: String): TableIdentifier = parse(sqlText) { parser =>
+ astBuilder.visitSingleTableIdentifier(parser.singleTableIdentifier())
+ }
- // Setup error collection.
- val reporter = new ParseErrorReporter()
+ /** Creates LogicalPlan for a given SQL string. */
+ override def parsePlan(sqlText: String): LogicalPlan = parse(sqlText) { parser =>
+ astBuilder.visitSingleStatement(parser.singleStatement()) match {
+ case plan: LogicalPlan => plan
+ case _ => nativeCommand(sqlText)
+ }
+ }
- // Create lexer.
- val lexer = new SparkSqlLexer(new ANTLRNoCaseStringStream(command))
- val tokens = new TokenRewriteStream(lexer)
- lexer.configure(conf, reporter)
+ /** Get the builder (visitor) which converts a ParseTree into a AST. */
+ protected def astBuilder: AstBuilder
- // Create the parser.
- val parser = new SparkSqlParser(tokens)
- parser.configure(conf, reporter)
+ /** Create a native command, or fail when this is not supported. */
+ protected def nativeCommand(sqlText: String): LogicalPlan = {
+ val position = Origin(None, None)
+ throw new ParseException(Option(sqlText), "Unsupported SQL statement", position, position)
+ }
- try {
- val result = toTree(parser)
+ protected def parse[T](command: String)(toResult: SqlBaseParser => T): T = {
+ logInfo(s"Parsing command: $command")
- // Check errors.
- reporter.checkForErrors()
+ val lexer = new SqlBaseLexer(new ANTLRNoCaseStringStream(command))
+ lexer.removeErrorListeners()
+ lexer.addErrorListener(ParseErrorListener)
- // Return the AST node from the result.
- logInfo(s"Parse completed.")
+ val tokenStream = new CommonTokenStream(lexer)
+ val parser = new SqlBaseParser(tokenStream)
+ parser.addParseListener(PostProcessor)
+ parser.removeErrorListeners()
+ parser.addErrorListener(ParseErrorListener)
- // Find the non null token tree in the result.
- @tailrec
- def nonNullToken(tree: CommonTree): CommonTree = {
- if (tree.token != null || tree.getChildCount == 0) tree
- else nonNullToken(tree.getChild(0).asInstanceOf[CommonTree])
+ try {
+ try {
+ // first, try parsing with potentially faster SLL mode
+ parser.getInterpreter.setPredictionMode(PredictionMode.SLL)
+ toResult(parser)
}
- val tree = nonNullToken(result)
-
- // Make sure all boundaries are set.
- tree.setUnknownTokenBoundaries()
-
- // Construct the immutable AST.
- def createASTNode(tree: CommonTree): ASTNode = {
- val children = (0 until tree.getChildCount).map { i =>
- createASTNode(tree.getChild(i).asInstanceOf[CommonTree])
- }.toList
- ASTNode(tree.token, tree.getTokenStartIndex, tree.getTokenStopIndex, children, tokens)
+ catch {
+ case e: ParseCancellationException =>
+ // if we fail, parse with LL mode
+ tokenStream.reset() // rewind input stream
+ parser.reset()
+
+ // Try Again.
+ parser.getInterpreter.setPredictionMode(PredictionMode.LL)
+ toResult(parser)
}
- createASTNode(tree)
}
catch {
- case e: RecognitionException =>
- logInfo(s"Parse failed.")
- reporter.throwError(e)
+ case e: ParseException if e.command.isDefined =>
+ throw e
+ case e: ParseException =>
+ throw e.withCommand(command)
+ case e: AnalysisException =>
+ val position = Origin(e.line, e.startPosition)
+ throw new ParseException(Option(command), e.message, position, position)
}
}
}
/**
+ * Concrete SQL parser for Catalyst-only SQL statements.
+ */
+object CatalystSqlParser extends AbstractSqlParser {
+ val astBuilder = new AstBuilder
+}
+
+/**
* This string stream provides the lexer with upper case characters only. This greatly simplifies
* lexing the stream, while we can maintain the original command.
*
@@ -120,58 +135,104 @@ object ParseDriver extends Logging {
* have the ANTLRNoCaseStringStream implementation.
*/
-private[parser] class ANTLRNoCaseStringStream(input: String) extends ANTLRStringStream(input) {
+private[parser] class ANTLRNoCaseStringStream(input: String) extends ANTLRInputStream(input) {
override def LA(i: Int): Int = {
val la = super.LA(i)
- if (la == 0 || la == CharStream.EOF) la
+ if (la == 0 || la == IntStream.EOF) la
else Character.toUpperCase(la)
}
}
/**
- * Utility used by the Parser and the Lexer for error collection and reporting.
+ * The ParseErrorListener converts parse errors into AnalysisExceptions.
*/
-private[parser] class ParseErrorReporter {
- val errors = scala.collection.mutable.Buffer.empty[ParseError]
-
- def report(br: BaseRecognizer, re: RecognitionException, tokenNames: Array[String]): Unit = {
- errors += ParseError(br, re, tokenNames)
+case object ParseErrorListener extends BaseErrorListener {
+ override def syntaxError(
+ recognizer: Recognizer[_, _],
+ offendingSymbol: scala.Any,
+ line: Int,
+ charPositionInLine: Int,
+ msg: String,
+ e: RecognitionException): Unit = {
+ val position = Origin(Some(line), Some(charPositionInLine))
+ throw new ParseException(None, msg, position, position)
}
+}
- def checkForErrors(): Unit = {
- if (errors.nonEmpty) {
- val first = errors.head
- val e = first.re
- throwError(e.line, e.charPositionInLine, first.buildMessage().toString, errors.tail)
- }
+/**
+ * A [[ParseException]] is an [[AnalysisException]] that is thrown during the parse process. It
+ * contains fields and an extended error message that make reporting and diagnosing errors easier.
+ */
+class ParseException(
+ val command: Option[String],
+ message: String,
+ val start: Origin,
+ val stop: Origin) extends AnalysisException(message, start.line, start.startPosition) {
+
+ def this(message: String, ctx: ParserRuleContext) = {
+ this(Option(ParserUtils.command(ctx)),
+ message,
+ ParserUtils.position(ctx.getStart),
+ ParserUtils.position(ctx.getStop))
}
- def throwError(e: RecognitionException): Nothing = {
- throwError(e.line, e.charPositionInLine, e.toString, errors)
+ override def getMessage: String = {
+ val builder = new StringBuilder
+ builder ++= "\n" ++= message
+ start match {
+ case Origin(Some(l), Some(p)) =>
+ builder ++= s"(line $l, pos $p)\n"
+ command.foreach { cmd =>
+ val (above, below) = cmd.split("\n").splitAt(l)
+ builder ++= "\n== SQL ==\n"
+ above.foreach(builder ++= _ += '\n')
+ builder ++= (0 until p).map(_ => "-").mkString("") ++= "^^^\n"
+ below.foreach(builder ++= _ += '\n')
+ }
+ case _ =>
+ command.foreach { cmd =>
+ builder ++= "\n== SQL ==\n" ++= cmd
+ }
+ }
+ builder.toString
}
- private def throwError(
- line: Int,
- startPosition: Int,
- msg: String,
- errors: Seq[ParseError]): Nothing = {
- val b = new StringBuilder
- b.append(msg).append("\n")
- errors.foreach(error => error.buildMessage(b).append("\n"))
- throw new AnalysisException(b.toString, Option(line), Option(startPosition))
+ def withCommand(cmd: String): ParseException = {
+ new ParseException(Option(cmd), message, start, stop)
}
}
/**
- * Error collected during the parsing process.
- *
- * This is based on Hive's org.apache.hadoop.hive.ql.parse.ParseError
+ * The post-processor validates & cleans-up the parse tree during the parse process.
*/
-private[parser] case class ParseError(
- br: BaseRecognizer,
- re: RecognitionException,
- tokenNames: Array[String]) {
- def buildMessage(s: StringBuilder = new StringBuilder): StringBuilder = {
- s.append(br.getErrorHeader(re)).append(" ").append(br.getErrorMessage(re, tokenNames))
+case object PostProcessor extends SqlBaseBaseListener {
+
+ /** Remove the back ticks from an Identifier. */
+ override def exitQuotedIdentifier(ctx: SqlBaseParser.QuotedIdentifierContext): Unit = {
+ replaceTokenByIdentifier(ctx, 1) { token =>
+ // Remove the double back ticks in the string.
+ token.setText(token.getText.replace("``", "`"))
+ token
+ }
+ }
+
+ /** Treat non-reserved keywords as Identifiers. */
+ override def exitNonReserved(ctx: SqlBaseParser.NonReservedContext): Unit = {
+ replaceTokenByIdentifier(ctx, 0)(identity)
+ }
+
+ private def replaceTokenByIdentifier(
+ ctx: ParserRuleContext,
+ stripMargins: Int)(
+ f: CommonToken => CommonToken = identity): Unit = {
+ val parent = ctx.getParent
+ parent.removeLastChild()
+ val token = ctx.getChild(0).getPayload.asInstanceOf[Token]
+ parent.addChild(f(new CommonToken(
+ new org.antlr.v4.runtime.misc.Pair(token.getTokenSource, token.getInputStream),
+ SqlBaseParser.IDENTIFIER,
+ token.getChannel,
+ token.getStartIndex + stripMargins,
+ token.getStopIndex - stripMargins)))
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserConf.scala
deleted file mode 100644
index ce449b1143..0000000000
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserConf.scala
+++ /dev/null
@@ -1,26 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.spark.sql.catalyst.parser
-
-trait ParserConf {
- def supportQuotedId: Boolean
- def supportSQL11ReservedKeywords: Boolean
-}
-
-case class SimpleParserConf(
- supportQuotedId: Boolean = true,
- supportSQL11ReservedKeywords: Boolean = false) extends ParserConf
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala
index 0c2e481954..cb9fefec8f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala
@@ -14,166 +14,181 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-
package org.apache.spark.sql.catalyst.parser
-import org.apache.spark.sql.AnalysisException
-import org.apache.spark.sql.catalyst.TableIdentifier
-import org.apache.spark.sql.catalyst.trees.CurrentOrigin
-import org.apache.spark.sql.types._
+import scala.collection.mutable.StringBuilder
+
+import org.antlr.v4.runtime.{CharStream, ParserRuleContext, Token}
+import org.antlr.v4.runtime.misc.Interval
+import org.antlr.v4.runtime.tree.TerminalNode
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin}
/**
- * A collection of utility methods and patterns for parsing query texts.
+ * A collection of utility methods for use during the parsing process.
*/
-// TODO: merge with ParseUtils
object ParserUtils {
-
- object Token {
- // Match on (text, children)
- def unapply(node: ASTNode): Some[(String, List[ASTNode])] = {
- CurrentOrigin.setPosition(node.line, node.positionInLine)
- node.pattern
- }
+ /** Get the command which created the token. */
+ def command(ctx: ParserRuleContext): String = {
+ command(ctx.getStart.getInputStream)
}
- private val escapedIdentifier = "`(.+)`".r
- private val doubleQuotedString = "\"([^\"]+)\"".r
- private val singleQuotedString = "'([^']+)'".r
-
- // Token patterns
- val COUNT = "(?i)COUNT".r
- val SUM = "(?i)SUM".r
- val AND = "(?i)AND".r
- val OR = "(?i)OR".r
- val NOT = "(?i)NOT".r
- val TRUE = "(?i)TRUE".r
- val FALSE = "(?i)FALSE".r
- val LIKE = "(?i)LIKE".r
- val RLIKE = "(?i)RLIKE".r
- val REGEXP = "(?i)REGEXP".r
- val IN = "(?i)IN".r
- val DIV = "(?i)DIV".r
- val BETWEEN = "(?i)BETWEEN".r
- val WHEN = "(?i)WHEN".r
- val CASE = "(?i)CASE".r
- val INTEGRAL = "[+-]?\\d+".r
- val DECIMAL = "[+-]?((\\d+(\\.\\d*)?)|(\\.\\d+))".r
-
- /**
- * Strip quotes, if any, from the string.
- */
- def unquoteString(str: String): String = {
- str match {
- case singleQuotedString(s) => s
- case doubleQuotedString(s) => s
- case other => other
- }
+ /** Get the command which created the token. */
+ def command(stream: CharStream): String = {
+ stream.getText(Interval.of(0, stream.size()))
}
- /**
- * Strip backticks, if any, from the string.
- */
- def cleanIdentifier(ident: String): String = {
- ident match {
- case escapedIdentifier(i) => i
- case plainIdent => plainIdent
- }
+ /** Get the code that creates the given node. */
+ def source(ctx: ParserRuleContext): String = {
+ val stream = ctx.getStart.getInputStream
+ stream.getText(Interval.of(ctx.getStart.getStartIndex, ctx.getStop.getStopIndex))
}
- def getClauses(
- clauseNames: Seq[String],
- nodeList: Seq[ASTNode]): Seq[Option[ASTNode]] = {
- var remainingNodes = nodeList
- val clauses = clauseNames.map { clauseName =>
- val (matches, nonMatches) = remainingNodes.partition(_.text.toUpperCase == clauseName)
- remainingNodes = nonMatches ++ (if (matches.nonEmpty) matches.tail else Nil)
- matches.headOption
- }
+ /** Get all the text which comes after the given rule. */
+ def remainder(ctx: ParserRuleContext): String = remainder(ctx.getStop)
- if (remainingNodes.nonEmpty) {
- sys.error(
- s"""Unhandled clauses: ${remainingNodes.map(_.treeString).mkString("\n")}.
- |You are likely trying to use an unsupported Hive feature."""".stripMargin)
- }
- clauses
+ /** Get all the text which comes after the given token. */
+ def remainder(token: Token): String = {
+ val stream = token.getInputStream
+ val interval = Interval.of(token.getStopIndex + 1, stream.size())
+ stream.getText(interval)
}
- def getClause(clauseName: String, nodeList: Seq[ASTNode]): ASTNode = {
- getClauseOption(clauseName, nodeList).getOrElse(sys.error(
- s"Expected clause $clauseName missing from ${nodeList.map(_.treeString).mkString("\n")}"))
+ /** Convert a string token into a string. */
+ def string(token: Token): String = unescapeSQLString(token.getText)
+
+ /** Convert a string node into a string. */
+ def string(node: TerminalNode): String = unescapeSQLString(node.getText)
+
+ /** Get the origin (line and position) of the token. */
+ def position(token: Token): Origin = {
+ Origin(Option(token.getLine), Option(token.getCharPositionInLine))
}
- def getClauseOption(clauseName: String, nodeList: Seq[ASTNode]): Option[ASTNode] = {
- nodeList.filter { case ast: ASTNode => ast.text == clauseName } match {
- case Seq(oneMatch) => Some(oneMatch)
- case Seq() => None
- case _ => sys.error(s"Found multiple instances of clause $clauseName")
+ /** Assert if a condition holds. If it doesn't throw a parse exception. */
+ def assert(f: => Boolean, message: String, ctx: ParserRuleContext): Unit = {
+ if (!f) {
+ throw new ParseException(message, ctx)
}
}
- def extractTableIdent(tableNameParts: ASTNode): TableIdentifier = {
- tableNameParts.children.map {
- case Token(part, Nil) => cleanIdentifier(part)
- } match {
- case Seq(tableOnly) => TableIdentifier(tableOnly)
- case Seq(databaseName, table) => TableIdentifier(table, Some(databaseName))
- case other => sys.error("Hive only supports tables names like 'tableName' " +
- s"or 'databaseName.tableName', found '$other'")
+ /**
+ * Register the origin of the context. Any TreeNode created in the closure will be assigned the
+ * registered origin. This method restores the previously set origin after completion of the
+ * closure.
+ */
+ def withOrigin[T](ctx: ParserRuleContext)(f: => T): T = {
+ val current = CurrentOrigin.get
+ CurrentOrigin.set(position(ctx.getStart))
+ try {
+ f
+ } finally {
+ CurrentOrigin.set(current)
}
}
- def nodeToDataType(node: ASTNode): DataType = node match {
- case Token("TOK_DECIMAL", precision :: scale :: Nil) =>
- DecimalType(precision.text.toInt, scale.text.toInt)
- case Token("TOK_DECIMAL", precision :: Nil) =>
- DecimalType(precision.text.toInt, 0)
- case Token("TOK_DECIMAL", Nil) => DecimalType.USER_DEFAULT
- case Token("TOK_BIGINT", Nil) => LongType
- case Token("TOK_INT", Nil) => IntegerType
- case Token("TOK_TINYINT", Nil) => ByteType
- case Token("TOK_SMALLINT", Nil) => ShortType
- case Token("TOK_BOOLEAN", Nil) => BooleanType
- case Token("TOK_STRING", Nil) => StringType
- case Token("TOK_VARCHAR", Token(_, Nil) :: Nil) => StringType
- case Token("TOK_CHAR", Token(_, Nil) :: Nil) => StringType
- case Token("TOK_FLOAT", Nil) => FloatType
- case Token("TOK_DOUBLE", Nil) => DoubleType
- case Token("TOK_DATE", Nil) => DateType
- case Token("TOK_TIMESTAMP", Nil) => TimestampType
- case Token("TOK_BINARY", Nil) => BinaryType
- case Token("TOK_LIST", elementType :: Nil) => ArrayType(nodeToDataType(elementType))
- case Token("TOK_STRUCT", Token("TOK_TABCOLLIST", fields) :: Nil) =>
- StructType(fields.map(nodeToStructField))
- case Token("TOK_MAP", keyType :: valueType :: Nil) =>
- MapType(nodeToDataType(keyType), nodeToDataType(valueType))
- case _ =>
- noParseRule("DataType", node)
- }
+ /** Unescape baskslash-escaped string enclosed by quotes. */
+ def unescapeSQLString(b: String): String = {
+ var enclosure: Character = null
+ val sb = new StringBuilder(b.length())
+
+ def appendEscapedChar(n: Char) {
+ n match {
+ case '0' => sb.append('\u0000')
+ case '\'' => sb.append('\'')
+ case '"' => sb.append('\"')
+ case 'b' => sb.append('\b')
+ case 'n' => sb.append('\n')
+ case 'r' => sb.append('\r')
+ case 't' => sb.append('\t')
+ case 'Z' => sb.append('\u001A')
+ case '\\' => sb.append('\\')
+ // The following 2 lines are exactly what MySQL does TODO: why do we do this?
+ case '%' => sb.append("\\%")
+ case '_' => sb.append("\\_")
+ case _ => sb.append(n)
+ }
+ }
- def nodeToStructField(node: ASTNode): StructField = node match {
- case Token("TOK_TABCOL", Token(fieldName, Nil) :: dataType :: Nil) =>
- StructField(cleanIdentifier(fieldName), nodeToDataType(dataType), nullable = true)
- case Token("TOK_TABCOL", Token(fieldName, Nil) :: dataType :: comment :: Nil) =>
- val meta = new MetadataBuilder().putString("comment", unquoteString(comment.text)).build()
- StructField(cleanIdentifier(fieldName), nodeToDataType(dataType), nullable = true, meta)
- case _ =>
- noParseRule("StructField", node)
+ var i = 0
+ val strLength = b.length
+ while (i < strLength) {
+ val currentChar = b.charAt(i)
+ if (enclosure == null) {
+ if (currentChar == '\'' || currentChar == '\"') {
+ enclosure = currentChar
+ }
+ } else if (enclosure == currentChar) {
+ enclosure = null
+ } else if (currentChar == '\\') {
+
+ if ((i + 6 < strLength) && b.charAt(i + 1) == 'u') {
+ // \u0000 style character literals.
+
+ val base = i + 2
+ val code = (0 until 4).foldLeft(0) { (mid, j) =>
+ val digit = Character.digit(b.charAt(j + base), 16)
+ (mid << 4) + digit
+ }
+ sb.append(code.asInstanceOf[Char])
+ i += 5
+ } else if (i + 4 < strLength) {
+ // \000 style character literals.
+
+ val i1 = b.charAt(i + 1)
+ val i2 = b.charAt(i + 2)
+ val i3 = b.charAt(i + 3)
+
+ if ((i1 >= '0' && i1 <= '1') && (i2 >= '0' && i2 <= '7') && (i3 >= '0' && i3 <= '7')) {
+ val tmp = ((i3 - '0') + ((i2 - '0') << 3) + ((i1 - '0') << 6)).asInstanceOf[Char]
+ sb.append(tmp)
+ i += 3
+ } else {
+ appendEscapedChar(i1)
+ i += 1
+ }
+ } else if (i + 2 < strLength) {
+ // escaped character literals.
+ val n = b.charAt(i + 1)
+ appendEscapedChar(n)
+ i += 1
+ }
+ } else {
+ // non-escaped character literals.
+ sb.append(currentChar)
+ }
+ i += 1
+ }
+ sb.toString()
}
- /**
- * Throw an exception because we cannot parse the given node for some unexpected reason.
- */
- def parseFailed(msg: String, node: ASTNode): Nothing = {
- throw new AnalysisException(s"$msg: '${node.source}")
- }
+ /** Some syntactic sugar which makes it easier to work with optional clauses for LogicalPlans. */
+ implicit class EnhancedLogicalPlan(val plan: LogicalPlan) extends AnyVal {
+ /**
+ * Create a plan using the block of code when the given context exists. Otherwise return the
+ * original plan.
+ */
+ def optional(ctx: AnyRef)(f: => LogicalPlan): LogicalPlan = {
+ if (ctx != null) {
+ f
+ } else {
+ plan
+ }
+ }
- /**
- * Throw an exception because there are no rules to parse the node.
- */
- def noParseRule(msg: String, node: ASTNode): Nothing = {
- throw new NotImplementedError(
- s"[$msg]: No parse rules for ASTNode type: ${node.tokenType}, tree:\n${node.treeString}")
+ /**
+ * Map a [[LogicalPlan]] to another [[LogicalPlan]] if the passed context exists using the
+ * passed function. The original plan is returned when the context does not exist.
+ */
+ def optionalMap[C <: ParserRuleContext](
+ ctx: C)(
+ f: (C, LogicalPlan) => LogicalPlan): LogicalPlan = {
+ if (ctx != null) {
+ f(ctx, plan)
+ } else {
+ plan
+ }
+ }
}
-
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
index 9c927077d0..0065619135 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
@@ -22,6 +22,7 @@ import scala.collection.mutable
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.types.IntegerType
@@ -68,6 +69,9 @@ object PhysicalOperation extends PredicateHelper {
val substitutedCondition = substitute(aliases)(condition)
(fields, filters ++ splitConjunctivePredicates(substitutedCondition), other, aliases)
+ case BroadcastHint(child) =>
+ collectProjectsAndFilters(child)
+
case other =>
(None, Nil, other, Map.empty)
}
@@ -139,20 +143,20 @@ object ExtractEquiJoinKeys extends Logging with PredicateHelper {
}
/**
- * A pattern that collects the filter and inner joins.
- *
- * Filter
- * |
- * inner Join
- * / \ ----> (Seq(plan0, plan1, plan2), conditions)
- * Filter plan2
- * |
- * inner join
- * / \
- * plan0 plan1
- *
- * Note: This pattern currently only works for left-deep trees.
- */
+ * A pattern that collects the filter and inner joins.
+ *
+ * Filter
+ * |
+ * inner Join
+ * / \ ----> (Seq(plan0, plan1, plan2), conditions)
+ * Filter plan2
+ * |
+ * inner join
+ * / \
+ * plan0 plan1
+ *
+ * Note: This pattern currently only works for left-deep trees.
+ */
object ExtractFiltersAndInnerJoins extends PredicateHelper {
// flatten all inner joins, which are next to each other
@@ -216,3 +220,75 @@ object IntegerIndex {
case _ => None
}
}
+
+/**
+ * An extractor used when planning the physical execution of an aggregation. Compared with a logical
+ * aggregation, the following transformations are performed:
+ * - Unnamed grouping expressions are named so that they can be referred to across phases of
+ * aggregation
+ * - Aggregations that appear multiple times are deduplicated.
+ * - The compution of the aggregations themselves is separated from the final result. For example,
+ * the `count` in `count + 1` will be split into an [[AggregateExpression]] and a final
+ * computation that computes `count.resultAttribute + 1`.
+ */
+object PhysicalAggregation {
+ // groupingExpressions, aggregateExpressions, resultExpressions, child
+ type ReturnType =
+ (Seq[NamedExpression], Seq[AggregateExpression], Seq[NamedExpression], LogicalPlan)
+
+ def unapply(a: Any): Option[ReturnType] = a match {
+ case logical.Aggregate(groupingExpressions, resultExpressions, child) =>
+ // A single aggregate expression might appear multiple times in resultExpressions.
+ // In order to avoid evaluating an individual aggregate function multiple times, we'll
+ // build a set of the distinct aggregate expressions and build a function which can
+ // be used to re-write expressions so that they reference the single copy of the
+ // aggregate function which actually gets computed.
+ val aggregateExpressions = resultExpressions.flatMap { expr =>
+ expr.collect {
+ case agg: AggregateExpression => agg
+ }
+ }.distinct
+
+ val namedGroupingExpressions = groupingExpressions.map {
+ case ne: NamedExpression => ne -> ne
+ // If the expression is not a NamedExpressions, we add an alias.
+ // So, when we generate the result of the operator, the Aggregate Operator
+ // can directly get the Seq of attributes representing the grouping expressions.
+ case other =>
+ val withAlias = Alias(other, other.toString)()
+ other -> withAlias
+ }
+ val groupExpressionMap = namedGroupingExpressions.toMap
+
+ // The original `resultExpressions` are a set of expressions which may reference
+ // aggregate expressions, grouping column values, and constants. When aggregate operator
+ // emits output rows, we will use `resultExpressions` to generate an output projection
+ // which takes the grouping columns and final aggregate result buffer as input.
+ // Thus, we must re-write the result expressions so that their attributes match up with
+ // the attributes of the final result projection's input row:
+ val rewrittenResultExpressions = resultExpressions.map { expr =>
+ expr.transformDown {
+ case ae: AggregateExpression =>
+ // The final aggregation buffer's attributes will be `finalAggregationAttributes`,
+ // so replace each aggregate expression by its corresponding attribute in the set:
+ ae.resultAttribute
+ case expression =>
+ // Since we're using `namedGroupingAttributes` to extract the grouping key
+ // columns, we need to replace grouping key expressions with their corresponding
+ // attributes. We do not rely on the equality check at here since attributes may
+ // differ cosmetically. Instead, we use semanticEquals.
+ groupExpressionMap.collectFirst {
+ case (expr, ne) if expr semanticEquals expression => ne.toAttribute
+ }.getOrElse(expression)
+ }.asInstanceOf[NamedExpression]
+ }
+
+ Some((
+ namedGroupingExpressions.map(_._2),
+ aggregateExpressions,
+ rewrittenResultExpressions,
+ child))
+
+ case _ => None
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
index d31164fe94..d4447ca32d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
@@ -44,25 +44,9 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT
* returns a constraint of the form `isNotNull(a)`
*/
private def constructIsNotNullConstraints(constraints: Set[Expression]): Set[Expression] = {
- var isNotNullConstraints = Set.empty[Expression]
-
- // First, we propagate constraints if the condition consists of equality and ranges. For all
- // other cases, we return an empty set of constraints
- constraints.foreach {
- case EqualTo(l, r) =>
- isNotNullConstraints ++= Set(IsNotNull(l), IsNotNull(r))
- case GreaterThan(l, r) =>
- isNotNullConstraints ++= Set(IsNotNull(l), IsNotNull(r))
- case GreaterThanOrEqual(l, r) =>
- isNotNullConstraints ++= Set(IsNotNull(l), IsNotNull(r))
- case LessThan(l, r) =>
- isNotNullConstraints ++= Set(IsNotNull(l), IsNotNull(r))
- case LessThanOrEqual(l, r) =>
- isNotNullConstraints ++= Set(IsNotNull(l), IsNotNull(r))
- case Not(EqualTo(l, r)) =>
- isNotNullConstraints ++= Set(IsNotNull(l), IsNotNull(r))
- case _ => // No inference
- }
+ // First, we propagate constraints from the null intolerant expressions.
+ var isNotNullConstraints: Set[Expression] =
+ constraints.flatMap(scanNullIntolerantExpr).map(IsNotNull(_))
// Second, we infer additional constraints from non-nullable attributes that are part of the
// operator's output
@@ -73,6 +57,17 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT
}
/**
+ * Recursively explores the expressions which are null intolerant and returns all attributes
+ * in these expressions.
+ */
+ private def scanNullIntolerantExpr(expr: Expression): Seq[Attribute] = expr match {
+ case a: Attribute => Seq(a)
+ case _: NullIntolerant | IsNotNull(_: NullIntolerant) =>
+ expr.children.flatMap(scanNullIntolerantExpr)
+ case _ => Seq.empty[Attribute]
+ }
+
+ /**
* Infers an additional set of constraints from a given set of equality constraints.
* For e.g., if an operator has constraints of the form (`a = 5`, `a = b`), this returns an
* additional constraint of the form `b = 5`
@@ -127,8 +122,8 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT
AttributeSet(children.flatMap(_.asInstanceOf[QueryPlan[PlanType]].output))
/**
- * The set of all attributes that are produced by this node.
- */
+ * The set of all attributes that are produced by this node.
+ */
def producedAttributes: AttributeSet = AttributeSet.empty
/**
@@ -216,8 +211,10 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT
if (changed) makeCopy(newArgs).asInstanceOf[this.type] else this
}
- /** Returns the result of running [[transformExpressions]] on this node
- * and all its children. */
+ /**
+ * Returns the result of running [[transformExpressions]] on this node
+ * and all its children.
+ */
def transformAllExpressions(rule: PartialFunction[Expression, Expression]): this.type = {
transform {
case q: QueryPlan[_] => q.transformExpressions(rule).asInstanceOf[PlanType]
@@ -315,18 +312,17 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT
/** Args that have cleaned such that differences in expression id should not affect equality */
protected lazy val cleanArgs: Seq[Any] = {
def cleanArg(arg: Any): Any = arg match {
+ // Children are checked using sameResult above.
+ case tn: TreeNode[_] if containsChild(tn) => null
case e: Expression => cleanExpression(e).canonicalized
case other => other
}
productIterator.map {
- // Children are checked using sameResult above.
- case tn: TreeNode[_] if containsChild(tn) => null
- case e: Expression => cleanArg(e)
case s: Option[_] => s.map(cleanArg)
case s: Seq[_] => s.map(cleanArg)
case m: Map[_, _] => m.mapValues(cleanArg)
- case other => other
+ case other => cleanArg(other)
}.toSeq
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala
index 9ca4f13dd7..13f57c54a5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala
@@ -26,13 +26,15 @@ object JoinType {
case "leftouter" | "left" => LeftOuter
case "rightouter" | "right" => RightOuter
case "leftsemi" => LeftSemi
+ case "leftanti" => LeftAnti
case _ =>
val supported = Seq(
"inner",
"outer", "full", "fullouter",
"leftouter", "left",
"rightouter", "right",
- "leftsemi")
+ "leftsemi",
+ "leftanti")
throw new IllegalArgumentException(s"Unsupported join type '$typ'. " +
"Supported join types include: " + supported.mkString("'", "', '", "'") + ".")
@@ -63,6 +65,10 @@ case object LeftSemi extends JoinType {
override def sql: String = "LEFT SEMI"
}
+case object LeftAnti extends JoinType {
+ override def sql: String = "LEFT ANTI"
+}
+
case class NaturalJoin(tpe: JoinType) extends JoinType {
require(Seq(Inner, LeftOuter, RightOuter, FullOuter).contains(tpe),
"Unsupported natural join type " + tpe)
@@ -70,7 +76,14 @@ case class NaturalJoin(tpe: JoinType) extends JoinType {
}
case class UsingJoin(tpe: JoinType, usingColumns: Seq[UnresolvedAttribute]) extends JoinType {
- require(Seq(Inner, LeftOuter, LeftSemi, RightOuter, FullOuter).contains(tpe),
+ require(Seq(Inner, LeftOuter, LeftSemi, RightOuter, FullOuter, LeftAnti).contains(tpe),
"Unsupported using join type " + tpe)
override def sql: String = "USING " + tpe.sql
}
+
+object LeftExistence {
+ def unapply(joinType: JoinType): Option[JoinType] = joinType match {
+ case LeftSemi | LeftAnti => Some(joinType)
+ case _ => None
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
index ecf4285c46..aceeb8aadc 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
@@ -79,13 +79,13 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging {
/**
* Computes [[Statistics]] for this plan. The default implementation assumes the output
- * cardinality is the product of of all child plan's cardinality, i.e. applies in the case
+ * cardinality is the product of all child plan's cardinality, i.e. applies in the case
* of cartesian joins.
*
* [[LeafNode]]s must override this.
*/
def statistics: Statistics = {
- if (children.size == 0) {
+ if (children.isEmpty) {
throw new UnsupportedOperationException(s"LeafNode $nodeName must implement statistics.")
}
Statistics(sizeInBytes = children.map(_.statistics.sizeInBytes).product)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
index 09c200fa83..d4fc9e4da9 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
@@ -236,10 +236,24 @@ case class Union(children: Seq[LogicalPlan]) extends LogicalPlan {
})
}
+ private def merge(a: Set[Expression], b: Set[Expression]): Set[Expression] = {
+ val common = a.intersect(b)
+ // The constraint with only one reference could be easily inferred as predicate
+ // Grouping the constraints by it's references so we can combine the constraints with same
+ // reference together
+ val othera = a.diff(common).filter(_.references.size == 1).groupBy(_.references.head)
+ val otherb = b.diff(common).filter(_.references.size == 1).groupBy(_.references.head)
+ // loose the constraints by: A1 && B1 || A2 && B2 -> (A1 || A2) && (B1 || B2)
+ val others = (othera.keySet intersect otherb.keySet).map { attr =>
+ Or(othera(attr).reduceLeft(And), otherb(attr).reduceLeft(And))
+ }
+ common ++ others
+ }
+
override protected def validConstraints: Set[Expression] = {
children
.map(child => rewriteConstraints(children.head.output, child.output, child.constraints))
- .reduce(_ intersect _)
+ .reduce(merge(_, _))
}
}
@@ -252,7 +266,7 @@ case class Join(
override def output: Seq[Attribute] = {
joinType match {
- case LeftSemi =>
+ case LeftExistence(_) =>
left.output
case LeftOuter =>
left.output ++ right.output.map(_.withNullability(true))
@@ -276,7 +290,7 @@ case class Join(
.union(splitConjunctivePredicates(condition.get).toSet)
case Inner =>
left.constraints.union(right.constraints)
- case LeftSemi =>
+ case LeftExistence(_) =>
left.constraints
case LeftOuter =>
left.constraints
@@ -519,7 +533,6 @@ case class Expand(
projections: Seq[Seq[Expression]],
output: Seq[Attribute],
child: LogicalPlan) extends UnaryNode {
-
override def references: AttributeSet =
AttributeSet(projections.flatten.flatMap(_.references))
@@ -527,6 +540,10 @@ case class Expand(
val sizeInBytes = super.statistics.sizeInBytes * projections.length
Statistics(sizeInBytes = sizeInBytes)
}
+
+ // This operator can reuse attributes (for example making them null when doing a roll up) so
+ // the contraints of the child may no longer be valid.
+ override protected def validConstraints: Set[Expression] = Set.empty[Expression]
}
/**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
index da7f81c785..6df46189b6 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
@@ -18,9 +18,45 @@
package org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.Encoder
+import org.apache.spark.sql.catalyst.analysis.UnresolvedDeserializer
import org.apache.spark.sql.catalyst.encoders._
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.types.{ObjectType, StructType}
+import org.apache.spark.sql.types.{DataType, ObjectType, StructType}
+
+object CatalystSerde {
+ def deserialize[T : Encoder](child: LogicalPlan): DeserializeToObject = {
+ val deserializer = UnresolvedDeserializer(encoderFor[T].deserializer)
+ DeserializeToObject(Alias(deserializer, "obj")(), child)
+ }
+
+ def serialize[T : Encoder](child: LogicalPlan): SerializeFromObject = {
+ SerializeFromObject(encoderFor[T].namedExpressions, child)
+ }
+}
+
+/**
+ * Takes the input row from child and turns it into object using the given deserializer expression.
+ * The output of this operator is a single-field safe row containing the deserialized object.
+ */
+case class DeserializeToObject(
+ deserializer: Alias,
+ child: LogicalPlan) extends UnaryNode {
+ override def output: Seq[Attribute] = deserializer.toAttribute :: Nil
+
+ def outputObjectType: DataType = deserializer.dataType
+}
+
+/**
+ * Takes the input object from child and turns in into unsafe row using the given serializer
+ * expression. The output of its child must be a single-field row containing the input object.
+ */
+case class SerializeFromObject(
+ serializer: Seq[NamedExpression],
+ child: LogicalPlan) extends UnaryNode {
+ override def output: Seq[Attribute] = serializer.map(_.toAttribute)
+
+ def inputObjectType: DataType = child.output.head.dataType
+}
/**
* A trait for logical operators that apply user defined functions to domain objects.
@@ -33,13 +69,6 @@ trait ObjectOperator extends LogicalPlan {
override def output: Seq[Attribute] = serializer.map(_.toAttribute)
/**
- * An [[ObjectOperator]] may have one or more deserializers to convert internal rows to objects.
- * It must also provide the attributes that are available during the resolution of each
- * deserializer.
- */
- def deserializers: Seq[(Expression, Seq[Attribute])]
-
- /**
* The object type that is produced by the user defined function. Note that the return type here
* is the same whether or not the operator is output serialized data.
*/
@@ -71,7 +100,7 @@ object MapPartitions {
child: LogicalPlan): MapPartitions = {
MapPartitions(
func.asInstanceOf[Iterator[Any] => Iterator[Any]],
- encoderFor[T].fromRowExpression,
+ UnresolvedDeserializer(encoderFor[T].deserializer),
encoderFor[U].namedExpressions,
child)
}
@@ -87,10 +116,32 @@ case class MapPartitions(
func: Iterator[Any] => Iterator[Any],
deserializer: Expression,
serializer: Seq[NamedExpression],
- child: LogicalPlan) extends UnaryNode with ObjectOperator {
- override def deserializers: Seq[(Expression, Seq[Attribute])] = Seq(deserializer -> child.output)
+ child: LogicalPlan) extends UnaryNode with ObjectOperator
+
+object MapElements {
+ def apply[T : Encoder, U : Encoder](
+ func: AnyRef,
+ child: LogicalPlan): MapElements = {
+ MapElements(
+ func,
+ UnresolvedDeserializer(encoderFor[T].deserializer),
+ encoderFor[U].namedExpressions,
+ child)
+ }
}
+/**
+ * A relation produced by applying `func` to each element of the `child`.
+ *
+ * @param deserializer used to extract the input to `func` from an input row.
+ * @param serializer use to serialize the output of `func`.
+ */
+case class MapElements(
+ func: AnyRef,
+ deserializer: Expression,
+ serializer: Seq[NamedExpression],
+ child: LogicalPlan) extends UnaryNode with ObjectOperator
+
/** Factory for constructing new `AppendColumn` nodes. */
object AppendColumns {
def apply[T : Encoder, U : Encoder](
@@ -98,7 +149,7 @@ object AppendColumns {
child: LogicalPlan): AppendColumns = {
new AppendColumns(
func.asInstanceOf[Any => Any],
- encoderFor[T].fromRowExpression,
+ UnresolvedDeserializer(encoderFor[T].deserializer),
encoderFor[U].namedExpressions,
child)
}
@@ -120,8 +171,6 @@ case class AppendColumns(
override def output: Seq[Attribute] = child.output ++ newColumns
def newColumns: Seq[Attribute] = serializer.map(_.toAttribute)
-
- override def deserializers: Seq[(Expression, Seq[Attribute])] = Seq(deserializer -> child.output)
}
/** Factory for constructing new `MapGroups` nodes. */
@@ -133,8 +182,8 @@ object MapGroups {
child: LogicalPlan): MapGroups = {
new MapGroups(
func.asInstanceOf[(Any, Iterator[Any]) => TraversableOnce[Any]],
- encoderFor[K].fromRowExpression,
- encoderFor[T].fromRowExpression,
+ UnresolvedDeserializer(encoderFor[K].deserializer, groupingAttributes),
+ UnresolvedDeserializer(encoderFor[T].deserializer, dataAttributes),
encoderFor[U].namedExpressions,
groupingAttributes,
dataAttributes,
@@ -158,11 +207,7 @@ case class MapGroups(
serializer: Seq[NamedExpression],
groupingAttributes: Seq[Attribute],
dataAttributes: Seq[Attribute],
- child: LogicalPlan) extends UnaryNode with ObjectOperator {
-
- override def deserializers: Seq[(Expression, Seq[Attribute])] =
- Seq(keyDeserializer -> groupingAttributes, valueDeserializer -> dataAttributes)
-}
+ child: LogicalPlan) extends UnaryNode with ObjectOperator
/** Factory for constructing new `CoGroup` nodes. */
object CoGroup {
@@ -170,22 +215,24 @@ object CoGroup {
func: (Key, Iterator[Left], Iterator[Right]) => TraversableOnce[Result],
leftGroup: Seq[Attribute],
rightGroup: Seq[Attribute],
- leftData: Seq[Attribute],
- rightData: Seq[Attribute],
+ leftAttr: Seq[Attribute],
+ rightAttr: Seq[Attribute],
left: LogicalPlan,
right: LogicalPlan): CoGroup = {
require(StructType.fromAttributes(leftGroup) == StructType.fromAttributes(rightGroup))
CoGroup(
func.asInstanceOf[(Any, Iterator[Any], Iterator[Any]) => TraversableOnce[Any]],
- encoderFor[Key].fromRowExpression,
- encoderFor[Left].fromRowExpression,
- encoderFor[Right].fromRowExpression,
+ // The `leftGroup` and `rightGroup` are guaranteed te be of same schema, so it's safe to
+ // resolve the `keyDeserializer` based on either of them, here we pick the left one.
+ UnresolvedDeserializer(encoderFor[Key].deserializer, leftGroup),
+ UnresolvedDeserializer(encoderFor[Left].deserializer, leftAttr),
+ UnresolvedDeserializer(encoderFor[Right].deserializer, rightAttr),
encoderFor[Result].namedExpressions,
leftGroup,
rightGroup,
- leftData,
- rightData,
+ leftAttr,
+ rightAttr,
left,
right)
}
@@ -206,10 +253,4 @@ case class CoGroup(
leftAttr: Seq[Attribute],
rightAttr: Seq[Attribute],
left: LogicalPlan,
- right: LogicalPlan) extends BinaryNode with ObjectOperator {
-
- override def deserializers: Seq[(Expression, Seq[Attribute])] =
- // The `leftGroup` and `rightGroup` are guaranteed te be of same schema, so it's safe to resolve
- // the `keyDeserializer` based on either of them, here we pick the left one.
- Seq(keyDeserializer -> leftGroup, leftDeserializer -> leftAttr, rightDeserializer -> rightAttr)
-}
+ right: LogicalPlan) extends BinaryNode with ObjectOperator
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
index be9f1ffa22..d449088498 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
@@ -76,9 +76,9 @@ case class OrderedDistribution(ordering: Seq[SortOrder]) extends Distribution {
}
/**
- * Represents data where tuples are broadcasted to every node. It is quite common that the
- * entire set of tuples is transformed into different data structure.
- */
+ * Represents data where tuples are broadcasted to every node. It is quite common that the
+ * entire set of tuples is transformed into different data structure.
+ */
case class BroadcastDistribution(mode: BroadcastMode) extends Distribution
/**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
index 6b7997e903..232ca43588 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
@@ -22,6 +22,7 @@ import java.util.UUID
import scala.collection.Map
import scala.collection.mutable.Stack
+import org.apache.commons.lang.ClassUtils
import org.json4s.JsonAST._
import org.json4s.JsonDSL._
import org.json4s.jackson.JsonMethods._
@@ -365,20 +366,32 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
* @param newArgs the new product arguments.
*/
def makeCopy(newArgs: Array[AnyRef]): BaseType = attachTree(this, "makeCopy") {
+ // Skip no-arg constructors that are just there for kryo.
val ctors = getClass.getConstructors.filter(_.getParameterTypes.size != 0)
if (ctors.isEmpty) {
sys.error(s"No valid constructor for $nodeName")
}
- val defaultCtor = ctors.maxBy(_.getParameterTypes.size)
+ val allArgs: Array[AnyRef] = if (otherCopyArgs.isEmpty) {
+ newArgs
+ } else {
+ newArgs ++ otherCopyArgs
+ }
+ val defaultCtor = ctors.find { ctor =>
+ if (ctor.getParameterTypes.length != allArgs.length) {
+ false
+ } else if (allArgs.contains(null)) {
+ // if there is a `null`, we can't figure out the class, therefore we should just fallback
+ // to older heuristic
+ false
+ } else {
+ val argsArray: Array[Class[_]] = allArgs.map(_.getClass)
+ ClassUtils.isAssignable(argsArray, ctor.getParameterTypes, true /* autoboxing */)
+ }
+ }.getOrElse(ctors.maxBy(_.getParameterTypes.length)) // fall back to older heuristic
try {
CurrentOrigin.withOrigin(origin) {
- // Skip no-arg constructors that are just there for kryo.
- if (otherCopyArgs.isEmpty) {
- defaultCtor.newInstance(newArgs: _*).asInstanceOf[BaseType]
- } else {
- defaultCtor.newInstance((newArgs ++ otherCopyArgs).toArray: _*).asInstanceOf[BaseType]
- }
+ defaultCtor.newInstance(allArgs.toArray: _*).asInstanceOf[BaseType]
}
} catch {
case e: java.lang.IllegalArgumentException =>
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringKeyHashMap.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringKeyHashMap.scala
index 191d5e6399..d5d151a580 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringKeyHashMap.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringKeyHashMap.scala
@@ -41,4 +41,6 @@ class StringKeyHashMap[T](normalizer: (String) => String) {
def remove(key: String): Option[T] = base.remove(normalizer(key))
def iterator: Iterator[(String, T)] = base.toIterator
+
+ def clear(): Unit = base.clear()
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala
index c2eeb3c565..cde8bd5b96 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql.catalyst.util
-import java.util.regex.Pattern
+import java.util.regex.{Pattern, PatternSyntaxException}
import org.apache.spark.unsafe.types.UTF8String
@@ -52,4 +52,25 @@ object StringUtils {
def isTrueString(s: UTF8String): Boolean = trueStrings.contains(s.toLowerCase)
def isFalseString(s: UTF8String): Boolean = falseStrings.contains(s.toLowerCase)
+
+ /**
+ * This utility can be used for filtering pattern in the "Like" of "Show Tables / Functions" DDL
+ * @param names the names list to be filtered
+ * @param pattern the filter pattern, only '*' and '|' are allowed as wildcards, others will
+ * follow regular expression convention, case insensitive match and white spaces
+ * on both ends will be ignored
+ * @return the filtered names list in order
+ */
+ def filterPattern(names: Seq[String], pattern: String): Seq[String] = {
+ val funcNames = scala.collection.mutable.SortedSet.empty[String]
+ pattern.trim().split("\\|").foreach { subPattern =>
+ try {
+ val regex = ("(?i)" + subPattern.replaceAll("\\*", ".*")).r
+ funcNames ++= names.filter{ name => regex.pattern.matcher(name).matches() }
+ } catch {
+ case _: PatternSyntaxException =>
+ }
+ }
+ funcNames.toSeq
+ }
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala
index b11365b297..f879b34358 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala
@@ -155,10 +155,13 @@ package object util {
/**
* Returns the string representation of this expression that is safe to be put in
- * code comments of generated code.
+ * code comments of generated code. The length is capped at 128 characters.
*/
- def toCommentSafeString(str: String): String =
- str.replace("*/", "\\*\\/").replace("\\u", "\\\\u")
+ def toCommentSafeString(str: String): String = {
+ val len = math.min(str.length, 128)
+ val suffix = if (str.length > len) "..." else ""
+ str.substring(0, len).replace("*/", "\\*\\/").replace("\\u", "\\\\u") + suffix
+ }
/* FIX ME
implicit class debugLogging(a: Any) {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala
index dabf9a2fc0..fb7251d71b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala
@@ -23,7 +23,6 @@ import org.json4s.JsonDSL._
import org.apache.spark.annotation.DeveloperApi
/**
- * ::DeveloperApi::
* The data type for User Defined Types (UDTs).
*
* This interface allows a user to make their own classes more interoperable with SparkSQL;
@@ -35,8 +34,11 @@ import org.apache.spark.annotation.DeveloperApi
*
* The conversion via `serialize` occurs when instantiating a `DataFrame` from another RDD.
* The conversion via `deserialize` occurs when reading from a `DataFrame`.
+ *
+ * Note: This was previously a developer API in Spark 1.x. We are making this private in Spark 2.0
+ * because we will very likely create a new version of this that works better with Datasets.
*/
-@DeveloperApi
+private[spark]
abstract class UserDefinedType[UserType >: Null] extends DataType with Serializable {
/** Underlying storage type for this UDT */
diff --git a/sql/catalyst/src/test/resources/log4j.properties b/sql/catalyst/src/test/resources/log4j.properties
index eb3b1999eb..3706a6e361 100644
--- a/sql/catalyst/src/test/resources/log4j.properties
+++ b/sql/catalyst/src/test/resources/log4j.properties
@@ -24,5 +24,4 @@ log4j.appender.file.layout=org.apache.log4j.PatternLayout
log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n
# Ignore messages below warning level from Jetty, because it's a bit verbose
-log4j.logger.org.spark-project.jetty=WARN
-org.spark-project.jetty.LEVEL=WARN
+log4j.logger.org.spark_project.jetty=WARN
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala
index 8207d64798..711e870711 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala
@@ -196,12 +196,11 @@ object RandomDataGenerator {
case ShortType => randomNumeric[Short](
rand, _.nextInt().toShort, Seq(Short.MinValue, Short.MaxValue, 0.toShort))
case NullType => Some(() => null)
- case ArrayType(elementType, containsNull) => {
+ case ArrayType(elementType, containsNull) =>
forType(elementType, nullable = containsNull, rand).map {
elementGenerator => () => Seq.fill(rand.nextInt(MAX_ARR_SIZE))(elementGenerator())
}
- }
- case MapType(keyType, valueType, valueContainsNull) => {
+ case MapType(keyType, valueType, valueContainsNull) =>
for (
keyGenerator <- forType(keyType, nullable = false, rand);
valueGenerator <-
@@ -221,8 +220,7 @@ object RandomDataGenerator {
keys.zip(values).toMap
}
}
- }
- case StructType(fields) => {
+ case StructType(fields) =>
val maybeFieldGenerators: Seq[Option[() => Any]] = fields.map { field =>
forType(field.dataType, nullable = field.nullable, rand)
}
@@ -232,8 +230,7 @@ object RandomDataGenerator {
} else {
None
}
- }
- case udt: UserDefinedType[_] => {
+ case udt: UserDefinedType[_] =>
val maybeSqlTypeGenerator = forType(udt.sqlType, nullable, rand)
// Because random data generator at here returns scala value, we need to
// convert it to catalyst value to call udt's deserialize.
@@ -253,7 +250,6 @@ object RandomDataGenerator {
} else {
None
}
- }
case unsupportedType => None
}
// Handle nullability by wrapping the non-null value generator:
@@ -277,7 +273,7 @@ object RandomDataGenerator {
val fields = mutable.ArrayBuffer.empty[Any]
schema.fields.foreach { f =>
f.dataType match {
- case ArrayType(childType, nullable) => {
+ case ArrayType(childType, nullable) =>
val data = if (f.nullable && rand.nextFloat() <= PROBABILITY_OF_NULL) {
null
} else {
@@ -294,10 +290,8 @@ object RandomDataGenerator {
arr
}
fields += data
- }
- case StructType(children) => {
+ case StructType(children) =>
fields += randomRow(rand, StructType(children))
- }
case _ =>
val generator = RandomDataGenerator.forType(f.dataType, f.nullable, rand)
assert(generator.isDefined, "Unsupported type")
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/RowTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/RowTest.scala
index d9577dea1b..c9c9599e7f 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/RowTest.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/RowTest.scala
@@ -121,7 +121,7 @@ class RowTest extends FunSpec with Matchers {
externalRow should be theSameInstanceAs externalRow.copy()
}
- it("copy should return same ref for interal rows") {
+ it("copy should return same ref for internal rows") {
internalRow should be theSameInstanceAs internalRow.copy()
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
index dd31050bb5..5ca5a72512 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
@@ -248,10 +248,10 @@ class ScalaReflectionSuite extends SparkFunSuite {
Seq(
("mirror", () => mirror),
("dataTypeFor", () => dataTypeFor[ComplexData]),
- ("constructorFor", () => constructorFor[ComplexData]),
+ ("constructorFor", () => deserializerFor[ComplexData]),
("extractorsFor", {
val inputObject = BoundReference(0, dataTypeForComplexData, nullable = false)
- () => extractorsFor[ComplexData](inputObject)
+ () => serializerFor[ComplexData](inputObject)
}),
("getConstructorParameters(cls)", () => getConstructorParameters(classOf[ComplexData])),
("getConstructorParameterNames", () => getConstructorParameterNames(classOf[ComplexData])),
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
index a90dfc5039..ad101d1c40 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
@@ -272,6 +272,62 @@ class AnalysisErrorSuite extends AnalysisTest {
testRelation2.where('bad_column > 1).groupBy('a)(UnresolvedAlias(max('b))),
"cannot resolve '`bad_column`'" :: Nil)
+ errorTest(
+ "slide duration greater than window in time window",
+ testRelation2.select(
+ TimeWindow(Literal("2016-01-01 01:01:01"), "1 second", "2 second", "0 second").as("window")),
+ s"The slide duration " :: " must be less than or equal to the windowDuration " :: Nil
+ )
+
+ errorTest(
+ "start time greater than slide duration in time window",
+ testRelation.select(
+ TimeWindow(Literal("2016-01-01 01:01:01"), "1 second", "1 second", "1 minute").as("window")),
+ "The start time " :: " must be less than the slideDuration " :: Nil
+ )
+
+ errorTest(
+ "start time equal to slide duration in time window",
+ testRelation.select(
+ TimeWindow(Literal("2016-01-01 01:01:01"), "1 second", "1 second", "1 second").as("window")),
+ "The start time " :: " must be less than the slideDuration " :: Nil
+ )
+
+ errorTest(
+ "negative window duration in time window",
+ testRelation.select(
+ TimeWindow(Literal("2016-01-01 01:01:01"), "-1 second", "1 second", "0 second").as("window")),
+ "The window duration " :: " must be greater than 0." :: Nil
+ )
+
+ errorTest(
+ "zero window duration in time window",
+ testRelation.select(
+ TimeWindow(Literal("2016-01-01 01:01:01"), "0 second", "1 second", "0 second").as("window")),
+ "The window duration " :: " must be greater than 0." :: Nil
+ )
+
+ errorTest(
+ "negative slide duration in time window",
+ testRelation.select(
+ TimeWindow(Literal("2016-01-01 01:01:01"), "1 second", "-1 second", "0 second").as("window")),
+ "The slide duration " :: " must be greater than 0." :: Nil
+ )
+
+ errorTest(
+ "zero slide duration in time window",
+ testRelation.select(
+ TimeWindow(Literal("2016-01-01 01:01:01"), "1 second", "0 second", "0 second").as("window")),
+ "The slide duration" :: " must be greater than 0." :: Nil
+ )
+
+ errorTest(
+ "negative start time in time window",
+ testRelation.select(
+ TimeWindow(Literal("2016-01-01 01:01:01"), "1 second", "1 second", "-5 second").as("window")),
+ "The start time" :: "must be greater than or equal to 0." :: Nil
+ )
+
test("SPARK-6452 regression test") {
// CheckAnalysis should throw AnalysisException when Aggregate contains missing attribute(s)
// Since we manually construct the logical plan at here and Sum only accept
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala
index 6fa4beed99..b1fcf011f4 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala
@@ -30,9 +30,9 @@ trait AnalysisTest extends PlanTest {
private def makeAnalyzer(caseSensitive: Boolean): Analyzer = {
val conf = new SimpleCatalystConf(caseSensitive)
- val catalog = new SessionCatalog(new InMemoryCatalog, conf)
- catalog.createTempTable("TaBlE", TestRelations.testRelation, ignoreIfExists = true)
- new Analyzer(catalog, EmptyFunctionRegistry, conf) {
+ val catalog = new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry, conf)
+ catalog.createTempTable("TaBlE", TestRelations.testRelation, overrideIfExists = true)
+ new Analyzer(catalog, conf) {
override val extendedResolutionRules = EliminateSubqueryAliases :: Nil
}
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala
index 31501864a8..b3b1f5b920 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala
@@ -32,8 +32,8 @@ import org.apache.spark.sql.types._
class DecimalPrecisionSuite extends PlanTest with BeforeAndAfter {
private val conf = new SimpleCatalystConf(caseSensitiveAnalysis = true)
- private val catalog = new SessionCatalog(new InMemoryCatalog, conf)
- private val analyzer = new Analyzer(catalog, EmptyFunctionRegistry, conf)
+ private val catalog = new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry, conf)
+ private val analyzer = new Analyzer(catalog, conf)
private val relation = LocalRelation(
AttributeReference("i", IntegerType)(),
@@ -52,7 +52,7 @@ class DecimalPrecisionSuite extends PlanTest with BeforeAndAfter {
private val b: Expression = UnresolvedAttribute("b")
before {
- catalog.createTempTable("table", relation, ignoreIfExists = true)
+ catalog.createTempTable("table", relation, overrideIfExists = true)
}
private def checkType(expression: Expression, expectedType: DataType): Unit = {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/CatalogTestCases.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/CatalogTestCases.scala
index 277c2d717e..f961fe3292 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/CatalogTestCases.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/CatalogTestCases.scala
@@ -149,6 +149,15 @@ abstract class CatalogTestCases extends SparkFunSuite with BeforeAndAfterEach {
// Tables
// --------------------------------------------------------------------------
+ test("the table type of an external table should be EXTERNAL_TABLE") {
+ val catalog = newBasicCatalog()
+ val table =
+ newTable("external_table1", "db2").copy(tableType = CatalogTableType.EXTERNAL_TABLE)
+ catalog.createTable("db2", table, ignoreIfExists = false)
+ val actual = catalog.getTable("db2", "external_table1")
+ assert(actual.tableType === CatalogTableType.EXTERNAL_TABLE)
+ }
+
test("drop table") {
val catalog = newBasicCatalog()
assert(catalog.listTables("db2").toSet == Set("tbl1", "tbl2"))
@@ -210,7 +219,7 @@ abstract class CatalogTestCases extends SparkFunSuite with BeforeAndAfterEach {
}
test("get table") {
- assert(newBasicCatalog().getTable("db2", "tbl1").name.table == "tbl1")
+ assert(newBasicCatalog().getTable("db2", "tbl1").identifier.table == "tbl1")
}
test("get table when database/table does not exist") {
@@ -272,31 +281,37 @@ abstract class CatalogTestCases extends SparkFunSuite with BeforeAndAfterEach {
test("drop partitions") {
val catalog = newBasicCatalog()
assert(catalogPartitionsEqual(catalog, "db2", "tbl2", Seq(part1, part2)))
- catalog.dropPartitions("db2", "tbl2", Seq(part1.spec), ignoreIfNotExists = false)
+ catalog.dropPartitions(
+ "db2", "tbl2", Seq(part1.spec), ignoreIfNotExists = false)
assert(catalogPartitionsEqual(catalog, "db2", "tbl2", Seq(part2)))
resetState()
val catalog2 = newBasicCatalog()
assert(catalogPartitionsEqual(catalog2, "db2", "tbl2", Seq(part1, part2)))
- catalog2.dropPartitions("db2", "tbl2", Seq(part1.spec, part2.spec), ignoreIfNotExists = false)
+ catalog2.dropPartitions(
+ "db2", "tbl2", Seq(part1.spec, part2.spec), ignoreIfNotExists = false)
assert(catalog2.listPartitions("db2", "tbl2").isEmpty)
}
test("drop partitions when database/table does not exist") {
val catalog = newBasicCatalog()
intercept[AnalysisException] {
- catalog.dropPartitions("does_not_exist", "tbl1", Seq(), ignoreIfNotExists = false)
+ catalog.dropPartitions(
+ "does_not_exist", "tbl1", Seq(), ignoreIfNotExists = false)
}
intercept[AnalysisException] {
- catalog.dropPartitions("db2", "does_not_exist", Seq(), ignoreIfNotExists = false)
+ catalog.dropPartitions(
+ "db2", "does_not_exist", Seq(), ignoreIfNotExists = false)
}
}
test("drop partitions that do not exist") {
val catalog = newBasicCatalog()
intercept[AnalysisException] {
- catalog.dropPartitions("db2", "tbl2", Seq(part3.spec), ignoreIfNotExists = false)
+ catalog.dropPartitions(
+ "db2", "tbl2", Seq(part3.spec), ignoreIfNotExists = false)
}
- catalog.dropPartitions("db2", "tbl2", Seq(part3.spec), ignoreIfNotExists = true)
+ catalog.dropPartitions(
+ "db2", "tbl2", Seq(part3.spec), ignoreIfNotExists = true)
}
test("get partition") {
@@ -433,7 +448,8 @@ abstract class CatalogTestCases extends SparkFunSuite with BeforeAndAfterEach {
test("get function") {
val catalog = newBasicCatalog()
assert(catalog.getFunction("db2", "func1") ==
- CatalogFunction(FunctionIdentifier("func1", Some("db2")), funcClass))
+ CatalogFunction(FunctionIdentifier("func1", Some("db2")), funcClass,
+ Seq.empty[(String, String)]))
intercept[AnalysisException] {
catalog.getFunction("db2", "does_not_exist")
}
@@ -452,7 +468,7 @@ abstract class CatalogTestCases extends SparkFunSuite with BeforeAndAfterEach {
assert(catalog.getFunction("db2", "func1").className == funcClass)
catalog.renameFunction("db2", "func1", newName)
intercept[AnalysisException] { catalog.getFunction("db2", "func1") }
- assert(catalog.getFunction("db2", newName).name.funcName == newName)
+ assert(catalog.getFunction("db2", newName).identifier.funcName == newName)
assert(catalog.getFunction("db2", newName).className == funcClass)
intercept[AnalysisException] { catalog.renameFunction("db2", "does_not_exist", "me") }
}
@@ -464,21 +480,6 @@ abstract class CatalogTestCases extends SparkFunSuite with BeforeAndAfterEach {
}
}
- test("alter function") {
- val catalog = newBasicCatalog()
- assert(catalog.getFunction("db2", "func1").className == funcClass)
- catalog.alterFunction("db2", newFunc("func1").copy(className = "muhaha"))
- assert(catalog.getFunction("db2", "func1").className == "muhaha")
- intercept[AnalysisException] { catalog.alterFunction("db2", newFunc("funcky")) }
- }
-
- test("alter function when database does not exist") {
- val catalog = newBasicCatalog()
- intercept[AnalysisException] {
- catalog.alterFunction("does_not_exist", newFunc())
- }
- }
-
test("list functions") {
val catalog = newBasicCatalog()
catalog.createFunction("db2", newFunc("func2"))
@@ -549,15 +550,19 @@ abstract class CatalogTestUtils {
def newTable(name: String, database: Option[String] = None): CatalogTable = {
CatalogTable(
- name = TableIdentifier(name, database),
+ identifier = TableIdentifier(name, database),
tableType = CatalogTableType.EXTERNAL_TABLE,
storage = storageFormat,
- schema = Seq(CatalogColumn("col1", "int"), CatalogColumn("col2", "string")),
- partitionColumns = Seq(CatalogColumn("a", "int"), CatalogColumn("b", "string")))
+ schema = Seq(
+ CatalogColumn("col1", "int"),
+ CatalogColumn("col2", "string"),
+ CatalogColumn("a", "int"),
+ CatalogColumn("b", "string")),
+ partitionColumnNames = Seq("a", "b"))
}
def newFunc(name: String, database: Option[String] = None): CatalogFunction = {
- CatalogFunction(FunctionIdentifier(name, database), funcClass)
+ CatalogFunction(FunctionIdentifier(name, database), funcClass, Seq.empty[(String, String)])
}
/**
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala
index 74e995cc5b..426273e1e3 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.catalog
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier}
+import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionInfo, Literal}
import org.apache.spark.sql.catalyst.plans.logical.{Range, SubqueryAlias}
@@ -61,7 +62,7 @@ class SessionCatalogSuite extends SparkFunSuite {
test("get database when a database exists") {
val catalog = new SessionCatalog(newBasicCatalog())
- val db1 = catalog.getDatabase("db1")
+ val db1 = catalog.getDatabaseMetadata("db1")
assert(db1.name == "db1")
assert(db1.description.contains("db1"))
}
@@ -69,7 +70,7 @@ class SessionCatalogSuite extends SparkFunSuite {
test("get database should throw exception when the database does not exist") {
val catalog = new SessionCatalog(newBasicCatalog())
intercept[AnalysisException] {
- catalog.getDatabase("db_that_does_not_exist")
+ catalog.getDatabaseMetadata("db_that_does_not_exist")
}
}
@@ -127,10 +128,10 @@ class SessionCatalogSuite extends SparkFunSuite {
test("alter database") {
val catalog = new SessionCatalog(newBasicCatalog())
- val db1 = catalog.getDatabase("db1")
+ val db1 = catalog.getDatabaseMetadata("db1")
// Note: alter properties here because Hive does not support altering other fields
catalog.alterDatabase(db1.copy(properties = Map("k" -> "v3", "good" -> "true")))
- val newDb1 = catalog.getDatabase("db1")
+ val newDb1 = catalog.getDatabaseMetadata("db1")
assert(db1.properties.isEmpty)
assert(newDb1.properties.size == 2)
assert(newDb1.properties.get("k") == Some("v3"))
@@ -197,17 +198,17 @@ class SessionCatalogSuite extends SparkFunSuite {
val catalog = new SessionCatalog(newBasicCatalog())
val tempTable1 = Range(1, 10, 1, 10, Seq())
val tempTable2 = Range(1, 20, 2, 10, Seq())
- catalog.createTempTable("tbl1", tempTable1, ignoreIfExists = false)
- catalog.createTempTable("tbl2", tempTable2, ignoreIfExists = false)
+ catalog.createTempTable("tbl1", tempTable1, overrideIfExists = false)
+ catalog.createTempTable("tbl2", tempTable2, overrideIfExists = false)
assert(catalog.getTempTable("tbl1") == Some(tempTable1))
assert(catalog.getTempTable("tbl2") == Some(tempTable2))
assert(catalog.getTempTable("tbl3") == None)
// Temporary table already exists
intercept[AnalysisException] {
- catalog.createTempTable("tbl1", tempTable1, ignoreIfExists = false)
+ catalog.createTempTable("tbl1", tempTable1, overrideIfExists = false)
}
// Temporary table already exists but we override it
- catalog.createTempTable("tbl1", tempTable2, ignoreIfExists = true)
+ catalog.createTempTable("tbl1", tempTable2, overrideIfExists = true)
assert(catalog.getTempTable("tbl1") == Some(tempTable2))
}
@@ -232,10 +233,9 @@ class SessionCatalogSuite extends SparkFunSuite {
intercept[AnalysisException] {
catalog.dropTable(TableIdentifier("tbl1", Some("unknown_db")), ignoreIfNotExists = true)
}
- // Table does not exist
- intercept[AnalysisException] {
- catalog.dropTable(TableIdentifier("unknown_table", Some("db2")), ignoreIfNotExists = false)
- }
+ // If the table does not exist, we do not issue an exception. Instead, we output an error log
+ // message to console when ignoreIfNotExists is set to false.
+ catalog.dropTable(TableIdentifier("unknown_table", Some("db2")), ignoreIfNotExists = false)
catalog.dropTable(TableIdentifier("unknown_table", Some("db2")), ignoreIfNotExists = true)
}
@@ -243,7 +243,7 @@ class SessionCatalogSuite extends SparkFunSuite {
val externalCatalog = newBasicCatalog()
val sessionCatalog = new SessionCatalog(externalCatalog)
val tempTable = Range(1, 10, 2, 10, Seq())
- sessionCatalog.createTempTable("tbl1", tempTable, ignoreIfExists = false)
+ sessionCatalog.createTempTable("tbl1", tempTable, overrideIfExists = false)
sessionCatalog.setCurrentDatabase("db2")
assert(sessionCatalog.getTempTable("tbl1") == Some(tempTable))
assert(externalCatalog.listTables("db2").toSet == Set("tbl1", "tbl2"))
@@ -255,7 +255,7 @@ class SessionCatalogSuite extends SparkFunSuite {
sessionCatalog.dropTable(TableIdentifier("tbl1"), ignoreIfNotExists = false)
assert(externalCatalog.listTables("db2").toSet == Set("tbl2"))
// If database is specified, temp tables are never dropped
- sessionCatalog.createTempTable("tbl1", tempTable, ignoreIfExists = false)
+ sessionCatalog.createTempTable("tbl1", tempTable, overrideIfExists = false)
sessionCatalog.createTable(newTable("tbl1", "db2"), ignoreIfExists = false)
sessionCatalog.dropTable(TableIdentifier("tbl1", Some("db2")), ignoreIfNotExists = false)
assert(sessionCatalog.getTempTable("tbl1") == Some(tempTable))
@@ -299,7 +299,7 @@ class SessionCatalogSuite extends SparkFunSuite {
val externalCatalog = newBasicCatalog()
val sessionCatalog = new SessionCatalog(externalCatalog)
val tempTable = Range(1, 10, 2, 10, Seq())
- sessionCatalog.createTempTable("tbl1", tempTable, ignoreIfExists = false)
+ sessionCatalog.createTempTable("tbl1", tempTable, overrideIfExists = false)
sessionCatalog.setCurrentDatabase("db2")
assert(sessionCatalog.getTempTable("tbl1") == Some(tempTable))
assert(externalCatalog.listTables("db2").toSet == Set("tbl1", "tbl2"))
@@ -327,7 +327,7 @@ class SessionCatalogSuite extends SparkFunSuite {
assert(newTbl1.properties.get("toh") == Some("frem"))
// Alter table without explicitly specifying database
sessionCatalog.setCurrentDatabase("db2")
- sessionCatalog.alterTable(tbl1.copy(name = TableIdentifier("tbl1")))
+ sessionCatalog.alterTable(tbl1.copy(identifier = TableIdentifier("tbl1")))
val newestTbl1 = externalCatalog.getTable("db2", "tbl1")
assert(newestTbl1 == tbl1)
}
@@ -345,21 +345,21 @@ class SessionCatalogSuite extends SparkFunSuite {
test("get table") {
val externalCatalog = newBasicCatalog()
val sessionCatalog = new SessionCatalog(externalCatalog)
- assert(sessionCatalog.getTable(TableIdentifier("tbl1", Some("db2")))
+ assert(sessionCatalog.getTableMetadata(TableIdentifier("tbl1", Some("db2")))
== externalCatalog.getTable("db2", "tbl1"))
// Get table without explicitly specifying database
sessionCatalog.setCurrentDatabase("db2")
- assert(sessionCatalog.getTable(TableIdentifier("tbl1"))
+ assert(sessionCatalog.getTableMetadata(TableIdentifier("tbl1"))
== externalCatalog.getTable("db2", "tbl1"))
}
test("get table when database/table does not exist") {
val catalog = new SessionCatalog(newBasicCatalog())
intercept[AnalysisException] {
- catalog.getTable(TableIdentifier("tbl1", Some("unknown_db")))
+ catalog.getTableMetadata(TableIdentifier("tbl1", Some("unknown_db")))
}
intercept[AnalysisException] {
- catalog.getTable(TableIdentifier("unknown_table", Some("db2")))
+ catalog.getTableMetadata(TableIdentifier("unknown_table", Some("db2")))
}
}
@@ -368,7 +368,7 @@ class SessionCatalogSuite extends SparkFunSuite {
val sessionCatalog = new SessionCatalog(externalCatalog)
val tempTable1 = Range(1, 10, 1, 10, Seq())
val metastoreTable1 = externalCatalog.getTable("db2", "tbl1")
- sessionCatalog.createTempTable("tbl1", tempTable1, ignoreIfExists = false)
+ sessionCatalog.createTempTable("tbl1", tempTable1, overrideIfExists = false)
sessionCatalog.setCurrentDatabase("db2")
// If we explicitly specify the database, we'll look up the relation in that database
assert(sessionCatalog.lookupRelation(TableIdentifier("tbl1", Some("db2")))
@@ -385,7 +385,7 @@ class SessionCatalogSuite extends SparkFunSuite {
test("lookup table relation with alias") {
val catalog = new SessionCatalog(newBasicCatalog())
val alias = "monster"
- val tableMetadata = catalog.getTable(TableIdentifier("tbl1", Some("db2")))
+ val tableMetadata = catalog.getTableMetadata(TableIdentifier("tbl1", Some("db2")))
val relation = SubqueryAlias("tbl1", CatalogRelation("db2", tableMetadata))
val relationWithAlias =
SubqueryAlias(alias,
@@ -406,7 +406,7 @@ class SessionCatalogSuite extends SparkFunSuite {
assert(!catalog.tableExists(TableIdentifier("tbl2", Some("db1"))))
// If database is explicitly specified, do not check temporary tables
val tempTable = Range(1, 10, 1, 10, Seq())
- catalog.createTempTable("tbl3", tempTable, ignoreIfExists = false)
+ catalog.createTempTable("tbl3", tempTable, overrideIfExists = false)
assert(!catalog.tableExists(TableIdentifier("tbl3", Some("db2"))))
// If database is not explicitly specified, check the current database
catalog.setCurrentDatabase("db2")
@@ -418,8 +418,8 @@ class SessionCatalogSuite extends SparkFunSuite {
test("list tables without pattern") {
val catalog = new SessionCatalog(newBasicCatalog())
val tempTable = Range(1, 10, 2, 10, Seq())
- catalog.createTempTable("tbl1", tempTable, ignoreIfExists = false)
- catalog.createTempTable("tbl4", tempTable, ignoreIfExists = false)
+ catalog.createTempTable("tbl1", tempTable, overrideIfExists = false)
+ catalog.createTempTable("tbl4", tempTable, overrideIfExists = false)
assert(catalog.listTables("db1").toSet ==
Set(TableIdentifier("tbl1"), TableIdentifier("tbl4")))
assert(catalog.listTables("db2").toSet ==
@@ -435,8 +435,8 @@ class SessionCatalogSuite extends SparkFunSuite {
test("list tables with pattern") {
val catalog = new SessionCatalog(newBasicCatalog())
val tempTable = Range(1, 10, 2, 10, Seq())
- catalog.createTempTable("tbl1", tempTable, ignoreIfExists = false)
- catalog.createTempTable("tbl4", tempTable, ignoreIfExists = false)
+ catalog.createTempTable("tbl1", tempTable, overrideIfExists = false)
+ catalog.createTempTable("tbl4", tempTable, overrideIfExists = false)
assert(catalog.listTables("db1", "*").toSet == catalog.listTables("db1").toSet)
assert(catalog.listTables("db2", "*").toSet == catalog.listTables("db2").toSet)
assert(catalog.listTables("db2", "tbl*").toSet ==
@@ -496,19 +496,25 @@ class SessionCatalogSuite extends SparkFunSuite {
val sessionCatalog = new SessionCatalog(externalCatalog)
assert(catalogPartitionsEqual(externalCatalog, "db2", "tbl2", Seq(part1, part2)))
sessionCatalog.dropPartitions(
- TableIdentifier("tbl2", Some("db2")), Seq(part1.spec), ignoreIfNotExists = false)
+ TableIdentifier("tbl2", Some("db2")),
+ Seq(part1.spec),
+ ignoreIfNotExists = false)
assert(catalogPartitionsEqual(externalCatalog, "db2", "tbl2", Seq(part2)))
// Drop partitions without explicitly specifying database
sessionCatalog.setCurrentDatabase("db2")
sessionCatalog.dropPartitions(
- TableIdentifier("tbl2"), Seq(part2.spec), ignoreIfNotExists = false)
+ TableIdentifier("tbl2"),
+ Seq(part2.spec),
+ ignoreIfNotExists = false)
assert(externalCatalog.listPartitions("db2", "tbl2").isEmpty)
// Drop multiple partitions at once
sessionCatalog.createPartitions(
TableIdentifier("tbl2", Some("db2")), Seq(part1, part2), ignoreIfExists = false)
assert(catalogPartitionsEqual(externalCatalog, "db2", "tbl2", Seq(part1, part2)))
sessionCatalog.dropPartitions(
- TableIdentifier("tbl2", Some("db2")), Seq(part1.spec, part2.spec), ignoreIfNotExists = false)
+ TableIdentifier("tbl2", Some("db2")),
+ Seq(part1.spec, part2.spec),
+ ignoreIfNotExists = false)
assert(externalCatalog.listPartitions("db2", "tbl2").isEmpty)
}
@@ -516,11 +522,15 @@ class SessionCatalogSuite extends SparkFunSuite {
val catalog = new SessionCatalog(newBasicCatalog())
intercept[AnalysisException] {
catalog.dropPartitions(
- TableIdentifier("tbl1", Some("does_not_exist")), Seq(), ignoreIfNotExists = false)
+ TableIdentifier("tbl1", Some("does_not_exist")),
+ Seq(),
+ ignoreIfNotExists = false)
}
intercept[AnalysisException] {
catalog.dropPartitions(
- TableIdentifier("does_not_exist", Some("db2")), Seq(), ignoreIfNotExists = false)
+ TableIdentifier("does_not_exist", Some("db2")),
+ Seq(),
+ ignoreIfNotExists = false)
}
}
@@ -528,10 +538,14 @@ class SessionCatalogSuite extends SparkFunSuite {
val catalog = new SessionCatalog(newBasicCatalog())
intercept[AnalysisException] {
catalog.dropPartitions(
- TableIdentifier("tbl2", Some("db2")), Seq(part3.spec), ignoreIfNotExists = false)
+ TableIdentifier("tbl2", Some("db2")),
+ Seq(part3.spec),
+ ignoreIfNotExists = false)
}
catalog.dropPartitions(
- TableIdentifier("tbl2", Some("db2")), Seq(part3.spec), ignoreIfNotExists = true)
+ TableIdentifier("tbl2", Some("db2")),
+ Seq(part3.spec),
+ ignoreIfNotExists = true)
}
test("get partition") {
@@ -658,78 +672,94 @@ class SessionCatalogSuite extends SparkFunSuite {
val externalCatalog = newEmptyCatalog()
val sessionCatalog = new SessionCatalog(externalCatalog)
sessionCatalog.createDatabase(newDb("mydb"), ignoreIfExists = false)
- sessionCatalog.createFunction(newFunc("myfunc", Some("mydb")))
+ sessionCatalog.createFunction(newFunc("myfunc", Some("mydb")), ignoreIfExists = false)
assert(externalCatalog.listFunctions("mydb", "*").toSet == Set("myfunc"))
// Create function without explicitly specifying database
sessionCatalog.setCurrentDatabase("mydb")
- sessionCatalog.createFunction(newFunc("myfunc2"))
+ sessionCatalog.createFunction(newFunc("myfunc2"), ignoreIfExists = false)
assert(externalCatalog.listFunctions("mydb", "*").toSet == Set("myfunc", "myfunc2"))
}
test("create function when database does not exist") {
val catalog = new SessionCatalog(newBasicCatalog())
intercept[AnalysisException] {
- catalog.createFunction(newFunc("func5", Some("does_not_exist")))
+ catalog.createFunction(
+ newFunc("func5", Some("does_not_exist")), ignoreIfExists = false)
}
}
test("create function that already exists") {
val catalog = new SessionCatalog(newBasicCatalog())
intercept[AnalysisException] {
- catalog.createFunction(newFunc("func1", Some("db2")))
+ catalog.createFunction(newFunc("func1", Some("db2")), ignoreIfExists = false)
}
+ catalog.createFunction(newFunc("func1", Some("db2")), ignoreIfExists = true)
}
test("create temp function") {
val catalog = new SessionCatalog(newBasicCatalog())
- val tempFunc1 = newFunc("temp1")
- val tempFunc2 = newFunc("temp2")
- catalog.createTempFunction(tempFunc1, ignoreIfExists = false)
- catalog.createTempFunction(tempFunc2, ignoreIfExists = false)
- assert(catalog.getTempFunction("temp1") == Some(tempFunc1))
- assert(catalog.getTempFunction("temp2") == Some(tempFunc2))
- assert(catalog.getTempFunction("temp3") == None)
+ val tempFunc1 = (e: Seq[Expression]) => e.head
+ val tempFunc2 = (e: Seq[Expression]) => e.last
+ val info1 = new ExpressionInfo("tempFunc1", "temp1")
+ val info2 = new ExpressionInfo("tempFunc2", "temp2")
+ catalog.createTempFunction("temp1", info1, tempFunc1, ignoreIfExists = false)
+ catalog.createTempFunction("temp2", info2, tempFunc2, ignoreIfExists = false)
+ val arguments = Seq(Literal(1), Literal(2), Literal(3))
+ assert(catalog.lookupFunction("temp1", arguments) === Literal(1))
+ assert(catalog.lookupFunction("temp2", arguments) === Literal(3))
+ // Temporary function does not exist.
+ intercept[AnalysisException] {
+ catalog.lookupFunction("temp3", arguments)
+ }
+ val tempFunc3 = (e: Seq[Expression]) => Literal(e.size)
+ val info3 = new ExpressionInfo("tempFunc3", "temp1")
// Temporary function already exists
intercept[AnalysisException] {
- catalog.createTempFunction(tempFunc1, ignoreIfExists = false)
+ catalog.createTempFunction("temp1", info3, tempFunc3, ignoreIfExists = false)
}
// Temporary function is overridden
- val tempFunc3 = tempFunc1.copy(className = "something else")
- catalog.createTempFunction(tempFunc3, ignoreIfExists = true)
- assert(catalog.getTempFunction("temp1") == Some(tempFunc3))
+ catalog.createTempFunction("temp1", info3, tempFunc3, ignoreIfExists = true)
+ assert(catalog.lookupFunction("temp1", arguments) === Literal(arguments.length))
}
test("drop function") {
val externalCatalog = newBasicCatalog()
val sessionCatalog = new SessionCatalog(externalCatalog)
assert(externalCatalog.listFunctions("db2", "*").toSet == Set("func1"))
- sessionCatalog.dropFunction(FunctionIdentifier("func1", Some("db2")))
+ sessionCatalog.dropFunction(
+ FunctionIdentifier("func1", Some("db2")), ignoreIfNotExists = false)
assert(externalCatalog.listFunctions("db2", "*").isEmpty)
// Drop function without explicitly specifying database
sessionCatalog.setCurrentDatabase("db2")
- sessionCatalog.createFunction(newFunc("func2", Some("db2")))
+ sessionCatalog.createFunction(newFunc("func2", Some("db2")), ignoreIfExists = false)
assert(externalCatalog.listFunctions("db2", "*").toSet == Set("func2"))
- sessionCatalog.dropFunction(FunctionIdentifier("func2"))
+ sessionCatalog.dropFunction(FunctionIdentifier("func2"), ignoreIfNotExists = false)
assert(externalCatalog.listFunctions("db2", "*").isEmpty)
}
test("drop function when database/function does not exist") {
val catalog = new SessionCatalog(newBasicCatalog())
intercept[AnalysisException] {
- catalog.dropFunction(FunctionIdentifier("something", Some("does_not_exist")))
+ catalog.dropFunction(
+ FunctionIdentifier("something", Some("does_not_exist")), ignoreIfNotExists = false)
}
intercept[AnalysisException] {
- catalog.dropFunction(FunctionIdentifier("does_not_exist"))
+ catalog.dropFunction(FunctionIdentifier("does_not_exist"), ignoreIfNotExists = false)
}
+ catalog.dropFunction(FunctionIdentifier("does_not_exist"), ignoreIfNotExists = true)
}
test("drop temp function") {
val catalog = new SessionCatalog(newBasicCatalog())
- val tempFunc = newFunc("func1")
- catalog.createTempFunction(tempFunc, ignoreIfExists = false)
- assert(catalog.getTempFunction("func1") == Some(tempFunc))
+ val info = new ExpressionInfo("tempFunc", "func1")
+ val tempFunc = (e: Seq[Expression]) => e.head
+ catalog.createTempFunction("func1", info, tempFunc, ignoreIfExists = false)
+ val arguments = Seq(Literal(1), Literal(2), Literal(3))
+ assert(catalog.lookupFunction("func1", arguments) === Literal(1))
catalog.dropTempFunction("func1", ignoreIfNotExists = false)
- assert(catalog.getTempFunction("func1") == None)
+ intercept[AnalysisException] {
+ catalog.lookupFunction("func1", arguments)
+ }
intercept[AnalysisException] {
catalog.dropTempFunction("func1", ignoreIfNotExists = false)
}
@@ -738,132 +768,47 @@ class SessionCatalogSuite extends SparkFunSuite {
test("get function") {
val catalog = new SessionCatalog(newBasicCatalog())
- val expected = CatalogFunction(FunctionIdentifier("func1", Some("db2")), funcClass)
- assert(catalog.getFunction(FunctionIdentifier("func1", Some("db2"))) == expected)
+ val expected =
+ CatalogFunction(FunctionIdentifier("func1", Some("db2")), funcClass,
+ Seq.empty[(String, String)])
+ assert(catalog.getFunctionMetadata(FunctionIdentifier("func1", Some("db2"))) == expected)
// Get function without explicitly specifying database
catalog.setCurrentDatabase("db2")
- assert(catalog.getFunction(FunctionIdentifier("func1")) == expected)
+ assert(catalog.getFunctionMetadata(FunctionIdentifier("func1")) == expected)
}
test("get function when database/function does not exist") {
val catalog = new SessionCatalog(newBasicCatalog())
intercept[AnalysisException] {
- catalog.getFunction(FunctionIdentifier("func1", Some("does_not_exist")))
- }
- intercept[AnalysisException] {
- catalog.getFunction(FunctionIdentifier("does_not_exist", Some("db2")))
- }
- }
-
- test("get temp function") {
- val externalCatalog = newBasicCatalog()
- val sessionCatalog = new SessionCatalog(externalCatalog)
- val metastoreFunc = externalCatalog.getFunction("db2", "func1")
- val tempFunc = newFunc("func1").copy(className = "something weird")
- sessionCatalog.createTempFunction(tempFunc, ignoreIfExists = false)
- sessionCatalog.setCurrentDatabase("db2")
- // If a database is specified, we'll always return the function in that database
- assert(sessionCatalog.getFunction(FunctionIdentifier("func1", Some("db2"))) == metastoreFunc)
- // If no database is specified, we'll first return temporary functions
- assert(sessionCatalog.getFunction(FunctionIdentifier("func1")) == tempFunc)
- // Then, if no such temporary function exist, check the current database
- sessionCatalog.dropTempFunction("func1", ignoreIfNotExists = false)
- assert(sessionCatalog.getFunction(FunctionIdentifier("func1")) == metastoreFunc)
- }
-
- test("rename function") {
- val externalCatalog = newBasicCatalog()
- val sessionCatalog = new SessionCatalog(externalCatalog)
- val newName = "funcky"
- assert(sessionCatalog.getFunction(
- FunctionIdentifier("func1", Some("db2"))) == newFunc("func1", Some("db2")))
- assert(externalCatalog.listFunctions("db2", "*").toSet == Set("func1"))
- sessionCatalog.renameFunction(
- FunctionIdentifier("func1", Some("db2")), FunctionIdentifier(newName, Some("db2")))
- assert(sessionCatalog.getFunction(
- FunctionIdentifier(newName, Some("db2"))) == newFunc(newName, Some("db2")))
- assert(externalCatalog.listFunctions("db2", "*").toSet == Set(newName))
- // Rename function without explicitly specifying database
- sessionCatalog.setCurrentDatabase("db2")
- sessionCatalog.renameFunction(FunctionIdentifier(newName), FunctionIdentifier("func1"))
- assert(sessionCatalog.getFunction(
- FunctionIdentifier("func1")) == newFunc("func1", Some("db2")))
- assert(externalCatalog.listFunctions("db2", "*").toSet == Set("func1"))
- // Renaming "db2.func1" to "db1.func2" should fail because databases don't match
- intercept[AnalysisException] {
- sessionCatalog.renameFunction(
- FunctionIdentifier("func1", Some("db2")), FunctionIdentifier("func2", Some("db1")))
- }
- }
-
- test("rename function when database/function does not exist") {
- val catalog = new SessionCatalog(newBasicCatalog())
- intercept[AnalysisException] {
- catalog.renameFunction(
- FunctionIdentifier("func1", Some("does_not_exist")),
- FunctionIdentifier("func5", Some("does_not_exist")))
+ catalog.getFunctionMetadata(FunctionIdentifier("func1", Some("does_not_exist")))
}
intercept[AnalysisException] {
- catalog.renameFunction(
- FunctionIdentifier("does_not_exist", Some("db2")),
- FunctionIdentifier("x", Some("db2")))
+ catalog.getFunctionMetadata(FunctionIdentifier("does_not_exist", Some("db2")))
}
}
- test("rename temp function") {
- val externalCatalog = newBasicCatalog()
- val sessionCatalog = new SessionCatalog(externalCatalog)
- val tempFunc = newFunc("func1").copy(className = "something weird")
- sessionCatalog.createTempFunction(tempFunc, ignoreIfExists = false)
- sessionCatalog.setCurrentDatabase("db2")
- // If a database is specified, we'll always rename the function in that database
- sessionCatalog.renameFunction(
- FunctionIdentifier("func1", Some("db2")), FunctionIdentifier("func3", Some("db2")))
- assert(sessionCatalog.getTempFunction("func1") == Some(tempFunc))
- assert(sessionCatalog.getTempFunction("func3") == None)
- assert(externalCatalog.listFunctions("db2", "*").toSet == Set("func3"))
- // If no database is specified, we'll first rename temporary functions
- sessionCatalog.createFunction(newFunc("func1", Some("db2")))
- sessionCatalog.renameFunction(FunctionIdentifier("func1"), FunctionIdentifier("func4"))
- assert(sessionCatalog.getTempFunction("func4") ==
- Some(tempFunc.copy(name = FunctionIdentifier("func4"))))
- assert(sessionCatalog.getTempFunction("func1") == None)
- assert(externalCatalog.listFunctions("db2", "*").toSet == Set("func1", "func3"))
- // Then, if no such temporary function exist, rename the function in the current database
- sessionCatalog.renameFunction(FunctionIdentifier("func1"), FunctionIdentifier("func5"))
- assert(sessionCatalog.getTempFunction("func5") == None)
- assert(externalCatalog.listFunctions("db2", "*").toSet == Set("func3", "func5"))
- }
-
- test("alter function") {
- val catalog = new SessionCatalog(newBasicCatalog())
- assert(catalog.getFunction(FunctionIdentifier("func1", Some("db2"))).className == funcClass)
- catalog.alterFunction(newFunc("func1", Some("db2")).copy(className = "muhaha"))
- assert(catalog.getFunction(FunctionIdentifier("func1", Some("db2"))).className == "muhaha")
- // Alter function without explicitly specifying database
- catalog.setCurrentDatabase("db2")
- catalog.alterFunction(newFunc("func1").copy(className = "derpy"))
- assert(catalog.getFunction(FunctionIdentifier("func1")).className == "derpy")
- }
-
- test("alter function when database/function does not exist") {
+ test("lookup temp function") {
val catalog = new SessionCatalog(newBasicCatalog())
+ val info1 = new ExpressionInfo("tempFunc1", "func1")
+ val tempFunc1 = (e: Seq[Expression]) => e.head
+ catalog.createTempFunction("func1", info1, tempFunc1, ignoreIfExists = false)
+ assert(catalog.lookupFunction("func1", Seq(Literal(1), Literal(2), Literal(3))) == Literal(1))
+ catalog.dropTempFunction("func1", ignoreIfNotExists = false)
intercept[AnalysisException] {
- catalog.alterFunction(newFunc("func5", Some("does_not_exist")))
- }
- intercept[AnalysisException] {
- catalog.alterFunction(newFunc("funcky", Some("db2")))
+ catalog.lookupFunction("func1", Seq(Literal(1), Literal(2), Literal(3)))
}
}
test("list functions") {
val catalog = new SessionCatalog(newBasicCatalog())
- val tempFunc1 = newFunc("func1").copy(className = "march")
- val tempFunc2 = newFunc("yes_me").copy(className = "april")
- catalog.createFunction(newFunc("func2", Some("db2")))
- catalog.createFunction(newFunc("not_me", Some("db2")))
- catalog.createTempFunction(tempFunc1, ignoreIfExists = false)
- catalog.createTempFunction(tempFunc2, ignoreIfExists = false)
+ val info1 = new ExpressionInfo("tempFunc1", "func1")
+ val info2 = new ExpressionInfo("tempFunc2", "yes_me")
+ val tempFunc1 = (e: Seq[Expression]) => e.head
+ val tempFunc2 = (e: Seq[Expression]) => e.last
+ catalog.createFunction(newFunc("func2", Some("db2")), ignoreIfExists = false)
+ catalog.createFunction(newFunc("not_me", Some("db2")), ignoreIfExists = false)
+ catalog.createTempFunction("func1", info1, tempFunc1, ignoreIfExists = false)
+ catalog.createTempFunction("yes_me", info2, tempFunc2, ignoreIfExists = false)
assert(catalog.listFunctions("db1", "*").toSet ==
Set(FunctionIdentifier("func1"),
FunctionIdentifier("yes_me")))
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
index f6583bfe42..18752014ea 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
@@ -315,7 +315,7 @@ class ExpressionEncoderSuite extends PlanTest with AnalysisTest {
val attr = AttributeReference("obj", ObjectType(encoder.clsTag.runtimeClass))()
val inputPlan = LocalRelation(attr)
val plan =
- Project(Alias(encoder.fromRowExpression, "obj")() :: Nil,
+ Project(Alias(encoder.deserializer, "obj")() :: Nil,
Project(encoder.namedExpressions,
inputPlan))
assertAnalysisSuccess(plan)
@@ -360,7 +360,7 @@ class ExpressionEncoderSuite extends PlanTest with AnalysisTest {
|${encoder.schema.treeString}
|
|fromRow Expressions:
- |${boundEncoder.fromRowExpression.treeString}
+ |${boundEncoder.deserializer.treeString}
""".stripMargin)
}
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
index 99e3b13ce8..2cf8ca7000 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
@@ -382,6 +382,7 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(InitCap(Literal("a b")), "A B")
checkEvaluation(InitCap(Literal(" a")), " A")
checkEvaluation(InitCap(Literal("the test")), "The Test")
+ checkEvaluation(InitCap(Literal("sParK")), "Spark")
// scalastyle:off
// non ascii characters are not allowed in the code, so we disable the scalastyle here.
checkEvaluation(InitCap(Literal("世界")), "世界")
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TimeWindowSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TimeWindowSuite.scala
new file mode 100644
index 0000000000..b82cf8d169
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TimeWindowSuite.scala
@@ -0,0 +1,111 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions
+
+import org.scalatest.PrivateMethodTester
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.types.LongType
+
+class TimeWindowSuite extends SparkFunSuite with ExpressionEvalHelper with PrivateMethodTester {
+
+ test("time window is unevaluable") {
+ intercept[UnsupportedOperationException] {
+ evaluate(TimeWindow(Literal(10L), "1 second", "1 second", "0 second"))
+ }
+ }
+
+ private def checkErrorMessage(msg: String, value: String): Unit = {
+ val validDuration = "10 second"
+ val validTime = "5 second"
+ val e1 = intercept[IllegalArgumentException] {
+ TimeWindow(Literal(10L), value, validDuration, validTime).windowDuration
+ }
+ val e2 = intercept[IllegalArgumentException] {
+ TimeWindow(Literal(10L), validDuration, value, validTime).slideDuration
+ }
+ val e3 = intercept[IllegalArgumentException] {
+ TimeWindow(Literal(10L), validDuration, validDuration, value).startTime
+ }
+ Seq(e1, e2, e3).foreach { e =>
+ e.getMessage.contains(msg)
+ }
+ }
+
+ test("blank intervals throw exception") {
+ for (blank <- Seq(null, " ", "\n", "\t")) {
+ checkErrorMessage(
+ "The window duration, slide duration and start time cannot be null or blank.", blank)
+ }
+ }
+
+ test("invalid intervals throw exception") {
+ checkErrorMessage(
+ "did not correspond to a valid interval string.", "2 apples")
+ }
+
+ test("intervals greater than a month throws exception") {
+ checkErrorMessage(
+ "Intervals greater than or equal to a month is not supported (1 month).", "1 month")
+ }
+
+ test("interval strings work with and without 'interval' prefix and return microseconds") {
+ val validDuration = "10 second"
+ for ((text, seconds) <- Seq(
+ ("1 second", 1000000), // 1e6
+ ("1 minute", 60000000), // 6e7
+ ("2 hours", 7200000000L))) { // 72e9
+ assert(TimeWindow(Literal(10L), text, validDuration, "0 seconds").windowDuration === seconds)
+ assert(TimeWindow(Literal(10L), "interval " + text, validDuration, "0 seconds").windowDuration
+ === seconds)
+ }
+ }
+
+ private val parseExpression = PrivateMethod[Long]('parseExpression)
+
+ test("parse sql expression for duration in microseconds - string") {
+ val dur = TimeWindow.invokePrivate(parseExpression(Literal("5 seconds")))
+ assert(dur.isInstanceOf[Long])
+ assert(dur === 5000000)
+ }
+
+ test("parse sql expression for duration in microseconds - integer") {
+ val dur = TimeWindow.invokePrivate(parseExpression(Literal(100)))
+ assert(dur.isInstanceOf[Long])
+ assert(dur === 100)
+ }
+
+ test("parse sql expression for duration in microseconds - long") {
+ val dur = TimeWindow.invokePrivate(parseExpression(Literal.create(2 << 52, LongType)))
+ assert(dur.isInstanceOf[Long])
+ assert(dur === (2 << 52))
+ }
+
+ test("parse sql expression for duration in microseconds - invalid interval") {
+ intercept[IllegalArgumentException] {
+ TimeWindow.invokePrivate(parseExpression(Literal("2 apples")))
+ }
+ }
+
+ test("parse sql expression for duration in microseconds - invalid expression") {
+ intercept[AnalysisException] {
+ TimeWindow.invokePrivate(parseExpression(Rand(123)))
+ }
+ }
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatterSuite.scala
index 9da1068e9c..f57b82bb96 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatterSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatterSuite.scala
@@ -18,13 +18,20 @@
package org.apache.spark.sql.catalyst.expressions.codegen
import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.catalyst.util._
class CodeFormatterSuite extends SparkFunSuite {
def testCase(name: String)(input: String)(expected: String): Unit = {
test(name) {
- assert(CodeFormatter.format(input).trim === expected.trim)
+ if (CodeFormatter.format(input).trim !== expected.trim) {
+ fail(
+ s"""
+ |== FAIL: Formatted code doesn't match ===
+ |${sideBySide(CodeFormatter.format(input).trim, expected.trim).mkString("\n")}
+ """.stripMargin)
+ }
}
}
@@ -93,4 +100,50 @@ class CodeFormatterSuite extends SparkFunSuite {
|/* 004 */ c)
""".stripMargin
}
+
+ testCase("single line comments") {
+ """// This is a comment about class A { { { ( (
+ |class A {
+ |class body;
+ |}""".stripMargin
+ }{
+ """
+ |/* 001 */ // This is a comment about class A { { { ( (
+ |/* 002 */ class A {
+ |/* 003 */ class body;
+ |/* 004 */ }
+ """.stripMargin
+ }
+
+ testCase("single line comments /* */ ") {
+ """/** This is a comment about class A { { { ( ( */
+ |class A {
+ |class body;
+ |}""".stripMargin
+ }{
+ """
+ |/* 001 */ /** This is a comment about class A { { { ( ( */
+ |/* 002 */ class A {
+ |/* 003 */ class body;
+ |/* 004 */ }
+ """.stripMargin
+ }
+
+ testCase("multi-line comments") {
+ """ /* This is a comment about
+ |class A {
+ |class body; ...*/
+ |class A {
+ |class body;
+ |}""".stripMargin
+ }{
+ """
+ |/* 001 */ /* This is a comment about
+ |/* 002 */ class A {
+ |/* 003 */ class body; ...*/
+ |/* 004 */ class A {
+ |/* 005 */ class body;
+ |/* 006 */ }
+ """.stripMargin
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BinaryComparisonSimplificationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BinaryComparisonSimplificationSuite.scala
new file mode 100644
index 0000000000..7cd038570b
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BinaryComparisonSimplificationSuite.scala
@@ -0,0 +1,95 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.optimizer
+
+import org.apache.spark.sql.catalyst.analysis._
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.dsl.plans._
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral}
+import org.apache.spark.sql.catalyst.plans.PlanTest
+import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.catalyst.rules._
+
+class BinaryComparisonSimplificationSuite extends PlanTest with PredicateHelper {
+
+ object Optimize extends RuleExecutor[LogicalPlan] {
+ val batches =
+ Batch("AnalysisNodes", Once,
+ EliminateSubqueryAliases) ::
+ Batch("Constant Folding", FixedPoint(50),
+ NullPropagation,
+ ConstantFolding,
+ BooleanSimplification,
+ BinaryComparisonSimplification,
+ PruneFilters) :: Nil
+ }
+
+ val nullableRelation = LocalRelation('a.int.withNullability(true))
+ val nonNullableRelation = LocalRelation('a.int.withNullability(false))
+
+ test("Preserve nullable exprs in general") {
+ for (e <- Seq('a === 'a, 'a <= 'a, 'a >= 'a, 'a < 'a, 'a > 'a)) {
+ val plan = nullableRelation.where(e).analyze
+ val actual = Optimize.execute(plan)
+ val correctAnswer = plan
+ comparePlans(actual, correctAnswer)
+ }
+ }
+
+ test("Preserve non-deterministic exprs") {
+ val plan = nonNullableRelation
+ .where(Rand(0) === Rand(0) && Rand(1) <=> Rand(1)).analyze
+ val actual = Optimize.execute(plan)
+ val correctAnswer = plan
+ comparePlans(actual, correctAnswer)
+ }
+
+ test("Nullable Simplification Primitive: <=>") {
+ val plan = nullableRelation.select('a <=> 'a).analyze
+ val actual = Optimize.execute(plan)
+ val correctAnswer = nullableRelation.select(Alias(TrueLiteral, "(a <=> a)")()).analyze
+ comparePlans(actual, correctAnswer)
+ }
+
+ test("Non-Nullable Simplification Primitive") {
+ val plan = nonNullableRelation
+ .select('a === 'a, 'a <=> 'a, 'a <= 'a, 'a >= 'a, 'a < 'a, 'a > 'a).analyze
+ val actual = Optimize.execute(plan)
+ val correctAnswer = nonNullableRelation
+ .select(
+ Alias(TrueLiteral, "(a = a)")(),
+ Alias(TrueLiteral, "(a <=> a)")(),
+ Alias(TrueLiteral, "(a <= a)")(),
+ Alias(TrueLiteral, "(a >= a)")(),
+ Alias(FalseLiteral, "(a < a)")(),
+ Alias(FalseLiteral, "(a > a)")())
+ .analyze
+ comparePlans(actual, correctAnswer)
+ }
+
+ test("Expression Normalization") {
+ val plan = nonNullableRelation.where(
+ 'a * Literal(100) + Pi() === Pi() + Literal(100) * 'a &&
+ DateAdd(CurrentDate(), 'a + Literal(2)) <= DateAdd(CurrentDate(), Literal(2) + 'a))
+ .analyze
+ val actual = Optimize.execute(plan)
+ val correctAnswer = nonNullableRelation.analyze
+ comparePlans(actual, correctAnswer)
+ }
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala
index e2c76b700f..8147d06969 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala
@@ -140,8 +140,7 @@ class BooleanSimplificationSuite extends PlanTest with PredicateHelper {
private val caseInsensitiveConf = new SimpleCatalystConf(false)
private val caseInsensitiveAnalyzer = new Analyzer(
- new SessionCatalog(new InMemoryCatalog, caseInsensitiveConf),
- EmptyFunctionRegistry,
+ new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry, caseInsensitiveConf),
caseInsensitiveConf)
test("(a && b) || (a && c) => a && (b || c) when case insensitive") {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala
index 2248e03b2f..52b574c0e6 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala
@@ -34,7 +34,7 @@ class ColumnPruningSuite extends PlanTest {
object Optimize extends RuleExecutor[LogicalPlan] {
val batches = Batch("Column pruning", FixedPoint(100),
- PushPredicateThroughProject,
+ PushDownPredicate,
ColumnPruning,
CollapseProject) :: Nil
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsSuite.scala
index 3824c67563..8c92ad82ac 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsSuite.scala
@@ -29,8 +29,8 @@ import org.apache.spark.sql.catalyst.rules._
class EliminateSortsSuite extends PlanTest {
val conf = new SimpleCatalystConf(caseSensitiveAnalysis = true, orderByOrdinal = false)
- val catalog = new SessionCatalog(new InMemoryCatalog, conf)
- val analyzer = new Analyzer(catalog, EmptyFunctionRegistry, conf)
+ val catalog = new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry, conf)
+ val analyzer = new Analyzer(catalog, conf)
object Optimize extends RuleExecutor[LogicalPlan] {
val batches =
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
index b84ae7c5bb..df7529d83f 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
@@ -33,14 +33,12 @@ class FilterPushdownSuite extends PlanTest {
val batches =
Batch("Subqueries", Once,
EliminateSubqueryAliases) ::
- Batch("Filter Pushdown", Once,
+ Batch("Filter Pushdown", FixedPoint(10),
SamplePushDown,
CombineFilters,
- PushPredicateThroughProject,
+ PushDownPredicate,
BooleanSimplification,
PushPredicateThroughJoin,
- PushPredicateThroughGenerate,
- PushPredicateThroughAggregate,
CollapseProject) :: Nil
}
@@ -620,8 +618,8 @@ class FilterPushdownSuite extends PlanTest {
val optimized = Optimize.execute(originalQuery.analyze)
val correctAnswer = testRelation
- .select('a, 'b)
.where('a === 3)
+ .select('a, 'b)
.groupBy('a)('a, count('b) as 'c)
.where('c === 2L)
.analyze
@@ -638,8 +636,8 @@ class FilterPushdownSuite extends PlanTest {
val optimized = Optimize.execute(originalQuery.analyze)
val correctAnswer = testRelation
- .select('a, 'b)
.where('a + 1 < 3)
+ .select('a, 'b)
.groupBy('a)(('a + 1) as 'aa, count('b) as 'c)
.where('c === 2L || 'aa > 4)
.analyze
@@ -656,8 +654,8 @@ class FilterPushdownSuite extends PlanTest {
val optimized = Optimize.execute(originalQuery.analyze)
val correctAnswer = testRelation
- .select('a, 'b)
.where("s" === "s")
+ .select('a, 'b)
.groupBy('a)('a, count('b) as 'c, "s" as 'd)
.where('c === 2L)
.analyze
@@ -681,4 +679,68 @@ class FilterPushdownSuite extends PlanTest {
comparePlans(optimized, correctAnswer)
}
+
+ test("broadcast hint") {
+ val originalQuery = BroadcastHint(testRelation)
+ .where('a === 2L && 'b + Rand(10).as("rnd") === 3)
+
+ val optimized = Optimize.execute(originalQuery.analyze)
+
+ val correctAnswer = BroadcastHint(testRelation.where('a === 2L))
+ .where('b + Rand(10).as("rnd") === 3)
+ .analyze
+
+ comparePlans(optimized, correctAnswer)
+ }
+
+ test("union") {
+ val testRelation2 = LocalRelation('d.int, 'e.int, 'f.int)
+
+ val originalQuery = Union(Seq(testRelation, testRelation2))
+ .where('a === 2L && 'b + Rand(10).as("rnd") === 3)
+
+ val optimized = Optimize.execute(originalQuery.analyze)
+
+ val correctAnswer = Union(Seq(
+ testRelation.where('a === 2L),
+ testRelation2.where('d === 2L)))
+ .where('b + Rand(10).as("rnd") === 3)
+ .analyze
+
+ comparePlans(optimized, correctAnswer)
+ }
+
+ test("intersect") {
+ val testRelation2 = LocalRelation('d.int, 'e.int, 'f.int)
+
+ val originalQuery = Intersect(testRelation, testRelation2)
+ .where('a === 2L && 'b + Rand(10).as("rnd") === 3)
+
+ val optimized = Optimize.execute(originalQuery.analyze)
+
+ val correctAnswer = Intersect(
+ testRelation.where('a === 2L),
+ testRelation2.where('d === 2L))
+ .where('b + Rand(10).as("rnd") === 3)
+ .analyze
+
+ comparePlans(optimized, correctAnswer)
+ }
+
+ test("except") {
+ val testRelation2 = LocalRelation('d.int, 'e.int, 'f.int)
+
+ val originalQuery = Except(testRelation, testRelation2)
+ .where('a === 2L && 'b + Rand(10).as("rnd") === 3)
+
+ val optimized = Optimize.execute(originalQuery.analyze)
+
+ val correctAnswer = Except(
+ testRelation.where('a === 2L),
+ testRelation2)
+ .where('b + Rand(10).as("rnd") === 3)
+ .analyze
+
+ comparePlans(optimized, correctAnswer)
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala
index e2f8146bee..c1ebf8b09e 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala
@@ -36,12 +36,10 @@ class JoinOptimizationSuite extends PlanTest {
EliminateSubqueryAliases) ::
Batch("Filter Pushdown", FixedPoint(100),
CombineFilters,
- PushPredicateThroughProject,
+ PushDownPredicate,
BooleanSimplification,
ReorderJoin,
PushPredicateThroughJoin,
- PushPredicateThroughGenerate,
- PushPredicateThroughAggregate,
ColumnPruning,
CollapseProject) :: Nil
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LikeSimplificationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LikeSimplificationSuite.scala
index 741bc113cf..fdde89d079 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LikeSimplificationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LikeSimplificationSuite.scala
@@ -61,6 +61,20 @@ class LikeSimplificationSuite extends PlanTest {
comparePlans(optimized, correctAnswer)
}
+ test("simplify Like into startsWith and EndsWith") {
+ val originalQuery =
+ testRelation
+ .where(('a like "abc\\%def") || ('a like "abc%def"))
+
+ val optimized = Optimize.execute(originalQuery.analyze)
+ val correctAnswer = testRelation
+ .where(('a like "abc\\%def") ||
+ (Length('a) >= 6 && (StartsWith('a, "abc") && EndsWith('a, "def"))))
+ .analyze
+
+ comparePlans(optimized, correctAnswer)
+ }
+
test("simplify Like into Contains") {
val originalQuery =
testRelation
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerExtendableSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerExtendableSuite.scala
index 7e3da6bea7..6e5672ddc3 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerExtendableSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerExtendableSuite.scala
@@ -23,21 +23,21 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule
/**
- * This is a test for SPARK-7727 if the Optimizer is kept being extendable
- */
+ * This is a test for SPARK-7727 if the Optimizer is kept being extendable
+ */
class OptimizerExtendableSuite extends SparkFunSuite {
/**
- * Dummy rule for test batches
- */
+ * Dummy rule for test batches
+ */
object DummyRule extends Rule[LogicalPlan] {
def apply(p: LogicalPlan): LogicalPlan = p
}
/**
- * This class represents a dummy extended optimizer that takes the batches of the
- * Optimizer and adds custom ones.
- */
+ * This class represents a dummy extended optimizer that takes the batches of the
+ * Optimizer and adds custom ones.
+ */
class ExtendedOptimizer extends Optimizer {
// rules set to DummyRule, would not be executed anyways
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PruneFiltersSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PruneFiltersSuite.scala
index 14fb72a8a3..d8cfec5391 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PruneFiltersSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PruneFiltersSuite.scala
@@ -34,7 +34,7 @@ class PruneFiltersSuite extends PlanTest {
Batch("Filter Pushdown and Pruning", Once,
CombineFilters,
PruneFilters,
- PushPredicateThroughProject,
+ PushDownPredicate,
PushPredicateThroughJoin) :: Nil
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala
index d436b627f6..c02fec3085 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala
@@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLite
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
-import org.apache.spark.sql.types.IntegerType
+import org.apache.spark.sql.types.{IntegerType, NullType}
class SimplifyConditionalSuite extends PlanTest with PredicateHelper {
@@ -41,6 +41,7 @@ class SimplifyConditionalSuite extends PlanTest with PredicateHelper {
private val trueBranch = (TrueLiteral, Literal(5))
private val normalBranch = (NonFoldableLiteral(true), Literal(10))
private val unreachableBranch = (FalseLiteral, Literal(20))
+ private val nullBranch = (Literal.create(null, NullType), Literal(30))
test("simplify if") {
assertEquivalent(
@@ -50,18 +51,22 @@ class SimplifyConditionalSuite extends PlanTest with PredicateHelper {
assertEquivalent(
If(FalseLiteral, Literal(10), Literal(20)),
Literal(20))
+
+ assertEquivalent(
+ If(Literal.create(null, NullType), Literal(10), Literal(20)),
+ Literal(20))
}
test("remove unreachable branches") {
// i.e. removing branches whose conditions are always false
assertEquivalent(
- CaseWhen(unreachableBranch :: normalBranch :: unreachableBranch :: Nil, None),
+ CaseWhen(unreachableBranch :: normalBranch :: unreachableBranch :: nullBranch :: Nil, None),
CaseWhen(normalBranch :: Nil, None))
}
test("remove entire CaseWhen if only the else branch is reachable") {
assertEquivalent(
- CaseWhen(unreachableBranch :: unreachableBranch :: Nil, Some(Literal(30))),
+ CaseWhen(unreachableBranch :: unreachableBranch :: nullBranch :: Nil, Some(Literal(30))),
Literal(30))
assertEquivalent(
@@ -71,12 +76,13 @@ class SimplifyConditionalSuite extends PlanTest with PredicateHelper {
test("remove entire CaseWhen if the first branch is always true") {
assertEquivalent(
- CaseWhen(trueBranch :: normalBranch :: Nil, None),
+ CaseWhen(trueBranch :: normalBranch :: nullBranch :: Nil, None),
Literal(5))
// Test branch elimination and simplification in combination
assertEquivalent(
- CaseWhen(unreachableBranch :: unreachableBranch:: trueBranch :: normalBranch :: Nil, None),
+ CaseWhen(unreachableBranch :: unreachableBranch :: nullBranch :: trueBranch :: normalBranch
+ :: Nil, None),
Literal(5))
// Make sure this doesn't trigger if there is a non-foldable branch before the true branch
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/TypedFilterOptimizationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/TypedFilterOptimizationSuite.scala
new file mode 100644
index 0000000000..1fae64e3bc
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/TypedFilterOptimizationSuite.scala
@@ -0,0 +1,74 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.optimizer
+
+import scala.reflect.runtime.universe.TypeTag
+
+import org.apache.spark.sql.catalyst.analysis.UnresolvedDeserializer
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.dsl.plans._
+import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder}
+import org.apache.spark.sql.catalyst.plans.PlanTest
+import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
+import org.apache.spark.sql.catalyst.rules.RuleExecutor
+import org.apache.spark.sql.types.BooleanType
+
+class TypedFilterOptimizationSuite extends PlanTest {
+ object Optimize extends RuleExecutor[LogicalPlan] {
+ val batches =
+ Batch("EliminateSerialization", FixedPoint(50),
+ EliminateSerialization) ::
+ Batch("EmbedSerializerInFilter", FixedPoint(50),
+ EmbedSerializerInFilter) :: Nil
+ }
+
+ implicit private def productEncoder[T <: Product : TypeTag] = ExpressionEncoder[T]()
+
+ test("back to back filter") {
+ val input = LocalRelation('_1.int, '_2.int)
+ val f1 = (i: (Int, Int)) => i._1 > 0
+ val f2 = (i: (Int, Int)) => i._2 > 0
+
+ val query = input.filter(f1).filter(f2).analyze
+
+ val optimized = Optimize.execute(query)
+
+ val expected = input.deserialize[(Int, Int)]
+ .where(callFunction(f1, BooleanType, 'obj))
+ .select('obj.as("obj"))
+ .where(callFunction(f2, BooleanType, 'obj))
+ .serialize[(Int, Int)].analyze
+
+ comparePlans(optimized, expected)
+ }
+
+ test("embed deserializer in filter condition if there is only one filter") {
+ val input = LocalRelation('_1.int, '_2.int)
+ val f = (i: (Int, Int)) => i._1 > 0
+
+ val query = input.filter(f).analyze
+
+ val optimized = Optimize.execute(query)
+
+ val deserializer = UnresolvedDeserializer(encoderFor[(Int, Int)].deserializer)
+ val condition = callFunction(f, BooleanType, deserializer)
+ val expected = input.where(condition).analyze
+
+ comparePlans(optimized, expected)
+ }
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/CatalystQlSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/CatalystQlSuite.scala
deleted file mode 100644
index c068e895b6..0000000000
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/CatalystQlSuite.scala
+++ /dev/null
@@ -1,243 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.catalyst.parser
-
-import org.apache.spark.sql.AnalysisException
-import org.apache.spark.sql.catalyst.TableIdentifier
-import org.apache.spark.sql.catalyst.analysis._
-import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.plans.PlanTest
-import org.apache.spark.sql.catalyst.plans.logical._
-import org.apache.spark.unsafe.types.CalendarInterval
-
-class CatalystQlSuite extends PlanTest {
- val parser = new CatalystQl()
-
- test("test case insensitive") {
- val result = Project(UnresolvedAlias(Literal(1)):: Nil, OneRowRelation)
- assert(result === parser.parsePlan("seLect 1"))
- assert(result === parser.parsePlan("select 1"))
- assert(result === parser.parsePlan("SELECT 1"))
- }
-
- test("test NOT operator with comparison operations") {
- val parsed = parser.parsePlan("SELECT NOT TRUE > TRUE")
- val expected = Project(
- UnresolvedAlias(
- Not(
- GreaterThan(Literal(true), Literal(true)))
- ) :: Nil,
- OneRowRelation)
- comparePlans(parsed, expected)
- }
-
- test("test Union Distinct operator") {
- val parsed1 = parser.parsePlan("SELECT * FROM t0 UNION SELECT * FROM t1")
- val parsed2 = parser.parsePlan("SELECT * FROM t0 UNION DISTINCT SELECT * FROM t1")
- val expected =
- Project(UnresolvedAlias(UnresolvedStar(None)) :: Nil,
- SubqueryAlias("u_1",
- Distinct(
- Union(
- Project(UnresolvedAlias(UnresolvedStar(None)) :: Nil,
- UnresolvedRelation(TableIdentifier("t0"), None)),
- Project(UnresolvedAlias(UnresolvedStar(None)) :: Nil,
- UnresolvedRelation(TableIdentifier("t1"), None))))))
- comparePlans(parsed1, expected)
- comparePlans(parsed2, expected)
- }
-
- test("test Union All operator") {
- val parsed = parser.parsePlan("SELECT * FROM t0 UNION ALL SELECT * FROM t1")
- val expected =
- Project(UnresolvedAlias(UnresolvedStar(None)) :: Nil,
- SubqueryAlias("u_1",
- Union(
- Project(UnresolvedAlias(UnresolvedStar(None)) :: Nil,
- UnresolvedRelation(TableIdentifier("t0"), None)),
- Project(UnresolvedAlias(UnresolvedStar(None)) :: Nil,
- UnresolvedRelation(TableIdentifier("t1"), None)))))
- comparePlans(parsed, expected)
- }
-
- test("support hive interval literal") {
- def checkInterval(sql: String, result: CalendarInterval): Unit = {
- val parsed = parser.parsePlan(sql)
- val expected = Project(
- UnresolvedAlias(
- Literal(result)
- ) :: Nil,
- OneRowRelation)
- comparePlans(parsed, expected)
- }
-
- def checkYearMonth(lit: String): Unit = {
- checkInterval(
- s"SELECT INTERVAL '$lit' YEAR TO MONTH",
- CalendarInterval.fromYearMonthString(lit))
- }
-
- def checkDayTime(lit: String): Unit = {
- checkInterval(
- s"SELECT INTERVAL '$lit' DAY TO SECOND",
- CalendarInterval.fromDayTimeString(lit))
- }
-
- def checkSingleUnit(lit: String, unit: String): Unit = {
- checkInterval(
- s"SELECT INTERVAL '$lit' $unit",
- CalendarInterval.fromSingleUnitString(unit, lit))
- }
-
- checkYearMonth("123-10")
- checkYearMonth("496-0")
- checkYearMonth("-2-3")
- checkYearMonth("-123-0")
-
- checkDayTime("99 11:22:33.123456789")
- checkDayTime("-99 11:22:33.123456789")
- checkDayTime("10 9:8:7.123456789")
- checkDayTime("1 0:0:0")
- checkDayTime("-1 0:0:0")
- checkDayTime("1 0:0:1")
-
- for (unit <- Seq("year", "month", "day", "hour", "minute", "second")) {
- checkSingleUnit("7", unit)
- checkSingleUnit("-7", unit)
- checkSingleUnit("0", unit)
- }
-
- checkSingleUnit("13.123456789", "second")
- checkSingleUnit("-13.123456789", "second")
- }
-
- test("support scientific notation") {
- def assertRight(input: String, output: Double): Unit = {
- val parsed = parser.parsePlan("SELECT " + input)
- val expected = Project(
- UnresolvedAlias(
- Literal(output)
- ) :: Nil,
- OneRowRelation)
- comparePlans(parsed, expected)
- }
-
- assertRight("9.0e1", 90)
- assertRight(".9e+2", 90)
- assertRight("0.9e+2", 90)
- assertRight("900e-1", 90)
- assertRight("900.0E-1", 90)
- assertRight("9.e+1", 90)
-
- intercept[AnalysisException](parser.parsePlan("SELECT .e3"))
- }
-
- test("parse expressions") {
- compareExpressions(
- parser.parseExpression("prinln('hello', 'world')"),
- UnresolvedFunction(
- "prinln", Literal("hello") :: Literal("world") :: Nil, false))
-
- compareExpressions(
- parser.parseExpression("1 + r.r As q"),
- Alias(Add(Literal(1), UnresolvedAttribute("r.r")), "q")())
-
- compareExpressions(
- parser.parseExpression("1 - f('o', o(bar))"),
- Subtract(Literal(1),
- UnresolvedFunction("f",
- Literal("o") ::
- UnresolvedFunction("o", UnresolvedAttribute("bar") :: Nil, false) ::
- Nil, false)))
-
- intercept[AnalysisException](parser.parseExpression("1 - f('o', o(bar)) hello * world"))
- }
-
- test("table identifier") {
- assert(TableIdentifier("q") === parser.parseTableIdentifier("q"))
- assert(TableIdentifier("q", Some("d")) === parser.parseTableIdentifier("d.q"))
- intercept[AnalysisException](parser.parseTableIdentifier(""))
- intercept[AnalysisException](parser.parseTableIdentifier("d.q.g"))
- }
-
- test("parse union/except/intersect") {
- parser.parsePlan("select * from t1 union all select * from t2")
- parser.parsePlan("select * from t1 union distinct select * from t2")
- parser.parsePlan("select * from t1 union select * from t2")
- parser.parsePlan("select * from t1 except select * from t2")
- parser.parsePlan("select * from t1 intersect select * from t2")
- parser.parsePlan("(select * from t1) union all (select * from t2)")
- parser.parsePlan("(select * from t1) union distinct (select * from t2)")
- parser.parsePlan("(select * from t1) union (select * from t2)")
- parser.parsePlan("select * from ((select * from t1) union (select * from t2)) t")
- }
-
- test("window function: better support of parentheses") {
- parser.parsePlan("select sum(product + 1) over (partition by ((1) + (product / 2)) " +
- "order by 2) from windowData")
- parser.parsePlan("select sum(product + 1) over (partition by (1 + (product / 2)) " +
- "order by 2) from windowData")
- parser.parsePlan("select sum(product + 1) over (partition by ((product / 2) + 1) " +
- "order by 2) from windowData")
-
- parser.parsePlan("select sum(product + 1) over (partition by ((product) + (1)) order by 2) " +
- "from windowData")
- parser.parsePlan("select sum(product + 1) over (partition by ((product) + 1) order by 2) " +
- "from windowData")
- parser.parsePlan("select sum(product + 1) over (partition by (product + (1)) order by 2) " +
- "from windowData")
- }
-
- test("very long AND/OR expression") {
- val equals = (1 to 1000).map(x => s"$x == $x")
- val expr = parser.parseExpression(equals.mkString(" AND "))
- assert(expr.isInstanceOf[And])
- assert(expr.collect( { case EqualTo(_, _) => true } ).size == 1000)
-
- val expr2 = parser.parseExpression(equals.mkString(" OR "))
- assert(expr2.isInstanceOf[Or])
- assert(expr2.collect( { case EqualTo(_, _) => true } ).size == 1000)
- }
-
- test("subquery") {
- parser.parsePlan("select (select max(b) from s) ss from t")
- parser.parsePlan("select * from t where a = (select b from s)")
- parser.parsePlan("select * from t group by g having a > (select b from s)")
- }
-
- test("using clause in JOIN") {
- // Tests parsing of using clause for different join types.
- parser.parsePlan("select * from t1 join t2 using (c1)")
- parser.parsePlan("select * from t1 join t2 using (c1, c2)")
- parser.parsePlan("select * from t1 left join t2 using (c1, c2)")
- parser.parsePlan("select * from t1 right join t2 using (c1, c2)")
- parser.parsePlan("select * from t1 full outer join t2 using (c1, c2)")
- parser.parsePlan("select * from t1 join t2 using (c1) join t3 using (c2)")
- // Tests errors
- // (1) Empty using clause
- // (2) Qualified columns in using
- // (3) Both on and using clause
- var error = intercept[AnalysisException](parser.parsePlan("select * from t1 join t2 using ()"))
- assert(error.message.contains("cannot recognize input near ')'"))
- error = intercept[AnalysisException](parser.parsePlan("select * from t1 join t2 using (t1.c1)"))
- assert(error.message.contains("mismatched input '.'"))
- error = intercept[AnalysisException](parser.parsePlan("select * from t1" +
- " join t2 using (c1) on t1.c1 = t2.c1"))
- assert(error.message.contains("missing EOF at 'on' near ')'"))
- }
-}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DataTypeParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DataTypeParserSuite.scala
index 7d3608033b..07b89cb61f 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DataTypeParserSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DataTypeParserSuite.scala
@@ -20,17 +20,21 @@ package org.apache.spark.sql.catalyst.parser
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.types._
-class DataTypeParserSuite extends SparkFunSuite {
+abstract class AbstractDataTypeParserSuite extends SparkFunSuite {
+
+ def parse(sql: String): DataType
def checkDataType(dataTypeString: String, expectedDataType: DataType): Unit = {
test(s"parse ${dataTypeString.replace("\n", "")}") {
- assert(DataTypeParser.parse(dataTypeString) === expectedDataType)
+ assert(parse(dataTypeString) === expectedDataType)
}
}
+ def intercept(sql: String)
+
def unsupported(dataTypeString: String): Unit = {
test(s"$dataTypeString is not supported") {
- intercept[DataTypeException](DataTypeParser.parse(dataTypeString))
+ intercept(dataTypeString)
}
}
@@ -97,13 +101,6 @@ class DataTypeParserSuite extends SparkFunSuite {
StructField("arrAy", ArrayType(DoubleType, true), true) ::
StructField("anotherArray", ArrayType(StringType, true), true) :: Nil)
)
- // A column name can be a reserved word in our DDL parser and SqlParser.
- checkDataType(
- "Struct<TABLE: string, CASE:boolean>",
- StructType(
- StructField("TABLE", StringType, true) ::
- StructField("CASE", BooleanType, true) :: Nil)
- )
// Use backticks to quote column names having special characters.
checkDataType(
"struct<`x+y`:int, `!@#$%^&*()`:string, `1_2.345<>:\"`:varchar(20)>",
@@ -118,6 +115,43 @@ class DataTypeParserSuite extends SparkFunSuite {
unsupported("it is not a data type")
unsupported("struct<x+y: int, 1.1:timestamp>")
unsupported("struct<x: int")
+}
+
+class DataTypeParserSuite extends AbstractDataTypeParserSuite {
+ override def intercept(sql: String): Unit =
+ intercept[DataTypeException](DataTypeParser.parse(sql))
+
+ override def parse(sql: String): DataType =
+ DataTypeParser.parse(sql)
+
+ // A column name can be a reserved word in our DDL parser and SqlParser.
+ checkDataType(
+ "Struct<TABLE: string, CASE:boolean>",
+ StructType(
+ StructField("TABLE", StringType, true) ::
+ StructField("CASE", BooleanType, true) :: Nil)
+ )
+
unsupported("struct<x int, y string>")
+
unsupported("struct<`x``y` int>")
}
+
+class CatalystQlDataTypeParserSuite extends AbstractDataTypeParserSuite {
+ override def intercept(sql: String): Unit =
+ intercept[ParseException](CatalystSqlParser.parseDataType(sql))
+
+ override def parse(sql: String): DataType =
+ CatalystSqlParser.parseDataType(sql)
+
+ // A column name can be a reserved word in our DDL parser and SqlParser.
+ unsupported("Struct<TABLE: string, CASE:boolean>")
+
+ checkDataType(
+ "struct<x int, y string>",
+ (new StructType).add("x", IntegerType).add("y", StringType))
+
+ checkDataType(
+ "struct<`x``y` int>",
+ (new StructType).add("x`y", IntegerType))
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ErrorParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ErrorParserSuite.scala
new file mode 100644
index 0000000000..db96bfb652
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ErrorParserSuite.scala
@@ -0,0 +1,67 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.catalyst.parser
+
+import org.apache.spark.SparkFunSuite
+
+/**
+ * Test various parser errors.
+ */
+class ErrorParserSuite extends SparkFunSuite {
+ def intercept(sql: String, line: Int, startPosition: Int, messages: String*): Unit = {
+ val e = intercept[ParseException](CatalystSqlParser.parsePlan(sql))
+
+ // Check position.
+ assert(e.line.isDefined)
+ assert(e.line.get === line)
+ assert(e.startPosition.isDefined)
+ assert(e.startPosition.get === startPosition)
+
+ // Check messages.
+ val error = e.getMessage
+ messages.foreach { message =>
+ assert(error.contains(message))
+ }
+ }
+
+ test("no viable input") {
+ intercept("select from tbl", 1, 7, "no viable alternative at input", "-------^^^")
+ intercept("select\nfrom tbl", 2, 0, "no viable alternative at input", "^^^")
+ intercept("select ((r + 1) ", 1, 16, "no viable alternative at input", "----------------^^^")
+ }
+
+ test("extraneous input") {
+ intercept("select 1 1", 1, 9, "extraneous input '1' expecting", "---------^^^")
+ intercept("select *\nfrom r as q t", 2, 12, "extraneous input", "------------^^^")
+ }
+
+ test("mismatched input") {
+ intercept("select * from r order by q from t", 1, 27,
+ "mismatched input",
+ "---------------------------^^^")
+ intercept("select *\nfrom r\norder by q\nfrom t", 4, 0, "mismatched input", "^^^")
+ }
+
+ test("semantic errors") {
+ intercept("select *\nfrom r\norder by q\ncluster by q", 3, 0,
+ "Combination of ORDER BY/SORT BY/DISTRIBUTE BY/CLUSTER BY is not supported",
+ "^^^")
+ intercept("select * from r where a in (select * from t)", 1, 24,
+ "IN with a Sub-query is currently not supported",
+ "------------------------^^^")
+ }
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala
new file mode 100644
index 0000000000..6f40ec67ec
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala
@@ -0,0 +1,497 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.catalyst.parser
+
+import java.sql.{Date, Timestamp}
+
+import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, _}
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.PlanTest
+import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.CalendarInterval
+
+/**
+ * Test basic expression parsing. If a type of expression is supported it should be tested here.
+ *
+ * Please note that some of the expressions test don't have to be sound expressions, only their
+ * structure needs to be valid. Unsound expressions should be caught by the Analyzer or
+ * CheckAnalysis classes.
+ */
+class ExpressionParserSuite extends PlanTest {
+ import CatalystSqlParser._
+ import org.apache.spark.sql.catalyst.dsl.expressions._
+ import org.apache.spark.sql.catalyst.dsl.plans._
+
+ def assertEqual(sqlCommand: String, e: Expression): Unit = {
+ compareExpressions(parseExpression(sqlCommand), e)
+ }
+
+ def intercept(sqlCommand: String, messages: String*): Unit = {
+ val e = intercept[ParseException](parseExpression(sqlCommand))
+ messages.foreach { message =>
+ assert(e.message.contains(message))
+ }
+ }
+
+ test("star expressions") {
+ // Global Star
+ assertEqual("*", UnresolvedStar(None))
+
+ // Targeted Star
+ assertEqual("a.b.*", UnresolvedStar(Option(Seq("a", "b"))))
+ }
+
+ // NamedExpression (Alias/Multialias)
+ test("named expressions") {
+ // No Alias
+ val r0 = 'a
+ assertEqual("a", r0)
+
+ // Single Alias.
+ val r1 = 'a as "b"
+ assertEqual("a as b", r1)
+ assertEqual("a b", r1)
+
+ // Multi-Alias
+ assertEqual("a as (b, c)", MultiAlias('a, Seq("b", "c")))
+ assertEqual("a() (b, c)", MultiAlias('a.function(), Seq("b", "c")))
+
+ // Numeric literals without a space between the literal qualifier and the alias, should not be
+ // interpreted as such. An unresolved reference should be returned instead.
+ // TODO add the JIRA-ticket number.
+ assertEqual("1SL", Symbol("1SL"))
+
+ // Aliased star is allowed.
+ assertEqual("a.* b", UnresolvedStar(Option(Seq("a"))) as 'b)
+ }
+
+ test("binary logical expressions") {
+ // And
+ assertEqual("a and b", 'a && 'b)
+
+ // Or
+ assertEqual("a or b", 'a || 'b)
+
+ // Combination And/Or check precedence
+ assertEqual("a and b or c and d", ('a && 'b) || ('c && 'd))
+ assertEqual("a or b or c and d", 'a || 'b || ('c && 'd))
+
+ // Multiple AND/OR get converted into a balanced tree
+ assertEqual("a or b or c or d or e or f", (('a || 'b) || 'c) || (('d || 'e) || 'f))
+ assertEqual("a and b and c and d and e and f", (('a && 'b) && 'c) && (('d && 'e) && 'f))
+ }
+
+ test("long binary logical expressions") {
+ def testVeryBinaryExpression(op: String, clazz: Class[_]): Unit = {
+ val sql = (1 to 1000).map(x => s"$x == $x").mkString(op)
+ val e = parseExpression(sql)
+ assert(e.collect { case _: EqualTo => true }.size === 1000)
+ assert(e.collect { case x if clazz.isInstance(x) => true }.size === 999)
+ }
+ testVeryBinaryExpression(" AND ", classOf[And])
+ testVeryBinaryExpression(" OR ", classOf[Or])
+ }
+
+ test("not expressions") {
+ assertEqual("not a", !'a)
+ assertEqual("!a", !'a)
+ assertEqual("not true > true", Not(GreaterThan(true, true)))
+ }
+
+ test("exists expression") {
+ intercept("exists (select 1 from b where b.x = a.x)", "EXISTS clauses are not supported")
+ }
+
+ test("comparison expressions") {
+ assertEqual("a = b", 'a === 'b)
+ assertEqual("a == b", 'a === 'b)
+ assertEqual("a <=> b", 'a <=> 'b)
+ assertEqual("a <> b", 'a =!= 'b)
+ assertEqual("a != b", 'a =!= 'b)
+ assertEqual("a < b", 'a < 'b)
+ assertEqual("a <= b", 'a <= 'b)
+ assertEqual("a > b", 'a > 'b)
+ assertEqual("a >= b", 'a >= 'b)
+ }
+
+ test("between expressions") {
+ assertEqual("a between b and c", 'a >= 'b && 'a <= 'c)
+ assertEqual("a not between b and c", !('a >= 'b && 'a <= 'c))
+ }
+
+ test("in expressions") {
+ assertEqual("a in (b, c, d)", 'a in ('b, 'c, 'd))
+ assertEqual("a not in (b, c, d)", !('a in ('b, 'c, 'd)))
+ }
+
+ test("in sub-query") {
+ intercept("a in (select b from c)", "IN with a Sub-query is currently not supported")
+ }
+
+ test("like expressions") {
+ assertEqual("a like 'pattern%'", 'a like "pattern%")
+ assertEqual("a not like 'pattern%'", !('a like "pattern%"))
+ assertEqual("a rlike 'pattern%'", 'a rlike "pattern%")
+ assertEqual("a not rlike 'pattern%'", !('a rlike "pattern%"))
+ assertEqual("a regexp 'pattern%'", 'a rlike "pattern%")
+ assertEqual("a not regexp 'pattern%'", !('a rlike "pattern%"))
+ }
+
+ test("is null expressions") {
+ assertEqual("a is null", 'a.isNull)
+ assertEqual("a is not null", 'a.isNotNull)
+ assertEqual("a = b is null", ('a === 'b).isNull)
+ assertEqual("a = b is not null", ('a === 'b).isNotNull)
+ }
+
+ test("binary arithmetic expressions") {
+ // Simple operations
+ assertEqual("a * b", 'a * 'b)
+ assertEqual("a / b", 'a / 'b)
+ assertEqual("a DIV b", ('a / 'b).cast(LongType))
+ assertEqual("a % b", 'a % 'b)
+ assertEqual("a + b", 'a + 'b)
+ assertEqual("a - b", 'a - 'b)
+ assertEqual("a & b", 'a & 'b)
+ assertEqual("a ^ b", 'a ^ 'b)
+ assertEqual("a | b", 'a | 'b)
+
+ // Check precedences
+ assertEqual(
+ "a * t | b ^ c & d - e + f % g DIV h / i * k",
+ 'a * 't | ('b ^ ('c & ('d - 'e + (('f % 'g / 'h).cast(LongType) / 'i * 'k)))))
+ }
+
+ test("unary arithmetic expressions") {
+ assertEqual("+a", 'a)
+ assertEqual("-a", -'a)
+ assertEqual("~a", ~'a)
+ assertEqual("-+~~a", -(~(~'a)))
+ }
+
+ test("cast expressions") {
+ // Note that DataType parsing is tested elsewhere.
+ assertEqual("cast(a as int)", 'a.cast(IntegerType))
+ assertEqual("cast(a as timestamp)", 'a.cast(TimestampType))
+ assertEqual("cast(a as array<int>)", 'a.cast(ArrayType(IntegerType)))
+ assertEqual("cast(cast(a as int) as long)", 'a.cast(IntegerType).cast(LongType))
+ }
+
+ test("function expressions") {
+ assertEqual("foo()", 'foo.function())
+ assertEqual("foo.bar()", Symbol("foo.bar").function())
+ assertEqual("foo(*)", 'foo.function(star()))
+ assertEqual("count(*)", 'count.function(1))
+ assertEqual("foo(a, b)", 'foo.function('a, 'b))
+ assertEqual("foo(all a, b)", 'foo.function('a, 'b))
+ assertEqual("foo(distinct a, b)", 'foo.distinctFunction('a, 'b))
+ assertEqual("grouping(distinct a, b)", 'grouping.distinctFunction('a, 'b))
+ assertEqual("`select`(all a, b)", 'select.function('a, 'b))
+ }
+
+ test("window function expressions") {
+ val func = 'foo.function(star())
+ def windowed(
+ partitioning: Seq[Expression] = Seq.empty,
+ ordering: Seq[SortOrder] = Seq.empty,
+ frame: WindowFrame = UnspecifiedFrame): Expression = {
+ WindowExpression(func, WindowSpecDefinition(partitioning, ordering, frame))
+ }
+
+ // Basic window testing.
+ assertEqual("foo(*) over w1", UnresolvedWindowExpression(func, WindowSpecReference("w1")))
+ assertEqual("foo(*) over ()", windowed())
+ assertEqual("foo(*) over (partition by a, b)", windowed(Seq('a, 'b)))
+ assertEqual("foo(*) over (distribute by a, b)", windowed(Seq('a, 'b)))
+ assertEqual("foo(*) over (cluster by a, b)", windowed(Seq('a, 'b)))
+ assertEqual("foo(*) over (order by a desc, b asc)", windowed(Seq.empty, Seq('a.desc, 'b.asc )))
+ assertEqual("foo(*) over (sort by a desc, b asc)", windowed(Seq.empty, Seq('a.desc, 'b.asc )))
+ assertEqual("foo(*) over (partition by a, b order by c)", windowed(Seq('a, 'b), Seq('c.asc)))
+ assertEqual("foo(*) over (distribute by a, b sort by c)", windowed(Seq('a, 'b), Seq('c.asc)))
+
+ // Test use of expressions in window functions.
+ assertEqual(
+ "sum(product + 1) over (partition by ((product) + (1)) order by 2)",
+ WindowExpression('sum.function('product + 1),
+ WindowSpecDefinition(Seq('product + 1), Seq(Literal(2).asc), UnspecifiedFrame)))
+ assertEqual(
+ "sum(product + 1) over (partition by ((product / 2) + 1) order by 2)",
+ WindowExpression('sum.function('product + 1),
+ WindowSpecDefinition(Seq('product / 2 + 1), Seq(Literal(2).asc), UnspecifiedFrame)))
+
+ // Range/Row
+ val frameTypes = Seq(("rows", RowFrame), ("range", RangeFrame))
+ val boundaries = Seq(
+ ("10 preceding", ValuePreceding(10), CurrentRow),
+ ("3 + 1 following", ValueFollowing(4), CurrentRow), // Will fail during analysis
+ ("unbounded preceding", UnboundedPreceding, CurrentRow),
+ ("unbounded following", UnboundedFollowing, CurrentRow), // Will fail during analysis
+ ("between unbounded preceding and current row", UnboundedPreceding, CurrentRow),
+ ("between unbounded preceding and unbounded following",
+ UnboundedPreceding, UnboundedFollowing),
+ ("between 10 preceding and current row", ValuePreceding(10), CurrentRow),
+ ("between current row and 5 following", CurrentRow, ValueFollowing(5)),
+ ("between 10 preceding and 5 following", ValuePreceding(10), ValueFollowing(5))
+ )
+ frameTypes.foreach {
+ case (frameTypeSql, frameType) =>
+ boundaries.foreach {
+ case (boundarySql, begin, end) =>
+ val query = s"foo(*) over (partition by a order by b $frameTypeSql $boundarySql)"
+ val expr = windowed(Seq('a), Seq('b.asc), SpecifiedWindowFrame(frameType, begin, end))
+ assertEqual(query, expr)
+ }
+ }
+
+ // We cannot use non integer constants.
+ intercept("foo(*) over (partition by a order by b rows 10.0 preceding)",
+ "Frame bound value must be a constant integer.")
+
+ // We cannot use an arbitrary expression.
+ intercept("foo(*) over (partition by a order by b rows exp(b) preceding)",
+ "Frame bound value must be a constant integer.")
+ }
+
+ test("row constructor") {
+ // Note that '(a)' will be interpreted as a nested expression.
+ assertEqual("(a, b)", CreateStruct(Seq('a, 'b)))
+ assertEqual("(a, b, c)", CreateStruct(Seq('a, 'b, 'c)))
+ }
+
+ test("scalar sub-query") {
+ assertEqual(
+ "(select max(val) from tbl) > current",
+ ScalarSubquery(table("tbl").select('max.function('val))) > 'current)
+ assertEqual(
+ "a = (select b from s)",
+ 'a === ScalarSubquery(table("s").select('b)))
+ }
+
+ test("case when") {
+ assertEqual("case a when 1 then b when 2 then c else d end",
+ CaseKeyWhen('a, Seq(1, 'b, 2, 'c, 'd)))
+ assertEqual("case when a = 1 then b when a = 2 then c else d end",
+ CaseWhen(Seq(('a === 1, 'b.expr), ('a === 2, 'c.expr)), 'd))
+ }
+
+ test("dereference") {
+ assertEqual("a.b", UnresolvedAttribute("a.b"))
+ assertEqual("`select`.b", UnresolvedAttribute("select.b"))
+ assertEqual("(a + b).b", ('a + 'b).getField("b")) // This will fail analysis.
+ assertEqual("struct(a, b).b", 'struct.function('a, 'b).getField("b"))
+ }
+
+ test("reference") {
+ // Regular
+ assertEqual("a", 'a)
+
+ // Starting with a digit.
+ assertEqual("1a", Symbol("1a"))
+
+ // Quoted using a keyword.
+ assertEqual("`select`", 'select)
+
+ // Unquoted using an unreserved keyword.
+ assertEqual("columns", 'columns)
+ }
+
+ test("subscript") {
+ assertEqual("a[b]", 'a.getItem('b))
+ assertEqual("a[1 + 1]", 'a.getItem(Literal(1) + 1))
+ assertEqual("`c`.a[b]", UnresolvedAttribute("c.a").getItem('b))
+ }
+
+ test("parenthesis") {
+ assertEqual("(a)", 'a)
+ assertEqual("r * (a + b)", 'r * ('a + 'b))
+ }
+
+ test("type constructors") {
+ // Dates.
+ assertEqual("dAte '2016-03-11'", Literal(Date.valueOf("2016-03-11")))
+ intercept[IllegalArgumentException] {
+ parseExpression("DAtE 'mar 11 2016'")
+ }
+
+ // Timestamps.
+ assertEqual("tImEstAmp '2016-03-11 20:54:00.000'",
+ Literal(Timestamp.valueOf("2016-03-11 20:54:00.000")))
+ intercept[IllegalArgumentException] {
+ parseExpression("timestamP '2016-33-11 20:54:00.000'")
+ }
+
+ // Unsupported datatype.
+ intercept("GEO '(10,-6)'", "Literals of type 'GEO' are currently not supported.")
+ }
+
+ test("literals") {
+ // NULL
+ assertEqual("null", Literal(null))
+
+ // Boolean
+ assertEqual("trUe", Literal(true))
+ assertEqual("False", Literal(false))
+
+ // Integral should have the narrowest possible type
+ assertEqual("787324", Literal(787324))
+ assertEqual("7873247234798249234", Literal(7873247234798249234L))
+ assertEqual("78732472347982492793712334",
+ Literal(BigDecimal("78732472347982492793712334").underlying()))
+
+ // Decimal
+ assertEqual("7873247234798249279371.2334",
+ Literal(BigDecimal("7873247234798249279371.2334").underlying()))
+
+ // Scientific Decimal
+ assertEqual("9.0e1", 90d)
+ assertEqual(".9e+2", 90d)
+ assertEqual("0.9e+2", 90d)
+ assertEqual("900e-1", 90d)
+ assertEqual("900.0E-1", 90d)
+ assertEqual("9.e+1", 90d)
+ intercept(".e3")
+
+ // Tiny Int Literal
+ assertEqual("10Y", Literal(10.toByte))
+ intercept("-1000Y")
+
+ // Small Int Literal
+ assertEqual("10S", Literal(10.toShort))
+ intercept("40000S")
+
+ // Long Int Literal
+ assertEqual("10L", Literal(10L))
+ intercept("78732472347982492793712334L")
+
+ // Double Literal
+ assertEqual("10.0D", Literal(10.0D))
+ // TODO we need to figure out if we should throw an exception here!
+ assertEqual("1E309", Literal(Double.PositiveInfinity))
+ }
+
+ test("strings") {
+ // Single Strings.
+ assertEqual("\"hello\"", "hello")
+ assertEqual("'hello'", "hello")
+
+ // Multi-Strings.
+ assertEqual("\"hello\" 'world'", "helloworld")
+ assertEqual("'hello' \" \" 'world'", "hello world")
+
+ // 'LIKE' string literals. Notice that an escaped '%' is the same as an escaped '\' and a
+ // regular '%'; to get the correct result you need to add another escaped '\'.
+ // TODO figure out if we shouldn't change the ParseUtils.unescapeSQLString method?
+ assertEqual("'pattern%'", "pattern%")
+ assertEqual("'no-pattern\\%'", "no-pattern\\%")
+ assertEqual("'pattern\\\\%'", "pattern\\%")
+ assertEqual("'pattern\\\\\\%'", "pattern\\\\%")
+
+ // Escaped characters.
+ // See: http://dev.mysql.com/doc/refman/5.7/en/string-literals.html
+ assertEqual("'\\0'", "\u0000") // ASCII NUL (X'00')
+ assertEqual("'\\''", "\'") // Single quote
+ assertEqual("'\\\"'", "\"") // Double quote
+ assertEqual("'\\b'", "\b") // Backspace
+ assertEqual("'\\n'", "\n") // Newline
+ assertEqual("'\\r'", "\r") // Carriage return
+ assertEqual("'\\t'", "\t") // Tab character
+ assertEqual("'\\Z'", "\u001A") // ASCII 26 - CTRL + Z (EOF on windows)
+
+ // Octals
+ assertEqual("'\\110\\145\\154\\154\\157\\041'", "Hello!")
+
+ // Unicode
+ assertEqual("'\\u0057\\u006F\\u0072\\u006C\\u0064\\u0020\\u003A\\u0029'", "World :)")
+ }
+
+ test("intervals") {
+ def intervalLiteral(u: String, s: String): Literal = {
+ Literal(CalendarInterval.fromSingleUnitString(u, s))
+ }
+
+ // Empty interval statement
+ intercept("interval", "at least one time unit should be given for interval literal")
+
+ // Single Intervals.
+ val units = Seq(
+ "year",
+ "month",
+ "week",
+ "day",
+ "hour",
+ "minute",
+ "second",
+ "millisecond",
+ "microsecond")
+ val forms = Seq("", "s")
+ val values = Seq("0", "10", "-7", "21")
+ units.foreach { unit =>
+ forms.foreach { form =>
+ values.foreach { value =>
+ val expected = intervalLiteral(unit, value)
+ assertEqual(s"interval $value $unit$form", expected)
+ assertEqual(s"interval '$value' $unit$form", expected)
+ }
+ }
+ }
+
+ // Hive nanosecond notation.
+ assertEqual("interval 13.123456789 seconds", intervalLiteral("second", "13.123456789"))
+ assertEqual("interval -13.123456789 second", intervalLiteral("second", "-13.123456789"))
+
+ // Non Existing unit
+ intercept("interval 10 nanoseconds", "No interval can be constructed")
+
+ // Year-Month intervals.
+ val yearMonthValues = Seq("123-10", "496-0", "-2-3", "-123-0")
+ yearMonthValues.foreach { value =>
+ val result = Literal(CalendarInterval.fromYearMonthString(value))
+ assertEqual(s"interval '$value' year to month", result)
+ }
+
+ // Day-Time intervals.
+ val datTimeValues = Seq(
+ "99 11:22:33.123456789",
+ "-99 11:22:33.123456789",
+ "10 9:8:7.123456789",
+ "1 0:0:0",
+ "-1 0:0:0",
+ "1 0:0:1")
+ datTimeValues.foreach { value =>
+ val result = Literal(CalendarInterval.fromDayTimeString(value))
+ assertEqual(s"interval '$value' day to second", result)
+ }
+
+ // Unknown FROM TO intervals
+ intercept("interval 10 month to second", "Intervals FROM month TO second are not supported.")
+
+ // Composed intervals.
+ assertEqual(
+ "interval 3 months 22 seconds 1 millisecond",
+ Literal(new CalendarInterval(3, 22001000L)))
+ assertEqual(
+ "interval 3 years '-1-10' year to month 3 weeks '1 0:0:2' day to second",
+ Literal(new CalendarInterval(14,
+ 22 * CalendarInterval.MICROS_PER_DAY + 2 * CalendarInterval.MICROS_PER_SECOND)))
+ }
+
+ test("composed expressions") {
+ assertEqual("1 + r.r As q", (Literal(1) + UnresolvedAttribute("r.r")).as("q"))
+ assertEqual("1 - f('o', o(bar))", Literal(1) - 'f.function("o", 'o.function('bar)))
+ intercept("1 - f('o', o(bar)) hello * world", "mismatched input '*'")
+ }
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ParserUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ParserUtilsSuite.scala
new file mode 100644
index 0000000000..d090daf7b4
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ParserUtilsSuite.scala
@@ -0,0 +1,65 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.catalyst.parser
+
+import org.apache.spark.SparkFunSuite
+
+class ParserUtilsSuite extends SparkFunSuite {
+
+ import ParserUtils._
+
+ test("unescapeSQLString") {
+ // scalastyle:off nonascii
+
+ // String not including escaped characters and enclosed by double quotes.
+ assert(unescapeSQLString(""""abcdefg"""") == "abcdefg")
+
+ // String enclosed by single quotes.
+ assert(unescapeSQLString("""'C0FFEE'""") == "C0FFEE")
+
+ // Strings including single escaped characters.
+ assert(unescapeSQLString("""'\0'""") == "\u0000")
+ assert(unescapeSQLString(""""\'"""") == "\'")
+ assert(unescapeSQLString("""'\"'""") == "\"")
+ assert(unescapeSQLString(""""\b"""") == "\b")
+ assert(unescapeSQLString("""'\n'""") == "\n")
+ assert(unescapeSQLString(""""\r"""") == "\r")
+ assert(unescapeSQLString("""'\t'""") == "\t")
+ assert(unescapeSQLString(""""\Z"""") == "\u001A")
+ assert(unescapeSQLString("""'\\'""") == "\\")
+ assert(unescapeSQLString(""""\%"""") == "\\%")
+ assert(unescapeSQLString("""'\_'""") == "\\_")
+
+ // String including '\000' style literal characters.
+ assert(unescapeSQLString("""'3 + 5 = \070'""") == "3 + 5 = \u0038")
+ assert(unescapeSQLString(""""\000"""") == "\u0000")
+
+ // String including invalid '\000' style literal characters.
+ assert(unescapeSQLString(""""\256"""") == "256")
+
+ // String including a '\u0000' style literal characters (\u732B is a cat in Kanji).
+ assert(unescapeSQLString(""""How cute \u732B are"""") == "How cute \u732B are")
+
+ // String including a surrogate pair character
+ // (\uD867\uDE3D is Okhotsk atka mackerel in Kanji).
+ assert(unescapeSQLString(""""\uD867\uDE3D is a fish"""") == "\uD867\uDE3D is a fish")
+
+ // scalastyle:on nonascii
+ }
+
+ // TODO: Add test cases for other methods in ParserUtils
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala
new file mode 100644
index 0000000000..411e2372f2
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala
@@ -0,0 +1,431 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.catalyst.parser
+
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.catalyst.analysis.UnresolvedGenerator
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans._
+import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.types.{BooleanType, IntegerType}
+
+class PlanParserSuite extends PlanTest {
+ import CatalystSqlParser._
+ import org.apache.spark.sql.catalyst.dsl.expressions._
+ import org.apache.spark.sql.catalyst.dsl.plans._
+
+ def assertEqual(sqlCommand: String, plan: LogicalPlan): Unit = {
+ comparePlans(parsePlan(sqlCommand), plan)
+ }
+
+ def intercept(sqlCommand: String, messages: String*): Unit = {
+ val e = intercept[ParseException](parsePlan(sqlCommand))
+ messages.foreach { message =>
+ assert(e.message.contains(message))
+ }
+ }
+
+ test("case insensitive") {
+ val plan = table("a").select(star())
+ assertEqual("sELEct * FroM a", plan)
+ assertEqual("select * fRoM a", plan)
+ assertEqual("SELECT * FROM a", plan)
+ }
+
+ test("show functions") {
+ assertEqual("show functions", ShowFunctions(None, None))
+ assertEqual("show functions foo", ShowFunctions(None, Some("foo")))
+ assertEqual("show functions foo.bar", ShowFunctions(Some("foo"), Some("bar")))
+ assertEqual("show functions 'foo\\\\.*'", ShowFunctions(None, Some("foo\\.*")))
+ intercept("show functions foo.bar.baz", "SHOW FUNCTIONS unsupported name")
+ }
+
+ test("describe function") {
+ assertEqual("describe function bar", DescribeFunction("bar", isExtended = false))
+ assertEqual("describe function extended bar", DescribeFunction("bar", isExtended = true))
+ assertEqual("describe function foo.bar", DescribeFunction("foo.bar", isExtended = false))
+ assertEqual("describe function extended f.bar", DescribeFunction("f.bar", isExtended = true))
+ }
+
+ test("set operations") {
+ val a = table("a").select(star())
+ val b = table("b").select(star())
+
+ assertEqual("select * from a union select * from b", Distinct(a.union(b)))
+ assertEqual("select * from a union distinct select * from b", Distinct(a.union(b)))
+ assertEqual("select * from a union all select * from b", a.union(b))
+ assertEqual("select * from a except select * from b", a.except(b))
+ intercept("select * from a except all select * from b", "EXCEPT ALL is not supported.")
+ assertEqual("select * from a except distinct select * from b", a.except(b))
+ assertEqual("select * from a intersect select * from b", a.intersect(b))
+ intercept("select * from a intersect all select * from b", "INTERSECT ALL is not supported.")
+ assertEqual("select * from a intersect distinct select * from b", a.intersect(b))
+ }
+
+ test("common table expressions") {
+ def cte(plan: LogicalPlan, namedPlans: (String, LogicalPlan)*): With = {
+ val ctes = namedPlans.map {
+ case (name, cte) =>
+ name -> SubqueryAlias(name, cte)
+ }.toMap
+ With(plan, ctes)
+ }
+ assertEqual(
+ "with cte1 as (select * from a) select * from cte1",
+ cte(table("cte1").select(star()), "cte1" -> table("a").select(star())))
+ assertEqual(
+ "with cte1 (select 1) select * from cte1",
+ cte(table("cte1").select(star()), "cte1" -> OneRowRelation.select(1)))
+ assertEqual(
+ "with cte1 (select 1), cte2 as (select * from cte1) select * from cte2",
+ cte(table("cte2").select(star()),
+ "cte1" -> OneRowRelation.select(1),
+ "cte2" -> table("cte1").select(star())))
+ intercept(
+ "with cte1 (select 1), cte1 as (select 1 from cte1) select * from cte1",
+ "Name 'cte1' is used for multiple common table expressions")
+ }
+
+ test("simple select query") {
+ assertEqual("select 1", OneRowRelation.select(1))
+ assertEqual("select a, b", OneRowRelation.select('a, 'b))
+ assertEqual("select a, b from db.c", table("db", "c").select('a, 'b))
+ assertEqual("select a, b from db.c where x < 1", table("db", "c").where('x < 1).select('a, 'b))
+ assertEqual(
+ "select a, b from db.c having x < 1",
+ table("db", "c").select('a, 'b).where(('x < 1).cast(BooleanType)))
+ assertEqual("select distinct a, b from db.c", Distinct(table("db", "c").select('a, 'b)))
+ assertEqual("select all a, b from db.c", table("db", "c").select('a, 'b))
+ }
+
+ test("reverse select query") {
+ assertEqual("from a", table("a"))
+ assertEqual("from a select b, c", table("a").select('b, 'c))
+ assertEqual(
+ "from db.a select b, c where d < 1", table("db", "a").where('d < 1).select('b, 'c))
+ assertEqual("from a select distinct b, c", Distinct(table("a").select('b, 'c)))
+ assertEqual(
+ "from (from a union all from b) c select *",
+ table("a").union(table("b")).as("c").select(star()))
+ }
+
+ test("multi select query") {
+ assertEqual(
+ "from a select * select * where s < 10",
+ table("a").select(star()).union(table("a").where('s < 10).select(star())))
+ intercept(
+ "from a select * select * from x where a.s < 10",
+ "Multi-Insert queries cannot have a FROM clause in their individual SELECT statements")
+ assertEqual(
+ "from a insert into tbl1 select * insert into tbl2 select * where s < 10",
+ table("a").select(star()).insertInto("tbl1").union(
+ table("a").where('s < 10).select(star()).insertInto("tbl2")))
+ }
+
+ test("query organization") {
+ // Test all valid combinations of order by/sort by/distribute by/cluster by/limit/windows
+ val baseSql = "select * from t"
+ val basePlan = table("t").select(star())
+
+ val ws = Map("w1" -> WindowSpecDefinition(Seq.empty, Seq.empty, UnspecifiedFrame))
+ val limitWindowClauses = Seq(
+ ("", (p: LogicalPlan) => p),
+ (" limit 10", (p: LogicalPlan) => p.limit(10)),
+ (" window w1 as ()", (p: LogicalPlan) => WithWindowDefinition(ws, p)),
+ (" window w1 as () limit 10", (p: LogicalPlan) => WithWindowDefinition(ws, p).limit(10))
+ )
+
+ val orderSortDistrClusterClauses = Seq(
+ ("", basePlan),
+ (" order by a, b desc", basePlan.orderBy('a.asc, 'b.desc)),
+ (" sort by a, b desc", basePlan.sortBy('a.asc, 'b.desc)),
+ (" distribute by a, b", basePlan.distribute('a, 'b)),
+ (" distribute by a sort by b", basePlan.distribute('a).sortBy('b.asc)),
+ (" cluster by a, b", basePlan.distribute('a, 'b).sortBy('a.asc, 'b.asc))
+ )
+
+ orderSortDistrClusterClauses.foreach {
+ case (s1, p1) =>
+ limitWindowClauses.foreach {
+ case (s2, pf2) =>
+ assertEqual(baseSql + s1 + s2, pf2(p1))
+ }
+ }
+
+ val msg = "Combination of ORDER BY/SORT BY/DISTRIBUTE BY/CLUSTER BY is not supported"
+ intercept(s"$baseSql order by a sort by a", msg)
+ intercept(s"$baseSql cluster by a distribute by a", msg)
+ intercept(s"$baseSql order by a cluster by a", msg)
+ intercept(s"$baseSql order by a distribute by a", msg)
+ }
+
+ test("insert into") {
+ val sql = "select * from t"
+ val plan = table("t").select(star())
+ def insert(
+ partition: Map[String, Option[String]],
+ overwrite: Boolean = false,
+ ifNotExists: Boolean = false): LogicalPlan =
+ InsertIntoTable(table("s"), partition, plan, overwrite, ifNotExists)
+
+ // Single inserts
+ assertEqual(s"insert overwrite table s $sql",
+ insert(Map.empty, overwrite = true))
+ assertEqual(s"insert overwrite table s if not exists $sql",
+ insert(Map.empty, overwrite = true, ifNotExists = true))
+ assertEqual(s"insert into s $sql",
+ insert(Map.empty))
+ assertEqual(s"insert into table s partition (c = 'd', e = 1) $sql",
+ insert(Map("c" -> Option("d"), "e" -> Option("1"))))
+ assertEqual(s"insert overwrite table s partition (c = 'd', x) if not exists $sql",
+ insert(Map("c" -> Option("d"), "x" -> None), overwrite = true, ifNotExists = true))
+
+ // Multi insert
+ val plan2 = table("t").where('x > 5).select(star())
+ assertEqual("from t insert into s select * limit 1 insert into u select * where x > 5",
+ InsertIntoTable(
+ table("s"), Map.empty, plan.limit(1), overwrite = false, ifNotExists = false).union(
+ InsertIntoTable(
+ table("u"), Map.empty, plan2, overwrite = false, ifNotExists = false)))
+ }
+
+ test("aggregation") {
+ val sql = "select a, b, sum(c) as c from d group by a, b"
+
+ // Normal
+ assertEqual(sql, table("d").groupBy('a, 'b)('a, 'b, 'sum.function('c).as("c")))
+
+ // Cube
+ assertEqual(s"$sql with cube",
+ table("d").groupBy(Cube(Seq('a, 'b)))('a, 'b, 'sum.function('c).as("c")))
+
+ // Rollup
+ assertEqual(s"$sql with rollup",
+ table("d").groupBy(Rollup(Seq('a, 'b)))('a, 'b, 'sum.function('c).as("c")))
+
+ // Grouping Sets
+ assertEqual(s"$sql grouping sets((a, b), (a), ())",
+ GroupingSets(Seq(0, 1, 3), Seq('a, 'b), table("d"), Seq('a, 'b, 'sum.function('c).as("c"))))
+ intercept(s"$sql grouping sets((a, b), (c), ())",
+ "c doesn't show up in the GROUP BY list")
+ }
+
+ test("limit") {
+ val sql = "select * from t"
+ val plan = table("t").select(star())
+ assertEqual(s"$sql limit 10", plan.limit(10))
+ assertEqual(s"$sql limit cast(9 / 4 as int)", plan.limit(Cast(Literal(9) / 4, IntegerType)))
+ }
+
+ test("window spec") {
+ // Note that WindowSpecs are testing in the ExpressionParserSuite
+ val sql = "select * from t"
+ val plan = table("t").select(star())
+ val spec = WindowSpecDefinition(Seq('a, 'b), Seq('c.asc),
+ SpecifiedWindowFrame(RowFrame, ValuePreceding(1), ValueFollowing(1)))
+
+ // Test window resolution.
+ val ws1 = Map("w1" -> spec, "w2" -> spec, "w3" -> spec)
+ assertEqual(
+ s"""$sql
+ |window w1 as (partition by a, b order by c rows between 1 preceding and 1 following),
+ | w2 as w1,
+ | w3 as w1""".stripMargin,
+ WithWindowDefinition(ws1, plan))
+
+ // Fail with no reference.
+ intercept(s"$sql window w2 as w1", "Cannot resolve window reference 'w1'")
+
+ // Fail when resolved reference is not a window spec.
+ intercept(
+ s"""$sql
+ |window w1 as (partition by a, b order by c rows between 1 preceding and 1 following),
+ | w2 as w1,
+ | w3 as w2""".stripMargin,
+ "Window reference 'w2' is not a window specification"
+ )
+ }
+
+ test("lateral view") {
+ // Single lateral view
+ assertEqual(
+ "select * from t lateral view explode(x) expl as x",
+ table("t")
+ .generate(Explode('x), join = true, outer = false, Some("expl"), Seq("x"))
+ .select(star()))
+
+ // Multiple lateral views
+ assertEqual(
+ """select *
+ |from t
+ |lateral view explode(x) expl
+ |lateral view outer json_tuple(x, y) jtup q, z""".stripMargin,
+ table("t")
+ .generate(Explode('x), join = true, outer = false, Some("expl"), Seq.empty)
+ .generate(JsonTuple(Seq('x, 'y)), join = true, outer = true, Some("jtup"), Seq("q", "z"))
+ .select(star()))
+
+ // Multi-Insert lateral views.
+ val from = table("t1").generate(Explode('x), join = true, outer = false, Some("expl"), Seq("x"))
+ assertEqual(
+ """from t1
+ |lateral view explode(x) expl as x
+ |insert into t2
+ |select *
+ |lateral view json_tuple(x, y) jtup q, z
+ |insert into t3
+ |select *
+ |where s < 10
+ """.stripMargin,
+ Union(from
+ .generate(JsonTuple(Seq('x, 'y)), join = true, outer = false, Some("jtup"), Seq("q", "z"))
+ .select(star())
+ .insertInto("t2"),
+ from.where('s < 10).select(star()).insertInto("t3")))
+
+ // Unresolved generator.
+ val expected = table("t")
+ .generate(
+ UnresolvedGenerator("posexplode", Seq('x)),
+ join = true,
+ outer = false,
+ Some("posexpl"),
+ Seq("x", "y"))
+ .select(star())
+ assertEqual(
+ "select * from t lateral view posexplode(x) posexpl as x, y",
+ expected)
+ }
+
+ test("joins") {
+ // Test single joins.
+ val testUnconditionalJoin = (sql: String, jt: JoinType) => {
+ assertEqual(
+ s"select * from t as tt $sql u",
+ table("t").as("tt").join(table("u"), jt, None).select(star()))
+ }
+ val testConditionalJoin = (sql: String, jt: JoinType) => {
+ assertEqual(
+ s"select * from t $sql u as uu on a = b",
+ table("t").join(table("u").as("uu"), jt, Option('a === 'b)).select(star()))
+ }
+ val testNaturalJoin = (sql: String, jt: JoinType) => {
+ assertEqual(
+ s"select * from t tt natural $sql u as uu",
+ table("t").as("tt").join(table("u").as("uu"), NaturalJoin(jt), None).select(star()))
+ }
+ val testUsingJoin = (sql: String, jt: JoinType) => {
+ assertEqual(
+ s"select * from t $sql u using(a, b)",
+ table("t").join(table("u"), UsingJoin(jt, Seq('a.attr, 'b.attr)), None).select(star()))
+ }
+ val testAll = Seq(testUnconditionalJoin, testConditionalJoin, testNaturalJoin, testUsingJoin)
+ val testExistence = Seq(testUnconditionalJoin, testConditionalJoin, testUsingJoin)
+ def test(sql: String, jt: JoinType, tests: Seq[(String, JoinType) => Unit]): Unit = {
+ tests.foreach(_(sql, jt))
+ }
+ test("cross join", Inner, Seq(testUnconditionalJoin))
+ test(",", Inner, Seq(testUnconditionalJoin))
+ test("join", Inner, testAll)
+ test("inner join", Inner, testAll)
+ test("left join", LeftOuter, testAll)
+ test("left outer join", LeftOuter, testAll)
+ test("right join", RightOuter, testAll)
+ test("right outer join", RightOuter, testAll)
+ test("full join", FullOuter, testAll)
+ test("full outer join", FullOuter, testAll)
+ test("left semi join", LeftSemi, testExistence)
+ test("left anti join", LeftAnti, testExistence)
+ test("anti join", LeftAnti, testExistence)
+
+ // Test multiple consecutive joins
+ assertEqual(
+ "select * from a join b join c right join d",
+ table("a").join(table("b")).join(table("c")).join(table("d"), RightOuter).select(star()))
+ }
+
+ test("sampled relations") {
+ val sql = "select * from t"
+ assertEqual(s"$sql tablesample(100 rows)",
+ table("t").limit(100).select(star()))
+ assertEqual(s"$sql tablesample(43 percent) as x",
+ Sample(0, .43d, withReplacement = false, 10L, table("t").as("x"))(true).select(star()))
+ assertEqual(s"$sql tablesample(bucket 4 out of 10) as x",
+ Sample(0, .4d, withReplacement = false, 10L, table("t").as("x"))(true).select(star()))
+ intercept(s"$sql tablesample(bucket 4 out of 10 on x) as x",
+ "TABLESAMPLE(BUCKET x OUT OF y ON id) is not supported")
+ intercept(s"$sql tablesample(bucket 11 out of 10) as x",
+ s"Sampling fraction (${11.0/10.0}) must be on interval [0, 1]")
+ }
+
+ test("sub-query") {
+ val plan = table("t0").select('id)
+ assertEqual("select id from (t0)", plan)
+ assertEqual("select id from ((((((t0))))))", plan)
+ assertEqual(
+ "(select * from t1) union distinct (select * from t2)",
+ Distinct(table("t1").select(star()).union(table("t2").select(star()))))
+ assertEqual(
+ "select * from ((select * from t1) union (select * from t2)) t",
+ Distinct(
+ table("t1").select(star()).union(table("t2").select(star()))).as("t").select(star()))
+ assertEqual(
+ """select id
+ |from (((select id from t0)
+ | union all
+ | (select id from t0))
+ | union all
+ | (select id from t0)) as u_1
+ """.stripMargin,
+ plan.union(plan).union(plan).as("u_1").select('id))
+ }
+
+ test("scalar sub-query") {
+ assertEqual(
+ "select (select max(b) from s) ss from t",
+ table("t").select(ScalarSubquery(table("s").select('max.function('b))).as("ss")))
+ assertEqual(
+ "select * from t where a = (select b from s)",
+ table("t").where('a === ScalarSubquery(table("s").select('b))).select(star()))
+ assertEqual(
+ "select g from t group by g having a > (select b from s)",
+ table("t")
+ .groupBy('g)('g)
+ .where(('a > ScalarSubquery(table("s").select('b))).cast(BooleanType)))
+ }
+
+ test("table reference") {
+ assertEqual("table t", table("t"))
+ assertEqual("table d.t", table("d", "t"))
+ }
+
+ test("inline table") {
+ assertEqual("values 1, 2, 3, 4", LocalRelation.fromExternalRows(
+ Seq('col1.int),
+ Seq(1, 2, 3, 4).map(x => Row(x))))
+ assertEqual(
+ "values (1, 'a'), (2, 'b'), (3, 'c') as tbl(a, b)",
+ LocalRelation.fromExternalRows(
+ Seq('a.int, 'b.string),
+ Seq((1, "a"), (2, "b"), (3, "c")).map(x => Row(x._1, x._2))).as("tbl"))
+ intercept("values (a, 'a'), (b, 'b')",
+ "All expressions in an inline table must be constants.")
+ intercept("values (1, 'a'), (2, 'b') as tbl(a, b, c)",
+ "Number of aliases must match the number of fields in an inline table.")
+ intercept[ArrayIndexOutOfBoundsException](parsePlan("values (1, 'a'), (2, 'b', 5Y)"))
+ }
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ASTNodeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala
index 8b05f9e33d..297b1931a9 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ASTNodeSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala
@@ -17,22 +17,26 @@
package org.apache.spark.sql.catalyst.parser
import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.catalyst.TableIdentifier
-class ASTNodeSuite extends SparkFunSuite {
- test("SPARK-13157 - remainder must return all input chars") {
- val inputs = Seq(
- ("add jar", "file:///tmp/ab/TestUDTF.jar"),
- ("add jar", "file:///tmp/a@b/TestUDTF.jar"),
- ("add jar", "c:\\windows32\\TestUDTF.jar"),
- ("add jar", "some \nbad\t\tfile\r\n.\njar"),
- ("ADD JAR", "@*#&@(!#@$^*!@^@#(*!@#"),
- ("SET", "foo=bar"),
- ("SET", "foo*)(@#^*@&!#^=bar")
- )
- inputs.foreach {
- case (command, arguments) =>
- val node = ParseDriver.parsePlan(s"$command $arguments", null)
- assert(node.remainder === arguments)
+class TableIdentifierParserSuite extends SparkFunSuite {
+ import CatalystSqlParser._
+
+ test("table identifier") {
+ // Regular names.
+ assert(TableIdentifier("q") === parseTableIdentifier("q"))
+ assert(TableIdentifier("q", Option("d")) === parseTableIdentifier("d.q"))
+
+ // Illegal names.
+ intercept[ParseException](parseTableIdentifier(""))
+ intercept[ParseException](parseTableIdentifier("d.q.g"))
+
+ // SQL Keywords.
+ val keywords = Seq("select", "from", "where", "left", "right")
+ keywords.foreach { keyword =>
+ intercept[ParseException](parseTableIdentifier(keyword))
+ assert(TableIdentifier(keyword) === parseTableIdentifier(s"`$keyword`"))
+ assert(TableIdentifier(keyword, Option("db")) === parseTableIdentifier(s"db.`$keyword`"))
}
}
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala
index e5063599a3..81cc6b123c 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala
@@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical._
-import org.apache.spark.sql.types.{IntegerType, StringType}
+import org.apache.spark.sql.types.{DoubleType, IntegerType, LongType, StringType}
class ConstraintPropagationSuite extends SparkFunSuite {
@@ -88,6 +88,33 @@ class ConstraintPropagationSuite extends SparkFunSuite {
IsNotNull(resolveColumn(aliasedRelation.analyze, "a")))))
}
+ test("propagating constraints in expand") {
+ val tr = LocalRelation('a.int, 'b.int, 'c.int)
+
+ assert(tr.analyze.constraints.isEmpty)
+
+ // We add IsNotNull constraints for 'a, 'b and 'c into LocalRelation
+ // by creating notNullRelation.
+ val notNullRelation = tr.where('c.attr > 10 && 'a.attr < 5 && 'b.attr > 2)
+ verifyConstraints(notNullRelation.analyze.constraints,
+ ExpressionSet(Seq(resolveColumn(notNullRelation.analyze, "c") > 10,
+ IsNotNull(resolveColumn(notNullRelation.analyze, "c")),
+ resolveColumn(notNullRelation.analyze, "a") < 5,
+ IsNotNull(resolveColumn(notNullRelation.analyze, "a")),
+ resolveColumn(notNullRelation.analyze, "b") > 2,
+ IsNotNull(resolveColumn(notNullRelation.analyze, "b")))))
+
+ val expand = Expand(
+ Seq(
+ Seq('c, Literal.create(null, StringType), 1),
+ Seq('c, 'a, 2)),
+ Seq('c, 'a, 'gid.int),
+ Project(Seq('a, 'c),
+ notNullRelation))
+ verifyConstraints(expand.analyze.constraints,
+ ExpressionSet(Seq.empty[Expression]))
+ }
+
test("propagating constraints in aliases") {
val tr = LocalRelation('a.int, 'b.string, 'c.int)
@@ -121,6 +148,20 @@ class ConstraintPropagationSuite extends SparkFunSuite {
.analyze.constraints,
ExpressionSet(Seq(resolveColumn(tr1, "a") > 10,
IsNotNull(resolveColumn(tr1, "a")))))
+
+ val a = resolveColumn(tr1, "a")
+ verifyConstraints(tr1
+ .where('a.attr > 10)
+ .union(tr2.where('d.attr > 11))
+ .analyze.constraints,
+ ExpressionSet(Seq(a > 10 || a > 11, IsNotNull(a))))
+
+ val b = resolveColumn(tr1, "b")
+ verifyConstraints(tr1
+ .where('a.attr > 10 && 'b.attr < 10)
+ .union(tr2.where('d.attr > 11 && 'e.attr < 11))
+ .analyze.constraints,
+ ExpressionSet(Seq(a > 10 || a > 11, b < 10 || b < 11, IsNotNull(a), IsNotNull(b))))
}
test("propagating constraints in intersect") {
@@ -219,6 +260,89 @@ class ConstraintPropagationSuite extends SparkFunSuite {
IsNotNull(resolveColumn(tr, "b")))))
}
+ test("infer constraints on cast") {
+ val tr = LocalRelation('a.int, 'b.long, 'c.int, 'd.long, 'e.int)
+ verifyConstraints(
+ tr.where('a.attr === 'b.attr &&
+ 'c.attr + 100 > 'd.attr &&
+ IsNotNull(Cast(Cast(resolveColumn(tr, "e"), LongType), LongType))).analyze.constraints,
+ ExpressionSet(Seq(Cast(resolveColumn(tr, "a"), LongType) === resolveColumn(tr, "b"),
+ Cast(resolveColumn(tr, "c") + 100, LongType) > resolveColumn(tr, "d"),
+ IsNotNull(resolveColumn(tr, "a")),
+ IsNotNull(resolveColumn(tr, "b")),
+ IsNotNull(resolveColumn(tr, "c")),
+ IsNotNull(resolveColumn(tr, "d")),
+ IsNotNull(resolveColumn(tr, "e")),
+ IsNotNull(Cast(Cast(resolveColumn(tr, "e"), LongType), LongType)))))
+ }
+
+ test("infer isnotnull constraints from compound expressions") {
+ val tr = LocalRelation('a.int, 'b.long, 'c.int, 'd.long, 'e.int)
+ verifyConstraints(
+ tr.where('a.attr + 'b.attr === 'c.attr &&
+ IsNotNull(
+ Cast(
+ Cast(Cast(resolveColumn(tr, "e"), LongType), LongType), LongType))).analyze.constraints,
+ ExpressionSet(Seq(
+ Cast(resolveColumn(tr, "a"), LongType) + resolveColumn(tr, "b") ===
+ Cast(resolveColumn(tr, "c"), LongType),
+ IsNotNull(resolveColumn(tr, "a")),
+ IsNotNull(resolveColumn(tr, "b")),
+ IsNotNull(resolveColumn(tr, "c")),
+ IsNotNull(resolveColumn(tr, "e")),
+ IsNotNull(Cast(Cast(Cast(resolveColumn(tr, "e"), LongType), LongType), LongType)))))
+
+ verifyConstraints(
+ tr.where(('a.attr * 'b.attr + 100) === 'c.attr && 'd / 10 === 'e).analyze.constraints,
+ ExpressionSet(Seq(
+ Cast(resolveColumn(tr, "a"), LongType) * resolveColumn(tr, "b") + Cast(100, LongType) ===
+ Cast(resolveColumn(tr, "c"), LongType),
+ Cast(resolveColumn(tr, "d"), DoubleType) /
+ Cast(Cast(10, LongType), DoubleType) ===
+ Cast(resolveColumn(tr, "e"), DoubleType),
+ IsNotNull(resolveColumn(tr, "a")),
+ IsNotNull(resolveColumn(tr, "b")),
+ IsNotNull(resolveColumn(tr, "c")),
+ IsNotNull(resolveColumn(tr, "d")),
+ IsNotNull(resolveColumn(tr, "e")))))
+
+ verifyConstraints(
+ tr.where(('a.attr * 'b.attr - 10) >= 'c.attr && 'd / 10 < 'e).analyze.constraints,
+ ExpressionSet(Seq(
+ Cast(resolveColumn(tr, "a"), LongType) * resolveColumn(tr, "b") - Cast(10, LongType) >=
+ Cast(resolveColumn(tr, "c"), LongType),
+ Cast(resolveColumn(tr, "d"), DoubleType) /
+ Cast(Cast(10, LongType), DoubleType) <
+ Cast(resolveColumn(tr, "e"), DoubleType),
+ IsNotNull(resolveColumn(tr, "a")),
+ IsNotNull(resolveColumn(tr, "b")),
+ IsNotNull(resolveColumn(tr, "c")),
+ IsNotNull(resolveColumn(tr, "d")),
+ IsNotNull(resolveColumn(tr, "e")))))
+
+ verifyConstraints(
+ tr.where('a.attr + 'b.attr - 'c.attr * 'd.attr > 'e.attr * 1000).analyze.constraints,
+ ExpressionSet(Seq(
+ (Cast(resolveColumn(tr, "a"), LongType) + resolveColumn(tr, "b")) -
+ (Cast(resolveColumn(tr, "c"), LongType) * resolveColumn(tr, "d")) >
+ Cast(resolveColumn(tr, "e") * 1000, LongType),
+ IsNotNull(resolveColumn(tr, "a")),
+ IsNotNull(resolveColumn(tr, "b")),
+ IsNotNull(resolveColumn(tr, "c")),
+ IsNotNull(resolveColumn(tr, "d")),
+ IsNotNull(resolveColumn(tr, "e")))))
+
+ // The constraint IsNotNull(IsNotNull(expr)) doesn't guarantee expr is not null.
+ verifyConstraints(
+ tr.where('a.attr === 'c.attr &&
+ IsNotNull(IsNotNull(resolveColumn(tr, "b")))).analyze.constraints,
+ ExpressionSet(Seq(
+ resolveColumn(tr, "a") === resolveColumn(tr, "c"),
+ IsNotNull(IsNotNull(resolveColumn(tr, "b"))),
+ IsNotNull(resolveColumn(tr, "a")),
+ IsNotNull(resolveColumn(tr, "c")))))
+ }
+
test("infer IsNotNull constraints from non-nullable attributes") {
val tr = LocalRelation('a.int, AttributeReference("b", IntegerType, nullable = false)(),
AttributeReference("c", StringType, nullable = false)())
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala
index 0541844e0b..7191936699 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala
@@ -19,7 +19,8 @@ package org.apache.spark.sql.catalyst.plans
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, OneRowRelation}
+import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
+import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, OneRowRelation, Sample}
import org.apache.spark.sql.catalyst.util._
/**
@@ -32,29 +33,37 @@ abstract class PlanTest extends SparkFunSuite with PredicateHelper {
*/
protected def normalizeExprIds(plan: LogicalPlan) = {
plan transformAllExpressions {
+ case s: ScalarSubquery =>
+ ScalarSubquery(s.query, ExprId(0))
case a: AttributeReference =>
AttributeReference(a.name, a.dataType, a.nullable)(exprId = ExprId(0))
case a: Alias =>
Alias(a.child, a.name)(exprId = ExprId(0))
+ case ae: AggregateExpression =>
+ ae.copy(resultId = ExprId(0))
}
}
/**
- * Normalizes the filter conditions that appear in the plan. For instance,
- * ((expr 1 && expr 2) && expr 3), (expr 1 && expr 2 && expr 3), (expr 3 && (expr 1 && expr 2)
- * etc., will all now be equivalent.
+ * Normalizes plans:
+ * - Filter the filter conditions that appear in a plan. For instance,
+ * ((expr 1 && expr 2) && expr 3), (expr 1 && expr 2 && expr 3), (expr 3 && (expr 1 && expr 2)
+ * etc., will all now be equivalent.
+ * - Sample the seed will replaced by 0L.
*/
- private def normalizeFilters(plan: LogicalPlan) = {
+ private def normalizePlan(plan: LogicalPlan): LogicalPlan = {
plan transform {
case filter @ Filter(condition: Expression, child: LogicalPlan) =>
Filter(splitConjunctivePredicates(condition).sortBy(_.hashCode()).reduce(And), child)
+ case sample: Sample =>
+ sample.copy(seed = 0L)(true)
}
}
/** Fails the test if the two plans do not match */
protected def comparePlans(plan1: LogicalPlan, plan2: LogicalPlan) {
- val normalized1 = normalizeFilters(normalizeExprIds(plan1))
- val normalized2 = normalizeFilters(normalizeExprIds(plan2))
+ val normalized1 = normalizePlan(normalizeExprIds(plan1))
+ val normalized2 = normalizePlan(normalizeExprIds(plan2))
if (normalized1 != normalized2) {
fail(
s"""
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/SameResultSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/SameResultSuite.scala
index 37941cf34e..467f76193c 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/SameResultSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/SameResultSuite.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.plans
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
-import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
+import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Union}
import org.apache.spark.sql.catalyst.util._
/**
@@ -61,4 +61,9 @@ class SameResultSuite extends SparkFunSuite {
test("sorts") {
assertSameResult(testRelation.orderBy('a.asc), testRelation2.orderBy('a.asc))
}
+
+ test("union") {
+ assertSameResult(Union(Seq(testRelation, testRelation2)),
+ Union(Seq(testRelation2, testRelation)))
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/StringUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/StringUtilsSuite.scala
index d6f273f9e5..2ffc18a8d1 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/StringUtilsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/StringUtilsSuite.scala
@@ -31,4 +31,16 @@ class StringUtilsSuite extends SparkFunSuite {
assert(escapeLikeRegex("**") === "(?s)\\Q*\\E\\Q*\\E")
assert(escapeLikeRegex("a_b") === "(?s)\\Qa\\E.\\Qb\\E")
}
+
+ test("filter pattern") {
+ val names = Seq("a1", "a2", "b2", "c3")
+ assert(filterPattern(names, " * ") === Seq("a1", "a2", "b2", "c3"))
+ assert(filterPattern(names, "*a*") === Seq("a1", "a2"))
+ assert(filterPattern(names, " *a* ") === Seq("a1", "a2"))
+ assert(filterPattern(names, " a* ") === Seq("a1", "a2"))
+ assert(filterPattern(names, " a.* ") === Seq("a1", "a2"))
+ assert(filterPattern(names, " B.*|a* ") === Seq("a1", "a2", "b2"))
+ assert(filterPattern(names, " a. ") === Seq("a1", "a2"))
+ assert(filterPattern(names, " d* ") === Nil)
+ }
}