aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala42
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala12
2 files changed, 22 insertions, 32 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index 946c5a9c04..616b9e0e65 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -155,7 +155,7 @@ case class Cast(child: Expression, dataType: DataType)
case ByteType =>
buildCast[Byte](_, _ != 0)
case DecimalType() =>
- buildCast[Decimal](_, _ != Decimal.ZERO)
+ buildCast[Decimal](_, !_.isZero)
case DoubleType =>
buildCast[Double](_, _ != 0)
case FloatType =>
@@ -315,13 +315,13 @@ case class Cast(child: Expression, dataType: DataType)
case TimestampType =>
// Note that we lose precision here.
buildCast[Long](_, t => changePrecision(Decimal(timestampToDouble(t)), target))
- case DecimalType() =>
+ case dt: DecimalType =>
b => changePrecision(b.asInstanceOf[Decimal].clone(), target)
- case LongType =>
- b => changePrecision(Decimal(b.asInstanceOf[Long]), target)
- case x: NumericType => // All other numeric types can be represented precisely as Doubles
+ case t: IntegralType =>
+ b => changePrecision(Decimal(t.integral.asInstanceOf[Integral[Any]].toLong(b)), target)
+ case x: FractionalType =>
b => try {
- changePrecision(Decimal(x.numeric.asInstanceOf[Numeric[Any]].toDouble(b)), target)
+ changePrecision(Decimal(x.fractional.asInstanceOf[Fractional[Any]].toDouble(b)), target)
} catch {
case _: NumberFormatException => null
}
@@ -534,10 +534,7 @@ case class Cast(child: Expression, dataType: DataType)
(c, evPrim, evNull) =>
s"""
try {
- org.apache.spark.sql.types.Decimal tmpDecimal =
- new org.apache.spark.sql.types.Decimal().set(
- new scala.math.BigDecimal(
- new java.math.BigDecimal($c.toString())));
+ Decimal tmpDecimal = Decimal.apply(new java.math.BigDecimal($c.toString()));
${changePrecision("tmpDecimal", target, evPrim, evNull)}
} catch (java.lang.NumberFormatException e) {
$evNull = true;
@@ -546,12 +543,7 @@ case class Cast(child: Expression, dataType: DataType)
case BooleanType =>
(c, evPrim, evNull) =>
s"""
- org.apache.spark.sql.types.Decimal tmpDecimal = null;
- if ($c) {
- tmpDecimal = new org.apache.spark.sql.types.Decimal().set(1);
- } else {
- tmpDecimal = new org.apache.spark.sql.types.Decimal().set(0);
- }
+ Decimal tmpDecimal = $c ? Decimal.apply(1) : Decimal.apply(0);
${changePrecision("tmpDecimal", target, evPrim, evNull)}
"""
case DateType =>
@@ -561,32 +553,28 @@ case class Cast(child: Expression, dataType: DataType)
// Note that we lose precision here.
(c, evPrim, evNull) =>
s"""
- org.apache.spark.sql.types.Decimal tmpDecimal =
- new org.apache.spark.sql.types.Decimal().set(
- scala.math.BigDecimal.valueOf(${timestampToDoubleCode(c)}));
+ Decimal tmpDecimal = Decimal.apply(
+ scala.math.BigDecimal.valueOf(${timestampToDoubleCode(c)}));
${changePrecision("tmpDecimal", target, evPrim, evNull)}
"""
case DecimalType() =>
(c, evPrim, evNull) =>
s"""
- org.apache.spark.sql.types.Decimal tmpDecimal = $c.clone();
+ Decimal tmpDecimal = $c.clone();
${changePrecision("tmpDecimal", target, evPrim, evNull)}
"""
- case LongType =>
+ case x: IntegralType =>
(c, evPrim, evNull) =>
s"""
- org.apache.spark.sql.types.Decimal tmpDecimal =
- new org.apache.spark.sql.types.Decimal().set($c);
+ Decimal tmpDecimal = Decimal.apply((long) $c);
${changePrecision("tmpDecimal", target, evPrim, evNull)}
"""
- case x: NumericType =>
+ case x: FractionalType =>
// All other numeric types can be represented precisely as Doubles
(c, evPrim, evNull) =>
s"""
try {
- org.apache.spark.sql.types.Decimal tmpDecimal =
- new org.apache.spark.sql.types.Decimal().set(
- scala.math.BigDecimal.valueOf((double) $c));
+ Decimal tmpDecimal = Decimal.apply(scala.math.BigDecimal.valueOf((double) $c));
${changePrecision("tmpDecimal", target, evPrim, evNull)}
} catch (java.lang.NumberFormatException e) {
$evNull = true;
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala
index 624c3f3d7f..d95805c245 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala
@@ -139,9 +139,9 @@ final class Decimal extends Ordered[Decimal] with Serializable {
def toBigDecimal: BigDecimal = {
if (decimalVal.ne(null)) {
- decimalVal(MATH_CONTEXT)
+ decimalVal
} else {
- BigDecimal(longVal, _scale)(MATH_CONTEXT)
+ BigDecimal(longVal, _scale)
}
}
@@ -280,13 +280,15 @@ final class Decimal extends Ordered[Decimal] with Serializable {
}
// HiveTypeCoercion will take care of the precision, scale of result
- def * (that: Decimal): Decimal = Decimal(toBigDecimal * that.toBigDecimal)
+ def * (that: Decimal): Decimal =
+ Decimal(toJavaBigDecimal.multiply(that.toJavaBigDecimal, MATH_CONTEXT))
def / (that: Decimal): Decimal =
- if (that.isZero) null else Decimal(toBigDecimal / that.toBigDecimal)
+ if (that.isZero) null else Decimal(toJavaBigDecimal.divide(that.toJavaBigDecimal, MATH_CONTEXT))
def % (that: Decimal): Decimal =
- if (that.isZero) null else Decimal(toBigDecimal % that.toBigDecimal)
+ if (that.isZero) null
+ else Decimal(toJavaBigDecimal.remainder(that.toJavaBigDecimal, MATH_CONTEXT))
def remainder(that: Decimal): Decimal = this % that