aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSean Zhong <seanzhong@databricks.com>2016-06-15 14:34:15 -0700
committerWenchen Fan <wenchen@databricks.com>2016-06-15 14:34:15 -0700
commit9bd80ad6bd43462d16ce24cda77cdfaa336c4e02 (patch)
tree0db8b623ae893265d94ea487c6bb1be549203fee
parent049e639fc22100af3cbe2ffcf3a2fd3ae2461373 (diff)
downloadspark-9bd80ad6bd43462d16ce24cda77cdfaa336c4e02.tar.gz
spark-9bd80ad6bd43462d16ce24cda77cdfaa336c4e02.tar.bz2
spark-9bd80ad6bd43462d16ce24cda77cdfaa336c4e02.zip
[SPARK-15776][SQL] Divide Expression inside Aggregation function is casted to wrong type
## What changes were proposed in this pull request? This PR fixes the problem that Divide Expression inside Aggregation function is casted to wrong type, which cause `select 1/2` and `select sum(1/2)`returning different result. **Before the change:** ``` scala> sql("select 1/2 as a").show() +---+ | a| +---+ |0.5| +---+ scala> sql("select sum(1/2) as a").show() +---+ | a| +---+ |0 | +---+ scala> sql("select sum(1 / 2) as a").schema res4: org.apache.spark.sql.types.StructType = StructType(StructField(a,LongType,true)) ``` **After the change:** ``` scala> sql("select 1/2 as a").show() +---+ | a| +---+ |0.5| +---+ scala> sql("select sum(1/2) as a").show() +---+ | a| +---+ |0.5| +---+ scala> sql("select sum(1/2) as a").schema res4: org.apache.spark.sql.types.StructType = StructType(StructField(a,DoubleType,true)) ``` ## How was this patch tested? Unit test. This PR is based on https://github.com/apache/spark/pull/13524 by Sephiroth-Lin Author: Sean Zhong <seanzhong@databricks.com> Closes #13651 from clockfly/SPARK-15776.
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala8
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala3
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala32
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala2
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala37
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala19
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala4
7 files changed, 86 insertions, 19 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
index a5b5b91e4a..16df628a57 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
@@ -525,14 +525,16 @@ object TypeCoercion {
def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions {
// Skip nodes who has not been resolved yet,
// as this is an extra rule which should be applied at last.
- case e if !e.resolved => e
+ case e if !e.childrenResolved => e
// Decimal and Double remain the same
case d: Divide if d.dataType == DoubleType => d
case d: Divide if d.dataType.isInstanceOf[DecimalType] => d
-
- case Divide(left, right) => Divide(Cast(left, DoubleType), Cast(right, DoubleType))
+ case Divide(left, right) if isNumeric(left) && isNumeric(right) =>
+ Divide(Cast(left, DoubleType), Cast(right, DoubleType))
}
+
+ private def isNumeric(ex: Expression): Boolean = ex.dataType.isInstanceOf[NumericType]
}
/**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
index b2df79a588..4db1352291 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
@@ -213,7 +213,7 @@ case class Multiply(left: Expression, right: Expression)
case class Divide(left: Expression, right: Expression)
extends BinaryArithmetic with NullIntolerant {
- override def inputType: AbstractDataType = NumericType
+ override def inputType: AbstractDataType = TypeCollection(DoubleType, DecimalType)
override def symbol: String = "/"
override def decimalMethod: String = "$div"
@@ -221,7 +221,6 @@ case class Divide(left: Expression, right: Expression)
private lazy val div: (Any, Any) => Any = dataType match {
case ft: FractionalType => ft.fractional.asInstanceOf[Fractional[Any]].div
- case it: IntegralType => it.integral.asInstanceOf[Integral[Any]].quot
}
override def eval(input: InternalRow): Any = {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
index 77ea29ead9..102c78bd72 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
@@ -345,4 +345,36 @@ class AnalysisSuite extends AnalysisTest {
assertAnalysisSuccess(query)
}
+
+ private def assertExpressionType(
+ expression: Expression,
+ expectedDataType: DataType): Unit = {
+ val afterAnalyze =
+ Project(Seq(Alias(expression, "a")()), OneRowRelation).analyze.expressions.head
+ if (!afterAnalyze.dataType.equals(expectedDataType)) {
+ fail(
+ s"""
+ |data type of expression $expression doesn't match expected:
+ |Actual data type:
+ |${afterAnalyze.dataType}
+ |
+ |Expected data type:
+ |${expectedDataType}
+ """.stripMargin)
+ }
+ }
+
+ test("SPARK-15776: test whether Divide expression's data type can be deduced correctly by " +
+ "analyzer") {
+ assertExpressionType(sum(Divide(1, 2)), DoubleType)
+ assertExpressionType(sum(Divide(1.0, 2)), DoubleType)
+ assertExpressionType(sum(Divide(1, 2.0)), DoubleType)
+ assertExpressionType(sum(Divide(1.0, 2.0)), DoubleType)
+ assertExpressionType(sum(Divide(1, 2.0f)), DoubleType)
+ assertExpressionType(sum(Divide(1.0f, 2)), DoubleType)
+ assertExpressionType(sum(Divide(1, Decimal(2))), DecimalType(31, 11))
+ assertExpressionType(sum(Divide(Decimal(1), 2)), DecimalType(31, 11))
+ assertExpressionType(sum(Divide(Decimal(1), 2.0)), DoubleType)
+ assertExpressionType(sum(Divide(1.0, Decimal(2.0))), DoubleType)
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
index 660dc86c3e..54436ea9a4 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
@@ -85,7 +85,7 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
assertError(Subtract('booleanField, 'booleanField),
"requires (numeric or calendarinterval) type")
assertError(Multiply('booleanField, 'booleanField), "requires numeric type")
- assertError(Divide('booleanField, 'booleanField), "requires numeric type")
+ assertError(Divide('booleanField, 'booleanField), "requires (double or decimal) type")
assertError(Remainder('booleanField, 'booleanField), "requires numeric type")
assertError(BitwiseAnd('booleanField, 'booleanField), "requires integral type")
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala
index 7435399b14..971c99b671 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala
@@ -19,10 +19,12 @@ package org.apache.spark.sql.catalyst.analysis
import java.sql.Timestamp
+import org.apache.spark.sql.catalyst.analysis.TypeCoercion.{Division, FunctionArgumentConversion}
+import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._
-import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.CalendarInterval
@@ -199,9 +201,20 @@ class TypeCoercionSuite extends PlanTest {
}
private def ruleTest(rule: Rule[LogicalPlan], initial: Expression, transformed: Expression) {
+ ruleTest(Seq(rule), initial, transformed)
+ }
+
+ private def ruleTest(
+ rules: Seq[Rule[LogicalPlan]],
+ initial: Expression,
+ transformed: Expression): Unit = {
val testRelation = LocalRelation(AttributeReference("a", IntegerType)())
+ val analyzer = new RuleExecutor[LogicalPlan] {
+ override val batches = Seq(Batch("Resolution", FixedPoint(3), rules: _*))
+ }
+
comparePlans(
- rule(Project(Seq(Alias(initial, "a")()), testRelation)),
+ analyzer.execute(Project(Seq(Alias(initial, "a")()), testRelation)),
Project(Seq(Alias(transformed, "a")()), testRelation))
}
@@ -630,6 +643,26 @@ class TypeCoercionSuite extends PlanTest {
Seq(Cast(Literal(1), StringType), Cast(Literal("b"), StringType)))
)
}
+
+ test("SPARK-15776 Divide expression's dataType should be casted to Double or Decimal " +
+ "in aggregation function like sum") {
+ val rules = Seq(FunctionArgumentConversion, Division)
+ // Casts Integer to Double
+ ruleTest(rules, sum(Divide(4, 3)), sum(Divide(Cast(4, DoubleType), Cast(3, DoubleType))))
+ // Left expression is Double, right expression is Int. Another rule ImplicitTypeCasts will
+ // cast the right expression to Double.
+ ruleTest(rules, sum(Divide(4.0, 3)), sum(Divide(4.0, 3)))
+ // Left expression is Int, right expression is Double
+ ruleTest(rules, sum(Divide(4, 3.0)), sum(Divide(Cast(4, DoubleType), Cast(3.0, DoubleType))))
+ // Casts Float to Double
+ ruleTest(
+ rules,
+ sum(Divide(4.0f, 3)),
+ sum(Divide(Cast(4.0f, DoubleType), Cast(3, DoubleType))))
+ // Left expression is Decimal, right expression is Int. Another rule DecimalPrecision will cast
+ // the right expression to Decimal.
+ ruleTest(rules, sum(Divide(Decimal(4.0), 3)), sum(Divide(Decimal(4.0), 3)))
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala
index 72285c6a24..2e37887fbc 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala
@@ -117,8 +117,13 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper
}
}
+ private def testDecimalAndDoubleType(testFunc: (Int => Any) => Unit): Unit = {
+ testFunc(_.toDouble)
+ testFunc(Decimal(_))
+ }
+
test("/ (Divide) basic") {
- testNumericDataTypes { convert =>
+ testDecimalAndDoubleType { convert =>
val left = Literal(convert(2))
val right = Literal(convert(1))
val dataType = left.dataType
@@ -128,12 +133,14 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper
checkEvaluation(Divide(left, Literal(convert(0))), null) // divide by zero
}
- DataTypeTestUtils.numericTypeWithoutDecimal.foreach { tpe =>
+ Seq(DoubleType, DecimalType.SYSTEM_DEFAULT).foreach { tpe =>
checkConsistencyBetweenInterpretedAndCodegen(Divide, tpe, tpe)
}
}
- test("/ (Divide) for integral type") {
+ // By fixing SPARK-15776, Divide's inputType is required to be DoubleType of DecimalType.
+ // TODO: in future release, we should add a IntegerDivide to support integral types.
+ ignore("/ (Divide) for integral type") {
checkEvaluation(Divide(Literal(1.toByte), Literal(2.toByte)), 0.toByte)
checkEvaluation(Divide(Literal(1.toShort), Literal(2.toShort)), 0.toShort)
checkEvaluation(Divide(Literal(1), Literal(2)), 0)
@@ -143,12 +150,6 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper
checkEvaluation(Divide(positiveLongLit, negativeLongLit), 0L)
}
- test("/ (Divide) for floating point") {
- checkEvaluation(Divide(Literal(1.0f), Literal(2.0f)), 0.5f)
- checkEvaluation(Divide(Literal(1.0), Literal(2.0)), 0.5)
- checkEvaluation(Divide(Literal(Decimal(1.0)), Literal(Decimal(2.0))), Decimal(0.5))
- }
-
test("% (Remainder)") {
testNumericDataTypes { convert =>
val left = Literal(convert(1))
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala
index 81cc6b123c..0b73b5e009 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala
@@ -298,7 +298,7 @@ class ConstraintPropagationSuite extends SparkFunSuite {
Cast(resolveColumn(tr, "a"), LongType) * resolveColumn(tr, "b") + Cast(100, LongType) ===
Cast(resolveColumn(tr, "c"), LongType),
Cast(resolveColumn(tr, "d"), DoubleType) /
- Cast(Cast(10, LongType), DoubleType) ===
+ Cast(10, DoubleType) ===
Cast(resolveColumn(tr, "e"), DoubleType),
IsNotNull(resolveColumn(tr, "a")),
IsNotNull(resolveColumn(tr, "b")),
@@ -312,7 +312,7 @@ class ConstraintPropagationSuite extends SparkFunSuite {
Cast(resolveColumn(tr, "a"), LongType) * resolveColumn(tr, "b") - Cast(10, LongType) >=
Cast(resolveColumn(tr, "c"), LongType),
Cast(resolveColumn(tr, "d"), DoubleType) /
- Cast(Cast(10, LongType), DoubleType) <
+ Cast(10, DoubleType) <
Cast(resolveColumn(tr, "e"), DoubleType),
IsNotNull(resolveColumn(tr, "a")),
IsNotNull(resolveColumn(tr, "b")),