diff options
Diffstat (limited to 'sql/catalyst')
3 files changed, 8 insertions, 5 deletions
diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index 59f93b3c46..cc3b8fd3b4 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -506,10 +506,10 @@ expression booleanExpression : NOT booleanExpression #logicalNot + | EXISTS '(' query ')' #exists | predicated #booleanDefault | left=booleanExpression operator=AND right=booleanExpression #logicalBinary | left=booleanExpression operator=OR right=booleanExpression #logicalBinary - | EXISTS '(' query ')' #exists ; // workaround for: @@ -546,9 +546,10 @@ primaryExpression | constant #constantDefault | ASTERISK #star | qualifiedName '.' ASTERISK #star - | '(' expression (',' expression)+ ')' #rowConstructor + | '(' namedExpression (',' namedExpression)+ ')' #rowConstructor | '(' query ')' #subqueryExpression - | qualifiedName '(' (setQuantifier? expression (',' expression)*)? ')' (OVER windowSpec)? #functionCall + | qualifiedName '(' (setQuantifier? namedExpression (',' namedExpression)*)? ')' + (OVER windowSpec)? #functionCall | value=primaryExpression '[' index=valueExpression ']' #subscript | identifier #columnReference | base=primaryExpression '.' fieldName=identifier #dereference diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 3cf11adc19..4c9fb2ec27 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -1016,7 +1016,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { // Create the function call. val name = ctx.qualifiedName.getText val isDistinct = Option(ctx.setQuantifier()).exists(_.DISTINCT != null) - val arguments = ctx.expression().asScala.map(expression) match { + val arguments = ctx.namedExpression().asScala.map(expression) match { case Seq(UnresolvedStar(None)) if name.toLowerCase == "count" && !isDistinct => // Transform COUNT(*) into COUNT(1). Seq(Literal(1)) @@ -1127,7 +1127,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { * Create a [[CreateStruct]] expression. */ override def visitRowConstructor(ctx: RowConstructorContext): Expression = withOrigin(ctx) { - CreateStruct(ctx.expression.asScala.map(expression)) + CreateStruct(ctx.namedExpression().asScala.map(expression)) } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala index 2fecb8dc4a..c2e62e7397 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala @@ -209,6 +209,7 @@ class ExpressionParserSuite extends PlanTest { assertEqual("foo(distinct a, b)", 'foo.distinctFunction('a, 'b)) assertEqual("grouping(distinct a, b)", 'grouping.distinctFunction('a, 'b)) assertEqual("`select`(all a, b)", 'select.function('a, 'b)) + assertEqual("foo(a as x, b as e)", 'foo.function('a as 'x, 'b as 'e)) } test("window function expressions") { @@ -278,6 +279,7 @@ class ExpressionParserSuite extends PlanTest { // Note that '(a)' will be interpreted as a nested expression. assertEqual("(a, b)", CreateStruct(Seq('a, 'b))) assertEqual("(a, b, c)", CreateStruct(Seq('a, 'b, 'c))) + assertEqual("(a as b, b as c)", CreateStruct(Seq('a as 'b, 'b as 'c))) } test("scalar sub-query") { |