aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rwxr-xr-xsql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala14
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala15
2 files changed, 27 insertions, 2 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
index 854b5b461b..4662f585cf 100755
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
@@ -77,10 +77,13 @@ class SqlParser extends StandardTokenParsers with PackratParsers {
protected val BETWEEN = Keyword("BETWEEN")
protected val BY = Keyword("BY")
protected val CACHE = Keyword("CACHE")
+ protected val CASE = Keyword("CASE")
protected val CAST = Keyword("CAST")
protected val COUNT = Keyword("COUNT")
protected val DESC = Keyword("DESC")
protected val DISTINCT = Keyword("DISTINCT")
+ protected val ELSE = Keyword("ELSE")
+ protected val END = Keyword("END")
protected val EXCEPT = Keyword("EXCEPT")
protected val FALSE = Keyword("FALSE")
protected val FIRST = Keyword("FIRST")
@@ -122,11 +125,13 @@ class SqlParser extends StandardTokenParsers with PackratParsers {
protected val SUBSTRING = Keyword("SUBSTRING")
protected val SUM = Keyword("SUM")
protected val TABLE = Keyword("TABLE")
+ protected val THEN = Keyword("THEN")
protected val TIMESTAMP = Keyword("TIMESTAMP")
protected val TRUE = Keyword("TRUE")
protected val UNCACHE = Keyword("UNCACHE")
protected val UNION = Keyword("UNION")
protected val UPPER = Keyword("UPPER")
+ protected val WHEN = Keyword("WHEN")
protected val WHERE = Keyword("WHERE")
// Use reflection to find the reserved words defined in this class.
@@ -333,6 +338,15 @@ class SqlParser extends StandardTokenParsers with PackratParsers {
IF ~> "(" ~> expression ~ "," ~ expression ~ "," ~ expression <~ ")" ^^ {
case c ~ "," ~ t ~ "," ~ f => If(c,t,f)
} |
+ CASE ~> expression.? ~ (WHEN ~> expression ~ (THEN ~> expression)).* ~
+ (ELSE ~> expression).? <~ END ^^ {
+ case casePart ~ altPart ~ elsePart =>
+ val altExprs = altPart.flatMap {
+ case we ~ te =>
+ Seq(casePart.fold(we)(EqualTo(_, we)), te)
+ }
+ CaseWhen(altExprs ++ elsePart.toList)
+ } |
(SUBSTR | SUBSTRING) ~> "(" ~> expression ~ "," ~ expression <~ ")" ^^ {
case s ~ "," ~ p => Substring(s,p,Literal(Integer.MAX_VALUE))
} |
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index b9b196ea5a..79de1bb855 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -680,9 +680,20 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
sql("SELECT CAST(TRUE AS STRING), CAST(FALSE AS STRING) FROM testData LIMIT 1"),
("true", "false") :: Nil)
}
-
+
test("SPARK-3371 Renaming a function expression with group by gives error") {
registerFunction("len", (s: String) => s.length)
checkAnswer(
- sql("SELECT len(value) as temp FROM testData WHERE key = 1 group by len(value)"), 1)}
+ sql("SELECT len(value) as temp FROM testData WHERE key = 1 group by len(value)"), 1)
+ }
+
+ test("SPARK-3813 CASE a WHEN b THEN c [WHEN d THEN e]* [ELSE f] END") {
+ checkAnswer(
+ sql("SELECT CASE key WHEN 1 THEN 1 ELSE 0 END FROM testData WHERE key = 1 group by key"), 1)
+ }
+
+ test("SPARK-3813 CASE WHEN a THEN b [WHEN c THEN d]* [ELSE e] END") {
+ checkAnswer(
+ sql("SELECT CASE WHEN key=1 THEN 1 ELSE 2 END FROM testData WHERE key = 1 group by key"), 1)
+ }
}