From 9bd80ad6bd43462d16ce24cda77cdfaa336c4e02 Mon Sep 17 00:00:00 2001 From: Sean Zhong Date: Wed, 15 Jun 2016 14:34:15 -0700 Subject: [SPARK-15776][SQL] Divide Expression inside Aggregation function is casted to wrong type ## What changes were proposed in this pull request? This PR fixes the problem that Divide Expression inside Aggregation function is casted to wrong type, which cause `select 1/2` and `select sum(1/2)`returning different result. **Before the change:** ``` scala> sql("select 1/2 as a").show() +---+ | a| +---+ |0.5| +---+ scala> sql("select sum(1/2) as a").show() +---+ | a| +---+ |0 | +---+ scala> sql("select sum(1 / 2) as a").schema res4: org.apache.spark.sql.types.StructType = StructType(StructField(a,LongType,true)) ``` **After the change:** ``` scala> sql("select 1/2 as a").show() +---+ | a| +---+ |0.5| +---+ scala> sql("select sum(1/2) as a").show() +---+ | a| +---+ |0.5| +---+ scala> sql("select sum(1/2) as a").schema res4: org.apache.spark.sql.types.StructType = StructType(StructField(a,DoubleType,true)) ``` ## How was this patch tested? Unit test. This PR is based on https://github.com/apache/spark/pull/13524 by Sephiroth-Lin Author: Sean Zhong Closes #13651 from clockfly/SPARK-15776. --- .../spark/sql/catalyst/analysis/TypeCoercion.scala | 8 +++-- .../sql/catalyst/expressions/arithmetic.scala | 3 +- .../sql/catalyst/analysis/AnalysisSuite.scala | 32 +++++++++++++++++++ .../analysis/ExpressionTypeCheckingSuite.scala | 2 +- .../sql/catalyst/analysis/TypeCoercionSuite.scala | 37 ++++++++++++++++++++-- .../expressions/ArithmeticExpressionSuite.scala | 19 +++++------ .../plans/ConstraintPropagationSuite.scala | 4 +-- 7 files changed, 86 insertions(+), 19 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index a5b5b91e4a..16df628a57 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -525,14 +525,16 @@ object TypeCoercion { def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { // Skip nodes who has not been resolved yet, // as this is an extra rule which should be applied at last. - case e if !e.resolved => e + case e if !e.childrenResolved => e // Decimal and Double remain the same case d: Divide if d.dataType == DoubleType => d case d: Divide if d.dataType.isInstanceOf[DecimalType] => d - - case Divide(left, right) => Divide(Cast(left, DoubleType), Cast(right, DoubleType)) + case Divide(left, right) if isNumeric(left) && isNumeric(right) => + Divide(Cast(left, DoubleType), Cast(right, DoubleType)) } + + private def isNumeric(ex: Expression): Boolean = ex.dataType.isInstanceOf[NumericType] } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index b2df79a588..4db1352291 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -213,7 +213,7 @@ case class Multiply(left: Expression, right: Expression) case class Divide(left: Expression, right: Expression) extends BinaryArithmetic with NullIntolerant { - override def inputType: AbstractDataType = NumericType + override def inputType: AbstractDataType = TypeCollection(DoubleType, DecimalType) override def symbol: String = "/" override def decimalMethod: String = "$div" @@ -221,7 +221,6 @@ case class Divide(left: Expression, right: Expression) private lazy val div: (Any, Any) => Any = dataType match { case ft: FractionalType => ft.fractional.asInstanceOf[Fractional[Any]].div - case it: IntegralType => it.integral.asInstanceOf[Integral[Any]].quot } override def eval(input: InternalRow): Any = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 77ea29ead9..102c78bd72 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -345,4 +345,36 @@ class AnalysisSuite extends AnalysisTest { assertAnalysisSuccess(query) } + + private def assertExpressionType( + expression: Expression, + expectedDataType: DataType): Unit = { + val afterAnalyze = + Project(Seq(Alias(expression, "a")()), OneRowRelation).analyze.expressions.head + if (!afterAnalyze.dataType.equals(expectedDataType)) { + fail( + s""" + |data type of expression $expression doesn't match expected: + |Actual data type: + |${afterAnalyze.dataType} + | + |Expected data type: + |${expectedDataType} + """.stripMargin) + } + } + + test("SPARK-15776: test whether Divide expression's data type can be deduced correctly by " + + "analyzer") { + assertExpressionType(sum(Divide(1, 2)), DoubleType) + assertExpressionType(sum(Divide(1.0, 2)), DoubleType) + assertExpressionType(sum(Divide(1, 2.0)), DoubleType) + assertExpressionType(sum(Divide(1.0, 2.0)), DoubleType) + assertExpressionType(sum(Divide(1, 2.0f)), DoubleType) + assertExpressionType(sum(Divide(1.0f, 2)), DoubleType) + assertExpressionType(sum(Divide(1, Decimal(2))), DecimalType(31, 11)) + assertExpressionType(sum(Divide(Decimal(1), 2)), DecimalType(31, 11)) + assertExpressionType(sum(Divide(Decimal(1), 2.0)), DoubleType) + assertExpressionType(sum(Divide(1.0, Decimal(2.0))), DoubleType) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala index 660dc86c3e..54436ea9a4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala @@ -85,7 +85,7 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { assertError(Subtract('booleanField, 'booleanField), "requires (numeric or calendarinterval) type") assertError(Multiply('booleanField, 'booleanField), "requires numeric type") - assertError(Divide('booleanField, 'booleanField), "requires numeric type") + assertError(Divide('booleanField, 'booleanField), "requires (double or decimal) type") assertError(Remainder('booleanField, 'booleanField), "requires numeric type") assertError(BitwiseAnd('booleanField, 'booleanField), "requires integral type") diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala index 7435399b14..971c99b671 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala @@ -19,10 +19,12 @@ package org.apache.spark.sql.catalyst.analysis import java.sql.Timestamp +import org.apache.spark.sql.catalyst.analysis.TypeCoercion.{Division, FunctionArgumentConversion} +import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.CalendarInterval @@ -199,9 +201,20 @@ class TypeCoercionSuite extends PlanTest { } private def ruleTest(rule: Rule[LogicalPlan], initial: Expression, transformed: Expression) { + ruleTest(Seq(rule), initial, transformed) + } + + private def ruleTest( + rules: Seq[Rule[LogicalPlan]], + initial: Expression, + transformed: Expression): Unit = { val testRelation = LocalRelation(AttributeReference("a", IntegerType)()) + val analyzer = new RuleExecutor[LogicalPlan] { + override val batches = Seq(Batch("Resolution", FixedPoint(3), rules: _*)) + } + comparePlans( - rule(Project(Seq(Alias(initial, "a")()), testRelation)), + analyzer.execute(Project(Seq(Alias(initial, "a")()), testRelation)), Project(Seq(Alias(transformed, "a")()), testRelation)) } @@ -630,6 +643,26 @@ class TypeCoercionSuite extends PlanTest { Seq(Cast(Literal(1), StringType), Cast(Literal("b"), StringType))) ) } + + test("SPARK-15776 Divide expression's dataType should be casted to Double or Decimal " + + "in aggregation function like sum") { + val rules = Seq(FunctionArgumentConversion, Division) + // Casts Integer to Double + ruleTest(rules, sum(Divide(4, 3)), sum(Divide(Cast(4, DoubleType), Cast(3, DoubleType)))) + // Left expression is Double, right expression is Int. Another rule ImplicitTypeCasts will + // cast the right expression to Double. + ruleTest(rules, sum(Divide(4.0, 3)), sum(Divide(4.0, 3))) + // Left expression is Int, right expression is Double + ruleTest(rules, sum(Divide(4, 3.0)), sum(Divide(Cast(4, DoubleType), Cast(3.0, DoubleType)))) + // Casts Float to Double + ruleTest( + rules, + sum(Divide(4.0f, 3)), + sum(Divide(Cast(4.0f, DoubleType), Cast(3, DoubleType)))) + // Left expression is Decimal, right expression is Int. Another rule DecimalPrecision will cast + // the right expression to Decimal. + ruleTest(rules, sum(Divide(Decimal(4.0), 3)), sum(Divide(Decimal(4.0), 3))) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala index 72285c6a24..2e37887fbc 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala @@ -117,8 +117,13 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper } } + private def testDecimalAndDoubleType(testFunc: (Int => Any) => Unit): Unit = { + testFunc(_.toDouble) + testFunc(Decimal(_)) + } + test("/ (Divide) basic") { - testNumericDataTypes { convert => + testDecimalAndDoubleType { convert => val left = Literal(convert(2)) val right = Literal(convert(1)) val dataType = left.dataType @@ -128,12 +133,14 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(Divide(left, Literal(convert(0))), null) // divide by zero } - DataTypeTestUtils.numericTypeWithoutDecimal.foreach { tpe => + Seq(DoubleType, DecimalType.SYSTEM_DEFAULT).foreach { tpe => checkConsistencyBetweenInterpretedAndCodegen(Divide, tpe, tpe) } } - test("/ (Divide) for integral type") { + // By fixing SPARK-15776, Divide's inputType is required to be DoubleType of DecimalType. + // TODO: in future release, we should add a IntegerDivide to support integral types. + ignore("/ (Divide) for integral type") { checkEvaluation(Divide(Literal(1.toByte), Literal(2.toByte)), 0.toByte) checkEvaluation(Divide(Literal(1.toShort), Literal(2.toShort)), 0.toShort) checkEvaluation(Divide(Literal(1), Literal(2)), 0) @@ -143,12 +150,6 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(Divide(positiveLongLit, negativeLongLit), 0L) } - test("/ (Divide) for floating point") { - checkEvaluation(Divide(Literal(1.0f), Literal(2.0f)), 0.5f) - checkEvaluation(Divide(Literal(1.0), Literal(2.0)), 0.5) - checkEvaluation(Divide(Literal(Decimal(1.0)), Literal(Decimal(2.0))), Decimal(0.5)) - } - test("% (Remainder)") { testNumericDataTypes { convert => val left = Literal(convert(1)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala index 81cc6b123c..0b73b5e009 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala @@ -298,7 +298,7 @@ class ConstraintPropagationSuite extends SparkFunSuite { Cast(resolveColumn(tr, "a"), LongType) * resolveColumn(tr, "b") + Cast(100, LongType) === Cast(resolveColumn(tr, "c"), LongType), Cast(resolveColumn(tr, "d"), DoubleType) / - Cast(Cast(10, LongType), DoubleType) === + Cast(10, DoubleType) === Cast(resolveColumn(tr, "e"), DoubleType), IsNotNull(resolveColumn(tr, "a")), IsNotNull(resolveColumn(tr, "b")), @@ -312,7 +312,7 @@ class ConstraintPropagationSuite extends SparkFunSuite { Cast(resolveColumn(tr, "a"), LongType) * resolveColumn(tr, "b") - Cast(10, LongType) >= Cast(resolveColumn(tr, "c"), LongType), Cast(resolveColumn(tr, "d"), DoubleType) / - Cast(Cast(10, LongType), DoubleType) < + Cast(10, DoubleType) < Cast(resolveColumn(tr, "e"), DoubleType), IsNotNull(resolveColumn(tr, "a")), IsNotNull(resolveColumn(tr, "b")), -- cgit v1.2.3