aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorravipesala <ravindra.pesala@huawei.com>2014-10-09 15:14:58 -0700
committerMichael Armbrust <michael@databricks.com>2014-10-09 15:14:58 -0700
commitac302052870a650d56f2d3131c27755bb2960ad7 (patch)
treecba7622be52f2daa07fc07dab123c9e8a83639ab
parentbc3b6cb06153d6b05f311dd78459768b6cf6a404 (diff)
downloadspark-ac302052870a650d56f2d3131c27755bb2960ad7.tar.gz
spark-ac302052870a650d56f2d3131c27755bb2960ad7.tar.bz2
spark-ac302052870a650d56f2d3131c27755bb2960ad7.zip
[SPARK-3813][SQL] Support "case when" conditional functions in Spark SQL.
"case when" conditional function is already supported in Spark SQL but there is no support in SqlParser. So added parser support to it. Author : ravipesala ravindra.pesalahuawei.com Author: ravipesala <ravindra.pesala@huawei.com> Closes #2678 from ravipesala/SPARK-3813 and squashes the following commits: 70c75a7 [ravipesala] Fixed styles 713ea84 [ravipesala] Updated as per admin comments 709684f [ravipesala] Changed parser to support case when function.
-rwxr-xr-xsql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala14
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala15
2 files changed, 27 insertions, 2 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
index 854b5b461b..4662f585cf 100755
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
@@ -77,10 +77,13 @@ class SqlParser extends StandardTokenParsers with PackratParsers {
protected val BETWEEN = Keyword("BETWEEN")
protected val BY = Keyword("BY")
protected val CACHE = Keyword("CACHE")
+ protected val CASE = Keyword("CASE")
protected val CAST = Keyword("CAST")
protected val COUNT = Keyword("COUNT")
protected val DESC = Keyword("DESC")
protected val DISTINCT = Keyword("DISTINCT")
+ protected val ELSE = Keyword("ELSE")
+ protected val END = Keyword("END")
protected val EXCEPT = Keyword("EXCEPT")
protected val FALSE = Keyword("FALSE")
protected val FIRST = Keyword("FIRST")
@@ -122,11 +125,13 @@ class SqlParser extends StandardTokenParsers with PackratParsers {
protected val SUBSTRING = Keyword("SUBSTRING")
protected val SUM = Keyword("SUM")
protected val TABLE = Keyword("TABLE")
+ protected val THEN = Keyword("THEN")
protected val TIMESTAMP = Keyword("TIMESTAMP")
protected val TRUE = Keyword("TRUE")
protected val UNCACHE = Keyword("UNCACHE")
protected val UNION = Keyword("UNION")
protected val UPPER = Keyword("UPPER")
+ protected val WHEN = Keyword("WHEN")
protected val WHERE = Keyword("WHERE")
// Use reflection to find the reserved words defined in this class.
@@ -333,6 +338,15 @@ class SqlParser extends StandardTokenParsers with PackratParsers {
IF ~> "(" ~> expression ~ "," ~ expression ~ "," ~ expression <~ ")" ^^ {
case c ~ "," ~ t ~ "," ~ f => If(c,t,f)
} |
+ CASE ~> expression.? ~ (WHEN ~> expression ~ (THEN ~> expression)).* ~
+ (ELSE ~> expression).? <~ END ^^ {
+ case casePart ~ altPart ~ elsePart =>
+ val altExprs = altPart.flatMap {
+ case we ~ te =>
+ Seq(casePart.fold(we)(EqualTo(_, we)), te)
+ }
+ CaseWhen(altExprs ++ elsePart.toList)
+ } |
(SUBSTR | SUBSTRING) ~> "(" ~> expression ~ "," ~ expression <~ ")" ^^ {
case s ~ "," ~ p => Substring(s,p,Literal(Integer.MAX_VALUE))
} |
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index b9b196ea5a..79de1bb855 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -680,9 +680,20 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
sql("SELECT CAST(TRUE AS STRING), CAST(FALSE AS STRING) FROM testData LIMIT 1"),
("true", "false") :: Nil)
}
-
+
test("SPARK-3371 Renaming a function expression with group by gives error") {
registerFunction("len", (s: String) => s.length)
checkAnswer(
- sql("SELECT len(value) as temp FROM testData WHERE key = 1 group by len(value)"), 1)}
+ sql("SELECT len(value) as temp FROM testData WHERE key = 1 group by len(value)"), 1)
+ }
+
+ test("SPARK-3813 CASE a WHEN b THEN c [WHEN d THEN e]* [ELSE f] END") {
+ checkAnswer(
+ sql("SELECT CASE key WHEN 1 THEN 1 ELSE 0 END FROM testData WHERE key = 1 group by key"), 1)
+ }
+
+ test("SPARK-3813 CASE WHEN a THEN b [WHEN c THEN d]* [ELSE e] END") {
+ checkAnswer(
+ sql("SELECT CASE WHEN key=1 THEN 1 ELSE 2 END FROM testData WHERE key = 1 group by key"), 1)
+ }
}