From 23f966f47523f85ba440b4080eee665271f53b5e Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Sat, 1 Nov 2014 19:29:14 -0700 Subject: [SPARK-3930] [SPARK-3933] Support fixed-precision decimal in SQL, and some optimizations - Adds optional precision and scale to Spark SQL's decimal type, which behave similarly to those in Hive 13 (https://cwiki.apache.org/confluence/download/attachments/27362075/Hive_Decimal_Precision_Scale_Support.pdf) - Replaces our internal representation of decimals with a Decimal class that can store small values in a mutable Long, saving memory in this situation and letting some operations happen directly on Longs This is still marked WIP because there are a few TODOs, but I'll remove that tag when done. Author: Matei Zaharia Closes #2983 from mateiz/decimal-1 and squashes the following commits: 35e6b02 [Matei Zaharia] Fix issues after merge 227f24a [Matei Zaharia] Review comments 31f915e [Matei Zaharia] Implement Davies's suggestions in Python eb84820 [Matei Zaharia] Support reading/writing decimals as fixed-length binary in Parquet 4dc6bae [Matei Zaharia] Fix decimal support in PySpark d1d9d68 [Matei Zaharia] Fix compile error and test issues after rebase b28933d [Matei Zaharia] Support decimal precision/scale in Hive metastore 2118c0d [Matei Zaharia] Some test and bug fixes 81db9cb [Matei Zaharia] Added mutable Decimal that will be more efficient for small precisions 7af0c3b [Matei Zaharia] Add optional precision and scale to DecimalType, but use Unlimited for now ec0a947 [Matei Zaharia] Make the result of AVG on Decimals be Decimal, not Double --- .../spark/sql/catalyst/ScalaReflection.scala | 20 +- .../org/apache/spark/sql/catalyst/SqlParser.scala | 14 +- .../sql/catalyst/analysis/HiveTypeCoercion.scala | 146 ++++++++- .../apache/spark/sql/catalyst/dsl/package.scala | 11 +- .../spark/sql/catalyst/expressions/Cast.scala | 78 +++-- .../sql/catalyst/expressions/aggregates.scala | 55 +++- .../sql/catalyst/expressions/arithmetic.scala | 10 +- .../expressions/codegen/CodeGenerator.scala | 31 +- .../catalyst/expressions/decimalFunctions.scala | 59 ++++ .../spark/sql/catalyst/expressions/literals.scala | 6 +- .../spark/sql/catalyst/optimizer/Optimizer.scala | 38 ++- .../spark/sql/catalyst/types/dataTypes.scala | 84 +++++- .../spark/sql/catalyst/types/decimal/Decimal.scala | 335 +++++++++++++++++++++ .../spark/sql/catalyst/ScalaReflectionSuite.scala | 14 +- .../sql/catalyst/analysis/AnalysisSuite.scala | 6 +- .../catalyst/analysis/DecimalPrecisionSuite.scala | 88 ++++++ .../catalyst/analysis/HiveTypeCoercionSuite.scala | 17 +- .../expressions/ExpressionEvaluationSuite.scala | 90 +++++- .../sql/catalyst/types/decimal/DecimalSuite.scala | 158 ++++++++++ .../org/apache/spark/sql/api/java/DataType.java | 5 - .../org/apache/spark/sql/api/java/DecimalType.java | 58 +++- .../scala/org/apache/spark/sql/SchemaRDD.scala | 3 +- .../apache/spark/sql/api/java/JavaSQLContext.scala | 2 +- .../scala/org/apache/spark/sql/api/java/Row.scala | 4 + .../spark/sql/execution/GeneratedAggregate.scala | 41 ++- .../org/apache/spark/sql/execution/SparkPlan.scala | 4 +- .../spark/sql/execution/SparkSqlSerializer.scala | 2 + .../spark/sql/execution/basicOperators.scala | 7 +- .../sql/execution/joins/BroadcastHashJoin.scala | 3 +- .../apache/spark/sql/execution/pythonUdfs.scala | 6 +- .../scala/org/apache/spark/sql/json/JsonRDD.scala | 20 +- .../main/scala/org/apache/spark/sql/package.scala | 14 + .../spark/sql/parquet/ParquetConverter.scala | 43 +++ .../spark/sql/parquet/ParquetTableSupport.scala | 28 ++ .../apache/spark/sql/parquet/ParquetTypes.scala | 79 +++-- .../spark/sql/types/util/DataTypeConversions.scala | 13 +- .../spark/sql/api/java/JavaApplySchemaSuite.java | 2 +- .../api/java/JavaSideDataTypeConversionSuite.java | 9 +- .../scala/org/apache/spark/sql/DataTypeSuite.scala | 2 +- .../spark/sql/ScalaReflectionRelationSuite.scala | 5 +- .../apache/spark/sql/api/java/JavaSQLSuite.scala | 2 + .../java/ScalaSideDataTypeConversionSuite.scala | 4 +- .../org/apache/spark/sql/json/JsonSuite.scala | 46 +-- .../spark/sql/parquet/ParquetQuerySuite.scala | 35 ++- .../server/SparkSQLOperationManager.scala | 4 +- .../spark/sql/hive/thriftserver/Shim12.scala | 4 +- .../spark/sql/hive/thriftserver/Shim13.scala | 2 +- .../org/apache/spark/sql/hive/HiveContext.scala | 9 +- .../org/apache/spark/sql/hive/HiveInspectors.scala | 24 +- .../spark/sql/hive/HiveMetastoreCatalog.scala | 14 +- .../scala/org/apache/spark/sql/hive/HiveQl.scala | 15 +- .../sql/hive/execution/InsertIntoHiveTable.scala | 3 +- .../scala/org/apache/spark/sql/hive/Shim12.scala | 22 +- .../scala/org/apache/spark/sql/hive/Shim13.scala | 39 ++- 54 files changed, 1604 insertions(+), 229 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/decimal/Decimal.scala create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/types/decimal/DecimalSuite.scala (limited to 'sql') diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 75923d9e8d..8fbdf664b7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -19,9 +19,10 @@ package org.apache.spark.sql.catalyst import java.sql.{Date, Timestamp} -import org.apache.spark.sql.catalyst.expressions.{GenericRow, Attribute, AttributeReference} +import org.apache.spark.sql.catalyst.expressions.{GenericRow, Attribute, AttributeReference, Row} import org.apache.spark.sql.catalyst.plans.logical.LocalRelation import org.apache.spark.sql.catalyst.types._ +import org.apache.spark.sql.catalyst.types.decimal.Decimal /** * Provides experimental support for generating catalyst schemas for scala objects. @@ -40,9 +41,20 @@ object ScalaReflection { case s: Seq[_] => s.map(convertToCatalyst) case m: Map[_, _] => m.map { case (k, v) => convertToCatalyst(k) -> convertToCatalyst(v) } case p: Product => new GenericRow(p.productIterator.map(convertToCatalyst).toArray) + case d: BigDecimal => Decimal(d) case other => other } + /** Converts Catalyst types used internally in rows to standard Scala types */ + def convertToScala(a: Any): Any = a match { + case s: Seq[_] => s.map(convertToScala) + case m: Map[_, _] => m.map { case (k, v) => convertToScala(k) -> convertToScala(v) } + case d: Decimal => d.toBigDecimal + case other => other + } + + def convertRowToScala(r: Row): Row = new GenericRow(r.toArray.map(convertToScala)) + /** Returns a Sequence of attributes for the given case class type. */ def attributesFor[T: TypeTag]: Seq[Attribute] = schemaFor[T] match { case Schema(s: StructType, _) => @@ -83,7 +95,8 @@ object ScalaReflection { case t if t <:< typeOf[String] => Schema(StringType, nullable = true) case t if t <:< typeOf[Timestamp] => Schema(TimestampType, nullable = true) case t if t <:< typeOf[Date] => Schema(DateType, nullable = true) - case t if t <:< typeOf[BigDecimal] => Schema(DecimalType, nullable = true) + case t if t <:< typeOf[BigDecimal] => Schema(DecimalType.Unlimited, nullable = true) + case t if t <:< typeOf[Decimal] => Schema(DecimalType.Unlimited, nullable = true) case t if t <:< typeOf[java.lang.Integer] => Schema(IntegerType, nullable = true) case t if t <:< typeOf[java.lang.Long] => Schema(LongType, nullable = true) case t if t <:< typeOf[java.lang.Double] => Schema(DoubleType, nullable = true) @@ -111,8 +124,9 @@ object ScalaReflection { case obj: LongType.JvmType => LongType case obj: FloatType.JvmType => FloatType case obj: DoubleType.JvmType => DoubleType - case obj: DecimalType.JvmType => DecimalType case obj: DateType.JvmType => DateType + case obj: BigDecimal => DecimalType.Unlimited + case obj: Decimal => DecimalType.Unlimited case obj: TimestampType.JvmType => TimestampType case null => NullType // For other cases, there is no obvious mapping from the type of the given object to a diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala index b1e7570f57..00fc4d75c9 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -52,11 +52,13 @@ class SqlParser extends AbstractSparkSQLParser { protected val CASE = Keyword("CASE") protected val CAST = Keyword("CAST") protected val COUNT = Keyword("COUNT") + protected val DECIMAL = Keyword("DECIMAL") protected val DESC = Keyword("DESC") protected val DISTINCT = Keyword("DISTINCT") protected val ELSE = Keyword("ELSE") protected val END = Keyword("END") protected val EXCEPT = Keyword("EXCEPT") + protected val DOUBLE = Keyword("DOUBLE") protected val FALSE = Keyword("FALSE") protected val FIRST = Keyword("FIRST") protected val FROM = Keyword("FROM") @@ -385,5 +387,15 @@ class SqlParser extends AbstractSparkSQLParser { } protected lazy val dataType: Parser[DataType] = - STRING ^^^ StringType | TIMESTAMP ^^^ TimestampType + ( STRING ^^^ StringType + | TIMESTAMP ^^^ TimestampType + | DOUBLE ^^^ DoubleType + | fixedDecimalType + | DECIMAL ^^^ DecimalType.Unlimited + ) + + protected lazy val fixedDecimalType: Parser[DataType] = + (DECIMAL ~ "(" ~> numericLit) ~ ("," ~> numericLit <~ ")") ^^ { + case precision ~ scale => DecimalType(precision.toInt, scale.toInt) + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 2b69c02b28..e38114ab3c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -25,19 +25,31 @@ import org.apache.spark.sql.catalyst.types._ object HiveTypeCoercion { // See https://cwiki.apache.org/confluence/display/Hive/LanguageManual+Types. // The conversion for integral and floating point types have a linear widening hierarchy: - val numericPrecedence = - Seq(ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType, DecimalType) - val allPromotions: Seq[Seq[DataType]] = numericPrecedence :: Nil + private val numericPrecedence = + Seq(ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType, DecimalType.Unlimited) + /** + * Find the tightest common type of two types that might be used in a binary expression. + * This handles all numeric types except fixed-precision decimals interacting with each other or + * with primitive types, because in that case the precision and scale of the result depends on + * the operation. Those rules are implemented in [[HiveTypeCoercion.DecimalPrecision]]. + */ def findTightestCommonType(t1: DataType, t2: DataType): Option[DataType] = { val valueTypes = Seq(t1, t2).filter(t => t != NullType) if (valueTypes.distinct.size > 1) { - // Try and find a promotion rule that contains both types in question. - val applicableConversion = - HiveTypeCoercion.allPromotions.find(p => p.contains(t1) && p.contains(t2)) - - // If found return the widest common type, otherwise None - applicableConversion.map(_.filter(t => t == t1 || t == t2).last) + // Promote numeric types to the highest of the two and all numeric types to unlimited decimal + if (numericPrecedence.contains(t1) && numericPrecedence.contains(t2)) { + Some(numericPrecedence.filter(t => t == t1 || t == t2).last) + } else if (t1.isInstanceOf[DecimalType] && t2.isInstanceOf[DecimalType]) { + // Fixed-precision decimals can up-cast into unlimited + if (t1 == DecimalType.Unlimited || t2 == DecimalType.Unlimited) { + Some(DecimalType.Unlimited) + } else { + None + } + } else { + None + } } else { Some(if (valueTypes.size == 0) NullType else valueTypes.head) } @@ -59,6 +71,7 @@ trait HiveTypeCoercion { ConvertNaNs :: WidenTypes :: PromoteStrings :: + DecimalPrecision :: BooleanComparisons :: BooleanCasts :: StringToIntegralCasts :: @@ -151,6 +164,7 @@ trait HiveTypeCoercion { import HiveTypeCoercion._ def apply(plan: LogicalPlan): LogicalPlan = plan transform { + // TODO: unions with fixed-precision decimals case u @ Union(left, right) if u.childrenResolved && !u.resolved => val castedInput = left.output.zip(right.output).map { // When a string is found on one side, make the other side a string too. @@ -265,6 +279,110 @@ trait HiveTypeCoercion { } } + // scalastyle:off + /** + * Calculates and propagates precision for fixed-precision decimals. Hive has a number of + * rules for this based on the SQL standard and MS SQL: + * https://cwiki.apache.org/confluence/download/attachments/27362075/Hive_Decimal_Precision_Scale_Support.pdf + * + * In particular, if we have expressions e1 and e2 with precision/scale p1/s2 and p2/s2 + * respectively, then the following operations have the following precision / scale: + * + * Operation Result Precision Result Scale + * ------------------------------------------------------------------------ + * e1 + e2 max(s1, s2) + max(p1-s1, p2-s2) + 1 max(s1, s2) + * e1 - e2 max(s1, s2) + max(p1-s1, p2-s2) + 1 max(s1, s2) + * e1 * e2 p1 + p2 + 1 s1 + s2 + * e1 / e2 p1 - s1 + s2 + max(6, s1 + p2 + 1) max(6, s1 + p2 + 1) + * e1 % e2 min(p1-s1, p2-s2) + max(s1, s2) max(s1, s2) + * sum(e1) p1 + 10 s1 + * avg(e1) p1 + 4 s1 + 4 + * + * Catalyst also has unlimited-precision decimals. For those, all ops return unlimited precision. + * + * To implement the rules for fixed-precision types, we introduce casts to turn them to unlimited + * precision, do the math on unlimited-precision numbers, then introduce casts back to the + * required fixed precision. This allows us to do all rounding and overflow handling in the + * cast-to-fixed-precision operator. + * + * In addition, when mixing non-decimal types with decimals, we use the following rules: + * - BYTE gets turned into DECIMAL(3, 0) + * - SHORT gets turned into DECIMAL(5, 0) + * - INT gets turned into DECIMAL(10, 0) + * - LONG gets turned into DECIMAL(20, 0) + * - FLOAT and DOUBLE cause fixed-length decimals to turn into DOUBLE (this is the same as Hive, + * but note that unlimited decimals are considered bigger than doubles in WidenTypes) + */ + // scalastyle:on + object DecimalPrecision extends Rule[LogicalPlan] { + import scala.math.{max, min} + + // Conversion rules for integer types into fixed-precision decimals + val intTypeToFixed: Map[DataType, DecimalType] = Map( + ByteType -> DecimalType(3, 0), + ShortType -> DecimalType(5, 0), + IntegerType -> DecimalType(10, 0), + LongType -> DecimalType(20, 0) + ) + + def isFloat(t: DataType): Boolean = t == FloatType || t == DoubleType + + def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + // Skip nodes whose children have not been resolved yet + case e if !e.childrenResolved => e + + case Add(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => + Cast( + Add(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)), + DecimalType(max(s1, s2) + max(p1 - s1, p2 - s2) + 1, max(s1, s2)) + ) + + case Subtract(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => + Cast( + Subtract(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)), + DecimalType(max(s1, s2) + max(p1 - s1, p2 - s2) + 1, max(s1, s2)) + ) + + case Multiply(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => + Cast( + Multiply(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)), + DecimalType(p1 + p2 + 1, s1 + s2) + ) + + case Divide(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => + Cast( + Divide(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)), + DecimalType(p1 - s1 + s2 + max(6, s1 + p2 + 1), max(6, s1 + p2 + 1)) + ) + + case Remainder(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) => + Cast( + Remainder(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)), + DecimalType(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2)) + ) + + // Promote integers inside a binary expression with fixed-precision decimals to decimals, + // and fixed-precision decimals in an expression with floats / doubles to doubles + case b: BinaryExpression if b.left.dataType != b.right.dataType => + (b.left.dataType, b.right.dataType) match { + case (t, DecimalType.Fixed(p, s)) if intTypeToFixed.contains(t) => + b.makeCopy(Array(Cast(b.left, intTypeToFixed(t)), b.right)) + case (DecimalType.Fixed(p, s), t) if intTypeToFixed.contains(t) => + b.makeCopy(Array(b.left, Cast(b.right, intTypeToFixed(t)))) + case (t, DecimalType.Fixed(p, s)) if isFloat(t) => + b.makeCopy(Array(b.left, Cast(b.right, DoubleType))) + case (DecimalType.Fixed(p, s), t) if isFloat(t) => + b.makeCopy(Array(Cast(b.left, DoubleType), b.right)) + case _ => + b + } + + // TODO: MaxOf, MinOf, etc might want other rules + + // SUM and AVERAGE are handled by the implementations of those expressions + } + } + /** * Changes Boolean values to Bytes so that expressions like true < false can be Evaluated. */ @@ -330,7 +448,7 @@ trait HiveTypeCoercion { case e if !e.childrenResolved => e case Cast(e @ StringType(), t: IntegralType) => - Cast(Cast(e, DecimalType), t) + Cast(Cast(e, DecimalType.Unlimited), t) } } @@ -383,10 +501,12 @@ trait HiveTypeCoercion { // Decimal and Double remain the same case d: Divide if d.resolved && d.dataType == DoubleType => d - case d: Divide if d.resolved && d.dataType == DecimalType => d + case d: Divide if d.resolved && d.dataType.isInstanceOf[DecimalType] => d - case Divide(l, r) if l.dataType == DecimalType => Divide(l, Cast(r, DecimalType)) - case Divide(l, r) if r.dataType == DecimalType => Divide(Cast(l, DecimalType), r) + case Divide(l, r) if l.dataType.isInstanceOf[DecimalType] => + Divide(l, Cast(r, DecimalType.Unlimited)) + case Divide(l, r) if r.dataType.isInstanceOf[DecimalType] => + Divide(Cast(l, DecimalType.Unlimited), r) case Divide(l, r) => Divide(Cast(l, DoubleType), Cast(r, DoubleType)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 23cfd483ec..7e6d770314 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.catalyst import java.sql.{Date, Timestamp} +import org.apache.spark.sql.catalyst.types.decimal.Decimal + import scala.language.implicitConversions import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute @@ -124,7 +126,8 @@ package object dsl { implicit def doubleToLiteral(d: Double) = Literal(d) implicit def stringToLiteral(s: String) = Literal(s) implicit def dateToLiteral(d: Date) = Literal(d) - implicit def decimalToLiteral(d: BigDecimal) = Literal(d) + implicit def bigDecimalToLiteral(d: BigDecimal) = Literal(d) + implicit def decimalToLiteral(d: Decimal) = Literal(d) implicit def timestampToLiteral(t: Timestamp) = Literal(t) implicit def binaryToLiteral(a: Array[Byte]) = Literal(a) @@ -183,7 +186,11 @@ package object dsl { def date = AttributeReference(s, DateType, nullable = true)() /** Creates a new AttributeReference of type decimal */ - def decimal = AttributeReference(s, DecimalType, nullable = true)() + def decimal = AttributeReference(s, DecimalType.Unlimited, nullable = true)() + + /** Creates a new AttributeReference of type decimal */ + def decimal(precision: Int, scale: Int) = + AttributeReference(s, DecimalType(precision, scale), nullable = true)() /** Creates a new AttributeReference of type timestamp */ def timestamp = AttributeReference(s, TimestampType, nullable = true)() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 8e5baf0eb8..2200966619 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -23,6 +23,7 @@ import java.text.{DateFormat, SimpleDateFormat} import org.apache.spark.Logging import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.types._ +import org.apache.spark.sql.catalyst.types.decimal.Decimal /** Cast the child expression to the target data type. */ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression with Logging { @@ -36,6 +37,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w case (BooleanType, DateType) => true case (DateType, _: NumericType) => true case (DateType, BooleanType) => true + case (_, DecimalType.Fixed(_, _)) => true // TODO: not all upcasts here can really give null case _ => child.nullable } @@ -76,8 +78,8 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w buildCast[Short](_, _ != 0) case ByteType => buildCast[Byte](_, _ != 0) - case DecimalType => - buildCast[BigDecimal](_, _ != 0) + case DecimalType() => + buildCast[Decimal](_, _ != 0) case DoubleType => buildCast[Double](_, _ != 0) case FloatType => @@ -109,19 +111,19 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w case DateType => buildCast[Date](_, d => new Timestamp(d.getTime)) // TimestampWritable.decimalToTimestamp - case DecimalType => - buildCast[BigDecimal](_, d => decimalToTimestamp(d)) + case DecimalType() => + buildCast[Decimal](_, d => decimalToTimestamp(d)) // TimestampWritable.doubleToTimestamp case DoubleType => - buildCast[Double](_, d => decimalToTimestamp(d)) + buildCast[Double](_, d => decimalToTimestamp(Decimal(d))) // TimestampWritable.floatToTimestamp case FloatType => - buildCast[Float](_, f => decimalToTimestamp(f)) + buildCast[Float](_, f => decimalToTimestamp(Decimal(f))) } - private[this] def decimalToTimestamp(d: BigDecimal) = { + private[this] def decimalToTimestamp(d: Decimal) = { val seconds = Math.floor(d.toDouble).toLong - val bd = (d - seconds) * 1000000000 + val bd = (d.toBigDecimal - seconds) * 1000000000 val nanos = bd.intValue() val millis = seconds * 1000 @@ -196,8 +198,8 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w buildCast[Date](_, d => dateToLong(d)) case TimestampType => buildCast[Timestamp](_, t => timestampToLong(t)) - case DecimalType => - buildCast[BigDecimal](_, _.toLong) + case DecimalType() => + buildCast[Decimal](_, _.toLong) case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toLong(b) } @@ -214,8 +216,8 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w buildCast[Date](_, d => dateToLong(d)) case TimestampType => buildCast[Timestamp](_, t => timestampToLong(t).toInt) - case DecimalType => - buildCast[BigDecimal](_, _.toInt) + case DecimalType() => + buildCast[Decimal](_, _.toInt) case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b) } @@ -232,8 +234,8 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w buildCast[Date](_, d => dateToLong(d)) case TimestampType => buildCast[Timestamp](_, t => timestampToLong(t).toShort) - case DecimalType => - buildCast[BigDecimal](_, _.toShort) + case DecimalType() => + buildCast[Decimal](_, _.toShort) case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b).toShort } @@ -250,27 +252,45 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w buildCast[Date](_, d => dateToLong(d)) case TimestampType => buildCast[Timestamp](_, t => timestampToLong(t).toByte) - case DecimalType => - buildCast[BigDecimal](_, _.toByte) + case DecimalType() => + buildCast[Decimal](_, _.toByte) case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b).toByte } - // DecimalConverter - private[this] def castToDecimal: Any => Any = child.dataType match { + /** + * Change the precision / scale in a given decimal to those set in `decimalType` (if any), + * returning null if it overflows or modifying `value` in-place and returning it if successful. + * + * NOTE: this modifies `value` in-place, so don't call it on external data. + */ + private[this] def changePrecision(value: Decimal, decimalType: DecimalType): Decimal = { + decimalType match { + case DecimalType.Unlimited => + value + case DecimalType.Fixed(precision, scale) => + if (value.changePrecision(precision, scale)) value else null + } + } + + private[this] def castToDecimal(target: DecimalType): Any => Any = child.dataType match { case StringType => - buildCast[String](_, s => try BigDecimal(s.toDouble) catch { + buildCast[String](_, s => try changePrecision(Decimal(s.toDouble), target) catch { case _: NumberFormatException => null }) case BooleanType => - buildCast[Boolean](_, b => if (b) BigDecimal(1) else BigDecimal(0)) + buildCast[Boolean](_, b => changePrecision(if (b) Decimal(1) else Decimal(0), target)) case DateType => - buildCast[Date](_, d => dateToDouble(d)) + buildCast[Date](_, d => changePrecision(null, target)) // date can't cast to decimal in Hive case TimestampType => // Note that we lose precision here. - buildCast[Timestamp](_, t => BigDecimal(timestampToDouble(t))) - case x: NumericType => - b => BigDecimal(x.numeric.asInstanceOf[Numeric[Any]].toDouble(b)) + buildCast[Timestamp](_, t => changePrecision(Decimal(timestampToDouble(t)), target)) + case DecimalType() => + b => changePrecision(b.asInstanceOf[Decimal].clone(), target) + case LongType => + b => changePrecision(Decimal(b.asInstanceOf[Long]), target) + case x: NumericType => // All other numeric types can be represented precisely as Doubles + b => changePrecision(Decimal(x.numeric.asInstanceOf[Numeric[Any]].toDouble(b)), target) } // DoubleConverter @@ -285,8 +305,8 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w buildCast[Date](_, d => dateToDouble(d)) case TimestampType => buildCast[Timestamp](_, t => timestampToDouble(t)) - case DecimalType => - buildCast[BigDecimal](_, _.toDouble) + case DecimalType() => + buildCast[Decimal](_, _.toDouble) case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toDouble(b) } @@ -303,8 +323,8 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w buildCast[Date](_, d => dateToDouble(d)) case TimestampType => buildCast[Timestamp](_, t => timestampToDouble(t).toFloat) - case DecimalType => - buildCast[BigDecimal](_, _.toFloat) + case DecimalType() => + buildCast[Decimal](_, _.toFloat) case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toFloat(b) } @@ -313,8 +333,8 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w case dt if dt == child.dataType => identity[Any] case StringType => castToString case BinaryType => castToBinary - case DecimalType => castToDecimal case DateType => castToDate + case decimal: DecimalType => castToDecimal(decimal) case TimestampType => castToTimestamp case BooleanType => castToBoolean case ByteType => castToByte diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala index 1b4d892625..2b364fc1df 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala @@ -286,18 +286,38 @@ case class ApproxCountDistinct(child: Expression, relativeSD: Double = 0.05) case class Average(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { override def nullable = false - override def dataType = DoubleType + + override def dataType = child.dataType match { + case DecimalType.Fixed(precision, scale) => + DecimalType(precision + 4, scale + 4) // Add 4 digits after decimal point, like Hive + case DecimalType.Unlimited => + DecimalType.Unlimited + case _ => + DoubleType + } + override def toString = s"AVG($child)" override def asPartial: SplitEvaluation = { val partialSum = Alias(Sum(child), "PartialSum")() val partialCount = Alias(Count(child), "PartialCount")() - val castedSum = Cast(Sum(partialSum.toAttribute), dataType) - val castedCount = Cast(Sum(partialCount.toAttribute), dataType) - SplitEvaluation( - Divide(castedSum, castedCount), - partialCount :: partialSum :: Nil) + child.dataType match { + case DecimalType.Fixed(_, _) => + // Turn the results to unlimited decimals for the divsion, before going back to fixed + val castedSum = Cast(Sum(partialSum.toAttribute), DecimalType.Unlimited) + val castedCount = Cast(Sum(partialCount.toAttribute), DecimalType.Unlimited) + SplitEvaluation( + Cast(Divide(castedSum, castedCount), dataType), + partialCount :: partialSum :: Nil) + + case _ => + val castedSum = Cast(Sum(partialSum.toAttribute), dataType) + val castedCount = Cast(Sum(partialCount.toAttribute), dataType) + SplitEvaluation( + Divide(castedSum, castedCount), + partialCount :: partialSum :: Nil) + } } override def newInstance() = new AverageFunction(child, this) @@ -306,7 +326,16 @@ case class Average(child: Expression) extends PartialAggregate with trees.UnaryN case class Sum(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { override def nullable = false - override def dataType = child.dataType + + override def dataType = child.dataType match { + case DecimalType.Fixed(precision, scale) => + DecimalType(precision + 10, scale) // Add 10 digits left of decimal point, like Hive + case DecimalType.Unlimited => + DecimalType.Unlimited + case _ => + child.dataType + } + override def toString = s"SUM($child)" override def asPartial: SplitEvaluation = { @@ -322,9 +351,17 @@ case class Sum(child: Expression) extends PartialAggregate with trees.UnaryNode[ case class SumDistinct(child: Expression) extends AggregateExpression with trees.UnaryNode[Expression] { - override def nullable = false - override def dataType = child.dataType + + override def dataType = child.dataType match { + case DecimalType.Fixed(precision, scale) => + DecimalType(precision + 10, scale) // Add 10 digits left of decimal point, like Hive + case DecimalType.Unlimited => + DecimalType.Unlimited + case _ => + child.dataType + } + override def toString = s"SUM(DISTINCT $child)" override def newInstance() = new SumDistinctFunction(child, this) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 83e8466ec2..8574cabc43 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -36,7 +36,7 @@ case class UnaryMinus(child: Expression) extends UnaryExpression { case class Sqrt(child: Expression) extends UnaryExpression { type EvaluatedType = Any - + def dataType = DoubleType override def foldable = child.foldable def nullable = child.nullable @@ -55,7 +55,9 @@ abstract class BinaryArithmetic extends BinaryExpression { def nullable = left.nullable || right.nullable override lazy val resolved = - left.resolved && right.resolved && left.dataType == right.dataType + left.resolved && right.resolved && + left.dataType == right.dataType && + !DecimalType.isFixed(left.dataType) def dataType = { if (!resolved) { @@ -104,6 +106,8 @@ case class Multiply(left: Expression, right: Expression) extends BinaryArithmeti case class Divide(left: Expression, right: Expression) extends BinaryArithmetic { def symbol = "/" + override def nullable = left.nullable || right.nullable || dataType.isInstanceOf[DecimalType] + override def eval(input: Row): Any = dataType match { case _: FractionalType => f2(input, left, right, _.div(_, _)) case _: IntegralType => i2(input, left , right, _.quot(_, _)) @@ -114,6 +118,8 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic case class Remainder(left: Expression, right: Expression) extends BinaryArithmetic { def symbol = "%" + override def nullable = left.nullable || right.nullable || dataType.isInstanceOf[DecimalType] + override def eval(input: Row): Any = i2(input, left, right, _.rem(_, _)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 5a3f013c34..67f8d411b6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions.codegen import com.google.common.cache.{CacheLoader, CacheBuilder} +import org.apache.spark.sql.catalyst.types.decimal.Decimal import scala.language.existentials @@ -485,6 +486,34 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin } """.children + case UnscaledValue(child) => + val childEval = expressionEvaluator(child) + + childEval.code ++ + q""" + var $nullTerm = ${childEval.nullTerm} + var $primitiveTerm: Long = if (!$nullTerm) { + ${childEval.primitiveTerm}.toUnscaledLong + } else { + ${defaultPrimitive(LongType)} + } + """.children + + case MakeDecimal(child, precision, scale) => + val childEval = expressionEvaluator(child) + + childEval.code ++ + q""" + var $nullTerm = ${childEval.nullTerm} + var $primitiveTerm: org.apache.spark.sql.catalyst.types.decimal.Decimal = + ${defaultPrimitive(DecimalType())} + + if (!$nullTerm) { + $primitiveTerm = new org.apache.spark.sql.catalyst.types.decimal.Decimal() + $primitiveTerm = $primitiveTerm.setOrNull(${childEval.primitiveTerm}, $precision, $scale) + $nullTerm = $primitiveTerm == null + } + """.children } // If there was no match in the partial function above, we fall back on calling the interpreted @@ -562,7 +591,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin case LongType => ru.Literal(Constant(1L)) case ByteType => ru.Literal(Constant(-1.toByte)) case DoubleType => ru.Literal(Constant(-1.toDouble)) - case DecimalType => ru.Literal(Constant(-1)) // Will get implicity converted as needed. + case DecimalType() => q"org.apache.spark.sql.catalyst.types.decimal.Decimal(-1)" case IntegerType => ru.Literal(Constant(-1)) case _ => ru.Literal(Constant(null)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala new file mode 100644 index 0000000000..d1eab2eb4e --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.sql.catalyst.types.decimal.Decimal +import org.apache.spark.sql.catalyst.types.{DecimalType, LongType, DoubleType, DataType} + +/** Return the unscaled Long value of a Decimal, assuming it fits in a Long */ +case class UnscaledValue(child: Expression) extends UnaryExpression { + override type EvaluatedType = Any + + override def dataType: DataType = LongType + override def foldable = child.foldable + def nullable = child.nullable + override def toString = s"UnscaledValue($child)" + + override def eval(input: Row): Any = { + val childResult = child.eval(input) + if (childResult == null) { + null + } else { + childResult.asInstanceOf[Decimal].toUnscaledLong + } + } +} + +/** Create a Decimal from an unscaled Long value */ +case class MakeDecimal(child: Expression, precision: Int, scale: Int) extends UnaryExpression { + override type EvaluatedType = Decimal + + override def dataType: DataType = DecimalType(precision, scale) + override def foldable = child.foldable + def nullable = child.nullable + override def toString = s"MakeDecimal($child,$precision,$scale)" + + override def eval(input: Row): Decimal = { + val childResult = child.eval(input) + if (childResult == null) { + null + } else { + new Decimal().setOrNull(childResult.asInstanceOf[Long], precision, scale) + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index ba240233ca..93c1932515 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import java.sql.{Date, Timestamp} import org.apache.spark.sql.catalyst.types._ +import org.apache.spark.sql.catalyst.types.decimal.Decimal object Literal { def apply(v: Any): Literal = v match { @@ -31,7 +32,8 @@ object Literal { case s: Short => Literal(s, ShortType) case s: String => Literal(s, StringType) case b: Boolean => Literal(b, BooleanType) - case d: BigDecimal => Literal(d, DecimalType) + case d: BigDecimal => Literal(Decimal(d), DecimalType.Unlimited) + case d: Decimal => Literal(d, DecimalType.Unlimited) case t: Timestamp => Literal(t, TimestampType) case d: Date => Literal(d, DateType) case a: Array[Byte] => Literal(a, BinaryType) @@ -62,7 +64,7 @@ case class Literal(value: Any, dataType: DataType) extends LeafExpression { } // TODO: Specialize -case class MutableLiteral(var value: Any, dataType: DataType, nullable: Boolean = true) +case class MutableLiteral(var value: Any, dataType: DataType, nullable: Boolean = true) extends LeafExpression { type EvaluatedType = Any diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 9ce7c78195..a4aa322fc5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.plans.LeftSemi import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.catalyst.types._ +import org.apache.spark.sql.catalyst.types.decimal.Decimal abstract class Optimizer extends RuleExecutor[LogicalPlan] @@ -43,6 +44,8 @@ object DefaultOptimizer extends Optimizer { SimplifyCasts, SimplifyCaseConversionExpressions, OptimizeIn) :: + Batch("Decimal Optimizations", FixedPoint(100), + DecimalAggregates) :: Batch("Filter Pushdown", FixedPoint(100), UnionPushdown, CombineFilters, @@ -390,9 +393,9 @@ object PushPredicateThroughProject extends Rule[LogicalPlan] { * evaluated using only the attributes of the left or right side of a join. Other * [[Filter]] conditions are moved into the `condition` of the [[Join]]. * - * And also Pushes down the join filter, where the `condition` can be evaluated using only the - * attributes of the left or right side of sub query when applicable. - * + * And also Pushes down the join filter, where the `condition` can be evaluated using only the + * attributes of the left or right side of sub query when applicable. + * * Check https://cwiki.apache.org/confluence/display/Hive/OuterJoinBehavior for more details */ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { @@ -404,7 +407,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { private def split(condition: Seq[Expression], left: LogicalPlan, right: LogicalPlan) = { val (leftEvaluateCondition, rest) = condition.partition(_.references subsetOf left.outputSet) - val (rightEvaluateCondition, commonCondition) = + val (rightEvaluateCondition, commonCondition) = rest.partition(_.references subsetOf right.outputSet) (leftEvaluateCondition, rightEvaluateCondition, commonCondition) @@ -413,7 +416,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { def apply(plan: LogicalPlan): LogicalPlan = plan transform { // push the where condition down into join filter case f @ Filter(filterCondition, Join(left, right, joinType, joinCondition)) => - val (leftFilterConditions, rightFilterConditions, commonFilterCondition) = + val (leftFilterConditions, rightFilterConditions, commonFilterCondition) = split(splitConjunctivePredicates(filterCondition), left, right) joinType match { @@ -451,7 +454,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { // push down the join filter into sub query scanning if applicable case f @ Join(left, right, joinType, joinCondition) => - val (leftJoinConditions, rightJoinConditions, commonJoinCondition) = + val (leftJoinConditions, rightJoinConditions, commonJoinCondition) = split(joinCondition.map(splitConjunctivePredicates).getOrElse(Nil), left, right) joinType match { @@ -519,3 +522,26 @@ object SimplifyCaseConversionExpressions extends Rule[LogicalPlan] { } } } + +/** + * Speeds up aggregates on fixed-precision decimals by executing them on unscaled Long values. + * + * This uses the same rules for increasing the precision and scale of the output as + * [[org.apache.spark.sql.catalyst.analysis.HiveTypeCoercion.DecimalPrecision]]. + */ +object DecimalAggregates extends Rule[LogicalPlan] { + import Decimal.MAX_LONG_DIGITS + + /** Maximum number of decimal digits representable precisely in a Double */ + val MAX_DOUBLE_DIGITS = 15 + + def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + case Sum(e @ DecimalType.Expression(prec, scale)) if prec + 10 <= MAX_LONG_DIGITS => + MakeDecimal(Sum(UnscaledValue(e)), prec + 10, scale) + + case Average(e @ DecimalType.Expression(prec, scale)) if prec + 4 <= MAX_DOUBLE_DIGITS => + Cast( + Divide(Average(UnscaledValue(e)), Literal(math.pow(10.0, scale), DoubleType)), + DecimalType(prec + 4, scale + 4)) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala index 6069f9b0a6..8dda0b1828 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.types import java.sql.{Date, Timestamp} -import scala.math.Numeric.{BigDecimalAsIfIntegral, DoubleAsIfIntegral, FloatAsIfIntegral} +import scala.math.Numeric.{FloatAsIfIntegral, BigDecimalAsIfIntegral, DoubleAsIfIntegral} import scala.reflect.ClassTag import scala.reflect.runtime.universe.{TypeTag, runtimeMirror, typeTag} import scala.util.parsing.combinator.RegexParsers @@ -33,6 +33,7 @@ import org.apache.spark.sql.catalyst.ScalaReflectionLock import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression} import org.apache.spark.sql.catalyst.util.Metadata import org.apache.spark.util.Utils +import org.apache.spark.sql.catalyst.types.decimal._ object DataType { def fromJson(json: String): DataType = parseDataType(parse(json)) @@ -91,11 +92,17 @@ object DataType { | "LongType" ^^^ LongType | "BinaryType" ^^^ BinaryType | "BooleanType" ^^^ BooleanType - | "DecimalType" ^^^ DecimalType | "DateType" ^^^ DateType + | "DecimalType()" ^^^ DecimalType.Unlimited + | fixedDecimalType | "TimestampType" ^^^ TimestampType ) + protected lazy val fixedDecimalType: Parser[DataType] = + ("DecimalType(" ~> "[0-9]+".r) ~ ("," ~> "[0-9]+".r <~ ")") ^^ { + case precision ~ scale => DecimalType(precision.toInt, scale.toInt) + } + protected lazy val arrayType: Parser[DataType] = "ArrayType" ~> "(" ~> dataType ~ "," ~ boolVal <~ ")" ^^ { case tpe ~ _ ~ containsNull => ArrayType(tpe, containsNull) @@ -200,10 +207,18 @@ trait PrimitiveType extends DataType { } object PrimitiveType { - private[sql] val all = Seq(DecimalType, DateType, TimestampType, BinaryType) ++ - NativeType.all - - private[sql] val nameToType = all.map(t => t.typeName -> t).toMap + private val nonDecimals = Seq(DateType, TimestampType, BinaryType) ++ NativeType.all + private val nonDecimalNameToType = nonDecimals.map(t => t.typeName -> t).toMap + + /** Given the string representation of a type, return its DataType */ + private[sql] def nameToType(name: String): DataType = { + val FIXED_DECIMAL = """decimal\(\s*(\d+)\s*,\s*(\d+)\s*\)""".r + name match { + case "decimal" => DecimalType.Unlimited + case FIXED_DECIMAL(precision, scale) => DecimalType(precision.toInt, scale.toInt) + case other => nonDecimalNameToType(other) + } + } } abstract class NativeType extends DataType { @@ -332,13 +347,58 @@ abstract class FractionalType extends NumericType { private[sql] val asIntegral: Integral[JvmType] } -case object DecimalType extends FractionalType { - private[sql] type JvmType = BigDecimal +/** Precision parameters for a Decimal */ +case class PrecisionInfo(precision: Int, scale: Int) + +/** A Decimal that might have fixed precision and scale, or unlimited values for these */ +case class DecimalType(precisionInfo: Option[PrecisionInfo]) extends FractionalType { + private[sql] type JvmType = Decimal @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] } - private[sql] val numeric = implicitly[Numeric[BigDecimal]] - private[sql] val fractional = implicitly[Fractional[BigDecimal]] - private[sql] val ordering = implicitly[Ordering[JvmType]] - private[sql] val asIntegral = BigDecimalAsIfIntegral + private[sql] val numeric = Decimal.DecimalIsFractional + private[sql] val fractional = Decimal.DecimalIsFractional + private[sql] val ordering = Decimal.DecimalIsFractional + private[sql] val asIntegral = Decimal.DecimalAsIfIntegral + + override def typeName: String = precisionInfo match { + case Some(PrecisionInfo(precision, scale)) => s"decimal($precision,$scale)" + case None => "decimal" + } + + override def toString: String = precisionInfo match { + case Some(PrecisionInfo(precision, scale)) => s"DecimalType($precision,$scale)" + case None => "DecimalType()" + } +} + +/** Extra factory methods and pattern matchers for Decimals */ +object DecimalType { + val Unlimited: DecimalType = DecimalType(None) + + object Fixed { + def unapply(t: DecimalType): Option[(Int, Int)] = + t.precisionInfo.map(p => (p.precision, p.scale)) + } + + object Expression { + def unapply(e: Expression): Option[(Int, Int)] = e.dataType match { + case t: DecimalType => t.precisionInfo.map(p => (p.precision, p.scale)) + case _ => None + } + } + + def apply(): DecimalType = Unlimited + + def apply(precision: Int, scale: Int): DecimalType = + DecimalType(Some(PrecisionInfo(precision, scale))) + + def unapply(t: DataType): Boolean = t.isInstanceOf[DecimalType] + + def unapply(e: Expression): Boolean = e.dataType.isInstanceOf[DecimalType] + + def isFixed(dataType: DataType): Boolean = dataType match { + case DecimalType.Fixed(_, _) => true + case _ => false + } } case object DoubleType extends FractionalType { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/decimal/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/decimal/Decimal.scala new file mode 100644 index 0000000000..708362acf3 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/decimal/Decimal.scala @@ -0,0 +1,335 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.types.decimal + +import org.apache.spark.annotation.DeveloperApi + +/** + * A mutable implementation of BigDecimal that can hold a Long if values are small enough. + * + * The semantics of the fields are as follows: + * - _precision and _scale represent the SQL precision and scale we are looking for + * - If decimalVal is set, it represents the whole decimal value + * - Otherwise, the decimal value is longVal / (10 ** _scale) + */ +final class Decimal extends Ordered[Decimal] with Serializable { + import Decimal.{MAX_LONG_DIGITS, POW_10, ROUNDING_MODE, BIG_DEC_ZERO} + + private var decimalVal: BigDecimal = null + private var longVal: Long = 0L + private var _precision: Int = 1 + private var _scale: Int = 0 + + def precision: Int = _precision + def scale: Int = _scale + + /** + * Set this Decimal to the given Long. Will have precision 20 and scale 0. + */ + def set(longVal: Long): Decimal = { + if (longVal <= -POW_10(MAX_LONG_DIGITS) || longVal >= POW_10(MAX_LONG_DIGITS)) { + // We can't represent this compactly as a long without risking overflow + this.decimalVal = BigDecimal(longVal) + this.longVal = 0L + } else { + this.decimalVal = null + this.longVal = longVal + } + this._precision = 20 + this._scale = 0 + this + } + + /** + * Set this Decimal to the given Int. Will have precision 10 and scale 0. + */ + def set(intVal: Int): Decimal = { + this.decimalVal = null + this.longVal = intVal + this._precision = 10 + this._scale = 0 + this + } + + /** + * Set this Decimal to the given unscaled Long, with a given precision and scale. + */ + def set(unscaled: Long, precision: Int, scale: Int): Decimal = { + if (setOrNull(unscaled, precision, scale) == null) { + throw new IllegalArgumentException("Unscaled value too large for precision") + } + this + } + + /** + * Set this Decimal to the given unscaled Long, with a given precision and scale, + * and return it, or return null if it cannot be set due to overflow. + */ + def setOrNull(unscaled: Long, precision: Int, scale: Int): Decimal = { + if (unscaled <= -POW_10(MAX_LONG_DIGITS) || unscaled >= POW_10(MAX_LONG_DIGITS)) { + // We can't represent this compactly as a long without risking overflow + if (precision < 19) { + return null // Requested precision is too low to represent this value + } + this.decimalVal = BigDecimal(longVal) + this.longVal = 0L + } else { + val p = POW_10(math.min(precision, MAX_LONG_DIGITS)) + if (unscaled <= -p || unscaled >= p) { + return null // Requested precision is too low to represent this value + } + this.decimalVal = null + this.longVal = unscaled + } + this._precision = precision + this._scale = scale + this + } + + /** + * Set this Decimal to the given BigDecimal value, with a given precision and scale. + */ + def set(decimal: BigDecimal, precision: Int, scale: Int): Decimal = { + this.decimalVal = decimal.setScale(scale, ROUNDING_MODE) + require(decimalVal.precision <= precision, "Overflowed precision") + this.longVal = 0L + this._precision = precision + this._scale = scale + this + } + + /** + * Set this Decimal to the given BigDecimal value, inheriting its precision and scale. + */ + def set(decimal: BigDecimal): Decimal = { + this.decimalVal = decimal + this.longVal = 0L + this._precision = decimal.precision + this._scale = decimal.scale + this + } + + /** + * Set this Decimal to the given Decimal value. + */ + def set(decimal: Decimal): Decimal = { + this.decimalVal = decimal.decimalVal + this.longVal = decimal.longVal + this._precision = decimal._precision + this._scale = decimal._scale + this + } + + def toBigDecimal: BigDecimal = { + if (decimalVal.ne(null)) { + decimalVal + } else { + BigDecimal(longVal, _scale) + } + } + + def toUnscaledLong: Long = { + if (decimalVal.ne(null)) { + decimalVal.underlying().unscaledValue().longValue() + } else { + longVal + } + } + + override def toString: String = toBigDecimal.toString() + + @DeveloperApi + def toDebugString: String = { + if (decimalVal.ne(null)) { + s"Decimal(expanded,$decimalVal,$precision,$scale})" + } else { + s"Decimal(compact,$longVal,$precision,$scale})" + } + } + + def toDouble: Double = toBigDecimal.doubleValue() + + def toFloat: Float = toBigDecimal.floatValue() + + def toLong: Long = { + if (decimalVal.eq(null)) { + longVal / POW_10(_scale) + } else { + decimalVal.longValue() + } + } + + def toInt: Int = toLong.toInt + + def toShort: Short = toLong.toShort + + def toByte: Byte = toLong.toByte + + /** + * Update precision and scale while keeping our value the same, and return true if successful. + * + * @return true if successful, false if overflow would occur + */ + def changePrecision(precision: Int, scale: Int): Boolean = { + // First, update our longVal if we can, or transfer over to using a BigDecimal + if (decimalVal.eq(null)) { + if (scale < _scale) { + // Easier case: we just need to divide our scale down + val diff = _scale - scale + val droppedDigits = longVal % POW_10(diff) + longVal /= POW_10(diff) + if (math.abs(droppedDigits) * 2 >= POW_10(diff)) { + longVal += (if (longVal < 0) -1L else 1L) + } + } else if (scale > _scale) { + // We might be able to multiply longVal by a power of 10 and not overflow, but if not, + // switch to using a BigDecimal + val diff = scale - _scale + val p = POW_10(math.max(MAX_LONG_DIGITS - diff, 0)) + if (diff <= MAX_LONG_DIGITS && longVal > -p && longVal < p) { + // Multiplying longVal by POW_10(diff) will still keep it below MAX_LONG_DIGITS + longVal *= POW_10(diff) + } else { + // Give up on using Longs; switch to BigDecimal, which we'll modify below + decimalVal = BigDecimal(longVal, _scale) + } + } + // In both cases, we will check whether our precision is okay below + } + + if (decimalVal.ne(null)) { + // We get here if either we started with a BigDecimal, or we switched to one because we would + // have overflowed our Long; in either case we must rescale decimalVal to the new scale. + val newVal = decimalVal.setScale(scale, ROUNDING_MODE) + if (newVal.precision > precision) { + return false + } + decimalVal = newVal + } else { + // We're still using Longs, but we should check whether we match the new precision + val p = POW_10(math.min(_precision, MAX_LONG_DIGITS)) + if (longVal <= -p || longVal >= p) { + // Note that we shouldn't have been able to fix this by switching to BigDecimal + return false + } + } + + _precision = precision + _scale = scale + true + } + + override def clone(): Decimal = new Decimal().set(this) + + override def compare(other: Decimal): Int = { + if (decimalVal.eq(null) && other.decimalVal.eq(null) && _scale == other._scale) { + if (longVal < other.longVal) -1 else if (longVal == other.longVal) 0 else 1 + } else { + toBigDecimal.compare(other.toBigDecimal) + } + } + + override def equals(other: Any) = other match { + case d: Decimal => + compare(d) == 0 + case _ => + false + } + + override def hashCode(): Int = toBigDecimal.hashCode() + + def isZero: Boolean = if (decimalVal.ne(null)) decimalVal == BIG_DEC_ZERO else longVal == 0 + + def + (that: Decimal): Decimal = Decimal(toBigDecimal + that.toBigDecimal) + + def - (that: Decimal): Decimal = Decimal(toBigDecimal - that.toBigDecimal) + + def * (that: Decimal): Decimal = Decimal(toBigDecimal * that.toBigDecimal) + + def / (that: Decimal): Decimal = + if (that.isZero) null else Decimal(toBigDecimal / that.toBigDecimal) + + def % (that: Decimal): Decimal = + if (that.isZero) null else Decimal(toBigDecimal % that.toBigDecimal) + + def remainder(that: Decimal): Decimal = this % that + + def unary_- : Decimal = { + if (decimalVal.ne(null)) { + Decimal(-decimalVal) + } else { + Decimal(-longVal, precision, scale) + } + } +} + +object Decimal { + private val ROUNDING_MODE = BigDecimal.RoundingMode.HALF_UP + + /** Maximum number of decimal digits a Long can represent */ + val MAX_LONG_DIGITS = 18 + + private val POW_10 = Array.tabulate[Long](MAX_LONG_DIGITS + 1)(i => math.pow(10, i).toLong) + + private val BIG_DEC_ZERO = BigDecimal(0) + + def apply(value: Double): Decimal = new Decimal().set(value) + + def apply(value: Long): Decimal = new Decimal().set(value) + + def apply(value: Int): Decimal = new Decimal().set(value) + + def apply(value: BigDecimal): Decimal = new Decimal().set(value) + + def apply(value: BigDecimal, precision: Int, scale: Int): Decimal = + new Decimal().set(value, precision, scale) + + def apply(unscaled: Long, precision: Int, scale: Int): Decimal = + new Decimal().set(unscaled, precision, scale) + + def apply(value: String): Decimal = new Decimal().set(BigDecimal(value)) + + // Evidence parameters for Decimal considered either as Fractional or Integral. We provide two + // parameters inheriting from a common trait since both traits define mkNumericOps. + // See scala.math's Numeric.scala for examples for Scala's built-in types. + + /** Common methods for Decimal evidence parameters */ + trait DecimalIsConflicted extends Numeric[Decimal] { + override def plus(x: Decimal, y: Decimal): Decimal = x + y + override def times(x: Decimal, y: Decimal): Decimal = x * y + override def minus(x: Decimal, y: Decimal): Decimal = x - y + override def negate(x: Decimal): Decimal = -x + override def toDouble(x: Decimal): Double = x.toDouble + override def toFloat(x: Decimal): Float = x.toFloat + override def toInt(x: Decimal): Int = x.toInt + override def toLong(x: Decimal): Long = x.toLong + override def fromInt(x: Int): Decimal = new Decimal().set(x) + override def compare(x: Decimal, y: Decimal): Int = x.compare(y) + } + + /** A [[scala.math.Fractional]] evidence parameter for Decimals. */ + object DecimalIsFractional extends DecimalIsConflicted with Fractional[Decimal] { + override def div(x: Decimal, y: Decimal): Decimal = x / y + } + + /** A [[scala.math.Integral]] evidence parameter for Decimals. */ + object DecimalAsIfIntegral extends DecimalIsConflicted with Integral[Decimal] { + override def quot(x: Decimal, y: Decimal): Decimal = x / y + override def rem(x: Decimal, y: Decimal): Decimal = x % y + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala index 430f0664b7..21b2c8e20d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala @@ -96,7 +96,7 @@ class ScalaReflectionSuite extends FunSuite { StructField("byteField", ByteType, nullable = true), StructField("booleanField", BooleanType, nullable = true), StructField("stringField", StringType, nullable = true), - StructField("decimalField", DecimalType, nullable = true), + StructField("decimalField", DecimalType.Unlimited, nullable = true), StructField("dateField", DateType, nullable = true), StructField("timestampField", TimestampType, nullable = true), StructField("binaryField", BinaryType, nullable = true))), @@ -199,7 +199,7 @@ class ScalaReflectionSuite extends FunSuite { assert(DoubleType === typeOfObject(1.7976931348623157E308)) // DecimalType - assert(DecimalType === typeOfObject(BigDecimal("1.7976931348623157E318"))) + assert(DecimalType.Unlimited === typeOfObject(BigDecimal("1.7976931348623157E318"))) // DateType assert(DateType === typeOfObject(Date.valueOf("2014-07-25"))) @@ -211,19 +211,19 @@ class ScalaReflectionSuite extends FunSuite { assert(NullType === typeOfObject(null)) def typeOfObject1: PartialFunction[Any, DataType] = typeOfObject orElse { - case value: java.math.BigInteger => DecimalType - case value: java.math.BigDecimal => DecimalType + case value: java.math.BigInteger => DecimalType.Unlimited + case value: java.math.BigDecimal => DecimalType.Unlimited case _ => StringType } - assert(DecimalType === typeOfObject1( + assert(DecimalType.Unlimited === typeOfObject1( new BigInteger("92233720368547758070"))) - assert(DecimalType === typeOfObject1( + assert(DecimalType.Unlimited === typeOfObject1( new java.math.BigDecimal("1.7976931348623157E318"))) assert(StringType === typeOfObject1(BigInt("92233720368547758070"))) def typeOfObject2: PartialFunction[Any, DataType] = typeOfObject orElse { - case value: java.math.BigInteger => DecimalType + case value: java.math.BigInteger => DecimalType.Unlimited } intercept[MatchError](typeOfObject2(BigInt("92233720368547758070"))) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 7b45738c4f..33a3cba3d4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -38,7 +38,7 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter { AttributeReference("a", StringType)(), AttributeReference("b", StringType)(), AttributeReference("c", DoubleType)(), - AttributeReference("d", DecimalType)(), + AttributeReference("d", DecimalType.Unlimited)(), AttributeReference("e", ShortType)()) before { @@ -119,7 +119,7 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter { AttributeReference("a", StringType)(), AttributeReference("b", StringType)(), AttributeReference("c", DoubleType)(), - AttributeReference("d", DecimalType)(), + AttributeReference("d", DecimalType.Unlimited)(), AttributeReference("e", ShortType)()) val expr0 = 'a / 2 @@ -137,7 +137,7 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter { assert(pl(0).dataType == DoubleType) assert(pl(1).dataType == DoubleType) assert(pl(2).dataType == DoubleType) - assert(pl(3).dataType == DecimalType) + assert(pl(3).dataType == DecimalType.Unlimited) assert(pl(4).dataType == DoubleType) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala new file mode 100644 index 0000000000..d5b7d2789a --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical.{Project, LocalRelation} +import org.apache.spark.sql.catalyst.types._ +import org.scalatest.{BeforeAndAfter, FunSuite} + +class DecimalPrecisionSuite extends FunSuite with BeforeAndAfter { + val catalog = new SimpleCatalog(false) + val analyzer = new Analyzer(catalog, EmptyFunctionRegistry, caseSensitive = false) + + val relation = LocalRelation( + AttributeReference("i", IntegerType)(), + AttributeReference("d1", DecimalType(2, 1))(), + AttributeReference("d2", DecimalType(5, 2))(), + AttributeReference("u", DecimalType.Unlimited)(), + AttributeReference("f", FloatType)() + ) + + val i: Expression = UnresolvedAttribute("i") + val d1: Expression = UnresolvedAttribute("d1") + val d2: Expression = UnresolvedAttribute("d2") + val u: Expression = UnresolvedAttribute("u") + val f: Expression = UnresolvedAttribute("f") + + before { + catalog.registerTable(None, "table", relation) + } + + private def checkType(expression: Expression, expectedType: DataType): Unit = { + val plan = Project(Seq(Alias(expression, "c")()), relation) + assert(analyzer(plan).schema.fields(0).dataType === expectedType) + } + + test("basic operations") { + checkType(Add(d1, d2), DecimalType(6, 2)) + checkType(Subtract(d1, d2), DecimalType(6, 2)) + checkType(Multiply(d1, d2), DecimalType(8, 3)) + checkType(Divide(d1, d2), DecimalType(10, 7)) + checkType(Divide(d2, d1), DecimalType(10, 6)) + checkType(Remainder(d1, d2), DecimalType(3, 2)) + checkType(Remainder(d2, d1), DecimalType(3, 2)) + checkType(Sum(d1), DecimalType(12, 1)) + checkType(Average(d1), DecimalType(6, 5)) + + checkType(Add(Add(d1, d2), d1), DecimalType(7, 2)) + checkType(Add(Add(Add(d1, d2), d1), d2), DecimalType(8, 2)) + checkType(Add(Add(d1, d2), Add(d1, d2)), DecimalType(7, 2)) + } + + test("bringing in primitive types") { + checkType(Add(d1, i), DecimalType(12, 1)) + checkType(Add(d1, f), DoubleType) + checkType(Add(i, d1), DecimalType(12, 1)) + checkType(Add(f, d1), DoubleType) + checkType(Add(d1, Cast(i, LongType)), DecimalType(22, 1)) + checkType(Add(d1, Cast(i, ShortType)), DecimalType(7, 1)) + checkType(Add(d1, Cast(i, ByteType)), DecimalType(5, 1)) + checkType(Add(d1, Cast(i, DoubleType)), DoubleType) + } + + test("unlimited decimals make everything else cast up") { + for (expr <- Seq(d1, d2, i, f, u)) { + checkType(Add(expr, u), DecimalType.Unlimited) + checkType(Subtract(expr, u), DecimalType.Unlimited) + checkType(Multiply(expr, u), DecimalType.Unlimited) + checkType(Divide(expr, u), DecimalType.Unlimited) + checkType(Remainder(expr, u), DecimalType.Unlimited) + } + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala index baeb9b0cf5..dfa2d958c0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala @@ -68,6 +68,21 @@ class HiveTypeCoercionSuite extends FunSuite { widenTest(LongType, FloatType, Some(FloatType)) widenTest(LongType, DoubleType, Some(DoubleType)) + // Casting up to unlimited-precision decimal + widenTest(IntegerType, DecimalType.Unlimited, Some(DecimalType.Unlimited)) + widenTest(DoubleType, DecimalType.Unlimited, Some(DecimalType.Unlimited)) + widenTest(DecimalType(3, 2), DecimalType.Unlimited, Some(DecimalType.Unlimited)) + widenTest(DecimalType.Unlimited, IntegerType, Some(DecimalType.Unlimited)) + widenTest(DecimalType.Unlimited, DoubleType, Some(DecimalType.Unlimited)) + widenTest(DecimalType.Unlimited, DecimalType(3, 2), Some(DecimalType.Unlimited)) + + // No up-casting for fixed-precision decimal (this is handled by arithmetic rules) + widenTest(DecimalType(2, 1), DecimalType(3, 2), None) + widenTest(DecimalType(2, 1), DoubleType, None) + widenTest(DecimalType(2, 1), IntegerType, None) + widenTest(DoubleType, DecimalType(2, 1), None) + widenTest(IntegerType, DecimalType(2, 1), None) + // StringType widenTest(NullType, StringType, Some(StringType)) widenTest(StringType, StringType, Some(StringType)) @@ -92,7 +107,7 @@ class HiveTypeCoercionSuite extends FunSuite { def ruleTest(initial: Expression, transformed: Expression) { val testRelation = LocalRelation(AttributeReference("a", IntegerType)()) assert(booleanCasts(Project(Seq(Alias(initial, "a")()), testRelation)) == - Project(Seq(Alias(transformed, "a")()), testRelation)) + Project(Seq(Alias(transformed, "a")()), testRelation)) } // Remove superflous boolean -> boolean casts. ruleTest(Cast(Literal(true), BooleanType), Literal(true)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala index 5657bc555e..6bfa0dbd65 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala @@ -21,6 +21,7 @@ import java.sql.{Date, Timestamp} import scala.collection.immutable.HashSet +import org.apache.spark.sql.catalyst.types.decimal.Decimal import org.scalatest.FunSuite import org.scalatest.Matchers._ import org.scalactic.TripleEqualsSupport.Spread @@ -138,7 +139,7 @@ class ExpressionEvaluationSuite extends FunSuite { val actual = try evaluate(expression, inputRow) catch { case e: Exception => fail(s"Exception evaluating $expression", e) } - actual.asInstanceOf[Double] shouldBe expected + actual.asInstanceOf[Double] shouldBe expected } test("IN") { @@ -165,7 +166,7 @@ class ExpressionEvaluationSuite extends FunSuite { checkEvaluation(InSet(three, nS, three +: nullS), false) checkEvaluation(InSet(one, hS, one +: s) && InSet(two, hS, two +: s), true) } - + test("MaxOf") { checkEvaluation(MaxOf(1, 2), 2) checkEvaluation(MaxOf(2, 1), 2) @@ -265,9 +266,9 @@ class ExpressionEvaluationSuite extends FunSuite { val ts = Timestamp.valueOf(nts) checkEvaluation("abdef" cast StringType, "abdef") - checkEvaluation("abdef" cast DecimalType, null) + checkEvaluation("abdef" cast DecimalType.Unlimited, null) checkEvaluation("abdef" cast TimestampType, null) - checkEvaluation("12.65" cast DecimalType, BigDecimal(12.65)) + checkEvaluation("12.65" cast DecimalType.Unlimited, Decimal(12.65)) checkEvaluation(Literal(1) cast LongType, 1) checkEvaluation(Cast(Literal(1000) cast TimestampType, LongType), 1.toLong) @@ -289,12 +290,12 @@ class ExpressionEvaluationSuite extends FunSuite { checkEvaluation(Cast(Cast(Cast(Cast( Cast("5" cast ByteType, ShortType), IntegerType), FloatType), DoubleType), LongType), 5) - checkEvaluation(Cast(Cast(Cast(Cast( - Cast("5" cast ByteType, TimestampType), DecimalType), LongType), StringType), ShortType), 0) - checkEvaluation(Cast(Cast(Cast(Cast( - Cast("5" cast TimestampType, ByteType), DecimalType), LongType), StringType), ShortType), null) - checkEvaluation(Cast(Cast(Cast(Cast( - Cast("5" cast DecimalType, ByteType), TimestampType), LongType), StringType), ShortType), 0) + checkEvaluation(Cast(Cast(Cast(Cast(Cast("5" cast + ByteType, TimestampType), DecimalType.Unlimited), LongType), StringType), ShortType), 0) + checkEvaluation(Cast(Cast(Cast(Cast(Cast("5" cast + TimestampType, ByteType), DecimalType.Unlimited), LongType), StringType), ShortType), null) + checkEvaluation(Cast(Cast(Cast(Cast(Cast("5" cast + DecimalType.Unlimited, ByteType), TimestampType), LongType), StringType), ShortType), 0) checkEvaluation(Literal(true) cast IntegerType, 1) checkEvaluation(Literal(false) cast IntegerType, 0) checkEvaluation(Cast(Literal(1) cast BooleanType, IntegerType), 1) @@ -302,7 +303,7 @@ class ExpressionEvaluationSuite extends FunSuite { checkEvaluation("23" cast DoubleType, 23d) checkEvaluation("23" cast IntegerType, 23) checkEvaluation("23" cast FloatType, 23f) - checkEvaluation("23" cast DecimalType, 23: BigDecimal) + checkEvaluation("23" cast DecimalType.Unlimited, Decimal(23)) checkEvaluation("23" cast ByteType, 23.toByte) checkEvaluation("23" cast ShortType, 23.toShort) checkEvaluation("2012-12-11" cast DoubleType, null) @@ -311,7 +312,7 @@ class ExpressionEvaluationSuite extends FunSuite { checkEvaluation(Literal(23d) + Cast(true, DoubleType), 24d) checkEvaluation(Literal(23) + Cast(true, IntegerType), 24) checkEvaluation(Literal(23f) + Cast(true, FloatType), 24f) - checkEvaluation(Literal(BigDecimal(23)) + Cast(true, DecimalType), 24: BigDecimal) + checkEvaluation(Literal(Decimal(23)) + Cast(true, DecimalType.Unlimited), Decimal(24)) checkEvaluation(Literal(23.toByte) + Cast(true, ByteType), 24.toByte) checkEvaluation(Literal(23.toShort) + Cast(true, ShortType), 24.toShort) @@ -325,7 +326,8 @@ class ExpressionEvaluationSuite extends FunSuite { assert(("abcdef" cast IntegerType).nullable === true) assert(("abcdef" cast ShortType).nullable === true) assert(("abcdef" cast ByteType).nullable === true) - assert(("abcdef" cast DecimalType).nullable === true) + assert(("abcdef" cast DecimalType.Unlimited).nullable === true) + assert(("abcdef" cast DecimalType(4, 2)).nullable === true) assert(("abcdef" cast DoubleType).nullable === true) assert(("abcdef" cast FloatType).nullable === true) @@ -338,6 +340,64 @@ class ExpressionEvaluationSuite extends FunSuite { checkEvaluation(Literal(d1) < Literal(d2), true) } + test("casting to fixed-precision decimals") { + // Overflow and rounding for casting to fixed-precision decimals: + // - Values should round with HALF_UP mode by default when you lower scale + // - Values that would overflow the target precision should turn into null + // - Because of this, casts to fixed-precision decimals should be nullable + + assert(Cast(Literal(123), DecimalType.Unlimited).nullable === false) + assert(Cast(Literal(10.03f), DecimalType.Unlimited).nullable === false) + assert(Cast(Literal(10.03), DecimalType.Unlimited).nullable === false) + assert(Cast(Literal(Decimal(10.03)), DecimalType.Unlimited).nullable === false) + + assert(Cast(Literal(123), DecimalType(2, 1)).nullable === true) + assert(Cast(Literal(10.03f), DecimalType(2, 1)).nullable === true) + assert(Cast(Literal(10.03), DecimalType(2, 1)).nullable === true) + assert(Cast(Literal(Decimal(10.03)), DecimalType(2, 1)).nullable === true) + + checkEvaluation(Cast(Literal(123), DecimalType.Unlimited), Decimal(123)) + checkEvaluation(Cast(Literal(123), DecimalType(3, 0)), Decimal(123)) + checkEvaluation(Cast(Literal(123), DecimalType(3, 1)), null) + checkEvaluation(Cast(Literal(123), DecimalType(2, 0)), null) + + checkEvaluation(Cast(Literal(10.03), DecimalType.Unlimited), Decimal(10.03)) + checkEvaluation(Cast(Literal(10.03), DecimalType(4, 2)), Decimal(10.03)) + checkEvaluation(Cast(Literal(10.03), DecimalType(3, 1)), Decimal(10.0)) + checkEvaluation(Cast(Literal(10.03), DecimalType(2, 0)), Decimal(10)) + checkEvaluation(Cast(Literal(10.03), DecimalType(1, 0)), null) + checkEvaluation(Cast(Literal(10.03), DecimalType(2, 1)), null) + checkEvaluation(Cast(Literal(10.03), DecimalType(3, 2)), null) + checkEvaluation(Cast(Literal(Decimal(10.03)), DecimalType(3, 1)), Decimal(10.0)) + checkEvaluation(Cast(Literal(Decimal(10.03)), DecimalType(3, 2)), null) + + checkEvaluation(Cast(Literal(10.05), DecimalType.Unlimited), Decimal(10.05)) + checkEvaluation(Cast(Literal(10.05), DecimalType(4, 2)), Decimal(10.05)) + checkEvaluation(Cast(Literal(10.05), DecimalType(3, 1)), Decimal(10.1)) + checkEvaluation(Cast(Literal(10.05), DecimalType(2, 0)), Decimal(10)) + checkEvaluation(Cast(Literal(10.05), DecimalType(1, 0)), null) + checkEvaluation(Cast(Literal(10.05), DecimalType(2, 1)), null) + checkEvaluation(Cast(Literal(10.05), DecimalType(3, 2)), null) + checkEvaluation(Cast(Literal(Decimal(10.05)), DecimalType(3, 1)), Decimal(10.1)) + checkEvaluation(Cast(Literal(Decimal(10.05)), DecimalType(3, 2)), null) + + checkEvaluation(Cast(Literal(9.95), DecimalType(3, 2)), Decimal(9.95)) + checkEvaluation(Cast(Literal(9.95), DecimalType(3, 1)), Decimal(10.0)) + checkEvaluation(Cast(Literal(9.95), DecimalType(2, 0)), Decimal(10)) + checkEvaluation(Cast(Literal(9.95), DecimalType(2, 1)), null) + checkEvaluation(Cast(Literal(9.95), DecimalType(1, 0)), null) + checkEvaluation(Cast(Literal(Decimal(9.95)), DecimalType(3, 1)), Decimal(10.0)) + checkEvaluation(Cast(Literal(Decimal(9.95)), DecimalType(1, 0)), null) + + checkEvaluation(Cast(Literal(-9.95), DecimalType(3, 2)), Decimal(-9.95)) + checkEvaluation(Cast(Literal(-9.95), DecimalType(3, 1)), Decimal(-10.0)) + checkEvaluation(Cast(Literal(-9.95), DecimalType(2, 0)), Decimal(-10)) + checkEvaluation(Cast(Literal(-9.95), DecimalType(2, 1)), null) + checkEvaluation(Cast(Literal(-9.95), DecimalType(1, 0)), null) + checkEvaluation(Cast(Literal(Decimal(-9.95)), DecimalType(3, 1)), Decimal(-10.0)) + checkEvaluation(Cast(Literal(Decimal(-9.95)), DecimalType(1, 0)), null) + } + test("timestamp") { val ts1 = new Timestamp(12) val ts2 = new Timestamp(123) @@ -374,7 +434,7 @@ class ExpressionEvaluationSuite extends FunSuite { millis.toFloat / 1000) checkEvaluation(Cast(Cast(millis.toDouble / 1000, TimestampType), DoubleType), millis.toDouble / 1000) - checkEvaluation(Cast(Literal(BigDecimal(1)) cast TimestampType, DecimalType), 1) + checkEvaluation(Cast(Literal(Decimal(1)) cast TimestampType, DecimalType.Unlimited), Decimal(1)) // A test for higher precision than millis checkEvaluation(Cast(Cast(0.00000001, TimestampType), DoubleType), 0.00000001) @@ -673,7 +733,7 @@ class ExpressionEvaluationSuite extends FunSuite { val expectedResults = inputSequence.map(l => math.sqrt(l.toDouble)) val rowSequence = inputSequence.map(l => new GenericRow(Array[Any](l.toDouble))) val d = 'a.double.at(0) - + for ((row, expected) <- rowSequence zip expectedResults) { checkEvaluation(Sqrt(d), expected, row) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/types/decimal/DecimalSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/types/decimal/DecimalSuite.scala new file mode 100644 index 0000000000..5aa263484d --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/types/decimal/DecimalSuite.scala @@ -0,0 +1,158 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.types.decimal + +import org.scalatest.{PrivateMethodTester, FunSuite} + +import scala.language.postfixOps + +class DecimalSuite extends FunSuite with PrivateMethodTester { + test("creating decimals") { + /** Check that a Decimal has the given string representation, precision and scale */ + def checkDecimal(d: Decimal, string: String, precision: Int, scale: Int): Unit = { + assert(d.toString === string) + assert(d.precision === precision) + assert(d.scale === scale) + } + + checkDecimal(new Decimal(), "0", 1, 0) + checkDecimal(Decimal(BigDecimal("10.030")), "10.030", 5, 3) + checkDecimal(Decimal(BigDecimal("10.030"), 4, 1), "10.0", 4, 1) + checkDecimal(Decimal(BigDecimal("-9.95"), 4, 1), "-10.0", 4, 1) + checkDecimal(Decimal("10.030"), "10.030", 5, 3) + checkDecimal(Decimal(10.03), "10.03", 4, 2) + checkDecimal(Decimal(17L), "17", 20, 0) + checkDecimal(Decimal(17), "17", 10, 0) + checkDecimal(Decimal(17L, 2, 1), "1.7", 2, 1) + checkDecimal(Decimal(170L, 4, 2), "1.70", 4, 2) + checkDecimal(Decimal(17L, 24, 1), "1.7", 24, 1) + checkDecimal(Decimal(1e17.toLong, 18, 0), 1e17.toLong.toString, 18, 0) + checkDecimal(Decimal(Long.MaxValue), Long.MaxValue.toString, 20, 0) + checkDecimal(Decimal(Long.MinValue), Long.MinValue.toString, 20, 0) + intercept[IllegalArgumentException](Decimal(170L, 2, 1)) + intercept[IllegalArgumentException](Decimal(170L, 2, 0)) + intercept[IllegalArgumentException](Decimal(BigDecimal("10.030"), 2, 1)) + intercept[IllegalArgumentException](Decimal(BigDecimal("-9.95"), 2, 1)) + intercept[IllegalArgumentException](Decimal(1e17.toLong, 17, 0)) + } + + test("double and long values") { + /** Check that a Decimal converts to the given double and long values */ + def checkValues(d: Decimal, doubleValue: Double, longValue: Long): Unit = { + assert(d.toDouble === doubleValue) + assert(d.toLong === longValue) + } + + checkValues(new Decimal(), 0.0, 0L) + checkValues(Decimal(BigDecimal("10.030")), 10.03, 10L) + checkValues(Decimal(BigDecimal("10.030"), 4, 1), 10.0, 10L) + checkValues(Decimal(BigDecimal("-9.95"), 4, 1), -10.0, -10L) + checkValues(Decimal(10.03), 10.03, 10L) + checkValues(Decimal(17L), 17.0, 17L) + checkValues(Decimal(17), 17.0, 17L) + checkValues(Decimal(17L, 2, 1), 1.7, 1L) + checkValues(Decimal(170L, 4, 2), 1.7, 1L) + checkValues(Decimal(1e16.toLong), 1e16, 1e16.toLong) + checkValues(Decimal(1e17.toLong), 1e17, 1e17.toLong) + checkValues(Decimal(1e18.toLong), 1e18, 1e18.toLong) + checkValues(Decimal(2e18.toLong), 2e18, 2e18.toLong) + checkValues(Decimal(Long.MaxValue), Long.MaxValue.toDouble, Long.MaxValue) + checkValues(Decimal(Long.MinValue), Long.MinValue.toDouble, Long.MinValue) + checkValues(Decimal(Double.MaxValue), Double.MaxValue, 0L) + checkValues(Decimal(Double.MinValue), Double.MinValue, 0L) + } + + // Accessor for the BigDecimal value of a Decimal, which will be null if it's using Longs + private val decimalVal = PrivateMethod[BigDecimal]('decimalVal) + + /** Check whether a decimal is represented compactly (passing whether we expect it to be) */ + private def checkCompact(d: Decimal, expected: Boolean): Unit = { + val isCompact = d.invokePrivate(decimalVal()).eq(null) + assert(isCompact == expected, s"$d ${if (expected) "was not" else "was"} compact") + } + + test("small decimals represented as unscaled long") { + checkCompact(new Decimal(), true) + checkCompact(Decimal(BigDecimal(10.03)), false) + checkCompact(Decimal(BigDecimal(1e20)), false) + checkCompact(Decimal(17L), true) + checkCompact(Decimal(17), true) + checkCompact(Decimal(17L, 2, 1), true) + checkCompact(Decimal(170L, 4, 2), true) + checkCompact(Decimal(17L, 24, 1), true) + checkCompact(Decimal(1e16.toLong), true) + checkCompact(Decimal(1e17.toLong), true) + checkCompact(Decimal(1e18.toLong - 1), true) + checkCompact(Decimal(- 1e18.toLong + 1), true) + checkCompact(Decimal(1e18.toLong - 1, 30, 10), true) + checkCompact(Decimal(- 1e18.toLong + 1, 30, 10), true) + checkCompact(Decimal(1e18.toLong), false) + checkCompact(Decimal(-1e18.toLong), false) + checkCompact(Decimal(1e18.toLong, 30, 10), false) + checkCompact(Decimal(-1e18.toLong, 30, 10), false) + checkCompact(Decimal(Long.MaxValue), false) + checkCompact(Decimal(Long.MinValue), false) + } + + test("hash code") { + assert(Decimal(123).hashCode() === (123).##) + assert(Decimal(-123).hashCode() === (-123).##) + assert(Decimal(123.312).hashCode() === (123.312).##) + assert(Decimal(Int.MaxValue).hashCode() === Int.MaxValue.##) + assert(Decimal(Long.MaxValue).hashCode() === Long.MaxValue.##) + assert(Decimal(BigDecimal(123)).hashCode() === (123).##) + + val reallyBig = BigDecimal("123182312312313232112312312123.1231231231") + assert(Decimal(reallyBig).hashCode() === reallyBig.hashCode) + } + + test("equals") { + // The decimals on the left are stored compactly, while the ones on the right aren't + checkCompact(Decimal(123), true) + checkCompact(Decimal(BigDecimal(123)), false) + checkCompact(Decimal("123"), false) + assert(Decimal(123) === Decimal(BigDecimal(123))) + assert(Decimal(123) === Decimal(BigDecimal("123.00"))) + assert(Decimal(-123) === Decimal(BigDecimal(-123))) + assert(Decimal(-123) === Decimal(BigDecimal("-123.00"))) + } + + test("isZero") { + assert(Decimal(0).isZero) + assert(Decimal(0, 4, 2).isZero) + assert(Decimal("0").isZero) + assert(Decimal("0.000").isZero) + assert(!Decimal(1).isZero) + assert(!Decimal(1, 4, 2).isZero) + assert(!Decimal("1").isZero) + assert(!Decimal("0.001").isZero) + } + + test("arithmetic") { + assert(Decimal(100) + Decimal(-100) === Decimal(0)) + assert(Decimal(100) + Decimal(-100) === Decimal(0)) + assert(Decimal(100) * Decimal(-100) === Decimal(-10000)) + assert(Decimal(1e13) * Decimal(1e13) === Decimal(1e26)) + assert(Decimal(100) / Decimal(-100) === Decimal(-1)) + assert(Decimal(100) / Decimal(0) === null) + assert(Decimal(100) % Decimal(-100) === Decimal(0)) + assert(Decimal(100) % Decimal(3) === Decimal(1)) + assert(Decimal(-100) % Decimal(3) === Decimal(-1)) + assert(Decimal(100) % Decimal(0) === null) + } +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/DataType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/DataType.java index 0c85cdc0aa..c38354039d 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/DataType.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/DataType.java @@ -52,11 +52,6 @@ public abstract class DataType { */ public static final TimestampType TimestampType = new TimestampType(); - /** - * Gets the DecimalType object. - */ - public static final DecimalType DecimalType = new DecimalType(); - /** * Gets the DoubleType object. */ diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/DecimalType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/DecimalType.java index bc54c078d7..60752451ec 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/DecimalType.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/DecimalType.java @@ -19,9 +19,61 @@ package org.apache.spark.sql.api.java; /** * The data type representing java.math.BigDecimal values. - * - * {@code DecimalType} is represented by the singleton object {@link DataType#DecimalType}. */ public class DecimalType extends DataType { - protected DecimalType() {} + private boolean hasPrecisionInfo; + private int precision; + private int scale; + + public DecimalType(int precision, int scale) { + this.hasPrecisionInfo = true; + this.precision = precision; + this.scale = scale; + } + + public DecimalType() { + this.hasPrecisionInfo = false; + this.precision = -1; + this.scale = -1; + } + + public boolean isUnlimited() { + return !hasPrecisionInfo; + } + + public boolean isFixed() { + return hasPrecisionInfo; + } + + /** Return the precision, or -1 if no precision is set */ + public int getPrecision() { + return precision; + } + + /** Return the scale, or -1 if no precision is set */ + public int getScale() { + return scale; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + DecimalType that = (DecimalType) o; + + if (hasPrecisionInfo != that.hasPrecisionInfo) return false; + if (precision != that.precision) return false; + if (scale != that.scale) return false; + + return true; + } + + @Override + public int hashCode() { + int result = (hasPrecisionInfo ? 1 : 0); + result = 31 * result + precision; + result = 31 * result + scale; + return result; + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala index 8b96df1096..018a18c4ac 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql import java.util.{Map => JMap, List => JList} +import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.storage.StorageLevel import scala.collection.JavaConversions._ @@ -113,7 +114,7 @@ class SchemaRDD( // ========================================================================================= override def compute(split: Partition, context: TaskContext): Iterator[Row] = - firstParent[Row].compute(split, context).map(_.copy()) + firstParent[Row].compute(split, context).map(ScalaReflection.convertRowToScala) override def getPartitions: Array[Partition] = firstParent[Row].partitions diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala index 082ae03eef..876b1c6ede 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala @@ -230,7 +230,7 @@ class JavaSQLContext(val sqlContext: SQLContext) extends UDFRegistration { case c: Class[_] if c == classOf[java.lang.Boolean] => (org.apache.spark.sql.BooleanType, true) case c: Class[_] if c == classOf[java.math.BigDecimal] => - (org.apache.spark.sql.DecimalType, true) + (org.apache.spark.sql.DecimalType(), true) case c: Class[_] if c == classOf[java.sql.Date] => (org.apache.spark.sql.DateType, true) case c: Class[_] if c == classOf[java.sql.Timestamp] => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/Row.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/java/Row.scala index df01411f60..401798e317 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/java/Row.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/Row.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.api.java +import org.apache.spark.sql.catalyst.types.decimal.Decimal + import scala.annotation.varargs import scala.collection.convert.Wrappers.{JListWrapper, JMapWrapper} import scala.collection.JavaConversions @@ -106,6 +108,8 @@ class Row(private[spark] val row: ScalaRow) extends Serializable { } override def hashCode(): Int = row.hashCode() + + override def toString: String = row.toString } object Row { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala index b3edd5020f..087b0ecbb2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala @@ -70,16 +70,29 @@ case class GeneratedAggregate( val computeFunctions = aggregatesToCompute.map { case c @ Count(expr) => + // If we're evaluating UnscaledValue(x), we can do Count on x directly, since its + // UnscaledValue will be null if and only if x is null; helps with Average on decimals + val toCount = expr match { + case UnscaledValue(e) => e + case _ => expr + } val currentCount = AttributeReference("currentCount", LongType, nullable = false)() val initialValue = Literal(0L) - val updateFunction = If(IsNotNull(expr), Add(currentCount, Literal(1L)), currentCount) + val updateFunction = If(IsNotNull(toCount), Add(currentCount, Literal(1L)), currentCount) val result = currentCount AggregateEvaluation(currentCount :: Nil, initialValue :: Nil, updateFunction :: Nil, result) case Sum(expr) => - val currentSum = AttributeReference("currentSum", expr.dataType, nullable = false)() - val initialValue = Cast(Literal(0L), expr.dataType) + val resultType = expr.dataType match { + case DecimalType.Fixed(precision, scale) => + DecimalType(precision + 10, scale) + case _ => + expr.dataType + } + + val currentSum = AttributeReference("currentSum", resultType, nullable = false)() + val initialValue = Cast(Literal(0L), resultType) // Coalasce avoids double calculation... // but really, common sub expression elimination would be better.... @@ -93,10 +106,26 @@ case class GeneratedAggregate( val currentSum = AttributeReference("currentSum", expr.dataType, nullable = false)() val initialCount = Literal(0L) val initialSum = Cast(Literal(0L), expr.dataType) - val updateCount = If(IsNotNull(expr), Add(currentCount, Literal(1L)), currentCount) + + // If we're evaluating UnscaledValue(x), we can do Count on x directly, since its + // UnscaledValue will be null if and only if x is null; helps with Average on decimals + val toCount = expr match { + case UnscaledValue(e) => e + case _ => expr + } + + val updateCount = If(IsNotNull(toCount), Add(currentCount, Literal(1L)), currentCount) val updateSum = Coalesce(Add(expr, currentSum) :: currentSum :: Nil) - val result = Divide(Cast(currentSum, DoubleType), Cast(currentCount, DoubleType)) + val resultType = expr.dataType match { + case DecimalType.Fixed(precision, scale) => + DecimalType(precision + 4, scale + 4) + case DecimalType.Unlimited => + DecimalType.Unlimited + case _ => + DoubleType + } + val result = Divide(Cast(currentSum, resultType), Cast(currentCount, resultType)) AggregateEvaluation( currentCount :: currentSum :: Nil, @@ -142,7 +171,7 @@ case class GeneratedAggregate( val computationSchema = computeFunctions.flatMap(_.schema) - val resultMap: Map[TreeNodeRef, Expression] = + val resultMap: Map[TreeNodeRef, Expression] = aggregatesToCompute.zip(computeFunctions).map { case (agg, func) => new TreeNodeRef(agg) -> func.result }.toMap diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index b1a7948b66..aafcce0572 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -23,7 +23,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.SQLContext -import org.apache.spark.sql.catalyst.trees +import org.apache.spark.sql.catalyst.{ScalaReflection, trees} import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ @@ -82,7 +82,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ /** * Runs this query returning the result as an array. */ - def executeCollect(): Array[Row] = execute().map(_.copy()).collect() + def executeCollect(): Array[Row] = execute().map(ScalaReflection.convertRowToScala).collect() protected def newProjection( expressions: Seq[Expression], inputSchema: Seq[Attribute]): Projection = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala index 077e6ebc5f..84d96e612f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala @@ -29,6 +29,7 @@ import com.twitter.chill.{AllScalaRegistrar, ResourcePool} import org.apache.spark.{SparkEnv, SparkConf} import org.apache.spark.serializer.{SerializerInstance, KryoSerializer} import org.apache.spark.sql.catalyst.expressions.GenericRow +import org.apache.spark.sql.catalyst.types.decimal.Decimal import org.apache.spark.util.collection.OpenHashSet import org.apache.spark.util.MutablePair import org.apache.spark.util.Utils @@ -51,6 +52,7 @@ private[sql] class SparkSqlSerializer(conf: SparkConf) extends KryoSerializer(co kryo.register(classOf[LongHashSet], new LongHashSetSerializer) kryo.register(classOf[org.apache.spark.util.collection.OpenHashSet[_]], new OpenHashSetSerializer) + kryo.register(classOf[Decimal]) kryo.setReferences(false) kryo.setClassLoader(Utils.getSparkClassLoader) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index 977f3c9f32..e6cd1a9d04 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -143,7 +143,7 @@ case class Limit(limit: Int, child: SparkPlan) partsScanned += numPartsToTry } - buf.toArray + buf.toArray.map(ScalaReflection.convertRowToScala) } override def execute() = { @@ -176,10 +176,11 @@ case class TakeOrdered(limit: Int, sortOrder: Seq[SortOrder], child: SparkPlan) override def output = child.output override def outputPartitioning = SinglePartition - val ordering = new RowOrdering(sortOrder, child.output) + val ord = new RowOrdering(sortOrder, child.output) // TODO: Is this copying for no reason? - override def executeCollect() = child.execute().map(_.copy()).takeOrdered(limit)(ordering) + override def executeCollect() = + child.execute().map(_.copy()).takeOrdered(limit)(ord).map(ScalaReflection.convertRowToScala) // TODO: Terminal split should be implemented differently from non-terminal split. // TODO: Pick num splits based on |limit|. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala index 8fd35880ee..5cf2a785ad 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala @@ -49,7 +49,8 @@ case class BroadcastHashJoin( @transient private val broadcastFuture = future { - val input: Array[Row] = buildPlan.executeCollect() + // Note that we use .execute().collect() because we don't want to convert data to Scala types + val input: Array[Row] = buildPlan.execute().map(_.copy()).collect() val hashed = HashedRelation(input.iterator, buildSideKeyGenerator, input.length) sparkContext.broadcast(hashed) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala index a1961bba18..997669051e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.execution import java.util.{List => JList, Map => JMap} +import org.apache.spark.sql.catalyst.types.decimal.Decimal + import scala.collection.JavaConversions._ import scala.collection.JavaConverters._ @@ -116,7 +118,7 @@ object EvaluatePython { def toJava(obj: Any, dataType: DataType): Any = (obj, dataType) match { case (null, _) => null - case (row: Row, struct: StructType) => + case (row: Seq[Any], struct: StructType) => val fields = struct.fields.map(field => field.dataType) row.zip(fields).map { case (obj, dataType) => toJava(obj, dataType) @@ -133,6 +135,8 @@ object EvaluatePython { case (k, v) => (k, toJava(v, mt.valueType)) // key should be primitive type }.asJava + case (dec: BigDecimal, dt: DecimalType) => dec.underlying() // Pyrolite can handle BigDecimal + // Pyrolite can handle Timestamp case (other, _) => other } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala index eabe312f92..5bb6f6c85d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.json +import org.apache.spark.sql.catalyst.types.decimal.Decimal + import scala.collection.Map import scala.collection.convert.Wrappers.{JMapWrapper, JListWrapper} import scala.math.BigDecimal @@ -175,9 +177,9 @@ private[sql] object JsonRDD extends Logging { ScalaReflection.typeOfObject orElse { // Since we do not have a data type backed by BigInteger, // when we see a Java BigInteger, we use DecimalType. - case value: java.math.BigInteger => DecimalType + case value: java.math.BigInteger => DecimalType.Unlimited // DecimalType's JVMType is scala BigDecimal. - case value: java.math.BigDecimal => DecimalType + case value: java.math.BigDecimal => DecimalType.Unlimited // Unexpected data type. case _ => StringType } @@ -319,13 +321,13 @@ private[sql] object JsonRDD extends Logging { } } - private def toDecimal(value: Any): BigDecimal = { + private def toDecimal(value: Any): Decimal = { value match { - case value: java.lang.Integer => BigDecimal(value) - case value: java.lang.Long => BigDecimal(value) - case value: java.math.BigInteger => BigDecimal(value) - case value: java.lang.Double => BigDecimal(value) - case value: java.math.BigDecimal => BigDecimal(value) + case value: java.lang.Integer => Decimal(value) + case value: java.lang.Long => Decimal(value) + case value: java.math.BigInteger => Decimal(BigDecimal(value)) + case value: java.lang.Double => Decimal(value) + case value: java.math.BigDecimal => Decimal(BigDecimal(value)) } } @@ -391,7 +393,7 @@ private[sql] object JsonRDD extends Logging { case IntegerType => value.asInstanceOf[IntegerType.JvmType] case LongType => toLong(value) case DoubleType => toDouble(value) - case DecimalType => toDecimal(value) + case DecimalType() => toDecimal(value) case BooleanType => value.asInstanceOf[BooleanType.JvmType] case NullType => null diff --git a/sql/core/src/main/scala/org/apache/spark/sql/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/package.scala index f0e57e2a74..05926a24c5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/package.scala @@ -183,6 +183,20 @@ package object sql { * * The data type representing `scala.math.BigDecimal` values. * + * TODO(matei): explain precision and scale + * + * @group dataType + */ + @DeveloperApi + type DecimalType = catalyst.types.DecimalType + + /** + * :: DeveloperApi :: + * + * The data type representing `scala.math.BigDecimal` values. + * + * TODO(matei): explain precision and scale + * * @group dataType */ @DeveloperApi diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala index 2fc7e1cf23..08feced61a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.parquet +import org.apache.spark.sql.catalyst.types.decimal.Decimal + import scala.collection.mutable.{Buffer, ArrayBuffer, HashMap} import parquet.io.api.{PrimitiveConverter, GroupConverter, Binary, Converter} @@ -117,6 +119,12 @@ private[sql] object CatalystConverter { parent.updateByte(fieldIndex, value.asInstanceOf[ByteType.JvmType]) } } + case d: DecimalType => { + new CatalystPrimitiveConverter(parent, fieldIndex) { + override def addBinary(value: Binary): Unit = + parent.updateDecimal(fieldIndex, value, d) + } + } // All other primitive types use the default converter case ctype: PrimitiveType => { // note: need the type tag here! new CatalystPrimitiveConverter(parent, fieldIndex) @@ -191,6 +199,10 @@ private[parquet] abstract class CatalystConverter extends GroupConverter { protected[parquet] def updateString(fieldIndex: Int, value: Binary): Unit = updateField(fieldIndex, value.toStringUsingUTF8) + protected[parquet] def updateDecimal(fieldIndex: Int, value: Binary, ctype: DecimalType): Unit = { + updateField(fieldIndex, readDecimal(new Decimal(), value, ctype)) + } + protected[parquet] def isRootConverter: Boolean = parent == null protected[parquet] def clearBuffer(): Unit @@ -201,6 +213,27 @@ private[parquet] abstract class CatalystConverter extends GroupConverter { * @return */ def getCurrentRecord: Row = throw new UnsupportedOperationException + + /** + * Read a decimal value from a Parquet Binary into "dest". Only supports decimals that fit in + * a long (i.e. precision <= 18) + */ + protected[parquet] def readDecimal(dest: Decimal, value: Binary, ctype: DecimalType): Unit = { + val precision = ctype.precisionInfo.get.precision + val scale = ctype.precisionInfo.get.scale + val bytes = value.getBytes + require(bytes.length <= 16, "Decimal field too large to read") + var unscaled = 0L + var i = 0 + while (i < bytes.length) { + unscaled = (unscaled << 8) | (bytes(i) & 0xFF) + i += 1 + } + // Make sure unscaled has the right sign, by sign-extending the first bit + val numBits = 8 * bytes.length + unscaled = (unscaled << (64 - numBits)) >> (64 - numBits) + dest.set(unscaled, precision, scale) + } } /** @@ -352,6 +385,16 @@ private[parquet] class CatalystPrimitiveRowConverter( override protected[parquet] def updateString(fieldIndex: Int, value: Binary): Unit = current.setString(fieldIndex, value.toStringUsingUTF8) + + override protected[parquet] def updateDecimal( + fieldIndex: Int, value: Binary, ctype: DecimalType): Unit = { + var decimal = current(fieldIndex).asInstanceOf[Decimal] + if (decimal == null) { + decimal = new Decimal + current(fieldIndex) = decimal + } + readDecimal(decimal, value, ctype) + } } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala index bdf02401b2..2a5f23b24e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.parquet import java.util.{HashMap => JHashMap} import org.apache.hadoop.conf.Configuration +import org.apache.spark.sql.catalyst.types.decimal.Decimal import parquet.column.ParquetProperties import parquet.hadoop.ParquetOutputFormat import parquet.hadoop.api.ReadSupport.ReadContext @@ -204,6 +205,11 @@ private[parquet] class RowWriteSupport extends WriteSupport[Row] with Logging { case DoubleType => writer.addDouble(value.asInstanceOf[Double]) case FloatType => writer.addFloat(value.asInstanceOf[Float]) case BooleanType => writer.addBoolean(value.asInstanceOf[Boolean]) + case d: DecimalType => + if (d.precisionInfo == None || d.precisionInfo.get.precision > 18) { + sys.error(s"Unsupported datatype $d, cannot write to consumer") + } + writeDecimal(value.asInstanceOf[Decimal], d.precisionInfo.get.precision) case _ => sys.error(s"Do not know how to writer $schema to consumer") } } @@ -283,6 +289,23 @@ private[parquet] class RowWriteSupport extends WriteSupport[Row] with Logging { } writer.endGroup() } + + // Scratch array used to write decimals as fixed-length binary + private val scratchBytes = new Array[Byte](8) + + private[parquet] def writeDecimal(decimal: Decimal, precision: Int): Unit = { + val numBytes = ParquetTypesConverter.BYTES_FOR_PRECISION(precision) + val unscaledLong = decimal.toUnscaledLong + var i = 0 + var shift = 8 * (numBytes - 1) + while (i < numBytes) { + scratchBytes(i) = (unscaledLong >> shift).toByte + i += 1 + shift -= 8 + } + writer.addBinary(Binary.fromByteArray(scratchBytes, 0, numBytes)) + } + } // Optimized for non-nested rows @@ -326,6 +349,11 @@ private[parquet] class MutableRowWriteSupport extends RowWriteSupport { case DoubleType => writer.addDouble(record.getDouble(index)) case FloatType => writer.addFloat(record.getFloat(index)) case BooleanType => writer.addBoolean(record.getBoolean(index)) + case d: DecimalType => + if (d.precisionInfo == None || d.precisionInfo.get.precision > 18) { + sys.error(s"Unsupported datatype $d, cannot write to consumer") + } + writeDecimal(record(index).asInstanceOf[Decimal], d.precisionInfo.get.precision) case _ => sys.error(s"Unsupported datatype $ctype, cannot write to consumer") } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala index e6389cf77a..e5077de8dd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala @@ -29,8 +29,8 @@ import org.apache.hadoop.mapreduce.lib.output.FileOutputCommitter import parquet.hadoop.{ParquetFileReader, Footer, ParquetFileWriter} import parquet.hadoop.metadata.{ParquetMetadata, FileMetaData} import parquet.hadoop.util.ContextUtil -import parquet.schema.{Type => ParquetType, PrimitiveType => ParquetPrimitiveType, MessageType} -import parquet.schema.{GroupType => ParquetGroupType, OriginalType => ParquetOriginalType, ConversionPatterns} +import parquet.schema.{Type => ParquetType, Types => ParquetTypes, PrimitiveType => ParquetPrimitiveType, MessageType} +import parquet.schema.{GroupType => ParquetGroupType, OriginalType => ParquetOriginalType, ConversionPatterns, DecimalMetadata} import parquet.schema.PrimitiveType.{PrimitiveTypeName => ParquetPrimitiveTypeName} import parquet.schema.Type.Repetition @@ -41,17 +41,25 @@ import org.apache.spark.sql.catalyst.types._ // Implicits import scala.collection.JavaConversions._ +/** A class representing Parquet info fields we care about, for passing back to Parquet */ +private[parquet] case class ParquetTypeInfo( + primitiveType: ParquetPrimitiveTypeName, + originalType: Option[ParquetOriginalType] = None, + decimalMetadata: Option[DecimalMetadata] = None, + length: Option[Int] = None) + private[parquet] object ParquetTypesConverter extends Logging { def isPrimitiveType(ctype: DataType): Boolean = classOf[PrimitiveType] isAssignableFrom ctype.getClass def toPrimitiveDataType( parquetType: ParquetPrimitiveType, - binayAsString: Boolean): DataType = + binaryAsString: Boolean): DataType = { + val originalType = parquetType.getOriginalType + val decimalInfo = parquetType.getDecimalMetadata parquetType.getPrimitiveTypeName match { case ParquetPrimitiveTypeName.BINARY - if (parquetType.getOriginalType == ParquetOriginalType.UTF8 || - binayAsString) => StringType + if (originalType == ParquetOriginalType.UTF8 || binaryAsString) => StringType case ParquetPrimitiveTypeName.BINARY => BinaryType case ParquetPrimitiveTypeName.BOOLEAN => BooleanType case ParquetPrimitiveTypeName.DOUBLE => DoubleType @@ -61,9 +69,14 @@ private[parquet] object ParquetTypesConverter extends Logging { case ParquetPrimitiveTypeName.INT96 => // TODO: add BigInteger type? TODO(andre) use DecimalType instead???? sys.error("Potential loss of precision: cannot convert INT96") + case ParquetPrimitiveTypeName.FIXED_LEN_BYTE_ARRAY + if (originalType == ParquetOriginalType.DECIMAL && decimalInfo.getPrecision <= 18) => + // TODO: for now, our reader only supports decimals that fit in a Long + DecimalType(decimalInfo.getPrecision, decimalInfo.getScale) case _ => sys.error( s"Unsupported parquet datatype $parquetType") } + } /** * Converts a given Parquet `Type` into the corresponding @@ -183,23 +196,40 @@ private[parquet] object ParquetTypesConverter extends Logging { * is not primitive. * * @param ctype The type to convert - * @return The name of the corresponding Parquet primitive type + * @return The name of the corresponding Parquet type properties */ - def fromPrimitiveDataType(ctype: DataType): - Option[(ParquetPrimitiveTypeName, Option[ParquetOriginalType])] = ctype match { - case StringType => Some(ParquetPrimitiveTypeName.BINARY, Some(ParquetOriginalType.UTF8)) - case BinaryType => Some(ParquetPrimitiveTypeName.BINARY, None) - case BooleanType => Some(ParquetPrimitiveTypeName.BOOLEAN, None) - case DoubleType => Some(ParquetPrimitiveTypeName.DOUBLE, None) - case FloatType => Some(ParquetPrimitiveTypeName.FLOAT, None) - case IntegerType => Some(ParquetPrimitiveTypeName.INT32, None) + def fromPrimitiveDataType(ctype: DataType): Option[ParquetTypeInfo] = ctype match { + case StringType => Some(ParquetTypeInfo( + ParquetPrimitiveTypeName.BINARY, Some(ParquetOriginalType.UTF8))) + case BinaryType => Some(ParquetTypeInfo(ParquetPrimitiveTypeName.BINARY)) + case BooleanType => Some(ParquetTypeInfo(ParquetPrimitiveTypeName.BOOLEAN)) + case DoubleType => Some(ParquetTypeInfo(ParquetPrimitiveTypeName.DOUBLE)) + case FloatType => Some(ParquetTypeInfo(ParquetPrimitiveTypeName.FLOAT)) + case IntegerType => Some(ParquetTypeInfo(ParquetPrimitiveTypeName.INT32)) // There is no type for Byte or Short so we promote them to INT32. - case ShortType => Some(ParquetPrimitiveTypeName.INT32, None) - case ByteType => Some(ParquetPrimitiveTypeName.INT32, None) - case LongType => Some(ParquetPrimitiveTypeName.INT64, None) + case ShortType => Some(ParquetTypeInfo(ParquetPrimitiveTypeName.INT32)) + case ByteType => Some(ParquetTypeInfo(ParquetPrimitiveTypeName.INT32)) + case LongType => Some(ParquetTypeInfo(ParquetPrimitiveTypeName.INT64)) + case DecimalType.Fixed(precision, scale) if precision <= 18 => + // TODO: for now, our writer only supports decimals that fit in a Long + Some(ParquetTypeInfo(ParquetPrimitiveTypeName.FIXED_LEN_BYTE_ARRAY, + Some(ParquetOriginalType.DECIMAL), + Some(new DecimalMetadata(precision, scale)), + Some(BYTES_FOR_PRECISION(precision)))) case _ => None } + /** + * Compute the FIXED_LEN_BYTE_ARRAY length needed to represent a given DECIMAL precision. + */ + private[parquet] val BYTES_FOR_PRECISION = Array.tabulate[Int](38) { precision => + var length = 1 + while (math.pow(2.0, 8 * length - 1) < math.pow(10.0, precision)) { + length += 1 + } + length + } + /** * Converts a given Catalyst [[org.apache.spark.sql.catalyst.types.DataType]] into * the corresponding Parquet `Type`. @@ -247,10 +277,17 @@ private[parquet] object ParquetTypesConverter extends Logging { } else { if (nullable) Repetition.OPTIONAL else Repetition.REQUIRED } - val primitiveType = fromPrimitiveDataType(ctype) - primitiveType.map { - case (primitiveType, originalType) => - new ParquetPrimitiveType(repetition, primitiveType, name, originalType.orNull) + val typeInfo = fromPrimitiveDataType(ctype) + typeInfo.map { + case ParquetTypeInfo(primitiveType, originalType, decimalMetadata, length) => + val builder = ParquetTypes.primitive(primitiveType, repetition).as(originalType.orNull) + for (len <- length) { + builder.length(len) + } + for (metadata <- decimalMetadata) { + builder.precision(metadata.getPrecision).scale(metadata.getScale) + } + builder.named(name) }.getOrElse { ctype match { case ArrayType(elementType, false) => { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala b/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala index 142598c904..7564bf3923 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.types.util import org.apache.spark.sql._ import org.apache.spark.sql.api.java.{DataType => JDataType, StructField => JStructField, MetadataBuilder => JMetaDataBuilder} +import org.apache.spark.sql.api.java.{DecimalType => JDecimalType} +import org.apache.spark.sql.catalyst.types.decimal.Decimal import scala.collection.JavaConverters._ @@ -44,7 +46,8 @@ protected[sql] object DataTypeConversions { case BooleanType => JDataType.BooleanType case DateType => JDataType.DateType case TimestampType => JDataType.TimestampType - case DecimalType => JDataType.DecimalType + case DecimalType.Fixed(precision, scale) => new JDecimalType(precision, scale) + case DecimalType.Unlimited => new JDecimalType() case DoubleType => JDataType.DoubleType case FloatType => JDataType.FloatType case ByteType => JDataType.ByteType @@ -88,7 +91,11 @@ protected[sql] object DataTypeConversions { case timestampType: org.apache.spark.sql.api.java.TimestampType => TimestampType case decimalType: org.apache.spark.sql.api.java.DecimalType => - DecimalType + if (decimalType.isFixed) { + DecimalType(decimalType.getPrecision, decimalType.getScale) + } else { + DecimalType.Unlimited + } case doubleType: org.apache.spark.sql.api.java.DoubleType => DoubleType case floatType: org.apache.spark.sql.api.java.FloatType => @@ -115,7 +122,7 @@ protected[sql] object DataTypeConversions { /** Converts Java objects to catalyst rows / types */ def convertJavaToCatalyst(a: Any): Any = a match { - case d: java.math.BigDecimal => BigDecimal(d) + case d: java.math.BigDecimal => Decimal(BigDecimal(d)) case other => other } diff --git a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaApplySchemaSuite.java b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaApplySchemaSuite.java index 9435a88009..a04b8060cd 100644 --- a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaApplySchemaSuite.java +++ b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaApplySchemaSuite.java @@ -118,7 +118,7 @@ public class JavaApplySchemaSuite implements Serializable { "\"bigInteger\":92233720368547758069, \"double\":1.7976931348623157E305, " + "\"boolean\":false, \"null\":null}")); List fields = new ArrayList(7); - fields.add(DataType.createStructField("bigInteger", DataType.DecimalType, true)); + fields.add(DataType.createStructField("bigInteger", new DecimalType(), true)); fields.add(DataType.createStructField("boolean", DataType.BooleanType, true)); fields.add(DataType.createStructField("double", DataType.DoubleType, true)); fields.add(DataType.createStructField("integer", DataType.IntegerType, true)); diff --git a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaSideDataTypeConversionSuite.java b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaSideDataTypeConversionSuite.java index d04396a5f8..8396a29c61 100644 --- a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaSideDataTypeConversionSuite.java +++ b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaSideDataTypeConversionSuite.java @@ -41,7 +41,8 @@ public class JavaSideDataTypeConversionSuite { checkDataType(DataType.BooleanType); checkDataType(DataType.DateType); checkDataType(DataType.TimestampType); - checkDataType(DataType.DecimalType); + checkDataType(new DecimalType()); + checkDataType(new DecimalType(10, 4)); checkDataType(DataType.DoubleType); checkDataType(DataType.FloatType); checkDataType(DataType.ByteType); @@ -59,7 +60,7 @@ public class JavaSideDataTypeConversionSuite { // Simple StructType. List simpleFields = new ArrayList(); - simpleFields.add(DataType.createStructField("a", DataType.DecimalType, false)); + simpleFields.add(DataType.createStructField("a", new DecimalType(), false)); simpleFields.add(DataType.createStructField("b", DataType.BooleanType, true)); simpleFields.add(DataType.createStructField("c", DataType.LongType, true)); simpleFields.add(DataType.createStructField("d", DataType.BinaryType, false)); @@ -128,7 +129,7 @@ public class JavaSideDataTypeConversionSuite { // StructType try { List simpleFields = new ArrayList(); - simpleFields.add(DataType.createStructField("a", DataType.DecimalType, false)); + simpleFields.add(DataType.createStructField("a", new DecimalType(), false)); simpleFields.add(DataType.createStructField("b", DataType.BooleanType, true)); simpleFields.add(DataType.createStructField("c", DataType.LongType, true)); simpleFields.add(null); @@ -138,7 +139,7 @@ public class JavaSideDataTypeConversionSuite { } try { List simpleFields = new ArrayList(); - simpleFields.add(DataType.createStructField("a", DataType.DecimalType, false)); + simpleFields.add(DataType.createStructField("a", new DecimalType(), false)); simpleFields.add(DataType.createStructField("a", DataType.BooleanType, true)); simpleFields.add(DataType.createStructField("c", DataType.LongType, true)); DataType.createStructType(simpleFields); diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataTypeSuite.scala index 6c9db639c0..e9740d913c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataTypeSuite.scala @@ -69,7 +69,7 @@ class DataTypeSuite extends FunSuite { checkDataTypeJsonRepr(LongType) checkDataTypeJsonRepr(FloatType) checkDataTypeJsonRepr(DoubleType) - checkDataTypeJsonRepr(DecimalType) + checkDataTypeJsonRepr(DecimalType.Unlimited) checkDataTypeJsonRepr(TimestampType) checkDataTypeJsonRepr(StringType) checkDataTypeJsonRepr(BinaryType) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala index bfa9ea4162..cf3a59e545 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql import java.sql.{Date, Timestamp} +import org.apache.spark.sql.catalyst.types.decimal.Decimal import org.scalatest.FunSuite import org.apache.spark.sql.catalyst.expressions._ @@ -81,7 +82,9 @@ class ScalaReflectionRelationSuite extends FunSuite { val rdd = sparkContext.parallelize(data :: Nil) rdd.registerTempTable("reflectData") - assert(sql("SELECT * FROM reflectData").collect().head === data.productIterator.toSeq) + assert(sql("SELECT * FROM reflectData").collect().head === + Seq("a", 1, 1L, 1.toFloat, 1.toDouble, 1.toShort, 1.toByte, true, + BigDecimal(1), new Date(12345), new Timestamp(12345), Seq(1,2,3))) } test("query case class RDD with nulls") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/api/java/JavaSQLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/api/java/JavaSQLSuite.scala index d83f3e23a9..c9012c9e47 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/api/java/JavaSQLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/api/java/JavaSQLSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.api.java +import org.apache.spark.sql.catalyst.types.decimal.Decimal + import scala.beans.BeanProperty import org.scalatest.FunSuite diff --git a/sql/core/src/test/scala/org/apache/spark/sql/api/java/ScalaSideDataTypeConversionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/api/java/ScalaSideDataTypeConversionSuite.scala index e0e0ff9cb3..62fe59dd34 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/api/java/ScalaSideDataTypeConversionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/api/java/ScalaSideDataTypeConversionSuite.scala @@ -38,7 +38,7 @@ class ScalaSideDataTypeConversionSuite extends FunSuite { checkDataType(org.apache.spark.sql.BooleanType) checkDataType(org.apache.spark.sql.DateType) checkDataType(org.apache.spark.sql.TimestampType) - checkDataType(org.apache.spark.sql.DecimalType) + checkDataType(org.apache.spark.sql.DecimalType.Unlimited) checkDataType(org.apache.spark.sql.DoubleType) checkDataType(org.apache.spark.sql.FloatType) checkDataType(org.apache.spark.sql.ByteType) @@ -58,7 +58,7 @@ class ScalaSideDataTypeConversionSuite extends FunSuite { // Simple StructType. val simpleScalaStructType = SStructType( - SStructField("a", org.apache.spark.sql.DecimalType, false) :: + SStructField("a", org.apache.spark.sql.DecimalType.Unlimited, false) :: SStructField("b", org.apache.spark.sql.BooleanType, true) :: SStructField("c", org.apache.spark.sql.LongType, true) :: SStructField("d", org.apache.spark.sql.BinaryType, false) :: Nil) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala index ce6184f5d8..1cb6c23c58 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.json import org.apache.spark.sql.catalyst.types._ +import org.apache.spark.sql.catalyst.types.decimal.Decimal import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.json.JsonRDD.{enforceCorrectType, compatibleType} import org.apache.spark.sql.QueryTest @@ -44,19 +45,22 @@ class JsonSuite extends QueryTest { checkTypePromotion(intNumber, enforceCorrectType(intNumber, IntegerType)) checkTypePromotion(intNumber.toLong, enforceCorrectType(intNumber, LongType)) checkTypePromotion(intNumber.toDouble, enforceCorrectType(intNumber, DoubleType)) - checkTypePromotion(BigDecimal(intNumber), enforceCorrectType(intNumber, DecimalType)) + checkTypePromotion( + Decimal(intNumber), enforceCorrectType(intNumber, DecimalType.Unlimited)) val longNumber: Long = 9223372036854775807L checkTypePromotion(longNumber, enforceCorrectType(longNumber, LongType)) checkTypePromotion(longNumber.toDouble, enforceCorrectType(longNumber, DoubleType)) - checkTypePromotion(BigDecimal(longNumber), enforceCorrectType(longNumber, DecimalType)) + checkTypePromotion( + Decimal(longNumber), enforceCorrectType(longNumber, DecimalType.Unlimited)) val doubleNumber: Double = 1.7976931348623157E308d checkTypePromotion(doubleNumber.toDouble, enforceCorrectType(doubleNumber, DoubleType)) - checkTypePromotion(BigDecimal(doubleNumber), enforceCorrectType(doubleNumber, DecimalType)) - + checkTypePromotion( + Decimal(doubleNumber), enforceCorrectType(doubleNumber, DecimalType.Unlimited)) + checkTypePromotion(new Timestamp(intNumber), enforceCorrectType(intNumber, TimestampType)) - checkTypePromotion(new Timestamp(intNumber.toLong), + checkTypePromotion(new Timestamp(intNumber.toLong), enforceCorrectType(intNumber.toLong, TimestampType)) val strTime = "2014-09-30 12:34:56" checkTypePromotion(Timestamp.valueOf(strTime), enforceCorrectType(strTime, TimestampType)) @@ -80,7 +84,7 @@ class JsonSuite extends QueryTest { checkDataType(NullType, IntegerType, IntegerType) checkDataType(NullType, LongType, LongType) checkDataType(NullType, DoubleType, DoubleType) - checkDataType(NullType, DecimalType, DecimalType) + checkDataType(NullType, DecimalType.Unlimited, DecimalType.Unlimited) checkDataType(NullType, StringType, StringType) checkDataType(NullType, ArrayType(IntegerType), ArrayType(IntegerType)) checkDataType(NullType, StructType(Nil), StructType(Nil)) @@ -91,7 +95,7 @@ class JsonSuite extends QueryTest { checkDataType(BooleanType, IntegerType, StringType) checkDataType(BooleanType, LongType, StringType) checkDataType(BooleanType, DoubleType, StringType) - checkDataType(BooleanType, DecimalType, StringType) + checkDataType(BooleanType, DecimalType.Unlimited, StringType) checkDataType(BooleanType, StringType, StringType) checkDataType(BooleanType, ArrayType(IntegerType), StringType) checkDataType(BooleanType, StructType(Nil), StringType) @@ -100,7 +104,7 @@ class JsonSuite extends QueryTest { checkDataType(IntegerType, IntegerType, IntegerType) checkDataType(IntegerType, LongType, LongType) checkDataType(IntegerType, DoubleType, DoubleType) - checkDataType(IntegerType, DecimalType, DecimalType) + checkDataType(IntegerType, DecimalType.Unlimited, DecimalType.Unlimited) checkDataType(IntegerType, StringType, StringType) checkDataType(IntegerType, ArrayType(IntegerType), StringType) checkDataType(IntegerType, StructType(Nil), StringType) @@ -108,23 +112,23 @@ class JsonSuite extends QueryTest { // LongType checkDataType(LongType, LongType, LongType) checkDataType(LongType, DoubleType, DoubleType) - checkDataType(LongType, DecimalType, DecimalType) + checkDataType(LongType, DecimalType.Unlimited, DecimalType.Unlimited) checkDataType(LongType, StringType, StringType) checkDataType(LongType, ArrayType(IntegerType), StringType) checkDataType(LongType, StructType(Nil), StringType) // DoubleType checkDataType(DoubleType, DoubleType, DoubleType) - checkDataType(DoubleType, DecimalType, DecimalType) + checkDataType(DoubleType, DecimalType.Unlimited, DecimalType.Unlimited) checkDataType(DoubleType, StringType, StringType) checkDataType(DoubleType, ArrayType(IntegerType), StringType) checkDataType(DoubleType, StructType(Nil), StringType) // DoubleType - checkDataType(DecimalType, DecimalType, DecimalType) - checkDataType(DecimalType, StringType, StringType) - checkDataType(DecimalType, ArrayType(IntegerType), StringType) - checkDataType(DecimalType, StructType(Nil), StringType) + checkDataType(DecimalType.Unlimited, DecimalType.Unlimited, DecimalType.Unlimited) + checkDataType(DecimalType.Unlimited, StringType, StringType) + checkDataType(DecimalType.Unlimited, ArrayType(IntegerType), StringType) + checkDataType(DecimalType.Unlimited, StructType(Nil), StringType) // StringType checkDataType(StringType, StringType, StringType) @@ -178,7 +182,7 @@ class JsonSuite extends QueryTest { checkDataType( StructType( StructField("f1", IntegerType, true) :: Nil), - DecimalType, + DecimalType.Unlimited, StringType) } @@ -186,7 +190,7 @@ class JsonSuite extends QueryTest { val jsonSchemaRDD = jsonRDD(primitiveFieldAndType) val expectedSchema = StructType( - StructField("bigInteger", DecimalType, true) :: + StructField("bigInteger", DecimalType.Unlimited, true) :: StructField("boolean", BooleanType, true) :: StructField("double", DoubleType, true) :: StructField("integer", IntegerType, true) :: @@ -216,7 +220,7 @@ class JsonSuite extends QueryTest { val expectedSchema = StructType( StructField("arrayOfArray1", ArrayType(ArrayType(StringType, false), false), true) :: StructField("arrayOfArray2", ArrayType(ArrayType(DoubleType, false), false), true) :: - StructField("arrayOfBigInteger", ArrayType(DecimalType, false), true) :: + StructField("arrayOfBigInteger", ArrayType(DecimalType.Unlimited, false), true) :: StructField("arrayOfBoolean", ArrayType(BooleanType, false), true) :: StructField("arrayOfDouble", ArrayType(DoubleType, false), true) :: StructField("arrayOfInteger", ArrayType(IntegerType, false), true) :: @@ -230,7 +234,7 @@ class JsonSuite extends QueryTest { StructField("field3", StringType, true) :: Nil), false), true) :: StructField("struct", StructType( StructField("field1", BooleanType, true) :: - StructField("field2", DecimalType, true) :: Nil), true) :: + StructField("field2", DecimalType.Unlimited, true) :: Nil), true) :: StructField("structWithArrayFields", StructType( StructField("field1", ArrayType(IntegerType, false), true) :: StructField("field2", ArrayType(StringType, false), true) :: Nil), true) :: Nil) @@ -331,7 +335,7 @@ class JsonSuite extends QueryTest { val expectedSchema = StructType( StructField("num_bool", StringType, true) :: StructField("num_num_1", LongType, true) :: - StructField("num_num_2", DecimalType, true) :: + StructField("num_num_2", DecimalType.Unlimited, true) :: StructField("num_num_3", DoubleType, true) :: StructField("num_str", StringType, true) :: StructField("str_bool", StringType, true) :: Nil) @@ -521,7 +525,7 @@ class JsonSuite extends QueryTest { val jsonSchemaRDD = jsonFile(path) val expectedSchema = StructType( - StructField("bigInteger", DecimalType, true) :: + StructField("bigInteger", DecimalType.Unlimited, true) :: StructField("boolean", BooleanType, true) :: StructField("double", DoubleType, true) :: StructField("integer", IntegerType, true) :: @@ -551,7 +555,7 @@ class JsonSuite extends QueryTest { primitiveFieldAndType.map(record => record.replaceAll("\n", " ")).saveAsTextFile(path) val schema = StructType( - StructField("bigInteger", DecimalType, true) :: + StructField("bigInteger", DecimalType.Unlimited, true) :: StructField("boolean", BooleanType, true) :: StructField("double", DoubleType, true) :: StructField("integer", IntegerType, true) :: diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala index 9979ab446d..08d9da27f1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala @@ -77,6 +77,8 @@ case class AllDataTypesWithNonPrimitiveType( case class BinaryData(binaryData: Array[Byte]) +case class NumericData(i: Int, d: Double) + class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterAll { TestData // Load test data tables. @@ -560,7 +562,7 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA assert(stringResult.size === 1) assert(stringResult(0).getString(2) == "100", "stringvalue incorrect") assert(stringResult(0).getInt(1) === 100) - + val query7 = sql(s"SELECT * FROM testfiltersource WHERE myoptint < 40") assert( query7.queryExecution.executedPlan(0)(0).isInstanceOf[ParquetTableScan], @@ -869,4 +871,35 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA assert(a.dataType === b.dataType) } } + + test("read/write fixed-length decimals") { + for ((precision, scale) <- Seq((5, 2), (1, 0), (1, 1), (18, 10), (18, 17))) { + val tempDir = getTempFilePath("parquetTest").getCanonicalPath + val data = sparkContext.parallelize(0 to 1000) + .map(i => NumericData(i, i / 100.0)) + .select('i, 'd cast DecimalType(precision, scale)) + data.saveAsParquetFile(tempDir) + checkAnswer(parquetFile(tempDir), data.toSchemaRDD.collect().toSeq) + } + + // Decimals with precision above 18 are not yet supported + intercept[RuntimeException] { + val tempDir = getTempFilePath("parquetTest").getCanonicalPath + val data = sparkContext.parallelize(0 to 1000) + .map(i => NumericData(i, i / 100.0)) + .select('i, 'd cast DecimalType(19, 10)) + data.saveAsParquetFile(tempDir) + checkAnswer(parquetFile(tempDir), data.toSchemaRDD.collect().toSeq) + } + + // Unlimited-length decimals are not yet supported + intercept[RuntimeException] { + val tempDir = getTempFilePath("parquetTest").getCanonicalPath + val data = sparkContext.parallelize(0 to 1000) + .map(i => NumericData(i, i / 100.0)) + .select('i, 'd cast DecimalType.Unlimited) + data.saveAsParquetFile(tempDir) + checkAnswer(parquetFile(tempDir), data.toSchemaRDD.collect().toSeq) + } + } } diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala index 2a4f24132c..99c4f46a82 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala @@ -47,7 +47,7 @@ private[thriftserver] class SparkSQLOperationManager(hiveContext: HiveContext) val operation = new SparkExecuteStatementOperation(parentSession, statement, confOverlay)( hiveContext, sessionToActivePool) - handleToOperation.put(operation.getHandle, operation) - operation + handleToOperation.put(operation.getHandle, operation) + operation } } diff --git a/sql/hive-thriftserver/v0.12.0/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim12.scala b/sql/hive-thriftserver/v0.12.0/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim12.scala index bbd727c686..8077d0ec46 100644 --- a/sql/hive-thriftserver/v0.12.0/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim12.scala +++ b/sql/hive-thriftserver/v0.12.0/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim12.scala @@ -123,7 +123,7 @@ private[hive] class SparkExecuteStatementOperation( to.addColumnValue(ColumnValue.doubleValue(from.getDouble(ordinal))) case FloatType => to.addColumnValue(ColumnValue.floatValue(from.getFloat(ordinal))) - case DecimalType => + case DecimalType() => val hiveDecimal = from.get(ordinal).asInstanceOf[BigDecimal].bigDecimal to.addColumnValue(ColumnValue.stringValue(new HiveDecimal(hiveDecimal))) case LongType => @@ -156,7 +156,7 @@ private[hive] class SparkExecuteStatementOperation( to.addColumnValue(ColumnValue.doubleValue(null)) case FloatType => to.addColumnValue(ColumnValue.floatValue(null)) - case DecimalType => + case DecimalType() => to.addColumnValue(ColumnValue.stringValue(null: HiveDecimal)) case LongType => to.addColumnValue(ColumnValue.longValue(null)) diff --git a/sql/hive-thriftserver/v0.13.1/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim13.scala b/sql/hive-thriftserver/v0.13.1/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim13.scala index e59681bfbe..2c1983de1d 100644 --- a/sql/hive-thriftserver/v0.13.1/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim13.scala +++ b/sql/hive-thriftserver/v0.13.1/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim13.scala @@ -123,7 +123,7 @@ private[hive] class SparkExecuteStatementOperation( to += from.getDouble(ordinal) case FloatType => to += from.getFloat(ordinal) - case DecimalType => + case DecimalType() => to += from.get(ordinal).asInstanceOf[BigDecimal].bigDecimal case LongType => to += from.getLong(ordinal) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index ff8fa44194..2e27817d60 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -21,6 +21,10 @@ import java.io.{BufferedReader, File, InputStreamReader, PrintStream} import java.sql.{Date, Timestamp} import java.util.{ArrayList => JArrayList} +import org.apache.hadoop.hive.common.`type`.HiveDecimal +import org.apache.spark.sql.catalyst.types.DecimalType +import org.apache.spark.sql.catalyst.types.decimal.Decimal + import scala.collection.JavaConversions._ import scala.language.implicitConversions import scala.reflect.runtime.universe.{TypeTag, typeTag} @@ -370,7 +374,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { protected val primitiveTypes = Seq(StringType, IntegerType, LongType, DoubleType, FloatType, BooleanType, ByteType, - ShortType, DecimalType, DateType, TimestampType, BinaryType) + ShortType, DateType, TimestampType, BinaryType) protected[sql] def toHiveString(a: (Any, DataType)): String = a match { case (struct: Row, StructType(fields)) => @@ -388,6 +392,8 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { case (d: Date, DateType) => new DateWritable(d).toString case (t: Timestamp, TimestampType) => new TimestampWritable(t).toString case (bin: Array[Byte], BinaryType) => new String(bin, "UTF-8") + case (decimal: Decimal, DecimalType()) => // Hive strips trailing zeros so use its toString + HiveShim.createDecimal(decimal.toBigDecimal.underlying()).toString case (other, tpe) if primitiveTypes contains tpe => other.toString } @@ -406,6 +412,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { }.toSeq.sorted.mkString("{", ",", "}") case (null, _) => "null" case (s: String, StringType) => "\"" + s + "\"" + case (decimal, DecimalType()) => decimal.toString case (other, tpe) if primitiveTypes contains tpe => other.toString } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala index 0439ab97d8..1e2bf5cc4b 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala @@ -28,6 +28,7 @@ import org.apache.hadoop.{io => hadoopIo} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.types import org.apache.spark.sql.catalyst.types._ +import org.apache.spark.sql.catalyst.types.decimal.Decimal /* Implicit conversions */ import scala.collection.JavaConversions._ @@ -38,7 +39,7 @@ private[hive] trait HiveInspectors { // writable case c: Class[_] if c == classOf[hadoopIo.DoubleWritable] => DoubleType case c: Class[_] if c == classOf[hiveIo.DoubleWritable] => DoubleType - case c: Class[_] if c == classOf[hiveIo.HiveDecimalWritable] => DecimalType + case c: Class[_] if c == classOf[hiveIo.HiveDecimalWritable] => DecimalType.Unlimited case c: Class[_] if c == classOf[hiveIo.ByteWritable] => ByteType case c: Class[_] if c == classOf[hiveIo.ShortWritable] => ShortType case c: Class[_] if c == classOf[hiveIo.DateWritable] => DateType @@ -54,8 +55,8 @@ private[hive] trait HiveInspectors { case c: Class[_] if c == classOf[java.lang.String] => StringType case c: Class[_] if c == classOf[java.sql.Date] => DateType case c: Class[_] if c == classOf[java.sql.Timestamp] => TimestampType - case c: Class[_] if c == classOf[HiveDecimal] => DecimalType - case c: Class[_] if c == classOf[java.math.BigDecimal] => DecimalType + case c: Class[_] if c == classOf[HiveDecimal] => DecimalType.Unlimited + case c: Class[_] if c == classOf[java.math.BigDecimal] => DecimalType.Unlimited case c: Class[_] if c == classOf[Array[Byte]] => BinaryType case c: Class[_] if c == classOf[java.lang.Short] => ShortType case c: Class[_] if c == classOf[java.lang.Integer] => IntegerType @@ -90,7 +91,7 @@ private[hive] trait HiveInspectors { case hvoi: HiveVarcharObjectInspector => if (data == null) null else hvoi.getPrimitiveJavaObject(data).getValue case hdoi: HiveDecimalObjectInspector => - if (data == null) null else BigDecimal(hdoi.getPrimitiveJavaObject(data).bigDecimalValue()) + if (data == null) null else HiveShim.toCatalystDecimal(hdoi, data) // org.apache.hadoop.hive.serde2.io.TimestampWritable.set will reset current time object // if next timestamp is null, so Timestamp object is cloned case ti: TimestampObjectInspector => ti.getPrimitiveJavaObject(data).clone() @@ -137,8 +138,9 @@ private[hive] trait HiveInspectors { case l: Short => l: java.lang.Short case l: Byte => l: java.lang.Byte case b: BigDecimal => HiveShim.createDecimal(b.underlying()) + case d: Decimal => HiveShim.createDecimal(d.toBigDecimal.underlying()) case b: Array[Byte] => b - case d: java.sql.Date => d + case d: java.sql.Date => d case t: java.sql.Timestamp => t } case x: StructObjectInspector => @@ -200,7 +202,7 @@ private[hive] trait HiveInspectors { case BinaryType => PrimitiveObjectInspectorFactory.javaByteArrayObjectInspector case DateType => PrimitiveObjectInspectorFactory.javaDateObjectInspector case TimestampType => PrimitiveObjectInspectorFactory.javaTimestampObjectInspector - case DecimalType => PrimitiveObjectInspectorFactory.javaHiveDecimalObjectInspector + case DecimalType() => PrimitiveObjectInspectorFactory.javaHiveDecimalObjectInspector case StructType(fields) => ObjectInspectorFactory.getStandardStructObjectInspector( fields.map(f => f.name), fields.map(f => toInspector(f.dataType))) @@ -229,8 +231,10 @@ private[hive] trait HiveInspectors { HiveShim.getPrimitiveWritableConstantObjectInspector(value) case Literal(value: java.sql.Timestamp, TimestampType) => HiveShim.getPrimitiveWritableConstantObjectInspector(value) - case Literal(value: BigDecimal, DecimalType) => + case Literal(value: BigDecimal, DecimalType()) => HiveShim.getPrimitiveWritableConstantObjectInspector(value) + case Literal(value: Decimal, DecimalType()) => + HiveShim.getPrimitiveWritableConstantObjectInspector(value.toBigDecimal) case Literal(_, NullType) => HiveShim.getPrimitiveNullWritableConstantObjectInspector case Literal(value: Seq[_], ArrayType(dt, _)) => @@ -277,8 +281,8 @@ private[hive] trait HiveInspectors { case _: JavaFloatObjectInspector => FloatType case _: WritableBinaryObjectInspector => BinaryType case _: JavaBinaryObjectInspector => BinaryType - case _: WritableHiveDecimalObjectInspector => DecimalType - case _: JavaHiveDecimalObjectInspector => DecimalType + case w: WritableHiveDecimalObjectInspector => HiveShim.decimalTypeInfoToCatalyst(w) + case j: JavaHiveDecimalObjectInspector => HiveShim.decimalTypeInfoToCatalyst(j) case _: WritableDateObjectInspector => DateType case _: JavaDateObjectInspector => DateType case _: WritableTimestampObjectInspector => TimestampType @@ -307,7 +311,7 @@ private[hive] trait HiveInspectors { case LongType => longTypeInfo case ShortType => shortTypeInfo case StringType => stringTypeInfo - case DecimalType => decimalTypeInfo + case d: DecimalType => HiveShim.decimalTypeInfo(d) case DateType => dateTypeInfo case TimestampType => timestampTypeInfo case NullType => voidTypeInfo diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 2dd2c882a8..096b4a07aa 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.hive import java.io.IOException import java.util.{List => JList} +import scala.util.matching.Regex import scala.util.parsing.combinator.RegexParsers import org.apache.hadoop.util.ReflectionUtils @@ -321,11 +322,18 @@ object HiveMetastoreTypes extends RegexParsers { "bigint" ^^^ LongType | "binary" ^^^ BinaryType | "boolean" ^^^ BooleanType | - HiveShim.metastoreDecimal ^^^ DecimalType | + fixedDecimalType | // Hive 0.13+ decimal with precision/scale + "decimal" ^^^ DecimalType.Unlimited | // Hive 0.12 decimal with no precision/scale "date" ^^^ DateType | "timestamp" ^^^ TimestampType | "varchar\\((\\d+)\\)".r ^^^ StringType + protected lazy val fixedDecimalType: Parser[DataType] = + ("decimal" ~> "(" ~> "\\d+".r) ~ ("," ~> "\\d+".r <~ ")") ^^ { + case precision ~ scale => + DecimalType(precision.toInt, scale.toInt) + } + protected lazy val arrayType: Parser[DataType] = "array" ~> "<" ~> dataType <~ ">" ^^ { case tpe => ArrayType(tpe) @@ -373,7 +381,7 @@ object HiveMetastoreTypes extends RegexParsers { case BinaryType => "binary" case BooleanType => "boolean" case DateType => "date" - case DecimalType => "decimal" + case d: DecimalType => HiveShim.decimalMetastoreString(d) case TimestampType => "timestamp" case NullType => "void" } @@ -441,7 +449,7 @@ private[hive] case class MetastoreRelation val partitionKeys = hiveQlTable.getPartitionKeys.map(_.toAttribute) /** Non-partitionKey attributes */ - val attributes = hiveQlTable.getCols.map(_.toAttribute) + val attributes = hiveQlTable.getCols.map(_.toAttribute) val output = attributes ++ partitionKeys diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index a3573e6502..74f68d0f95 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -31,6 +31,7 @@ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.types._ +import org.apache.spark.sql.catalyst.types.decimal.Decimal /* Implicit conversions */ import scala.collection.JavaConversions._ @@ -325,7 +326,11 @@ private[hive] object HiveQl { } protected def nodeToDataType(node: Node): DataType = node match { - case Token("TOK_DECIMAL", Nil) => DecimalType + case Token("TOK_DECIMAL", precision :: scale :: Nil) => + DecimalType(precision.getText.toInt, scale.getText.toInt) + case Token("TOK_DECIMAL", precision :: Nil) => + DecimalType(precision.getText.toInt, 0) + case Token("TOK_DECIMAL", Nil) => DecimalType.Unlimited case Token("TOK_BIGINT", Nil) => LongType case Token("TOK_INT", Nil) => IntegerType case Token("TOK_TINYINT", Nil) => ByteType @@ -942,8 +947,12 @@ private[hive] object HiveQl { Cast(nodeToExpr(arg), BinaryType) case Token("TOK_FUNCTION", Token("TOK_BOOLEAN", Nil) :: arg :: Nil) => Cast(nodeToExpr(arg), BooleanType) + case Token("TOK_FUNCTION", Token("TOK_DECIMAL", precision :: scale :: nil) :: arg :: Nil) => + Cast(nodeToExpr(arg), DecimalType(precision.getText.toInt, scale.getText.toInt)) + case Token("TOK_FUNCTION", Token("TOK_DECIMAL", precision :: Nil) :: arg :: Nil) => + Cast(nodeToExpr(arg), DecimalType(precision.getText.toInt, 0)) case Token("TOK_FUNCTION", Token("TOK_DECIMAL", Nil) :: arg :: Nil) => - Cast(nodeToExpr(arg), DecimalType) + Cast(nodeToExpr(arg), DecimalType.Unlimited) case Token("TOK_FUNCTION", Token("TOK_TIMESTAMP", Nil) :: arg :: Nil) => Cast(nodeToExpr(arg), TimestampType) case Token("TOK_FUNCTION", Token("TOK_DATE", Nil) :: arg :: Nil) => @@ -1063,7 +1072,7 @@ private[hive] object HiveQl { } else if (ast.getText.endsWith("BD") || ast.getText.endsWith("D")) { // Literal decimal val strVal = ast.getText.stripSuffix("D").stripSuffix("B") - v = Literal(BigDecimal(strVal)) + v = Literal(Decimal(strVal)) } else { v = Literal(ast.getText.toDouble, DoubleType) v = Literal(ast.getText.toLong, LongType) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala index 79234f8a66..92bc1c6625 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala @@ -35,6 +35,7 @@ import org.apache.hadoop.mapred.{FileOutputCommitter, FileOutputFormat, JobConf} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.expressions.Row +import org.apache.spark.sql.catalyst.types.decimal.Decimal import org.apache.spark.sql.execution.{Command, SparkPlan, UnaryNode} import org.apache.spark.sql.hive._ import org.apache.spark.sql.hive.{ ShimFileSinkDesc => FileSinkDesc} @@ -76,7 +77,7 @@ case class InsertIntoHiveTable( (o: Any) => new HiveVarchar(o.asInstanceOf[String], o.asInstanceOf[String].size) case _: JavaHiveDecimalObjectInspector => - (o: Any) => HiveShim.createDecimal(o.asInstanceOf[BigDecimal].underlying()) + (o: Any) => HiveShim.createDecimal(o.asInstanceOf[Decimal].toBigDecimal.underlying()) case soi: StandardStructObjectInspector => val wrappers = soi.getAllStructFieldRefs.map(ref => wrapperFor(ref.getFieldObjectInspector)) diff --git a/sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala b/sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala index afc252ac27..8e946b7e82 100644 --- a/sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala +++ b/sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala @@ -30,21 +30,24 @@ import org.apache.hadoop.hive.ql.plan.{CreateTableDesc, FileSinkDesc, TableDesc} import org.apache.hadoop.hive.ql.processors._ import org.apache.hadoop.hive.ql.stats.StatsSetupConst import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector.PrimitiveCategory -import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory -import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector +import org.apache.hadoop.hive.serde2.objectinspector.primitive.{HiveDecimalObjectInspector, PrimitiveObjectInspectorFactory} +import org.apache.hadoop.hive.serde2.objectinspector.{PrimitiveObjectInspector, ObjectInspector} +import org.apache.hadoop.hive.serde2.typeinfo.{TypeInfo, TypeInfoFactory} import org.apache.hadoop.hive.serde2.{Deserializer, ColumnProjectionUtils} import org.apache.hadoop.hive.serde2.{io => hiveIo} import org.apache.hadoop.{io => hadoopIo} import org.apache.hadoop.mapred.InputFormat +import org.apache.spark.sql.catalyst.types.decimal.Decimal import scala.collection.JavaConversions._ import scala.language.implicitConversions +import org.apache.spark.sql.catalyst.types.DecimalType + /** * A compatibility layer for interacting with Hive version 0.12.0. */ private[hive] object HiveShim { val version = "0.12.0" - val metastoreDecimal = "decimal" def getTableDesc( serdeClass: Class[_ <: Deserializer], @@ -149,6 +152,19 @@ private[hive] object HiveShim { def setLocation(tbl: Table, crtTbl: CreateTableDesc): Unit = { tbl.setDataLocation(new Path(crtTbl.getLocation()).toUri()) } + + def decimalMetastoreString(decimalType: DecimalType): String = "decimal" + + def decimalTypeInfo(decimalType: DecimalType): TypeInfo = + TypeInfoFactory.decimalTypeInfo + + def decimalTypeInfoToCatalyst(inspector: PrimitiveObjectInspector): DecimalType = { + DecimalType.Unlimited + } + + def toCatalystDecimal(hdoi: HiveDecimalObjectInspector, data: Any): Decimal = { + Decimal(hdoi.getPrimitiveJavaObject(data).bigDecimalValue()) + } } class ShimFileSinkDesc(var dir: String, var tableInfo: TableDesc, var compressed: Boolean) diff --git a/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala b/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala index 42cd65b251..0bc330cdbe 100644 --- a/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala +++ b/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala @@ -29,15 +29,15 @@ import org.apache.hadoop.hive.ql.Context import org.apache.hadoop.hive.ql.metadata.{Table, Hive, Partition} import org.apache.hadoop.hive.ql.plan.{CreateTableDesc, FileSinkDesc, TableDesc} import org.apache.hadoop.hive.ql.processors.CommandProcessorFactory -import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoFactory -import org.apache.hadoop.hive.serde2.{ColumnProjectionUtils, Deserializer} -import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector.PrimitiveCategory -import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory -import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector +import org.apache.hadoop.hive.serde2.typeinfo.{TypeInfo, DecimalTypeInfo, TypeInfoFactory} +import org.apache.hadoop.hive.serde2.objectinspector.primitive.{HiveDecimalObjectInspector, PrimitiveObjectInspectorFactory} +import org.apache.hadoop.hive.serde2.objectinspector.{PrimitiveObjectInspector, ObjectInspector} import org.apache.hadoop.hive.serde2.{Deserializer, ColumnProjectionUtils} import org.apache.hadoop.hive.serde2.{io => hiveIo} import org.apache.hadoop.{io => hadoopIo} import org.apache.spark.Logging +import org.apache.spark.sql.catalyst.types.DecimalType +import org.apache.spark.sql.catalyst.types.decimal.Decimal import scala.collection.JavaConversions._ import scala.language.implicitConversions @@ -47,11 +47,6 @@ import scala.language.implicitConversions */ private[hive] object HiveShim { val version = "0.13.1" - /* - * TODO: hive-0.13 support DECIMAL(precision, scale), DECIMAL in hive-0.12 is actually DECIMAL(38,unbounded) - * Full support of new decimal feature need to be fixed in seperate PR. - */ - val metastoreDecimal = "decimal\\((\\d+),(\\d+)\\)".r def getTableDesc( serdeClass: Class[_ <: Deserializer], @@ -197,6 +192,30 @@ private[hive] object HiveShim { f.setDestTableId(w.destTableId) f } + + // Precision and scale to pass for unlimited decimals; these are the same as the precision and + // scale Hive 0.13 infers for BigDecimals from sources that don't specify them (e.g. UDFs) + private val UNLIMITED_DECIMAL_PRECISION = 38 + private val UNLIMITED_DECIMAL_SCALE = 18 + + def decimalMetastoreString(decimalType: DecimalType): String = decimalType match { + case DecimalType.Fixed(precision, scale) => s"decimal($precision,$scale)" + case _ => s"decimal($UNLIMITED_DECIMAL_PRECISION,$UNLIMITED_DECIMAL_SCALE)" + } + + def decimalTypeInfo(decimalType: DecimalType): TypeInfo = decimalType match { + case DecimalType.Fixed(precision, scale) => new DecimalTypeInfo(precision, scale) + case _ => new DecimalTypeInfo(UNLIMITED_DECIMAL_PRECISION, UNLIMITED_DECIMAL_SCALE) + } + + def decimalTypeInfoToCatalyst(inspector: PrimitiveObjectInspector): DecimalType = { + val info = inspector.getTypeInfo.asInstanceOf[DecimalTypeInfo] + DecimalType(info.precision(), info.scale()) + } + + def toCatalystDecimal(hdoi: HiveDecimalObjectInspector, data: Any): Decimal = { + Decimal(hdoi.getPrimitiveJavaObject(data).bigDecimalValue(), hdoi.precision(), hdoi.scale()) + } } /* -- cgit v1.2.3