aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala30
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala13
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala6
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala26
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala6
5 files changed, 55 insertions, 26 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
index 87ffbfe791..e052750344 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
@@ -19,9 +19,7 @@ package org.apache.spark.sql.catalyst.analysis
import javax.annotation.Nullable
-import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.types._
@@ -38,7 +36,7 @@ object HiveTypeCoercion {
val typeCoercionRules =
PropagateTypes ::
InConversion ::
- WidenTypes ::
+ WidenSetOperationTypes ::
PromoteStrings ::
DecimalPrecision ::
BooleanEquality ::
@@ -175,7 +173,7 @@ object HiveTypeCoercion {
*
* This rule is only applied to Union/Except/Intersect
*/
- object WidenTypes extends Rule[LogicalPlan] {
+ object WidenSetOperationTypes extends Rule[LogicalPlan] {
private[this] def widenOutputTypes(
planName: String,
@@ -203,9 +201,9 @@ object HiveTypeCoercion {
def castOutput(plan: LogicalPlan): LogicalPlan = {
val casted = plan.output.zip(castedTypes).map {
- case (hs, Some(dt)) if dt != hs.dataType =>
- Alias(Cast(hs, dt), hs.name)()
- case (hs, _) => hs
+ case (e, Some(dt)) if e.dataType != dt =>
+ Alias(Cast(e, dt), e.name)()
+ case (e, _) => e
}
Project(casted, plan)
}
@@ -355,20 +353,8 @@ object HiveTypeCoercion {
DecimalType.bounded(range + scale, scale)
}
- /**
- * An expression used to wrap the children when promote the precision of DecimalType to avoid
- * promote multiple times.
- */
- case class ChangePrecision(child: Expression) extends UnaryExpression {
- override def dataType: DataType = child.dataType
- override def eval(input: InternalRow): Any = child.eval(input)
- override def gen(ctx: CodeGenContext): GeneratedExpressionCode = child.gen(ctx)
- override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = ""
- override def prettyName: String = "change_precision"
- }
-
- def changePrecision(e: Expression, dataType: DataType): Expression = {
- ChangePrecision(Cast(e, dataType))
+ private def changePrecision(e: Expression, dataType: DataType): Expression = {
+ ChangeDecimalPrecision(Cast(e, dataType))
}
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
@@ -378,7 +364,7 @@ object HiveTypeCoercion {
case e if !e.childrenResolved => e
// Skip nodes who is already promoted
- case e: BinaryArithmetic if e.left.isInstanceOf[ChangePrecision] => e
+ case e: BinaryArithmetic if e.left.isInstanceOf[ChangeDecimalPrecision] => e
case Add(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) =>
val dt = DecimalType.bounded(max(s1, s2) + max(p1 - s1, p2 - s2) + 1, max(s1, s2))
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala
index b9d4736a65..adb33e4c8d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala
@@ -17,6 +17,7 @@
package org.apache.spark.sql.catalyst.expressions
+import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode}
import org.apache.spark.sql.types._
@@ -60,3 +61,15 @@ case class MakeDecimal(child: Expression, precision: Int, scale: Int) extends Un
})
}
}
+
+/**
+ * An expression used to wrap the children when promote the precision of DecimalType to avoid
+ * promote multiple times.
+ */
+case class ChangeDecimalPrecision(child: Expression) extends UnaryExpression {
+ override def dataType: DataType = child.dataType
+ override def eval(input: InternalRow): Any = child.eval(input)
+ override def gen(ctx: CodeGenContext): GeneratedExpressionCode = child.gen(ctx)
+ override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = ""
+ override def prettyName: String = "change_decimal_precision"
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala
index 26b24616d9..0cd352d0fa 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala
@@ -78,6 +78,10 @@ case class DecimalType(precision: Int, scale: Int) extends FractionalType {
override def toString: String = s"DecimalType($precision,$scale)"
+ /**
+ * Returns whether this DecimalType is wider than `other`. If yes, it means `other`
+ * can be casted into `this` safely without losing any precision or range.
+ */
private[sql] def isWiderThan(other: DataType): Boolean = other match {
case dt: DecimalType =>
(precision - scale) >= (dt.precision - dt.scale) && scale >= dt.scale
@@ -109,7 +113,7 @@ object DecimalType extends AbstractDataType {
@deprecated("Does not support unlimited precision, please specify the precision and scale", "1.5")
val Unlimited: DecimalType = SYSTEM_DEFAULT
- // The decimal types compatible with other numberic types
+ // The decimal types compatible with other numeric types
private[sql] val ByteDecimal = DecimalType(3, 0)
private[sql] val ShortDecimal = DecimalType(5, 0)
private[sql] val IntDecimal = DecimalType(10, 0)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala
index f9f15e7a66..fc11627da6 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala
@@ -154,4 +154,30 @@ class DecimalPrecisionSuite extends SparkFunSuite with BeforeAndAfter {
checkType(Remainder(expr, u), DoubleType)
}
}
+
+ test("DecimalType.isWiderThan") {
+ val d0 = DecimalType(2, 0)
+ val d1 = DecimalType(2, 1)
+ val d2 = DecimalType(5, 2)
+ val d3 = DecimalType(15, 3)
+ val d4 = DecimalType(25, 4)
+
+ assert(d0.isWiderThan(d1) === false)
+ assert(d1.isWiderThan(d0) === false)
+ assert(d1.isWiderThan(d2) === false)
+ assert(d2.isWiderThan(d1) === true)
+ assert(d2.isWiderThan(d3) === false)
+ assert(d3.isWiderThan(d2) === true)
+ assert(d4.isWiderThan(d3) === true)
+
+ assert(d1.isWiderThan(ByteType) === false)
+ assert(d2.isWiderThan(ByteType) === true)
+ assert(d2.isWiderThan(ShortType) === false)
+ assert(d3.isWiderThan(ShortType) === true)
+ assert(d3.isWiderThan(IntegerType) === true)
+ assert(d3.isWiderThan(LongType) === false)
+ assert(d4.isWiderThan(LongType) === true)
+ assert(d4.isWiderThan(FloatType) === false)
+ assert(d4.isWiderThan(DoubleType) === false)
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
index 55865bdb53..4454d51b75 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
@@ -314,7 +314,7 @@ class HiveTypeCoercionSuite extends PlanTest {
)
}
- test("WidenTypes for union except and intersect") {
+ test("WidenSetOperationTypes for union except and intersect") {
def checkOutput(logical: LogicalPlan, expectTypes: Seq[DataType]): Unit = {
logical.output.zip(expectTypes).foreach { case (attr, dt) =>
assert(attr.dataType === dt)
@@ -332,7 +332,7 @@ class HiveTypeCoercionSuite extends PlanTest {
AttributeReference("f", FloatType)(),
AttributeReference("l", LongType)())
- val wt = HiveTypeCoercion.WidenTypes
+ val wt = HiveTypeCoercion.WidenSetOperationTypes
val expectedTypes = Seq(StringType, DecimalType.SYSTEM_DEFAULT, FloatType, DoubleType)
val r1 = wt(Union(left, right)).asInstanceOf[Union]
@@ -353,7 +353,7 @@ class HiveTypeCoercionSuite extends PlanTest {
}
}
- val dp = HiveTypeCoercion.WidenTypes
+ val dp = HiveTypeCoercion.WidenSetOperationTypes
val left1 = LocalRelation(
AttributeReference("l", DecimalType(10, 8))())