diff options
author | Wenchen Fan <wenchen@databricks.com> | 2015-12-08 10:13:40 -0800 |
---|---|---|
committer | Michael Armbrust <michael@databricks.com> | 2015-12-08 10:13:40 -0800 |
commit | 381f17b540d92507cc07adf18bce8bc7e5ca5407 (patch) | |
tree | 41ea6806106769d9d9ab1ed81412bf4fe3b973f2 /sql/catalyst | |
parent | 75c60bf4ba91e45e76a6e27f054a1c550eb6ff94 (diff) | |
download | spark-381f17b540d92507cc07adf18bce8bc7e5ca5407.tar.gz spark-381f17b540d92507cc07adf18bce8bc7e5ca5407.tar.bz2 spark-381f17b540d92507cc07adf18bce8bc7e5ca5407.zip |
[SPARK-12201][SQL] add type coercion rule for greatest/least
checked with hive, greatest/least should cast their children to a tightest common type,
i.e. `(int, long) => long`, `(int, string) => error`, `(decimal(10,5), decimal(5, 10)) => error`
Author: Wenchen Fan <wenchen@databricks.com>
Closes #10196 from cloud-fan/type-coercion.
Diffstat (limited to 'sql/catalyst')
3 files changed, 47 insertions, 0 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 29502a5991..dbcbd6854b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -594,6 +594,20 @@ object HiveTypeCoercion { case None => c } + case g @ Greatest(children) if children.map(_.dataType).distinct.size > 1 => + val types = children.map(_.dataType) + findTightestCommonType(types) match { + case Some(finalDataType) => Greatest(children.map(Cast(_, finalDataType))) + case None => g + } + + case l @ Least(children) if children.map(_.dataType).distinct.size > 1 => + val types = children.map(_.dataType) + findTightestCommonType(types) match { + case Some(finalDataType) => Least(children.map(Cast(_, finalDataType))) + case None => l + } + case NaNvl(l, r) if l.dataType == DoubleType && r.dataType == FloatType => NaNvl(l, Cast(r, DoubleType)) case NaNvl(l, r) if l.dataType == FloatType && r.dataType == DoubleType => diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala index ba1866efc8..915c585ec9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala @@ -32,6 +32,7 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { 'intField.int, 'stringField.string, 'booleanField.boolean, + 'decimalField.decimal(8, 0), 'arrayField.array(StringType), 'mapField.map(StringType, LongType)) @@ -189,4 +190,13 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { assertError(Round('intField, 'mapField), "requires int type") assertError(Round('booleanField, 'intField), "requires numeric type") } + + test("check types for Greatest/Least") { + for (operator <- Seq[(Seq[Expression] => Expression)](Greatest, Least)) { + assertError(operator(Seq('booleanField)), "requires at least 2 arguments") + assertError(operator(Seq('intField, 'stringField)), "should all have the same type") + assertError(operator(Seq('intField, 'decimalField)), "should all have the same type") + assertError(operator(Seq('mapField, 'mapField)), "does not support ordering") + } + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala index d3fafaae89..142915056f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala @@ -251,6 +251,29 @@ class HiveTypeCoercionSuite extends PlanTest { :: Nil)) } + test("greatest/least cast") { + for (operator <- Seq[(Seq[Expression] => Expression)](Greatest, Least)) { + ruleTest(HiveTypeCoercion.FunctionArgumentConversion, + operator(Literal(1.0) + :: Literal(1) + :: Literal.create(1.0, FloatType) + :: Nil), + operator(Cast(Literal(1.0), DoubleType) + :: Cast(Literal(1), DoubleType) + :: Cast(Literal.create(1.0, FloatType), DoubleType) + :: Nil)) + ruleTest(HiveTypeCoercion.FunctionArgumentConversion, + operator(Literal(1L) + :: Literal(1) + :: Literal(new java.math.BigDecimal("1000000000000000000000")) + :: Nil), + operator(Cast(Literal(1L), DecimalType(22, 0)) + :: Cast(Literal(1), DecimalType(22, 0)) + :: Cast(Literal(new java.math.BigDecimal("1000000000000000000000")), DecimalType(22, 0)) + :: Nil)) + } + } + test("nanvl casts") { ruleTest(HiveTypeCoercion.FunctionArgumentConversion, NaNvl(Literal.create(1.0, FloatType), Literal.create(1.0, DoubleType)), |