aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorWojtek Szymanski <wk.szymanski@gmail.com>2017-03-08 12:36:16 -0800
committerWenchen Fan <wenchen@databricks.com>2017-03-08 12:36:16 -0800
commite9e2c612d58a19ddcb4b6abfb7389a4b0f7ef6f8 (patch)
tree31f22bac0755a6384fef07155531f77423b242af /sql
parentf3387d97487cbef894b6963bc008f6a5c4294a85 (diff)
downloadspark-e9e2c612d58a19ddcb4b6abfb7389a4b0f7ef6f8.tar.gz
spark-e9e2c612d58a19ddcb4b6abfb7389a4b0f7ef6f8.tar.bz2
spark-e9e2c612d58a19ddcb4b6abfb7389a4b0f7ef6f8.zip
[SPARK-19727][SQL] Fix for round function that modifies original column
## What changes were proposed in this pull request? Fix for SQL round function that modifies original column when underlying data frame is created from a local product. import org.apache.spark.sql.functions._ case class NumericRow(value: BigDecimal) val df = spark.createDataFrame(Seq(NumericRow(BigDecimal("1.23456789")))) df.show() +--------------------+ | value| +--------------------+ |1.234567890000000000| +--------------------+ df.withColumn("value_rounded", round('value)).show() // before +--------------------+-------------+ | value|value_rounded| +--------------------+-------------+ |1.000000000000000000| 1| +--------------------+-------------+ // after +--------------------+-------------+ | value|value_rounded| +--------------------+-------------+ |1.234567890000000000| 1| +--------------------+-------------+ ## How was this patch tested? New unit test added to existing suite `org.apache.spark.sql.MathFunctionsSuite` Author: Wojtek Szymanski <wk.szymanski@gmail.com> Closes #17075 from wojtek-szymanski/SPARK-19727.
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala6
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala13
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala10
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala28
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala8
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/MathFunctionsSuite.scala12
7 files changed, 54 insertions, 25 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
index 5b9161551a..d4ebdb139f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
@@ -310,11 +310,7 @@ object CatalystTypeConverters {
case d: JavaBigInteger => Decimal(d)
case d: Decimal => d
}
- if (decimal.changePrecision(dataType.precision, dataType.scale)) {
- decimal
- } else {
- null
- }
+ decimal.toPrecision(dataType.precision, dataType.scale).orNull
}
override def toScala(catalystValue: Decimal): JavaBigDecimal = {
if (catalystValue == null) null
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index 7c60f7d57a..1049915986 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -352,6 +352,15 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
if (value.changePrecision(decimalType.precision, decimalType.scale)) value else null
}
+ /**
+ * Create new `Decimal` with precision and scale given in `decimalType` (if any),
+ * returning null if it overflows or creating a new `value` and returning it if successful.
+ *
+ */
+ private[this] def toPrecision(value: Decimal, decimalType: DecimalType): Decimal =
+ value.toPrecision(decimalType.precision, decimalType.scale).orNull
+
+
private[this] def castToDecimal(from: DataType, target: DecimalType): Any => Any = from match {
case StringType =>
buildCast[UTF8String](_, s => try {
@@ -360,14 +369,14 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
case _: NumberFormatException => null
})
case BooleanType =>
- buildCast[Boolean](_, b => changePrecision(if (b) Decimal.ONE else Decimal.ZERO, target))
+ buildCast[Boolean](_, b => toPrecision(if (b) Decimal.ONE else Decimal.ZERO, target))
case DateType =>
buildCast[Int](_, d => null) // date can't cast to decimal in Hive
case TimestampType =>
// Note that we lose precision here.
buildCast[Long](_, t => changePrecision(Decimal(timestampToDouble(t)), target))
case dt: DecimalType =>
- b => changePrecision(b.asInstanceOf[Decimal].clone(), target)
+ b => toPrecision(b.asInstanceOf[Decimal], target)
case t: IntegralType =>
b => changePrecision(Decimal(t.integral.asInstanceOf[Integral[Any]].toLong(b)), target)
case x: FractionalType =>
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala
index fa5dea6841..c2211ae5d5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala
@@ -84,14 +84,8 @@ case class CheckOverflow(child: Expression, dataType: DecimalType) extends Unary
override def nullable: Boolean = true
- override def nullSafeEval(input: Any): Any = {
- val d = input.asInstanceOf[Decimal].clone()
- if (d.changePrecision(dataType.precision, dataType.scale)) {
- d
- } else {
- null
- }
- }
+ override def nullSafeEval(input: Any): Any =
+ input.asInstanceOf[Decimal].toPrecision(dataType.precision, dataType.scale).orNull
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
nullSafeCodeGen(ctx, ev, eval => {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala
index 65273a77b1..dea5f85cb0 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala
@@ -1024,7 +1024,7 @@ abstract class RoundBase(child: Expression, scale: Expression,
child.dataType match {
case _: DecimalType =>
val decimal = input1.asInstanceOf[Decimal]
- if (decimal.changePrecision(decimal.precision, _scale, mode)) decimal else null
+ decimal.toPrecision(decimal.precision, _scale, mode).orNull
case ByteType =>
BigDecimal(input1.asInstanceOf[Byte]).setScale(_scale, mode).toByte
case ShortType =>
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala
index 089c84d5f7..e8f6884c02 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala
@@ -21,6 +21,7 @@ import java.lang.{Long => JLong}
import java.math.{BigInteger, MathContext, RoundingMode}
import org.apache.spark.annotation.InterfaceStability
+import org.apache.spark.sql.AnalysisException
/**
* A mutable implementation of BigDecimal that can hold a Long if values are small enough.
@@ -223,6 +224,19 @@ final class Decimal extends Ordered[Decimal] with Serializable {
}
/**
+ * Create new `Decimal` with given precision and scale.
+ *
+ * @return `Some(decimal)` if successful or `None` if overflow would occur
+ */
+ private[sql] def toPrecision(
+ precision: Int,
+ scale: Int,
+ roundMode: BigDecimal.RoundingMode.Value = ROUND_HALF_UP): Option[Decimal] = {
+ val copy = clone()
+ if (copy.changePrecision(precision, scale, roundMode)) Some(copy) else None
+ }
+
+ /**
* Update precision and scale while keeping our value the same, and return true if successful.
*
* @return true if successful, false if overflow would occur
@@ -362,17 +376,15 @@ final class Decimal extends Ordered[Decimal] with Serializable {
def abs: Decimal = if (this.compare(Decimal.ZERO) < 0) this.unary_- else this
def floor: Decimal = if (scale == 0) this else {
- val value = this.clone()
- value.changePrecision(
- DecimalType.bounded(precision - scale + 1, 0).precision, 0, ROUND_FLOOR)
- value
+ val newPrecision = DecimalType.bounded(precision - scale + 1, 0).precision
+ toPrecision(newPrecision, 0, ROUND_FLOOR).getOrElse(
+ throw new AnalysisException(s"Overflow when setting precision to $newPrecision"))
}
def ceil: Decimal = if (scale == 0) this else {
- val value = this.clone()
- value.changePrecision(
- DecimalType.bounded(precision - scale + 1, 0).precision, 0, ROUND_CEILING)
- value
+ val newPrecision = DecimalType.bounded(precision - scale + 1, 0).precision
+ toPrecision(newPrecision, 0, ROUND_CEILING).getOrElse(
+ throw new AnalysisException(s"Overflow when setting precision to $newPrecision"))
}
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala
index 52d0692524..714883a409 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala
@@ -193,7 +193,7 @@ class DecimalSuite extends SparkFunSuite with PrivateMethodTester {
assert(Decimal(Long.MaxValue, 100, 0).toUnscaledLong === Long.MaxValue)
}
- test("changePrecision() on compact decimal should respect rounding mode") {
+ test("changePrecision/toPrecision on compact decimal should respect rounding mode") {
Seq(ROUND_FLOOR, ROUND_CEILING, ROUND_HALF_UP, ROUND_HALF_EVEN).foreach { mode =>
Seq("0.4", "0.5", "0.6", "1.0", "1.1", "1.6", "2.5", "5.5").foreach { n =>
Seq("", "-").foreach { sign =>
@@ -202,6 +202,12 @@ class DecimalSuite extends SparkFunSuite with PrivateMethodTester {
val d = Decimal(unscaled, 8, 1)
assert(d.changePrecision(10, 0, mode))
assert(d.toString === bd.setScale(0, mode).toString(), s"num: $sign$n, mode: $mode")
+
+ val copy = d.toPrecision(10, 0, mode).orNull
+ assert(copy !== null)
+ assert(d.ne(copy))
+ assert(d === copy)
+ assert(copy.toString === bd.setScale(0, mode).toString(), s"num: $sign$n, mode: $mode")
}
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MathFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MathFunctionsSuite.scala
index 37443d0342..328c5395ec 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/MathFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/MathFunctionsSuite.scala
@@ -233,6 +233,18 @@ class MathFunctionsSuite extends QueryTest with SharedSQLContext {
)
}
+ test("round/bround with data frame from a local Seq of Product") {
+ val df = spark.createDataFrame(Seq(Tuple1(BigDecimal("5.9")))).toDF("value")
+ checkAnswer(
+ df.withColumn("value_rounded", round('value)),
+ Seq(Row(BigDecimal("5.9"), BigDecimal("6")))
+ )
+ checkAnswer(
+ df.withColumn("value_brounded", bround('value)),
+ Seq(Row(BigDecimal("5.9"), BigDecimal("6")))
+ )
+ }
+
test("exp") {
testOneToOneMathFunction(exp, math.exp)
}