aboutsummaryrefslogtreecommitdiff
path: root/sql/catalyst
diff options
context:
space:
mode:
authorWenchen Fan <wenchen@databricks.com>2015-12-08 10:13:40 -0800
committerMichael Armbrust <michael@databricks.com>2015-12-08 10:13:40 -0800
commit381f17b540d92507cc07adf18bce8bc7e5ca5407 (patch)
tree41ea6806106769d9d9ab1ed81412bf4fe3b973f2 /sql/catalyst
parent75c60bf4ba91e45e76a6e27f054a1c550eb6ff94 (diff)
downloadspark-381f17b540d92507cc07adf18bce8bc7e5ca5407.tar.gz
spark-381f17b540d92507cc07adf18bce8bc7e5ca5407.tar.bz2
spark-381f17b540d92507cc07adf18bce8bc7e5ca5407.zip
[SPARK-12201][SQL] add type coercion rule for greatest/least
checked with hive, greatest/least should cast their children to a tightest common type, i.e. `(int, long) => long`, `(int, string) => error`, `(decimal(10,5), decimal(5, 10)) => error` Author: Wenchen Fan <wenchen@databricks.com> Closes #10196 from cloud-fan/type-coercion.
Diffstat (limited to 'sql/catalyst')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala14
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala10
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala23
3 files changed, 47 insertions, 0 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
index 29502a5991..dbcbd6854b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
@@ -594,6 +594,20 @@ object HiveTypeCoercion {
case None => c
}
+ case g @ Greatest(children) if children.map(_.dataType).distinct.size > 1 =>
+ val types = children.map(_.dataType)
+ findTightestCommonType(types) match {
+ case Some(finalDataType) => Greatest(children.map(Cast(_, finalDataType)))
+ case None => g
+ }
+
+ case l @ Least(children) if children.map(_.dataType).distinct.size > 1 =>
+ val types = children.map(_.dataType)
+ findTightestCommonType(types) match {
+ case Some(finalDataType) => Least(children.map(Cast(_, finalDataType)))
+ case None => l
+ }
+
case NaNvl(l, r) if l.dataType == DoubleType && r.dataType == FloatType =>
NaNvl(l, Cast(r, DoubleType))
case NaNvl(l, r) if l.dataType == FloatType && r.dataType == DoubleType =>
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
index ba1866efc8..915c585ec9 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
@@ -32,6 +32,7 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
'intField.int,
'stringField.string,
'booleanField.boolean,
+ 'decimalField.decimal(8, 0),
'arrayField.array(StringType),
'mapField.map(StringType, LongType))
@@ -189,4 +190,13 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
assertError(Round('intField, 'mapField), "requires int type")
assertError(Round('booleanField, 'intField), "requires numeric type")
}
+
+ test("check types for Greatest/Least") {
+ for (operator <- Seq[(Seq[Expression] => Expression)](Greatest, Least)) {
+ assertError(operator(Seq('booleanField)), "requires at least 2 arguments")
+ assertError(operator(Seq('intField, 'stringField)), "should all have the same type")
+ assertError(operator(Seq('intField, 'decimalField)), "should all have the same type")
+ assertError(operator(Seq('mapField, 'mapField)), "does not support ordering")
+ }
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
index d3fafaae89..142915056f 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
@@ -251,6 +251,29 @@ class HiveTypeCoercionSuite extends PlanTest {
:: Nil))
}
+ test("greatest/least cast") {
+ for (operator <- Seq[(Seq[Expression] => Expression)](Greatest, Least)) {
+ ruleTest(HiveTypeCoercion.FunctionArgumentConversion,
+ operator(Literal(1.0)
+ :: Literal(1)
+ :: Literal.create(1.0, FloatType)
+ :: Nil),
+ operator(Cast(Literal(1.0), DoubleType)
+ :: Cast(Literal(1), DoubleType)
+ :: Cast(Literal.create(1.0, FloatType), DoubleType)
+ :: Nil))
+ ruleTest(HiveTypeCoercion.FunctionArgumentConversion,
+ operator(Literal(1L)
+ :: Literal(1)
+ :: Literal(new java.math.BigDecimal("1000000000000000000000"))
+ :: Nil),
+ operator(Cast(Literal(1L), DecimalType(22, 0))
+ :: Cast(Literal(1), DecimalType(22, 0))
+ :: Cast(Literal(new java.math.BigDecimal("1000000000000000000000")), DecimalType(22, 0))
+ :: Nil))
+ }
+ }
+
test("nanvl casts") {
ruleTest(HiveTypeCoercion.FunctionArgumentConversion,
NaNvl(Literal.create(1.0, FloatType), Literal.create(1.0, DoubleType)),