From 8cf4a1f02e40f37f940f6a347c078f5879585bf4 Mon Sep 17 00:00:00 2001 From: Daoyuan Wang Date: Sun, 1 Feb 2015 18:51:38 -0800 Subject: [SPARK-5262] [SPARK-5244] [SQL] add coalesce in SQLParser and widen types for parameters of coalesce I'll add test case in #4040 Author: Daoyuan Wang Closes #4057 from adrian-wang/coal and squashes the following commits: 4d0111a [Daoyuan Wang] address Yin's comments c393e18 [Daoyuan Wang] fix rebase conflicts e47c03a [Daoyuan Wang] add coalesce in parser c74828d [Daoyuan Wang] cast types for coalesce --- .../org/apache/spark/sql/catalyst/SqlParser.scala | 2 ++ .../sql/catalyst/analysis/HiveTypeCoercion.scala | 16 +++++++++++++ .../catalyst/analysis/HiveTypeCoercionSuite.scala | 27 ++++++++++++++++++++++ .../scala/org/apache/spark/sql/SQLQuerySuite.scala | 12 ++++++++++ .../scala/org/apache/spark/sql/hive/HiveQl.scala | 2 ++ .../sql/hive/execution/HiveTypeCoercionSuite.scala | 6 +++++ 6 files changed, 65 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala index 24a65f8f4d..594a423146 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -50,6 +50,7 @@ class SqlParser extends AbstractSparkSQLParser { protected val CACHE = Keyword("CACHE") protected val CASE = Keyword("CASE") protected val CAST = Keyword("CAST") + protected val COALESCE = Keyword("COALESCE") protected val COUNT = Keyword("COUNT") protected val DECIMAL = Keyword("DECIMAL") protected val DESC = Keyword("DESC") @@ -295,6 +296,7 @@ class SqlParser extends AbstractSparkSQLParser { { case s ~ p => Substring(s, p, Literal(Integer.MAX_VALUE)) } | (SUBSTR | SUBSTRING) ~ "(" ~> expression ~ ("," ~> expression) ~ ("," ~> expression) <~ ")" ^^ { case s ~ p ~ l => Substring(s, p, l) } + | COALESCE ~ "(" ~> repsep(expression, ",") <~ ")" ^^ { case exprs => Coalesce(exprs) } | SQRT ~ "(" ~> expression <~ ")" ^^ { case exp => Sqrt(exp) } | ABS ~ "(" ~> expression <~ ")" ^^ { case exp => Abs(exp) } | ident ~ ("(" ~> repsep(expression, ",")) <~ ")" ^^ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 6ef8577fd0..34ef7d28cc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -503,6 +503,22 @@ trait HiveTypeCoercion { // Hive lets you do aggregation of timestamps... for some reason case Sum(e @ TimestampType()) => Sum(Cast(e, DoubleType)) case Average(e @ TimestampType()) => Average(Cast(e, DoubleType)) + + // Coalesce should return the first non-null value, which could be any column + // from the list. So we need to make sure the return type is deterministic and + // compatible with every child column. + case Coalesce(es) if es.map(_.dataType).distinct.size > 1 => + val dt: Option[DataType] = Some(NullType) + val types = es.map(_.dataType) + val rt = types.foldLeft(dt)((r, c) => r match { + case None => None + case Some(d) => findTightestCommonType(d, c) + }) + rt match { + case Some(finaldt) => Coalesce(es.map(Cast(_, finaldt))) + case None => + sys.error(s"Could not determine return type of Coalesce for ${types.mkString(",")}") + } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala index f5a502b43f..85798d0871 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala @@ -114,4 +114,31 @@ class HiveTypeCoercionSuite extends FunSuite { // Stringify boolean when casting to string. ruleTest(Cast(Literal(false), StringType), If(Literal(false), Literal("true"), Literal("false"))) } + + test("coalesce casts") { + val fac = new HiveTypeCoercion { }.FunctionArgumentConversion + def ruleTest(initial: Expression, transformed: Expression) { + val testRelation = LocalRelation(AttributeReference("a", IntegerType)()) + assert(fac(Project(Seq(Alias(initial, "a")()), testRelation)) == + Project(Seq(Alias(transformed, "a")()), testRelation)) + } + ruleTest( + Coalesce(Literal(1.0) + :: Literal(1) + :: Literal(1.0, FloatType) + :: Nil), + Coalesce(Cast(Literal(1.0), DoubleType) + :: Cast(Literal(1), DoubleType) + :: Cast(Literal(1.0, FloatType), DoubleType) + :: Nil)) + ruleTest( + Coalesce(Literal(1L) + :: Literal(1) + :: Literal(new java.math.BigDecimal("1000000000000000000000")) + :: Nil), + Coalesce(Cast(Literal(1L), DecimalType()) + :: Cast(Literal(1), DecimalType()) + :: Cast(Literal(new java.math.BigDecimal("1000000000000000000000")), DecimalType()) + :: Nil)) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index d684278f11..d82c34316c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -88,6 +88,18 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { setConf(SQLConf.CODEGEN_ENABLED, originalValue.toString) } + test("Add Parser of SQL COALESCE()") { + checkAnswer( + sql("""SELECT COALESCE(1, 2)"""), + Row(1)) + checkAnswer( + sql("SELECT COALESCE(null, 1, 1.5)"), + Row(1.toDouble)) + checkAnswer( + sql("SELECT COALESCE(null, null, null)"), + Row(null)) + } + test("SPARK-3176 Added Parser of SQL LAST()") { checkAnswer( sql("SELECT LAST(n) FROM lowerCaseData"), diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index 399e58b259..30a64b48d7 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -965,6 +965,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C /* Case insensitive matches */ val ARRAY = "(?i)ARRAY".r + val COALESCE = "(?i)COALESCE".r val COUNT = "(?i)COUNT".r val AVG = "(?i)AVG".r val SUM = "(?i)SUM".r @@ -1140,6 +1141,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C Substring(nodeToExpr(string), nodeToExpr(pos), Literal(Integer.MAX_VALUE, IntegerType)) case Token("TOK_FUNCTION", Token(SUBSTR(), Nil) :: string :: pos :: length :: Nil) => Substring(nodeToExpr(string), nodeToExpr(pos), nodeToExpr(length)) + case Token("TOK_FUNCTION", Token(COALESCE(), Nil) :: list) => Coalesce(list.map(nodeToExpr)) /* UDFs - Must be last otherwise will preempt built in functions */ case Token("TOK_FUNCTION", Token(name, Nil) :: args) => diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala index 48fffe53cf..ab0e0443c7 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala @@ -57,4 +57,10 @@ class HiveTypeCoercionSuite extends HiveComparisonTest { } assert(numEquals === 1) } + + test("COALESCE with different types") { + intercept[RuntimeException] { + TestHive.sql("""SELECT COALESCE(1, true, "abc") FROM src limit 1""").collect() + } + } } -- cgit v1.2.3