aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorHerman van Hovell <hvanhovell@questtec.nl>2016-01-15 15:19:10 -0800
committerReynold Xin <rxin@databricks.com>2016-01-15 15:19:10 -0800
commit7cd7f2202547224593517b392f56e49e4c94cabc (patch)
tree3deb6f260ce94c59d2e25bc29095582dfd637173
parent3f1c58d60b85625ab3abf16456ce27c820453ecf (diff)
downloadspark-7cd7f2202547224593517b392f56e49e4c94cabc.tar.gz
spark-7cd7f2202547224593517b392f56e49e4c94cabc.tar.bz2
spark-7cd7f2202547224593517b392f56e49e4c94cabc.zip
[SPARK-12575][SQL] Grammar parity with existing SQL parser
In this PR the new CatalystQl parser stack reaches grammar parity with the old Parser-Combinator based SQL Parser. This PR also replaces all uses of the old Parser, and removes it from the code base. Although the existing Hive and SQL parser dialects were mostly the same, some kinks had to be worked out: - The SQL Parser allowed syntax like ```APPROXIMATE(0.01) COUNT(DISTINCT a)```. In order to make this work we needed to hardcode approximate operators in the parser, or we would have to create an approximate expression. ```APPROXIMATE_COUNT_DISTINCT(a, 0.01)``` would also do the job and is much easier to maintain. So, this PR **removes** this keyword. - The old SQL Parser supports ```LIMIT``` clauses in nested queries. This is **not supported** anymore. See https://github.com/apache/spark/pull/10689 for the rationale for this. - Hive has a charset name char set literal combination it supports, for instance the following expression ```_ISO-8859-1 0x4341464562616265``` would yield this string: ```CAFEbabe```. Hive will only allow charset names to start with an underscore. This is quite annoying in spark because as soon as you use a tuple names will start with an underscore. In this PR we **remove** this feature from the parser. It would be quite easy to implement such a feature as an Expression later on. - Hive and the SQL Parser treat decimal literals differently. Hive will turn any decimal into a ```Double``` whereas the SQL Parser would convert a non-scientific decimal into a ```BigDecimal```, and would turn a scientific decimal into a Double. We follow Hive's behavior here. The new parser supports a big decimal literal, for instance: ```81923801.42BD```, which can be used when a big decimal is needed. cc rxin viirya marmbrus yhuai cloud-fan Author: Herman van Hovell <hvanhovell@questtec.nl> Closes #10745 from hvanhovell/SPARK-12575-2.
-rw-r--r--python/pyspark/sql/tests.py3
-rw-r--r--sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/ExpressionParser.g57
-rw-r--r--sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/FromClauseParser.g7
-rw-r--r--sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlLexer.g32
-rw-r--r--sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g10
-rw-r--r--sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/parser/ParseUtils.java31
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala4
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala139
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ParserDialect.scala46
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala509
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala1
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystQlSuite.scala1
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/SqlParserSuite.scala150
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/Column.scala1
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala8
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala6
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala17
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSQLParser.scala19
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DDLParser.scala17
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/functions.scala7
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala73
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala12
-rw-r--r--sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala4
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/ExtendedHiveQlParser.scala18
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala21
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala4
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala7
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQlSuite.scala9
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala5
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala20
-rw-r--r--unsafe/src/main/java/org/apache/spark/unsafe/types/CalendarInterval.java14
33 files changed, 286 insertions, 972 deletions
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index e396cf41f2..c03cb9338a 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -1081,8 +1081,7 @@ class SQLTests(ReusedPySparkTestCase):
def test_capture_analysis_exception(self):
self.assertRaises(AnalysisException, lambda: self.sqlCtx.sql("select abc"))
self.assertRaises(AnalysisException, lambda: self.df.selectExpr("a + b"))
- # RuntimeException should not be captured
- self.assertRaises(py4j.protocol.Py4JJavaError, lambda: self.sqlCtx.sql("abc"))
+ self.assertRaises(AnalysisException, lambda: self.sqlCtx.sql("abc"))
def test_capture_illegalargument_exception(self):
self.assertRaisesRegexp(IllegalArgumentException, "Setting negative mapred.reduce.tasks",
diff --git a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/ExpressionParser.g b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/ExpressionParser.g
index aabb5d4958..047a7e56cb 100644
--- a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/ExpressionParser.g
+++ b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/ExpressionParser.g
@@ -123,7 +123,6 @@ constant
| SmallintLiteral
| TinyintLiteral
| DecimalLiteral
- | charSetStringLiteral
| booleanValue
;
@@ -132,13 +131,6 @@ stringLiteralSequence
StringLiteral StringLiteral+ -> ^(TOK_STRINGLITERALSEQUENCE StringLiteral StringLiteral+)
;
-charSetStringLiteral
-@init { gParent.pushMsg("character string literal", state); }
-@after { gParent.popMsg(state); }
- :
- csName=CharSetName csLiteral=CharSetLiteral -> ^(TOK_CHARSETLITERAL $csName $csLiteral)
- ;
-
dateLiteral
:
KW_DATE StringLiteral ->
@@ -163,22 +155,38 @@ timestampLiteral
intervalLiteral
:
- KW_INTERVAL StringLiteral qualifiers=intervalQualifiers ->
- {
- adaptor.create($qualifiers.tree.token.getType(), $StringLiteral.text)
+ (KW_INTERVAL intervalConstant KW_YEAR KW_TO KW_MONTH) => KW_INTERVAL intervalConstant KW_YEAR KW_TO KW_MONTH
+ -> ^(TOK_INTERVAL_YEAR_MONTH_LITERAL intervalConstant)
+ | (KW_INTERVAL intervalConstant KW_DAY KW_TO KW_SECOND) => KW_INTERVAL intervalConstant KW_DAY KW_TO KW_SECOND
+ -> ^(TOK_INTERVAL_DAY_TIME_LITERAL intervalConstant)
+ | KW_INTERVAL
+ ((intervalConstant KW_YEAR)=> year=intervalConstant KW_YEAR)?
+ ((intervalConstant KW_MONTH)=> month=intervalConstant KW_MONTH)?
+ ((intervalConstant KW_WEEK)=> week=intervalConstant KW_WEEK)?
+ ((intervalConstant KW_DAY)=> day=intervalConstant KW_DAY)?
+ ((intervalConstant KW_HOUR)=> hour=intervalConstant KW_HOUR)?
+ ((intervalConstant KW_MINUTE)=> minute=intervalConstant KW_MINUTE)?
+ ((intervalConstant KW_SECOND)=> second=intervalConstant KW_SECOND)?
+ (millisecond=intervalConstant KW_MILLISECOND)?
+ (microsecond=intervalConstant KW_MICROSECOND)?
+ -> ^(TOK_INTERVAL
+ ^(TOK_INTERVAL_YEAR_LITERAL $year?)
+ ^(TOK_INTERVAL_MONTH_LITERAL $month?)
+ ^(TOK_INTERVAL_WEEK_LITERAL $week?)
+ ^(TOK_INTERVAL_DAY_LITERAL $day?)
+ ^(TOK_INTERVAL_HOUR_LITERAL $hour?)
+ ^(TOK_INTERVAL_MINUTE_LITERAL $minute?)
+ ^(TOK_INTERVAL_SECOND_LITERAL $second?)
+ ^(TOK_INTERVAL_MILLISECOND_LITERAL $millisecond?)
+ ^(TOK_INTERVAL_MICROSECOND_LITERAL $microsecond?))
+ ;
+
+intervalConstant
+ :
+ sign=(MINUS|PLUS)? value=Number -> {
+ adaptor.create(Number, ($sign != null ? $sign.getText() : "") + $value.getText())
}
- ;
-
-intervalQualifiers
- :
- KW_YEAR KW_TO KW_MONTH -> TOK_INTERVAL_YEAR_MONTH_LITERAL
- | KW_DAY KW_TO KW_SECOND -> TOK_INTERVAL_DAY_TIME_LITERAL
- | KW_YEAR -> TOK_INTERVAL_YEAR_LITERAL
- | KW_MONTH -> TOK_INTERVAL_MONTH_LITERAL
- | KW_DAY -> TOK_INTERVAL_DAY_LITERAL
- | KW_HOUR -> TOK_INTERVAL_HOUR_LITERAL
- | KW_MINUTE -> TOK_INTERVAL_MINUTE_LITERAL
- | KW_SECOND -> TOK_INTERVAL_SECOND_LITERAL
+ | StringLiteral
;
expression
@@ -219,7 +227,8 @@ nullCondition
precedenceUnaryPrefixExpression
:
- (precedenceUnaryOperator^)* precedenceFieldExpression
+ (precedenceUnaryOperator+)=> precedenceUnaryOperator^ precedenceUnaryPrefixExpression
+ | precedenceFieldExpression
;
precedenceUnarySuffixExpression
diff --git a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/FromClauseParser.g b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/FromClauseParser.g
index 972c52e3ff..6d76afcd4a 100644
--- a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/FromClauseParser.g
+++ b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/FromClauseParser.g
@@ -206,11 +206,8 @@ tableName
@init { gParent.pushMsg("table name", state); }
@after { gParent.popMsg(state); }
:
- db=identifier DOT tab=identifier
- -> ^(TOK_TABNAME $db $tab)
- |
- tab=identifier
- -> ^(TOK_TABNAME $tab)
+ id1=identifier (DOT id2=identifier)?
+ -> ^(TOK_TABNAME $id1 $id2?)
;
viewName
diff --git a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlLexer.g b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlLexer.g
index 44a63fbef2..ee2882e51c 100644
--- a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlLexer.g
+++ b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlLexer.g
@@ -307,12 +307,12 @@ KW_AUTHORIZATION: 'AUTHORIZATION';
KW_CONF: 'CONF';
KW_VALUES: 'VALUES';
KW_RELOAD: 'RELOAD';
-KW_YEAR: 'YEAR';
-KW_MONTH: 'MONTH';
-KW_DAY: 'DAY';
-KW_HOUR: 'HOUR';
-KW_MINUTE: 'MINUTE';
-KW_SECOND: 'SECOND';
+KW_YEAR: 'YEAR'|'YEARS';
+KW_MONTH: 'MONTH'|'MONTHS';
+KW_DAY: 'DAY'|'DAYS';
+KW_HOUR: 'HOUR'|'HOURS';
+KW_MINUTE: 'MINUTE'|'MINUTES';
+KW_SECOND: 'SECOND'|'SECONDS';
KW_START: 'START';
KW_TRANSACTION: 'TRANSACTION';
KW_COMMIT: 'COMMIT';
@@ -324,6 +324,9 @@ KW_ISOLATION: 'ISOLATION';
KW_LEVEL: 'LEVEL';
KW_SNAPSHOT: 'SNAPSHOT';
KW_AUTOCOMMIT: 'AUTOCOMMIT';
+KW_WEEK: 'WEEK'|'WEEKS';
+KW_MILLISECOND: 'MILLISECOND'|'MILLISECONDS';
+KW_MICROSECOND: 'MICROSECOND'|'MICROSECONDS';
// Operators
// NOTE: if you add a new function/operator, add it to sysFuncNames so that describe function _FUNC_ will work.
@@ -400,12 +403,6 @@ StringLiteral
)+
;
-CharSetLiteral
- :
- StringLiteral
- | '0' 'X' (HexDigit|Digit)+
- ;
-
BigintLiteral
:
(Digit)+ 'L'
@@ -433,7 +430,7 @@ ByteLengthLiteral
Number
:
- (Digit)+ ( DOT (Digit)* (Exponent)? | Exponent)?
+ ((Digit+ (DOT Digit*)?) | (DOT Digit+)) Exponent?
;
/*
@@ -456,10 +453,10 @@ An Identifier can be:
- macro name
- hint name
- window name
-*/
+*/
Identifier
:
- (Letter | Digit) (Letter | Digit | '_')*
+ (Letter | Digit | '_')+
| {allowQuotedId()}? QuotedIdentifier /* though at the language level we allow all Identifiers to be QuotedIdentifiers;
at the API level only columns are allowed to be of this form */
| '`' RegexComponent+ '`'
@@ -471,11 +468,6 @@ QuotedIdentifier
'`' ( '``' | ~('`') )* '`' { setText(getText().substring(1, getText().length() -1 ).replaceAll("``", "`")); }
;
-CharSetName
- :
- '_' (Letter | Digit | '_' | '-' | '.' | ':' )+
- ;
-
WS : (' '|'\r'|'\t'|'\n') {$channel=HIDDEN;}
;
diff --git a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g
index 2c13d3056f..c146ca5914 100644
--- a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g
+++ b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g
@@ -116,16 +116,20 @@ TOK_DATELITERAL;
TOK_DATETIME;
TOK_TIMESTAMP;
TOK_TIMESTAMPLITERAL;
+TOK_INTERVAL;
TOK_INTERVAL_YEAR_MONTH;
TOK_INTERVAL_YEAR_MONTH_LITERAL;
TOK_INTERVAL_DAY_TIME;
TOK_INTERVAL_DAY_TIME_LITERAL;
TOK_INTERVAL_YEAR_LITERAL;
TOK_INTERVAL_MONTH_LITERAL;
+TOK_INTERVAL_WEEK_LITERAL;
TOK_INTERVAL_DAY_LITERAL;
TOK_INTERVAL_HOUR_LITERAL;
TOK_INTERVAL_MINUTE_LITERAL;
TOK_INTERVAL_SECOND_LITERAL;
+TOK_INTERVAL_MILLISECOND_LITERAL;
+TOK_INTERVAL_MICROSECOND_LITERAL;
TOK_STRING;
TOK_CHAR;
TOK_VARCHAR;
@@ -228,7 +232,6 @@ TOK_TMP_FILE;
TOK_TABSORTCOLNAMEASC;
TOK_TABSORTCOLNAMEDESC;
TOK_STRINGLITERALSEQUENCE;
-TOK_CHARSETLITERAL;
TOK_CREATEFUNCTION;
TOK_DROPFUNCTION;
TOK_RELOADFUNCTION;
@@ -509,7 +512,9 @@ import java.util.HashMap;
xlateMap.put("KW_UPDATE", "UPDATE");
xlateMap.put("KW_VALUES", "VALUES");
xlateMap.put("KW_PURGE", "PURGE");
-
+ xlateMap.put("KW_WEEK", "WEEK");
+ xlateMap.put("KW_MILLISECOND", "MILLISECOND");
+ xlateMap.put("KW_MICROSECOND", "MICROSECOND");
// Operators
xlateMap.put("DOT", ".");
@@ -2078,6 +2083,7 @@ primitiveType
| KW_SMALLINT -> TOK_SMALLINT
| KW_INT -> TOK_INT
| KW_BIGINT -> TOK_BIGINT
+ | KW_LONG -> TOK_BIGINT
| KW_BOOLEAN -> TOK_BOOLEAN
| KW_FLOAT -> TOK_FLOAT
| KW_DOUBLE -> TOK_DOUBLE
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/parser/ParseUtils.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/parser/ParseUtils.java
index 5bc87b680f..2520c7bb8d 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/parser/ParseUtils.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/parser/ParseUtils.java
@@ -18,12 +18,10 @@
package org.apache.spark.sql.catalyst.parser;
-import java.io.UnsupportedEncodingException;
-
/**
* A couple of utility methods that help with parsing ASTs.
*
- * Both methods in this class were take from the SemanticAnalyzer in Hive:
+ * The 'unescapeSQLString' method in this class was take from the SemanticAnalyzer in Hive:
* ql/src/java/org/apache/hadoop/hive/ql/parse/BaseSemanticAnalyzer.java
*/
public final class ParseUtils {
@@ -31,33 +29,6 @@ public final class ParseUtils {
super();
}
- public static String charSetString(String charSetName, String charSetString)
- throws UnsupportedEncodingException {
- // The character set name starts with a _, so strip that
- charSetName = charSetName.substring(1);
- if (charSetString.charAt(0) == '\'') {
- return new String(unescapeSQLString(charSetString).getBytes(), charSetName);
- } else // hex input is also supported
- {
- assert charSetString.charAt(0) == '0';
- assert charSetString.charAt(1) == 'x';
- charSetString = charSetString.substring(2);
-
- byte[] bArray = new byte[charSetString.length() / 2];
- int j = 0;
- for (int i = 0; i < charSetString.length(); i += 2) {
- int val = Character.digit(charSetString.charAt(i), 16) * 16
- + Character.digit(charSetString.charAt(i + 1), 16);
- if (val > 127) {
- val = val - 256;
- }
- bArray[j++] = (byte)val;
- }
-
- return new String(bArray, charSetName);
- }
- }
-
private static final int[] multiplier = new int[] {1000, 100, 10, 1};
@SuppressWarnings("nls")
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala
index bdc52c08ac..9443369808 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala
@@ -26,9 +26,9 @@ import scala.util.parsing.input.CharArrayReader.EofCh
import org.apache.spark.sql.catalyst.plans.logical._
private[sql] abstract class AbstractSparkSQLParser
- extends StandardTokenParsers with PackratParsers {
+ extends StandardTokenParsers with PackratParsers with ParserDialect {
- def parse(input: String): LogicalPlan = synchronized {
+ def parsePlan(input: String): LogicalPlan = synchronized {
// Initialize the Keywords.
initLexical
phrase(start)(new lexical.Scanner(input)) match {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala
index d0fbdacf6e..c1591ecfe2 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala
@@ -30,16 +30,10 @@ import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.CalendarInterval
import org.apache.spark.util.random.RandomSampler
-private[sql] object CatalystQl {
- val parser = new CatalystQl
- def parseExpression(sql: String): Expression = parser.parseExpression(sql)
- def parseTableIdentifier(sql: String): TableIdentifier = parser.parseTableIdentifier(sql)
-}
-
/**
* This class translates a HQL String to a Catalyst [[LogicalPlan]] or [[Expression]].
*/
-private[sql] class CatalystQl(val conf: ParserConf = SimpleParserConf()) {
+private[sql] class CatalystQl(val conf: ParserConf = SimpleParserConf()) extends ParserDialect {
object Token {
def unapply(node: ASTNode): Some[(String, List[ASTNode])] = {
CurrentOrigin.setPosition(node.line, node.positionInLine)
@@ -611,13 +605,6 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
case plainIdent => plainIdent
}
- val numericAstTypes = Seq(
- SparkSqlParser.Number,
- SparkSqlParser.TinyintLiteral,
- SparkSqlParser.SmallintLiteral,
- SparkSqlParser.BigintLiteral,
- SparkSqlParser.DecimalLiteral)
-
/* Case insensitive matches */
val COUNT = "(?i)COUNT".r
val SUM = "(?i)SUM".r
@@ -635,6 +622,8 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
val WHEN = "(?i)WHEN".r
val CASE = "(?i)CASE".r
+ val INTEGRAL = "[+-]?\\d+".r
+
protected def nodeToExpr(node: ASTNode): Expression = node match {
/* Attribute References */
case Token("TOK_TABLE_OR_COL", Token(name, Nil) :: Nil) =>
@@ -650,8 +639,8 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
case Token("TOK_ALLCOLREF", Nil) => UnresolvedStar(None)
// The format of dbName.tableName.* cannot be parsed by HiveParser. TOK_TABNAME will only
// has a single child which is tableName.
- case Token("TOK_ALLCOLREF", Token("TOK_TABNAME", Token(name, Nil) :: Nil) :: Nil) =>
- UnresolvedStar(Some(UnresolvedAttribute.parseAttributeName(name)))
+ case Token("TOK_ALLCOLREF", Token("TOK_TABNAME", target) :: Nil) if target.nonEmpty =>
+ UnresolvedStar(Some(target.map(_.text)))
/* Aggregate Functions */
case Token("TOK_FUNCTIONDI", Token(COUNT(), Nil) :: args) =>
@@ -787,71 +776,71 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
case Token("TOK_STRINGLITERALSEQUENCE", strings) =>
Literal(strings.map(s => ParseUtils.unescapeSQLString(s.text)).mkString)
- // This code is adapted from
- // /ql/src/java/org/apache/hadoop/hive/ql/parse/TypeCheckProcFactory.java#L223
- case ast: ASTNode if numericAstTypes contains ast.tokenType =>
- var v: Literal = null
- try {
- if (ast.text.endsWith("L")) {
- // Literal bigint.
- v = Literal.create(ast.text.substring(0, ast.text.length() - 1).toLong, LongType)
- } else if (ast.text.endsWith("S")) {
- // Literal smallint.
- v = Literal.create(ast.text.substring(0, ast.text.length() - 1).toShort, ShortType)
- } else if (ast.text.endsWith("Y")) {
- // Literal tinyint.
- v = Literal.create(ast.text.substring(0, ast.text.length() - 1).toByte, ByteType)
- } else if (ast.text.endsWith("BD") || ast.text.endsWith("D")) {
- // Literal decimal
- val strVal = ast.text.stripSuffix("D").stripSuffix("B")
- v = Literal(Decimal(strVal))
- } else {
- v = Literal.create(ast.text.toDouble, DoubleType)
- v = Literal.create(ast.text.toLong, LongType)
- v = Literal.create(ast.text.toInt, IntegerType)
- }
- } catch {
- case nfe: NumberFormatException => // Do nothing
- }
-
- if (v == null) {
- sys.error(s"Failed to parse number '${ast.text}'.")
- } else {
- v
- }
-
- case ast: ASTNode if ast.tokenType == SparkSqlParser.StringLiteral =>
- Literal(ParseUtils.unescapeSQLString(ast.text))
+ case ast if ast.tokenType == SparkSqlParser.TinyintLiteral =>
+ Literal.create(ast.text.substring(0, ast.text.length() - 1).toByte, ByteType)
- case ast: ASTNode if ast.tokenType == SparkSqlParser.TOK_DATELITERAL =>
- Literal(Date.valueOf(ast.text.substring(1, ast.text.length - 1)))
+ case ast if ast.tokenType == SparkSqlParser.SmallintLiteral =>
+ Literal.create(ast.text.substring(0, ast.text.length() - 1).toShort, ShortType)
- case ast: ASTNode if ast.tokenType == SparkSqlParser.TOK_CHARSETLITERAL =>
- Literal(ParseUtils.charSetString(ast.children.head.text, ast.children(1).text))
+ case ast if ast.tokenType == SparkSqlParser.BigintLiteral =>
+ Literal.create(ast.text.substring(0, ast.text.length() - 1).toLong, LongType)
- case ast: ASTNode if ast.tokenType == SparkSqlParser.TOK_INTERVAL_YEAR_MONTH_LITERAL =>
- Literal(CalendarInterval.fromYearMonthString(ast.text))
+ case ast if ast.tokenType == SparkSqlParser.DecimalLiteral =>
+ Literal(Decimal(ast.text.substring(0, ast.text.length() - 2)))
- case ast: ASTNode if ast.tokenType == SparkSqlParser.TOK_INTERVAL_DAY_TIME_LITERAL =>
- Literal(CalendarInterval.fromDayTimeString(ast.text))
-
- case ast: ASTNode if ast.tokenType == SparkSqlParser.TOK_INTERVAL_YEAR_LITERAL =>
- Literal(CalendarInterval.fromSingleUnitString("year", ast.text))
-
- case ast: ASTNode if ast.tokenType == SparkSqlParser.TOK_INTERVAL_MONTH_LITERAL =>
- Literal(CalendarInterval.fromSingleUnitString("month", ast.text))
-
- case ast: ASTNode if ast.tokenType == SparkSqlParser.TOK_INTERVAL_DAY_LITERAL =>
- Literal(CalendarInterval.fromSingleUnitString("day", ast.text))
-
- case ast: ASTNode if ast.tokenType == SparkSqlParser.TOK_INTERVAL_HOUR_LITERAL =>
- Literal(CalendarInterval.fromSingleUnitString("hour", ast.text))
+ case ast if ast.tokenType == SparkSqlParser.Number =>
+ val text = ast.text
+ text match {
+ case INTEGRAL() =>
+ BigDecimal(text) match {
+ case v if v.isValidInt =>
+ Literal(v.intValue())
+ case v if v.isValidLong =>
+ Literal(v.longValue())
+ case v => Literal(v.underlying())
+ }
+ case _ =>
+ Literal(text.toDouble)
+ }
+ case ast if ast.tokenType == SparkSqlParser.StringLiteral =>
+ Literal(ParseUtils.unescapeSQLString(ast.text))
- case ast: ASTNode if ast.tokenType == SparkSqlParser.TOK_INTERVAL_MINUTE_LITERAL =>
- Literal(CalendarInterval.fromSingleUnitString("minute", ast.text))
+ case ast if ast.tokenType == SparkSqlParser.TOK_DATELITERAL =>
+ Literal(Date.valueOf(ast.text.substring(1, ast.text.length - 1)))
- case ast: ASTNode if ast.tokenType == SparkSqlParser.TOK_INTERVAL_SECOND_LITERAL =>
- Literal(CalendarInterval.fromSingleUnitString("second", ast.text))
+ case ast if ast.tokenType == SparkSqlParser.TOK_INTERVAL_YEAR_MONTH_LITERAL =>
+ Literal(CalendarInterval.fromYearMonthString(ast.children.head.text))
+
+ case ast if ast.tokenType == SparkSqlParser.TOK_INTERVAL_DAY_TIME_LITERAL =>
+ Literal(CalendarInterval.fromDayTimeString(ast.children.head.text))
+
+ case Token("TOK_INTERVAL", elements) =>
+ var interval = new CalendarInterval(0, 0)
+ var updated = false
+ elements.foreach {
+ // The interval node will always contain children for all possible time units. A child node
+ // is only useful when it contains exactly one (numeric) child.
+ case e @ Token(name, Token(value, Nil) :: Nil) =>
+ val unit = name match {
+ case "TOK_INTERVAL_YEAR_LITERAL" => "year"
+ case "TOK_INTERVAL_MONTH_LITERAL" => "month"
+ case "TOK_INTERVAL_WEEK_LITERAL" => "week"
+ case "TOK_INTERVAL_DAY_LITERAL" => "day"
+ case "TOK_INTERVAL_HOUR_LITERAL" => "hour"
+ case "TOK_INTERVAL_MINUTE_LITERAL" => "minute"
+ case "TOK_INTERVAL_SECOND_LITERAL" => "second"
+ case "TOK_INTERVAL_MILLISECOND_LITERAL" => "millisecond"
+ case "TOK_INTERVAL_MICROSECOND_LITERAL" => "microsecond"
+ case _ => noParseRule(s"Interval($name)", e)
+ }
+ interval = interval.add(CalendarInterval.fromSingleUnitString(unit, value))
+ updated = true
+ case _ =>
+ }
+ if (!updated) {
+ throw new AnalysisException("at least one time unit should be given for interval literal")
+ }
+ Literal(interval)
case _ =>
noParseRule("Expression", node)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ParserDialect.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ParserDialect.scala
index e21d3c0546..7d9fbf2f12 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ParserDialect.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ParserDialect.scala
@@ -18,52 +18,22 @@
package org.apache.spark.sql.catalyst
import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
/**
* Root class of SQL Parser Dialect, and we don't guarantee the binary
* compatibility for the future release, let's keep it as the internal
* interface for advanced user.
- *
*/
@DeveloperApi
-abstract class ParserDialect {
- // this is the main function that will be implemented by sql parser.
- def parse(sqlText: String): LogicalPlan
-}
+trait ParserDialect {
+ /** Creates LogicalPlan for a given SQL string. */
+ def parsePlan(sqlText: String): LogicalPlan
-/**
- * Currently we support the default dialect named "sql", associated with the class
- * [[DefaultParserDialect]]
- *
- * And we can also provide custom SQL Dialect, for example in Spark SQL CLI:
- * {{{
- *-- switch to "hiveql" dialect
- * spark-sql>SET spark.sql.dialect=hiveql;
- * spark-sql>SELECT * FROM src LIMIT 1;
- *
- *-- switch to "sql" dialect
- * spark-sql>SET spark.sql.dialect=sql;
- * spark-sql>SELECT * FROM src LIMIT 1;
- *
- *-- register the new SQL dialect
- * spark-sql> SET spark.sql.dialect=com.xxx.xxx.SQL99Dialect;
- * spark-sql> SELECT * FROM src LIMIT 1;
- *
- *-- register the non-exist SQL dialect
- * spark-sql> SET spark.sql.dialect=NotExistedClass;
- * spark-sql> SELECT * FROM src LIMIT 1;
- *
- *-- Exception will be thrown and switch to dialect
- *-- "sql" (for SQLContext) or
- *-- "hiveql" (for HiveContext)
- * }}}
- */
-private[spark] class DefaultParserDialect extends ParserDialect {
- @transient
- protected val sqlParser = SqlParser
+ /** Creates Expression for a given SQL string. */
+ def parseExpression(sqlText: String): Expression
- override def parse(sqlText: String): LogicalPlan = {
- sqlParser.parse(sqlText)
- }
+ /** Creates TableIdentifier for a given SQL string. */
+ def parseTableIdentifier(sqlText: String): TableIdentifier
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
deleted file mode 100644
index 85ff4ea0c9..0000000000
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
+++ /dev/null
@@ -1,509 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.catalyst
-
-import scala.language.implicitConversions
-
-import org.apache.spark.sql.AnalysisException
-import org.apache.spark.sql.catalyst.analysis._
-import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.expressions.aggregate._
-import org.apache.spark.sql.catalyst.plans._
-import org.apache.spark.sql.catalyst.plans.logical._
-import org.apache.spark.sql.catalyst.util.DataTypeParser
-import org.apache.spark.sql.types._
-import org.apache.spark.unsafe.types.CalendarInterval
-
-/**
- * A very simple SQL parser. Based loosely on:
- * https://github.com/stephentu/scala-sql-parser/blob/master/src/main/scala/parser.scala
- *
- * Limitations:
- * - Only supports a very limited subset of SQL.
- *
- * This is currently included mostly for illustrative purposes. Users wanting more complete support
- * for a SQL like language should checkout the HiveQL support in the sql/hive sub-project.
- */
-object SqlParser extends AbstractSparkSQLParser with DataTypeParser {
-
- def parseExpression(input: String): Expression = synchronized {
- // Initialize the Keywords.
- initLexical
- phrase(projection)(new lexical.Scanner(input)) match {
- case Success(plan, _) => plan
- case failureOrError => sys.error(failureOrError.toString)
- }
- }
-
- def parseTableIdentifier(input: String): TableIdentifier = synchronized {
- // Initialize the Keywords.
- initLexical
- phrase(tableIdentifier)(new lexical.Scanner(input)) match {
- case Success(ident, _) => ident
- case failureOrError => sys.error(failureOrError.toString)
- }
- }
-
- // Keyword is a convention with AbstractSparkSQLParser, which will scan all of the `Keyword`
- // properties via reflection the class in runtime for constructing the SqlLexical object
- protected val ALL = Keyword("ALL")
- protected val AND = Keyword("AND")
- protected val APPROXIMATE = Keyword("APPROXIMATE")
- protected val AS = Keyword("AS")
- protected val ASC = Keyword("ASC")
- protected val BETWEEN = Keyword("BETWEEN")
- protected val BY = Keyword("BY")
- protected val CASE = Keyword("CASE")
- protected val CAST = Keyword("CAST")
- protected val DESC = Keyword("DESC")
- protected val DISTINCT = Keyword("DISTINCT")
- protected val ELSE = Keyword("ELSE")
- protected val END = Keyword("END")
- protected val EXCEPT = Keyword("EXCEPT")
- protected val FALSE = Keyword("FALSE")
- protected val FROM = Keyword("FROM")
- protected val FULL = Keyword("FULL")
- protected val GROUP = Keyword("GROUP")
- protected val HAVING = Keyword("HAVING")
- protected val IN = Keyword("IN")
- protected val INNER = Keyword("INNER")
- protected val INSERT = Keyword("INSERT")
- protected val INTERSECT = Keyword("INTERSECT")
- protected val INTERVAL = Keyword("INTERVAL")
- protected val INTO = Keyword("INTO")
- protected val IS = Keyword("IS")
- protected val JOIN = Keyword("JOIN")
- protected val LEFT = Keyword("LEFT")
- protected val LIKE = Keyword("LIKE")
- protected val LIMIT = Keyword("LIMIT")
- protected val NOT = Keyword("NOT")
- protected val NULL = Keyword("NULL")
- protected val ON = Keyword("ON")
- protected val OR = Keyword("OR")
- protected val ORDER = Keyword("ORDER")
- protected val SORT = Keyword("SORT")
- protected val OUTER = Keyword("OUTER")
- protected val OVERWRITE = Keyword("OVERWRITE")
- protected val REGEXP = Keyword("REGEXP")
- protected val RIGHT = Keyword("RIGHT")
- protected val RLIKE = Keyword("RLIKE")
- protected val SELECT = Keyword("SELECT")
- protected val SEMI = Keyword("SEMI")
- protected val TABLE = Keyword("TABLE")
- protected val THEN = Keyword("THEN")
- protected val TRUE = Keyword("TRUE")
- protected val UNION = Keyword("UNION")
- protected val WHEN = Keyword("WHEN")
- protected val WHERE = Keyword("WHERE")
- protected val WITH = Keyword("WITH")
-
- protected lazy val start: Parser[LogicalPlan] =
- start1 | insert | cte
-
- protected lazy val start1: Parser[LogicalPlan] =
- (select | ("(" ~> select <~ ")")) *
- ( UNION ~ ALL ^^^ { (q1: LogicalPlan, q2: LogicalPlan) => Union(q1, q2) }
- | INTERSECT ^^^ { (q1: LogicalPlan, q2: LogicalPlan) => Intersect(q1, q2) }
- | EXCEPT ^^^ { (q1: LogicalPlan, q2: LogicalPlan) => Except(q1, q2)}
- | UNION ~ DISTINCT.? ^^^ { (q1: LogicalPlan, q2: LogicalPlan) => Distinct(Union(q1, q2)) }
- )
-
- protected lazy val select: Parser[LogicalPlan] =
- SELECT ~> DISTINCT.? ~
- repsep(projection, ",") ~
- (FROM ~> relations).? ~
- (WHERE ~> expression).? ~
- (GROUP ~ BY ~> rep1sep(expression, ",")).? ~
- (HAVING ~> expression).? ~
- sortType.? ~
- (LIMIT ~> expression).? ^^ {
- case d ~ p ~ r ~ f ~ g ~ h ~ o ~ l =>
- val base = r.getOrElse(OneRowRelation)
- val withFilter = f.map(Filter(_, base)).getOrElse(base)
- val withProjection = g
- .map(Aggregate(_, p.map(UnresolvedAlias(_)), withFilter))
- .getOrElse(Project(p.map(UnresolvedAlias(_)), withFilter))
- val withDistinct = d.map(_ => Distinct(withProjection)).getOrElse(withProjection)
- val withHaving = h.map(Filter(_, withDistinct)).getOrElse(withDistinct)
- val withOrder = o.map(_(withHaving)).getOrElse(withHaving)
- val withLimit = l.map(Limit(_, withOrder)).getOrElse(withOrder)
- withLimit
- }
-
- protected lazy val insert: Parser[LogicalPlan] =
- INSERT ~> (OVERWRITE ^^^ true | INTO ^^^ false) ~ (TABLE ~> relation) ~ select ^^ {
- case o ~ r ~ s => InsertIntoTable(r, Map.empty[String, Option[String]], s, o, false)
- }
-
- protected lazy val cte: Parser[LogicalPlan] =
- WITH ~> rep1sep(ident ~ ( AS ~ "(" ~> start1 <~ ")"), ",") ~ (start1 | insert) ^^ {
- case r ~ s => With(s, r.map({case n ~ s => (n, Subquery(n, s))}).toMap)
- }
-
- protected lazy val projection: Parser[Expression] =
- expression ~ (AS.? ~> ident.?) ^^ {
- case e ~ a => a.fold(e)(Alias(e, _)())
- }
-
- // Based very loosely on the MySQL Grammar.
- // http://dev.mysql.com/doc/refman/5.0/en/join.html
- protected lazy val relations: Parser[LogicalPlan] =
- ( relation ~ rep1("," ~> relation) ^^ {
- case r1 ~ joins => joins.foldLeft(r1) { case(lhs, r) => Join(lhs, r, Inner, None) } }
- | relation
- )
-
- protected lazy val relation: Parser[LogicalPlan] =
- joinedRelation | relationFactor
-
- protected lazy val relationFactor: Parser[LogicalPlan] =
- ( tableIdentifier ~ (opt(AS) ~> opt(ident)) ^^ {
- case tableIdent ~ alias => UnresolvedRelation(tableIdent, alias)
- }
- | ("(" ~> start <~ ")") ~ (AS.? ~> ident) ^^ { case s ~ a => Subquery(a, s) }
- )
-
- protected lazy val joinedRelation: Parser[LogicalPlan] =
- relationFactor ~ rep1(joinType.? ~ (JOIN ~> relationFactor) ~ joinConditions.?) ^^ {
- case r1 ~ joins =>
- joins.foldLeft(r1) { case (lhs, jt ~ rhs ~ cond) =>
- Join(lhs, rhs, joinType = jt.getOrElse(Inner), cond)
- }
- }
-
- protected lazy val joinConditions: Parser[Expression] =
- ON ~> expression
-
- protected lazy val joinType: Parser[JoinType] =
- ( INNER ^^^ Inner
- | LEFT ~ SEMI ^^^ LeftSemi
- | LEFT ~ OUTER.? ^^^ LeftOuter
- | RIGHT ~ OUTER.? ^^^ RightOuter
- | FULL ~ OUTER.? ^^^ FullOuter
- )
-
- protected lazy val sortType: Parser[LogicalPlan => LogicalPlan] =
- ( ORDER ~ BY ~> ordering ^^ { case o => l: LogicalPlan => Sort(o, true, l) }
- | SORT ~ BY ~> ordering ^^ { case o => l: LogicalPlan => Sort(o, false, l) }
- )
-
- protected lazy val ordering: Parser[Seq[SortOrder]] =
- ( rep1sep(expression ~ direction.?, ",") ^^ {
- case exps => exps.map(pair => SortOrder(pair._1, pair._2.getOrElse(Ascending)))
- }
- )
-
- protected lazy val direction: Parser[SortDirection] =
- ( ASC ^^^ Ascending
- | DESC ^^^ Descending
- )
-
- protected lazy val expression: Parser[Expression] =
- orExpression
-
- protected lazy val orExpression: Parser[Expression] =
- andExpression * (OR ^^^ { (e1: Expression, e2: Expression) => Or(e1, e2) })
-
- protected lazy val andExpression: Parser[Expression] =
- notExpression * (AND ^^^ { (e1: Expression, e2: Expression) => And(e1, e2) })
-
- protected lazy val notExpression: Parser[Expression] =
- NOT.? ~ comparisonExpression ^^ { case maybeNot ~ e => maybeNot.map(_ => Not(e)).getOrElse(e) }
-
- protected lazy val comparisonExpression: Parser[Expression] =
- ( termExpression ~ ("=" ~> termExpression) ^^ { case e1 ~ e2 => EqualTo(e1, e2) }
- | termExpression ~ ("<" ~> termExpression) ^^ { case e1 ~ e2 => LessThan(e1, e2) }
- | termExpression ~ ("<=" ~> termExpression) ^^ { case e1 ~ e2 => LessThanOrEqual(e1, e2) }
- | termExpression ~ (">" ~> termExpression) ^^ { case e1 ~ e2 => GreaterThan(e1, e2) }
- | termExpression ~ (">=" ~> termExpression) ^^ { case e1 ~ e2 => GreaterThanOrEqual(e1, e2) }
- | termExpression ~ ("!=" ~> termExpression) ^^ { case e1 ~ e2 => Not(EqualTo(e1, e2)) }
- | termExpression ~ ("<>" ~> termExpression) ^^ { case e1 ~ e2 => Not(EqualTo(e1, e2)) }
- | termExpression ~ ("<=>" ~> termExpression) ^^ { case e1 ~ e2 => EqualNullSafe(e1, e2) }
- | termExpression ~ NOT.? ~ (BETWEEN ~> termExpression) ~ (AND ~> termExpression) ^^ {
- case e ~ not ~ el ~ eu =>
- val betweenExpr: Expression = And(GreaterThanOrEqual(e, el), LessThanOrEqual(e, eu))
- not.fold(betweenExpr)(f => Not(betweenExpr))
- }
- | termExpression ~ (RLIKE ~> termExpression) ^^ { case e1 ~ e2 => RLike(e1, e2) }
- | termExpression ~ (REGEXP ~> termExpression) ^^ { case e1 ~ e2 => RLike(e1, e2) }
- | termExpression ~ (LIKE ~> termExpression) ^^ { case e1 ~ e2 => Like(e1, e2) }
- | termExpression ~ (NOT ~ LIKE ~> termExpression) ^^ { case e1 ~ e2 => Not(Like(e1, e2)) }
- | termExpression ~ (IN ~ "(" ~> rep1sep(termExpression, ",")) <~ ")" ^^ {
- case e1 ~ e2 => In(e1, e2)
- }
- | termExpression ~ (NOT ~ IN ~ "(" ~> rep1sep(termExpression, ",")) <~ ")" ^^ {
- case e1 ~ e2 => Not(In(e1, e2))
- }
- | termExpression <~ IS ~ NULL ^^ { case e => IsNull(e) }
- | termExpression <~ IS ~ NOT ~ NULL ^^ { case e => IsNotNull(e) }
- | termExpression
- )
-
- protected lazy val termExpression: Parser[Expression] =
- productExpression *
- ( "+" ^^^ { (e1: Expression, e2: Expression) => Add(e1, e2) }
- | "-" ^^^ { (e1: Expression, e2: Expression) => Subtract(e1, e2) }
- )
-
- protected lazy val productExpression: Parser[Expression] =
- baseExpression *
- ( "*" ^^^ { (e1: Expression, e2: Expression) => Multiply(e1, e2) }
- | "/" ^^^ { (e1: Expression, e2: Expression) => Divide(e1, e2) }
- | "%" ^^^ { (e1: Expression, e2: Expression) => Remainder(e1, e2) }
- | "&" ^^^ { (e1: Expression, e2: Expression) => BitwiseAnd(e1, e2) }
- | "|" ^^^ { (e1: Expression, e2: Expression) => BitwiseOr(e1, e2) }
- | "^" ^^^ { (e1: Expression, e2: Expression) => BitwiseXor(e1, e2) }
- )
-
- protected lazy val function: Parser[Expression] =
- ( ident <~ ("(" ~ "*" ~ ")") ^^ { case udfName =>
- if (lexical.normalizeKeyword(udfName) == "count") {
- AggregateExpression(Count(Literal(1)), mode = Complete, isDistinct = false)
- } else {
- throw new AnalysisException(s"invalid expression $udfName(*)")
- }
- }
- | ident ~ ("(" ~> repsep(expression, ",")) <~ ")" ^^
- { case udfName ~ exprs => UnresolvedFunction(udfName, exprs, isDistinct = false) }
- | ident ~ ("(" ~ DISTINCT ~> repsep(expression, ",")) <~ ")" ^^ { case udfName ~ exprs =>
- lexical.normalizeKeyword(udfName) match {
- case "count" =>
- aggregate.Count(exprs).toAggregateExpression(isDistinct = true)
- case _ => UnresolvedFunction(udfName, exprs, isDistinct = true)
- }
- }
- | APPROXIMATE ~> ident ~ ("(" ~ DISTINCT ~> expression <~ ")") ^^ { case udfName ~ exp =>
- if (lexical.normalizeKeyword(udfName) == "count") {
- AggregateExpression(new HyperLogLogPlusPlus(exp), mode = Complete, isDistinct = false)
- } else {
- throw new AnalysisException(s"invalid function approximate $udfName")
- }
- }
- | APPROXIMATE ~> "(" ~> unsignedFloat ~ ")" ~ ident ~ "(" ~ DISTINCT ~ expression <~ ")" ^^
- { case s ~ _ ~ udfName ~ _ ~ _ ~ exp =>
- if (lexical.normalizeKeyword(udfName) == "count") {
- AggregateExpression(
- HyperLogLogPlusPlus(exp, s.toDouble, 0, 0),
- mode = Complete,
- isDistinct = false)
- } else {
- throw new AnalysisException(s"invalid function approximate($s) $udfName")
- }
- }
- | CASE ~> whenThenElse ^^
- { case branches => CaseWhen.createFromParser(branches) }
- | CASE ~> expression ~ whenThenElse ^^
- { case keyPart ~ branches => CaseKeyWhen(keyPart, branches) }
- )
-
- protected lazy val whenThenElse: Parser[List[Expression]] =
- rep1(WHEN ~> expression ~ (THEN ~> expression)) ~ (ELSE ~> expression).? <~ END ^^ {
- case altPart ~ elsePart =>
- altPart.flatMap { case whenExpr ~ thenExpr =>
- Seq(whenExpr, thenExpr)
- } ++ elsePart
- }
-
- protected lazy val cast: Parser[Expression] =
- CAST ~ "(" ~> expression ~ (AS ~> dataType) <~ ")" ^^ {
- case exp ~ t => Cast(exp, t)
- }
-
- protected lazy val literal: Parser[Literal] =
- ( numericLiteral
- | booleanLiteral
- | stringLit ^^ { case s => Literal.create(s, StringType) }
- | intervalLiteral
- | NULL ^^^ Literal.create(null, NullType)
- )
-
- protected lazy val booleanLiteral: Parser[Literal] =
- ( TRUE ^^^ Literal.create(true, BooleanType)
- | FALSE ^^^ Literal.create(false, BooleanType)
- )
-
- protected lazy val numericLiteral: Parser[Literal] =
- ( integral ^^ { case i => Literal(toNarrowestIntegerType(i)) }
- | sign.? ~ unsignedFloat ^^
- { case s ~ f => Literal(toDecimalOrDouble(s.getOrElse("") + f)) }
- )
-
- protected lazy val unsignedFloat: Parser[String] =
- ( "." ~> numericLit ^^ { u => "0." + u }
- | elem("decimal", _.isInstanceOf[lexical.DecimalLit]) ^^ (_.chars)
- )
-
- protected lazy val sign: Parser[String] = ("+" | "-")
-
- protected lazy val integral: Parser[String] =
- sign.? ~ numericLit ^^ { case s ~ n => s.getOrElse("") + n }
-
- private def intervalUnit(unitName: String) = acceptIf {
- case lexical.Identifier(str) =>
- val normalized = lexical.normalizeKeyword(str)
- normalized == unitName || normalized == unitName + "s"
- case _ => false
- } {_ => "wrong interval unit"}
-
- protected lazy val month: Parser[Int] =
- integral <~ intervalUnit("month") ^^ { case num => num.toInt }
-
- protected lazy val year: Parser[Int] =
- integral <~ intervalUnit("year") ^^ { case num => num.toInt * 12 }
-
- protected lazy val microsecond: Parser[Long] =
- integral <~ intervalUnit("microsecond") ^^ { case num => num.toLong }
-
- protected lazy val millisecond: Parser[Long] =
- integral <~ intervalUnit("millisecond") ^^ {
- case num => num.toLong * CalendarInterval.MICROS_PER_MILLI
- }
-
- protected lazy val second: Parser[Long] =
- integral <~ intervalUnit("second") ^^ {
- case num => num.toLong * CalendarInterval.MICROS_PER_SECOND
- }
-
- protected lazy val minute: Parser[Long] =
- integral <~ intervalUnit("minute") ^^ {
- case num => num.toLong * CalendarInterval.MICROS_PER_MINUTE
- }
-
- protected lazy val hour: Parser[Long] =
- integral <~ intervalUnit("hour") ^^ {
- case num => num.toLong * CalendarInterval.MICROS_PER_HOUR
- }
-
- protected lazy val day: Parser[Long] =
- integral <~ intervalUnit("day") ^^ {
- case num => num.toLong * CalendarInterval.MICROS_PER_DAY
- }
-
- protected lazy val week: Parser[Long] =
- integral <~ intervalUnit("week") ^^ {
- case num => num.toLong * CalendarInterval.MICROS_PER_WEEK
- }
-
- private def intervalKeyword(keyword: String) = acceptIf {
- case lexical.Identifier(str) =>
- lexical.normalizeKeyword(str) == keyword
- case _ => false
- } {_ => "wrong interval keyword"}
-
- protected lazy val intervalLiteral: Parser[Literal] =
- ( INTERVAL ~> stringLit <~ intervalKeyword("year") ~ intervalKeyword("to") ~
- intervalKeyword("month") ^^ { case s =>
- Literal(CalendarInterval.fromYearMonthString(s))
- }
- | INTERVAL ~> stringLit <~ intervalKeyword("day") ~ intervalKeyword("to") ~
- intervalKeyword("second") ^^ { case s =>
- Literal(CalendarInterval.fromDayTimeString(s))
- }
- | INTERVAL ~> stringLit <~ intervalKeyword("year") ^^ { case s =>
- Literal(CalendarInterval.fromSingleUnitString("year", s))
- }
- | INTERVAL ~> stringLit <~ intervalKeyword("month") ^^ { case s =>
- Literal(CalendarInterval.fromSingleUnitString("month", s))
- }
- | INTERVAL ~> stringLit <~ intervalKeyword("day") ^^ { case s =>
- Literal(CalendarInterval.fromSingleUnitString("day", s))
- }
- | INTERVAL ~> stringLit <~ intervalKeyword("hour") ^^ { case s =>
- Literal(CalendarInterval.fromSingleUnitString("hour", s))
- }
- | INTERVAL ~> stringLit <~ intervalKeyword("minute") ^^ { case s =>
- Literal(CalendarInterval.fromSingleUnitString("minute", s))
- }
- | INTERVAL ~> stringLit <~ intervalKeyword("second") ^^ { case s =>
- Literal(CalendarInterval.fromSingleUnitString("second", s))
- }
- | INTERVAL ~> year.? ~ month.? ~ week.? ~ day.? ~ hour.? ~ minute.? ~ second.? ~
- millisecond.? ~ microsecond.? ^^ { case year ~ month ~ week ~ day ~ hour ~ minute ~ second ~
- millisecond ~ microsecond =>
- if (!Seq(year, month, week, day, hour, minute, second,
- millisecond, microsecond).exists(_.isDefined)) {
- throw new AnalysisException(
- "at least one time unit should be given for interval literal")
- }
- val months = Seq(year, month).map(_.getOrElse(0)).sum
- val microseconds = Seq(week, day, hour, minute, second, millisecond, microsecond)
- .map(_.getOrElse(0L)).sum
- Literal(new CalendarInterval(months, microseconds))
- }
- )
-
- private def toNarrowestIntegerType(value: String): Any = {
- val bigIntValue = BigDecimal(value)
-
- bigIntValue match {
- case v if bigIntValue.isValidInt => v.toIntExact
- case v if bigIntValue.isValidLong => v.toLongExact
- case v => v.underlying()
- }
- }
-
- private def toDecimalOrDouble(value: String): Any = {
- val decimal = BigDecimal(value)
- // follow the behavior in MS SQL Server
- // https://msdn.microsoft.com/en-us/library/ms179899.aspx
- if (value.contains('E') || value.contains('e')) {
- decimal.doubleValue()
- } else {
- decimal.underlying()
- }
- }
-
- protected lazy val baseExpression: Parser[Expression] =
- ( "*" ^^^ UnresolvedStar(None)
- | rep1(ident <~ ".") <~ "*" ^^ { case target => UnresolvedStar(Option(target))}
- | primary
- )
-
- protected lazy val signedPrimary: Parser[Expression] =
- sign ~ primary ^^ { case s ~ e => if (s == "-") UnaryMinus(e) else e }
-
- protected lazy val attributeName: Parser[String] = acceptMatch("attribute name", {
- case lexical.Identifier(str) => str
- case lexical.Keyword(str) if !lexical.delimiters.contains(str) => str
- })
-
- protected lazy val primary: PackratParser[Expression] =
- ( literal
- | expression ~ ("[" ~> expression <~ "]") ^^
- { case base ~ ordinal => UnresolvedExtractValue(base, ordinal) }
- | (expression <~ ".") ~ ident ^^
- { case base ~ fieldName => UnresolvedExtractValue(base, Literal(fieldName)) }
- | cast
- | "(" ~> expression <~ ")"
- | function
- | dotExpressionHeader
- | signedPrimary
- | "~" ~> expression ^^ BitwiseNot
- | attributeName ^^ UnresolvedAttribute.quoted
- )
-
- protected lazy val dotExpressionHeader: Parser[Expression] =
- (ident <~ ".") ~ ident ~ rep("." ~> ident) ^^ {
- case i1 ~ i2 ~ rest => UnresolvedAttribute(Seq(i1, i2) ++ rest)
- }
-
- protected lazy val tableIdentifier: Parser[TableIdentifier] =
- (ident <~ ".").? ~ ident ^^ {
- case maybeDbName ~ tableName => TableIdentifier(tableName, maybeDbName)
- }
-}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala
index e1fd22e367..ec833d6789 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala
@@ -447,6 +447,7 @@ object HyperLogLogPlusPlus {
private def validateDoubleLiteral(exp: Expression): Double = exp match {
case Literal(d: Double, DoubleType) => d
+ case Literal(dec: Decimal, _) => dec.toDouble
case _ =>
throw new AnalysisException("The second argument should be a double literal.")
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystQlSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystQlSuite.scala
index ba9d2524a9..6d25de98ce 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystQlSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystQlSuite.scala
@@ -108,6 +108,7 @@ class CatalystQlSuite extends PlanTest {
}
assertRight("9.0e1", 90)
+ assertRight(".9e+2", 90)
assertRight("0.9e+2", 90)
assertRight("900e-1", 90)
assertRight("900.0E-1", 90)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/SqlParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/SqlParserSuite.scala
deleted file mode 100644
index b0884f5287..0000000000
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/SqlParserSuite.scala
+++ /dev/null
@@ -1,150 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.catalyst
-
-import org.apache.spark.sql.catalyst.analysis.UnresolvedAlias
-import org.apache.spark.sql.catalyst.expressions.{Attribute, GreaterThan, Literal, Not}
-import org.apache.spark.sql.catalyst.plans.PlanTest
-import org.apache.spark.sql.catalyst.plans.logical.{Command, LogicalPlan, OneRowRelation, Project}
-import org.apache.spark.unsafe.types.CalendarInterval
-
-private[sql] case class TestCommand(cmd: String) extends LogicalPlan with Command {
- override def output: Seq[Attribute] = Seq.empty
- override def children: Seq[LogicalPlan] = Seq.empty
-}
-
-private[sql] class SuperLongKeywordTestParser extends AbstractSparkSQLParser {
- protected val EXECUTE = Keyword("THISISASUPERLONGKEYWORDTEST")
-
- override protected lazy val start: Parser[LogicalPlan] = set
-
- private lazy val set: Parser[LogicalPlan] =
- EXECUTE ~> ident ^^ {
- case fileName => TestCommand(fileName)
- }
-}
-
-private[sql] class CaseInsensitiveTestParser extends AbstractSparkSQLParser {
- protected val EXECUTE = Keyword("EXECUTE")
-
- override protected lazy val start: Parser[LogicalPlan] = set
-
- private lazy val set: Parser[LogicalPlan] =
- EXECUTE ~> ident ^^ {
- case fileName => TestCommand(fileName)
- }
-}
-
-class SqlParserSuite extends PlanTest {
-
- test("test long keyword") {
- val parser = new SuperLongKeywordTestParser
- assert(TestCommand("NotRealCommand") ===
- parser.parse("ThisIsASuperLongKeyWordTest NotRealCommand"))
- }
-
- test("test case insensitive") {
- val parser = new CaseInsensitiveTestParser
- assert(TestCommand("NotRealCommand") === parser.parse("EXECUTE NotRealCommand"))
- assert(TestCommand("NotRealCommand") === parser.parse("execute NotRealCommand"))
- assert(TestCommand("NotRealCommand") === parser.parse("exEcute NotRealCommand"))
- }
-
- test("test NOT operator with comparison operations") {
- val parsed = SqlParser.parse("SELECT NOT TRUE > TRUE")
- val expected = Project(
- UnresolvedAlias(
- Not(
- GreaterThan(Literal(true), Literal(true)))
- ) :: Nil,
- OneRowRelation)
- comparePlans(parsed, expected)
- }
-
- test("support hive interval literal") {
- def checkInterval(sql: String, result: CalendarInterval): Unit = {
- val parsed = SqlParser.parse(sql)
- val expected = Project(
- UnresolvedAlias(
- Literal(result)
- ) :: Nil,
- OneRowRelation)
- comparePlans(parsed, expected)
- }
-
- def checkYearMonth(lit: String): Unit = {
- checkInterval(
- s"SELECT INTERVAL '$lit' YEAR TO MONTH",
- CalendarInterval.fromYearMonthString(lit))
- }
-
- def checkDayTime(lit: String): Unit = {
- checkInterval(
- s"SELECT INTERVAL '$lit' DAY TO SECOND",
- CalendarInterval.fromDayTimeString(lit))
- }
-
- def checkSingleUnit(lit: String, unit: String): Unit = {
- checkInterval(
- s"SELECT INTERVAL '$lit' $unit",
- CalendarInterval.fromSingleUnitString(unit, lit))
- }
-
- checkYearMonth("123-10")
- checkYearMonth("496-0")
- checkYearMonth("-2-3")
- checkYearMonth("-123-0")
-
- checkDayTime("99 11:22:33.123456789")
- checkDayTime("-99 11:22:33.123456789")
- checkDayTime("10 9:8:7.123456789")
- checkDayTime("1 0:0:0")
- checkDayTime("-1 0:0:0")
- checkDayTime("1 0:0:1")
-
- for (unit <- Seq("year", "month", "day", "hour", "minute", "second")) {
- checkSingleUnit("7", unit)
- checkSingleUnit("-7", unit)
- checkSingleUnit("0", unit)
- }
-
- checkSingleUnit("13.123456789", "second")
- checkSingleUnit("-13.123456789", "second")
- }
-
- test("support scientific notation") {
- def assertRight(input: String, output: Double): Unit = {
- val parsed = SqlParser.parse("SELECT " + input)
- val expected = Project(
- UnresolvedAlias(
- Literal(output)
- ) :: Nil,
- OneRowRelation)
- comparePlans(parsed, expected)
- }
-
- assertRight("9.0e1", 90)
- assertRight(".9e+2", 90)
- assertRight("0.9e+2", 90)
- assertRight("900e-1", 90)
- assertRight("900.0E-1", 90)
- assertRight("9.e+1", 90)
-
- intercept[RuntimeException](SqlParser.parse("SELECT .e3"))
- }
-}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
index 6a020f9f28..97bf7a0cc4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
@@ -21,7 +21,6 @@ import scala.language.implicitConversions
import org.apache.spark.Logging
import org.apache.spark.annotation.Experimental
-import org.apache.spark.sql.catalyst.SqlParser._
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder}
import org.apache.spark.sql.catalyst.expressions._
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
index 91bf2f8ce4..3422d0ead4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
@@ -30,7 +30,7 @@ import org.apache.spark.annotation.{DeveloperApi, Experimental}
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.api.python.PythonRDD
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, ScalaReflection, SqlParser}
+import org.apache.spark.sql.catalyst._
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
@@ -737,7 +737,7 @@ class DataFrame private[sql](
@scala.annotation.varargs
def selectExpr(exprs: String*): DataFrame = {
select(exprs.map { expr =>
- Column(SqlParser.parseExpression(expr))
+ Column(sqlContext.sqlParser.parseExpression(expr))
}: _*)
}
@@ -764,7 +764,7 @@ class DataFrame private[sql](
* @since 1.3.0
*/
def filter(conditionExpr: String): DataFrame = {
- filter(Column(SqlParser.parseExpression(conditionExpr)))
+ filter(Column(sqlContext.sqlParser.parseExpression(conditionExpr)))
}
/**
@@ -788,7 +788,7 @@ class DataFrame private[sql](
* @since 1.5.0
*/
def where(conditionExpr: String): DataFrame = {
- filter(Column(SqlParser.parseExpression(conditionExpr)))
+ filter(Column(sqlContext.sqlParser.parseExpression(conditionExpr)))
}
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
index d948e48942..8f852e5216 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
@@ -29,7 +29,7 @@ import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.catalyst.SqlParser
+import org.apache.spark.sql.catalyst.{CatalystQl}
import org.apache.spark.sql.execution.datasources.{LogicalRelation, ResolvedDataSource}
import org.apache.spark.sql.execution.datasources.jdbc.{JDBCPartition, JDBCPartitioningInfo, JDBCRelation}
import org.apache.spark.sql.execution.datasources.json.JSONRelation
@@ -337,7 +337,7 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging {
*/
def table(tableName: String): DataFrame = {
DataFrame(sqlContext,
- sqlContext.catalog.lookupRelation(SqlParser.parseTableIdentifier(tableName)))
+ sqlContext.catalog.lookupRelation(sqlContext.sqlParser.parseTableIdentifier(tableName)))
}
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
index 00f9817b53..ab63fe4aa8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
@@ -22,7 +22,7 @@ import java.util.Properties
import scala.collection.JavaConverters._
import org.apache.spark.annotation.Experimental
-import org.apache.spark.sql.catalyst.{SqlParser, TableIdentifier}
+import org.apache.spark.sql.catalyst.{CatalystQl, TableIdentifier}
import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoTable, Project}
import org.apache.spark.sql.execution.datasources.{BucketSpec, CreateTableUsingAsSelect, ResolvedDataSource}
@@ -192,7 +192,7 @@ final class DataFrameWriter private[sql](df: DataFrame) {
* @since 1.4.0
*/
def insertInto(tableName: String): Unit = {
- insertInto(SqlParser.parseTableIdentifier(tableName))
+ insertInto(df.sqlContext.sqlParser.parseTableIdentifier(tableName))
}
private def insertInto(tableIdent: TableIdentifier): Unit = {
@@ -282,7 +282,7 @@ final class DataFrameWriter private[sql](df: DataFrame) {
* @since 1.4.0
*/
def saveAsTable(tableName: String): Unit = {
- saveAsTable(SqlParser.parseTableIdentifier(tableName))
+ saveAsTable(df.sqlContext.sqlParser.parseTableIdentifier(tableName))
}
private def saveAsTable(tableIdent: TableIdentifier): Unit = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index b909765a7c..a0939adb6d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -39,6 +39,7 @@ import org.apache.spark.sql.catalyst.encoders.encoderFor
import org.apache.spark.sql.catalyst.errors.DialectException
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.optimizer.Optimizer
+import org.apache.spark.sql.catalyst.parser.ParserConf
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Range}
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.execution._
@@ -205,15 +206,17 @@ class SQLContext private[sql](
protected[sql] lazy val optimizer: Optimizer = new SparkOptimizer(this)
@transient
- protected[sql] val ddlParser = new DDLParser(sqlParser.parse(_))
+ protected[sql] val ddlParser = new DDLParser(sqlParser)
@transient
- protected[sql] val sqlParser = new SparkSQLParser(getSQLDialect().parse(_))
+ protected[sql] val sqlParser = new SparkSQLParser(getSQLDialect())
protected[sql] def getSQLDialect(): ParserDialect = {
try {
val clazz = Utils.classForName(dialectClassName)
- clazz.newInstance().asInstanceOf[ParserDialect]
+ clazz.getConstructor(classOf[ParserConf])
+ .newInstance(conf)
+ .asInstanceOf[ParserDialect]
} catch {
case NonFatal(e) =>
// Since we didn't find the available SQL Dialect, it will fail even for SET command:
@@ -237,7 +240,7 @@ class SQLContext private[sql](
new sparkexecution.QueryExecution(this, plan)
protected[sql] def dialectClassName = if (conf.dialect == "sql") {
- classOf[DefaultParserDialect].getCanonicalName
+ classOf[SparkQl].getCanonicalName
} else {
conf.dialect
}
@@ -682,7 +685,7 @@ class SQLContext private[sql](
tableName: String,
source: String,
options: Map[String, String]): DataFrame = {
- val tableIdent = SqlParser.parseTableIdentifier(tableName)
+ val tableIdent = sqlParser.parseTableIdentifier(tableName)
val cmd =
CreateTableUsing(
tableIdent,
@@ -728,7 +731,7 @@ class SQLContext private[sql](
source: String,
schema: StructType,
options: Map[String, String]): DataFrame = {
- val tableIdent = SqlParser.parseTableIdentifier(tableName)
+ val tableIdent = sqlParser.parseTableIdentifier(tableName)
val cmd =
CreateTableUsing(
tableIdent,
@@ -833,7 +836,7 @@ class SQLContext private[sql](
* @since 1.3.0
*/
def table(tableName: String): DataFrame = {
- table(SqlParser.parseTableIdentifier(tableName))
+ table(sqlParser.parseTableIdentifier(tableName))
}
private def table(tableIdent: TableIdentifier): DataFrame = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSQLParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSQLParser.scala
index b3e8d0d849..1af2c756cd 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSQLParser.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSQLParser.scala
@@ -19,8 +19,8 @@ package org.apache.spark.sql.execution
import scala.util.parsing.combinator.RegexParsers
-import org.apache.spark.sql.catalyst.AbstractSparkSQLParser
-import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
+import org.apache.spark.sql.catalyst.{AbstractSparkSQLParser, ParserDialect, TableIdentifier}
+import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression}
import org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.types.StringType
@@ -29,9 +29,16 @@ import org.apache.spark.sql.types.StringType
* The top level Spark SQL parser. This parser recognizes syntaxes that are available for all SQL
* dialects supported by Spark SQL, and delegates all the other syntaxes to the `fallback` parser.
*
- * @param fallback A function that parses an input string to a logical plan
+ * @param fallback A function that returns the next parser in the chain. This is a call-by-name
+ * parameter because this allows us to return a different dialect if we
+ * have to.
*/
-class SparkSQLParser(fallback: String => LogicalPlan) extends AbstractSparkSQLParser {
+class SparkSQLParser(fallback: => ParserDialect) extends AbstractSparkSQLParser {
+
+ override def parseExpression(sql: String): Expression = fallback.parseExpression(sql)
+
+ override def parseTableIdentifier(sql: String): TableIdentifier =
+ fallback.parseTableIdentifier(sql)
// A parser for the key-value part of the "SET [key = [value ]]" syntax
private object SetCommandParser extends RegexParsers {
@@ -74,7 +81,7 @@ class SparkSQLParser(fallback: String => LogicalPlan) extends AbstractSparkSQLPa
private lazy val cache: Parser[LogicalPlan] =
CACHE ~> LAZY.? ~ (TABLE ~> ident) ~ (AS ~> restInput).? ^^ {
case isLazy ~ tableName ~ plan =>
- CacheTableCommand(tableName, plan.map(fallback), isLazy.isDefined)
+ CacheTableCommand(tableName, plan.map(fallback.parsePlan), isLazy.isDefined)
}
private lazy val uncache: Parser[LogicalPlan] =
@@ -111,7 +118,7 @@ class SparkSQLParser(fallback: String => LogicalPlan) extends AbstractSparkSQLPa
private lazy val others: Parser[LogicalPlan] =
wholeInput ^^ {
- case input => fallback(input)
+ case input => fallback.parsePlan(input)
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DDLParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DDLParser.scala
index d8d21b06b8..10655a85cc 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DDLParser.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DDLParser.scala
@@ -22,25 +22,30 @@ import scala.util.matching.Regex
import org.apache.spark.Logging
import org.apache.spark.sql.SaveMode
-import org.apache.spark.sql.catalyst.{AbstractSparkSQLParser, TableIdentifier}
+import org.apache.spark.sql.catalyst.{AbstractSparkSQLParser, ParserDialect, TableIdentifier}
import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
+import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.util.DataTypeParser
import org.apache.spark.sql.types._
-
/**
* A parser for foreign DDL commands.
*/
-class DDLParser(parseQuery: String => LogicalPlan)
+class DDLParser(fallback: => ParserDialect)
extends AbstractSparkSQLParser with DataTypeParser with Logging {
+ override def parseExpression(sql: String): Expression = fallback.parseExpression(sql)
+
+ override def parseTableIdentifier(sql: String): TableIdentifier =
+
+ fallback.parseTableIdentifier(sql)
def parse(input: String, exceptionOnError: Boolean): LogicalPlan = {
try {
- parse(input)
+ parsePlan(input)
} catch {
case ddlException: DDLException => throw ddlException
- case _ if !exceptionOnError => parseQuery(input)
+ case _ if !exceptionOnError => fallback.parsePlan(input)
case x: Throwable => throw x
}
}
@@ -104,7 +109,7 @@ class DDLParser(parseQuery: String => LogicalPlan)
SaveMode.ErrorIfExists
}
- val queryPlan = parseQuery(query.get)
+ val queryPlan = fallback.parsePlan(query.get)
CreateTableUsingAsSelect(tableIdent,
provider,
temp.isDefined,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index b8ea2261e9..8c2530fd68 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -22,7 +22,7 @@ import scala.reflect.runtime.universe.{typeTag, TypeTag}
import scala.util.Try
import org.apache.spark.annotation.Experimental
-import org.apache.spark.sql.catalyst.{ScalaReflection, SqlParser}
+import org.apache.spark.sql.catalyst.{CatalystQl, ScalaReflection}
import org.apache.spark.sql.catalyst.analysis.{Star, UnresolvedFunction}
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions._
@@ -1063,7 +1063,10 @@ object functions extends LegacyFunctions {
*
* @group normal_funcs
*/
- def expr(expr: String): Column = Column(SqlParser.parseExpression(expr))
+ def expr(expr: String): Column = {
+ val parser = SQLContext.getActive().map(_.getSQLDialect()).getOrElse(new CatalystQl())
+ Column(parser.parseExpression(expr))
+ }
//////////////////////////////////////////////////////////////////////////////////////////////
// Math Functions
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala
index 58f982c2bc..aec450e0a6 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala
@@ -212,7 +212,7 @@ class MathExpressionsSuite extends QueryTest with SharedSQLContext {
Seq(Row(5, 10, 0), Row(55, 60, 100), Row(555, 560, 600))
)
- val pi = 3.1415
+ val pi = "3.1415BD"
checkAnswer(
sql(s"SELECT round($pi, -3), round($pi, -2), round($pi, -1), " +
s"round($pi, 0), round($pi, 1), round($pi, 2), round($pi, 3)"),
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index 03d67c4e91..75e81b9c91 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -21,10 +21,11 @@ import java.math.MathContext
import java.sql.Timestamp
import org.apache.spark.AccumulatorSuite
-import org.apache.spark.sql.catalyst.DefaultParserDialect
+import org.apache.spark.sql.catalyst.CatalystQl
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry
import org.apache.spark.sql.catalyst.errors.DialectException
-import org.apache.spark.sql.execution.aggregate
+import org.apache.spark.sql.catalyst.parser.ParserConf
+import org.apache.spark.sql.execution.{aggregate, SparkQl}
import org.apache.spark.sql.execution.joins.{CartesianProduct, SortMergeJoin}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.{SharedSQLContext, TestSQLContext}
@@ -32,7 +33,7 @@ import org.apache.spark.sql.test.SQLTestData._
import org.apache.spark.sql.types._
/** A SQL Dialect for testing purpose, and it can not be nested type */
-class MyDialect extends DefaultParserDialect
+class MyDialect(conf: ParserConf) extends CatalystQl(conf)
class SQLQuerySuite extends QueryTest with SharedSQLContext {
import testImplicits._
@@ -161,7 +162,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
newContext.sql("SELECT 1")
}
// test if the dialect set back to DefaultSQLDialect
- assert(newContext.getSQLDialect().getClass === classOf[DefaultParserDialect])
+ assert(newContext.getSQLDialect().getClass === classOf[SparkQl])
}
test("SPARK-4625 support SORT BY in SimpleSQLParser & DSL") {
@@ -586,7 +587,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
}
test("Allow only a single WITH clause per query") {
- intercept[RuntimeException] {
+ intercept[AnalysisException] {
sql(
"with q1 as (select * from testData) with q2 as (select * from q1) select * from q2")
}
@@ -602,8 +603,8 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
test("from follow multiple brackets") {
checkAnswer(sql(
"""
- |select key from ((select * from testData limit 1)
- | union all (select * from testData limit 1)) x limit 1
+ |select key from ((select * from testData)
+ | union all (select * from testData)) x limit 1
""".stripMargin),
Row(1)
)
@@ -616,7 +617,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
checkAnswer(sql(
"""
|select key from
- | (select * from testData limit 1 union all select * from testData limit 1) x
+ | (select * from testData union all select * from testData) x
| limit 1
""".stripMargin),
Row(1)
@@ -649,13 +650,13 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
test("approximate count distinct") {
checkAnswer(
- sql("SELECT APPROXIMATE COUNT(DISTINCT a) FROM testData2"),
+ sql("SELECT APPROX_COUNT_DISTINCT(a) FROM testData2"),
Row(3))
}
test("approximate count distinct with user provided standard deviation") {
checkAnswer(
- sql("SELECT APPROXIMATE(0.04) COUNT(DISTINCT a) FROM testData2"),
+ sql("SELECT APPROX_COUNT_DISTINCT(a, 0.04) FROM testData2"),
Row(3))
}
@@ -1192,19 +1193,19 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
test("Floating point number format") {
checkAnswer(
- sql("SELECT 0.3"), Row(BigDecimal(0.3).underlying())
+ sql("SELECT 0.3"), Row(0.3)
)
checkAnswer(
- sql("SELECT -0.8"), Row(BigDecimal(-0.8).underlying())
+ sql("SELECT -0.8"), Row(-0.8)
)
checkAnswer(
- sql("SELECT .5"), Row(BigDecimal(0.5))
+ sql("SELECT .5"), Row(0.5)
)
checkAnswer(
- sql("SELECT -.18"), Row(BigDecimal(-0.18))
+ sql("SELECT -.18"), Row(-0.18)
)
}
@@ -1218,11 +1219,11 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
)
checkAnswer(
- sql("SELECT 9223372036854775808"), Row(new java.math.BigDecimal("9223372036854775808"))
+ sql("SELECT 9223372036854775808BD"), Row(new java.math.BigDecimal("9223372036854775808"))
)
checkAnswer(
- sql("SELECT -9223372036854775809"), Row(new java.math.BigDecimal("-9223372036854775809"))
+ sql("SELECT -9223372036854775809BD"), Row(new java.math.BigDecimal("-9223372036854775809"))
)
}
@@ -1237,11 +1238,11 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
)
checkAnswer(
- sql("SELECT -5.2"), Row(BigDecimal(-5.2))
+ sql("SELECT -5.2BD"), Row(BigDecimal(-5.2))
)
checkAnswer(
- sql("SELECT +6.8"), Row(BigDecimal(6.8))
+ sql("SELECT +6.8"), Row(6.8d)
)
checkAnswer(
@@ -1616,20 +1617,20 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
}
test("decimal precision with multiply/division") {
- checkAnswer(sql("select 10.3 * 3.0"), Row(BigDecimal("30.90")))
- checkAnswer(sql("select 10.3000 * 3.0"), Row(BigDecimal("30.90000")))
- checkAnswer(sql("select 10.30000 * 30.0"), Row(BigDecimal("309.000000")))
- checkAnswer(sql("select 10.300000000000000000 * 3.000000000000000000"),
+ checkAnswer(sql("select 10.3BD * 3.0BD"), Row(BigDecimal("30.90")))
+ checkAnswer(sql("select 10.3000BD * 3.0BD"), Row(BigDecimal("30.90000")))
+ checkAnswer(sql("select 10.30000BD * 30.0BD"), Row(BigDecimal("309.000000")))
+ checkAnswer(sql("select 10.300000000000000000BD * 3.000000000000000000BD"),
Row(BigDecimal("30.900000000000000000000000000000000000", new MathContext(38))))
- checkAnswer(sql("select 10.300000000000000000 * 3.0000000000000000000"),
+ checkAnswer(sql("select 10.300000000000000000BD * 3.0000000000000000000BD"),
Row(null))
- checkAnswer(sql("select 10.3 / 3.0"), Row(BigDecimal("3.433333")))
- checkAnswer(sql("select 10.3000 / 3.0"), Row(BigDecimal("3.4333333")))
- checkAnswer(sql("select 10.30000 / 30.0"), Row(BigDecimal("0.343333333")))
- checkAnswer(sql("select 10.300000000000000000 / 3.00000000000000000"),
+ checkAnswer(sql("select 10.3BD / 3.0BD"), Row(BigDecimal("3.433333")))
+ checkAnswer(sql("select 10.3000BD / 3.0BD"), Row(BigDecimal("3.4333333")))
+ checkAnswer(sql("select 10.30000BD / 30.0BD"), Row(BigDecimal("0.343333333")))
+ checkAnswer(sql("select 10.300000000000000000BD / 3.00000000000000000BD"),
Row(BigDecimal("3.433333333333333333333333333", new MathContext(38))))
- checkAnswer(sql("select 10.3000000000000000000 / 3.00000000000000000"),
+ checkAnswer(sql("select 10.3000000000000000000BD / 3.00000000000000000BD"),
Row(BigDecimal("3.4333333333333333333333333333", new MathContext(38))))
}
@@ -1655,13 +1656,13 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
}
test("precision smaller than scale") {
- checkAnswer(sql("select 10.00"), Row(BigDecimal("10.00")))
- checkAnswer(sql("select 1.00"), Row(BigDecimal("1.00")))
- checkAnswer(sql("select 0.10"), Row(BigDecimal("0.10")))
- checkAnswer(sql("select 0.01"), Row(BigDecimal("0.01")))
- checkAnswer(sql("select 0.001"), Row(BigDecimal("0.001")))
- checkAnswer(sql("select -0.01"), Row(BigDecimal("-0.01")))
- checkAnswer(sql("select -0.001"), Row(BigDecimal("-0.001")))
+ checkAnswer(sql("select 10.00BD"), Row(BigDecimal("10.00")))
+ checkAnswer(sql("select 1.00BD"), Row(BigDecimal("1.00")))
+ checkAnswer(sql("select 0.10BD"), Row(BigDecimal("0.10")))
+ checkAnswer(sql("select 0.01BD"), Row(BigDecimal("0.01")))
+ checkAnswer(sql("select 0.001BD"), Row(BigDecimal("0.001")))
+ checkAnswer(sql("select -0.01BD"), Row(BigDecimal("-0.01")))
+ checkAnswer(sql("select -0.001BD"), Row(BigDecimal("-0.001")))
}
test("external sorting updates peak execution memory") {
@@ -1750,7 +1751,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
assert(e1.message.contains("Table not found"))
val e2 = intercept[AnalysisException] {
- sql("select * from no_db.no_table")
+ sql("select * from no_db.no_table").show()
}
assert(e2.message.contains("Table not found"))
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala
index 860e07c68c..e70eb2a060 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala
@@ -442,13 +442,13 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
// Number and String conflict: resolve the type as number in this query.
checkAnswer(
- sql("select num_str + 1.2 from jsonTable where num_str > 14"),
+ sql("select num_str + 1.2BD from jsonTable where num_str > 14"),
Row(BigDecimal("92233720368547758071.2"))
)
// Number and String conflict: resolve the type as number in this query.
checkAnswer(
- sql("select num_str + 1.2 from jsonTable where num_str >= 92233720368547758060"),
+ sql("select num_str + 1.2BD from jsonTable where num_str >= 92233720368547758060BD"),
Row(new java.math.BigDecimal("92233720368547758071.2"))
)
@@ -856,7 +856,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
jsonWithSimpleMap.registerTempTable("jsonWithSimpleMap")
checkAnswer(
- sql("select map from jsonWithSimpleMap"),
+ sql("select `map` from jsonWithSimpleMap"),
Row(Map("a" -> 1)) ::
Row(Map("b" -> 2)) ::
Row(Map("c" -> 3)) ::
@@ -865,7 +865,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
)
checkAnswer(
- sql("select map['c'] from jsonWithSimpleMap"),
+ sql("select `map`['c'] from jsonWithSimpleMap"),
Row(null) ::
Row(null) ::
Row(3) ::
@@ -884,7 +884,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
jsonWithComplexMap.registerTempTable("jsonWithComplexMap")
checkAnswer(
- sql("select map from jsonWithComplexMap"),
+ sql("select `map` from jsonWithComplexMap"),
Row(Map("a" -> Row(Seq(1, 2, 3, null), null))) ::
Row(Map("b" -> Row(null, 2))) ::
Row(Map("c" -> Row(Seq(), 4))) ::
@@ -894,7 +894,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
)
checkAnswer(
- sql("select map['a'].field1, map['c'].field2 from jsonWithComplexMap"),
+ sql("select `map`['a'].field1, `map`['c'].field2 from jsonWithComplexMap"),
Row(Seq(1, 2, 3, null), null) ::
Row(null, null) ::
Row(null, 4) ::
diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala
index afd2f61158..828ec97105 100644
--- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala
+++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala
@@ -296,6 +296,9 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
// Odd changes to output
"merge4",
+ // Unsupported underscore syntax.
+ "inputddl5",
+
// Thift is broken...
"inputddl8",
@@ -603,7 +606,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
"inputddl2",
"inputddl3",
"inputddl4",
- "inputddl5",
"inputddl6",
"inputddl7",
"inputddl8",
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/ExtendedHiveQlParser.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/ExtendedHiveQlParser.scala
index b22f424981..313ba18f6a 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/ExtendedHiveQlParser.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/ExtendedHiveQlParser.scala
@@ -19,14 +19,23 @@ package org.apache.spark.sql.hive
import scala.language.implicitConversions
+import org.apache.spark.sql.catalyst.{AbstractSparkSQLParser, TableIdentifier}
+import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.plans.logical._
-import org.apache.spark.sql.catalyst.AbstractSparkSQLParser
import org.apache.spark.sql.hive.execution.{AddFile, AddJar, HiveNativeCommand}
/**
* A parser that recognizes all HiveQL constructs together with Spark SQL specific extensions.
*/
-private[hive] class ExtendedHiveQlParser extends AbstractSparkSQLParser {
+private[hive] class ExtendedHiveQlParser(sqlContext: HiveContext) extends AbstractSparkSQLParser {
+
+ val parser = new HiveQl(sqlContext.conf)
+
+ override def parseExpression(sql: String): Expression = parser.parseExpression(sql)
+
+ override def parseTableIdentifier(sql: String): TableIdentifier =
+ parser.parseTableIdentifier(sql)
+
// Keyword is a convention with AbstractSparkSQLParser, which will scan all of the `Keyword`
// properties via reflection the class in runtime for constructing the SqlLexical object
protected val ADD = Keyword("ADD")
@@ -38,7 +47,10 @@ private[hive] class ExtendedHiveQlParser extends AbstractSparkSQLParser {
protected lazy val hiveQl: Parser[LogicalPlan] =
restInput ^^ {
- case statement => HiveQl.parsePlan(statement.trim)
+ case statement =>
+ sqlContext.executionHive.withHiveState {
+ parser.parsePlan(statement.trim)
+ }
}
protected lazy val dfs: Parser[LogicalPlan] =
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
index cbaf00603e..7bdca52200 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
@@ -42,7 +42,7 @@ import org.apache.spark.api.java.JavaSparkContext
import org.apache.spark.sql._
import org.apache.spark.sql.SQLConf.SQLConfEntry
import org.apache.spark.sql.SQLConf.SQLConfEntry._
-import org.apache.spark.sql.catalyst.{InternalRow, ParserDialect, SqlParser}
+import org.apache.spark.sql.catalyst.{InternalRow, ParserDialect}
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.expressions.{Expression, LeafExpression}
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
@@ -57,17 +57,6 @@ import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.Utils
/**
- * This is the HiveQL Dialect, this dialect is strongly bind with HiveContext
- */
-private[hive] class HiveQLDialect(sqlContext: HiveContext) extends ParserDialect {
- override def parse(sqlText: String): LogicalPlan = {
- sqlContext.executionHive.withHiveState {
- HiveQl.parseSql(sqlText)
- }
- }
-}
-
-/**
* Returns the current database of metadataHive.
*/
private[hive] case class CurrentDatabase(ctx: HiveContext)
@@ -342,12 +331,12 @@ class HiveContext private[hive](
* @since 1.3.0
*/
def refreshTable(tableName: String): Unit = {
- val tableIdent = SqlParser.parseTableIdentifier(tableName)
+ val tableIdent = sqlParser.parseTableIdentifier(tableName)
catalog.refreshTable(tableIdent)
}
protected[hive] def invalidateTable(tableName: String): Unit = {
- val tableIdent = SqlParser.parseTableIdentifier(tableName)
+ val tableIdent = sqlParser.parseTableIdentifier(tableName)
catalog.invalidateTable(tableIdent)
}
@@ -361,7 +350,7 @@ class HiveContext private[hive](
* @since 1.2.0
*/
def analyze(tableName: String) {
- val tableIdent = SqlParser.parseTableIdentifier(tableName)
+ val tableIdent = sqlParser.parseTableIdentifier(tableName)
val relation = EliminateSubQueries(catalog.lookupRelation(tableIdent))
relation match {
@@ -559,7 +548,7 @@ class HiveContext private[hive](
protected[sql] override def getSQLDialect(): ParserDialect = {
if (conf.dialect == "hiveql") {
- new HiveQLDialect(this)
+ new ExtendedHiveQlParser(this)
} else {
super.getSQLDialect()
}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
index daaa5a5709..3d54048c24 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
@@ -416,8 +416,8 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive
alias match {
// because hive use things like `_c0` to build the expanded text
// currently we cannot support view from "create view v1(c1) as ..."
- case None => Subquery(table.name, HiveQl.parsePlan(viewText))
- case Some(aliasText) => Subquery(aliasText, HiveQl.parsePlan(viewText))
+ case None => Subquery(table.name, hive.parseSql(viewText))
+ case Some(aliasText) => Subquery(aliasText, hive.parseSql(viewText))
}
} else {
MetastoreRelation(qualifiedTableName.database, qualifiedTableName.name, alias)(table)(hive)
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
index ca9ddf94c1..46246f8191 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
@@ -79,7 +79,7 @@ private[hive] case class CreateViewAsSelect(
}
/** Provides a mapping from HiveQL statements to catalyst logical plans and expression trees. */
-private[hive] object HiveQl extends SparkQl with Logging {
+private[hive] class HiveQl(conf: ParserConf) extends SparkQl(conf) with Logging {
protected val nativeCommands = Seq(
"TOK_ALTERDATABASE_OWNER",
"TOK_ALTERDATABASE_PROPERTIES",
@@ -168,8 +168,6 @@ private[hive] object HiveQl extends SparkQl with Logging {
"TOK_TRUNCATETABLE" // truncate table" is a NativeCommand, does not need to explain.
) ++ nativeCommands
- protected val hqlParser = new ExtendedHiveQlParser
-
/**
* Returns the HiveConf
*/
@@ -186,9 +184,6 @@ private[hive] object HiveQl extends SparkQl with Logging {
ss.getConf
}
- /** Returns a LogicalPlan for a given HiveQL string. */
- def parseSql(sql: String): LogicalPlan = hqlParser.parse(sql)
-
protected def getProperties(node: ASTNode): Seq[(String, String)] = node match {
case Token("TOK_TABLEPROPLIST", list) =>
list.map {
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQlSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQlSuite.scala
index 53d15c14cb..137dadd6c6 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQlSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQlSuite.scala
@@ -23,12 +23,15 @@ import org.scalatest.BeforeAndAfterAll
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions.JsonTuple
+import org.apache.spark.sql.catalyst.parser.SimpleParserConf
import org.apache.spark.sql.catalyst.plans.logical.Generate
import org.apache.spark.sql.hive.client.{ExternalTable, HiveColumn, HiveTable, ManagedTable}
class HiveQlSuite extends SparkFunSuite with BeforeAndAfterAll {
+ val parser = new HiveQl(SimpleParserConf())
+
private def extractTableDesc(sql: String): (HiveTable, Boolean) = {
- HiveQl.parsePlan(sql).collect {
+ parser.parsePlan(sql).collect {
case CreateTableAsSelect(desc, child, allowExisting) => (desc, allowExisting)
}.head
}
@@ -173,7 +176,7 @@ class HiveQlSuite extends SparkFunSuite with BeforeAndAfterAll {
test("Invalid interval term should throw AnalysisException") {
def assertError(sql: String, errorMessage: String): Unit = {
val e = intercept[AnalysisException] {
- HiveQl.parseSql(sql)
+ parser.parsePlan(sql)
}
assert(e.getMessage.contains(errorMessage))
}
@@ -186,7 +189,7 @@ class HiveQlSuite extends SparkFunSuite with BeforeAndAfterAll {
}
test("use native json_tuple instead of hive's UDTF in LATERAL VIEW") {
- val plan = HiveQl.parseSql(
+ val plan = parser.parsePlan(
"""
|SELECT *
|FROM (SELECT '{"f1": "value1", "f2": 12}' json) test
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala
index 78f74cdc19..91bedf9c5a 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala
@@ -21,6 +21,7 @@ import scala.reflect.ClassTag
import org.apache.spark.sql.{QueryTest, Row, SQLConf}
import org.apache.spark.sql.catalyst.TableIdentifier
+import org.apache.spark.sql.catalyst.parser.SimpleParserConf
import org.apache.spark.sql.execution.joins._
import org.apache.spark.sql.hive.execution._
import org.apache.spark.sql.hive.test.TestHiveSingleton
@@ -28,9 +29,11 @@ import org.apache.spark.sql.hive.test.TestHiveSingleton
class StatisticsSuite extends QueryTest with TestHiveSingleton {
import hiveContext.sql
+ val parser = new HiveQl(SimpleParserConf())
+
test("parse analyze commands") {
def assertAnalyzeCommand(analyzeCommand: String, c: Class[_]) {
- val parsed = HiveQl.parseSql(analyzeCommand)
+ val parsed = parser.parsePlan(analyzeCommand)
val operators = parsed.collect {
case a: AnalyzeTable => a
case o => o
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
index f6c687aab7..61d5aa7ae6 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
@@ -22,12 +22,14 @@ import java.sql.{Date, Timestamp}
import scala.collection.JavaConverters._
import org.apache.spark.sql._
-import org.apache.spark.sql.catalyst.{DefaultParserDialect, TableIdentifier}
+import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis.{EliminateSubQueries, FunctionRegistry}
import org.apache.spark.sql.catalyst.errors.DialectException
+import org.apache.spark.sql.catalyst.parser.ParserConf
+import org.apache.spark.sql.execution.SparkQl
import org.apache.spark.sql.execution.datasources.LogicalRelation
import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation
-import org.apache.spark.sql.hive.{HiveContext, HiveQLDialect, MetastoreRelation}
+import org.apache.spark.sql.hive.{ExtendedHiveQlParser, HiveContext, HiveQl, MetastoreRelation}
import org.apache.spark.sql.hive.test.TestHiveSingleton
import org.apache.spark.sql.test.SQLTestUtils
import org.apache.spark.sql.types._
@@ -56,7 +58,7 @@ case class WindowData(
area: String,
product: Int)
/** A SQL Dialect for testing purpose, and it can not be nested type */
-class MyDialect extends DefaultParserDialect
+class MyDialect(conf: ParserConf) extends HiveQl(conf)
/**
* A collection of hive query tests where we generate the answers ourselves instead of depending on
@@ -339,20 +341,20 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton {
val hiveContext = new HiveContext(sqlContext.sparkContext)
val dialectConf = "spark.sql.dialect"
checkAnswer(hiveContext.sql(s"set $dialectConf"), Row(dialectConf, "hiveql"))
- assert(hiveContext.getSQLDialect().getClass === classOf[HiveQLDialect])
+ assert(hiveContext.getSQLDialect().getClass === classOf[ExtendedHiveQlParser])
}
test("SQL Dialect Switching") {
- assert(getSQLDialect().getClass === classOf[HiveQLDialect])
+ assert(getSQLDialect().getClass === classOf[ExtendedHiveQlParser])
setConf("spark.sql.dialect", classOf[MyDialect].getCanonicalName())
assert(getSQLDialect().getClass === classOf[MyDialect])
assert(sql("SELECT 1").collect() === Array(Row(1)))
// set the dialect back to the DefaultSQLDialect
sql("SET spark.sql.dialect=sql")
- assert(getSQLDialect().getClass === classOf[DefaultParserDialect])
+ assert(getSQLDialect().getClass === classOf[SparkQl])
sql("SET spark.sql.dialect=hiveql")
- assert(getSQLDialect().getClass === classOf[HiveQLDialect])
+ assert(getSQLDialect().getClass === classOf[ExtendedHiveQlParser])
// set invalid dialect
sql("SET spark.sql.dialect.abc=MyTestClass")
@@ -361,14 +363,14 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton {
sql("SELECT 1")
}
// test if the dialect set back to HiveQLDialect
- getSQLDialect().getClass === classOf[HiveQLDialect]
+ getSQLDialect().getClass === classOf[ExtendedHiveQlParser]
sql("SET spark.sql.dialect=MyTestClass")
intercept[DialectException] {
sql("SELECT 1")
}
// test if the dialect set back to HiveQLDialect
- assert(getSQLDialect().getClass === classOf[HiveQLDialect])
+ assert(getSQLDialect().getClass === classOf[ExtendedHiveQlParser])
}
test("CTAS with serde") {
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/CalendarInterval.java b/unsafe/src/main/java/org/apache/spark/unsafe/types/CalendarInterval.java
index 30e1758076..62edf6c64b 100644
--- a/unsafe/src/main/java/org/apache/spark/unsafe/types/CalendarInterval.java
+++ b/unsafe/src/main/java/org/apache/spark/unsafe/types/CalendarInterval.java
@@ -188,6 +188,11 @@ public final class CalendarInterval implements Serializable {
Integer.MIN_VALUE, Integer.MAX_VALUE);
result = new CalendarInterval(month, 0L);
+ } else if (unit.equals("week")) {
+ long week = toLongWithRange("week", m.group(1),
+ Long.MIN_VALUE / MICROS_PER_WEEK, Long.MAX_VALUE / MICROS_PER_WEEK);
+ result = new CalendarInterval(0, week * MICROS_PER_WEEK);
+
} else if (unit.equals("day")) {
long day = toLongWithRange("day", m.group(1),
Long.MIN_VALUE / MICROS_PER_DAY, Long.MAX_VALUE / MICROS_PER_DAY);
@@ -206,6 +211,15 @@ public final class CalendarInterval implements Serializable {
} else if (unit.equals("second")) {
long micros = parseSecondNano(m.group(1));
result = new CalendarInterval(0, micros);
+
+ } else if (unit.equals("millisecond")) {
+ long millisecond = toLongWithRange("millisecond", m.group(1),
+ Long.MIN_VALUE / MICROS_PER_MILLI, Long.MAX_VALUE / MICROS_PER_MILLI);
+ result = new CalendarInterval(0, millisecond * MICROS_PER_MILLI);
+
+ } else if (unit.equals("microsecond")) {
+ long micros = Long.valueOf(m.group(1));
+ result = new CalendarInterval(0, micros);
}
} catch (Exception e) {
throw new IllegalArgumentException("Error parsing interval string: " + e.getMessage(), e);