aboutsummaryrefslogtreecommitdiff
path: root/sql/catalyst
diff options
context:
space:
mode:
authorMatei Zaharia <matei@databricks.com>2014-11-01 19:29:14 -0700
committerMichael Armbrust <michael@databricks.com>2014-11-01 19:29:14 -0700
commit23f966f47523f85ba440b4080eee665271f53b5e (patch)
treed796351567f8b187511b9049199cbf99c5826fb3 /sql/catalyst
parent56f2c61cde3f5d906c2a58e9af1a661222f2c679 (diff)
downloadspark-23f966f47523f85ba440b4080eee665271f53b5e.tar.gz
spark-23f966f47523f85ba440b4080eee665271f53b5e.tar.bz2
spark-23f966f47523f85ba440b4080eee665271f53b5e.zip
[SPARK-3930] [SPARK-3933] Support fixed-precision decimal in SQL, and some optimizations
- Adds optional precision and scale to Spark SQL's decimal type, which behave similarly to those in Hive 13 (https://cwiki.apache.org/confluence/download/attachments/27362075/Hive_Decimal_Precision_Scale_Support.pdf) - Replaces our internal representation of decimals with a Decimal class that can store small values in a mutable Long, saving memory in this situation and letting some operations happen directly on Longs This is still marked WIP because there are a few TODOs, but I'll remove that tag when done. Author: Matei Zaharia <matei@databricks.com> Closes #2983 from mateiz/decimal-1 and squashes the following commits: 35e6b02 [Matei Zaharia] Fix issues after merge 227f24a [Matei Zaharia] Review comments 31f915e [Matei Zaharia] Implement Davies's suggestions in Python eb84820 [Matei Zaharia] Support reading/writing decimals as fixed-length binary in Parquet 4dc6bae [Matei Zaharia] Fix decimal support in PySpark d1d9d68 [Matei Zaharia] Fix compile error and test issues after rebase b28933d [Matei Zaharia] Support decimal precision/scale in Hive metastore 2118c0d [Matei Zaharia] Some test and bug fixes 81db9cb [Matei Zaharia] Added mutable Decimal that will be more efficient for small precisions 7af0c3b [Matei Zaharia] Add optional precision and scale to DecimalType, but use Unlimited for now ec0a947 [Matei Zaharia] Make the result of AVG on Decimals be Decimal, not Double
Diffstat (limited to 'sql/catalyst')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala20
-rwxr-xr-xsql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala14
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala146
-rwxr-xr-xsql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala11
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala78
-rwxr-xr-xsql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala55
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala10
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala31
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala59
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala6
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala38
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala84
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/decimal/Decimal.scala335
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala14
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala6
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala88
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala17
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala90
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/types/decimal/DecimalSuite.scala158
19 files changed, 1154 insertions, 106 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index 75923d9e8d..8fbdf664b7 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -19,9 +19,10 @@ package org.apache.spark.sql.catalyst
import java.sql.{Date, Timestamp}
-import org.apache.spark.sql.catalyst.expressions.{GenericRow, Attribute, AttributeReference}
+import org.apache.spark.sql.catalyst.expressions.{GenericRow, Attribute, AttributeReference, Row}
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
import org.apache.spark.sql.catalyst.types._
+import org.apache.spark.sql.catalyst.types.decimal.Decimal
/**
* Provides experimental support for generating catalyst schemas for scala objects.
@@ -40,9 +41,20 @@ object ScalaReflection {
case s: Seq[_] => s.map(convertToCatalyst)
case m: Map[_, _] => m.map { case (k, v) => convertToCatalyst(k) -> convertToCatalyst(v) }
case p: Product => new GenericRow(p.productIterator.map(convertToCatalyst).toArray)
+ case d: BigDecimal => Decimal(d)
case other => other
}
+ /** Converts Catalyst types used internally in rows to standard Scala types */
+ def convertToScala(a: Any): Any = a match {
+ case s: Seq[_] => s.map(convertToScala)
+ case m: Map[_, _] => m.map { case (k, v) => convertToScala(k) -> convertToScala(v) }
+ case d: Decimal => d.toBigDecimal
+ case other => other
+ }
+
+ def convertRowToScala(r: Row): Row = new GenericRow(r.toArray.map(convertToScala))
+
/** Returns a Sequence of attributes for the given case class type. */
def attributesFor[T: TypeTag]: Seq[Attribute] = schemaFor[T] match {
case Schema(s: StructType, _) =>
@@ -83,7 +95,8 @@ object ScalaReflection {
case t if t <:< typeOf[String] => Schema(StringType, nullable = true)
case t if t <:< typeOf[Timestamp] => Schema(TimestampType, nullable = true)
case t if t <:< typeOf[Date] => Schema(DateType, nullable = true)
- case t if t <:< typeOf[BigDecimal] => Schema(DecimalType, nullable = true)
+ case t if t <:< typeOf[BigDecimal] => Schema(DecimalType.Unlimited, nullable = true)
+ case t if t <:< typeOf[Decimal] => Schema(DecimalType.Unlimited, nullable = true)
case t if t <:< typeOf[java.lang.Integer] => Schema(IntegerType, nullable = true)
case t if t <:< typeOf[java.lang.Long] => Schema(LongType, nullable = true)
case t if t <:< typeOf[java.lang.Double] => Schema(DoubleType, nullable = true)
@@ -111,8 +124,9 @@ object ScalaReflection {
case obj: LongType.JvmType => LongType
case obj: FloatType.JvmType => FloatType
case obj: DoubleType.JvmType => DoubleType
- case obj: DecimalType.JvmType => DecimalType
case obj: DateType.JvmType => DateType
+ case obj: BigDecimal => DecimalType.Unlimited
+ case obj: Decimal => DecimalType.Unlimited
case obj: TimestampType.JvmType => TimestampType
case null => NullType
// For other cases, there is no obvious mapping from the type of the given object to a
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
index b1e7570f57..00fc4d75c9 100755
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
@@ -52,11 +52,13 @@ class SqlParser extends AbstractSparkSQLParser {
protected val CASE = Keyword("CASE")
protected val CAST = Keyword("CAST")
protected val COUNT = Keyword("COUNT")
+ protected val DECIMAL = Keyword("DECIMAL")
protected val DESC = Keyword("DESC")
protected val DISTINCT = Keyword("DISTINCT")
protected val ELSE = Keyword("ELSE")
protected val END = Keyword("END")
protected val EXCEPT = Keyword("EXCEPT")
+ protected val DOUBLE = Keyword("DOUBLE")
protected val FALSE = Keyword("FALSE")
protected val FIRST = Keyword("FIRST")
protected val FROM = Keyword("FROM")
@@ -385,5 +387,15 @@ class SqlParser extends AbstractSparkSQLParser {
}
protected lazy val dataType: Parser[DataType] =
- STRING ^^^ StringType | TIMESTAMP ^^^ TimestampType
+ ( STRING ^^^ StringType
+ | TIMESTAMP ^^^ TimestampType
+ | DOUBLE ^^^ DoubleType
+ | fixedDecimalType
+ | DECIMAL ^^^ DecimalType.Unlimited
+ )
+
+ protected lazy val fixedDecimalType: Parser[DataType] =
+ (DECIMAL ~ "(" ~> numericLit) ~ ("," ~> numericLit <~ ")") ^^ {
+ case precision ~ scale => DecimalType(precision.toInt, scale.toInt)
+ }
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
index 2b69c02b28..e38114ab3c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
@@ -25,19 +25,31 @@ import org.apache.spark.sql.catalyst.types._
object HiveTypeCoercion {
// See https://cwiki.apache.org/confluence/display/Hive/LanguageManual+Types.
// The conversion for integral and floating point types have a linear widening hierarchy:
- val numericPrecedence =
- Seq(ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType, DecimalType)
- val allPromotions: Seq[Seq[DataType]] = numericPrecedence :: Nil
+ private val numericPrecedence =
+ Seq(ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType, DecimalType.Unlimited)
+ /**
+ * Find the tightest common type of two types that might be used in a binary expression.
+ * This handles all numeric types except fixed-precision decimals interacting with each other or
+ * with primitive types, because in that case the precision and scale of the result depends on
+ * the operation. Those rules are implemented in [[HiveTypeCoercion.DecimalPrecision]].
+ */
def findTightestCommonType(t1: DataType, t2: DataType): Option[DataType] = {
val valueTypes = Seq(t1, t2).filter(t => t != NullType)
if (valueTypes.distinct.size > 1) {
- // Try and find a promotion rule that contains both types in question.
- val applicableConversion =
- HiveTypeCoercion.allPromotions.find(p => p.contains(t1) && p.contains(t2))
-
- // If found return the widest common type, otherwise None
- applicableConversion.map(_.filter(t => t == t1 || t == t2).last)
+ // Promote numeric types to the highest of the two and all numeric types to unlimited decimal
+ if (numericPrecedence.contains(t1) && numericPrecedence.contains(t2)) {
+ Some(numericPrecedence.filter(t => t == t1 || t == t2).last)
+ } else if (t1.isInstanceOf[DecimalType] && t2.isInstanceOf[DecimalType]) {
+ // Fixed-precision decimals can up-cast into unlimited
+ if (t1 == DecimalType.Unlimited || t2 == DecimalType.Unlimited) {
+ Some(DecimalType.Unlimited)
+ } else {
+ None
+ }
+ } else {
+ None
+ }
} else {
Some(if (valueTypes.size == 0) NullType else valueTypes.head)
}
@@ -59,6 +71,7 @@ trait HiveTypeCoercion {
ConvertNaNs ::
WidenTypes ::
PromoteStrings ::
+ DecimalPrecision ::
BooleanComparisons ::
BooleanCasts ::
StringToIntegralCasts ::
@@ -151,6 +164,7 @@ trait HiveTypeCoercion {
import HiveTypeCoercion._
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+ // TODO: unions with fixed-precision decimals
case u @ Union(left, right) if u.childrenResolved && !u.resolved =>
val castedInput = left.output.zip(right.output).map {
// When a string is found on one side, make the other side a string too.
@@ -265,6 +279,110 @@ trait HiveTypeCoercion {
}
}
+ // scalastyle:off
+ /**
+ * Calculates and propagates precision for fixed-precision decimals. Hive has a number of
+ * rules for this based on the SQL standard and MS SQL:
+ * https://cwiki.apache.org/confluence/download/attachments/27362075/Hive_Decimal_Precision_Scale_Support.pdf
+ *
+ * In particular, if we have expressions e1 and e2 with precision/scale p1/s2 and p2/s2
+ * respectively, then the following operations have the following precision / scale:
+ *
+ * Operation Result Precision Result Scale
+ * ------------------------------------------------------------------------
+ * e1 + e2 max(s1, s2) + max(p1-s1, p2-s2) + 1 max(s1, s2)
+ * e1 - e2 max(s1, s2) + max(p1-s1, p2-s2) + 1 max(s1, s2)
+ * e1 * e2 p1 + p2 + 1 s1 + s2
+ * e1 / e2 p1 - s1 + s2 + max(6, s1 + p2 + 1) max(6, s1 + p2 + 1)
+ * e1 % e2 min(p1-s1, p2-s2) + max(s1, s2) max(s1, s2)
+ * sum(e1) p1 + 10 s1
+ * avg(e1) p1 + 4 s1 + 4
+ *
+ * Catalyst also has unlimited-precision decimals. For those, all ops return unlimited precision.
+ *
+ * To implement the rules for fixed-precision types, we introduce casts to turn them to unlimited
+ * precision, do the math on unlimited-precision numbers, then introduce casts back to the
+ * required fixed precision. This allows us to do all rounding and overflow handling in the
+ * cast-to-fixed-precision operator.
+ *
+ * In addition, when mixing non-decimal types with decimals, we use the following rules:
+ * - BYTE gets turned into DECIMAL(3, 0)
+ * - SHORT gets turned into DECIMAL(5, 0)
+ * - INT gets turned into DECIMAL(10, 0)
+ * - LONG gets turned into DECIMAL(20, 0)
+ * - FLOAT and DOUBLE cause fixed-length decimals to turn into DOUBLE (this is the same as Hive,
+ * but note that unlimited decimals are considered bigger than doubles in WidenTypes)
+ */
+ // scalastyle:on
+ object DecimalPrecision extends Rule[LogicalPlan] {
+ import scala.math.{max, min}
+
+ // Conversion rules for integer types into fixed-precision decimals
+ val intTypeToFixed: Map[DataType, DecimalType] = Map(
+ ByteType -> DecimalType(3, 0),
+ ShortType -> DecimalType(5, 0),
+ IntegerType -> DecimalType(10, 0),
+ LongType -> DecimalType(20, 0)
+ )
+
+ def isFloat(t: DataType): Boolean = t == FloatType || t == DoubleType
+
+ def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
+ // Skip nodes whose children have not been resolved yet
+ case e if !e.childrenResolved => e
+
+ case Add(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) =>
+ Cast(
+ Add(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)),
+ DecimalType(max(s1, s2) + max(p1 - s1, p2 - s2) + 1, max(s1, s2))
+ )
+
+ case Subtract(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) =>
+ Cast(
+ Subtract(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)),
+ DecimalType(max(s1, s2) + max(p1 - s1, p2 - s2) + 1, max(s1, s2))
+ )
+
+ case Multiply(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) =>
+ Cast(
+ Multiply(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)),
+ DecimalType(p1 + p2 + 1, s1 + s2)
+ )
+
+ case Divide(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) =>
+ Cast(
+ Divide(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)),
+ DecimalType(p1 - s1 + s2 + max(6, s1 + p2 + 1), max(6, s1 + p2 + 1))
+ )
+
+ case Remainder(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) =>
+ Cast(
+ Remainder(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)),
+ DecimalType(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2))
+ )
+
+ // Promote integers inside a binary expression with fixed-precision decimals to decimals,
+ // and fixed-precision decimals in an expression with floats / doubles to doubles
+ case b: BinaryExpression if b.left.dataType != b.right.dataType =>
+ (b.left.dataType, b.right.dataType) match {
+ case (t, DecimalType.Fixed(p, s)) if intTypeToFixed.contains(t) =>
+ b.makeCopy(Array(Cast(b.left, intTypeToFixed(t)), b.right))
+ case (DecimalType.Fixed(p, s), t) if intTypeToFixed.contains(t) =>
+ b.makeCopy(Array(b.left, Cast(b.right, intTypeToFixed(t))))
+ case (t, DecimalType.Fixed(p, s)) if isFloat(t) =>
+ b.makeCopy(Array(b.left, Cast(b.right, DoubleType)))
+ case (DecimalType.Fixed(p, s), t) if isFloat(t) =>
+ b.makeCopy(Array(Cast(b.left, DoubleType), b.right))
+ case _ =>
+ b
+ }
+
+ // TODO: MaxOf, MinOf, etc might want other rules
+
+ // SUM and AVERAGE are handled by the implementations of those expressions
+ }
+ }
+
/**
* Changes Boolean values to Bytes so that expressions like true < false can be Evaluated.
*/
@@ -330,7 +448,7 @@ trait HiveTypeCoercion {
case e if !e.childrenResolved => e
case Cast(e @ StringType(), t: IntegralType) =>
- Cast(Cast(e, DecimalType), t)
+ Cast(Cast(e, DecimalType.Unlimited), t)
}
}
@@ -383,10 +501,12 @@ trait HiveTypeCoercion {
// Decimal and Double remain the same
case d: Divide if d.resolved && d.dataType == DoubleType => d
- case d: Divide if d.resolved && d.dataType == DecimalType => d
+ case d: Divide if d.resolved && d.dataType.isInstanceOf[DecimalType] => d
- case Divide(l, r) if l.dataType == DecimalType => Divide(l, Cast(r, DecimalType))
- case Divide(l, r) if r.dataType == DecimalType => Divide(Cast(l, DecimalType), r)
+ case Divide(l, r) if l.dataType.isInstanceOf[DecimalType] =>
+ Divide(l, Cast(r, DecimalType.Unlimited))
+ case Divide(l, r) if r.dataType.isInstanceOf[DecimalType] =>
+ Divide(Cast(l, DecimalType.Unlimited), r)
case Divide(l, r) => Divide(Cast(l, DoubleType), Cast(r, DoubleType))
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
index 23cfd483ec..7e6d770314 100755
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
@@ -19,6 +19,8 @@ package org.apache.spark.sql.catalyst
import java.sql.{Date, Timestamp}
+import org.apache.spark.sql.catalyst.types.decimal.Decimal
+
import scala.language.implicitConversions
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
@@ -124,7 +126,8 @@ package object dsl {
implicit def doubleToLiteral(d: Double) = Literal(d)
implicit def stringToLiteral(s: String) = Literal(s)
implicit def dateToLiteral(d: Date) = Literal(d)
- implicit def decimalToLiteral(d: BigDecimal) = Literal(d)
+ implicit def bigDecimalToLiteral(d: BigDecimal) = Literal(d)
+ implicit def decimalToLiteral(d: Decimal) = Literal(d)
implicit def timestampToLiteral(t: Timestamp) = Literal(t)
implicit def binaryToLiteral(a: Array[Byte]) = Literal(a)
@@ -183,7 +186,11 @@ package object dsl {
def date = AttributeReference(s, DateType, nullable = true)()
/** Creates a new AttributeReference of type decimal */
- def decimal = AttributeReference(s, DecimalType, nullable = true)()
+ def decimal = AttributeReference(s, DecimalType.Unlimited, nullable = true)()
+
+ /** Creates a new AttributeReference of type decimal */
+ def decimal(precision: Int, scale: Int) =
+ AttributeReference(s, DecimalType(precision, scale), nullable = true)()
/** Creates a new AttributeReference of type timestamp */
def timestamp = AttributeReference(s, TimestampType, nullable = true)()
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index 8e5baf0eb8..2200966619 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -23,6 +23,7 @@ import java.text.{DateFormat, SimpleDateFormat}
import org.apache.spark.Logging
import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.catalyst.types._
+import org.apache.spark.sql.catalyst.types.decimal.Decimal
/** Cast the child expression to the target data type. */
case class Cast(child: Expression, dataType: DataType) extends UnaryExpression with Logging {
@@ -36,6 +37,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
case (BooleanType, DateType) => true
case (DateType, _: NumericType) => true
case (DateType, BooleanType) => true
+ case (_, DecimalType.Fixed(_, _)) => true // TODO: not all upcasts here can really give null
case _ => child.nullable
}
@@ -76,8 +78,8 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
buildCast[Short](_, _ != 0)
case ByteType =>
buildCast[Byte](_, _ != 0)
- case DecimalType =>
- buildCast[BigDecimal](_, _ != 0)
+ case DecimalType() =>
+ buildCast[Decimal](_, _ != 0)
case DoubleType =>
buildCast[Double](_, _ != 0)
case FloatType =>
@@ -109,19 +111,19 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
case DateType =>
buildCast[Date](_, d => new Timestamp(d.getTime))
// TimestampWritable.decimalToTimestamp
- case DecimalType =>
- buildCast[BigDecimal](_, d => decimalToTimestamp(d))
+ case DecimalType() =>
+ buildCast[Decimal](_, d => decimalToTimestamp(d))
// TimestampWritable.doubleToTimestamp
case DoubleType =>
- buildCast[Double](_, d => decimalToTimestamp(d))
+ buildCast[Double](_, d => decimalToTimestamp(Decimal(d)))
// TimestampWritable.floatToTimestamp
case FloatType =>
- buildCast[Float](_, f => decimalToTimestamp(f))
+ buildCast[Float](_, f => decimalToTimestamp(Decimal(f)))
}
- private[this] def decimalToTimestamp(d: BigDecimal) = {
+ private[this] def decimalToTimestamp(d: Decimal) = {
val seconds = Math.floor(d.toDouble).toLong
- val bd = (d - seconds) * 1000000000
+ val bd = (d.toBigDecimal - seconds) * 1000000000
val nanos = bd.intValue()
val millis = seconds * 1000
@@ -196,8 +198,8 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
buildCast[Date](_, d => dateToLong(d))
case TimestampType =>
buildCast[Timestamp](_, t => timestampToLong(t))
- case DecimalType =>
- buildCast[BigDecimal](_, _.toLong)
+ case DecimalType() =>
+ buildCast[Decimal](_, _.toLong)
case x: NumericType =>
b => x.numeric.asInstanceOf[Numeric[Any]].toLong(b)
}
@@ -214,8 +216,8 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
buildCast[Date](_, d => dateToLong(d))
case TimestampType =>
buildCast[Timestamp](_, t => timestampToLong(t).toInt)
- case DecimalType =>
- buildCast[BigDecimal](_, _.toInt)
+ case DecimalType() =>
+ buildCast[Decimal](_, _.toInt)
case x: NumericType =>
b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b)
}
@@ -232,8 +234,8 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
buildCast[Date](_, d => dateToLong(d))
case TimestampType =>
buildCast[Timestamp](_, t => timestampToLong(t).toShort)
- case DecimalType =>
- buildCast[BigDecimal](_, _.toShort)
+ case DecimalType() =>
+ buildCast[Decimal](_, _.toShort)
case x: NumericType =>
b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b).toShort
}
@@ -250,27 +252,45 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
buildCast[Date](_, d => dateToLong(d))
case TimestampType =>
buildCast[Timestamp](_, t => timestampToLong(t).toByte)
- case DecimalType =>
- buildCast[BigDecimal](_, _.toByte)
+ case DecimalType() =>
+ buildCast[Decimal](_, _.toByte)
case x: NumericType =>
b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b).toByte
}
- // DecimalConverter
- private[this] def castToDecimal: Any => Any = child.dataType match {
+ /**
+ * Change the precision / scale in a given decimal to those set in `decimalType` (if any),
+ * returning null if it overflows or modifying `value` in-place and returning it if successful.
+ *
+ * NOTE: this modifies `value` in-place, so don't call it on external data.
+ */
+ private[this] def changePrecision(value: Decimal, decimalType: DecimalType): Decimal = {
+ decimalType match {
+ case DecimalType.Unlimited =>
+ value
+ case DecimalType.Fixed(precision, scale) =>
+ if (value.changePrecision(precision, scale)) value else null
+ }
+ }
+
+ private[this] def castToDecimal(target: DecimalType): Any => Any = child.dataType match {
case StringType =>
- buildCast[String](_, s => try BigDecimal(s.toDouble) catch {
+ buildCast[String](_, s => try changePrecision(Decimal(s.toDouble), target) catch {
case _: NumberFormatException => null
})
case BooleanType =>
- buildCast[Boolean](_, b => if (b) BigDecimal(1) else BigDecimal(0))
+ buildCast[Boolean](_, b => changePrecision(if (b) Decimal(1) else Decimal(0), target))
case DateType =>
- buildCast[Date](_, d => dateToDouble(d))
+ buildCast[Date](_, d => changePrecision(null, target)) // date can't cast to decimal in Hive
case TimestampType =>
// Note that we lose precision here.
- buildCast[Timestamp](_, t => BigDecimal(timestampToDouble(t)))
- case x: NumericType =>
- b => BigDecimal(x.numeric.asInstanceOf[Numeric[Any]].toDouble(b))
+ buildCast[Timestamp](_, t => changePrecision(Decimal(timestampToDouble(t)), target))
+ case DecimalType() =>
+ b => changePrecision(b.asInstanceOf[Decimal].clone(), target)
+ case LongType =>
+ b => changePrecision(Decimal(b.asInstanceOf[Long]), target)
+ case x: NumericType => // All other numeric types can be represented precisely as Doubles
+ b => changePrecision(Decimal(x.numeric.asInstanceOf[Numeric[Any]].toDouble(b)), target)
}
// DoubleConverter
@@ -285,8 +305,8 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
buildCast[Date](_, d => dateToDouble(d))
case TimestampType =>
buildCast[Timestamp](_, t => timestampToDouble(t))
- case DecimalType =>
- buildCast[BigDecimal](_, _.toDouble)
+ case DecimalType() =>
+ buildCast[Decimal](_, _.toDouble)
case x: NumericType =>
b => x.numeric.asInstanceOf[Numeric[Any]].toDouble(b)
}
@@ -303,8 +323,8 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
buildCast[Date](_, d => dateToDouble(d))
case TimestampType =>
buildCast[Timestamp](_, t => timestampToDouble(t).toFloat)
- case DecimalType =>
- buildCast[BigDecimal](_, _.toFloat)
+ case DecimalType() =>
+ buildCast[Decimal](_, _.toFloat)
case x: NumericType =>
b => x.numeric.asInstanceOf[Numeric[Any]].toFloat(b)
}
@@ -313,8 +333,8 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
case dt if dt == child.dataType => identity[Any]
case StringType => castToString
case BinaryType => castToBinary
- case DecimalType => castToDecimal
case DateType => castToDate
+ case decimal: DecimalType => castToDecimal(decimal)
case TimestampType => castToTimestamp
case BooleanType => castToBoolean
case ByteType => castToByte
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
index 1b4d892625..2b364fc1df 100755
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
@@ -286,18 +286,38 @@ case class ApproxCountDistinct(child: Expression, relativeSD: Double = 0.05)
case class Average(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {
override def nullable = false
- override def dataType = DoubleType
+
+ override def dataType = child.dataType match {
+ case DecimalType.Fixed(precision, scale) =>
+ DecimalType(precision + 4, scale + 4) // Add 4 digits after decimal point, like Hive
+ case DecimalType.Unlimited =>
+ DecimalType.Unlimited
+ case _ =>
+ DoubleType
+ }
+
override def toString = s"AVG($child)"
override def asPartial: SplitEvaluation = {
val partialSum = Alias(Sum(child), "PartialSum")()
val partialCount = Alias(Count(child), "PartialCount")()
- val castedSum = Cast(Sum(partialSum.toAttribute), dataType)
- val castedCount = Cast(Sum(partialCount.toAttribute), dataType)
- SplitEvaluation(
- Divide(castedSum, castedCount),
- partialCount :: partialSum :: Nil)
+ child.dataType match {
+ case DecimalType.Fixed(_, _) =>
+ // Turn the results to unlimited decimals for the divsion, before going back to fixed
+ val castedSum = Cast(Sum(partialSum.toAttribute), DecimalType.Unlimited)
+ val castedCount = Cast(Sum(partialCount.toAttribute), DecimalType.Unlimited)
+ SplitEvaluation(
+ Cast(Divide(castedSum, castedCount), dataType),
+ partialCount :: partialSum :: Nil)
+
+ case _ =>
+ val castedSum = Cast(Sum(partialSum.toAttribute), dataType)
+ val castedCount = Cast(Sum(partialCount.toAttribute), dataType)
+ SplitEvaluation(
+ Divide(castedSum, castedCount),
+ partialCount :: partialSum :: Nil)
+ }
}
override def newInstance() = new AverageFunction(child, this)
@@ -306,7 +326,16 @@ case class Average(child: Expression) extends PartialAggregate with trees.UnaryN
case class Sum(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {
override def nullable = false
- override def dataType = child.dataType
+
+ override def dataType = child.dataType match {
+ case DecimalType.Fixed(precision, scale) =>
+ DecimalType(precision + 10, scale) // Add 10 digits left of decimal point, like Hive
+ case DecimalType.Unlimited =>
+ DecimalType.Unlimited
+ case _ =>
+ child.dataType
+ }
+
override def toString = s"SUM($child)"
override def asPartial: SplitEvaluation = {
@@ -322,9 +351,17 @@ case class Sum(child: Expression) extends PartialAggregate with trees.UnaryNode[
case class SumDistinct(child: Expression)
extends AggregateExpression with trees.UnaryNode[Expression] {
-
override def nullable = false
- override def dataType = child.dataType
+
+ override def dataType = child.dataType match {
+ case DecimalType.Fixed(precision, scale) =>
+ DecimalType(precision + 10, scale) // Add 10 digits left of decimal point, like Hive
+ case DecimalType.Unlimited =>
+ DecimalType.Unlimited
+ case _ =>
+ child.dataType
+ }
+
override def toString = s"SUM(DISTINCT $child)"
override def newInstance() = new SumDistinctFunction(child, this)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
index 83e8466ec2..8574cabc43 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
@@ -36,7 +36,7 @@ case class UnaryMinus(child: Expression) extends UnaryExpression {
case class Sqrt(child: Expression) extends UnaryExpression {
type EvaluatedType = Any
-
+
def dataType = DoubleType
override def foldable = child.foldable
def nullable = child.nullable
@@ -55,7 +55,9 @@ abstract class BinaryArithmetic extends BinaryExpression {
def nullable = left.nullable || right.nullable
override lazy val resolved =
- left.resolved && right.resolved && left.dataType == right.dataType
+ left.resolved && right.resolved &&
+ left.dataType == right.dataType &&
+ !DecimalType.isFixed(left.dataType)
def dataType = {
if (!resolved) {
@@ -104,6 +106,8 @@ case class Multiply(left: Expression, right: Expression) extends BinaryArithmeti
case class Divide(left: Expression, right: Expression) extends BinaryArithmetic {
def symbol = "/"
+ override def nullable = left.nullable || right.nullable || dataType.isInstanceOf[DecimalType]
+
override def eval(input: Row): Any = dataType match {
case _: FractionalType => f2(input, left, right, _.div(_, _))
case _: IntegralType => i2(input, left , right, _.quot(_, _))
@@ -114,6 +118,8 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic
case class Remainder(left: Expression, right: Expression) extends BinaryArithmetic {
def symbol = "%"
+ override def nullable = left.nullable || right.nullable || dataType.isInstanceOf[DecimalType]
+
override def eval(input: Row): Any = i2(input, left, right, _.rem(_, _))
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index 5a3f013c34..67f8d411b6 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst.expressions.codegen
import com.google.common.cache.{CacheLoader, CacheBuilder}
+import org.apache.spark.sql.catalyst.types.decimal.Decimal
import scala.language.existentials
@@ -485,6 +486,34 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
}
""".children
+ case UnscaledValue(child) =>
+ val childEval = expressionEvaluator(child)
+
+ childEval.code ++
+ q"""
+ var $nullTerm = ${childEval.nullTerm}
+ var $primitiveTerm: Long = if (!$nullTerm) {
+ ${childEval.primitiveTerm}.toUnscaledLong
+ } else {
+ ${defaultPrimitive(LongType)}
+ }
+ """.children
+
+ case MakeDecimal(child, precision, scale) =>
+ val childEval = expressionEvaluator(child)
+
+ childEval.code ++
+ q"""
+ var $nullTerm = ${childEval.nullTerm}
+ var $primitiveTerm: org.apache.spark.sql.catalyst.types.decimal.Decimal =
+ ${defaultPrimitive(DecimalType())}
+
+ if (!$nullTerm) {
+ $primitiveTerm = new org.apache.spark.sql.catalyst.types.decimal.Decimal()
+ $primitiveTerm = $primitiveTerm.setOrNull(${childEval.primitiveTerm}, $precision, $scale)
+ $nullTerm = $primitiveTerm == null
+ }
+ """.children
}
// If there was no match in the partial function above, we fall back on calling the interpreted
@@ -562,7 +591,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
case LongType => ru.Literal(Constant(1L))
case ByteType => ru.Literal(Constant(-1.toByte))
case DoubleType => ru.Literal(Constant(-1.toDouble))
- case DecimalType => ru.Literal(Constant(-1)) // Will get implicity converted as needed.
+ case DecimalType() => q"org.apache.spark.sql.catalyst.types.decimal.Decimal(-1)"
case IntegerType => ru.Literal(Constant(-1))
case _ => ru.Literal(Constant(null))
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala
new file mode 100644
index 0000000000..d1eab2eb4e
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala
@@ -0,0 +1,59 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions
+
+import org.apache.spark.sql.catalyst.types.decimal.Decimal
+import org.apache.spark.sql.catalyst.types.{DecimalType, LongType, DoubleType, DataType}
+
+/** Return the unscaled Long value of a Decimal, assuming it fits in a Long */
+case class UnscaledValue(child: Expression) extends UnaryExpression {
+ override type EvaluatedType = Any
+
+ override def dataType: DataType = LongType
+ override def foldable = child.foldable
+ def nullable = child.nullable
+ override def toString = s"UnscaledValue($child)"
+
+ override def eval(input: Row): Any = {
+ val childResult = child.eval(input)
+ if (childResult == null) {
+ null
+ } else {
+ childResult.asInstanceOf[Decimal].toUnscaledLong
+ }
+ }
+}
+
+/** Create a Decimal from an unscaled Long value */
+case class MakeDecimal(child: Expression, precision: Int, scale: Int) extends UnaryExpression {
+ override type EvaluatedType = Decimal
+
+ override def dataType: DataType = DecimalType(precision, scale)
+ override def foldable = child.foldable
+ def nullable = child.nullable
+ override def toString = s"MakeDecimal($child,$precision,$scale)"
+
+ override def eval(input: Row): Decimal = {
+ val childResult = child.eval(input)
+ if (childResult == null) {
+ null
+ } else {
+ new Decimal().setOrNull(childResult.asInstanceOf[Long], precision, scale)
+ }
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
index ba240233ca..93c1932515 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions
import java.sql.{Date, Timestamp}
import org.apache.spark.sql.catalyst.types._
+import org.apache.spark.sql.catalyst.types.decimal.Decimal
object Literal {
def apply(v: Any): Literal = v match {
@@ -31,7 +32,8 @@ object Literal {
case s: Short => Literal(s, ShortType)
case s: String => Literal(s, StringType)
case b: Boolean => Literal(b, BooleanType)
- case d: BigDecimal => Literal(d, DecimalType)
+ case d: BigDecimal => Literal(Decimal(d), DecimalType.Unlimited)
+ case d: Decimal => Literal(d, DecimalType.Unlimited)
case t: Timestamp => Literal(t, TimestampType)
case d: Date => Literal(d, DateType)
case a: Array[Byte] => Literal(a, BinaryType)
@@ -62,7 +64,7 @@ case class Literal(value: Any, dataType: DataType) extends LeafExpression {
}
// TODO: Specialize
-case class MutableLiteral(var value: Any, dataType: DataType, nullable: Boolean = true)
+case class MutableLiteral(var value: Any, dataType: DataType, nullable: Boolean = true)
extends LeafExpression {
type EvaluatedType = Any
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 9ce7c78195..a4aa322fc5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.plans.LeftSemi
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.catalyst.types._
+import org.apache.spark.sql.catalyst.types.decimal.Decimal
abstract class Optimizer extends RuleExecutor[LogicalPlan]
@@ -43,6 +44,8 @@ object DefaultOptimizer extends Optimizer {
SimplifyCasts,
SimplifyCaseConversionExpressions,
OptimizeIn) ::
+ Batch("Decimal Optimizations", FixedPoint(100),
+ DecimalAggregates) ::
Batch("Filter Pushdown", FixedPoint(100),
UnionPushdown,
CombineFilters,
@@ -390,9 +393,9 @@ object PushPredicateThroughProject extends Rule[LogicalPlan] {
* evaluated using only the attributes of the left or right side of a join. Other
* [[Filter]] conditions are moved into the `condition` of the [[Join]].
*
- * And also Pushes down the join filter, where the `condition` can be evaluated using only the
- * attributes of the left or right side of sub query when applicable.
- *
+ * And also Pushes down the join filter, where the `condition` can be evaluated using only the
+ * attributes of the left or right side of sub query when applicable.
+ *
* Check https://cwiki.apache.org/confluence/display/Hive/OuterJoinBehavior for more details
*/
object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper {
@@ -404,7 +407,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper {
private def split(condition: Seq[Expression], left: LogicalPlan, right: LogicalPlan) = {
val (leftEvaluateCondition, rest) =
condition.partition(_.references subsetOf left.outputSet)
- val (rightEvaluateCondition, commonCondition) =
+ val (rightEvaluateCondition, commonCondition) =
rest.partition(_.references subsetOf right.outputSet)
(leftEvaluateCondition, rightEvaluateCondition, commonCondition)
@@ -413,7 +416,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
// push the where condition down into join filter
case f @ Filter(filterCondition, Join(left, right, joinType, joinCondition)) =>
- val (leftFilterConditions, rightFilterConditions, commonFilterCondition) =
+ val (leftFilterConditions, rightFilterConditions, commonFilterCondition) =
split(splitConjunctivePredicates(filterCondition), left, right)
joinType match {
@@ -451,7 +454,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper {
// push down the join filter into sub query scanning if applicable
case f @ Join(left, right, joinType, joinCondition) =>
- val (leftJoinConditions, rightJoinConditions, commonJoinCondition) =
+ val (leftJoinConditions, rightJoinConditions, commonJoinCondition) =
split(joinCondition.map(splitConjunctivePredicates).getOrElse(Nil), left, right)
joinType match {
@@ -519,3 +522,26 @@ object SimplifyCaseConversionExpressions extends Rule[LogicalPlan] {
}
}
}
+
+/**
+ * Speeds up aggregates on fixed-precision decimals by executing them on unscaled Long values.
+ *
+ * This uses the same rules for increasing the precision and scale of the output as
+ * [[org.apache.spark.sql.catalyst.analysis.HiveTypeCoercion.DecimalPrecision]].
+ */
+object DecimalAggregates extends Rule[LogicalPlan] {
+ import Decimal.MAX_LONG_DIGITS
+
+ /** Maximum number of decimal digits representable precisely in a Double */
+ val MAX_DOUBLE_DIGITS = 15
+
+ def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
+ case Sum(e @ DecimalType.Expression(prec, scale)) if prec + 10 <= MAX_LONG_DIGITS =>
+ MakeDecimal(Sum(UnscaledValue(e)), prec + 10, scale)
+
+ case Average(e @ DecimalType.Expression(prec, scale)) if prec + 4 <= MAX_DOUBLE_DIGITS =>
+ Cast(
+ Divide(Average(UnscaledValue(e)), Literal(math.pow(10.0, scale), DoubleType)),
+ DecimalType(prec + 4, scale + 4))
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala
index 6069f9b0a6..8dda0b1828 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.types
import java.sql.{Date, Timestamp}
-import scala.math.Numeric.{BigDecimalAsIfIntegral, DoubleAsIfIntegral, FloatAsIfIntegral}
+import scala.math.Numeric.{FloatAsIfIntegral, BigDecimalAsIfIntegral, DoubleAsIfIntegral}
import scala.reflect.ClassTag
import scala.reflect.runtime.universe.{TypeTag, runtimeMirror, typeTag}
import scala.util.parsing.combinator.RegexParsers
@@ -33,6 +33,7 @@ import org.apache.spark.sql.catalyst.ScalaReflectionLock
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression}
import org.apache.spark.sql.catalyst.util.Metadata
import org.apache.spark.util.Utils
+import org.apache.spark.sql.catalyst.types.decimal._
object DataType {
def fromJson(json: String): DataType = parseDataType(parse(json))
@@ -91,11 +92,17 @@ object DataType {
| "LongType" ^^^ LongType
| "BinaryType" ^^^ BinaryType
| "BooleanType" ^^^ BooleanType
- | "DecimalType" ^^^ DecimalType
| "DateType" ^^^ DateType
+ | "DecimalType()" ^^^ DecimalType.Unlimited
+ | fixedDecimalType
| "TimestampType" ^^^ TimestampType
)
+ protected lazy val fixedDecimalType: Parser[DataType] =
+ ("DecimalType(" ~> "[0-9]+".r) ~ ("," ~> "[0-9]+".r <~ ")") ^^ {
+ case precision ~ scale => DecimalType(precision.toInt, scale.toInt)
+ }
+
protected lazy val arrayType: Parser[DataType] =
"ArrayType" ~> "(" ~> dataType ~ "," ~ boolVal <~ ")" ^^ {
case tpe ~ _ ~ containsNull => ArrayType(tpe, containsNull)
@@ -200,10 +207,18 @@ trait PrimitiveType extends DataType {
}
object PrimitiveType {
- private[sql] val all = Seq(DecimalType, DateType, TimestampType, BinaryType) ++
- NativeType.all
-
- private[sql] val nameToType = all.map(t => t.typeName -> t).toMap
+ private val nonDecimals = Seq(DateType, TimestampType, BinaryType) ++ NativeType.all
+ private val nonDecimalNameToType = nonDecimals.map(t => t.typeName -> t).toMap
+
+ /** Given the string representation of a type, return its DataType */
+ private[sql] def nameToType(name: String): DataType = {
+ val FIXED_DECIMAL = """decimal\(\s*(\d+)\s*,\s*(\d+)\s*\)""".r
+ name match {
+ case "decimal" => DecimalType.Unlimited
+ case FIXED_DECIMAL(precision, scale) => DecimalType(precision.toInt, scale.toInt)
+ case other => nonDecimalNameToType(other)
+ }
+ }
}
abstract class NativeType extends DataType {
@@ -332,13 +347,58 @@ abstract class FractionalType extends NumericType {
private[sql] val asIntegral: Integral[JvmType]
}
-case object DecimalType extends FractionalType {
- private[sql] type JvmType = BigDecimal
+/** Precision parameters for a Decimal */
+case class PrecisionInfo(precision: Int, scale: Int)
+
+/** A Decimal that might have fixed precision and scale, or unlimited values for these */
+case class DecimalType(precisionInfo: Option[PrecisionInfo]) extends FractionalType {
+ private[sql] type JvmType = Decimal
@transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] }
- private[sql] val numeric = implicitly[Numeric[BigDecimal]]
- private[sql] val fractional = implicitly[Fractional[BigDecimal]]
- private[sql] val ordering = implicitly[Ordering[JvmType]]
- private[sql] val asIntegral = BigDecimalAsIfIntegral
+ private[sql] val numeric = Decimal.DecimalIsFractional
+ private[sql] val fractional = Decimal.DecimalIsFractional
+ private[sql] val ordering = Decimal.DecimalIsFractional
+ private[sql] val asIntegral = Decimal.DecimalAsIfIntegral
+
+ override def typeName: String = precisionInfo match {
+ case Some(PrecisionInfo(precision, scale)) => s"decimal($precision,$scale)"
+ case None => "decimal"
+ }
+
+ override def toString: String = precisionInfo match {
+ case Some(PrecisionInfo(precision, scale)) => s"DecimalType($precision,$scale)"
+ case None => "DecimalType()"
+ }
+}
+
+/** Extra factory methods and pattern matchers for Decimals */
+object DecimalType {
+ val Unlimited: DecimalType = DecimalType(None)
+
+ object Fixed {
+ def unapply(t: DecimalType): Option[(Int, Int)] =
+ t.precisionInfo.map(p => (p.precision, p.scale))
+ }
+
+ object Expression {
+ def unapply(e: Expression): Option[(Int, Int)] = e.dataType match {
+ case t: DecimalType => t.precisionInfo.map(p => (p.precision, p.scale))
+ case _ => None
+ }
+ }
+
+ def apply(): DecimalType = Unlimited
+
+ def apply(precision: Int, scale: Int): DecimalType =
+ DecimalType(Some(PrecisionInfo(precision, scale)))
+
+ def unapply(t: DataType): Boolean = t.isInstanceOf[DecimalType]
+
+ def unapply(e: Expression): Boolean = e.dataType.isInstanceOf[DecimalType]
+
+ def isFixed(dataType: DataType): Boolean = dataType match {
+ case DecimalType.Fixed(_, _) => true
+ case _ => false
+ }
}
case object DoubleType extends FractionalType {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/decimal/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/decimal/Decimal.scala
new file mode 100644
index 0000000000..708362acf3
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/decimal/Decimal.scala
@@ -0,0 +1,335 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.types.decimal
+
+import org.apache.spark.annotation.DeveloperApi
+
+/**
+ * A mutable implementation of BigDecimal that can hold a Long if values are small enough.
+ *
+ * The semantics of the fields are as follows:
+ * - _precision and _scale represent the SQL precision and scale we are looking for
+ * - If decimalVal is set, it represents the whole decimal value
+ * - Otherwise, the decimal value is longVal / (10 ** _scale)
+ */
+final class Decimal extends Ordered[Decimal] with Serializable {
+ import Decimal.{MAX_LONG_DIGITS, POW_10, ROUNDING_MODE, BIG_DEC_ZERO}
+
+ private var decimalVal: BigDecimal = null
+ private var longVal: Long = 0L
+ private var _precision: Int = 1
+ private var _scale: Int = 0
+
+ def precision: Int = _precision
+ def scale: Int = _scale
+
+ /**
+ * Set this Decimal to the given Long. Will have precision 20 and scale 0.
+ */
+ def set(longVal: Long): Decimal = {
+ if (longVal <= -POW_10(MAX_LONG_DIGITS) || longVal >= POW_10(MAX_LONG_DIGITS)) {
+ // We can't represent this compactly as a long without risking overflow
+ this.decimalVal = BigDecimal(longVal)
+ this.longVal = 0L
+ } else {
+ this.decimalVal = null
+ this.longVal = longVal
+ }
+ this._precision = 20
+ this._scale = 0
+ this
+ }
+
+ /**
+ * Set this Decimal to the given Int. Will have precision 10 and scale 0.
+ */
+ def set(intVal: Int): Decimal = {
+ this.decimalVal = null
+ this.longVal = intVal
+ this._precision = 10
+ this._scale = 0
+ this
+ }
+
+ /**
+ * Set this Decimal to the given unscaled Long, with a given precision and scale.
+ */
+ def set(unscaled: Long, precision: Int, scale: Int): Decimal = {
+ if (setOrNull(unscaled, precision, scale) == null) {
+ throw new IllegalArgumentException("Unscaled value too large for precision")
+ }
+ this
+ }
+
+ /**
+ * Set this Decimal to the given unscaled Long, with a given precision and scale,
+ * and return it, or return null if it cannot be set due to overflow.
+ */
+ def setOrNull(unscaled: Long, precision: Int, scale: Int): Decimal = {
+ if (unscaled <= -POW_10(MAX_LONG_DIGITS) || unscaled >= POW_10(MAX_LONG_DIGITS)) {
+ // We can't represent this compactly as a long without risking overflow
+ if (precision < 19) {
+ return null // Requested precision is too low to represent this value
+ }
+ this.decimalVal = BigDecimal(longVal)
+ this.longVal = 0L
+ } else {
+ val p = POW_10(math.min(precision, MAX_LONG_DIGITS))
+ if (unscaled <= -p || unscaled >= p) {
+ return null // Requested precision is too low to represent this value
+ }
+ this.decimalVal = null
+ this.longVal = unscaled
+ }
+ this._precision = precision
+ this._scale = scale
+ this
+ }
+
+ /**
+ * Set this Decimal to the given BigDecimal value, with a given precision and scale.
+ */
+ def set(decimal: BigDecimal, precision: Int, scale: Int): Decimal = {
+ this.decimalVal = decimal.setScale(scale, ROUNDING_MODE)
+ require(decimalVal.precision <= precision, "Overflowed precision")
+ this.longVal = 0L
+ this._precision = precision
+ this._scale = scale
+ this
+ }
+
+ /**
+ * Set this Decimal to the given BigDecimal value, inheriting its precision and scale.
+ */
+ def set(decimal: BigDecimal): Decimal = {
+ this.decimalVal = decimal
+ this.longVal = 0L
+ this._precision = decimal.precision
+ this._scale = decimal.scale
+ this
+ }
+
+ /**
+ * Set this Decimal to the given Decimal value.
+ */
+ def set(decimal: Decimal): Decimal = {
+ this.decimalVal = decimal.decimalVal
+ this.longVal = decimal.longVal
+ this._precision = decimal._precision
+ this._scale = decimal._scale
+ this
+ }
+
+ def toBigDecimal: BigDecimal = {
+ if (decimalVal.ne(null)) {
+ decimalVal
+ } else {
+ BigDecimal(longVal, _scale)
+ }
+ }
+
+ def toUnscaledLong: Long = {
+ if (decimalVal.ne(null)) {
+ decimalVal.underlying().unscaledValue().longValue()
+ } else {
+ longVal
+ }
+ }
+
+ override def toString: String = toBigDecimal.toString()
+
+ @DeveloperApi
+ def toDebugString: String = {
+ if (decimalVal.ne(null)) {
+ s"Decimal(expanded,$decimalVal,$precision,$scale})"
+ } else {
+ s"Decimal(compact,$longVal,$precision,$scale})"
+ }
+ }
+
+ def toDouble: Double = toBigDecimal.doubleValue()
+
+ def toFloat: Float = toBigDecimal.floatValue()
+
+ def toLong: Long = {
+ if (decimalVal.eq(null)) {
+ longVal / POW_10(_scale)
+ } else {
+ decimalVal.longValue()
+ }
+ }
+
+ def toInt: Int = toLong.toInt
+
+ def toShort: Short = toLong.toShort
+
+ def toByte: Byte = toLong.toByte
+
+ /**
+ * Update precision and scale while keeping our value the same, and return true if successful.
+ *
+ * @return true if successful, false if overflow would occur
+ */
+ def changePrecision(precision: Int, scale: Int): Boolean = {
+ // First, update our longVal if we can, or transfer over to using a BigDecimal
+ if (decimalVal.eq(null)) {
+ if (scale < _scale) {
+ // Easier case: we just need to divide our scale down
+ val diff = _scale - scale
+ val droppedDigits = longVal % POW_10(diff)
+ longVal /= POW_10(diff)
+ if (math.abs(droppedDigits) * 2 >= POW_10(diff)) {
+ longVal += (if (longVal < 0) -1L else 1L)
+ }
+ } else if (scale > _scale) {
+ // We might be able to multiply longVal by a power of 10 and not overflow, but if not,
+ // switch to using a BigDecimal
+ val diff = scale - _scale
+ val p = POW_10(math.max(MAX_LONG_DIGITS - diff, 0))
+ if (diff <= MAX_LONG_DIGITS && longVal > -p && longVal < p) {
+ // Multiplying longVal by POW_10(diff) will still keep it below MAX_LONG_DIGITS
+ longVal *= POW_10(diff)
+ } else {
+ // Give up on using Longs; switch to BigDecimal, which we'll modify below
+ decimalVal = BigDecimal(longVal, _scale)
+ }
+ }
+ // In both cases, we will check whether our precision is okay below
+ }
+
+ if (decimalVal.ne(null)) {
+ // We get here if either we started with a BigDecimal, or we switched to one because we would
+ // have overflowed our Long; in either case we must rescale decimalVal to the new scale.
+ val newVal = decimalVal.setScale(scale, ROUNDING_MODE)
+ if (newVal.precision > precision) {
+ return false
+ }
+ decimalVal = newVal
+ } else {
+ // We're still using Longs, but we should check whether we match the new precision
+ val p = POW_10(math.min(_precision, MAX_LONG_DIGITS))
+ if (longVal <= -p || longVal >= p) {
+ // Note that we shouldn't have been able to fix this by switching to BigDecimal
+ return false
+ }
+ }
+
+ _precision = precision
+ _scale = scale
+ true
+ }
+
+ override def clone(): Decimal = new Decimal().set(this)
+
+ override def compare(other: Decimal): Int = {
+ if (decimalVal.eq(null) && other.decimalVal.eq(null) && _scale == other._scale) {
+ if (longVal < other.longVal) -1 else if (longVal == other.longVal) 0 else 1
+ } else {
+ toBigDecimal.compare(other.toBigDecimal)
+ }
+ }
+
+ override def equals(other: Any) = other match {
+ case d: Decimal =>
+ compare(d) == 0
+ case _ =>
+ false
+ }
+
+ override def hashCode(): Int = toBigDecimal.hashCode()
+
+ def isZero: Boolean = if (decimalVal.ne(null)) decimalVal == BIG_DEC_ZERO else longVal == 0
+
+ def + (that: Decimal): Decimal = Decimal(toBigDecimal + that.toBigDecimal)
+
+ def - (that: Decimal): Decimal = Decimal(toBigDecimal - that.toBigDecimal)
+
+ def * (that: Decimal): Decimal = Decimal(toBigDecimal * that.toBigDecimal)
+
+ def / (that: Decimal): Decimal =
+ if (that.isZero) null else Decimal(toBigDecimal / that.toBigDecimal)
+
+ def % (that: Decimal): Decimal =
+ if (that.isZero) null else Decimal(toBigDecimal % that.toBigDecimal)
+
+ def remainder(that: Decimal): Decimal = this % that
+
+ def unary_- : Decimal = {
+ if (decimalVal.ne(null)) {
+ Decimal(-decimalVal)
+ } else {
+ Decimal(-longVal, precision, scale)
+ }
+ }
+}
+
+object Decimal {
+ private val ROUNDING_MODE = BigDecimal.RoundingMode.HALF_UP
+
+ /** Maximum number of decimal digits a Long can represent */
+ val MAX_LONG_DIGITS = 18
+
+ private val POW_10 = Array.tabulate[Long](MAX_LONG_DIGITS + 1)(i => math.pow(10, i).toLong)
+
+ private val BIG_DEC_ZERO = BigDecimal(0)
+
+ def apply(value: Double): Decimal = new Decimal().set(value)
+
+ def apply(value: Long): Decimal = new Decimal().set(value)
+
+ def apply(value: Int): Decimal = new Decimal().set(value)
+
+ def apply(value: BigDecimal): Decimal = new Decimal().set(value)
+
+ def apply(value: BigDecimal, precision: Int, scale: Int): Decimal =
+ new Decimal().set(value, precision, scale)
+
+ def apply(unscaled: Long, precision: Int, scale: Int): Decimal =
+ new Decimal().set(unscaled, precision, scale)
+
+ def apply(value: String): Decimal = new Decimal().set(BigDecimal(value))
+
+ // Evidence parameters for Decimal considered either as Fractional or Integral. We provide two
+ // parameters inheriting from a common trait since both traits define mkNumericOps.
+ // See scala.math's Numeric.scala for examples for Scala's built-in types.
+
+ /** Common methods for Decimal evidence parameters */
+ trait DecimalIsConflicted extends Numeric[Decimal] {
+ override def plus(x: Decimal, y: Decimal): Decimal = x + y
+ override def times(x: Decimal, y: Decimal): Decimal = x * y
+ override def minus(x: Decimal, y: Decimal): Decimal = x - y
+ override def negate(x: Decimal): Decimal = -x
+ override def toDouble(x: Decimal): Double = x.toDouble
+ override def toFloat(x: Decimal): Float = x.toFloat
+ override def toInt(x: Decimal): Int = x.toInt
+ override def toLong(x: Decimal): Long = x.toLong
+ override def fromInt(x: Int): Decimal = new Decimal().set(x)
+ override def compare(x: Decimal, y: Decimal): Int = x.compare(y)
+ }
+
+ /** A [[scala.math.Fractional]] evidence parameter for Decimals. */
+ object DecimalIsFractional extends DecimalIsConflicted with Fractional[Decimal] {
+ override def div(x: Decimal, y: Decimal): Decimal = x / y
+ }
+
+ /** A [[scala.math.Integral]] evidence parameter for Decimals. */
+ object DecimalAsIfIntegral extends DecimalIsConflicted with Integral[Decimal] {
+ override def quot(x: Decimal, y: Decimal): Decimal = x / y
+ override def rem(x: Decimal, y: Decimal): Decimal = x % y
+ }
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
index 430f0664b7..21b2c8e20d 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
@@ -96,7 +96,7 @@ class ScalaReflectionSuite extends FunSuite {
StructField("byteField", ByteType, nullable = true),
StructField("booleanField", BooleanType, nullable = true),
StructField("stringField", StringType, nullable = true),
- StructField("decimalField", DecimalType, nullable = true),
+ StructField("decimalField", DecimalType.Unlimited, nullable = true),
StructField("dateField", DateType, nullable = true),
StructField("timestampField", TimestampType, nullable = true),
StructField("binaryField", BinaryType, nullable = true))),
@@ -199,7 +199,7 @@ class ScalaReflectionSuite extends FunSuite {
assert(DoubleType === typeOfObject(1.7976931348623157E308))
// DecimalType
- assert(DecimalType === typeOfObject(BigDecimal("1.7976931348623157E318")))
+ assert(DecimalType.Unlimited === typeOfObject(BigDecimal("1.7976931348623157E318")))
// DateType
assert(DateType === typeOfObject(Date.valueOf("2014-07-25")))
@@ -211,19 +211,19 @@ class ScalaReflectionSuite extends FunSuite {
assert(NullType === typeOfObject(null))
def typeOfObject1: PartialFunction[Any, DataType] = typeOfObject orElse {
- case value: java.math.BigInteger => DecimalType
- case value: java.math.BigDecimal => DecimalType
+ case value: java.math.BigInteger => DecimalType.Unlimited
+ case value: java.math.BigDecimal => DecimalType.Unlimited
case _ => StringType
}
- assert(DecimalType === typeOfObject1(
+ assert(DecimalType.Unlimited === typeOfObject1(
new BigInteger("92233720368547758070")))
- assert(DecimalType === typeOfObject1(
+ assert(DecimalType.Unlimited === typeOfObject1(
new java.math.BigDecimal("1.7976931348623157E318")))
assert(StringType === typeOfObject1(BigInt("92233720368547758070")))
def typeOfObject2: PartialFunction[Any, DataType] = typeOfObject orElse {
- case value: java.math.BigInteger => DecimalType
+ case value: java.math.BigInteger => DecimalType.Unlimited
}
intercept[MatchError](typeOfObject2(BigInt("92233720368547758070")))
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
index 7b45738c4f..33a3cba3d4 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
@@ -38,7 +38,7 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter {
AttributeReference("a", StringType)(),
AttributeReference("b", StringType)(),
AttributeReference("c", DoubleType)(),
- AttributeReference("d", DecimalType)(),
+ AttributeReference("d", DecimalType.Unlimited)(),
AttributeReference("e", ShortType)())
before {
@@ -119,7 +119,7 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter {
AttributeReference("a", StringType)(),
AttributeReference("b", StringType)(),
AttributeReference("c", DoubleType)(),
- AttributeReference("d", DecimalType)(),
+ AttributeReference("d", DecimalType.Unlimited)(),
AttributeReference("e", ShortType)())
val expr0 = 'a / 2
@@ -137,7 +137,7 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter {
assert(pl(0).dataType == DoubleType)
assert(pl(1).dataType == DoubleType)
assert(pl(2).dataType == DoubleType)
- assert(pl(3).dataType == DecimalType)
+ assert(pl(3).dataType == DecimalType.Unlimited)
assert(pl(4).dataType == DoubleType)
}
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala
new file mode 100644
index 0000000000..d5b7d2789a
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala
@@ -0,0 +1,88 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.analysis
+
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.logical.{Project, LocalRelation}
+import org.apache.spark.sql.catalyst.types._
+import org.scalatest.{BeforeAndAfter, FunSuite}
+
+class DecimalPrecisionSuite extends FunSuite with BeforeAndAfter {
+ val catalog = new SimpleCatalog(false)
+ val analyzer = new Analyzer(catalog, EmptyFunctionRegistry, caseSensitive = false)
+
+ val relation = LocalRelation(
+ AttributeReference("i", IntegerType)(),
+ AttributeReference("d1", DecimalType(2, 1))(),
+ AttributeReference("d2", DecimalType(5, 2))(),
+ AttributeReference("u", DecimalType.Unlimited)(),
+ AttributeReference("f", FloatType)()
+ )
+
+ val i: Expression = UnresolvedAttribute("i")
+ val d1: Expression = UnresolvedAttribute("d1")
+ val d2: Expression = UnresolvedAttribute("d2")
+ val u: Expression = UnresolvedAttribute("u")
+ val f: Expression = UnresolvedAttribute("f")
+
+ before {
+ catalog.registerTable(None, "table", relation)
+ }
+
+ private def checkType(expression: Expression, expectedType: DataType): Unit = {
+ val plan = Project(Seq(Alias(expression, "c")()), relation)
+ assert(analyzer(plan).schema.fields(0).dataType === expectedType)
+ }
+
+ test("basic operations") {
+ checkType(Add(d1, d2), DecimalType(6, 2))
+ checkType(Subtract(d1, d2), DecimalType(6, 2))
+ checkType(Multiply(d1, d2), DecimalType(8, 3))
+ checkType(Divide(d1, d2), DecimalType(10, 7))
+ checkType(Divide(d2, d1), DecimalType(10, 6))
+ checkType(Remainder(d1, d2), DecimalType(3, 2))
+ checkType(Remainder(d2, d1), DecimalType(3, 2))
+ checkType(Sum(d1), DecimalType(12, 1))
+ checkType(Average(d1), DecimalType(6, 5))
+
+ checkType(Add(Add(d1, d2), d1), DecimalType(7, 2))
+ checkType(Add(Add(Add(d1, d2), d1), d2), DecimalType(8, 2))
+ checkType(Add(Add(d1, d2), Add(d1, d2)), DecimalType(7, 2))
+ }
+
+ test("bringing in primitive types") {
+ checkType(Add(d1, i), DecimalType(12, 1))
+ checkType(Add(d1, f), DoubleType)
+ checkType(Add(i, d1), DecimalType(12, 1))
+ checkType(Add(f, d1), DoubleType)
+ checkType(Add(d1, Cast(i, LongType)), DecimalType(22, 1))
+ checkType(Add(d1, Cast(i, ShortType)), DecimalType(7, 1))
+ checkType(Add(d1, Cast(i, ByteType)), DecimalType(5, 1))
+ checkType(Add(d1, Cast(i, DoubleType)), DoubleType)
+ }
+
+ test("unlimited decimals make everything else cast up") {
+ for (expr <- Seq(d1, d2, i, f, u)) {
+ checkType(Add(expr, u), DecimalType.Unlimited)
+ checkType(Subtract(expr, u), DecimalType.Unlimited)
+ checkType(Multiply(expr, u), DecimalType.Unlimited)
+ checkType(Divide(expr, u), DecimalType.Unlimited)
+ checkType(Remainder(expr, u), DecimalType.Unlimited)
+ }
+ }
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
index baeb9b0cf5..dfa2d958c0 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
@@ -68,6 +68,21 @@ class HiveTypeCoercionSuite extends FunSuite {
widenTest(LongType, FloatType, Some(FloatType))
widenTest(LongType, DoubleType, Some(DoubleType))
+ // Casting up to unlimited-precision decimal
+ widenTest(IntegerType, DecimalType.Unlimited, Some(DecimalType.Unlimited))
+ widenTest(DoubleType, DecimalType.Unlimited, Some(DecimalType.Unlimited))
+ widenTest(DecimalType(3, 2), DecimalType.Unlimited, Some(DecimalType.Unlimited))
+ widenTest(DecimalType.Unlimited, IntegerType, Some(DecimalType.Unlimited))
+ widenTest(DecimalType.Unlimited, DoubleType, Some(DecimalType.Unlimited))
+ widenTest(DecimalType.Unlimited, DecimalType(3, 2), Some(DecimalType.Unlimited))
+
+ // No up-casting for fixed-precision decimal (this is handled by arithmetic rules)
+ widenTest(DecimalType(2, 1), DecimalType(3, 2), None)
+ widenTest(DecimalType(2, 1), DoubleType, None)
+ widenTest(DecimalType(2, 1), IntegerType, None)
+ widenTest(DoubleType, DecimalType(2, 1), None)
+ widenTest(IntegerType, DecimalType(2, 1), None)
+
// StringType
widenTest(NullType, StringType, Some(StringType))
widenTest(StringType, StringType, Some(StringType))
@@ -92,7 +107,7 @@ class HiveTypeCoercionSuite extends FunSuite {
def ruleTest(initial: Expression, transformed: Expression) {
val testRelation = LocalRelation(AttributeReference("a", IntegerType)())
assert(booleanCasts(Project(Seq(Alias(initial, "a")()), testRelation)) ==
- Project(Seq(Alias(transformed, "a")()), testRelation))
+ Project(Seq(Alias(transformed, "a")()), testRelation))
}
// Remove superflous boolean -> boolean casts.
ruleTest(Cast(Literal(true), BooleanType), Literal(true))
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
index 5657bc555e..6bfa0dbd65 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
@@ -21,6 +21,7 @@ import java.sql.{Date, Timestamp}
import scala.collection.immutable.HashSet
+import org.apache.spark.sql.catalyst.types.decimal.Decimal
import org.scalatest.FunSuite
import org.scalatest.Matchers._
import org.scalactic.TripleEqualsSupport.Spread
@@ -138,7 +139,7 @@ class ExpressionEvaluationSuite extends FunSuite {
val actual = try evaluate(expression, inputRow) catch {
case e: Exception => fail(s"Exception evaluating $expression", e)
}
- actual.asInstanceOf[Double] shouldBe expected
+ actual.asInstanceOf[Double] shouldBe expected
}
test("IN") {
@@ -165,7 +166,7 @@ class ExpressionEvaluationSuite extends FunSuite {
checkEvaluation(InSet(three, nS, three +: nullS), false)
checkEvaluation(InSet(one, hS, one +: s) && InSet(two, hS, two +: s), true)
}
-
+
test("MaxOf") {
checkEvaluation(MaxOf(1, 2), 2)
checkEvaluation(MaxOf(2, 1), 2)
@@ -265,9 +266,9 @@ class ExpressionEvaluationSuite extends FunSuite {
val ts = Timestamp.valueOf(nts)
checkEvaluation("abdef" cast StringType, "abdef")
- checkEvaluation("abdef" cast DecimalType, null)
+ checkEvaluation("abdef" cast DecimalType.Unlimited, null)
checkEvaluation("abdef" cast TimestampType, null)
- checkEvaluation("12.65" cast DecimalType, BigDecimal(12.65))
+ checkEvaluation("12.65" cast DecimalType.Unlimited, Decimal(12.65))
checkEvaluation(Literal(1) cast LongType, 1)
checkEvaluation(Cast(Literal(1000) cast TimestampType, LongType), 1.toLong)
@@ -289,12 +290,12 @@ class ExpressionEvaluationSuite extends FunSuite {
checkEvaluation(Cast(Cast(Cast(Cast(
Cast("5" cast ByteType, ShortType), IntegerType), FloatType), DoubleType), LongType), 5)
- checkEvaluation(Cast(Cast(Cast(Cast(
- Cast("5" cast ByteType, TimestampType), DecimalType), LongType), StringType), ShortType), 0)
- checkEvaluation(Cast(Cast(Cast(Cast(
- Cast("5" cast TimestampType, ByteType), DecimalType), LongType), StringType), ShortType), null)
- checkEvaluation(Cast(Cast(Cast(Cast(
- Cast("5" cast DecimalType, ByteType), TimestampType), LongType), StringType), ShortType), 0)
+ checkEvaluation(Cast(Cast(Cast(Cast(Cast("5" cast
+ ByteType, TimestampType), DecimalType.Unlimited), LongType), StringType), ShortType), 0)
+ checkEvaluation(Cast(Cast(Cast(Cast(Cast("5" cast
+ TimestampType, ByteType), DecimalType.Unlimited), LongType), StringType), ShortType), null)
+ checkEvaluation(Cast(Cast(Cast(Cast(Cast("5" cast
+ DecimalType.Unlimited, ByteType), TimestampType), LongType), StringType), ShortType), 0)
checkEvaluation(Literal(true) cast IntegerType, 1)
checkEvaluation(Literal(false) cast IntegerType, 0)
checkEvaluation(Cast(Literal(1) cast BooleanType, IntegerType), 1)
@@ -302,7 +303,7 @@ class ExpressionEvaluationSuite extends FunSuite {
checkEvaluation("23" cast DoubleType, 23d)
checkEvaluation("23" cast IntegerType, 23)
checkEvaluation("23" cast FloatType, 23f)
- checkEvaluation("23" cast DecimalType, 23: BigDecimal)
+ checkEvaluation("23" cast DecimalType.Unlimited, Decimal(23))
checkEvaluation("23" cast ByteType, 23.toByte)
checkEvaluation("23" cast ShortType, 23.toShort)
checkEvaluation("2012-12-11" cast DoubleType, null)
@@ -311,7 +312,7 @@ class ExpressionEvaluationSuite extends FunSuite {
checkEvaluation(Literal(23d) + Cast(true, DoubleType), 24d)
checkEvaluation(Literal(23) + Cast(true, IntegerType), 24)
checkEvaluation(Literal(23f) + Cast(true, FloatType), 24f)
- checkEvaluation(Literal(BigDecimal(23)) + Cast(true, DecimalType), 24: BigDecimal)
+ checkEvaluation(Literal(Decimal(23)) + Cast(true, DecimalType.Unlimited), Decimal(24))
checkEvaluation(Literal(23.toByte) + Cast(true, ByteType), 24.toByte)
checkEvaluation(Literal(23.toShort) + Cast(true, ShortType), 24.toShort)
@@ -325,7 +326,8 @@ class ExpressionEvaluationSuite extends FunSuite {
assert(("abcdef" cast IntegerType).nullable === true)
assert(("abcdef" cast ShortType).nullable === true)
assert(("abcdef" cast ByteType).nullable === true)
- assert(("abcdef" cast DecimalType).nullable === true)
+ assert(("abcdef" cast DecimalType.Unlimited).nullable === true)
+ assert(("abcdef" cast DecimalType(4, 2)).nullable === true)
assert(("abcdef" cast DoubleType).nullable === true)
assert(("abcdef" cast FloatType).nullable === true)
@@ -338,6 +340,64 @@ class ExpressionEvaluationSuite extends FunSuite {
checkEvaluation(Literal(d1) < Literal(d2), true)
}
+ test("casting to fixed-precision decimals") {
+ // Overflow and rounding for casting to fixed-precision decimals:
+ // - Values should round with HALF_UP mode by default when you lower scale
+ // - Values that would overflow the target precision should turn into null
+ // - Because of this, casts to fixed-precision decimals should be nullable
+
+ assert(Cast(Literal(123), DecimalType.Unlimited).nullable === false)
+ assert(Cast(Literal(10.03f), DecimalType.Unlimited).nullable === false)
+ assert(Cast(Literal(10.03), DecimalType.Unlimited).nullable === false)
+ assert(Cast(Literal(Decimal(10.03)), DecimalType.Unlimited).nullable === false)
+
+ assert(Cast(Literal(123), DecimalType(2, 1)).nullable === true)
+ assert(Cast(Literal(10.03f), DecimalType(2, 1)).nullable === true)
+ assert(Cast(Literal(10.03), DecimalType(2, 1)).nullable === true)
+ assert(Cast(Literal(Decimal(10.03)), DecimalType(2, 1)).nullable === true)
+
+ checkEvaluation(Cast(Literal(123), DecimalType.Unlimited), Decimal(123))
+ checkEvaluation(Cast(Literal(123), DecimalType(3, 0)), Decimal(123))
+ checkEvaluation(Cast(Literal(123), DecimalType(3, 1)), null)
+ checkEvaluation(Cast(Literal(123), DecimalType(2, 0)), null)
+
+ checkEvaluation(Cast(Literal(10.03), DecimalType.Unlimited), Decimal(10.03))
+ checkEvaluation(Cast(Literal(10.03), DecimalType(4, 2)), Decimal(10.03))
+ checkEvaluation(Cast(Literal(10.03), DecimalType(3, 1)), Decimal(10.0))
+ checkEvaluation(Cast(Literal(10.03), DecimalType(2, 0)), Decimal(10))
+ checkEvaluation(Cast(Literal(10.03), DecimalType(1, 0)), null)
+ checkEvaluation(Cast(Literal(10.03), DecimalType(2, 1)), null)
+ checkEvaluation(Cast(Literal(10.03), DecimalType(3, 2)), null)
+ checkEvaluation(Cast(Literal(Decimal(10.03)), DecimalType(3, 1)), Decimal(10.0))
+ checkEvaluation(Cast(Literal(Decimal(10.03)), DecimalType(3, 2)), null)
+
+ checkEvaluation(Cast(Literal(10.05), DecimalType.Unlimited), Decimal(10.05))
+ checkEvaluation(Cast(Literal(10.05), DecimalType(4, 2)), Decimal(10.05))
+ checkEvaluation(Cast(Literal(10.05), DecimalType(3, 1)), Decimal(10.1))
+ checkEvaluation(Cast(Literal(10.05), DecimalType(2, 0)), Decimal(10))
+ checkEvaluation(Cast(Literal(10.05), DecimalType(1, 0)), null)
+ checkEvaluation(Cast(Literal(10.05), DecimalType(2, 1)), null)
+ checkEvaluation(Cast(Literal(10.05), DecimalType(3, 2)), null)
+ checkEvaluation(Cast(Literal(Decimal(10.05)), DecimalType(3, 1)), Decimal(10.1))
+ checkEvaluation(Cast(Literal(Decimal(10.05)), DecimalType(3, 2)), null)
+
+ checkEvaluation(Cast(Literal(9.95), DecimalType(3, 2)), Decimal(9.95))
+ checkEvaluation(Cast(Literal(9.95), DecimalType(3, 1)), Decimal(10.0))
+ checkEvaluation(Cast(Literal(9.95), DecimalType(2, 0)), Decimal(10))
+ checkEvaluation(Cast(Literal(9.95), DecimalType(2, 1)), null)
+ checkEvaluation(Cast(Literal(9.95), DecimalType(1, 0)), null)
+ checkEvaluation(Cast(Literal(Decimal(9.95)), DecimalType(3, 1)), Decimal(10.0))
+ checkEvaluation(Cast(Literal(Decimal(9.95)), DecimalType(1, 0)), null)
+
+ checkEvaluation(Cast(Literal(-9.95), DecimalType(3, 2)), Decimal(-9.95))
+ checkEvaluation(Cast(Literal(-9.95), DecimalType(3, 1)), Decimal(-10.0))
+ checkEvaluation(Cast(Literal(-9.95), DecimalType(2, 0)), Decimal(-10))
+ checkEvaluation(Cast(Literal(-9.95), DecimalType(2, 1)), null)
+ checkEvaluation(Cast(Literal(-9.95), DecimalType(1, 0)), null)
+ checkEvaluation(Cast(Literal(Decimal(-9.95)), DecimalType(3, 1)), Decimal(-10.0))
+ checkEvaluation(Cast(Literal(Decimal(-9.95)), DecimalType(1, 0)), null)
+ }
+
test("timestamp") {
val ts1 = new Timestamp(12)
val ts2 = new Timestamp(123)
@@ -374,7 +434,7 @@ class ExpressionEvaluationSuite extends FunSuite {
millis.toFloat / 1000)
checkEvaluation(Cast(Cast(millis.toDouble / 1000, TimestampType), DoubleType),
millis.toDouble / 1000)
- checkEvaluation(Cast(Literal(BigDecimal(1)) cast TimestampType, DecimalType), 1)
+ checkEvaluation(Cast(Literal(Decimal(1)) cast TimestampType, DecimalType.Unlimited), Decimal(1))
// A test for higher precision than millis
checkEvaluation(Cast(Cast(0.00000001, TimestampType), DoubleType), 0.00000001)
@@ -673,7 +733,7 @@ class ExpressionEvaluationSuite extends FunSuite {
val expectedResults = inputSequence.map(l => math.sqrt(l.toDouble))
val rowSequence = inputSequence.map(l => new GenericRow(Array[Any](l.toDouble)))
val d = 'a.double.at(0)
-
+
for ((row, expected) <- rowSequence zip expectedResults) {
checkEvaluation(Sqrt(d), expected, row)
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/types/decimal/DecimalSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/types/decimal/DecimalSuite.scala
new file mode 100644
index 0000000000..5aa263484d
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/types/decimal/DecimalSuite.scala
@@ -0,0 +1,158 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.types.decimal
+
+import org.scalatest.{PrivateMethodTester, FunSuite}
+
+import scala.language.postfixOps
+
+class DecimalSuite extends FunSuite with PrivateMethodTester {
+ test("creating decimals") {
+ /** Check that a Decimal has the given string representation, precision and scale */
+ def checkDecimal(d: Decimal, string: String, precision: Int, scale: Int): Unit = {
+ assert(d.toString === string)
+ assert(d.precision === precision)
+ assert(d.scale === scale)
+ }
+
+ checkDecimal(new Decimal(), "0", 1, 0)
+ checkDecimal(Decimal(BigDecimal("10.030")), "10.030", 5, 3)
+ checkDecimal(Decimal(BigDecimal("10.030"), 4, 1), "10.0", 4, 1)
+ checkDecimal(Decimal(BigDecimal("-9.95"), 4, 1), "-10.0", 4, 1)
+ checkDecimal(Decimal("10.030"), "10.030", 5, 3)
+ checkDecimal(Decimal(10.03), "10.03", 4, 2)
+ checkDecimal(Decimal(17L), "17", 20, 0)
+ checkDecimal(Decimal(17), "17", 10, 0)
+ checkDecimal(Decimal(17L, 2, 1), "1.7", 2, 1)
+ checkDecimal(Decimal(170L, 4, 2), "1.70", 4, 2)
+ checkDecimal(Decimal(17L, 24, 1), "1.7", 24, 1)
+ checkDecimal(Decimal(1e17.toLong, 18, 0), 1e17.toLong.toString, 18, 0)
+ checkDecimal(Decimal(Long.MaxValue), Long.MaxValue.toString, 20, 0)
+ checkDecimal(Decimal(Long.MinValue), Long.MinValue.toString, 20, 0)
+ intercept[IllegalArgumentException](Decimal(170L, 2, 1))
+ intercept[IllegalArgumentException](Decimal(170L, 2, 0))
+ intercept[IllegalArgumentException](Decimal(BigDecimal("10.030"), 2, 1))
+ intercept[IllegalArgumentException](Decimal(BigDecimal("-9.95"), 2, 1))
+ intercept[IllegalArgumentException](Decimal(1e17.toLong, 17, 0))
+ }
+
+ test("double and long values") {
+ /** Check that a Decimal converts to the given double and long values */
+ def checkValues(d: Decimal, doubleValue: Double, longValue: Long): Unit = {
+ assert(d.toDouble === doubleValue)
+ assert(d.toLong === longValue)
+ }
+
+ checkValues(new Decimal(), 0.0, 0L)
+ checkValues(Decimal(BigDecimal("10.030")), 10.03, 10L)
+ checkValues(Decimal(BigDecimal("10.030"), 4, 1), 10.0, 10L)
+ checkValues(Decimal(BigDecimal("-9.95"), 4, 1), -10.0, -10L)
+ checkValues(Decimal(10.03), 10.03, 10L)
+ checkValues(Decimal(17L), 17.0, 17L)
+ checkValues(Decimal(17), 17.0, 17L)
+ checkValues(Decimal(17L, 2, 1), 1.7, 1L)
+ checkValues(Decimal(170L, 4, 2), 1.7, 1L)
+ checkValues(Decimal(1e16.toLong), 1e16, 1e16.toLong)
+ checkValues(Decimal(1e17.toLong), 1e17, 1e17.toLong)
+ checkValues(Decimal(1e18.toLong), 1e18, 1e18.toLong)
+ checkValues(Decimal(2e18.toLong), 2e18, 2e18.toLong)
+ checkValues(Decimal(Long.MaxValue), Long.MaxValue.toDouble, Long.MaxValue)
+ checkValues(Decimal(Long.MinValue), Long.MinValue.toDouble, Long.MinValue)
+ checkValues(Decimal(Double.MaxValue), Double.MaxValue, 0L)
+ checkValues(Decimal(Double.MinValue), Double.MinValue, 0L)
+ }
+
+ // Accessor for the BigDecimal value of a Decimal, which will be null if it's using Longs
+ private val decimalVal = PrivateMethod[BigDecimal]('decimalVal)
+
+ /** Check whether a decimal is represented compactly (passing whether we expect it to be) */
+ private def checkCompact(d: Decimal, expected: Boolean): Unit = {
+ val isCompact = d.invokePrivate(decimalVal()).eq(null)
+ assert(isCompact == expected, s"$d ${if (expected) "was not" else "was"} compact")
+ }
+
+ test("small decimals represented as unscaled long") {
+ checkCompact(new Decimal(), true)
+ checkCompact(Decimal(BigDecimal(10.03)), false)
+ checkCompact(Decimal(BigDecimal(1e20)), false)
+ checkCompact(Decimal(17L), true)
+ checkCompact(Decimal(17), true)
+ checkCompact(Decimal(17L, 2, 1), true)
+ checkCompact(Decimal(170L, 4, 2), true)
+ checkCompact(Decimal(17L, 24, 1), true)
+ checkCompact(Decimal(1e16.toLong), true)
+ checkCompact(Decimal(1e17.toLong), true)
+ checkCompact(Decimal(1e18.toLong - 1), true)
+ checkCompact(Decimal(- 1e18.toLong + 1), true)
+ checkCompact(Decimal(1e18.toLong - 1, 30, 10), true)
+ checkCompact(Decimal(- 1e18.toLong + 1, 30, 10), true)
+ checkCompact(Decimal(1e18.toLong), false)
+ checkCompact(Decimal(-1e18.toLong), false)
+ checkCompact(Decimal(1e18.toLong, 30, 10), false)
+ checkCompact(Decimal(-1e18.toLong, 30, 10), false)
+ checkCompact(Decimal(Long.MaxValue), false)
+ checkCompact(Decimal(Long.MinValue), false)
+ }
+
+ test("hash code") {
+ assert(Decimal(123).hashCode() === (123).##)
+ assert(Decimal(-123).hashCode() === (-123).##)
+ assert(Decimal(123.312).hashCode() === (123.312).##)
+ assert(Decimal(Int.MaxValue).hashCode() === Int.MaxValue.##)
+ assert(Decimal(Long.MaxValue).hashCode() === Long.MaxValue.##)
+ assert(Decimal(BigDecimal(123)).hashCode() === (123).##)
+
+ val reallyBig = BigDecimal("123182312312313232112312312123.1231231231")
+ assert(Decimal(reallyBig).hashCode() === reallyBig.hashCode)
+ }
+
+ test("equals") {
+ // The decimals on the left are stored compactly, while the ones on the right aren't
+ checkCompact(Decimal(123), true)
+ checkCompact(Decimal(BigDecimal(123)), false)
+ checkCompact(Decimal("123"), false)
+ assert(Decimal(123) === Decimal(BigDecimal(123)))
+ assert(Decimal(123) === Decimal(BigDecimal("123.00")))
+ assert(Decimal(-123) === Decimal(BigDecimal(-123)))
+ assert(Decimal(-123) === Decimal(BigDecimal("-123.00")))
+ }
+
+ test("isZero") {
+ assert(Decimal(0).isZero)
+ assert(Decimal(0, 4, 2).isZero)
+ assert(Decimal("0").isZero)
+ assert(Decimal("0.000").isZero)
+ assert(!Decimal(1).isZero)
+ assert(!Decimal(1, 4, 2).isZero)
+ assert(!Decimal("1").isZero)
+ assert(!Decimal("0.001").isZero)
+ }
+
+ test("arithmetic") {
+ assert(Decimal(100) + Decimal(-100) === Decimal(0))
+ assert(Decimal(100) + Decimal(-100) === Decimal(0))
+ assert(Decimal(100) * Decimal(-100) === Decimal(-10000))
+ assert(Decimal(1e13) * Decimal(1e13) === Decimal(1e26))
+ assert(Decimal(100) / Decimal(-100) === Decimal(-1))
+ assert(Decimal(100) / Decimal(0) === null)
+ assert(Decimal(100) % Decimal(-100) === Decimal(0))
+ assert(Decimal(100) % Decimal(3) === Decimal(1))
+ assert(Decimal(-100) % Decimal(3) === Decimal(-1))
+ assert(Decimal(100) % Decimal(0) === null)
+ }
+}