aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--python/pyspark/sql/tests.py3
-rw-r--r--sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/ExpressionParser.g57
-rw-r--r--sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/FromClauseParser.g7
-rw-r--r--sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlLexer.g32
-rw-r--r--sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g10
-rw-r--r--sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/parser/ParseUtils.java31
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala4
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala139
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ParserDialect.scala46
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala509
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala1
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystQlSuite.scala1
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/SqlParserSuite.scala150
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/Column.scala1
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala8
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala6
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala17
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSQLParser.scala19
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DDLParser.scala17
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/functions.scala7
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala73
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala12
-rw-r--r--sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala4
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/ExtendedHiveQlParser.scala18
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala21
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala4
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala7
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQlSuite.scala9
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala5
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala20
-rw-r--r--unsafe/src/main/java/org/apache/spark/unsafe/types/CalendarInterval.java14
33 files changed, 286 insertions, 972 deletions
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index e396cf41f2..c03cb9338a 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -1081,8 +1081,7 @@ class SQLTests(ReusedPySparkTestCase):
def test_capture_analysis_exception(self):
self.assertRaises(AnalysisException, lambda: self.sqlCtx.sql("select abc"))
self.assertRaises(AnalysisException, lambda: self.df.selectExpr("a + b"))
- # RuntimeException should not be captured
- self.assertRaises(py4j.protocol.Py4JJavaError, lambda: self.sqlCtx.sql("abc"))
+ self.assertRaises(AnalysisException, lambda: self.sqlCtx.sql("abc"))
def test_capture_illegalargument_exception(self):
self.assertRaisesRegexp(IllegalArgumentException, "Setting negative mapred.reduce.tasks",
diff --git a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/ExpressionParser.g b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/ExpressionParser.g
index aabb5d4958..047a7e56cb 100644
--- a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/ExpressionParser.g
+++ b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/ExpressionParser.g
@@ -123,7 +123,6 @@ constant
| SmallintLiteral
| TinyintLiteral
| DecimalLiteral
- | charSetStringLiteral
| booleanValue
;
@@ -132,13 +131,6 @@ stringLiteralSequence
StringLiteral StringLiteral+ -> ^(TOK_STRINGLITERALSEQUENCE StringLiteral StringLiteral+)
;
-charSetStringLiteral
-@init { gParent.pushMsg("character string literal", state); }
-@after { gParent.popMsg(state); }
- :
- csName=CharSetName csLiteral=CharSetLiteral -> ^(TOK_CHARSETLITERAL $csName $csLiteral)
- ;
-
dateLiteral
:
KW_DATE StringLiteral ->
@@ -163,22 +155,38 @@ timestampLiteral
intervalLiteral
:
- KW_INTERVAL StringLiteral qualifiers=intervalQualifiers ->
- {
- adaptor.create($qualifiers.tree.token.getType(), $StringLiteral.text)
+ (KW_INTERVAL intervalConstant KW_YEAR KW_TO KW_MONTH) => KW_INTERVAL intervalConstant KW_YEAR KW_TO KW_MONTH
+ -> ^(TOK_INTERVAL_YEAR_MONTH_LITERAL intervalConstant)
+ | (KW_INTERVAL intervalConstant KW_DAY KW_TO KW_SECOND) => KW_INTERVAL intervalConstant KW_DAY KW_TO KW_SECOND
+ -> ^(TOK_INTERVAL_DAY_TIME_LITERAL intervalConstant)
+ | KW_INTERVAL
+ ((intervalConstant KW_YEAR)=> year=intervalConstant KW_YEAR)?
+ ((intervalConstant KW_MONTH)=> month=intervalConstant KW_MONTH)?
+ ((intervalConstant KW_WEEK)=> week=intervalConstant KW_WEEK)?
+ ((intervalConstant KW_DAY)=> day=intervalConstant KW_DAY)?
+ ((intervalConstant KW_HOUR)=> hour=intervalConstant KW_HOUR)?
+ ((intervalConstant KW_MINUTE)=> minute=intervalConstant KW_MINUTE)?
+ ((intervalConstant KW_SECOND)=> second=intervalConstant KW_SECOND)?
+ (millisecond=intervalConstant KW_MILLISECOND)?
+ (microsecond=intervalConstant KW_MICROSECOND)?
+ -> ^(TOK_INTERVAL
+ ^(TOK_INTERVAL_YEAR_LITERAL $year?)
+ ^(TOK_INTERVAL_MONTH_LITERAL $month?)
+ ^(TOK_INTERVAL_WEEK_LITERAL $week?)
+ ^(TOK_INTERVAL_DAY_LITERAL $day?)
+ ^(TOK_INTERVAL_HOUR_LITERAL $hour?)
+ ^(TOK_INTERVAL_MINUTE_LITERAL $minute?)
+ ^(TOK_INTERVAL_SECOND_LITERAL $second?)
+ ^(TOK_INTERVAL_MILLISECOND_LITERAL $millisecond?)
+ ^(TOK_INTERVAL_MICROSECOND_LITERAL $microsecond?))
+ ;
+
+intervalConstant
+ :
+ sign=(MINUS|PLUS)? value=Number -> {
+ adaptor.create(Number, ($sign != null ? $sign.getText() : "") + $value.getText())
}
- ;
-
-intervalQualifiers
- :
- KW_YEAR KW_TO KW_MONTH -> TOK_INTERVAL_YEAR_MONTH_LITERAL
- | KW_DAY KW_TO KW_SECOND -> TOK_INTERVAL_DAY_TIME_LITERAL
- | KW_YEAR -> TOK_INTERVAL_YEAR_LITERAL
- | KW_MONTH -> TOK_INTERVAL_MONTH_LITERAL
- | KW_DAY -> TOK_INTERVAL_DAY_LITERAL
- | KW_HOUR -> TOK_INTERVAL_HOUR_LITERAL
- | KW_MINUTE -> TOK_INTERVAL_MINUTE_LITERAL
- | KW_SECOND -> TOK_INTERVAL_SECOND_LITERAL
+ | StringLiteral
;
expression
@@ -219,7 +227,8 @@ nullCondition
precedenceUnaryPrefixExpression
:
- (precedenceUnaryOperator^)* precedenceFieldExpression
+ (precedenceUnaryOperator+)=> precedenceUnaryOperator^ precedenceUnaryPrefixExpression
+ | precedenceFieldExpression
;
precedenceUnarySuffixExpression
diff --git a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/FromClauseParser.g b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/FromClauseParser.g
index 972c52e3ff..6d76afcd4a 100644
--- a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/FromClauseParser.g
+++ b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/FromClauseParser.g
@@ -206,11 +206,8 @@ tableName
@init { gParent.pushMsg("table name", state); }
@after { gParent.popMsg(state); }
:
- db=identifier DOT tab=identifier
- -> ^(TOK_TABNAME $db $tab)
- |
- tab=identifier
- -> ^(TOK_TABNAME $tab)
+ id1=identifier (DOT id2=identifier)?
+ -> ^(TOK_TABNAME $id1 $id2?)
;
viewName
diff --git a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlLexer.g b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlLexer.g
index 44a63fbef2..ee2882e51c 100644
--- a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlLexer.g
+++ b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlLexer.g
@@ -307,12 +307,12 @@ KW_AUTHORIZATION: 'AUTHORIZATION';
KW_CONF: 'CONF';
KW_VALUES: 'VALUES';
KW_RELOAD: 'RELOAD';
-KW_YEAR: 'YEAR';
-KW_MONTH: 'MONTH';
-KW_DAY: 'DAY';
-KW_HOUR: 'HOUR';
-KW_MINUTE: 'MINUTE';
-KW_SECOND: 'SECOND';
+KW_YEAR: 'YEAR'|'YEARS';
+KW_MONTH: 'MONTH'|'MONTHS';
+KW_DAY: 'DAY'|'DAYS';
+KW_HOUR: 'HOUR'|'HOURS';
+KW_MINUTE: 'MINUTE'|'MINUTES';
+KW_SECOND: 'SECOND'|'SECONDS';
KW_START: 'START';
KW_TRANSACTION: 'TRANSACTION';
KW_COMMIT: 'COMMIT';
@@ -324,6 +324,9 @@ KW_ISOLATION: 'ISOLATION';
KW_LEVEL: 'LEVEL';
KW_SNAPSHOT: 'SNAPSHOT';
KW_AUTOCOMMIT: 'AUTOCOMMIT';
+KW_WEEK: 'WEEK'|'WEEKS';
+KW_MILLISECOND: 'MILLISECOND'|'MILLISECONDS';
+KW_MICROSECOND: 'MICROSECOND'|'MICROSECONDS';
// Operators
// NOTE: if you add a new function/operator, add it to sysFuncNames so that describe function _FUNC_ will work.
@@ -400,12 +403,6 @@ StringLiteral
)+
;
-CharSetLiteral
- :
- StringLiteral
- | '0' 'X' (HexDigit|Digit)+
- ;
-
BigintLiteral
:
(Digit)+ 'L'
@@ -433,7 +430,7 @@ ByteLengthLiteral
Number
:
- (Digit)+ ( DOT (Digit)* (Exponent)? | Exponent)?
+ ((Digit+ (DOT Digit*)?) | (DOT Digit+)) Exponent?
;
/*
@@ -456,10 +453,10 @@ An Identifier can be:
- macro name
- hint name
- window name
-*/
+*/
Identifier
:
- (Letter | Digit) (Letter | Digit | '_')*
+ (Letter | Digit | '_')+
| {allowQuotedId()}? QuotedIdentifier /* though at the language level we allow all Identifiers to be QuotedIdentifiers;
at the API level only columns are allowed to be of this form */
| '`' RegexComponent+ '`'
@@ -471,11 +468,6 @@ QuotedIdentifier
'`' ( '``' | ~('`') )* '`' { setText(getText().substring(1, getText().length() -1 ).replaceAll("``", "`")); }
;
-CharSetName
- :
- '_' (Letter | Digit | '_' | '-' | '.' | ':' )+
- ;
-
WS : (' '|'\r'|'\t'|'\n') {$channel=HIDDEN;}
;
diff --git a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g
index 2c13d3056f..c146ca5914 100644
--- a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g
+++ b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g
@@ -116,16 +116,20 @@ TOK_DATELITERAL;
TOK_DATETIME;
TOK_TIMESTAMP;
TOK_TIMESTAMPLITERAL;
+TOK_INTERVAL;
TOK_INTERVAL_YEAR_MONTH;
TOK_INTERVAL_YEAR_MONTH_LITERAL;
TOK_INTERVAL_DAY_TIME;
TOK_INTERVAL_DAY_TIME_LITERAL;
TOK_INTERVAL_YEAR_LITERAL;
TOK_INTERVAL_MONTH_LITERAL;
+TOK_INTERVAL_WEEK_LITERAL;
TOK_INTERVAL_DAY_LITERAL;
TOK_INTERVAL_HOUR_LITERAL;
TOK_INTERVAL_MINUTE_LITERAL;
TOK_INTERVAL_SECOND_LITERAL;
+TOK_INTERVAL_MILLISECOND_LITERAL;
+TOK_INTERVAL_MICROSECOND_LITERAL;
TOK_STRING;
TOK_CHAR;
TOK_VARCHAR;
@@ -228,7 +232,6 @@ TOK_TMP_FILE;
TOK_TABSORTCOLNAMEASC;
TOK_TABSORTCOLNAMEDESC;
TOK_STRINGLITERALSEQUENCE;
-TOK_CHARSETLITERAL;
TOK_CREATEFUNCTION;
TOK_DROPFUNCTION;
TOK_RELOADFUNCTION;
@@ -509,7 +512,9 @@ import java.util.HashMap;
xlateMap.put("KW_UPDATE", "UPDATE");
xlateMap.put("KW_VALUES", "VALUES");
xlateMap.put("KW_PURGE", "PURGE");
-
+ xlateMap.put("KW_WEEK", "WEEK");
+ xlateMap.put("KW_MILLISECOND", "MILLISECOND");
+ xlateMap.put("KW_MICROSECOND", "MICROSECOND");
// Operators
xlateMap.put("DOT", ".");
@@ -2078,6 +2083,7 @@ primitiveType
| KW_SMALLINT -> TOK_SMALLINT
| KW_INT -> TOK_INT
| KW_BIGINT -> TOK_BIGINT
+ | KW_LONG -> TOK_BIGINT
| KW_BOOLEAN -> TOK_BOOLEAN
| KW_FLOAT -> TOK_FLOAT
| KW_DOUBLE -> TOK_DOUBLE
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/parser/ParseUtils.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/parser/ParseUtils.java
index 5bc87b680f..2520c7bb8d 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/parser/ParseUtils.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/parser/ParseUtils.java
@@ -18,12 +18,10 @@
package org.apache.spark.sql.catalyst.parser;
-import java.io.UnsupportedEncodingException;
-
/**
* A couple of utility methods that help with parsing ASTs.
*
- * Both methods in this class were take from the SemanticAnalyzer in Hive:
+ * The 'unescapeSQLString' method in this class was take from the SemanticAnalyzer in Hive:
* ql/src/java/org/apache/hadoop/hive/ql/parse/BaseSemanticAnalyzer.java
*/
public final class ParseUtils {
@@ -31,33 +29,6 @@ public final class ParseUtils {
super();
}
- public static String charSetString(String charSetName, String charSetString)
- throws UnsupportedEncodingException {
- // The character set name starts with a _, so strip that
- charSetName = charSetName.substring(1);
- if (charSetString.charAt(0) == '\'') {
- return new String(unescapeSQLString(charSetString).getBytes(), charSetName);
- } else // hex input is also supported
- {
- assert charSetString.charAt(0) == '0';
- assert charSetString.charAt(1) == 'x';
- charSetString = charSetString.substring(2);
-
- byte[] bArray = new byte[charSetString.length() / 2];
- int j = 0;
- for (int i = 0; i < charSetString.length(); i += 2) {
- int val = Character.digit(charSetString.charAt(i), 16) * 16
- + Character.digit(charSetString.charAt(i + 1), 16);
- if (val > 127) {
- val = val - 256;
- }
- bArray[j++] = (byte)val;
- }
-
- return new String(bArray, charSetName);
- }
- }
-
private static final int[] multiplier = new int[] {1000, 100, 10, 1};
@SuppressWarnings("nls")
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala
index bdc52c08ac..9443369808 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala
@@ -26,9 +26,9 @@ import scala.util.parsing.input.CharArrayReader.EofCh
import org.apache.spark.sql.catalyst.plans.logical._
private[sql] abstract class AbstractSparkSQLParser
- extends StandardTokenParsers with PackratParsers {
+ extends StandardTokenParsers with PackratParsers with ParserDialect {
- def parse(input: String): LogicalPlan = synchronized {
+ def parsePlan(input: String): LogicalPlan = synchronized {
// Initialize the Keywords.
initLexical
phrase(start)(new lexical.Scanner(input)) match {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala
index d0fbdacf6e..c1591ecfe2 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala
@@ -30,16 +30,10 @@ import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.CalendarInterval
import org.apache.spark.util.random.RandomSampler
-private[sql] object CatalystQl {
- val parser = new CatalystQl
- def parseExpression(sql: String): Expression = parser.parseExpression(sql)
- def parseTableIdentifier(sql: String): TableIdentifier = parser.parseTableIdentifier(sql)
-}
-
/**
* This class translates a HQL String to a Catalyst [[LogicalPlan]] or [[Expression]].
*/
-private[sql] class CatalystQl(val conf: ParserConf = SimpleParserConf()) {
+private[sql] class CatalystQl(val conf: ParserConf = SimpleParserConf()) extends ParserDialect {
object Token {
def unapply(node: ASTNode): Some[(String, List[ASTNode])] = {
CurrentOrigin.setPosition(node.line, node.positionInLine)
@@ -611,13 +605,6 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
case plainIdent => plainIdent
}
- val numericAstTypes = Seq(
- SparkSqlParser.Number,
- SparkSqlParser.TinyintLiteral,
- SparkSqlParser.SmallintLiteral,
- SparkSqlParser.BigintLiteral,
- SparkSqlParser.DecimalLiteral)
-
/* Case insensitive matches */
val COUNT = "(?i)COUNT".r
val SUM = "(?i)SUM".r
@@ -635,6 +622,8 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
val WHEN = "(?i)WHEN".r
val CASE = "(?i)CASE".r
+ val INTEGRAL = "[+-]?\\d+".r
+
protected def nodeToExpr(node: ASTNode): Expression = node match {
/* Attribute References */
case Token("TOK_TABLE_OR_COL", Token(name, Nil) :: Nil) =>
@@ -650,8 +639,8 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
case Token("TOK_ALLCOLREF", Nil) => UnresolvedStar(None)
// The format of dbName.tableName.* cannot be parsed by HiveParser. TOK_TABNAME will only
// has a single child which is tableName.
- case Token("TOK_ALLCOLREF", Token("TOK_TABNAME", Token(name, Nil) :: Nil) :: Nil) =>
- UnresolvedStar(Some(UnresolvedAttribute.parseAttributeName(name)))
+ case Token("TOK_ALLCOLREF", Token("TOK_TABNAME", target) :: Nil) if target.nonEmpty =>
+ UnresolvedStar(Some(target.map(_.text)))
/* Aggregate Functions */
case Token("TOK_FUNCTIONDI", Token(COUNT(), Nil) :: args) =>
@@ -787,71 +776,71 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
case Token("TOK_STRINGLITERALSEQUENCE", strings) =>
Literal(strings.map(s => ParseUtils.unescapeSQLString(s.text)).mkString)
- // This code is adapted from
- // /ql/src/java/org/apache/hadoop/hive/ql/parse/TypeCheckProcFactory.java#L223
- case ast: ASTNode if numericAstTypes contains ast.tokenType =>
- var v: Literal = null
- try {
- if (ast.text.endsWith("L")) {
- // Literal bigint.
- v = Literal.create(ast.text.substring(0, ast.text.length() - 1).toLong, LongType)
- } else if (ast.text.endsWith("S")) {
- // Literal smallint.
- v = Literal.create(ast.text.substring(0, ast.text.length() - 1).toShort, ShortType)
- } else if (ast.text.endsWith("Y")) {
- // Literal tinyint.
- v = Literal.create(ast.text.substring(0, ast.text.length() - 1).toByte, ByteType)
- } else if (ast.text.endsWith("BD") || ast.text.endsWith("D")) {
- // Literal decimal
- val strVal = ast.text.stripSuffix("D").stripSuffix("B")
- v = Literal(Decimal(strVal))
- } else {
- v = Literal.create(ast.text.toDouble, DoubleType)
- v = Literal.create(ast.text.toLong, LongType)
- v = Literal.create(ast.text.toInt, IntegerType)
- }
- } catch {
- case nfe: NumberFormatException => // Do nothing
- }
-
- if (v == null) {
- sys.error(s"Failed to parse number '${ast.text}'.")
- } else {
- v
- }
-
- case ast: ASTNode if ast.tokenType == SparkSqlParser.StringLiteral =>
- Literal(ParseUtils.unescapeSQLString(ast.text))
+ case ast if ast.tokenType == SparkSqlParser.TinyintLiteral =>
+ Literal.create(ast.text.substring(0, ast.text.length() - 1).toByte, ByteType)
- case ast: ASTNode if ast.tokenType == SparkSqlParser.TOK_DATELITERAL =>
- Literal(Date.valueOf(ast.text.substring(1, ast.text.length - 1)))
+ case ast if ast.tokenType == SparkSqlParser.SmallintLiteral =>
+ Literal.create(ast.text.substring(0, ast.text.length() - 1).toShort, ShortType)
- case ast: ASTNode if ast.tokenType == SparkSqlParser.TOK_CHARSETLITERAL =>
- Literal(ParseUtils.charSetString(ast.children.head.text, ast.children(1).text))
+ case ast if ast.tokenType == SparkSqlParser.BigintLiteral =>
+ Literal.create(ast.text.substring(0, ast.text.length() - 1).toLong, LongType)
- case ast: ASTNode if ast.tokenType == SparkSqlParser.TOK_INTERVAL_YEAR_MONTH_LITERAL =>
- Literal(CalendarInterval.fromYearMonthString(ast.text))
+ case ast if ast.tokenType == SparkSqlParser.DecimalLiteral =>
+ Literal(Decimal(ast.text.substring(0, ast.text.length() - 2)))
- case ast: ASTNode if ast.tokenType == SparkSqlParser.TOK_INTERVAL_DAY_TIME_LITERAL =>
- Literal(CalendarInterval.fromDayTimeString(ast.text))
-
- case ast: ASTNode if ast.tokenType == SparkSqlParser.TOK_INTERVAL_YEAR_LITERAL =>
- Literal(CalendarInterval.fromSingleUnitString("year", ast.text))
-
- case ast: ASTNode if ast.tokenType == SparkSqlParser.TOK_INTERVAL_MONTH_LITERAL =>
- Literal(CalendarInterval.fromSingleUnitString("month", ast.text))
-
- case ast: ASTNode if ast.tokenType == SparkSqlParser.TOK_INTERVAL_DAY_LITERAL =>
- Literal(CalendarInterval.fromSingleUnitString("day", ast.text))
-
- case ast: ASTNode if ast.tokenType == SparkSqlParser.TOK_INTERVAL_HOUR_LITERAL =>
- Literal(CalendarInterval.fromSingleUnitString("hour", ast.text))
+ case ast if ast.tokenType == SparkSqlParser.Number =>
+ val text = ast.text
+ text match {
+ case INTEGRAL() =>
+ BigDecimal(text) match {
+ case v if v.isValidInt =>
+ Literal(v.intValue())
+ case v if v.isValidLong =>
+ Literal(v.longValue())
+ case v => Literal(v.underlying())
+ }
+ case _ =>
+ Literal(text.toDouble)
+ }
+ case ast if ast.tokenType == SparkSqlParser.StringLiteral =>
+ Literal(ParseUtils.unescapeSQLString(ast.text))
- case ast: ASTNode if ast.tokenType == SparkSqlParser.TOK_INTERVAL_MINUTE_LITERAL =>
- Literal(CalendarInterval.fromSingleUnitString("minute", ast.text))
+ case ast if ast.tokenType == SparkSqlParser.TOK_DATELITERAL =>
+ Literal(Date.valueOf(ast.text.substring(1, ast.text.length - 1)))
- case ast: ASTNode if ast.tokenType == SparkSqlParser.TOK_INTERVAL_SECOND_LITERAL =>
- Literal(CalendarInterval.fromSingleUnitString("second", ast.text))
+ case ast if ast.tokenType == SparkSqlParser.TOK_INTERVAL_YEAR_MONTH_LITERAL =>
+ Literal(CalendarInterval.fromYearMonthString(ast.children.head.text))
+
+ case ast if ast.tokenType == SparkSqlParser.TOK_INTERVAL_DAY_TIME_LITERAL =>
+ Literal(CalendarInterval.fromDayTimeString(ast.children.head.text))
+
+ case Token("TOK_INTERVAL", elements) =>
+ var interval = new CalendarInterval(0, 0)
+ var updated = false
+ elements.foreach {
+ // The interval node will always contain children for all possible time units. A child node
+ // is only useful when it contains exactly one (numeric) child.
+ case e @ Token(name, Token(value, Nil) :: Nil) =>
+ val unit = name match {
+ case "TOK_INTERVAL_YEAR_LITERAL" => "year"
+ case "TOK_INTERVAL_MONTH_LITERAL" => "month"
+ case "TOK_INTERVAL_WEEK_LITERAL" => "week"
+ case "TOK_INTERVAL_DAY_LITERAL" => "day"
+ case "TOK_INTERVAL_HOUR_LITERAL" => "hour"
+ case "TOK_INTERVAL_MINUTE_LITERAL" => "minute"
+ case "TOK_INTERVAL_SECOND_LITERAL" => "second"
+ case "TOK_INTERVAL_MILLISECOND_LITERAL" => "millisecond"
+ case "TOK_INTERVAL_MICROSECOND_LITERAL" => "microsecond"
+ case _ => noParseRule(s"Interval($name)", e)
+ }
+ interval = interval.add(CalendarInterval.fromSingleUnitString(unit, value))
+ updated = true
+ case _ =>
+ }
+ if (!updated) {
+ throw new AnalysisException("at least one time unit should be given for interval literal")
+ }
+ Literal(interval)
case _ =>
noParseRule("Expression", node)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ParserDialect.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ParserDialect.scala
index e21d3c0546..7d9fbf2f12 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ParserDialect.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ParserDialect.scala
@@ -18,52 +18,22 @@
package org.apache.spark.sql.catalyst
import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
/**
* Root class of SQL Parser Dialect, and we don't guarantee the binary
* compatibility for the future release, let's keep it as the internal
* interface for advanced user.
- *
*/
@DeveloperApi
-abstract class ParserDialect {
- // this is the main function that will be implemented by sql parser.
- def parse(sqlText: String): LogicalPlan
-}
+trait ParserDialect {
+ /** Creates LogicalPlan for a given SQL string. */
+ def parsePlan(sqlText: String): LogicalPlan
-/**
- * Currently we support the default dialect named "sql", associated with the class
- * [[DefaultParserDialect]]
- *
- * And we can also provide custom SQL Dialect, for example in Spark SQL CLI:
- * {{{
- *-- switch to "hiveql" dialect
- * spark-sql>SET spark.sql.dialect=hiveql;
- * spark-sql>SELECT * FROM src LIMIT 1;
- *
- *-- switch to "sql" dialect
- * spark-sql>SET spark.sql.dialect=sql;
- * spark-sql>SELECT * FROM src LIMIT 1;
- *
- *-- register the new SQL dialect
- * spark-sql> SET spark.sql.dialect=com.xxx.xxx.SQL99Dialect;
- * spark-sql> SELECT * FROM src LIMIT 1;
- *
- *-- register the non-exist SQL dialect
- * spark-sql> SET spark.sql.dialect=NotExistedClass;
- * spark-sql> SELECT * FROM src LIMIT 1;
- *
- *-- Exception will be thrown and switch to dialect
- *-- "sql" (for SQLContext) or
- *-- "hiveql" (for HiveContext)
- * }}}
- */
-private[spark] class DefaultParserDialect extends ParserDialect {
- @transient
- protected val sqlParser = SqlParser
+ /** Creates Expression for a given SQL string. */
+ def parseExpression(sqlText: String): Expression
- override def parse(sqlText: String): LogicalPlan = {
- sqlParser.parse(sqlText)
- }
+ /** Creates TableIdentifier for a given SQL string. */
+ def parseTableIdentifier(sqlText: String): TableIdentifier
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
deleted file mode 100644
index 85ff4ea0c9..0000000000
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
+++ /dev/null
@@ -1,509 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.catalyst
-
-import scala.language.implicitConversions
-
-import org.apache.spark.sql.AnalysisException
-import org.apache.spark.sql.catalyst.analysis._
-import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.expressions.aggregate._
-import org.apache.spark.sql.catalyst.plans._
-import org.apache.spark.sql.catalyst.plans.logical._
-import org.apache.spark.sql.catalyst.util.DataTypeParser
-import org.apache.spark.sql.types._
-import org.apache.spark.unsafe.types.CalendarInterval
-
-/**
- * A very simple SQL parser. Based loosely on:
- * https://github.com/stephentu/scala-sql-parser/blob/master/src/main/scala/parser.scala
- *
- * Limitations:
- * - Only supports a very limited subset of SQL.
- *
- * This is currently included mostly for illustrative purposes. Users wanting more complete support
- * for a SQL like language should checkout the HiveQL support in the sql/hive sub-project.
- */
-object SqlParser extends AbstractSparkSQLParser with DataTypeParser {
-
- def parseExpression(input: String): Expression = synchronized {
- // Initialize the Keywords.
- initLexical
- phrase(projection)(new lexical.Scanner(input)) match {
- case Success(plan, _) => plan
- case failureOrError => sys.error(failureOrError.toString)
- }
- }
-
- def parseTableIdentifier(input: String): TableIdentifier = synchronized {
- // Initialize the Keywords.
- initLexical
- phrase(tableIdentifier)(new lexical.Scanner(input)) match {
- case Success(ident, _) => ident
- case failureOrError => sys.error(failureOrError.toString)
- }
- }
-
- // Keyword is a convention with AbstractSparkSQLParser, which will scan all of the `Keyword`
- // properties via reflection the class in runtime for constructing the SqlLexical object
- protected val ALL = Keyword("ALL")
- protected val AND = Keyword("AND")
- protected val APPROXIMATE = Keyword("APPROXIMATE")
- protected val AS = Keyword("AS")
- protected val ASC = Keyword("ASC")
- protected val BETWEEN = Keyword("BETWEEN")
- protected val BY = Keyword("BY")
- protected val CASE = Keyword("CASE")
- protected val CAST = Keyword("CAST")
- protected val DESC = Keyword("DESC")
- protected val DISTINCT = Keyword("DISTINCT")
- protected val ELSE = Keyword("ELSE")
- protected val END = Keyword("END")
- protected val EXCEPT = Keyword("EXCEPT")
- protected val FALSE = Keyword("FALSE")
- protected val FROM = Keyword("FROM")
- protected val FULL = Keyword("FULL")
- protected val GROUP = Keyword("GROUP")
- protected val HAVING = Keyword("HAVING")
- protected val IN = Keyword("IN")
- protected val INNER = Keyword("INNER")
- protected val INSERT = Keyword("INSERT")
- protected val INTERSECT = Keyword("INTERSECT")
- protected val INTERVAL = Keyword("INTERVAL")
- protected val INTO = Keyword("INTO")
- protected val IS = Keyword("IS")
- protected val JOIN = Keyword("JOIN")
- protected val LEFT = Keyword("LEFT")
- protected val LIKE = Keyword("LIKE")
- protected val LIMIT = Keyword("LIMIT")
- protected val NOT = Keyword("NOT")
- protected val NULL = Keyword("NULL")
- protected val ON = Keyword("ON")
- protected val OR = Keyword("OR")
- protected val ORDER = Keyword("ORDER")
- protected val SORT = Keyword("SORT")
- protected val OUTER = Keyword("OUTER")
- protected val OVERWRITE = Keyword("OVERWRITE")
- protected val REGEXP = Keyword("REGEXP")
- protected val RIGHT = Keyword("RIGHT")
- protected val RLIKE = Keyword("RLIKE")
- protected val SELECT = Keyword("SELECT")
- protected val SEMI = Keyword("SEMI")
- protected val TABLE = Keyword("TABLE")
- protected val THEN = Keyword("THEN")
- protected val TRUE = Keyword("TRUE")
- protected val UNION = Keyword("UNION")
- protected val WHEN = Keyword("WHEN")
- protected val WHERE = Keyword("WHERE")
- protected val WITH = Keyword("WITH")
-
- protected lazy val start: Parser[LogicalPlan] =
- start1 | insert | cte
-
- protected lazy val start1: Parser[LogicalPlan] =
- (select | ("(" ~> select <~ ")")) *
- ( UNION ~ ALL ^^^ { (q1: LogicalPlan, q2: LogicalPlan) => Union(q1, q2) }
- | INTERSECT ^^^ { (q1: LogicalPlan, q2: LogicalPlan) => Intersect(q1, q2) }
- | EXCEPT ^^^ { (q1: LogicalPlan, q2: LogicalPlan) => Except(q1, q2)}
- | UNION ~ DISTINCT.? ^^^ { (q1: LogicalPlan, q2: LogicalPlan) => Distinct(Union(q1, q2)) }
- )
-
- protected lazy val select: Parser[LogicalPlan] =
- SELECT ~> DISTINCT.? ~
- repsep(projection, ",") ~
- (FROM ~> relations).? ~
- (WHERE ~> expression).? ~
- (GROUP ~ BY ~> rep1sep(expression, ",")).? ~
- (HAVING ~> expression).? ~
- sortType.? ~
- (LIMIT ~> expression).? ^^ {
- case d ~ p ~ r ~ f ~ g ~ h ~ o ~ l =>
- val base = r.getOrElse(OneRowRelation)
- val withFilter = f.map(Filter(_, base)).getOrElse(base)
- val withProjection = g
- .map(Aggregate(_, p.map(UnresolvedAlias(_)), withFilter))
- .getOrElse(Project(p.map(UnresolvedAlias(_)), withFilter))
- val withDistinct = d.map(_ => Distinct(withProjection)).getOrElse(withProjection)
- val withHaving = h.map(Filter(_, withDistinct)).getOrElse(withDistinct)
- val withOrder = o.map(_(withHaving)).getOrElse(withHaving)
- val withLimit = l.map(Limit(_, withOrder)).getOrElse(withOrder)
- withLimit
- }
-
- protected lazy val insert: Parser[LogicalPlan] =
- INSERT ~> (OVERWRITE ^^^ true | INTO ^^^ false) ~ (TABLE ~> relation) ~ select ^^ {
- case o ~ r ~ s => InsertIntoTable(r, Map.empty[String, Option[String]], s, o, false)
- }
-
- protected lazy val cte: Parser[LogicalPlan] =
- WITH ~> rep1sep(ident ~ ( AS ~ "(" ~> start1 <~ ")"), ",") ~ (start1 | insert) ^^ {
- case r ~ s => With(s, r.map({case n ~ s => (n, Subquery(n, s))}).toMap)
- }
-
- protected lazy val projection: Parser[Expression] =
- expression ~ (AS.? ~> ident.?) ^^ {
- case e ~ a => a.fold(e)(Alias(e, _)())
- }
-
- // Based very loosely on the MySQL Grammar.
- // http://dev.mysql.com/doc/refman/5.0/en/join.html
- protected lazy val relations: Parser[LogicalPlan] =
- ( relation ~ rep1("," ~> relation) ^^ {
- case r1 ~ joins => joins.foldLeft(r1) { case(lhs, r) => Join(lhs, r, Inner, None) } }
- | relation
- )
-
- protected lazy val relation: Parser[LogicalPlan] =
- joinedRelation | relationFactor
-
- protected lazy val relationFactor: Parser[LogicalPlan] =
- ( tableIdentifier ~ (opt(AS) ~> opt(ident)) ^^ {
- case tableIdent ~ alias => UnresolvedRelation(tableIdent, alias)
- }
- | ("(" ~> start <~ ")") ~ (AS.? ~> ident) ^^ { case s ~ a => Subquery(a, s) }
- )
-
- protected lazy val joinedRelation: Parser[LogicalPlan] =
- relationFactor ~ rep1(joinType.? ~ (JOIN ~> relationFactor) ~ joinConditions.?) ^^ {
- case r1 ~ joins =>
- joins.foldLeft(r1) { case (lhs, jt ~ rhs ~ cond) =>
- Join(lhs, rhs, joinType = jt.getOrElse(Inner), cond)
- }
- }
-
- protected lazy val joinConditions: Parser[Expression] =
- ON ~> expression
-
- protected lazy val joinType: Parser[JoinType] =
- ( INNER ^^^ Inner
- | LEFT ~ SEMI ^^^ LeftSemi
- | LEFT ~ OUTER.? ^^^ LeftOuter
- | RIGHT ~ OUTER.? ^^^ RightOuter
- | FULL ~ OUTER.? ^^^ FullOuter
- )
-
- protected lazy val sortType: Parser[LogicalPlan => LogicalPlan] =
- ( ORDER ~ BY ~> ordering ^^ { case o => l: LogicalPlan => Sort(o, true, l) }
- | SORT ~ BY ~> ordering ^^ { case o => l: LogicalPlan => Sort(o, false, l) }
- )
-
- protected lazy val ordering: Parser[Seq[SortOrder]] =
- ( rep1sep(expression ~ direction.?, ",") ^^ {
- case exps => exps.map(pair => SortOrder(pair._1, pair._2.getOrElse(Ascending)))
- }
- )
-
- protected lazy val direction: Parser[SortDirection] =
- ( ASC ^^^ Ascending
- | DESC ^^^ Descending
- )
-
- protected lazy val expression: Parser[Expression] =
- orExpression
-
- protected lazy val orExpression: Parser[Expression] =
- andExpression * (OR ^^^ { (e1: Expression, e2: Expression) => Or(e1, e2) })
-
- protected lazy val andExpression: Parser[Expression] =
- notExpression * (AND ^^^ { (e1: Expression, e2: Expression) => And(e1, e2) })
-
- protected lazy val notExpression: Parser[Expression] =
- NOT.? ~ comparisonExpression ^^ { case maybeNot ~ e => maybeNot.map(_ => Not(e)).getOrElse(e) }
-
- protected lazy val comparisonExpression: Parser[Expression] =
- ( termExpression ~ ("=" ~> termExpression) ^^ { case e1 ~ e2 => EqualTo(e1, e2) }
- | termExpression ~ ("<" ~> termExpression) ^^ { case e1 ~ e2 => LessThan(e1, e2) }
- | termExpression ~ ("<=" ~> termExpression) ^^ { case e1 ~ e2 => LessThanOrEqual(e1, e2) }
- | termExpression ~ (">" ~> termExpression) ^^ { case e1 ~ e2 => GreaterThan(e1, e2) }
- | termExpression ~ (">=" ~> termExpression) ^^ { case e1 ~ e2 => GreaterThanOrEqual(e1, e2) }
- | termExpression ~ ("!=" ~> termExpression) ^^ { case e1 ~ e2 => Not(EqualTo(e1, e2)) }
- | termExpression ~ ("<>" ~> termExpression) ^^ { case e1 ~ e2 => Not(EqualTo(e1, e2)) }
- | termExpression ~ ("<=>" ~> termExpression) ^^ { case e1 ~ e2 => EqualNullSafe(e1, e2) }
- | termExpression ~ NOT.? ~ (BETWEEN ~> termExpression) ~ (AND ~> termExpression) ^^ {
- case e ~ not ~ el ~ eu =>
- val betweenExpr: Expression = And(GreaterThanOrEqual(e, el), LessThanOrEqual(e, eu))
- not.fold(betweenExpr)(f => Not(betweenExpr))
- }
- | termExpression ~ (RLIKE ~> termExpression) ^^ { case e1 ~ e2 => RLike(e1, e2) }
- | termExpression ~ (REGEXP ~> termExpression) ^^ { case e1 ~ e2 => RLike(e1, e2) }
- | termExpression ~ (LIKE ~> termExpression) ^^ { case e1 ~ e2 => Like(e1, e2) }
- | termExpression ~ (NOT ~ LIKE ~> termExpression) ^^ { case e1 ~ e2 => Not(Like(e1, e2)) }
- | termExpression ~ (IN ~ "(" ~> rep1sep(termExpression, ",")) <~ ")" ^^ {
- case e1 ~ e2 => In(e1, e2)
- }
- | termExpression ~ (NOT ~ IN ~ "(" ~> rep1sep(termExpression, ",")) <~ ")" ^^ {
- case e1 ~ e2 => Not(In(e1, e2))
- }
- | termExpression <~ IS ~ NULL ^^ { case e => IsNull(e) }
- | termExpression <~ IS ~ NOT ~ NULL ^^ { case e => IsNotNull(e) }
- | termExpression
- )
-
- protected lazy val termExpression: Parser[Expression] =
- productExpression *
- ( "+" ^^^ { (e1: Expression, e2: Expression) => Add(e1, e2) }
- | "-" ^^^ { (e1: Expression, e2: Expression) => Subtract(e1, e2) }
- )
-
- protected lazy val productExpression: Parser[Expression] =
- baseExpression *
- ( "*" ^^^ { (e1: Expression, e2: Expression) => Multiply(e1, e2) }
- | "/" ^^^ { (e1: Expression, e2: Expression) => Divide(e1, e2) }
- | "%" ^^^ { (e1: Expression, e2: Expression) => Remainder(e1, e2) }
- | "&" ^^^ { (e1: Expression, e2: Expression) => BitwiseAnd(e1, e2) }
- | "|" ^^^ { (e1: Expression, e2: Expression) => BitwiseOr(e1, e2) }
- | "^" ^^^ { (e1: Expression, e2: Expression) => BitwiseXor(e1, e2) }
- )
-
- protected lazy val function: Parser[Expression] =
- ( ident <~ ("(" ~ "*" ~ ")") ^^ { case udfName =>
- if (lexical.normalizeKeyword(udfName) == "count") {
- AggregateExpression(Count(Literal(1)), mode = Complete, isDistinct = false)
- } else {
- throw new AnalysisException(s"invalid expression $udfName(*)")
- }
- }
- | ident ~ ("(" ~> repsep(expression, ",")) <~ ")" ^^
- { case udfName ~ exprs => UnresolvedFunction(udfName, exprs, isDistinct = false) }
- | ident ~ ("(" ~ DISTINCT ~> repsep(expression, ",")) <~ ")" ^^ { case udfName ~ exprs =>
- lexical.normalizeKeyword(udfName) match {
- case "count" =>
- aggregate.Count(exprs).toAggregateExpression(isDistinct = true)
- case _ => UnresolvedFunction(udfName, exprs, isDistinct = true)
- }
- }
- | APPROXIMATE ~> ident ~ ("(" ~ DISTINCT ~> expression <~ ")") ^^ { case udfName ~ exp =>
- if (lexical.normalizeKeyword(udfName) == "count") {
- AggregateExpression(new HyperLogLogPlusPlus(exp), mode = Complete, isDistinct = false)
- } else {
- throw new AnalysisException(s"invalid function approximate $udfName")
- }
- }
- | APPROXIMATE ~> "(" ~> unsignedFloat ~ ")" ~ ident ~ "(" ~ DISTINCT ~ expression <~ ")" ^^
- { case s ~ _ ~ udfName ~ _ ~ _ ~ exp =>
- if (lexical.normalizeKeyword(udfName) == "count") {
- AggregateExpression(
- HyperLogLogPlusPlus(exp, s.toDouble, 0, 0),
- mode = Complete,
- isDistinct = false)
- } else {
- throw new AnalysisException(s"invalid function approximate($s) $udfName")
- }
- }
- | CASE ~> whenThenElse ^^
- { case branches => CaseWhen.createFromParser(branches) }
- | CASE ~> expression ~ whenThenElse ^^
- { case keyPart ~ branches => CaseKeyWhen(keyPart, branches) }
- )
-
- protected lazy val whenThenElse: Parser[List[Expression]] =
- rep1(WHEN ~> expression ~ (THEN ~> expression)) ~ (ELSE ~> expression).? <~ END ^^ {
- case altPart ~ elsePart =>
- altPart.flatMap { case whenExpr ~ thenExpr =>
- Seq(whenExpr, thenExpr)
- } ++ elsePart
- }
-
- protected lazy val cast: Parser[Expression] =
- CAST ~ "(" ~> expression ~ (AS ~> dataType) <~ ")" ^^ {
- case exp ~ t => Cast(exp, t)
- }
-
- protected lazy val literal: Parser[Literal] =
- ( numericLiteral
- | booleanLiteral
- | stringLit ^^ { case s => Literal.create(s, StringType) }
- | intervalLiteral
- | NULL ^^^ Literal.create(null, NullType)
- )
-
- protected lazy val booleanLiteral: Parser[Literal] =
- ( TRUE ^^^ Literal.create(true, BooleanType)
- | FALSE ^^^ Literal.create(false, BooleanType)
- )
-
- protected lazy val numericLiteral: Parser[Literal] =
- ( integral ^^ { case i => Literal(toNarrowestIntegerType(i)) }
- | sign.? ~ unsignedFloat ^^
- { case s ~ f => Literal(toDecimalOrDouble(s.getOrElse("") + f)) }
- )
-
- protected lazy val unsignedFloat: Parser[String] =
- ( "." ~> numericLit ^^ { u => "0." + u }
- | elem("decimal", _.isInstanceOf[lexical.DecimalLit]) ^^ (_.chars)
- )
-
- protected lazy val sign: Parser[String] = ("+" | "-")
-
- protected lazy val integral: Parser[String] =
- sign.? ~ numericLit ^^ { case s ~ n => s.getOrElse("") + n }
-
- private def intervalUnit(unitName: String) = acceptIf {
- case lexical.Identifier(str) =>
- val normalized = lexical.normalizeKeyword(str)
- normalized == unitName || normalized == unitName + "s"
- case _ => false
- } {_ => "wrong interval unit"}
-
- protected lazy val month: Parser[Int] =
- integral <~ intervalUnit("month") ^^ { case num => num.toInt }
-
- protected lazy val year: Parser[Int] =
- integral <~ intervalUnit("year") ^^ { case num => num.toInt * 12 }
-
- protected lazy val microsecond: Parser[Long] =
- integral <~ intervalUnit("microsecond") ^^ { case num => num.toLong }
-
- protected lazy val millisecond: Parser[Long] =
- integral <~ intervalUnit("millisecond") ^^ {
- case num => num.toLong * CalendarInterval.MICROS_PER_MILLI
- }
-
- protected lazy val second: Parser[Long] =
- integral <~ intervalUnit("second") ^^ {
- case num => num.toLong * CalendarInterval.MICROS_PER_SECOND
- }
-
- protected lazy val minute: Parser[Long] =
- integral <~ intervalUnit("minute") ^^ {
- case num => num.toLong * CalendarInterval.MICROS_PER_MINUTE
- }
-
- protected lazy val hour: Parser[Long] =
- integral <~ intervalUnit("hour") ^^ {
- case num => num.toLong * CalendarInterval.MICROS_PER_HOUR
- }
-
- protected lazy val day: Parser[Long] =
- integral <~ intervalUnit("day") ^^ {
- case num => num.toLong * CalendarInterval.MICROS_PER_DAY
- }
-
- protected lazy val week: Parser[Long] =
- integral <~ intervalUnit("week") ^^ {
- case num => num.toLong * CalendarInterval.MICROS_PER_WEEK
- }
-
- private def intervalKeyword(keyword: String) = acceptIf {
- case lexical.Identifier(str) =>
- lexical.normalizeKeyword(str) == keyword
- case _ => false
- } {_ => "wrong interval keyword"}
-
- protected lazy val intervalLiteral: Parser[Literal] =
- ( INTERVAL ~> stringLit <~ intervalKeyword("year") ~ intervalKeyword("to") ~
- intervalKeyword("month") ^^ { case s =>
- Literal(CalendarInterval.fromYearMonthString(s))
- }
- | INTERVAL ~> stringLit <~ intervalKeyword("day") ~ intervalKeyword("to") ~
- intervalKeyword("second") ^^ { case s =>
- Literal(CalendarInterval.fromDayTimeString(s))
- }
- | INTERVAL ~> stringLit <~ intervalKeyword("year") ^^ { case s =>
- Literal(CalendarInterval.fromSingleUnitString("year", s))
- }
- | INTERVAL ~> stringLit <~ intervalKeyword("month") ^^ { case s =>
- Literal(CalendarInterval.fromSingleUnitString("month", s))
- }
- | INTERVAL ~> stringLit <~ intervalKeyword("day") ^^ { case s =>
- Literal(CalendarInterval.fromSingleUnitString("day", s))
- }
- | INTERVAL ~> stringLit <~ intervalKeyword("hour") ^^ { case s =>
- Literal(CalendarInterval.fromSingleUnitString("hour", s))
- }
- | INTERVAL ~> stringLit <~ intervalKeyword("minute") ^^ { case s =>
- Literal(CalendarInterval.fromSingleUnitString("minute", s))
- }
- | INTERVAL ~> stringLit <~ intervalKeyword("second") ^^ { case s =>
- Literal(CalendarInterval.fromSingleUnitString("second", s))
- }
- | INTERVAL ~> year.? ~ month.? ~ week.? ~ day.? ~ hour.? ~ minute.? ~ second.? ~
- millisecond.? ~ microsecond.? ^^ { case year ~ month ~ week ~ day ~ hour ~ minute ~ second ~
- millisecond ~ microsecond =>
- if (!Seq(year, month, week, day, hour, minute, second,
- millisecond, microsecond).exists(_.isDefined)) {
- throw new AnalysisException(
- "at least one time unit should be given for interval literal")
- }
- val months = Seq(year, month).map(_.getOrElse(0)).sum
- val microseconds = Seq(week, day, hour, minute, second, millisecond, microsecond)
- .map(_.getOrElse(0L)).sum
- Literal(new CalendarInterval(months, microseconds))
- }
- )
-
- private def toNarrowestIntegerType(value: String): Any = {
- val bigIntValue = BigDecimal(value)
-
- bigIntValue match {
- case v if bigIntValue.isValidInt => v.toIntExact
- case v if bigIntValue.isValidLong => v.toLongExact
- case v => v.underlying()
- }
- }
-
- private def toDecimalOrDouble(value: String): Any = {
- val decimal = BigDecimal(value)
- // follow the behavior in MS SQL Server
- // https://msdn.microsoft.com/en-us/library/ms179899.aspx
- if (value.contains('E') || value.contains('e')) {
- decimal.doubleValue()
- } else {
- decimal.underlying()
- }
- }
-
- protected lazy val baseExpression: Parser[Expression] =
- ( "*" ^^^ UnresolvedStar(None)
- | rep1(ident <~ ".") <~ "*" ^^ { case target => UnresolvedStar(Option(target))}
- | primary
- )
-
- protected lazy val signedPrimary: Parser[Expression] =
- sign ~ primary ^^ { case s ~ e => if (s == "-") UnaryMinus(e) else e }
-
- protected lazy val attributeName: Parser[String] = acceptMatch("attribute name", {
- case lexical.Identifier(str) => str
- case lexical.Keyword(str) if !lexical.delimiters.contains(str) => str
- })
-
- protected lazy val primary: PackratParser[Expression] =
- ( literal
- | expression ~ ("[" ~> expression <~ "]") ^^
- { case base ~ ordinal => UnresolvedExtractValue(base, ordinal) }
- | (expression <~ ".") ~ ident ^^
- { case base ~ fieldName => UnresolvedExtractValue(base, Literal(fieldName)) }
- | cast
- | "(" ~> expression <~ ")"
- | function
- | dotExpressionHeader
- | signedPrimary
- | "~" ~> expression ^^ BitwiseNot
- | attributeName ^^ UnresolvedAttribute.quoted
- )
-
- protected lazy val dotExpressionHeader: Parser[Expression] =
- (ident <~ ".") ~ ident ~ rep("." ~> ident) ^^ {
- case i1 ~ i2 ~ rest => UnresolvedAttribute(Seq(i1, i2) ++ rest)
- }
-
- protected lazy val tableIdentifier: Parser[TableIdentifier] =
- (ident <~ ".").? ~ ident ^^ {
- case maybeDbName ~ tableName => TableIdentifier(tableName, maybeDbName)
- }
-}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala
index e1fd22e367..ec833d6789 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala
@@ -447,6 +447,7 @@ object HyperLogLogPlusPlus {
private def validateDoubleLiteral(exp: Expression): Double = exp match {
case Literal(d: Double, DoubleType) => d
+ case Literal(dec: Decimal, _) => dec.toDouble
case _ =>
throw new AnalysisException("The second argument should be a double literal.")
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystQlSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystQlSuite.scala
index ba9d2524a9..6d25de98ce 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystQlSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystQlSuite.scala
@@ -108,6 +108,7 @@ class CatalystQlSuite extends PlanTest {
}
assertRight("9.0e1", 90)
+ assertRight(".9e+2", 90)
assertRight("0.9e+2", 90)
assertRight("900e-1", 90)
assertRight("900.0E-1", 90)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/SqlParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/SqlParserSuite.scala
deleted file mode 100644
index b0884f5287..0000000000
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/SqlParserSuite.scala
+++ /dev/null
@@ -1,150 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.catalyst
-
-import org.apache.spark.sql.catalyst.analysis.UnresolvedAlias
-import org.apache.spark.sql.catalyst.expressions.{Attribute, GreaterThan, Literal, Not}
-import org.apache.spark.sql.catalyst.plans.PlanTest
-import org.apache.spark.sql.catalyst.plans.logical.{Command, LogicalPlan, OneRowRelation, Project}
-import org.apache.spark.unsafe.types.CalendarInterval
-
-private[sql] case class TestCommand(cmd: String) extends LogicalPlan with Command {
- override def output: Seq[Attribute] = Seq.empty
- override def children: Seq[LogicalPlan] = Seq.empty
-}
-
-private[sql] class SuperLongKeywordTestParser extends AbstractSparkSQLParser {
- protected val EXECUTE = Keyword("THISISASUPERLONGKEYWORDTEST")
-
- override protected lazy val start: Parser[LogicalPlan] = set
-
- private lazy val set: Parser[LogicalPlan] =
- EXECUTE ~> ident ^^ {
- case fileName => TestCommand(fileName)
- }
-}
-
-private[sql] class CaseInsensitiveTestParser extends AbstractSparkSQLParser {
- protected val EXECUTE = Keyword("EXECUTE")
-
- override protected lazy val start: Parser[LogicalPlan] = set
-
- private lazy val set: Parser[LogicalPlan] =
- EXECUTE ~> ident ^^ {
- case fileName => TestCommand(fileName)
- }
-}
-
-class SqlParserSuite extends PlanTest {
-
- test("test long keyword") {
- val parser = new SuperLongKeywordTestParser
- assert(TestCommand("NotRealCommand") ===
- parser.parse("ThisIsASuperLongKeyWordTest NotRealCommand"))
- }
-
- test("test case insensitive") {
- val parser = new CaseInsensitiveTestParser
- assert(TestCommand("NotRealCommand") === parser.parse("EXECUTE NotRealCommand"))
- assert(TestCommand("NotRealCommand") === parser.parse("execute NotRealCommand"))
- assert(TestCommand("NotRealCommand") === parser.parse("exEcute NotRealCommand"))
- }
-
- test("test NOT operator with comparison operations") {
- val parsed = SqlParser.parse("SELECT NOT TRUE > TRUE")
- val expected = Project(
- UnresolvedAlias(
- Not(
- GreaterThan(Literal(true), Literal(true)))
- ) :: Nil,
- OneRowRelation)
- comparePlans(parsed, expected)
- }
-
- test("support hive interval literal") {
- def checkInterval(sql: String, result: CalendarInterval): Unit = {
- val parsed = SqlParser.parse(sql)
- val expected = Project(
- UnresolvedAlias(
- Literal(result)
- ) :: Nil,
- OneRowRelation)
- comparePlans(parsed, expected)
- }
-
- def checkYearMonth(lit: String): Unit = {
- checkInterval(
- s"SELECT INTERVAL '$lit' YEAR TO MONTH",
- CalendarInterval.fromYearMonthString(lit))
- }
-
- def checkDayTime(lit: String): Unit = {
- checkInterval(
- s"SELECT INTERVAL '$lit' DAY TO SECOND",
- CalendarInterval.fromDayTimeString(lit))
- }
-
- def checkSingleUnit(lit: String, unit: String): Unit = {
- checkInterval(
- s"SELECT INTERVAL '$lit' $unit",
- CalendarInterval.fromSingleUnitString(unit, lit))
- }
-
- checkYearMonth("123-10")
- checkYearMonth("496-0")
- checkYearMonth("-2-3")
- checkYearMonth("-123-0")
-
- checkDayTime("99 11:22:33.123456789")
- checkDayTime("-99 11:22:33.123456789")
- checkDayTime("10 9:8:7.123456789")
- checkDayTime("1 0:0:0")
- checkDayTime("-1 0:0:0")
- checkDayTime("1 0:0:1")
-
- for (unit <- Seq("year", "month", "day", "hour", "minute", "second")) {
- checkSingleUnit("7", unit)
- checkSingleUnit("-7", unit)
- checkSingleUnit("0", unit)
- }
-
- checkSingleUnit("13.123456789", "second")
- checkSingleUnit("-13.123456789", "second")
- }
-
- test("support scientific notation") {
- def assertRight(input: String, output: Double): Unit = {
- val parsed = SqlParser.parse("SELECT " + input)
- val expected = Project(
- UnresolvedAlias(
- Literal(output)
- ) :: Nil,
- OneRowRelation)
- comparePlans(parsed, expected)
- }
-
- assertRight("9.0e1", 90)
- assertRight(".9e+2", 90)
- assertRight("0.9e+2", 90)
- assertRight("900e-1", 90)
- assertRight("900.0E-1", 90)
- assertRight("9.e+1", 90)
-
- intercept[RuntimeException](SqlParser.parse("SELECT .e3"))
- }
-}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
index 6a020f9f28..97bf7a0cc4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
@@ -21,7 +21,6 @@ import scala.language.implicitConversions
import org.apache.spark.Logging
import org.apache.spark.annotation.Experimental
-import org.apache.spark.sql.catalyst.SqlParser._
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder}
import org.apache.spark.sql.catalyst.expressions._
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
index 91bf2f8ce4..3422d0ead4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
@@ -30,7 +30,7 @@ import org.apache.spark.annotation.{DeveloperApi, Experimental}
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.api.python.PythonRDD
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, ScalaReflection, SqlParser}
+import org.apache.spark.sql.catalyst._
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
@@ -737,7 +737,7 @@ class DataFrame private[sql](
@scala.annotation.varargs
def selectExpr(exprs: String*): DataFrame = {
select(exprs.map { expr =>
- Column(SqlParser.parseExpression(expr))
+ Column(sqlContext.sqlParser.parseExpression(expr))
}: _*)
}
@@ -764,7 +764,7 @@ class DataFrame private[sql](
* @since 1.3.0
*/
def filter(conditionExpr: String): DataFrame = {
- filter(Column(SqlParser.parseExpression(conditionExpr)))
+ filter(Column(sqlContext.sqlParser.parseExpression(conditionExpr)))
}
/**
@@ -788,7 +788,7 @@ class DataFrame private[sql](
* @since 1.5.0
*/
def where(conditionExpr: String): DataFrame = {
- filter(Column(SqlParser.parseExpression(conditionExpr)))
+ filter(Column(sqlContext.sqlParser.parseExpression(conditionExpr)))
}
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
index d948e48942..8f852e5216 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
@@ -29,7 +29,7 @@ import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.catalyst.SqlParser
+import org.apache.spark.sql.catalyst.{CatalystQl}
import org.apache.spark.sql.execution.datasources.{LogicalRelation, ResolvedDataSource}
import org.apache.spark.sql.execution.datasources.jdbc.{JDBCPartition, JDBCPartitioningInfo, JDBCRelation}
import org.apache.spark.sql.execution.datasources.json.JSONRelation
@@ -337,7 +337,7 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging {
*/
def table(tableName: String): DataFrame = {
DataFrame(sqlContext,
- sqlContext.catalog.lookupRelation(SqlParser.parseTableIdentifier(tableName)))
+ sqlContext.catalog.lookupRelation(sqlContext.sqlParser.parseTableIdentifier(tableName)))
}
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
index 00f9817b53..ab63fe4aa8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
@@ -22,7 +22,7 @@ import java.util.Properties
import scala.collection.JavaConverters._
import org.apache.spark.annotation.Experimental
-import org.apache.spark.sql.catalyst.{SqlParser, TableIdentifier}
+import org.apache.spark.sql.catalyst.{CatalystQl, TableIdentifier}
import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoTable, Project}
import org.apache.spark.sql.execution.datasources.{BucketSpec, CreateTableUsingAsSelect, ResolvedDataSource}
@@ -192,7 +192,7 @@ final class DataFrameWriter private[sql](df: DataFrame) {
* @since 1.4.0
*/
def insertInto(tableName: String): Unit = {
- insertInto(SqlParser.parseTableIdentifier(tableName))
+ insertInto(df.sqlContext.sqlParser.parseTableIdentifier(tableName))
}
private def insertInto(tableIdent: TableIdentifier): Unit = {
@@ -282,7 +282,7 @@ final class DataFrameWriter private[sql](df: DataFrame) {
* @since 1.4.0
*/
def saveAsTable(tableName: String): Unit = {
- saveAsTable(SqlParser.parseTableIdentifier(tableName))
+ saveAsTable(df.sqlContext.sqlParser.parseTableIdentifier(tableName))
}
private def saveAsTable(tableIdent: TableIdentifier): Unit = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index b909765a7c..a0939adb6d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -39,6 +39,7 @@ import org.apache.spark.sql.catalyst.encoders.encoderFor
import org.apache.spark.sql.catalyst.errors.DialectException
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.optimizer.Optimizer
+import org.apache.spark.sql.catalyst.parser.ParserConf
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Range}
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.execution._
@@ -205,15 +206,17 @@ class SQLContext private[sql](
protected[sql] lazy val optimizer: Optimizer = new SparkOptimizer(this)
@transient
- protected[sql] val ddlParser = new DDLParser(sqlParser.parse(_))
+ protected[sql] val ddlParser = new DDLParser(sqlParser)
@transient
- protected[sql] val sqlParser = new SparkSQLParser(getSQLDialect().parse(_))
+ protected[sql] val sqlParser = new SparkSQLParser(getSQLDialect())
protected[sql] def getSQLDialect(): ParserDialect = {
try {
val clazz = Utils.classForName(dialectClassName)
- clazz.newInstance().asInstanceOf[ParserDialect]
+ clazz.getConstructor(classOf[ParserConf])
+ .newInstance(conf)
+ .asInstanceOf[ParserDialect]
} catch {
case NonFatal(e) =>
// Since we didn't find the available SQL Dialect, it will fail even for SET command:
@@ -237,7 +240,7 @@ class SQLContext private[sql](
new sparkexecution.QueryExecution(this, plan)
protected[sql] def dialectClassName = if (conf.dialect == "sql") {
- classOf[DefaultParserDialect].getCanonicalName
+ classOf[SparkQl].getCanonicalName
} else {
conf.dialect
}
@@ -682,7 +685,7 @@ class SQLContext private[sql](
tableName: String,
source: String,
options: Map[String, String]): DataFrame = {
- val tableIdent = SqlParser.parseTableIdentifier(tableName)
+ val tableIdent = sqlParser.parseTableIdentifier(tableName)
val cmd =
CreateTableUsing(
tableIdent,
@@ -728,7 +731,7 @@ class SQLContext private[sql](
source: String,
schema: StructType,
options: Map[String, String]): DataFrame = {
- val tableIdent = SqlParser.parseTableIdentifier(tableName)
+ val tableIdent = sqlParser.parseTableIdentifier(tableName)
val cmd =
CreateTableUsing(
tableIdent,
@@ -833,7 +836,7 @@ class SQLContext private[sql](
* @since 1.3.0
*/
def table(tableName: String): DataFrame = {
- table(SqlParser.parseTableIdentifier(tableName))
+ table(sqlParser.parseTableIdentifier(tableName))
}
private def table(tableIdent: TableIdentifier): DataFrame = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSQLParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSQLParser.scala
index b3e8d0d849..1af2c756cd 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSQLParser.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSQLParser.scala
@@ -19,8 +19,8 @@ package org.apache.spark.sql.execution
import scala.util.parsing.combinator.RegexParsers
-import org.apache.spark.sql.catalyst.AbstractSparkSQLParser
-import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
+import org.apache.spark.sql.catalyst.{AbstractSparkSQLParser, ParserDialect, TableIdentifier}
+import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression}
import org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.types.StringType
@@ -29,9 +29,16 @@ import org.apache.spark.sql.types.StringType
* The top level Spark SQL parser. This parser recognizes syntaxes that are available for all SQL
* dialects supported by Spark SQL, and delegates all the other syntaxes to the `fallback` parser.
*
- * @param fallback A function that parses an input string to a logical plan
+ * @param fallback A function that returns the next parser in the chain. This is a call-by-name
+ * parameter because this allows us to return a different dialect if we
+ * have to.
*/
-class SparkSQLParser(fallback: String => LogicalPlan) extends AbstractSparkSQLParser {
+class SparkSQLParser(fallback: => ParserDialect) extends AbstractSparkSQLParser {
+
+ override def parseExpression(sql: String): Expression = fallback.parseExpression(sql)
+
+ override def parseTableIdentifier(sql: String): TableIdentifier =
+ fallback.parseTableIdentifier(sql)
// A parser for the key-value part of the "SET [key = [value ]]" syntax
private object SetCommandParser extends RegexParsers {
@@ -74,7 +81,7 @@ class SparkSQLParser(fallback: String => LogicalPlan) extends AbstractSparkSQLPa
private lazy val cache: Parser[LogicalPlan] =
CACHE ~> LAZY.? ~ (TABLE ~> ident) ~ (AS ~> restInput).? ^^ {
case isLazy ~ tableName ~ plan =>
- CacheTableCommand(tableName, plan.map(fallback), isLazy.isDefined)
+ CacheTableCommand(tableName, plan.map(fallback.parsePlan), isLazy.isDefined)
}
private lazy val uncache: Parser[LogicalPlan] =
@@ -111,7 +118,7 @@ class SparkSQLParser(fallback: String => LogicalPlan) extends AbstractSparkSQLPa
private lazy val others: Parser[LogicalPlan] =
wholeInput ^^ {
- case input => fallback(input)
+ case input => fallback.parsePlan(input)
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DDLParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DDLParser.scala
index d8d21b06b8..10655a85cc 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DDLParser.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DDLParser.scala
@@ -22,25 +22,30 @@ import scala.util.matching.Regex
import org.apache.spark.Logging
import org.apache.spark.sql.SaveMode
-import org.apache.spark.sql.catalyst.{AbstractSparkSQLParser, TableIdentifier}
+import org.apache.spark.sql.catalyst.{AbstractSparkSQLParser, ParserDialect, TableIdentifier}
import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
+import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.util.DataTypeParser
import org.apache.spark.sql.types._
-
/**
* A parser for foreign DDL commands.
*/
-class DDLParser(parseQuery: String => LogicalPlan)
+class DDLParser(fallback: => ParserDialect)
extends AbstractSparkSQLParser with DataTypeParser with Logging {
+ override def parseExpression(sql: String): Expression = fallback.parseExpression(sql)
+
+ override def parseTableIdentifier(sql: String): TableIdentifier =
+
+ fallback.parseTableIdentifier(sql)
def parse(input: String, exceptionOnError: Boolean): LogicalPlan = {
try {
- parse(input)
+ parsePlan(input)
} catch {
case ddlException: DDLException => throw ddlException
- case _ if !exceptionOnError => parseQuery(input)
+ case _ if !exceptionOnError => fallback.parsePlan(input)
case x: Throwable => throw x
}
}
@@ -104,7 +109,7 @@ class DDLParser(parseQuery: String => LogicalPlan)
SaveMode.ErrorIfExists
}
- val queryPlan = parseQuery(query.get)
+ val queryPlan = fallback.parsePlan(query.get)
CreateTableUsingAsSelect(tableIdent,
provider,
temp.isDefined,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index b8ea2261e9..8c2530fd68 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -22,7 +22,7 @@ import scala.reflect.runtime.universe.{typeTag, TypeTag}
import scala.util.Try
import org.apache.spark.annotation.Experimental
-import org.apache.spark.sql.catalyst.{ScalaReflection, SqlParser}
+import org.apache.spark.sql.catalyst.{CatalystQl, ScalaReflection}
import org.apache.spark.sql.catalyst.analysis.{Star, UnresolvedFunction}
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions._
@@ -1063,7 +1063,10 @@ object functions extends LegacyFunctions {
*
* @group normal_funcs
*/
- def expr(expr: String): Column = Column(SqlParser.parseExpression(expr))
+ def expr(expr: String): Column = {
+ val parser = SQLContext.getActive().map(_.getSQLDialect()).getOrElse(new CatalystQl())
+ Column(parser.parseExpression(expr))
+ }
//////////////////////////////////////////////////////////////////////////////////////////////
// Math Functions
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala
index 58f982c2bc..aec450e0a6 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala
@@ -212,7 +212,7 @@ class MathExpressionsSuite extends QueryTest with SharedSQLContext {
Seq(Row(5, 10, 0), Row(55, 60, 100), Row(555, 560, 600))
)
- val pi = 3.1415
+ val pi = "3.1415BD"
checkAnswer(
sql(s"SELECT round($pi, -3), round($pi, -2), round($pi, -1), " +
s"round($pi, 0), round($pi, 1), round($pi, 2), round($pi, 3)"),
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index 03d67c4e91..75e81b9c91 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -21,10 +21,11 @@ import java.math.MathContext
import java.sql.Timestamp
import org.apache.spark.AccumulatorSuite
-import org.apache.spark.sql.catalyst.DefaultParserDialect
+import org.apache.spark.sql.catalyst.CatalystQl
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry
import org.apache.spark.sql.catalyst.errors.DialectException
-import org.apache.spark.sql.execution.aggregate
+import org.apache.spark.sql.catalyst.parser.ParserConf
+import org.apache.spark.sql.execution.{aggregate, SparkQl}
import org.apache.spark.sql.execution.joins.{CartesianProduct, SortMergeJoin}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.{SharedSQLContext, TestSQLContext}
@@ -32,7 +33,7 @@ import org.apache.spark.sql.test.SQLTestData._
import org.apache.spark.sql.types._
/** A SQL Dialect for testing purpose, and it can not be nested type */
-class MyDialect extends DefaultParserDialect
+class MyDialect(conf: ParserConf) extends CatalystQl(conf)
class SQLQuerySuite extends QueryTest with SharedSQLContext {
import testImplicits._
@@ -161,7 +162,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
newContext.sql("SELECT 1")
}
// test if the dialect set back to DefaultSQLDialect
- assert(newContext.getSQLDialect().getClass === classOf[DefaultParserDialect])
+ assert(newContext.getSQLDialect().getClass === classOf[SparkQl])
}
test("SPARK-4625 support SORT BY in SimpleSQLParser & DSL") {
@@ -586,7 +587,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
}
test("Allow only a single WITH clause per query") {
- intercept[RuntimeException] {
+ intercept[AnalysisException] {
sql(
"with q1 as (select * from testData) with q2 as (select * from q1) select * from q2")
}
@@ -602,8 +603,8 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
test("from follow multiple brackets") {
checkAnswer(sql(
"""
- |select key from ((select * from testData limit 1)
- | union all (select * from testData limit 1)) x limit 1
+ |select key from ((select * from testData)
+ | union all (select * from testData)) x limit 1
""".stripMargin),
Row(1)
)
@@ -616,7 +617,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
checkAnswer(sql(
"""
|select key from
- | (select * from testData limit 1 union all select * from testData limit 1) x
+ | (select * from testData union all select * from testData) x
| limit 1
""".stripMargin),
Row(1)
@@ -649,13 +650,13 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
test("approximate count distinct") {
checkAnswer(
- sql("SELECT APPROXIMATE COUNT(DISTINCT a) FROM testData2"),
+ sql("SELECT APPROX_COUNT_DISTINCT(a) FROM testData2"),
Row(3))
}
test("approximate count distinct with user provided standard deviation") {
checkAnswer(
- sql("SELECT APPROXIMATE(0.04) COUNT(DISTINCT a) FROM testData2"),
+ sql("SELECT APPROX_COUNT_DISTINCT(a, 0.04) FROM testData2"),
Row(3))
}
@@ -1192,19 +1193,19 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
test("Floating point number format") {
checkAnswer(
- sql("SELECT 0.3"), Row(BigDecimal(0.3).underlying())
+ sql("SELECT 0.3"), Row(0.3)
)
checkAnswer(
- sql("SELECT -0.8"), Row(BigDecimal(-0.8).underlying())
+ sql("SELECT -0.8"), Row(-0.8)
)
checkAnswer(
- sql("SELECT .5"), Row(BigDecimal(0.5))
+ sql("SELECT .5"), Row(0.5)
)
checkAnswer(
- sql("SELECT -.18"), Row(BigDecimal(-0.18))
+ sql("SELECT -.18"), Row(-0.18)
)
}
@@ -1218,11 +1219,11 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
)
checkAnswer(
- sql("SELECT 9223372036854775808"), Row(new java.math.BigDecimal("9223372036854775808"))
+ sql("SELECT 9223372036854775808BD"), Row(new java.math.BigDecimal("9223372036854775808"))
)
checkAnswer(
- sql("SELECT -9223372036854775809"), Row(new java.math.BigDecimal("-9223372036854775809"))
+ sql("SELECT -9223372036854775809BD"), Row(new java.math.BigDecimal("-9223372036854775809"))
)
}
@@ -1237,11 +1238,11 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
)
checkAnswer(
- sql("SELECT -5.2"), Row(BigDecimal(-5.2))
+ sql("SELECT -5.2BD"), Row(BigDecimal(-5.2))
)
checkAnswer(
- sql("SELECT +6.8"), Row(BigDecimal(6.8))
+ sql("SELECT +6.8"), Row(6.8d)
)
checkAnswer(
@@ -1616,20 +1617,20 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
}
test("decimal precision with multiply/division") {
- checkAnswer(sql("select 10.3 * 3.0"), Row(BigDecimal("30.90")))
- checkAnswer(sql("select 10.3000 * 3.0"), Row(BigDecimal("30.90000")))
- checkAnswer(sql("select 10.30000 * 30.0"), Row(BigDecimal("309.000000")))
- checkAnswer(sql("select 10.300000000000000000 * 3.000000000000000000"),
+ checkAnswer(sql("select 10.3BD * 3.0BD"), Row(BigDecimal("30.90")))
+ checkAnswer(sql("select 10.3000BD * 3.0BD"), Row(BigDecimal("30.90000")))
+ checkAnswer(sql("select 10.30000BD * 30.0BD"), Row(BigDecimal("309.000000")))
+ checkAnswer(sql("select 10.300000000000000000BD * 3.000000000000000000BD"),
Row(BigDecimal("30.900000000000000000000000000000000000", new MathContext(38))))
- checkAnswer(sql("select 10.300000000000000000 * 3.0000000000000000000"),
+ checkAnswer(sql("select 10.300000000000000000BD * 3.0000000000000000000BD"),
Row(null))
- checkAnswer(sql("select 10.3 / 3.0"), Row(BigDecimal("3.433333")))
- checkAnswer(sql("select 10.3000 / 3.0"), Row(BigDecimal("3.4333333")))
- checkAnswer(sql("select 10.30000 / 30.0"), Row(BigDecimal("0.343333333")))
- checkAnswer(sql("select 10.300000000000000000 / 3.00000000000000000"),
+ checkAnswer(sql("select 10.3BD / 3.0BD"), Row(BigDecimal("3.433333")))
+ checkAnswer(sql("select 10.3000BD / 3.0BD"), Row(BigDecimal("3.4333333")))
+ checkAnswer(sql("select 10.30000BD / 30.0BD"), Row(BigDecimal("0.343333333")))
+ checkAnswer(sql("select 10.300000000000000000BD / 3.00000000000000000BD"),
Row(BigDecimal("3.433333333333333333333333333", new MathContext(38))))
- checkAnswer(sql("select 10.3000000000000000000 / 3.00000000000000000"),
+ checkAnswer(sql("select 10.3000000000000000000BD / 3.00000000000000000BD"),
Row(BigDecimal("3.4333333333333333333333333333", new MathContext(38))))
}
@@ -1655,13 +1656,13 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
}
test("precision smaller than scale") {
- checkAnswer(sql("select 10.00"), Row(BigDecimal("10.00")))
- checkAnswer(sql("select 1.00"), Row(BigDecimal("1.00")))
- checkAnswer(sql("select 0.10"), Row(BigDecimal("0.10")))
- checkAnswer(sql("select 0.01"), Row(BigDecimal("0.01")))
- checkAnswer(sql("select 0.001"), Row(BigDecimal("0.001")))
- checkAnswer(sql("select -0.01"), Row(BigDecimal("-0.01")))
- checkAnswer(sql("select -0.001"), Row(BigDecimal("-0.001")))
+ checkAnswer(sql("select 10.00BD"), Row(BigDecimal("10.00")))
+ checkAnswer(sql("select 1.00BD"), Row(BigDecimal("1.00")))
+ checkAnswer(sql("select 0.10BD"), Row(BigDecimal("0.10")))
+ checkAnswer(sql("select 0.01BD"), Row(BigDecimal("0.01")))
+ checkAnswer(sql("select 0.001BD"), Row(BigDecimal("0.001")))
+ checkAnswer(sql("select -0.01BD"), Row(BigDecimal("-0.01")))
+ checkAnswer(sql("select -0.001BD"), Row(BigDecimal("-0.001")))
}
test("external sorting updates peak execution memory") {
@@ -1750,7 +1751,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
assert(e1.message.contains("Table not found"))
val e2 = intercept[AnalysisException] {
- sql("select * from no_db.no_table")
+ sql("select * from no_db.no_table").show()
}
assert(e2.message.contains("Table not found"))
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala
index 860e07c68c..e70eb2a060 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala
@@ -442,13 +442,13 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
// Number and String conflict: resolve the type as number in this query.
checkAnswer(
- sql("select num_str + 1.2 from jsonTable where num_str > 14"),
+ sql("select num_str + 1.2BD from jsonTable where num_str > 14"),
Row(BigDecimal("92233720368547758071.2"))
)
// Number and String conflict: resolve the type as number in this query.
checkAnswer(
- sql("select num_str + 1.2 from jsonTable where num_str >= 92233720368547758060"),
+ sql("select num_str + 1.2BD from jsonTable where num_str >= 92233720368547758060BD"),
Row(new java.math.BigDecimal("92233720368547758071.2"))
)
@@ -856,7 +856,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
jsonWithSimpleMap.registerTempTable("jsonWithSimpleMap")
checkAnswer(
- sql("select map from jsonWithSimpleMap"),
+ sql("select `map` from jsonWithSimpleMap"),
Row(Map("a" -> 1)) ::
Row(Map("b" -> 2)) ::
Row(Map("c" -> 3)) ::
@@ -865,7 +865,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
)
checkAnswer(
- sql("select map['c'] from jsonWithSimpleMap"),
+ sql("select `map`['c'] from jsonWithSimpleMap"),
Row(null) ::
Row(null) ::
Row(3) ::
@@ -884,7 +884,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
jsonWithComplexMap.registerTempTable("jsonWithComplexMap")
checkAnswer(
- sql("select map from jsonWithComplexMap"),
+ sql("select `map` from jsonWithComplexMap"),
Row(Map("a" -> Row(Seq(1, 2, 3, null), null))) ::
Row(Map("b" -> Row(null, 2))) ::
Row(Map("c" -> Row(Seq(), 4))) ::
@@ -894,7 +894,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
)
checkAnswer(
- sql("select map['a'].field1, map['c'].field2 from jsonWithComplexMap"),
+ sql("select `map`['a'].field1, `map`['c'].field2 from jsonWithComplexMap"),
Row(Seq(1, 2, 3, null), null) ::
Row(null, null) ::
Row(null, 4) ::
diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala
index afd2f61158..828ec97105 100644
--- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala
+++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala
@@ -296,6 +296,9 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
// Odd changes to output
"merge4",
+ // Unsupported underscore syntax.
+ "inputddl5",
+
// Thift is broken...
"inputddl8",
@@ -603,7 +606,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
"inputddl2",
"inputddl3",
"inputddl4",
- "inputddl5",
"inputddl6",
"inputddl7",
"inputddl8",
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/ExtendedHiveQlParser.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/ExtendedHiveQlParser.scala
index b22f424981..313ba18f6a 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/ExtendedHiveQlParser.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/ExtendedHiveQlParser.scala
@@ -19,14 +19,23 @@ package org.apache.spark.sql.hive
import scala.language.implicitConversions
+import org.apache.spark.sql.catalyst.{AbstractSparkSQLParser, TableIdentifier}
+import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.plans.logical._
-import org.apache.spark.sql.catalyst.AbstractSparkSQLParser
import org.apache.spark.sql.hive.execution.{AddFile, AddJar, HiveNativeCommand}
/**
* A parser that recognizes all HiveQL constructs together with Spark SQL specific extensions.
*/
-private[hive] class ExtendedHiveQlParser extends AbstractSparkSQLParser {
+private[hive] class ExtendedHiveQlParser(sqlContext: HiveContext) extends AbstractSparkSQLParser {
+
+ val parser = new HiveQl(sqlContext.conf)
+
+ override def parseExpression(sql: String): Expression = parser.parseExpression(sql)
+
+ override def parseTableIdentifier(sql: String): TableIdentifier =
+ parser.parseTableIdentifier(sql)
+
// Keyword is a convention with AbstractSparkSQLParser, which will scan all of the `Keyword`
// properties via reflection the class in runtime for constructing the SqlLexical object
protected val ADD = Keyword("ADD")
@@ -38,7 +47,10 @@ private[hive] class ExtendedHiveQlParser extends AbstractSparkSQLParser {
protected lazy val hiveQl: Parser[LogicalPlan] =
restInput ^^ {
- case statement => HiveQl.parsePlan(statement.trim)
+ case statement =>
+ sqlContext.executionHive.withHiveState {
+ parser.parsePlan(statement.trim)
+ }
}
protected lazy val dfs: Parser[LogicalPlan] =
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
index cbaf00603e..7bdca52200 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
@@ -42,7 +42,7 @@ import org.apache.spark.api.java.JavaSparkContext
import org.apache.spark.sql._
import org.apache.spark.sql.SQLConf.SQLConfEntry
import org.apache.spark.sql.SQLConf.SQLConfEntry._
-import org.apache.spark.sql.catalyst.{InternalRow, ParserDialect, SqlParser}
+import org.apache.spark.sql.catalyst.{InternalRow, ParserDialect}
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.expressions.{Expression, LeafExpression}
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
@@ -57,17 +57,6 @@ import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.Utils
/**
- * This is the HiveQL Dialect, this dialect is strongly bind with HiveContext
- */
-private[hive] class HiveQLDialect(sqlContext: HiveContext) extends ParserDialect {
- override def parse(sqlText: String): LogicalPlan = {
- sqlContext.executionHive.withHiveState {
- HiveQl.parseSql(sqlText)
- }
- }
-}
-
-/**
* Returns the current database of metadataHive.
*/
private[hive] case class CurrentDatabase(ctx: HiveContext)
@@ -342,12 +331,12 @@ class HiveContext private[hive](
* @since 1.3.0
*/
def refreshTable(tableName: String): Unit = {
- val tableIdent = SqlParser.parseTableIdentifier(tableName)
+ val tableIdent = sqlParser.parseTableIdentifier(tableName)
catalog.refreshTable(tableIdent)
}
protected[hive] def invalidateTable(tableName: String): Unit = {
- val tableIdent = SqlParser.parseTableIdentifier(tableName)
+ val tableIdent = sqlParser.parseTableIdentifier(tableName)
catalog.invalidateTable(tableIdent)
}
@@ -361,7 +350,7 @@ class HiveContext private[hive](
* @since 1.2.0
*/
def analyze(tableName: String) {
- val tableIdent = SqlParser.parseTableIdentifier(tableName)
+ val tableIdent = sqlParser.parseTableIdentifier(tableName)
val relation = EliminateSubQueries(catalog.lookupRelation(tableIdent))
relation match {
@@ -559,7 +548,7 @@ class HiveContext private[hive](
protected[sql] override def getSQLDialect(): ParserDialect = {
if (conf.dialect == "hiveql") {
- new HiveQLDialect(this)
+ new ExtendedHiveQlParser(this)
} else {
super.getSQLDialect()
}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
index daaa5a5709..3d54048c24 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
@@ -416,8 +416,8 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive
alias match {
// because hive use things like `_c0` to build the expanded text
// currently we cannot support view from "create view v1(c1) as ..."
- case None => Subquery(table.name, HiveQl.parsePlan(viewText))
- case Some(aliasText) => Subquery(aliasText, HiveQl.parsePlan(viewText))
+ case None => Subquery(table.name, hive.parseSql(viewText))
+ case Some(aliasText) => Subquery(aliasText, hive.parseSql(viewText))
}
} else {
MetastoreRelation(qualifiedTableName.database, qualifiedTableName.name, alias)(table)(hive)
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
index ca9ddf94c1..46246f8191 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
@@ -79,7 +79,7 @@ private[hive] case class CreateViewAsSelect(
}
/** Provides a mapping from HiveQL statements to catalyst logical plans and expression trees. */
-private[hive] object HiveQl extends SparkQl with Logging {
+private[hive] class HiveQl(conf: ParserConf) extends SparkQl(conf) with Logging {
protected val nativeCommands = Seq(
"TOK_ALTERDATABASE_OWNER",
"TOK_ALTERDATABASE_PROPERTIES",
@@ -168,8 +168,6 @@ private[hive] object HiveQl extends SparkQl with Logging {
"TOK_TRUNCATETABLE" // truncate table" is a NativeCommand, does not need to explain.
) ++ nativeCommands
- protected val hqlParser = new ExtendedHiveQlParser
-
/**
* Returns the HiveConf
*/
@@ -186,9 +184,6 @@ private[hive] object HiveQl extends SparkQl with Logging {
ss.getConf
}
- /** Returns a LogicalPlan for a given HiveQL string. */
- def parseSql(sql: String): LogicalPlan = hqlParser.parse(sql)
-
protected def getProperties(node: ASTNode): Seq[(String, String)] = node match {
case Token("TOK_TABLEPROPLIST", list) =>
list.map {
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQlSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQlSuite.scala
index 53d15c14cb..137dadd6c6 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQlSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQlSuite.scala
@@ -23,12 +23,15 @@ import org.scalatest.BeforeAndAfterAll
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions.JsonTuple
+import org.apache.spark.sql.catalyst.parser.SimpleParserConf
import org.apache.spark.sql.catalyst.plans.logical.Generate
import org.apache.spark.sql.hive.client.{ExternalTable, HiveColumn, HiveTable, ManagedTable}
class HiveQlSuite extends SparkFunSuite with BeforeAndAfterAll {
+ val parser = new HiveQl(SimpleParserConf())
+
private def extractTableDesc(sql: String): (HiveTable, Boolean) = {
- HiveQl.parsePlan(sql).collect {
+ parser.parsePlan(sql).collect {
case CreateTableAsSelect(desc, child, allowExisting) => (desc, allowExisting)
}.head
}
@@ -173,7 +176,7 @@ class HiveQlSuite extends SparkFunSuite with BeforeAndAfterAll {
test("Invalid interval term should throw AnalysisException") {
def assertError(sql: String, errorMessage: String): Unit = {
val e = intercept[AnalysisException] {
- HiveQl.parseSql(sql)
+ parser.parsePlan(sql)
}
assert(e.getMessage.contains(errorMessage))
}
@@ -186,7 +189,7 @@ class HiveQlSuite extends SparkFunSuite with BeforeAndAfterAll {
}
test("use native json_tuple instead of hive's UDTF in LATERAL VIEW") {
- val plan = HiveQl.parseSql(
+ val plan = parser.parsePlan(
"""
|SELECT *
|FROM (SELECT '{"f1": "value1", "f2": 12}' json) test
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala
index 78f74cdc19..91bedf9c5a 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala
@@ -21,6 +21,7 @@ import scala.reflect.ClassTag
import org.apache.spark.sql.{QueryTest, Row, SQLConf}
import org.apache.spark.sql.catalyst.TableIdentifier
+import org.apache.spark.sql.catalyst.parser.SimpleParserConf
import org.apache.spark.sql.execution.joins._
import org.apache.spark.sql.hive.execution._
import org.apache.spark.sql.hive.test.TestHiveSingleton
@@ -28,9 +29,11 @@ import org.apache.spark.sql.hive.test.TestHiveSingleton
class StatisticsSuite extends QueryTest with TestHiveSingleton {
import hiveContext.sql
+ val parser = new HiveQl(SimpleParserConf())
+
test("parse analyze commands") {
def assertAnalyzeCommand(analyzeCommand: String, c: Class[_]) {
- val parsed = HiveQl.parseSql(analyzeCommand)
+ val parsed = parser.parsePlan(analyzeCommand)
val operators = parsed.collect {
case a: AnalyzeTable => a
case o => o
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
index f6c687aab7..61d5aa7ae6 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
@@ -22,12 +22,14 @@ import java.sql.{Date, Timestamp}
import scala.collection.JavaConverters._
import org.apache.spark.sql._
-import org.apache.spark.sql.catalyst.{DefaultParserDialect, TableIdentifier}
+import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis.{EliminateSubQueries, FunctionRegistry}
import org.apache.spark.sql.catalyst.errors.DialectException
+import org.apache.spark.sql.catalyst.parser.ParserConf
+import org.apache.spark.sql.execution.SparkQl
import org.apache.spark.sql.execution.datasources.LogicalRelation
import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation
-import org.apache.spark.sql.hive.{HiveContext, HiveQLDialect, MetastoreRelation}
+import org.apache.spark.sql.hive.{ExtendedHiveQlParser, HiveContext, HiveQl, MetastoreRelation}
import org.apache.spark.sql.hive.test.TestHiveSingleton
import org.apache.spark.sql.test.SQLTestUtils
import org.apache.spark.sql.types._
@@ -56,7 +58,7 @@ case class WindowData(
area: String,
product: Int)
/** A SQL Dialect for testing purpose, and it can not be nested type */
-class MyDialect extends DefaultParserDialect
+class MyDialect(conf: ParserConf) extends HiveQl(conf)
/**
* A collection of hive query tests where we generate the answers ourselves instead of depending on
@@ -339,20 +341,20 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton {
val hiveContext = new HiveContext(sqlContext.sparkContext)
val dialectConf = "spark.sql.dialect"
checkAnswer(hiveContext.sql(s"set $dialectConf"), Row(dialectConf, "hiveql"))
- assert(hiveContext.getSQLDialect().getClass === classOf[HiveQLDialect])
+ assert(hiveContext.getSQLDialect().getClass === classOf[ExtendedHiveQlParser])
}
test("SQL Dialect Switching") {
- assert(getSQLDialect().getClass === classOf[HiveQLDialect])
+ assert(getSQLDialect().getClass === classOf[ExtendedHiveQlParser])
setConf("spark.sql.dialect", classOf[MyDialect].getCanonicalName())
assert(getSQLDialect().getClass === classOf[MyDialect])
assert(sql("SELECT 1").collect() === Array(Row(1)))
// set the dialect back to the DefaultSQLDialect
sql("SET spark.sql.dialect=sql")
- assert(getSQLDialect().getClass === classOf[DefaultParserDialect])
+ assert(getSQLDialect().getClass === classOf[SparkQl])
sql("SET spark.sql.dialect=hiveql")
- assert(getSQLDialect().getClass === classOf[HiveQLDialect])
+ assert(getSQLDialect().getClass === classOf[ExtendedHiveQlParser])
// set invalid dialect
sql("SET spark.sql.dialect.abc=MyTestClass")
@@ -361,14 +363,14 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton {
sql("SELECT 1")
}
// test if the dialect set back to HiveQLDialect
- getSQLDialect().getClass === classOf[HiveQLDialect]
+ getSQLDialect().getClass === classOf[ExtendedHiveQlParser]
sql("SET spark.sql.dialect=MyTestClass")
intercept[DialectException] {
sql("SELECT 1")
}
// test if the dialect set back to HiveQLDialect
- assert(getSQLDialect().getClass === classOf[HiveQLDialect])
+ assert(getSQLDialect().getClass === classOf[ExtendedHiveQlParser])
}
test("CTAS with serde") {
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/CalendarInterval.java b/unsafe/src/main/java/org/apache/spark/unsafe/types/CalendarInterval.java
index 30e1758076..62edf6c64b 100644
--- a/unsafe/src/main/java/org/apache/spark/unsafe/types/CalendarInterval.java
+++ b/unsafe/src/main/java/org/apache/spark/unsafe/types/CalendarInterval.java
@@ -188,6 +188,11 @@ public final class CalendarInterval implements Serializable {
Integer.MIN_VALUE, Integer.MAX_VALUE);
result = new CalendarInterval(month, 0L);
+ } else if (unit.equals("week")) {
+ long week = toLongWithRange("week", m.group(1),
+ Long.MIN_VALUE / MICROS_PER_WEEK, Long.MAX_VALUE / MICROS_PER_WEEK);
+ result = new CalendarInterval(0, week * MICROS_PER_WEEK);
+
} else if (unit.equals("day")) {
long day = toLongWithRange("day", m.group(1),
Long.MIN_VALUE / MICROS_PER_DAY, Long.MAX_VALUE / MICROS_PER_DAY);
@@ -206,6 +211,15 @@ public final class CalendarInterval implements Serializable {
} else if (unit.equals("second")) {
long micros = parseSecondNano(m.group(1));
result = new CalendarInterval(0, micros);
+
+ } else if (unit.equals("millisecond")) {
+ long millisecond = toLongWithRange("millisecond", m.group(1),
+ Long.MIN_VALUE / MICROS_PER_MILLI, Long.MAX_VALUE / MICROS_PER_MILLI);
+ result = new CalendarInterval(0, millisecond * MICROS_PER_MILLI);
+
+ } else if (unit.equals("microsecond")) {
+ long micros = Long.valueOf(m.group(1));
+ result = new CalendarInterval(0, micros);
}
} catch (Exception e) {
throw new IllegalArgumentException("Error parsing interval string: " + e.getMessage(), e);