diff options
author | Wojtek Szymanski <wk.szymanski@gmail.com> | 2017-03-08 12:36:16 -0800 |
---|---|---|
committer | Wenchen Fan <wenchen@databricks.com> | 2017-03-08 12:36:16 -0800 |
commit | e9e2c612d58a19ddcb4b6abfb7389a4b0f7ef6f8 (patch) | |
tree | 31f22bac0755a6384fef07155531f77423b242af /sql | |
parent | f3387d97487cbef894b6963bc008f6a5c4294a85 (diff) | |
download | spark-e9e2c612d58a19ddcb4b6abfb7389a4b0f7ef6f8.tar.gz spark-e9e2c612d58a19ddcb4b6abfb7389a4b0f7ef6f8.tar.bz2 spark-e9e2c612d58a19ddcb4b6abfb7389a4b0f7ef6f8.zip |
[SPARK-19727][SQL] Fix for round function that modifies original column
## What changes were proposed in this pull request?
Fix for SQL round function that modifies original column when underlying data frame is created from a local product.
import org.apache.spark.sql.functions._
case class NumericRow(value: BigDecimal)
val df = spark.createDataFrame(Seq(NumericRow(BigDecimal("1.23456789"))))
df.show()
+--------------------+
| value|
+--------------------+
|1.234567890000000000|
+--------------------+
df.withColumn("value_rounded", round('value)).show()
// before
+--------------------+-------------+
| value|value_rounded|
+--------------------+-------------+
|1.000000000000000000| 1|
+--------------------+-------------+
// after
+--------------------+-------------+
| value|value_rounded|
+--------------------+-------------+
|1.234567890000000000| 1|
+--------------------+-------------+
## How was this patch tested?
New unit test added to existing suite `org.apache.spark.sql.MathFunctionsSuite`
Author: Wojtek Szymanski <wk.szymanski@gmail.com>
Closes #17075 from wojtek-szymanski/SPARK-19727.
Diffstat (limited to 'sql')
7 files changed, 54 insertions, 25 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala index 5b9161551a..d4ebdb139f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala @@ -310,11 +310,7 @@ object CatalystTypeConverters { case d: JavaBigInteger => Decimal(d) case d: Decimal => d } - if (decimal.changePrecision(dataType.precision, dataType.scale)) { - decimal - } else { - null - } + decimal.toPrecision(dataType.precision, dataType.scale).orNull } override def toScala(catalystValue: Decimal): JavaBigDecimal = { if (catalystValue == null) null diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 7c60f7d57a..1049915986 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -352,6 +352,15 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String if (value.changePrecision(decimalType.precision, decimalType.scale)) value else null } + /** + * Create new `Decimal` with precision and scale given in `decimalType` (if any), + * returning null if it overflows or creating a new `value` and returning it if successful. + * + */ + private[this] def toPrecision(value: Decimal, decimalType: DecimalType): Decimal = + value.toPrecision(decimalType.precision, decimalType.scale).orNull + + private[this] def castToDecimal(from: DataType, target: DecimalType): Any => Any = from match { case StringType => buildCast[UTF8String](_, s => try { @@ -360,14 +369,14 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String case _: NumberFormatException => null }) case BooleanType => - buildCast[Boolean](_, b => changePrecision(if (b) Decimal.ONE else Decimal.ZERO, target)) + buildCast[Boolean](_, b => toPrecision(if (b) Decimal.ONE else Decimal.ZERO, target)) case DateType => buildCast[Int](_, d => null) // date can't cast to decimal in Hive case TimestampType => // Note that we lose precision here. buildCast[Long](_, t => changePrecision(Decimal(timestampToDouble(t)), target)) case dt: DecimalType => - b => changePrecision(b.asInstanceOf[Decimal].clone(), target) + b => toPrecision(b.asInstanceOf[Decimal], target) case t: IntegralType => b => changePrecision(Decimal(t.integral.asInstanceOf[Integral[Any]].toLong(b)), target) case x: FractionalType => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala index fa5dea6841..c2211ae5d5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala @@ -84,14 +84,8 @@ case class CheckOverflow(child: Expression, dataType: DecimalType) extends Unary override def nullable: Boolean = true - override def nullSafeEval(input: Any): Any = { - val d = input.asInstanceOf[Decimal].clone() - if (d.changePrecision(dataType.precision, dataType.scale)) { - d - } else { - null - } - } + override def nullSafeEval(input: Any): Any = + input.asInstanceOf[Decimal].toPrecision(dataType.precision, dataType.scale).orNull override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, eval => { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala index 65273a77b1..dea5f85cb0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala @@ -1024,7 +1024,7 @@ abstract class RoundBase(child: Expression, scale: Expression, child.dataType match { case _: DecimalType => val decimal = input1.asInstanceOf[Decimal] - if (decimal.changePrecision(decimal.precision, _scale, mode)) decimal else null + decimal.toPrecision(decimal.precision, _scale, mode).orNull case ByteType => BigDecimal(input1.asInstanceOf[Byte]).setScale(_scale, mode).toByte case ShortType => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala index 089c84d5f7..e8f6884c02 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala @@ -21,6 +21,7 @@ import java.lang.{Long => JLong} import java.math.{BigInteger, MathContext, RoundingMode} import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.sql.AnalysisException /** * A mutable implementation of BigDecimal that can hold a Long if values are small enough. @@ -223,6 +224,19 @@ final class Decimal extends Ordered[Decimal] with Serializable { } /** + * Create new `Decimal` with given precision and scale. + * + * @return `Some(decimal)` if successful or `None` if overflow would occur + */ + private[sql] def toPrecision( + precision: Int, + scale: Int, + roundMode: BigDecimal.RoundingMode.Value = ROUND_HALF_UP): Option[Decimal] = { + val copy = clone() + if (copy.changePrecision(precision, scale, roundMode)) Some(copy) else None + } + + /** * Update precision and scale while keeping our value the same, and return true if successful. * * @return true if successful, false if overflow would occur @@ -362,17 +376,15 @@ final class Decimal extends Ordered[Decimal] with Serializable { def abs: Decimal = if (this.compare(Decimal.ZERO) < 0) this.unary_- else this def floor: Decimal = if (scale == 0) this else { - val value = this.clone() - value.changePrecision( - DecimalType.bounded(precision - scale + 1, 0).precision, 0, ROUND_FLOOR) - value + val newPrecision = DecimalType.bounded(precision - scale + 1, 0).precision + toPrecision(newPrecision, 0, ROUND_FLOOR).getOrElse( + throw new AnalysisException(s"Overflow when setting precision to $newPrecision")) } def ceil: Decimal = if (scale == 0) this else { - val value = this.clone() - value.changePrecision( - DecimalType.bounded(precision - scale + 1, 0).precision, 0, ROUND_CEILING) - value + val newPrecision = DecimalType.bounded(precision - scale + 1, 0).precision + toPrecision(newPrecision, 0, ROUND_CEILING).getOrElse( + throw new AnalysisException(s"Overflow when setting precision to $newPrecision")) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala index 52d0692524..714883a409 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala @@ -193,7 +193,7 @@ class DecimalSuite extends SparkFunSuite with PrivateMethodTester { assert(Decimal(Long.MaxValue, 100, 0).toUnscaledLong === Long.MaxValue) } - test("changePrecision() on compact decimal should respect rounding mode") { + test("changePrecision/toPrecision on compact decimal should respect rounding mode") { Seq(ROUND_FLOOR, ROUND_CEILING, ROUND_HALF_UP, ROUND_HALF_EVEN).foreach { mode => Seq("0.4", "0.5", "0.6", "1.0", "1.1", "1.6", "2.5", "5.5").foreach { n => Seq("", "-").foreach { sign => @@ -202,6 +202,12 @@ class DecimalSuite extends SparkFunSuite with PrivateMethodTester { val d = Decimal(unscaled, 8, 1) assert(d.changePrecision(10, 0, mode)) assert(d.toString === bd.setScale(0, mode).toString(), s"num: $sign$n, mode: $mode") + + val copy = d.toPrecision(10, 0, mode).orNull + assert(copy !== null) + assert(d.ne(copy)) + assert(d === copy) + assert(copy.toString === bd.setScale(0, mode).toString(), s"num: $sign$n, mode: $mode") } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MathFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MathFunctionsSuite.scala index 37443d0342..328c5395ec 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/MathFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/MathFunctionsSuite.scala @@ -233,6 +233,18 @@ class MathFunctionsSuite extends QueryTest with SharedSQLContext { ) } + test("round/bround with data frame from a local Seq of Product") { + val df = spark.createDataFrame(Seq(Tuple1(BigDecimal("5.9")))).toDF("value") + checkAnswer( + df.withColumn("value_rounded", round('value)), + Seq(Row(BigDecimal("5.9"), BigDecimal("6"))) + ) + checkAnswer( + df.withColumn("value_brounded", bround('value)), + Seq(Row(BigDecimal("5.9"), BigDecimal("6"))) + ) + } + test("exp") { testOneToOneMathFunction(exp, math.exp) } |