diff options
author | Matei Zaharia <matei@databricks.com> | 2014-11-01 19:29:14 -0700 |
---|---|---|
committer | Michael Armbrust <michael@databricks.com> | 2014-11-01 19:29:14 -0700 |
commit | 23f966f47523f85ba440b4080eee665271f53b5e (patch) | |
tree | d796351567f8b187511b9049199cbf99c5826fb3 /sql/catalyst | |
parent | 56f2c61cde3f5d906c2a58e9af1a661222f2c679 (diff) | |
download | spark-23f966f47523f85ba440b4080eee665271f53b5e.tar.gz spark-23f966f47523f85ba440b4080eee665271f53b5e.tar.bz2 spark-23f966f47523f85ba440b4080eee665271f53b5e.zip |
[SPARK-3930] [SPARK-3933] Support fixed-precision decimal in SQL, and some optimizations
- Adds optional precision and scale to Spark SQL's decimal type, which behave similarly to those in Hive 13 (https://cwiki.apache.org/confluence/download/attachments/27362075/Hive_Decimal_Precision_Scale_Support.pdf)
- Replaces our internal representation of decimals with a Decimal class that can store small values in a mutable Long, saving memory in this situation and letting some operations happen directly on Longs
This is still marked WIP because there are a few TODOs, but I'll remove that tag when done.
Author: Matei Zaharia <matei@databricks.com>
Closes #2983 from mateiz/decimal-1 and squashes the following commits:
35e6b02 [Matei Zaharia] Fix issues after merge
227f24a [Matei Zaharia] Review comments
31f915e [Matei Zaharia] Implement Davies's suggestions in Python
eb84820 [Matei Zaharia] Support reading/writing decimals as fixed-length binary in Parquet
4dc6bae [Matei Zaharia] Fix decimal support in PySpark
d1d9d68 [Matei Zaharia] Fix compile error and test issues after rebase
b28933d [Matei Zaharia] Support decimal precision/scale in Hive metastore
2118c0d [Matei Zaharia] Some test and bug fixes
81db9cb [Matei Zaharia] Added mutable Decimal that will be more efficient for small precisions
7af0c3b [Matei Zaharia] Add optional precision and scale to DecimalType, but use Unlimited for now
ec0a947 [Matei Zaharia] Make the result of AVG on Decimals be Decimal, not Double
Diffstat (limited to 'sql/catalyst')
19 files changed, 1154 insertions, 106 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 75923d9e8d..8fbdf664b7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -19,9 +19,10 @@ package org.apache.spark.sql.catalyst import java.sql.{Date, Timestamp} -import org.apache.spark.sql.catalyst.expressions.{GenericRow, Attribute, AttributeReference} +import org.apache.spark.sql.catalyst.expressions.{GenericRow, Attribute, AttributeReference, Row} import org.apache.spark.sql.catalyst.plans.logical.LocalRelation import org.apache.spark.sql.catalyst.types._ +import org.apache.spark.sql.catalyst.types.decimal.Decimal /** * Provides experimental support for generating catalyst schemas for scala objects. @@ -40,9 +41,20 @@ object ScalaReflection { case s: Seq[_] => s.map(convertToCatalyst) case m: Map[_, _] => m.map { case (k, v) => convertToCatalyst(k) -> convertToCatalyst(v) } case p: Product => new GenericRow(p.productIterator.map(convertToCatalyst).toArray) + case d: BigDecimal => Decimal(d) case other => other } + /** Converts Catalyst types used internally in rows to standard Scala types */ + def convertToScala(a: Any): Any = a match { + case s: Seq[_] => s.map(convertToScala) + case m: Map[_, _] => m.map { case (k, v) => convertToScala(k) -> convertToScala(v) } + case d: Decimal => d.toBigDecimal + case other => other + } + + def convertRowToScala(r: Row): Row = new GenericRow(r.toArray.map(convertToScala)) + /** Returns a Sequence of attributes for the given case class type. */ def attributesFor[T: TypeTag]: Seq[Attribute] = schemaFor[T] match { case Schema(s: StructType, _) => @@ -83,7 +95,8 @@ object ScalaReflection { case t if t <:< typeOf[String] => Schema(StringType, nullable = true) case t if t <:< typeOf[Timestamp] => Schema(TimestampType, nullable = true) case t if t <:< typeOf[Date] => Schema(DateType, nullable = true) - case t if t <:< typeOf[BigDecimal] => Schema(DecimalType, nullable = true) + case t if t <:< typeOf[BigDecimal] => Schema(DecimalType.Unlimited, nullable = true) + case t if t <:< typeOf[Decimal] => Schema(DecimalType.Unlimited, nullable = true) case t if t <:< typeOf[java.lang.Integer] => Schema(IntegerType, nullable = true) case t if t <:< typeOf[java.lang.Long] => Schema(LongType, nullable = true) case t if t <:< typeOf[java.lang.Double] => Schema(DoubleType, nullable = true) @@ -111,8 +124,9 @@ object ScalaReflection { case obj: LongType.JvmType => LongType case obj: FloatType.JvmType => FloatType case obj: DoubleType.JvmType => DoubleType - case obj: DecimalType.JvmType => DecimalType case obj: DateType.JvmType => DateType + case obj: BigDecimal => DecimalType.Unlimited + case obj: Decimal => DecimalType.Unlimited case obj: TimestampType.JvmType => TimestampType case null => NullType // For other cases, there is no obvious mapping from the type of the given object to a diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala index b1e7570f57..00fc4d75c9 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -52,11 +52,13 @@ class SqlParser extends AbstractSparkSQLParser { protected val CASE = Keyword("CASE") protected val CAST = Keyword("CAST") protected val COUNT = Keyword("COUNT") + protected val DECIMAL = Keyword("DECIMAL") protected val DESC = Keyword("DESC") protected val DISTINCT = Keyword("DISTINCT") protected val ELSE = Keyword("ELSE") protected val END = Keyword("END") protected val EXCEPT = Keyword("EXCEPT") + protected val DOUBLE = Keyword("DOUBLE") protected val FALSE = Keyword("FALSE") protected val FIRST = Keyword("FIRST") protected val FROM = Keyword("FROM") @@ -385,5 +387,15 @@ class SqlParser extends AbstractSparkSQLParser { } protected lazy val dataType: Parser[DataType] = - STRING ^^^ StringType | TIMESTAMP ^^^ TimestampType + ( STRING ^^^ StringType + | TIMESTAMP ^^^ TimestampType + | DOUBLE ^^^ DoubleType + | fixedDecimalType + | DECIMAL ^^^ DecimalType.Unlimited + ) + + protected lazy val fixedDecimalType: Parser[DataType] = + (DECIMAL ~ "(" ~> numericLit) ~ ("," ~> numericLit <~ ")") ^^ { + case precision ~ scale => DecimalType(precision.toInt, scale.toInt) + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 2b69c02b28..e38114ab3c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -25,19 +25,31 @@ import org.apache.spark.sql.catalyst.types._ object HiveTypeCoercion { // See https://cwiki.apache.org/confluence/display/Hive/LanguageManual+Types. // The conversion for integral and floating point types have a linear widening hierarchy: - val numericPrecedence = - Seq(ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType, DecimalType) - val allPromotions: Seq[Seq[DataType]] = numericPrecedence :: Nil + private val numericPrecedence = + Seq(ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType, DecimalType.Unlimited) + /** + * Find the tightest common type of two types that might be used in a binary expression. + * This handles all numeric types except fixed-precision decimals interacting with each other or + * with primitive types, because in that case the precision and scale of the result depends on + * the operation. Those rules are implemented in [[HiveTypeCoercion.DecimalPrecision]]. + */ def findTightestCommonType(t1: DataType, t2: DataType): Option[DataType] = { val valueTypes = Seq(t1, t2).filter(t => t != NullType) if (valueTypes.distinct.size > 1) { - // Try and find a promotion rule that contains both types in question. - val applicableConversion = - HiveTypeCoercion.allPromotions.find(p => p.contains(t1) && p.contains(t2)) - - // If found return the widest common type, otherwise None - applicableConversion.map(_.filter(t => t == t1 || t == t2).last) + // Promote numeric types to the highest of the two and all numeric types to unlimited decimal + if (numericPrecedence.contains(t1) && numericPrecedence.contains(t2)) { + Some(numericPrecedence.filter(t => t == t1 || t == t2).last) + } else if (t1.isInstanceOf[DecimalType] && t2.isInstanceOf[DecimalType]) { + // Fixed-precision decimals can up-cast into unlimited + if (t1 == DecimalType.Unlimited || t2 == DecimalType.Unlimited) { + Some(DecimalType.Unlimited) + } else { + None + } + } else { + None + } } else { Some(if (valueTypes.size == 0) NullType else valueTypes.head) } @@ -59,6 +71,7 @@ trait HiveTypeCoercion { ConvertNaNs :: WidenTypes :: PromoteStrings :: + DecimalPrecision :: BooleanComparisons :: BooleanCasts :: StringToIntegralCasts :: @@ -151,6 +164,7 @@ trait HiveTypeCoercion { import HiveTypeCoercion._ def apply(plan: LogicalPlan): LogicalPlan = plan transform { + // TODO: unions with fixed-precision decimals case u @ Union(left, right) if u.childrenResolved && !u.resolved => val castedInput = left.output.zip(right.output).map { // When a string is found on one side, make the other side a string too. @@ -265,6 +279,110 @@ trait HiveTypeCoercion { } } + // scalastyle:off + /** + * Calculates and propagates precision for fixed-precision decimals. Hive has a number of + * rules for this based on the SQL standard and MS SQL: + * https://cwiki.apache.org/confluence/download/attachments/27362075/Hive_Decimal_Precision_Scale_Support.pdf + * + * In particular, if we have expressions e1 and e2 with precision/scale p1/s2 and p2/s2 + * respectively, then the following operations have the following precision / scale: + * + * Operation Result Precision Result Scale + * ------------------------------------------------------------------------ + * e1 + e2 max(s1, s2) + max(p1-s1, p2-s2) + 1 max(s1, s2) + * e1 - e2 max(s1, s2) + max(p1-s1, p2-s2) + 1 max(s1, s2) + * e1 * e2 p1 + p2 + 1 s1 + s2 + * e1 / e2 p1 - s1 + s2 + max(6, s1 + p2 + 1) max(6, s1 + p2 + 1) + * e1 % e2 min(p1-s1, p2-s2) + max(s1, s2) max(s1, s2) + * sum(e1) p1 + 10 s1 + * avg(e1) p1 + 4 s1 + 4 + * + * Catalyst also has unlimited-precision decimals. For those, all ops return unlimited precision. + * + * To implement the rules for fixed-precision types, we introduce casts to turn them to unlimited + * precision, do the math on unlimited-precision numbers, then introduce casts back to the + * required fixed precision. This allows us to do all rounding and overflow handling in the + * cast-to-fixed-precision operator. + * + * In addition, when mixing non-decimal types with decimals, we use the following rules: + * - BYTE gets turned into DECIMAL(3, 0) + * - SHORT gets turned into DECIMAL(5, 0) + * - INT gets turned into DECIMAL(10, 0) + * - LONG gets turned into DECIMAL(20, 0) + * - FLOAT and DOUBLE cause fixed-length decimals to turn into DOUBLE (this is the same as Hive, + * but note that unlimited decimals are considered bigger than doubles in WidenTypes) + */ + // scalastyle:on + object DecimalPrecision extends Rule[LogicalPlan] { + import scala.math.{max, min} + + // Conversion rules for integer types into fixed-precision decimals + val intTypeToFixed: Map[DataType, DecimalType] = Map( + ByteType -> DecimalType(3, 0), + ShortType -> DecimalType(5, 0), + IntegerType -> DecimalType(10, 0), + LongType -> DecimalType(20, 0) + ) + + def isFloat(t: DataType): Boolean = t == FloatType || t == DoubleType + + def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + // Skip nodes whose children have not been resolved yet + case e if !e.childrenResolved => e + + case Add(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => + Cast( + Add(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)), + DecimalType(max(s1, s2) + max(p1 - s1, p2 - s2) + 1, max(s1, s2)) + ) + + case Subtract(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => + Cast( + Subtract(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)), + DecimalType(max(s1, s2) + max(p1 - s1, p2 - s2) + 1, max(s1, s2)) + ) + + case Multiply(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => + Cast( + Multiply(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)), + DecimalType(p1 + p2 + 1, s1 + s2) + ) + + case Divide(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => + Cast( + Divide(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)), + DecimalType(p1 - s1 + s2 + max(6, s1 + p2 + 1), max(6, s1 + p2 + 1)) + ) + + case Remainder(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => + Cast( + Remainder(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)), + DecimalType(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2)) + ) + + // Promote integers inside a binary expression with fixed-precision decimals to decimals, + // and fixed-precision decimals in an expression with floats / doubles to doubles + case b: BinaryExpression if b.left.dataType != b.right.dataType => + (b.left.dataType, b.right.dataType) match { + case (t, DecimalType.Fixed(p, s)) if intTypeToFixed.contains(t) => + b.makeCopy(Array(Cast(b.left, intTypeToFixed(t)), b.right)) + case (DecimalType.Fixed(p, s), t) if intTypeToFixed.contains(t) => + b.makeCopy(Array(b.left, Cast(b.right, intTypeToFixed(t)))) + case (t, DecimalType.Fixed(p, s)) if isFloat(t) => + b.makeCopy(Array(b.left, Cast(b.right, DoubleType))) + case (DecimalType.Fixed(p, s), t) if isFloat(t) => + b.makeCopy(Array(Cast(b.left, DoubleType), b.right)) + case _ => + b + } + + // TODO: MaxOf, MinOf, etc might want other rules + + // SUM and AVERAGE are handled by the implementations of those expressions + } + } + /** * Changes Boolean values to Bytes so that expressions like true < false can be Evaluated. */ @@ -330,7 +448,7 @@ trait HiveTypeCoercion { case e if !e.childrenResolved => e case Cast(e @ StringType(), t: IntegralType) => - Cast(Cast(e, DecimalType), t) + Cast(Cast(e, DecimalType.Unlimited), t) } } @@ -383,10 +501,12 @@ trait HiveTypeCoercion { // Decimal and Double remain the same case d: Divide if d.resolved && d.dataType == DoubleType => d - case d: Divide if d.resolved && d.dataType == DecimalType => d + case d: Divide if d.resolved && d.dataType.isInstanceOf[DecimalType] => d - case Divide(l, r) if l.dataType == DecimalType => Divide(l, Cast(r, DecimalType)) - case Divide(l, r) if r.dataType == DecimalType => Divide(Cast(l, DecimalType), r) + case Divide(l, r) if l.dataType.isInstanceOf[DecimalType] => + Divide(l, Cast(r, DecimalType.Unlimited)) + case Divide(l, r) if r.dataType.isInstanceOf[DecimalType] => + Divide(Cast(l, DecimalType.Unlimited), r) case Divide(l, r) => Divide(Cast(l, DoubleType), Cast(r, DoubleType)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 23cfd483ec..7e6d770314 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.catalyst import java.sql.{Date, Timestamp} +import org.apache.spark.sql.catalyst.types.decimal.Decimal + import scala.language.implicitConversions import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute @@ -124,7 +126,8 @@ package object dsl { implicit def doubleToLiteral(d: Double) = Literal(d) implicit def stringToLiteral(s: String) = Literal(s) implicit def dateToLiteral(d: Date) = Literal(d) - implicit def decimalToLiteral(d: BigDecimal) = Literal(d) + implicit def bigDecimalToLiteral(d: BigDecimal) = Literal(d) + implicit def decimalToLiteral(d: Decimal) = Literal(d) implicit def timestampToLiteral(t: Timestamp) = Literal(t) implicit def binaryToLiteral(a: Array[Byte]) = Literal(a) @@ -183,7 +186,11 @@ package object dsl { def date = AttributeReference(s, DateType, nullable = true)() /** Creates a new AttributeReference of type decimal */ - def decimal = AttributeReference(s, DecimalType, nullable = true)() + def decimal = AttributeReference(s, DecimalType.Unlimited, nullable = true)() + + /** Creates a new AttributeReference of type decimal */ + def decimal(precision: Int, scale: Int) = + AttributeReference(s, DecimalType(precision, scale), nullable = true)() /** Creates a new AttributeReference of type timestamp */ def timestamp = AttributeReference(s, TimestampType, nullable = true)() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 8e5baf0eb8..2200966619 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -23,6 +23,7 @@ import java.text.{DateFormat, SimpleDateFormat} import org.apache.spark.Logging import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.types._ +import org.apache.spark.sql.catalyst.types.decimal.Decimal /** Cast the child expression to the target data type. */ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression with Logging { @@ -36,6 +37,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w case (BooleanType, DateType) => true case (DateType, _: NumericType) => true case (DateType, BooleanType) => true + case (_, DecimalType.Fixed(_, _)) => true // TODO: not all upcasts here can really give null case _ => child.nullable } @@ -76,8 +78,8 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w buildCast[Short](_, _ != 0) case ByteType => buildCast[Byte](_, _ != 0) - case DecimalType => - buildCast[BigDecimal](_, _ != 0) + case DecimalType() => + buildCast[Decimal](_, _ != 0) case DoubleType => buildCast[Double](_, _ != 0) case FloatType => @@ -109,19 +111,19 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w case DateType => buildCast[Date](_, d => new Timestamp(d.getTime)) // TimestampWritable.decimalToTimestamp - case DecimalType => - buildCast[BigDecimal](_, d => decimalToTimestamp(d)) + case DecimalType() => + buildCast[Decimal](_, d => decimalToTimestamp(d)) // TimestampWritable.doubleToTimestamp case DoubleType => - buildCast[Double](_, d => decimalToTimestamp(d)) + buildCast[Double](_, d => decimalToTimestamp(Decimal(d))) // TimestampWritable.floatToTimestamp case FloatType => - buildCast[Float](_, f => decimalToTimestamp(f)) + buildCast[Float](_, f => decimalToTimestamp(Decimal(f))) } - private[this] def decimalToTimestamp(d: BigDecimal) = { + private[this] def decimalToTimestamp(d: Decimal) = { val seconds = Math.floor(d.toDouble).toLong - val bd = (d - seconds) * 1000000000 + val bd = (d.toBigDecimal - seconds) * 1000000000 val nanos = bd.intValue() val millis = seconds * 1000 @@ -196,8 +198,8 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w buildCast[Date](_, d => dateToLong(d)) case TimestampType => buildCast[Timestamp](_, t => timestampToLong(t)) - case DecimalType => - buildCast[BigDecimal](_, _.toLong) + case DecimalType() => + buildCast[Decimal](_, _.toLong) case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toLong(b) } @@ -214,8 +216,8 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w buildCast[Date](_, d => dateToLong(d)) case TimestampType => buildCast[Timestamp](_, t => timestampToLong(t).toInt) - case DecimalType => - buildCast[BigDecimal](_, _.toInt) + case DecimalType() => + buildCast[Decimal](_, _.toInt) case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b) } @@ -232,8 +234,8 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w buildCast[Date](_, d => dateToLong(d)) case TimestampType => buildCast[Timestamp](_, t => timestampToLong(t).toShort) - case DecimalType => - buildCast[BigDecimal](_, _.toShort) + case DecimalType() => + buildCast[Decimal](_, _.toShort) case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b).toShort } @@ -250,27 +252,45 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w buildCast[Date](_, d => dateToLong(d)) case TimestampType => buildCast[Timestamp](_, t => timestampToLong(t).toByte) - case DecimalType => - buildCast[BigDecimal](_, _.toByte) + case DecimalType() => + buildCast[Decimal](_, _.toByte) case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b).toByte } - // DecimalConverter - private[this] def castToDecimal: Any => Any = child.dataType match { + /** + * Change the precision / scale in a given decimal to those set in `decimalType` (if any), + * returning null if it overflows or modifying `value` in-place and returning it if successful. + * + * NOTE: this modifies `value` in-place, so don't call it on external data. + */ + private[this] def changePrecision(value: Decimal, decimalType: DecimalType): Decimal = { + decimalType match { + case DecimalType.Unlimited => + value + case DecimalType.Fixed(precision, scale) => + if (value.changePrecision(precision, scale)) value else null + } + } + + private[this] def castToDecimal(target: DecimalType): Any => Any = child.dataType match { case StringType => - buildCast[String](_, s => try BigDecimal(s.toDouble) catch { + buildCast[String](_, s => try changePrecision(Decimal(s.toDouble), target) catch { case _: NumberFormatException => null }) case BooleanType => - buildCast[Boolean](_, b => if (b) BigDecimal(1) else BigDecimal(0)) + buildCast[Boolean](_, b => changePrecision(if (b) Decimal(1) else Decimal(0), target)) case DateType => - buildCast[Date](_, d => dateToDouble(d)) + buildCast[Date](_, d => changePrecision(null, target)) // date can't cast to decimal in Hive case TimestampType => // Note that we lose precision here. - buildCast[Timestamp](_, t => BigDecimal(timestampToDouble(t))) - case x: NumericType => - b => BigDecimal(x.numeric.asInstanceOf[Numeric[Any]].toDouble(b)) + buildCast[Timestamp](_, t => changePrecision(Decimal(timestampToDouble(t)), target)) + case DecimalType() => + b => changePrecision(b.asInstanceOf[Decimal].clone(), target) + case LongType => + b => changePrecision(Decimal(b.asInstanceOf[Long]), target) + case x: NumericType => // All other numeric types can be represented precisely as Doubles + b => changePrecision(Decimal(x.numeric.asInstanceOf[Numeric[Any]].toDouble(b)), target) } // DoubleConverter @@ -285,8 +305,8 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w buildCast[Date](_, d => dateToDouble(d)) case TimestampType => buildCast[Timestamp](_, t => timestampToDouble(t)) - case DecimalType => - buildCast[BigDecimal](_, _.toDouble) + case DecimalType() => + buildCast[Decimal](_, _.toDouble) case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toDouble(b) } @@ -303,8 +323,8 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w buildCast[Date](_, d => dateToDouble(d)) case TimestampType => buildCast[Timestamp](_, t => timestampToDouble(t).toFloat) - case DecimalType => - buildCast[BigDecimal](_, _.toFloat) + case DecimalType() => + buildCast[Decimal](_, _.toFloat) case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toFloat(b) } @@ -313,8 +333,8 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w case dt if dt == child.dataType => identity[Any] case StringType => castToString case BinaryType => castToBinary - case DecimalType => castToDecimal case DateType => castToDate + case decimal: DecimalType => castToDecimal(decimal) case TimestampType => castToTimestamp case BooleanType => castToBoolean case ByteType => castToByte diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala index 1b4d892625..2b364fc1df 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala @@ -286,18 +286,38 @@ case class ApproxCountDistinct(child: Expression, relativeSD: Double = 0.05) case class Average(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { override def nullable = false - override def dataType = DoubleType + + override def dataType = child.dataType match { + case DecimalType.Fixed(precision, scale) => + DecimalType(precision + 4, scale + 4) // Add 4 digits after decimal point, like Hive + case DecimalType.Unlimited => + DecimalType.Unlimited + case _ => + DoubleType + } + override def toString = s"AVG($child)" override def asPartial: SplitEvaluation = { val partialSum = Alias(Sum(child), "PartialSum")() val partialCount = Alias(Count(child), "PartialCount")() - val castedSum = Cast(Sum(partialSum.toAttribute), dataType) - val castedCount = Cast(Sum(partialCount.toAttribute), dataType) - SplitEvaluation( - Divide(castedSum, castedCount), - partialCount :: partialSum :: Nil) + child.dataType match { + case DecimalType.Fixed(_, _) => + // Turn the results to unlimited decimals for the divsion, before going back to fixed + val castedSum = Cast(Sum(partialSum.toAttribute), DecimalType.Unlimited) + val castedCount = Cast(Sum(partialCount.toAttribute), DecimalType.Unlimited) + SplitEvaluation( + Cast(Divide(castedSum, castedCount), dataType), + partialCount :: partialSum :: Nil) + + case _ => + val castedSum = Cast(Sum(partialSum.toAttribute), dataType) + val castedCount = Cast(Sum(partialCount.toAttribute), dataType) + SplitEvaluation( + Divide(castedSum, castedCount), + partialCount :: partialSum :: Nil) + } } override def newInstance() = new AverageFunction(child, this) @@ -306,7 +326,16 @@ case class Average(child: Expression) extends PartialAggregate with trees.UnaryN case class Sum(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { override def nullable = false - override def dataType = child.dataType + + override def dataType = child.dataType match { + case DecimalType.Fixed(precision, scale) => + DecimalType(precision + 10, scale) // Add 10 digits left of decimal point, like Hive + case DecimalType.Unlimited => + DecimalType.Unlimited + case _ => + child.dataType + } + override def toString = s"SUM($child)" override def asPartial: SplitEvaluation = { @@ -322,9 +351,17 @@ case class Sum(child: Expression) extends PartialAggregate with trees.UnaryNode[ case class SumDistinct(child: Expression) extends AggregateExpression with trees.UnaryNode[Expression] { - override def nullable = false - override def dataType = child.dataType + + override def dataType = child.dataType match { + case DecimalType.Fixed(precision, scale) => + DecimalType(precision + 10, scale) // Add 10 digits left of decimal point, like Hive + case DecimalType.Unlimited => + DecimalType.Unlimited + case _ => + child.dataType + } + override def toString = s"SUM(DISTINCT $child)" override def newInstance() = new SumDistinctFunction(child, this) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 83e8466ec2..8574cabc43 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -36,7 +36,7 @@ case class UnaryMinus(child: Expression) extends UnaryExpression { case class Sqrt(child: Expression) extends UnaryExpression { type EvaluatedType = Any - + def dataType = DoubleType override def foldable = child.foldable def nullable = child.nullable @@ -55,7 +55,9 @@ abstract class BinaryArithmetic extends BinaryExpression { def nullable = left.nullable || right.nullable override lazy val resolved = - left.resolved && right.resolved && left.dataType == right.dataType + left.resolved && right.resolved && + left.dataType == right.dataType && + !DecimalType.isFixed(left.dataType) def dataType = { if (!resolved) { @@ -104,6 +106,8 @@ case class Multiply(left: Expression, right: Expression) extends BinaryArithmeti case class Divide(left: Expression, right: Expression) extends BinaryArithmetic { def symbol = "/" + override def nullable = left.nullable || right.nullable || dataType.isInstanceOf[DecimalType] + override def eval(input: Row): Any = dataType match { case _: FractionalType => f2(input, left, right, _.div(_, _)) case _: IntegralType => i2(input, left , right, _.quot(_, _)) @@ -114,6 +118,8 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic case class Remainder(left: Expression, right: Expression) extends BinaryArithmetic { def symbol = "%" + override def nullable = left.nullable || right.nullable || dataType.isInstanceOf[DecimalType] + override def eval(input: Row): Any = i2(input, left, right, _.rem(_, _)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 5a3f013c34..67f8d411b6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions.codegen import com.google.common.cache.{CacheLoader, CacheBuilder} +import org.apache.spark.sql.catalyst.types.decimal.Decimal import scala.language.existentials @@ -485,6 +486,34 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin } """.children + case UnscaledValue(child) => + val childEval = expressionEvaluator(child) + + childEval.code ++ + q""" + var $nullTerm = ${childEval.nullTerm} + var $primitiveTerm: Long = if (!$nullTerm) { + ${childEval.primitiveTerm}.toUnscaledLong + } else { + ${defaultPrimitive(LongType)} + } + """.children + + case MakeDecimal(child, precision, scale) => + val childEval = expressionEvaluator(child) + + childEval.code ++ + q""" + var $nullTerm = ${childEval.nullTerm} + var $primitiveTerm: org.apache.spark.sql.catalyst.types.decimal.Decimal = + ${defaultPrimitive(DecimalType())} + + if (!$nullTerm) { + $primitiveTerm = new org.apache.spark.sql.catalyst.types.decimal.Decimal() + $primitiveTerm = $primitiveTerm.setOrNull(${childEval.primitiveTerm}, $precision, $scale) + $nullTerm = $primitiveTerm == null + } + """.children } // If there was no match in the partial function above, we fall back on calling the interpreted @@ -562,7 +591,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin case LongType => ru.Literal(Constant(1L)) case ByteType => ru.Literal(Constant(-1.toByte)) case DoubleType => ru.Literal(Constant(-1.toDouble)) - case DecimalType => ru.Literal(Constant(-1)) // Will get implicity converted as needed. + case DecimalType() => q"org.apache.spark.sql.catalyst.types.decimal.Decimal(-1)" case IntegerType => ru.Literal(Constant(-1)) case _ => ru.Literal(Constant(null)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala new file mode 100644 index 0000000000..d1eab2eb4e --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.sql.catalyst.types.decimal.Decimal +import org.apache.spark.sql.catalyst.types.{DecimalType, LongType, DoubleType, DataType} + +/** Return the unscaled Long value of a Decimal, assuming it fits in a Long */ +case class UnscaledValue(child: Expression) extends UnaryExpression { + override type EvaluatedType = Any + + override def dataType: DataType = LongType + override def foldable = child.foldable + def nullable = child.nullable + override def toString = s"UnscaledValue($child)" + + override def eval(input: Row): Any = { + val childResult = child.eval(input) + if (childResult == null) { + null + } else { + childResult.asInstanceOf[Decimal].toUnscaledLong + } + } +} + +/** Create a Decimal from an unscaled Long value */ +case class MakeDecimal(child: Expression, precision: Int, scale: Int) extends UnaryExpression { + override type EvaluatedType = Decimal + + override def dataType: DataType = DecimalType(precision, scale) + override def foldable = child.foldable + def nullable = child.nullable + override def toString = s"MakeDecimal($child,$precision,$scale)" + + override def eval(input: Row): Decimal = { + val childResult = child.eval(input) + if (childResult == null) { + null + } else { + new Decimal().setOrNull(childResult.asInstanceOf[Long], precision, scale) + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index ba240233ca..93c1932515 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import java.sql.{Date, Timestamp} import org.apache.spark.sql.catalyst.types._ +import org.apache.spark.sql.catalyst.types.decimal.Decimal object Literal { def apply(v: Any): Literal = v match { @@ -31,7 +32,8 @@ object Literal { case s: Short => Literal(s, ShortType) case s: String => Literal(s, StringType) case b: Boolean => Literal(b, BooleanType) - case d: BigDecimal => Literal(d, DecimalType) + case d: BigDecimal => Literal(Decimal(d), DecimalType.Unlimited) + case d: Decimal => Literal(d, DecimalType.Unlimited) case t: Timestamp => Literal(t, TimestampType) case d: Date => Literal(d, DateType) case a: Array[Byte] => Literal(a, BinaryType) @@ -62,7 +64,7 @@ case class Literal(value: Any, dataType: DataType) extends LeafExpression { } // TODO: Specialize -case class MutableLiteral(var value: Any, dataType: DataType, nullable: Boolean = true) +case class MutableLiteral(var value: Any, dataType: DataType, nullable: Boolean = true) extends LeafExpression { type EvaluatedType = Any diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 9ce7c78195..a4aa322fc5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.plans.LeftSemi import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.catalyst.types._ +import org.apache.spark.sql.catalyst.types.decimal.Decimal abstract class Optimizer extends RuleExecutor[LogicalPlan] @@ -43,6 +44,8 @@ object DefaultOptimizer extends Optimizer { SimplifyCasts, SimplifyCaseConversionExpressions, OptimizeIn) :: + Batch("Decimal Optimizations", FixedPoint(100), + DecimalAggregates) :: Batch("Filter Pushdown", FixedPoint(100), UnionPushdown, CombineFilters, @@ -390,9 +393,9 @@ object PushPredicateThroughProject extends Rule[LogicalPlan] { * evaluated using only the attributes of the left or right side of a join. Other * [[Filter]] conditions are moved into the `condition` of the [[Join]]. * - * And also Pushes down the join filter, where the `condition` can be evaluated using only the - * attributes of the left or right side of sub query when applicable. - * + * And also Pushes down the join filter, where the `condition` can be evaluated using only the + * attributes of the left or right side of sub query when applicable. + * * Check https://cwiki.apache.org/confluence/display/Hive/OuterJoinBehavior for more details */ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { @@ -404,7 +407,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { private def split(condition: Seq[Expression], left: LogicalPlan, right: LogicalPlan) = { val (leftEvaluateCondition, rest) = condition.partition(_.references subsetOf left.outputSet) - val (rightEvaluateCondition, commonCondition) = + val (rightEvaluateCondition, commonCondition) = rest.partition(_.references subsetOf right.outputSet) (leftEvaluateCondition, rightEvaluateCondition, commonCondition) @@ -413,7 +416,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { def apply(plan: LogicalPlan): LogicalPlan = plan transform { // push the where condition down into join filter case f @ Filter(filterCondition, Join(left, right, joinType, joinCondition)) => - val (leftFilterConditions, rightFilterConditions, commonFilterCondition) = + val (leftFilterConditions, rightFilterConditions, commonFilterCondition) = split(splitConjunctivePredicates(filterCondition), left, right) joinType match { @@ -451,7 +454,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { // push down the join filter into sub query scanning if applicable case f @ Join(left, right, joinType, joinCondition) => - val (leftJoinConditions, rightJoinConditions, commonJoinCondition) = + val (leftJoinConditions, rightJoinConditions, commonJoinCondition) = split(joinCondition.map(splitConjunctivePredicates).getOrElse(Nil), left, right) joinType match { @@ -519,3 +522,26 @@ object SimplifyCaseConversionExpressions extends Rule[LogicalPlan] { } } } + +/** + * Speeds up aggregates on fixed-precision decimals by executing them on unscaled Long values. + * + * This uses the same rules for increasing the precision and scale of the output as + * [[org.apache.spark.sql.catalyst.analysis.HiveTypeCoercion.DecimalPrecision]]. + */ +object DecimalAggregates extends Rule[LogicalPlan] { + import Decimal.MAX_LONG_DIGITS + + /** Maximum number of decimal digits representable precisely in a Double */ + val MAX_DOUBLE_DIGITS = 15 + + def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + case Sum(e @ DecimalType.Expression(prec, scale)) if prec + 10 <= MAX_LONG_DIGITS => + MakeDecimal(Sum(UnscaledValue(e)), prec + 10, scale) + + case Average(e @ DecimalType.Expression(prec, scale)) if prec + 4 <= MAX_DOUBLE_DIGITS => + Cast( + Divide(Average(UnscaledValue(e)), Literal(math.pow(10.0, scale), DoubleType)), + DecimalType(prec + 4, scale + 4)) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala index 6069f9b0a6..8dda0b1828 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.types import java.sql.{Date, Timestamp} -import scala.math.Numeric.{BigDecimalAsIfIntegral, DoubleAsIfIntegral, FloatAsIfIntegral} +import scala.math.Numeric.{FloatAsIfIntegral, BigDecimalAsIfIntegral, DoubleAsIfIntegral} import scala.reflect.ClassTag import scala.reflect.runtime.universe.{TypeTag, runtimeMirror, typeTag} import scala.util.parsing.combinator.RegexParsers @@ -33,6 +33,7 @@ import org.apache.spark.sql.catalyst.ScalaReflectionLock import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression} import org.apache.spark.sql.catalyst.util.Metadata import org.apache.spark.util.Utils +import org.apache.spark.sql.catalyst.types.decimal._ object DataType { def fromJson(json: String): DataType = parseDataType(parse(json)) @@ -91,11 +92,17 @@ object DataType { | "LongType" ^^^ LongType | "BinaryType" ^^^ BinaryType | "BooleanType" ^^^ BooleanType - | "DecimalType" ^^^ DecimalType | "DateType" ^^^ DateType + | "DecimalType()" ^^^ DecimalType.Unlimited + | fixedDecimalType | "TimestampType" ^^^ TimestampType ) + protected lazy val fixedDecimalType: Parser[DataType] = + ("DecimalType(" ~> "[0-9]+".r) ~ ("," ~> "[0-9]+".r <~ ")") ^^ { + case precision ~ scale => DecimalType(precision.toInt, scale.toInt) + } + protected lazy val arrayType: Parser[DataType] = "ArrayType" ~> "(" ~> dataType ~ "," ~ boolVal <~ ")" ^^ { case tpe ~ _ ~ containsNull => ArrayType(tpe, containsNull) @@ -200,10 +207,18 @@ trait PrimitiveType extends DataType { } object PrimitiveType { - private[sql] val all = Seq(DecimalType, DateType, TimestampType, BinaryType) ++ - NativeType.all - - private[sql] val nameToType = all.map(t => t.typeName -> t).toMap + private val nonDecimals = Seq(DateType, TimestampType, BinaryType) ++ NativeType.all + private val nonDecimalNameToType = nonDecimals.map(t => t.typeName -> t).toMap + + /** Given the string representation of a type, return its DataType */ + private[sql] def nameToType(name: String): DataType = { + val FIXED_DECIMAL = """decimal\(\s*(\d+)\s*,\s*(\d+)\s*\)""".r + name match { + case "decimal" => DecimalType.Unlimited + case FIXED_DECIMAL(precision, scale) => DecimalType(precision.toInt, scale.toInt) + case other => nonDecimalNameToType(other) + } + } } abstract class NativeType extends DataType { @@ -332,13 +347,58 @@ abstract class FractionalType extends NumericType { private[sql] val asIntegral: Integral[JvmType] } -case object DecimalType extends FractionalType { - private[sql] type JvmType = BigDecimal +/** Precision parameters for a Decimal */ +case class PrecisionInfo(precision: Int, scale: Int) + +/** A Decimal that might have fixed precision and scale, or unlimited values for these */ +case class DecimalType(precisionInfo: Option[PrecisionInfo]) extends FractionalType { + private[sql] type JvmType = Decimal @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } - private[sql] val numeric = implicitly[Numeric[BigDecimal]] - private[sql] val fractional = implicitly[Fractional[BigDecimal]] - private[sql] val ordering = implicitly[Ordering[JvmType]] - private[sql] val asIntegral = BigDecimalAsIfIntegral + private[sql] val numeric = Decimal.DecimalIsFractional + private[sql] val fractional = Decimal.DecimalIsFractional + private[sql] val ordering = Decimal.DecimalIsFractional + private[sql] val asIntegral = Decimal.DecimalAsIfIntegral + + override def typeName: String = precisionInfo match { + case Some(PrecisionInfo(precision, scale)) => s"decimal($precision,$scale)" + case None => "decimal" + } + + override def toString: String = precisionInfo match { + case Some(PrecisionInfo(precision, scale)) => s"DecimalType($precision,$scale)" + case None => "DecimalType()" + } +} + +/** Extra factory methods and pattern matchers for Decimals */ +object DecimalType { + val Unlimited: DecimalType = DecimalType(None) + + object Fixed { + def unapply(t: DecimalType): Option[(Int, Int)] = + t.precisionInfo.map(p => (p.precision, p.scale)) + } + + object Expression { + def unapply(e: Expression): Option[(Int, Int)] = e.dataType match { + case t: DecimalType => t.precisionInfo.map(p => (p.precision, p.scale)) + case _ => None + } + } + + def apply(): DecimalType = Unlimited + + def apply(precision: Int, scale: Int): DecimalType = + DecimalType(Some(PrecisionInfo(precision, scale))) + + def unapply(t: DataType): Boolean = t.isInstanceOf[DecimalType] + + def unapply(e: Expression): Boolean = e.dataType.isInstanceOf[DecimalType] + + def isFixed(dataType: DataType): Boolean = dataType match { + case DecimalType.Fixed(_, _) => true + case _ => false + } } case object DoubleType extends FractionalType { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/decimal/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/decimal/Decimal.scala new file mode 100644 index 0000000000..708362acf3 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/decimal/Decimal.scala @@ -0,0 +1,335 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.types.decimal + +import org.apache.spark.annotation.DeveloperApi + +/** + * A mutable implementation of BigDecimal that can hold a Long if values are small enough. + * + * The semantics of the fields are as follows: + * - _precision and _scale represent the SQL precision and scale we are looking for + * - If decimalVal is set, it represents the whole decimal value + * - Otherwise, the decimal value is longVal / (10 ** _scale) + */ +final class Decimal extends Ordered[Decimal] with Serializable { + import Decimal.{MAX_LONG_DIGITS, POW_10, ROUNDING_MODE, BIG_DEC_ZERO} + + private var decimalVal: BigDecimal = null + private var longVal: Long = 0L + private var _precision: Int = 1 + private var _scale: Int = 0 + + def precision: Int = _precision + def scale: Int = _scale + + /** + * Set this Decimal to the given Long. Will have precision 20 and scale 0. + */ + def set(longVal: Long): Decimal = { + if (longVal <= -POW_10(MAX_LONG_DIGITS) || longVal >= POW_10(MAX_LONG_DIGITS)) { + // We can't represent this compactly as a long without risking overflow + this.decimalVal = BigDecimal(longVal) + this.longVal = 0L + } else { + this.decimalVal = null + this.longVal = longVal + } + this._precision = 20 + this._scale = 0 + this + } + + /** + * Set this Decimal to the given Int. Will have precision 10 and scale 0. + */ + def set(intVal: Int): Decimal = { + this.decimalVal = null + this.longVal = intVal + this._precision = 10 + this._scale = 0 + this + } + + /** + * Set this Decimal to the given unscaled Long, with a given precision and scale. + */ + def set(unscaled: Long, precision: Int, scale: Int): Decimal = { + if (setOrNull(unscaled, precision, scale) == null) { + throw new IllegalArgumentException("Unscaled value too large for precision") + } + this + } + + /** + * Set this Decimal to the given unscaled Long, with a given precision and scale, + * and return it, or return null if it cannot be set due to overflow. + */ + def setOrNull(unscaled: Long, precision: Int, scale: Int): Decimal = { + if (unscaled <= -POW_10(MAX_LONG_DIGITS) || unscaled >= POW_10(MAX_LONG_DIGITS)) { + // We can't represent this compactly as a long without risking overflow + if (precision < 19) { + return null // Requested precision is too low to represent this value + } + this.decimalVal = BigDecimal(longVal) + this.longVal = 0L + } else { + val p = POW_10(math.min(precision, MAX_LONG_DIGITS)) + if (unscaled <= -p || unscaled >= p) { + return null // Requested precision is too low to represent this value + } + this.decimalVal = null + this.longVal = unscaled + } + this._precision = precision + this._scale = scale + this + } + + /** + * Set this Decimal to the given BigDecimal value, with a given precision and scale. + */ + def set(decimal: BigDecimal, precision: Int, scale: Int): Decimal = { + this.decimalVal = decimal.setScale(scale, ROUNDING_MODE) + require(decimalVal.precision <= precision, "Overflowed precision") + this.longVal = 0L + this._precision = precision + this._scale = scale + this + } + + /** + * Set this Decimal to the given BigDecimal value, inheriting its precision and scale. + */ + def set(decimal: BigDecimal): Decimal = { + this.decimalVal = decimal + this.longVal = 0L + this._precision = decimal.precision + this._scale = decimal.scale + this + } + + /** + * Set this Decimal to the given Decimal value. + */ + def set(decimal: Decimal): Decimal = { + this.decimalVal = decimal.decimalVal + this.longVal = decimal.longVal + this._precision = decimal._precision + this._scale = decimal._scale + this + } + + def toBigDecimal: BigDecimal = { + if (decimalVal.ne(null)) { + decimalVal + } else { + BigDecimal(longVal, _scale) + } + } + + def toUnscaledLong: Long = { + if (decimalVal.ne(null)) { + decimalVal.underlying().unscaledValue().longValue() + } else { + longVal + } + } + + override def toString: String = toBigDecimal.toString() + + @DeveloperApi + def toDebugString: String = { + if (decimalVal.ne(null)) { + s"Decimal(expanded,$decimalVal,$precision,$scale})" + } else { + s"Decimal(compact,$longVal,$precision,$scale})" + } + } + + def toDouble: Double = toBigDecimal.doubleValue() + + def toFloat: Float = toBigDecimal.floatValue() + + def toLong: Long = { + if (decimalVal.eq(null)) { + longVal / POW_10(_scale) + } else { + decimalVal.longValue() + } + } + + def toInt: Int = toLong.toInt + + def toShort: Short = toLong.toShort + + def toByte: Byte = toLong.toByte + + /** + * Update precision and scale while keeping our value the same, and return true if successful. + * + * @return true if successful, false if overflow would occur + */ + def changePrecision(precision: Int, scale: Int): Boolean = { + // First, update our longVal if we can, or transfer over to using a BigDecimal + if (decimalVal.eq(null)) { + if (scale < _scale) { + // Easier case: we just need to divide our scale down + val diff = _scale - scale + val droppedDigits = longVal % POW_10(diff) + longVal /= POW_10(diff) + if (math.abs(droppedDigits) * 2 >= POW_10(diff)) { + longVal += (if (longVal < 0) -1L else 1L) + } + } else if (scale > _scale) { + // We might be able to multiply longVal by a power of 10 and not overflow, but if not, + // switch to using a BigDecimal + val diff = scale - _scale + val p = POW_10(math.max(MAX_LONG_DIGITS - diff, 0)) + if (diff <= MAX_LONG_DIGITS && longVal > -p && longVal < p) { + // Multiplying longVal by POW_10(diff) will still keep it below MAX_LONG_DIGITS + longVal *= POW_10(diff) + } else { + // Give up on using Longs; switch to BigDecimal, which we'll modify below + decimalVal = BigDecimal(longVal, _scale) + } + } + // In both cases, we will check whether our precision is okay below + } + + if (decimalVal.ne(null)) { + // We get here if either we started with a BigDecimal, or we switched to one because we would + // have overflowed our Long; in either case we must rescale decimalVal to the new scale. + val newVal = decimalVal.setScale(scale, ROUNDING_MODE) + if (newVal.precision > precision) { + return false + } + decimalVal = newVal + } else { + // We're still using Longs, but we should check whether we match the new precision + val p = POW_10(math.min(_precision, MAX_LONG_DIGITS)) + if (longVal <= -p || longVal >= p) { + // Note that we shouldn't have been able to fix this by switching to BigDecimal + return false + } + } + + _precision = precision + _scale = scale + true + } + + override def clone(): Decimal = new Decimal().set(this) + + override def compare(other: Decimal): Int = { + if (decimalVal.eq(null) && other.decimalVal.eq(null) && _scale == other._scale) { + if (longVal < other.longVal) -1 else if (longVal == other.longVal) 0 else 1 + } else { + toBigDecimal.compare(other.toBigDecimal) + } + } + + override def equals(other: Any) = other match { + case d: Decimal => + compare(d) == 0 + case _ => + false + } + + override def hashCode(): Int = toBigDecimal.hashCode() + + def isZero: Boolean = if (decimalVal.ne(null)) decimalVal == BIG_DEC_ZERO else longVal == 0 + + def + (that: Decimal): Decimal = Decimal(toBigDecimal + that.toBigDecimal) + + def - (that: Decimal): Decimal = Decimal(toBigDecimal - that.toBigDecimal) + + def * (that: Decimal): Decimal = Decimal(toBigDecimal * that.toBigDecimal) + + def / (that: Decimal): Decimal = + if (that.isZero) null else Decimal(toBigDecimal / that.toBigDecimal) + + def % (that: Decimal): Decimal = + if (that.isZero) null else Decimal(toBigDecimal % that.toBigDecimal) + + def remainder(that: Decimal): Decimal = this % that + + def unary_- : Decimal = { + if (decimalVal.ne(null)) { + Decimal(-decimalVal) + } else { + Decimal(-longVal, precision, scale) + } + } +} + +object Decimal { + private val ROUNDING_MODE = BigDecimal.RoundingMode.HALF_UP + + /** Maximum number of decimal digits a Long can represent */ + val MAX_LONG_DIGITS = 18 + + private val POW_10 = Array.tabulate[Long](MAX_LONG_DIGITS + 1)(i => math.pow(10, i).toLong) + + private val BIG_DEC_ZERO = BigDecimal(0) + + def apply(value: Double): Decimal = new Decimal().set(value) + + def apply(value: Long): Decimal = new Decimal().set(value) + + def apply(value: Int): Decimal = new Decimal().set(value) + + def apply(value: BigDecimal): Decimal = new Decimal().set(value) + + def apply(value: BigDecimal, precision: Int, scale: Int): Decimal = + new Decimal().set(value, precision, scale) + + def apply(unscaled: Long, precision: Int, scale: Int): Decimal = + new Decimal().set(unscaled, precision, scale) + + def apply(value: String): Decimal = new Decimal().set(BigDecimal(value)) + + // Evidence parameters for Decimal considered either as Fractional or Integral. We provide two + // parameters inheriting from a common trait since both traits define mkNumericOps. + // See scala.math's Numeric.scala for examples for Scala's built-in types. + + /** Common methods for Decimal evidence parameters */ + trait DecimalIsConflicted extends Numeric[Decimal] { + override def plus(x: Decimal, y: Decimal): Decimal = x + y + override def times(x: Decimal, y: Decimal): Decimal = x * y + override def minus(x: Decimal, y: Decimal): Decimal = x - y + override def negate(x: Decimal): Decimal = -x + override def toDouble(x: Decimal): Double = x.toDouble + override def toFloat(x: Decimal): Float = x.toFloat + override def toInt(x: Decimal): Int = x.toInt + override def toLong(x: Decimal): Long = x.toLong + override def fromInt(x: Int): Decimal = new Decimal().set(x) + override def compare(x: Decimal, y: Decimal): Int = x.compare(y) + } + + /** A [[scala.math.Fractional]] evidence parameter for Decimals. */ + object DecimalIsFractional extends DecimalIsConflicted with Fractional[Decimal] { + override def div(x: Decimal, y: Decimal): Decimal = x / y + } + + /** A [[scala.math.Integral]] evidence parameter for Decimals. */ + object DecimalAsIfIntegral extends DecimalIsConflicted with Integral[Decimal] { + override def quot(x: Decimal, y: Decimal): Decimal = x / y + override def rem(x: Decimal, y: Decimal): Decimal = x % y + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala index 430f0664b7..21b2c8e20d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala @@ -96,7 +96,7 @@ class ScalaReflectionSuite extends FunSuite { StructField("byteField", ByteType, nullable = true), StructField("booleanField", BooleanType, nullable = true), StructField("stringField", StringType, nullable = true), - StructField("decimalField", DecimalType, nullable = true), + StructField("decimalField", DecimalType.Unlimited, nullable = true), StructField("dateField", DateType, nullable = true), StructField("timestampField", TimestampType, nullable = true), StructField("binaryField", BinaryType, nullable = true))), @@ -199,7 +199,7 @@ class ScalaReflectionSuite extends FunSuite { assert(DoubleType === typeOfObject(1.7976931348623157E308)) // DecimalType - assert(DecimalType === typeOfObject(BigDecimal("1.7976931348623157E318"))) + assert(DecimalType.Unlimited === typeOfObject(BigDecimal("1.7976931348623157E318"))) // DateType assert(DateType === typeOfObject(Date.valueOf("2014-07-25"))) @@ -211,19 +211,19 @@ class ScalaReflectionSuite extends FunSuite { assert(NullType === typeOfObject(null)) def typeOfObject1: PartialFunction[Any, DataType] = typeOfObject orElse { - case value: java.math.BigInteger => DecimalType - case value: java.math.BigDecimal => DecimalType + case value: java.math.BigInteger => DecimalType.Unlimited + case value: java.math.BigDecimal => DecimalType.Unlimited case _ => StringType } - assert(DecimalType === typeOfObject1( + assert(DecimalType.Unlimited === typeOfObject1( new BigInteger("92233720368547758070"))) - assert(DecimalType === typeOfObject1( + assert(DecimalType.Unlimited === typeOfObject1( new java.math.BigDecimal("1.7976931348623157E318"))) assert(StringType === typeOfObject1(BigInt("92233720368547758070"))) def typeOfObject2: PartialFunction[Any, DataType] = typeOfObject orElse { - case value: java.math.BigInteger => DecimalType + case value: java.math.BigInteger => DecimalType.Unlimited } intercept[MatchError](typeOfObject2(BigInt("92233720368547758070"))) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 7b45738c4f..33a3cba3d4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -38,7 +38,7 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter { AttributeReference("a", StringType)(), AttributeReference("b", StringType)(), AttributeReference("c", DoubleType)(), - AttributeReference("d", DecimalType)(), + AttributeReference("d", DecimalType.Unlimited)(), AttributeReference("e", ShortType)()) before { @@ -119,7 +119,7 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter { AttributeReference("a", StringType)(), AttributeReference("b", StringType)(), AttributeReference("c", DoubleType)(), - AttributeReference("d", DecimalType)(), + AttributeReference("d", DecimalType.Unlimited)(), AttributeReference("e", ShortType)()) val expr0 = 'a / 2 @@ -137,7 +137,7 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter { assert(pl(0).dataType == DoubleType) assert(pl(1).dataType == DoubleType) assert(pl(2).dataType == DoubleType) - assert(pl(3).dataType == DecimalType) + assert(pl(3).dataType == DecimalType.Unlimited) assert(pl(4).dataType == DoubleType) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala new file mode 100644 index 0000000000..d5b7d2789a --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical.{Project, LocalRelation} +import org.apache.spark.sql.catalyst.types._ +import org.scalatest.{BeforeAndAfter, FunSuite} + +class DecimalPrecisionSuite extends FunSuite with BeforeAndAfter { + val catalog = new SimpleCatalog(false) + val analyzer = new Analyzer(catalog, EmptyFunctionRegistry, caseSensitive = false) + + val relation = LocalRelation( + AttributeReference("i", IntegerType)(), + AttributeReference("d1", DecimalType(2, 1))(), + AttributeReference("d2", DecimalType(5, 2))(), + AttributeReference("u", DecimalType.Unlimited)(), + AttributeReference("f", FloatType)() + ) + + val i: Expression = UnresolvedAttribute("i") + val d1: Expression = UnresolvedAttribute("d1") + val d2: Expression = UnresolvedAttribute("d2") + val u: Expression = UnresolvedAttribute("u") + val f: Expression = UnresolvedAttribute("f") + + before { + catalog.registerTable(None, "table", relation) + } + + private def checkType(expression: Expression, expectedType: DataType): Unit = { + val plan = Project(Seq(Alias(expression, "c")()), relation) + assert(analyzer(plan).schema.fields(0).dataType === expectedType) + } + + test("basic operations") { + checkType(Add(d1, d2), DecimalType(6, 2)) + checkType(Subtract(d1, d2), DecimalType(6, 2)) + checkType(Multiply(d1, d2), DecimalType(8, 3)) + checkType(Divide(d1, d2), DecimalType(10, 7)) + checkType(Divide(d2, d1), DecimalType(10, 6)) + checkType(Remainder(d1, d2), DecimalType(3, 2)) + checkType(Remainder(d2, d1), DecimalType(3, 2)) + checkType(Sum(d1), DecimalType(12, 1)) + checkType(Average(d1), DecimalType(6, 5)) + + checkType(Add(Add(d1, d2), d1), DecimalType(7, 2)) + checkType(Add(Add(Add(d1, d2), d1), d2), DecimalType(8, 2)) + checkType(Add(Add(d1, d2), Add(d1, d2)), DecimalType(7, 2)) + } + + test("bringing in primitive types") { + checkType(Add(d1, i), DecimalType(12, 1)) + checkType(Add(d1, f), DoubleType) + checkType(Add(i, d1), DecimalType(12, 1)) + checkType(Add(f, d1), DoubleType) + checkType(Add(d1, Cast(i, LongType)), DecimalType(22, 1)) + checkType(Add(d1, Cast(i, ShortType)), DecimalType(7, 1)) + checkType(Add(d1, Cast(i, ByteType)), DecimalType(5, 1)) + checkType(Add(d1, Cast(i, DoubleType)), DoubleType) + } + + test("unlimited decimals make everything else cast up") { + for (expr <- Seq(d1, d2, i, f, u)) { + checkType(Add(expr, u), DecimalType.Unlimited) + checkType(Subtract(expr, u), DecimalType.Unlimited) + checkType(Multiply(expr, u), DecimalType.Unlimited) + checkType(Divide(expr, u), DecimalType.Unlimited) + checkType(Remainder(expr, u), DecimalType.Unlimited) + } + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala index baeb9b0cf5..dfa2d958c0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala @@ -68,6 +68,21 @@ class HiveTypeCoercionSuite extends FunSuite { widenTest(LongType, FloatType, Some(FloatType)) widenTest(LongType, DoubleType, Some(DoubleType)) + // Casting up to unlimited-precision decimal + widenTest(IntegerType, DecimalType.Unlimited, Some(DecimalType.Unlimited)) + widenTest(DoubleType, DecimalType.Unlimited, Some(DecimalType.Unlimited)) + widenTest(DecimalType(3, 2), DecimalType.Unlimited, Some(DecimalType.Unlimited)) + widenTest(DecimalType.Unlimited, IntegerType, Some(DecimalType.Unlimited)) + widenTest(DecimalType.Unlimited, DoubleType, Some(DecimalType.Unlimited)) + widenTest(DecimalType.Unlimited, DecimalType(3, 2), Some(DecimalType.Unlimited)) + + // No up-casting for fixed-precision decimal (this is handled by arithmetic rules) + widenTest(DecimalType(2, 1), DecimalType(3, 2), None) + widenTest(DecimalType(2, 1), DoubleType, None) + widenTest(DecimalType(2, 1), IntegerType, None) + widenTest(DoubleType, DecimalType(2, 1), None) + widenTest(IntegerType, DecimalType(2, 1), None) + // StringType widenTest(NullType, StringType, Some(StringType)) widenTest(StringType, StringType, Some(StringType)) @@ -92,7 +107,7 @@ class HiveTypeCoercionSuite extends FunSuite { def ruleTest(initial: Expression, transformed: Expression) { val testRelation = LocalRelation(AttributeReference("a", IntegerType)()) assert(booleanCasts(Project(Seq(Alias(initial, "a")()), testRelation)) == - Project(Seq(Alias(transformed, "a")()), testRelation)) + Project(Seq(Alias(transformed, "a")()), testRelation)) } // Remove superflous boolean -> boolean casts. ruleTest(Cast(Literal(true), BooleanType), Literal(true)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala index 5657bc555e..6bfa0dbd65 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala @@ -21,6 +21,7 @@ import java.sql.{Date, Timestamp} import scala.collection.immutable.HashSet +import org.apache.spark.sql.catalyst.types.decimal.Decimal import org.scalatest.FunSuite import org.scalatest.Matchers._ import org.scalactic.TripleEqualsSupport.Spread @@ -138,7 +139,7 @@ class ExpressionEvaluationSuite extends FunSuite { val actual = try evaluate(expression, inputRow) catch { case e: Exception => fail(s"Exception evaluating $expression", e) } - actual.asInstanceOf[Double] shouldBe expected + actual.asInstanceOf[Double] shouldBe expected } test("IN") { @@ -165,7 +166,7 @@ class ExpressionEvaluationSuite extends FunSuite { checkEvaluation(InSet(three, nS, three +: nullS), false) checkEvaluation(InSet(one, hS, one +: s) && InSet(two, hS, two +: s), true) } - + test("MaxOf") { checkEvaluation(MaxOf(1, 2), 2) checkEvaluation(MaxOf(2, 1), 2) @@ -265,9 +266,9 @@ class ExpressionEvaluationSuite extends FunSuite { val ts = Timestamp.valueOf(nts) checkEvaluation("abdef" cast StringType, "abdef") - checkEvaluation("abdef" cast DecimalType, null) + checkEvaluation("abdef" cast DecimalType.Unlimited, null) checkEvaluation("abdef" cast TimestampType, null) - checkEvaluation("12.65" cast DecimalType, BigDecimal(12.65)) + checkEvaluation("12.65" cast DecimalType.Unlimited, Decimal(12.65)) checkEvaluation(Literal(1) cast LongType, 1) checkEvaluation(Cast(Literal(1000) cast TimestampType, LongType), 1.toLong) @@ -289,12 +290,12 @@ class ExpressionEvaluationSuite extends FunSuite { checkEvaluation(Cast(Cast(Cast(Cast( Cast("5" cast ByteType, ShortType), IntegerType), FloatType), DoubleType), LongType), 5) - checkEvaluation(Cast(Cast(Cast(Cast( - Cast("5" cast ByteType, TimestampType), DecimalType), LongType), StringType), ShortType), 0) - checkEvaluation(Cast(Cast(Cast(Cast( - Cast("5" cast TimestampType, ByteType), DecimalType), LongType), StringType), ShortType), null) - checkEvaluation(Cast(Cast(Cast(Cast( - Cast("5" cast DecimalType, ByteType), TimestampType), LongType), StringType), ShortType), 0) + checkEvaluation(Cast(Cast(Cast(Cast(Cast("5" cast + ByteType, TimestampType), DecimalType.Unlimited), LongType), StringType), ShortType), 0) + checkEvaluation(Cast(Cast(Cast(Cast(Cast("5" cast + TimestampType, ByteType), DecimalType.Unlimited), LongType), StringType), ShortType), null) + checkEvaluation(Cast(Cast(Cast(Cast(Cast("5" cast + DecimalType.Unlimited, ByteType), TimestampType), LongType), StringType), ShortType), 0) checkEvaluation(Literal(true) cast IntegerType, 1) checkEvaluation(Literal(false) cast IntegerType, 0) checkEvaluation(Cast(Literal(1) cast BooleanType, IntegerType), 1) @@ -302,7 +303,7 @@ class ExpressionEvaluationSuite extends FunSuite { checkEvaluation("23" cast DoubleType, 23d) checkEvaluation("23" cast IntegerType, 23) checkEvaluation("23" cast FloatType, 23f) - checkEvaluation("23" cast DecimalType, 23: BigDecimal) + checkEvaluation("23" cast DecimalType.Unlimited, Decimal(23)) checkEvaluation("23" cast ByteType, 23.toByte) checkEvaluation("23" cast ShortType, 23.toShort) checkEvaluation("2012-12-11" cast DoubleType, null) @@ -311,7 +312,7 @@ class ExpressionEvaluationSuite extends FunSuite { checkEvaluation(Literal(23d) + Cast(true, DoubleType), 24d) checkEvaluation(Literal(23) + Cast(true, IntegerType), 24) checkEvaluation(Literal(23f) + Cast(true, FloatType), 24f) - checkEvaluation(Literal(BigDecimal(23)) + Cast(true, DecimalType), 24: BigDecimal) + checkEvaluation(Literal(Decimal(23)) + Cast(true, DecimalType.Unlimited), Decimal(24)) checkEvaluation(Literal(23.toByte) + Cast(true, ByteType), 24.toByte) checkEvaluation(Literal(23.toShort) + Cast(true, ShortType), 24.toShort) @@ -325,7 +326,8 @@ class ExpressionEvaluationSuite extends FunSuite { assert(("abcdef" cast IntegerType).nullable === true) assert(("abcdef" cast ShortType).nullable === true) assert(("abcdef" cast ByteType).nullable === true) - assert(("abcdef" cast DecimalType).nullable === true) + assert(("abcdef" cast DecimalType.Unlimited).nullable === true) + assert(("abcdef" cast DecimalType(4, 2)).nullable === true) assert(("abcdef" cast DoubleType).nullable === true) assert(("abcdef" cast FloatType).nullable === true) @@ -338,6 +340,64 @@ class ExpressionEvaluationSuite extends FunSuite { checkEvaluation(Literal(d1) < Literal(d2), true) } + test("casting to fixed-precision decimals") { + // Overflow and rounding for casting to fixed-precision decimals: + // - Values should round with HALF_UP mode by default when you lower scale + // - Values that would overflow the target precision should turn into null + // - Because of this, casts to fixed-precision decimals should be nullable + + assert(Cast(Literal(123), DecimalType.Unlimited).nullable === false) + assert(Cast(Literal(10.03f), DecimalType.Unlimited).nullable === false) + assert(Cast(Literal(10.03), DecimalType.Unlimited).nullable === false) + assert(Cast(Literal(Decimal(10.03)), DecimalType.Unlimited).nullable === false) + + assert(Cast(Literal(123), DecimalType(2, 1)).nullable === true) + assert(Cast(Literal(10.03f), DecimalType(2, 1)).nullable === true) + assert(Cast(Literal(10.03), DecimalType(2, 1)).nullable === true) + assert(Cast(Literal(Decimal(10.03)), DecimalType(2, 1)).nullable === true) + + checkEvaluation(Cast(Literal(123), DecimalType.Unlimited), Decimal(123)) + checkEvaluation(Cast(Literal(123), DecimalType(3, 0)), Decimal(123)) + checkEvaluation(Cast(Literal(123), DecimalType(3, 1)), null) + checkEvaluation(Cast(Literal(123), DecimalType(2, 0)), null) + + checkEvaluation(Cast(Literal(10.03), DecimalType.Unlimited), Decimal(10.03)) + checkEvaluation(Cast(Literal(10.03), DecimalType(4, 2)), Decimal(10.03)) + checkEvaluation(Cast(Literal(10.03), DecimalType(3, 1)), Decimal(10.0)) + checkEvaluation(Cast(Literal(10.03), DecimalType(2, 0)), Decimal(10)) + checkEvaluation(Cast(Literal(10.03), DecimalType(1, 0)), null) + checkEvaluation(Cast(Literal(10.03), DecimalType(2, 1)), null) + checkEvaluation(Cast(Literal(10.03), DecimalType(3, 2)), null) + checkEvaluation(Cast(Literal(Decimal(10.03)), DecimalType(3, 1)), Decimal(10.0)) + checkEvaluation(Cast(Literal(Decimal(10.03)), DecimalType(3, 2)), null) + + checkEvaluation(Cast(Literal(10.05), DecimalType.Unlimited), Decimal(10.05)) + checkEvaluation(Cast(Literal(10.05), DecimalType(4, 2)), Decimal(10.05)) + checkEvaluation(Cast(Literal(10.05), DecimalType(3, 1)), Decimal(10.1)) + checkEvaluation(Cast(Literal(10.05), DecimalType(2, 0)), Decimal(10)) + checkEvaluation(Cast(Literal(10.05), DecimalType(1, 0)), null) + checkEvaluation(Cast(Literal(10.05), DecimalType(2, 1)), null) + checkEvaluation(Cast(Literal(10.05), DecimalType(3, 2)), null) + checkEvaluation(Cast(Literal(Decimal(10.05)), DecimalType(3, 1)), Decimal(10.1)) + checkEvaluation(Cast(Literal(Decimal(10.05)), DecimalType(3, 2)), null) + + checkEvaluation(Cast(Literal(9.95), DecimalType(3, 2)), Decimal(9.95)) + checkEvaluation(Cast(Literal(9.95), DecimalType(3, 1)), Decimal(10.0)) + checkEvaluation(Cast(Literal(9.95), DecimalType(2, 0)), Decimal(10)) + checkEvaluation(Cast(Literal(9.95), DecimalType(2, 1)), null) + checkEvaluation(Cast(Literal(9.95), DecimalType(1, 0)), null) + checkEvaluation(Cast(Literal(Decimal(9.95)), DecimalType(3, 1)), Decimal(10.0)) + checkEvaluation(Cast(Literal(Decimal(9.95)), DecimalType(1, 0)), null) + + checkEvaluation(Cast(Literal(-9.95), DecimalType(3, 2)), Decimal(-9.95)) + checkEvaluation(Cast(Literal(-9.95), DecimalType(3, 1)), Decimal(-10.0)) + checkEvaluation(Cast(Literal(-9.95), DecimalType(2, 0)), Decimal(-10)) + checkEvaluation(Cast(Literal(-9.95), DecimalType(2, 1)), null) + checkEvaluation(Cast(Literal(-9.95), DecimalType(1, 0)), null) + checkEvaluation(Cast(Literal(Decimal(-9.95)), DecimalType(3, 1)), Decimal(-10.0)) + checkEvaluation(Cast(Literal(Decimal(-9.95)), DecimalType(1, 0)), null) + } + test("timestamp") { val ts1 = new Timestamp(12) val ts2 = new Timestamp(123) @@ -374,7 +434,7 @@ class ExpressionEvaluationSuite extends FunSuite { millis.toFloat / 1000) checkEvaluation(Cast(Cast(millis.toDouble / 1000, TimestampType), DoubleType), millis.toDouble / 1000) - checkEvaluation(Cast(Literal(BigDecimal(1)) cast TimestampType, DecimalType), 1) + checkEvaluation(Cast(Literal(Decimal(1)) cast TimestampType, DecimalType.Unlimited), Decimal(1)) // A test for higher precision than millis checkEvaluation(Cast(Cast(0.00000001, TimestampType), DoubleType), 0.00000001) @@ -673,7 +733,7 @@ class ExpressionEvaluationSuite extends FunSuite { val expectedResults = inputSequence.map(l => math.sqrt(l.toDouble)) val rowSequence = inputSequence.map(l => new GenericRow(Array[Any](l.toDouble))) val d = 'a.double.at(0) - + for ((row, expected) <- rowSequence zip expectedResults) { checkEvaluation(Sqrt(d), expected, row) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/types/decimal/DecimalSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/types/decimal/DecimalSuite.scala new file mode 100644 index 0000000000..5aa263484d --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/types/decimal/DecimalSuite.scala @@ -0,0 +1,158 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.types.decimal + +import org.scalatest.{PrivateMethodTester, FunSuite} + +import scala.language.postfixOps + +class DecimalSuite extends FunSuite with PrivateMethodTester { + test("creating decimals") { + /** Check that a Decimal has the given string representation, precision and scale */ + def checkDecimal(d: Decimal, string: String, precision: Int, scale: Int): Unit = { + assert(d.toString === string) + assert(d.precision === precision) + assert(d.scale === scale) + } + + checkDecimal(new Decimal(), "0", 1, 0) + checkDecimal(Decimal(BigDecimal("10.030")), "10.030", 5, 3) + checkDecimal(Decimal(BigDecimal("10.030"), 4, 1), "10.0", 4, 1) + checkDecimal(Decimal(BigDecimal("-9.95"), 4, 1), "-10.0", 4, 1) + checkDecimal(Decimal("10.030"), "10.030", 5, 3) + checkDecimal(Decimal(10.03), "10.03", 4, 2) + checkDecimal(Decimal(17L), "17", 20, 0) + checkDecimal(Decimal(17), "17", 10, 0) + checkDecimal(Decimal(17L, 2, 1), "1.7", 2, 1) + checkDecimal(Decimal(170L, 4, 2), "1.70", 4, 2) + checkDecimal(Decimal(17L, 24, 1), "1.7", 24, 1) + checkDecimal(Decimal(1e17.toLong, 18, 0), 1e17.toLong.toString, 18, 0) + checkDecimal(Decimal(Long.MaxValue), Long.MaxValue.toString, 20, 0) + checkDecimal(Decimal(Long.MinValue), Long.MinValue.toString, 20, 0) + intercept[IllegalArgumentException](Decimal(170L, 2, 1)) + intercept[IllegalArgumentException](Decimal(170L, 2, 0)) + intercept[IllegalArgumentException](Decimal(BigDecimal("10.030"), 2, 1)) + intercept[IllegalArgumentException](Decimal(BigDecimal("-9.95"), 2, 1)) + intercept[IllegalArgumentException](Decimal(1e17.toLong, 17, 0)) + } + + test("double and long values") { + /** Check that a Decimal converts to the given double and long values */ + def checkValues(d: Decimal, doubleValue: Double, longValue: Long): Unit = { + assert(d.toDouble === doubleValue) + assert(d.toLong === longValue) + } + + checkValues(new Decimal(), 0.0, 0L) + checkValues(Decimal(BigDecimal("10.030")), 10.03, 10L) + checkValues(Decimal(BigDecimal("10.030"), 4, 1), 10.0, 10L) + checkValues(Decimal(BigDecimal("-9.95"), 4, 1), -10.0, -10L) + checkValues(Decimal(10.03), 10.03, 10L) + checkValues(Decimal(17L), 17.0, 17L) + checkValues(Decimal(17), 17.0, 17L) + checkValues(Decimal(17L, 2, 1), 1.7, 1L) + checkValues(Decimal(170L, 4, 2), 1.7, 1L) + checkValues(Decimal(1e16.toLong), 1e16, 1e16.toLong) + checkValues(Decimal(1e17.toLong), 1e17, 1e17.toLong) + checkValues(Decimal(1e18.toLong), 1e18, 1e18.toLong) + checkValues(Decimal(2e18.toLong), 2e18, 2e18.toLong) + checkValues(Decimal(Long.MaxValue), Long.MaxValue.toDouble, Long.MaxValue) + checkValues(Decimal(Long.MinValue), Long.MinValue.toDouble, Long.MinValue) + checkValues(Decimal(Double.MaxValue), Double.MaxValue, 0L) + checkValues(Decimal(Double.MinValue), Double.MinValue, 0L) + } + + // Accessor for the BigDecimal value of a Decimal, which will be null if it's using Longs + private val decimalVal = PrivateMethod[BigDecimal]('decimalVal) + + /** Check whether a decimal is represented compactly (passing whether we expect it to be) */ + private def checkCompact(d: Decimal, expected: Boolean): Unit = { + val isCompact = d.invokePrivate(decimalVal()).eq(null) + assert(isCompact == expected, s"$d ${if (expected) "was not" else "was"} compact") + } + + test("small decimals represented as unscaled long") { + checkCompact(new Decimal(), true) + checkCompact(Decimal(BigDecimal(10.03)), false) + checkCompact(Decimal(BigDecimal(1e20)), false) + checkCompact(Decimal(17L), true) + checkCompact(Decimal(17), true) + checkCompact(Decimal(17L, 2, 1), true) + checkCompact(Decimal(170L, 4, 2), true) + checkCompact(Decimal(17L, 24, 1), true) + checkCompact(Decimal(1e16.toLong), true) + checkCompact(Decimal(1e17.toLong), true) + checkCompact(Decimal(1e18.toLong - 1), true) + checkCompact(Decimal(- 1e18.toLong + 1), true) + checkCompact(Decimal(1e18.toLong - 1, 30, 10), true) + checkCompact(Decimal(- 1e18.toLong + 1, 30, 10), true) + checkCompact(Decimal(1e18.toLong), false) + checkCompact(Decimal(-1e18.toLong), false) + checkCompact(Decimal(1e18.toLong, 30, 10), false) + checkCompact(Decimal(-1e18.toLong, 30, 10), false) + checkCompact(Decimal(Long.MaxValue), false) + checkCompact(Decimal(Long.MinValue), false) + } + + test("hash code") { + assert(Decimal(123).hashCode() === (123).##) + assert(Decimal(-123).hashCode() === (-123).##) + assert(Decimal(123.312).hashCode() === (123.312).##) + assert(Decimal(Int.MaxValue).hashCode() === Int.MaxValue.##) + assert(Decimal(Long.MaxValue).hashCode() === Long.MaxValue.##) + assert(Decimal(BigDecimal(123)).hashCode() === (123).##) + + val reallyBig = BigDecimal("123182312312313232112312312123.1231231231") + assert(Decimal(reallyBig).hashCode() === reallyBig.hashCode) + } + + test("equals") { + // The decimals on the left are stored compactly, while the ones on the right aren't + checkCompact(Decimal(123), true) + checkCompact(Decimal(BigDecimal(123)), false) + checkCompact(Decimal("123"), false) + assert(Decimal(123) === Decimal(BigDecimal(123))) + assert(Decimal(123) === Decimal(BigDecimal("123.00"))) + assert(Decimal(-123) === Decimal(BigDecimal(-123))) + assert(Decimal(-123) === Decimal(BigDecimal("-123.00"))) + } + + test("isZero") { + assert(Decimal(0).isZero) + assert(Decimal(0, 4, 2).isZero) + assert(Decimal("0").isZero) + assert(Decimal("0.000").isZero) + assert(!Decimal(1).isZero) + assert(!Decimal(1, 4, 2).isZero) + assert(!Decimal("1").isZero) + assert(!Decimal("0.001").isZero) + } + + test("arithmetic") { + assert(Decimal(100) + Decimal(-100) === Decimal(0)) + assert(Decimal(100) + Decimal(-100) === Decimal(0)) + assert(Decimal(100) * Decimal(-100) === Decimal(-10000)) + assert(Decimal(1e13) * Decimal(1e13) === Decimal(1e26)) + assert(Decimal(100) / Decimal(-100) === Decimal(-1)) + assert(Decimal(100) / Decimal(0) === null) + assert(Decimal(100) % Decimal(-100) === Decimal(0)) + assert(Decimal(100) % Decimal(3) === Decimal(1)) + assert(Decimal(-100) % Decimal(3) === Decimal(-1)) + assert(Decimal(100) % Decimal(0) === null) + } +} |