aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorHerman van Hovell <hvanhovell@databricks.com>2016-10-03 19:32:59 -0700
committerHerman van Hovell <hvanhovell@databricks.com>2016-10-03 19:32:59 -0700
commit2bbecdec2023143fd144e4242ff70822e0823986 (patch)
treebcb61dd7987595769bb4271b340b014d4170c7b0
parentd8399b600cef706c22d381b01fab19c610db439a (diff)
downloadspark-2bbecdec2023143fd144e4242ff70822e0823986.tar.gz
spark-2bbecdec2023143fd144e4242ff70822e0823986.tar.bz2
spark-2bbecdec2023143fd144e4242ff70822e0823986.zip
[SPARK-17753][SQL] Allow a complex expression as the input a value based case statement
## What changes were proposed in this pull request? We currently only allow relatively simple expressions as the input for a value based case statement. Expressions like `case (a > 1) or (b = 2) when true then 1 when false then 0 end` currently fail. This PR adds support for such expressions. ## How was this patch tested? Added a test to the ExpressionParserSuite. Author: Herman van Hovell <hvanhovell@databricks.com> Closes #15322 from hvanhovell/SPARK-17753.
-rw-r--r--sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g412
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala2
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala4
3 files changed, 11 insertions, 7 deletions
diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4
index 1284681fe8..c336a0c8ea 100644
--- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4
+++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4
@@ -527,16 +527,16 @@ valueExpression
;
primaryExpression
- : constant #constantDefault
- | name=(CURRENT_DATE | CURRENT_TIMESTAMP) #timeFunctionCall
+ : name=(CURRENT_DATE | CURRENT_TIMESTAMP) #timeFunctionCall
+ | CASE value=expression whenClause+ (ELSE elseExpression=expression)? END #simpleCase
+ | CASE whenClause+ (ELSE elseExpression=expression)? END #searchedCase
+ | CAST '(' expression AS dataType ')' #cast
+ | constant #constantDefault
| ASTERISK #star
| qualifiedName '.' ASTERISK #star
| '(' expression (',' expression)+ ')' #rowConstructor
- | qualifiedName '(' (setQuantifier? expression (',' expression)*)? ')' (OVER windowSpec)? #functionCall
| '(' query ')' #subqueryExpression
- | CASE valueExpression whenClause+ (ELSE elseExpression=expression)? END #simpleCase
- | CASE whenClause+ (ELSE elseExpression=expression)? END #searchedCase
- | CAST '(' expression AS dataType ')' #cast
+ | qualifiedName '(' (setQuantifier? expression (',' expression)*)? ')' (OVER windowSpec)? #functionCall
| value=primaryExpression '[' index=valueExpression ']' #subscript
| identifier #columnReference
| base=primaryExpression '.' fieldName=identifier #dereference
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
index 12a70b7769..cd0c70a491 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
@@ -1138,7 +1138,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging {
* }}}
*/
override def visitSimpleCase(ctx: SimpleCaseContext): Expression = withOrigin(ctx) {
- val e = expression(ctx.valueExpression)
+ val e = expression(ctx.value)
val branches = ctx.whenClause.asScala.map { wCtx =>
(EqualTo(e, expression(wCtx.condition)), expression(wCtx.result))
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala
index f319215f05..3718ac5f1e 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala
@@ -292,6 +292,10 @@ class ExpressionParserSuite extends PlanTest {
test("case when") {
assertEqual("case a when 1 then b when 2 then c else d end",
CaseKeyWhen('a, Seq(1, 'b, 2, 'c, 'd)))
+ assertEqual("case (a or b) when true then c when false then d else e end",
+ CaseKeyWhen('a || 'b, Seq(true, 'c, false, 'd, 'e)))
+ assertEqual("case 'a'='a' when true then 1 end",
+ CaseKeyWhen("a" === "a", Seq(true, 1)))
assertEqual("case when a = 1 then b when a = 2 then c else d end",
CaseWhen(Seq(('a === 1, 'b.expr), ('a === 2, 'c.expr)), 'd))
}