aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorDaoyuan Wang <daoyuan.wang@intel.com>2015-02-01 18:51:38 -0800
committerMichael Armbrust <michael@databricks.com>2015-02-01 18:51:38 -0800
commit8cf4a1f02e40f37f940f6a347c078f5879585bf4 (patch)
tree430ebc09a5753a8cb1738ce289be6057b1dbd0c2 /sql
parent1b56f1d6bb079a669ae83e70ee515373ade2a469 (diff)
downloadspark-8cf4a1f02e40f37f940f6a347c078f5879585bf4.tar.gz
spark-8cf4a1f02e40f37f940f6a347c078f5879585bf4.tar.bz2
spark-8cf4a1f02e40f37f940f6a347c078f5879585bf4.zip
[SPARK-5262] [SPARK-5244] [SQL] add coalesce in SQLParser and widen types for parameters of coalesce
I'll add test case in #4040 Author: Daoyuan Wang <daoyuan.wang@intel.com> Closes #4057 from adrian-wang/coal and squashes the following commits: 4d0111a [Daoyuan Wang] address Yin's comments c393e18 [Daoyuan Wang] fix rebase conflicts e47c03a [Daoyuan Wang] add coalesce in parser c74828d [Daoyuan Wang] cast types for coalesce
Diffstat (limited to 'sql')
-rwxr-xr-xsql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala16
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala27
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala12
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala2
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala6
6 files changed, 65 insertions, 0 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
index 24a65f8f4d..594a423146 100755
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
@@ -50,6 +50,7 @@ class SqlParser extends AbstractSparkSQLParser {
protected val CACHE = Keyword("CACHE")
protected val CASE = Keyword("CASE")
protected val CAST = Keyword("CAST")
+ protected val COALESCE = Keyword("COALESCE")
protected val COUNT = Keyword("COUNT")
protected val DECIMAL = Keyword("DECIMAL")
protected val DESC = Keyword("DESC")
@@ -295,6 +296,7 @@ class SqlParser extends AbstractSparkSQLParser {
{ case s ~ p => Substring(s, p, Literal(Integer.MAX_VALUE)) }
| (SUBSTR | SUBSTRING) ~ "(" ~> expression ~ ("," ~> expression) ~ ("," ~> expression) <~ ")" ^^
{ case s ~ p ~ l => Substring(s, p, l) }
+ | COALESCE ~ "(" ~> repsep(expression, ",") <~ ")" ^^ { case exprs => Coalesce(exprs) }
| SQRT ~ "(" ~> expression <~ ")" ^^ { case exp => Sqrt(exp) }
| ABS ~ "(" ~> expression <~ ")" ^^ { case exp => Abs(exp) }
| ident ~ ("(" ~> repsep(expression, ",")) <~ ")" ^^
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
index 6ef8577fd0..34ef7d28cc 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
@@ -503,6 +503,22 @@ trait HiveTypeCoercion {
// Hive lets you do aggregation of timestamps... for some reason
case Sum(e @ TimestampType()) => Sum(Cast(e, DoubleType))
case Average(e @ TimestampType()) => Average(Cast(e, DoubleType))
+
+ // Coalesce should return the first non-null value, which could be any column
+ // from the list. So we need to make sure the return type is deterministic and
+ // compatible with every child column.
+ case Coalesce(es) if es.map(_.dataType).distinct.size > 1 =>
+ val dt: Option[DataType] = Some(NullType)
+ val types = es.map(_.dataType)
+ val rt = types.foldLeft(dt)((r, c) => r match {
+ case None => None
+ case Some(d) => findTightestCommonType(d, c)
+ })
+ rt match {
+ case Some(finaldt) => Coalesce(es.map(Cast(_, finaldt)))
+ case None =>
+ sys.error(s"Could not determine return type of Coalesce for ${types.mkString(",")}")
+ }
}
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
index f5a502b43f..85798d0871 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
@@ -114,4 +114,31 @@ class HiveTypeCoercionSuite extends FunSuite {
// Stringify boolean when casting to string.
ruleTest(Cast(Literal(false), StringType), If(Literal(false), Literal("true"), Literal("false")))
}
+
+ test("coalesce casts") {
+ val fac = new HiveTypeCoercion { }.FunctionArgumentConversion
+ def ruleTest(initial: Expression, transformed: Expression) {
+ val testRelation = LocalRelation(AttributeReference("a", IntegerType)())
+ assert(fac(Project(Seq(Alias(initial, "a")()), testRelation)) ==
+ Project(Seq(Alias(transformed, "a")()), testRelation))
+ }
+ ruleTest(
+ Coalesce(Literal(1.0)
+ :: Literal(1)
+ :: Literal(1.0, FloatType)
+ :: Nil),
+ Coalesce(Cast(Literal(1.0), DoubleType)
+ :: Cast(Literal(1), DoubleType)
+ :: Cast(Literal(1.0, FloatType), DoubleType)
+ :: Nil))
+ ruleTest(
+ Coalesce(Literal(1L)
+ :: Literal(1)
+ :: Literal(new java.math.BigDecimal("1000000000000000000000"))
+ :: Nil),
+ Coalesce(Cast(Literal(1L), DecimalType())
+ :: Cast(Literal(1), DecimalType())
+ :: Cast(Literal(new java.math.BigDecimal("1000000000000000000000")), DecimalType())
+ :: Nil))
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index d684278f11..d82c34316c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -88,6 +88,18 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
setConf(SQLConf.CODEGEN_ENABLED, originalValue.toString)
}
+ test("Add Parser of SQL COALESCE()") {
+ checkAnswer(
+ sql("""SELECT COALESCE(1, 2)"""),
+ Row(1))
+ checkAnswer(
+ sql("SELECT COALESCE(null, 1, 1.5)"),
+ Row(1.toDouble))
+ checkAnswer(
+ sql("SELECT COALESCE(null, null, null)"),
+ Row(null))
+ }
+
test("SPARK-3176 Added Parser of SQL LAST()") {
checkAnswer(
sql("SELECT LAST(n) FROM lowerCaseData"),
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
index 399e58b259..30a64b48d7 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
@@ -965,6 +965,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
/* Case insensitive matches */
val ARRAY = "(?i)ARRAY".r
+ val COALESCE = "(?i)COALESCE".r
val COUNT = "(?i)COUNT".r
val AVG = "(?i)AVG".r
val SUM = "(?i)SUM".r
@@ -1140,6 +1141,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
Substring(nodeToExpr(string), nodeToExpr(pos), Literal(Integer.MAX_VALUE, IntegerType))
case Token("TOK_FUNCTION", Token(SUBSTR(), Nil) :: string :: pos :: length :: Nil) =>
Substring(nodeToExpr(string), nodeToExpr(pos), nodeToExpr(length))
+ case Token("TOK_FUNCTION", Token(COALESCE(), Nil) :: list) => Coalesce(list.map(nodeToExpr))
/* UDFs - Must be last otherwise will preempt built in functions */
case Token("TOK_FUNCTION", Token(name, Nil) :: args) =>
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala
index 48fffe53cf..ab0e0443c7 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala
@@ -57,4 +57,10 @@ class HiveTypeCoercionSuite extends HiveComparisonTest {
}
assert(numEquals === 1)
}
+
+ test("COALESCE with different types") {
+ intercept[RuntimeException] {
+ TestHive.sql("""SELECT COALESCE(1, true, "abc") FROM src limit 1""").collect()
+ }
+ }
}