From e9e2c612d58a19ddcb4b6abfb7389a4b0f7ef6f8 Mon Sep 17 00:00:00 2001 From: Wojtek Szymanski Date: Wed, 8 Mar 2017 12:36:16 -0800 Subject: [SPARK-19727][SQL] Fix for round function that modifies original column ## What changes were proposed in this pull request? Fix for SQL round function that modifies original column when underlying data frame is created from a local product. import org.apache.spark.sql.functions._ case class NumericRow(value: BigDecimal) val df = spark.createDataFrame(Seq(NumericRow(BigDecimal("1.23456789")))) df.show() +--------------------+ | value| +--------------------+ |1.234567890000000000| +--------------------+ df.withColumn("value_rounded", round('value)).show() // before +--------------------+-------------+ | value|value_rounded| +--------------------+-------------+ |1.000000000000000000| 1| +--------------------+-------------+ // after +--------------------+-------------+ | value|value_rounded| +--------------------+-------------+ |1.234567890000000000| 1| +--------------------+-------------+ ## How was this patch tested? New unit test added to existing suite `org.apache.spark.sql.MathFunctionsSuite` Author: Wojtek Szymanski Closes #17075 from wojtek-szymanski/SPARK-19727. --- .../sql/catalyst/CatalystTypeConverters.scala | 6 +---- .../spark/sql/catalyst/expressions/Cast.scala | 13 ++++++++-- .../catalyst/expressions/decimalExpressions.scala | 10 ++------ .../sql/catalyst/expressions/mathExpressions.scala | 2 +- .../scala/org/apache/spark/sql/types/Decimal.scala | 28 +++++++++++++++------- .../org/apache/spark/sql/types/DecimalSuite.scala | 8 ++++++- 6 files changed, 42 insertions(+), 25 deletions(-) (limited to 'sql/catalyst') diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala index 5b9161551a..d4ebdb139f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala @@ -310,11 +310,7 @@ object CatalystTypeConverters { case d: JavaBigInteger => Decimal(d) case d: Decimal => d } - if (decimal.changePrecision(dataType.precision, dataType.scale)) { - decimal - } else { - null - } + decimal.toPrecision(dataType.precision, dataType.scale).orNull } override def toScala(catalystValue: Decimal): JavaBigDecimal = { if (catalystValue == null) null diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 7c60f7d57a..1049915986 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -352,6 +352,15 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String if (value.changePrecision(decimalType.precision, decimalType.scale)) value else null } + /** + * Create new `Decimal` with precision and scale given in `decimalType` (if any), + * returning null if it overflows or creating a new `value` and returning it if successful. + * + */ + private[this] def toPrecision(value: Decimal, decimalType: DecimalType): Decimal = + value.toPrecision(decimalType.precision, decimalType.scale).orNull + + private[this] def castToDecimal(from: DataType, target: DecimalType): Any => Any = from match { case StringType => buildCast[UTF8String](_, s => try { @@ -360,14 +369,14 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String case _: NumberFormatException => null }) case BooleanType => - buildCast[Boolean](_, b => changePrecision(if (b) Decimal.ONE else Decimal.ZERO, target)) + buildCast[Boolean](_, b => toPrecision(if (b) Decimal.ONE else Decimal.ZERO, target)) case DateType => buildCast[Int](_, d => null) // date can't cast to decimal in Hive case TimestampType => // Note that we lose precision here. buildCast[Long](_, t => changePrecision(Decimal(timestampToDouble(t)), target)) case dt: DecimalType => - b => changePrecision(b.asInstanceOf[Decimal].clone(), target) + b => toPrecision(b.asInstanceOf[Decimal], target) case t: IntegralType => b => changePrecision(Decimal(t.integral.asInstanceOf[Integral[Any]].toLong(b)), target) case x: FractionalType => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala index fa5dea6841..c2211ae5d5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala @@ -84,14 +84,8 @@ case class CheckOverflow(child: Expression, dataType: DecimalType) extends Unary override def nullable: Boolean = true - override def nullSafeEval(input: Any): Any = { - val d = input.asInstanceOf[Decimal].clone() - if (d.changePrecision(dataType.precision, dataType.scale)) { - d - } else { - null - } - } + override def nullSafeEval(input: Any): Any = + input.asInstanceOf[Decimal].toPrecision(dataType.precision, dataType.scale).orNull override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, eval => { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala index 65273a77b1..dea5f85cb0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala @@ -1024,7 +1024,7 @@ abstract class RoundBase(child: Expression, scale: Expression, child.dataType match { case _: DecimalType => val decimal = input1.asInstanceOf[Decimal] - if (decimal.changePrecision(decimal.precision, _scale, mode)) decimal else null + decimal.toPrecision(decimal.precision, _scale, mode).orNull case ByteType => BigDecimal(input1.asInstanceOf[Byte]).setScale(_scale, mode).toByte case ShortType => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala index 089c84d5f7..e8f6884c02 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala @@ -21,6 +21,7 @@ import java.lang.{Long => JLong} import java.math.{BigInteger, MathContext, RoundingMode} import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.sql.AnalysisException /** * A mutable implementation of BigDecimal that can hold a Long if values are small enough. @@ -222,6 +223,19 @@ final class Decimal extends Ordered[Decimal] with Serializable { case java.math.BigDecimal.ROUND_HALF_EVEN => changePrecision(precision, scale, ROUND_HALF_EVEN) } + /** + * Create new `Decimal` with given precision and scale. + * + * @return `Some(decimal)` if successful or `None` if overflow would occur + */ + private[sql] def toPrecision( + precision: Int, + scale: Int, + roundMode: BigDecimal.RoundingMode.Value = ROUND_HALF_UP): Option[Decimal] = { + val copy = clone() + if (copy.changePrecision(precision, scale, roundMode)) Some(copy) else None + } + /** * Update precision and scale while keeping our value the same, and return true if successful. * @@ -362,17 +376,15 @@ final class Decimal extends Ordered[Decimal] with Serializable { def abs: Decimal = if (this.compare(Decimal.ZERO) < 0) this.unary_- else this def floor: Decimal = if (scale == 0) this else { - val value = this.clone() - value.changePrecision( - DecimalType.bounded(precision - scale + 1, 0).precision, 0, ROUND_FLOOR) - value + val newPrecision = DecimalType.bounded(precision - scale + 1, 0).precision + toPrecision(newPrecision, 0, ROUND_FLOOR).getOrElse( + throw new AnalysisException(s"Overflow when setting precision to $newPrecision")) } def ceil: Decimal = if (scale == 0) this else { - val value = this.clone() - value.changePrecision( - DecimalType.bounded(precision - scale + 1, 0).precision, 0, ROUND_CEILING) - value + val newPrecision = DecimalType.bounded(precision - scale + 1, 0).precision + toPrecision(newPrecision, 0, ROUND_CEILING).getOrElse( + throw new AnalysisException(s"Overflow when setting precision to $newPrecision")) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala index 52d0692524..714883a409 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala @@ -193,7 +193,7 @@ class DecimalSuite extends SparkFunSuite with PrivateMethodTester { assert(Decimal(Long.MaxValue, 100, 0).toUnscaledLong === Long.MaxValue) } - test("changePrecision() on compact decimal should respect rounding mode") { + test("changePrecision/toPrecision on compact decimal should respect rounding mode") { Seq(ROUND_FLOOR, ROUND_CEILING, ROUND_HALF_UP, ROUND_HALF_EVEN).foreach { mode => Seq("0.4", "0.5", "0.6", "1.0", "1.1", "1.6", "2.5", "5.5").foreach { n => Seq("", "-").foreach { sign => @@ -202,6 +202,12 @@ class DecimalSuite extends SparkFunSuite with PrivateMethodTester { val d = Decimal(unscaled, 8, 1) assert(d.changePrecision(10, 0, mode)) assert(d.toString === bd.setScale(0, mode).toString(), s"num: $sign$n, mode: $mode") + + val copy = d.toPrecision(10, 0, mode).orNull + assert(copy !== null) + assert(d.ne(copy)) + assert(d === copy) + assert(copy.toString === bd.setScale(0, mode).toString(), s"num: $sign$n, mode: $mode") } } } -- cgit v1.2.3