diff options
author | Herman van Hovell <hvanhovell@databricks.com> | 2016-10-03 19:32:59 -0700 |
---|---|---|
committer | Herman van Hovell <hvanhovell@databricks.com> | 2016-10-03 19:32:59 -0700 |
commit | 2bbecdec2023143fd144e4242ff70822e0823986 (patch) | |
tree | bcb61dd7987595769bb4271b340b014d4170c7b0 /sql | |
parent | d8399b600cef706c22d381b01fab19c610db439a (diff) | |
download | spark-2bbecdec2023143fd144e4242ff70822e0823986.tar.gz spark-2bbecdec2023143fd144e4242ff70822e0823986.tar.bz2 spark-2bbecdec2023143fd144e4242ff70822e0823986.zip |
[SPARK-17753][SQL] Allow a complex expression as the input a value based case statement
## What changes were proposed in this pull request?
We currently only allow relatively simple expressions as the input for a value based case statement. Expressions like `case (a > 1) or (b = 2) when true then 1 when false then 0 end` currently fail. This PR adds support for such expressions.
## How was this patch tested?
Added a test to the ExpressionParserSuite.
Author: Herman van Hovell <hvanhovell@databricks.com>
Closes #15322 from hvanhovell/SPARK-17753.
Diffstat (limited to 'sql')
3 files changed, 11 insertions, 7 deletions
diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index 1284681fe8..c336a0c8ea 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -527,16 +527,16 @@ valueExpression ; primaryExpression - : constant #constantDefault - | name=(CURRENT_DATE | CURRENT_TIMESTAMP) #timeFunctionCall + : name=(CURRENT_DATE | CURRENT_TIMESTAMP) #timeFunctionCall + | CASE value=expression whenClause+ (ELSE elseExpression=expression)? END #simpleCase + | CASE whenClause+ (ELSE elseExpression=expression)? END #searchedCase + | CAST '(' expression AS dataType ')' #cast + | constant #constantDefault | ASTERISK #star | qualifiedName '.' ASTERISK #star | '(' expression (',' expression)+ ')' #rowConstructor - | qualifiedName '(' (setQuantifier? expression (',' expression)*)? ')' (OVER windowSpec)? #functionCall | '(' query ')' #subqueryExpression - | CASE valueExpression whenClause+ (ELSE elseExpression=expression)? END #simpleCase - | CASE whenClause+ (ELSE elseExpression=expression)? END #searchedCase - | CAST '(' expression AS dataType ')' #cast + | qualifiedName '(' (setQuantifier? expression (',' expression)*)? ')' (OVER windowSpec)? #functionCall | value=primaryExpression '[' index=valueExpression ']' #subscript | identifier #columnReference | base=primaryExpression '.' fieldName=identifier #dereference diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 12a70b7769..cd0c70a491 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -1138,7 +1138,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { * }}} */ override def visitSimpleCase(ctx: SimpleCaseContext): Expression = withOrigin(ctx) { - val e = expression(ctx.valueExpression) + val e = expression(ctx.value) val branches = ctx.whenClause.asScala.map { wCtx => (EqualTo(e, expression(wCtx.condition)), expression(wCtx.result)) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala index f319215f05..3718ac5f1e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala @@ -292,6 +292,10 @@ class ExpressionParserSuite extends PlanTest { test("case when") { assertEqual("case a when 1 then b when 2 then c else d end", CaseKeyWhen('a, Seq(1, 'b, 2, 'c, 'd))) + assertEqual("case (a or b) when true then c when false then d else e end", + CaseKeyWhen('a || 'b, Seq(true, 'c, false, 'd, 'e))) + assertEqual("case 'a'='a' when true then 1 end", + CaseKeyWhen("a" === "a", Seq(true, 1))) assertEqual("case when a = 1 then b when a = 2 then c else d end", CaseWhen(Seq(('a === 1, 'b.expr), ('a === 2, 'c.expr)), 'd)) } |