aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorMatei Zaharia <matei@databricks.com>2014-11-01 19:29:14 -0700
committerMichael Armbrust <michael@databricks.com>2014-11-01 19:29:14 -0700
commit23f966f47523f85ba440b4080eee665271f53b5e (patch)
treed796351567f8b187511b9049199cbf99c5826fb3 /sql
parent56f2c61cde3f5d906c2a58e9af1a661222f2c679 (diff)
downloadspark-23f966f47523f85ba440b4080eee665271f53b5e.tar.gz
spark-23f966f47523f85ba440b4080eee665271f53b5e.tar.bz2
spark-23f966f47523f85ba440b4080eee665271f53b5e.zip
[SPARK-3930] [SPARK-3933] Support fixed-precision decimal in SQL, and some optimizations
- Adds optional precision and scale to Spark SQL's decimal type, which behave similarly to those in Hive 13 (https://cwiki.apache.org/confluence/download/attachments/27362075/Hive_Decimal_Precision_Scale_Support.pdf) - Replaces our internal representation of decimals with a Decimal class that can store small values in a mutable Long, saving memory in this situation and letting some operations happen directly on Longs This is still marked WIP because there are a few TODOs, but I'll remove that tag when done. Author: Matei Zaharia <matei@databricks.com> Closes #2983 from mateiz/decimal-1 and squashes the following commits: 35e6b02 [Matei Zaharia] Fix issues after merge 227f24a [Matei Zaharia] Review comments 31f915e [Matei Zaharia] Implement Davies's suggestions in Python eb84820 [Matei Zaharia] Support reading/writing decimals as fixed-length binary in Parquet 4dc6bae [Matei Zaharia] Fix decimal support in PySpark d1d9d68 [Matei Zaharia] Fix compile error and test issues after rebase b28933d [Matei Zaharia] Support decimal precision/scale in Hive metastore 2118c0d [Matei Zaharia] Some test and bug fixes 81db9cb [Matei Zaharia] Added mutable Decimal that will be more efficient for small precisions 7af0c3b [Matei Zaharia] Add optional precision and scale to DecimalType, but use Unlimited for now ec0a947 [Matei Zaharia] Make the result of AVG on Decimals be Decimal, not Double
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala20
-rwxr-xr-xsql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala14
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala146
-rwxr-xr-xsql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala11
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala78
-rwxr-xr-xsql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala55
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala10
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala31
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala59
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala6
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala38
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala84
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/decimal/Decimal.scala335
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala14
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala6
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala88
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala17
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala90
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/types/decimal/DecimalSuite.scala158
-rw-r--r--sql/core/src/main/java/org/apache/spark/sql/api/java/DataType.java5
-rw-r--r--sql/core/src/main/java/org/apache/spark/sql/api/java/DecimalType.java58
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala3
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/api/java/Row.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala41
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala7
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala3
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala6
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala20
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/package.scala14
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala43
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala28
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala79
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala13
-rw-r--r--sql/core/src/test/java/org/apache/spark/sql/api/java/JavaApplySchemaSuite.java2
-rw-r--r--sql/core/src/test/java/org/apache/spark/sql/api/java/JavaSideDataTypeConversionSuite.java9
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataTypeSuite.scala2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala5
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/api/java/JavaSQLSuite.scala2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/api/java/ScalaSideDataTypeConversionSuite.scala4
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala46
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala35
-rw-r--r--sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala4
-rw-r--r--sql/hive-thriftserver/v0.12.0/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim12.scala4
-rw-r--r--sql/hive-thriftserver/v0.13.1/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim13.scala2
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala9
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala24
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala14
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala15
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala3
-rw-r--r--sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala22
-rw-r--r--sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala39
54 files changed, 1604 insertions, 229 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index 75923d9e8d..8fbdf664b7 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -19,9 +19,10 @@ package org.apache.spark.sql.catalyst
import java.sql.{Date, Timestamp}
-import org.apache.spark.sql.catalyst.expressions.{GenericRow, Attribute, AttributeReference}
+import org.apache.spark.sql.catalyst.expressions.{GenericRow, Attribute, AttributeReference, Row}
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
import org.apache.spark.sql.catalyst.types._
+import org.apache.spark.sql.catalyst.types.decimal.Decimal
/**
* Provides experimental support for generating catalyst schemas for scala objects.
@@ -40,9 +41,20 @@ object ScalaReflection {
case s: Seq[_] => s.map(convertToCatalyst)
case m: Map[_, _] => m.map { case (k, v) => convertToCatalyst(k) -> convertToCatalyst(v) }
case p: Product => new GenericRow(p.productIterator.map(convertToCatalyst).toArray)
+ case d: BigDecimal => Decimal(d)
case other => other
}
+ /** Converts Catalyst types used internally in rows to standard Scala types */
+ def convertToScala(a: Any): Any = a match {
+ case s: Seq[_] => s.map(convertToScala)
+ case m: Map[_, _] => m.map { case (k, v) => convertToScala(k) -> convertToScala(v) }
+ case d: Decimal => d.toBigDecimal
+ case other => other
+ }
+
+ def convertRowToScala(r: Row): Row = new GenericRow(r.toArray.map(convertToScala))
+
/** Returns a Sequence of attributes for the given case class type. */
def attributesFor[T: TypeTag]: Seq[Attribute] = schemaFor[T] match {
case Schema(s: StructType, _) =>
@@ -83,7 +95,8 @@ object ScalaReflection {
case t if t <:< typeOf[String] => Schema(StringType, nullable = true)
case t if t <:< typeOf[Timestamp] => Schema(TimestampType, nullable = true)
case t if t <:< typeOf[Date] => Schema(DateType, nullable = true)
- case t if t <:< typeOf[BigDecimal] => Schema(DecimalType, nullable = true)
+ case t if t <:< typeOf[BigDecimal] => Schema(DecimalType.Unlimited, nullable = true)
+ case t if t <:< typeOf[Decimal] => Schema(DecimalType.Unlimited, nullable = true)
case t if t <:< typeOf[java.lang.Integer] => Schema(IntegerType, nullable = true)
case t if t <:< typeOf[java.lang.Long] => Schema(LongType, nullable = true)
case t if t <:< typeOf[java.lang.Double] => Schema(DoubleType, nullable = true)
@@ -111,8 +124,9 @@ object ScalaReflection {
case obj: LongType.JvmType => LongType
case obj: FloatType.JvmType => FloatType
case obj: DoubleType.JvmType => DoubleType
- case obj: DecimalType.JvmType => DecimalType
case obj: DateType.JvmType => DateType
+ case obj: BigDecimal => DecimalType.Unlimited
+ case obj: Decimal => DecimalType.Unlimited
case obj: TimestampType.JvmType => TimestampType
case null => NullType
// For other cases, there is no obvious mapping from the type of the given object to a
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
index b1e7570f57..00fc4d75c9 100755
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
@@ -52,11 +52,13 @@ class SqlParser extends AbstractSparkSQLParser {
protected val CASE = Keyword("CASE")
protected val CAST = Keyword("CAST")
protected val COUNT = Keyword("COUNT")
+ protected val DECIMAL = Keyword("DECIMAL")
protected val DESC = Keyword("DESC")
protected val DISTINCT = Keyword("DISTINCT")
protected val ELSE = Keyword("ELSE")
protected val END = Keyword("END")
protected val EXCEPT = Keyword("EXCEPT")
+ protected val DOUBLE = Keyword("DOUBLE")
protected val FALSE = Keyword("FALSE")
protected val FIRST = Keyword("FIRST")
protected val FROM = Keyword("FROM")
@@ -385,5 +387,15 @@ class SqlParser extends AbstractSparkSQLParser {
}
protected lazy val dataType: Parser[DataType] =
- STRING ^^^ StringType | TIMESTAMP ^^^ TimestampType
+ ( STRING ^^^ StringType
+ | TIMESTAMP ^^^ TimestampType
+ | DOUBLE ^^^ DoubleType
+ | fixedDecimalType
+ | DECIMAL ^^^ DecimalType.Unlimited
+ )
+
+ protected lazy val fixedDecimalType: Parser[DataType] =
+ (DECIMAL ~ "(" ~> numericLit) ~ ("," ~> numericLit <~ ")") ^^ {
+ case precision ~ scale => DecimalType(precision.toInt, scale.toInt)
+ }
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
index 2b69c02b28..e38114ab3c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
@@ -25,19 +25,31 @@ import org.apache.spark.sql.catalyst.types._
object HiveTypeCoercion {
// See https://cwiki.apache.org/confluence/display/Hive/LanguageManual+Types.
// The conversion for integral and floating point types have a linear widening hierarchy:
- val numericPrecedence =
- Seq(ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType, DecimalType)
- val allPromotions: Seq[Seq[DataType]] = numericPrecedence :: Nil
+ private val numericPrecedence =
+ Seq(ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType, DecimalType.Unlimited)
+ /**
+ * Find the tightest common type of two types that might be used in a binary expression.
+ * This handles all numeric types except fixed-precision decimals interacting with each other or
+ * with primitive types, because in that case the precision and scale of the result depends on
+ * the operation. Those rules are implemented in [[HiveTypeCoercion.DecimalPrecision]].
+ */
def findTightestCommonType(t1: DataType, t2: DataType): Option[DataType] = {
val valueTypes = Seq(t1, t2).filter(t => t != NullType)
if (valueTypes.distinct.size > 1) {
- // Try and find a promotion rule that contains both types in question.
- val applicableConversion =
- HiveTypeCoercion.allPromotions.find(p => p.contains(t1) && p.contains(t2))
-
- // If found return the widest common type, otherwise None
- applicableConversion.map(_.filter(t => t == t1 || t == t2).last)
+ // Promote numeric types to the highest of the two and all numeric types to unlimited decimal
+ if (numericPrecedence.contains(t1) && numericPrecedence.contains(t2)) {
+ Some(numericPrecedence.filter(t => t == t1 || t == t2).last)
+ } else if (t1.isInstanceOf[DecimalType] && t2.isInstanceOf[DecimalType]) {
+ // Fixed-precision decimals can up-cast into unlimited
+ if (t1 == DecimalType.Unlimited || t2 == DecimalType.Unlimited) {
+ Some(DecimalType.Unlimited)
+ } else {
+ None
+ }
+ } else {
+ None
+ }
} else {
Some(if (valueTypes.size == 0) NullType else valueTypes.head)
}
@@ -59,6 +71,7 @@ trait HiveTypeCoercion {
ConvertNaNs ::
WidenTypes ::
PromoteStrings ::
+ DecimalPrecision ::
BooleanComparisons ::
BooleanCasts ::
StringToIntegralCasts ::
@@ -151,6 +164,7 @@ trait HiveTypeCoercion {
import HiveTypeCoercion._
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+ // TODO: unions with fixed-precision decimals
case u @ Union(left, right) if u.childrenResolved && !u.resolved =>
val castedInput = left.output.zip(right.output).map {
// When a string is found on one side, make the other side a string too.
@@ -265,6 +279,110 @@ trait HiveTypeCoercion {
}
}
+ // scalastyle:off
+ /**
+ * Calculates and propagates precision for fixed-precision decimals. Hive has a number of
+ * rules for this based on the SQL standard and MS SQL:
+ * https://cwiki.apache.org/confluence/download/attachments/27362075/Hive_Decimal_Precision_Scale_Support.pdf
+ *
+ * In particular, if we have expressions e1 and e2 with precision/scale p1/s2 and p2/s2
+ * respectively, then the following operations have the following precision / scale:
+ *
+ * Operation Result Precision Result Scale
+ * ------------------------------------------------------------------------
+ * e1 + e2 max(s1, s2) + max(p1-s1, p2-s2) + 1 max(s1, s2)
+ * e1 - e2 max(s1, s2) + max(p1-s1, p2-s2) + 1 max(s1, s2)
+ * e1 * e2 p1 + p2 + 1 s1 + s2
+ * e1 / e2 p1 - s1 + s2 + max(6, s1 + p2 + 1) max(6, s1 + p2 + 1)
+ * e1 % e2 min(p1-s1, p2-s2) + max(s1, s2) max(s1, s2)
+ * sum(e1) p1 + 10 s1
+ * avg(e1) p1 + 4 s1 + 4
+ *
+ * Catalyst also has unlimited-precision decimals. For those, all ops return unlimited precision.
+ *
+ * To implement the rules for fixed-precision types, we introduce casts to turn them to unlimited
+ * precision, do the math on unlimited-precision numbers, then introduce casts back to the
+ * required fixed precision. This allows us to do all rounding and overflow handling in the
+ * cast-to-fixed-precision operator.
+ *
+ * In addition, when mixing non-decimal types with decimals, we use the following rules:
+ * - BYTE gets turned into DECIMAL(3, 0)
+ * - SHORT gets turned into DECIMAL(5, 0)
+ * - INT gets turned into DECIMAL(10, 0)
+ * - LONG gets turned into DECIMAL(20, 0)
+ * - FLOAT and DOUBLE cause fixed-length decimals to turn into DOUBLE (this is the same as Hive,
+ * but note that unlimited decimals are considered bigger than doubles in WidenTypes)
+ */
+ // scalastyle:on
+ object DecimalPrecision extends Rule[LogicalPlan] {
+ import scala.math.{max, min}
+
+ // Conversion rules for integer types into fixed-precision decimals
+ val intTypeToFixed: Map[DataType, DecimalType] = Map(
+ ByteType -> DecimalType(3, 0),
+ ShortType -> DecimalType(5, 0),
+ IntegerType -> DecimalType(10, 0),
+ LongType -> DecimalType(20, 0)
+ )
+
+ def isFloat(t: DataType): Boolean = t == FloatType || t == DoubleType
+
+ def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
+ // Skip nodes whose children have not been resolved yet
+ case e if !e.childrenResolved => e
+
+ case Add(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) =>
+ Cast(
+ Add(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)),
+ DecimalType(max(s1, s2) + max(p1 - s1, p2 - s2) + 1, max(s1, s2))
+ )
+
+ case Subtract(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) =>
+ Cast(
+ Subtract(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)),
+ DecimalType(max(s1, s2) + max(p1 - s1, p2 - s2) + 1, max(s1, s2))
+ )
+
+ case Multiply(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) =>
+ Cast(
+ Multiply(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)),
+ DecimalType(p1 + p2 + 1, s1 + s2)
+ )
+
+ case Divide(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) =>
+ Cast(
+ Divide(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)),
+ DecimalType(p1 - s1 + s2 + max(6, s1 + p2 + 1), max(6, s1 + p2 + 1))
+ )
+
+ case Remainder(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) =>
+ Cast(
+ Remainder(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)),
+ DecimalType(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2))
+ )
+
+ // Promote integers inside a binary expression with fixed-precision decimals to decimals,
+ // and fixed-precision decimals in an expression with floats / doubles to doubles
+ case b: BinaryExpression if b.left.dataType != b.right.dataType =>
+ (b.left.dataType, b.right.dataType) match {
+ case (t, DecimalType.Fixed(p, s)) if intTypeToFixed.contains(t) =>
+ b.makeCopy(Array(Cast(b.left, intTypeToFixed(t)), b.right))
+ case (DecimalType.Fixed(p, s), t) if intTypeToFixed.contains(t) =>
+ b.makeCopy(Array(b.left, Cast(b.right, intTypeToFixed(t))))
+ case (t, DecimalType.Fixed(p, s)) if isFloat(t) =>
+ b.makeCopy(Array(b.left, Cast(b.right, DoubleType)))
+ case (DecimalType.Fixed(p, s), t) if isFloat(t) =>
+ b.makeCopy(Array(Cast(b.left, DoubleType), b.right))
+ case _ =>
+ b
+ }
+
+ // TODO: MaxOf, MinOf, etc might want other rules
+
+ // SUM and AVERAGE are handled by the implementations of those expressions
+ }
+ }
+
/**
* Changes Boolean values to Bytes so that expressions like true < false can be Evaluated.
*/
@@ -330,7 +448,7 @@ trait HiveTypeCoercion {
case e if !e.childrenResolved => e
case Cast(e @ StringType(), t: IntegralType) =>
- Cast(Cast(e, DecimalType), t)
+ Cast(Cast(e, DecimalType.Unlimited), t)
}
}
@@ -383,10 +501,12 @@ trait HiveTypeCoercion {
// Decimal and Double remain the same
case d: Divide if d.resolved && d.dataType == DoubleType => d
- case d: Divide if d.resolved && d.dataType == DecimalType => d
+ case d: Divide if d.resolved && d.dataType.isInstanceOf[DecimalType] => d
- case Divide(l, r) if l.dataType == DecimalType => Divide(l, Cast(r, DecimalType))
- case Divide(l, r) if r.dataType == DecimalType => Divide(Cast(l, DecimalType), r)
+ case Divide(l, r) if l.dataType.isInstanceOf[DecimalType] =>
+ Divide(l, Cast(r, DecimalType.Unlimited))
+ case Divide(l, r) if r.dataType.isInstanceOf[DecimalType] =>
+ Divide(Cast(l, DecimalType.Unlimited), r)
case Divide(l, r) => Divide(Cast(l, DoubleType), Cast(r, DoubleType))
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
index 23cfd483ec..7e6d770314 100755
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
@@ -19,6 +19,8 @@ package org.apache.spark.sql.catalyst
import java.sql.{Date, Timestamp}
+import org.apache.spark.sql.catalyst.types.decimal.Decimal
+
import scala.language.implicitConversions
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
@@ -124,7 +126,8 @@ package object dsl {
implicit def doubleToLiteral(d: Double) = Literal(d)
implicit def stringToLiteral(s: String) = Literal(s)
implicit def dateToLiteral(d: Date) = Literal(d)
- implicit def decimalToLiteral(d: BigDecimal) = Literal(d)
+ implicit def bigDecimalToLiteral(d: BigDecimal) = Literal(d)
+ implicit def decimalToLiteral(d: Decimal) = Literal(d)
implicit def timestampToLiteral(t: Timestamp) = Literal(t)
implicit def binaryToLiteral(a: Array[Byte]) = Literal(a)
@@ -183,7 +186,11 @@ package object dsl {
def date = AttributeReference(s, DateType, nullable = true)()
/** Creates a new AttributeReference of type decimal */
- def decimal = AttributeReference(s, DecimalType, nullable = true)()
+ def decimal = AttributeReference(s, DecimalType.Unlimited, nullable = true)()
+
+ /** Creates a new AttributeReference of type decimal */
+ def decimal(precision: Int, scale: Int) =
+ AttributeReference(s, DecimalType(precision, scale), nullable = true)()
/** Creates a new AttributeReference of type timestamp */
def timestamp = AttributeReference(s, TimestampType, nullable = true)()
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index 8e5baf0eb8..2200966619 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -23,6 +23,7 @@ import java.text.{DateFormat, SimpleDateFormat}
import org.apache.spark.Logging
import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.catalyst.types._
+import org.apache.spark.sql.catalyst.types.decimal.Decimal
/** Cast the child expression to the target data type. */
case class Cast(child: Expression, dataType: DataType) extends UnaryExpression with Logging {
@@ -36,6 +37,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
case (BooleanType, DateType) => true
case (DateType, _: NumericType) => true
case (DateType, BooleanType) => true
+ case (_, DecimalType.Fixed(_, _)) => true // TODO: not all upcasts here can really give null
case _ => child.nullable
}
@@ -76,8 +78,8 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
buildCast[Short](_, _ != 0)
case ByteType =>
buildCast[Byte](_, _ != 0)
- case DecimalType =>
- buildCast[BigDecimal](_, _ != 0)
+ case DecimalType() =>
+ buildCast[Decimal](_, _ != 0)
case DoubleType =>
buildCast[Double](_, _ != 0)
case FloatType =>
@@ -109,19 +111,19 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
case DateType =>
buildCast[Date](_, d => new Timestamp(d.getTime))
// TimestampWritable.decimalToTimestamp
- case DecimalType =>
- buildCast[BigDecimal](_, d => decimalToTimestamp(d))
+ case DecimalType() =>
+ buildCast[Decimal](_, d => decimalToTimestamp(d))
// TimestampWritable.doubleToTimestamp
case DoubleType =>
- buildCast[Double](_, d => decimalToTimestamp(d))
+ buildCast[Double](_, d => decimalToTimestamp(Decimal(d)))
// TimestampWritable.floatToTimestamp
case FloatType =>
- buildCast[Float](_, f => decimalToTimestamp(f))
+ buildCast[Float](_, f => decimalToTimestamp(Decimal(f)))
}
- private[this] def decimalToTimestamp(d: BigDecimal) = {
+ private[this] def decimalToTimestamp(d: Decimal) = {
val seconds = Math.floor(d.toDouble).toLong
- val bd = (d - seconds) * 1000000000
+ val bd = (d.toBigDecimal - seconds) * 1000000000
val nanos = bd.intValue()
val millis = seconds * 1000
@@ -196,8 +198,8 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
buildCast[Date](_, d => dateToLong(d))
case TimestampType =>
buildCast[Timestamp](_, t => timestampToLong(t))
- case DecimalType =>
- buildCast[BigDecimal](_, _.toLong)
+ case DecimalType() =>
+ buildCast[Decimal](_, _.toLong)
case x: NumericType =>
b => x.numeric.asInstanceOf[Numeric[Any]].toLong(b)
}
@@ -214,8 +216,8 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
buildCast[Date](_, d => dateToLong(d))
case TimestampType =>
buildCast[Timestamp](_, t => timestampToLong(t).toInt)
- case DecimalType =>
- buildCast[BigDecimal](_, _.toInt)
+ case DecimalType() =>
+ buildCast[Decimal](_, _.toInt)
case x: NumericType =>
b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b)
}
@@ -232,8 +234,8 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
buildCast[Date](_, d => dateToLong(d))
case TimestampType =>
buildCast[Timestamp](_, t => timestampToLong(t).toShort)
- case DecimalType =>
- buildCast[BigDecimal](_, _.toShort)
+ case DecimalType() =>
+ buildCast[Decimal](_, _.toShort)
case x: NumericType =>
b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b).toShort
}
@@ -250,27 +252,45 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
buildCast[Date](_, d => dateToLong(d))
case TimestampType =>
buildCast[Timestamp](_, t => timestampToLong(t).toByte)
- case DecimalType =>
- buildCast[BigDecimal](_, _.toByte)
+ case DecimalType() =>
+ buildCast[Decimal](_, _.toByte)
case x: NumericType =>
b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b).toByte
}
- // DecimalConverter
- private[this] def castToDecimal: Any => Any = child.dataType match {
+ /**
+ * Change the precision / scale in a given decimal to those set in `decimalType` (if any),
+ * returning null if it overflows or modifying `value` in-place and returning it if successful.
+ *
+ * NOTE: this modifies `value` in-place, so don't call it on external data.
+ */
+ private[this] def changePrecision(value: Decimal, decimalType: DecimalType): Decimal = {
+ decimalType match {
+ case DecimalType.Unlimited =>
+ value
+ case DecimalType.Fixed(precision, scale) =>
+ if (value.changePrecision(precision, scale)) value else null
+ }
+ }
+
+ private[this] def castToDecimal(target: DecimalType): Any => Any = child.dataType match {
case StringType =>
- buildCast[String](_, s => try BigDecimal(s.toDouble) catch {
+ buildCast[String](_, s => try changePrecision(Decimal(s.toDouble), target) catch {
case _: NumberFormatException => null
})
case BooleanType =>
- buildCast[Boolean](_, b => if (b) BigDecimal(1) else BigDecimal(0))
+ buildCast[Boolean](_, b => changePrecision(if (b) Decimal(1) else Decimal(0), target))
case DateType =>
- buildCast[Date](_, d => dateToDouble(d))
+ buildCast[Date](_, d => changePrecision(null, target)) // date can't cast to decimal in Hive
case TimestampType =>
// Note that we lose precision here.
- buildCast[Timestamp](_, t => BigDecimal(timestampToDouble(t)))
- case x: NumericType =>
- b => BigDecimal(x.numeric.asInstanceOf[Numeric[Any]].toDouble(b))
+ buildCast[Timestamp](_, t => changePrecision(Decimal(timestampToDouble(t)), target))
+ case DecimalType() =>
+ b => changePrecision(b.asInstanceOf[Decimal].clone(), target)
+ case LongType =>
+ b => changePrecision(Decimal(b.asInstanceOf[Long]), target)
+ case x: NumericType => // All other numeric types can be represented precisely as Doubles
+ b => changePrecision(Decimal(x.numeric.asInstanceOf[Numeric[Any]].toDouble(b)), target)
}
// DoubleConverter
@@ -285,8 +305,8 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
buildCast[Date](_, d => dateToDouble(d))
case TimestampType =>
buildCast[Timestamp](_, t => timestampToDouble(t))
- case DecimalType =>
- buildCast[BigDecimal](_, _.toDouble)
+ case DecimalType() =>
+ buildCast[Decimal](_, _.toDouble)
case x: NumericType =>
b => x.numeric.asInstanceOf[Numeric[Any]].toDouble(b)
}
@@ -303,8 +323,8 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
buildCast[Date](_, d => dateToDouble(d))
case TimestampType =>
buildCast[Timestamp](_, t => timestampToDouble(t).toFloat)
- case DecimalType =>
- buildCast[BigDecimal](_, _.toFloat)
+ case DecimalType() =>
+ buildCast[Decimal](_, _.toFloat)
case x: NumericType =>
b => x.numeric.asInstanceOf[Numeric[Any]].toFloat(b)
}
@@ -313,8 +333,8 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
case dt if dt == child.dataType => identity[Any]
case StringType => castToString
case BinaryType => castToBinary
- case DecimalType => castToDecimal
case DateType => castToDate
+ case decimal: DecimalType => castToDecimal(decimal)
case TimestampType => castToTimestamp
case BooleanType => castToBoolean
case ByteType => castToByte
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
index 1b4d892625..2b364fc1df 100755
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
@@ -286,18 +286,38 @@ case class ApproxCountDistinct(child: Expression, relativeSD: Double = 0.05)
case class Average(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {
override def nullable = false
- override def dataType = DoubleType
+
+ override def dataType = child.dataType match {
+ case DecimalType.Fixed(precision, scale) =>
+ DecimalType(precision + 4, scale + 4) // Add 4 digits after decimal point, like Hive
+ case DecimalType.Unlimited =>
+ DecimalType.Unlimited
+ case _ =>
+ DoubleType
+ }
+
override def toString = s"AVG($child)"
override def asPartial: SplitEvaluation = {
val partialSum = Alias(Sum(child), "PartialSum")()
val partialCount = Alias(Count(child), "PartialCount")()
- val castedSum = Cast(Sum(partialSum.toAttribute), dataType)
- val castedCount = Cast(Sum(partialCount.toAttribute), dataType)
- SplitEvaluation(
- Divide(castedSum, castedCount),
- partialCount :: partialSum :: Nil)
+ child.dataType match {
+ case DecimalType.Fixed(_, _) =>
+ // Turn the results to unlimited decimals for the divsion, before going back to fixed
+ val castedSum = Cast(Sum(partialSum.toAttribute), DecimalType.Unlimited)
+ val castedCount = Cast(Sum(partialCount.toAttribute), DecimalType.Unlimited)
+ SplitEvaluation(
+ Cast(Divide(castedSum, castedCount), dataType),
+ partialCount :: partialSum :: Nil)
+
+ case _ =>
+ val castedSum = Cast(Sum(partialSum.toAttribute), dataType)
+ val castedCount = Cast(Sum(partialCount.toAttribute), dataType)
+ SplitEvaluation(
+ Divide(castedSum, castedCount),
+ partialCount :: partialSum :: Nil)
+ }
}
override def newInstance() = new AverageFunction(child, this)
@@ -306,7 +326,16 @@ case class Average(child: Expression) extends PartialAggregate with trees.UnaryN
case class Sum(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {
override def nullable = false
- override def dataType = child.dataType
+
+ override def dataType = child.dataType match {
+ case DecimalType.Fixed(precision, scale) =>
+ DecimalType(precision + 10, scale) // Add 10 digits left of decimal point, like Hive
+ case DecimalType.Unlimited =>
+ DecimalType.Unlimited
+ case _ =>
+ child.dataType
+ }
+
override def toString = s"SUM($child)"
override def asPartial: SplitEvaluation = {
@@ -322,9 +351,17 @@ case class Sum(child: Expression) extends PartialAggregate with trees.UnaryNode[
case class SumDistinct(child: Expression)
extends AggregateExpression with trees.UnaryNode[Expression] {
-
override def nullable = false
- override def dataType = child.dataType
+
+ override def dataType = child.dataType match {
+ case DecimalType.Fixed(precision, scale) =>
+ DecimalType(precision + 10, scale) // Add 10 digits left of decimal point, like Hive
+ case DecimalType.Unlimited =>
+ DecimalType.Unlimited
+ case _ =>
+ child.dataType
+ }
+
override def toString = s"SUM(DISTINCT $child)"
override def newInstance() = new SumDistinctFunction(child, this)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
index 83e8466ec2..8574cabc43 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
@@ -36,7 +36,7 @@ case class UnaryMinus(child: Expression) extends UnaryExpression {
case class Sqrt(child: Expression) extends UnaryExpression {
type EvaluatedType = Any
-
+
def dataType = DoubleType
override def foldable = child.foldable
def nullable = child.nullable
@@ -55,7 +55,9 @@ abstract class BinaryArithmetic extends BinaryExpression {
def nullable = left.nullable || right.nullable
override lazy val resolved =
- left.resolved && right.resolved && left.dataType == right.dataType
+ left.resolved && right.resolved &&
+ left.dataType == right.dataType &&
+ !DecimalType.isFixed(left.dataType)
def dataType = {
if (!resolved) {
@@ -104,6 +106,8 @@ case class Multiply(left: Expression, right: Expression) extends BinaryArithmeti
case class Divide(left: Expression, right: Expression) extends BinaryArithmetic {
def symbol = "/"
+ override def nullable = left.nullable || right.nullable || dataType.isInstanceOf[DecimalType]
+
override def eval(input: Row): Any = dataType match {
case _: FractionalType => f2(input, left, right, _.div(_, _))
case _: IntegralType => i2(input, left , right, _.quot(_, _))
@@ -114,6 +118,8 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic
case class Remainder(left: Expression, right: Expression) extends BinaryArithmetic {
def symbol = "%"
+ override def nullable = left.nullable || right.nullable || dataType.isInstanceOf[DecimalType]
+
override def eval(input: Row): Any = i2(input, left, right, _.rem(_, _))
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index 5a3f013c34..67f8d411b6 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst.expressions.codegen
import com.google.common.cache.{CacheLoader, CacheBuilder}
+import org.apache.spark.sql.catalyst.types.decimal.Decimal
import scala.language.existentials
@@ -485,6 +486,34 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
}
""".children
+ case UnscaledValue(child) =>
+ val childEval = expressionEvaluator(child)
+
+ childEval.code ++
+ q"""
+ var $nullTerm = ${childEval.nullTerm}
+ var $primitiveTerm: Long = if (!$nullTerm) {
+ ${childEval.primitiveTerm}.toUnscaledLong
+ } else {
+ ${defaultPrimitive(LongType)}
+ }
+ """.children
+
+ case MakeDecimal(child, precision, scale) =>
+ val childEval = expressionEvaluator(child)
+
+ childEval.code ++
+ q"""
+ var $nullTerm = ${childEval.nullTerm}
+ var $primitiveTerm: org.apache.spark.sql.catalyst.types.decimal.Decimal =
+ ${defaultPrimitive(DecimalType())}
+
+ if (!$nullTerm) {
+ $primitiveTerm = new org.apache.spark.sql.catalyst.types.decimal.Decimal()
+ $primitiveTerm = $primitiveTerm.setOrNull(${childEval.primitiveTerm}, $precision, $scale)
+ $nullTerm = $primitiveTerm == null
+ }
+ """.children
}
// If there was no match in the partial function above, we fall back on calling the interpreted
@@ -562,7 +591,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
case LongType => ru.Literal(Constant(1L))
case ByteType => ru.Literal(Constant(-1.toByte))
case DoubleType => ru.Literal(Constant(-1.toDouble))
- case DecimalType => ru.Literal(Constant(-1)) // Will get implicity converted as needed.
+ case DecimalType() => q"org.apache.spark.sql.catalyst.types.decimal.Decimal(-1)"
case IntegerType => ru.Literal(Constant(-1))
case _ => ru.Literal(Constant(null))
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala
new file mode 100644
index 0000000000..d1eab2eb4e
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala
@@ -0,0 +1,59 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions
+
+import org.apache.spark.sql.catalyst.types.decimal.Decimal
+import org.apache.spark.sql.catalyst.types.{DecimalType, LongType, DoubleType, DataType}
+
+/** Return the unscaled Long value of a Decimal, assuming it fits in a Long */
+case class UnscaledValue(child: Expression) extends UnaryExpression {
+ override type EvaluatedType = Any
+
+ override def dataType: DataType = LongType
+ override def foldable = child.foldable
+ def nullable = child.nullable
+ override def toString = s"UnscaledValue($child)"
+
+ override def eval(input: Row): Any = {
+ val childResult = child.eval(input)
+ if (childResult == null) {
+ null
+ } else {
+ childResult.asInstanceOf[Decimal].toUnscaledLong
+ }
+ }
+}
+
+/** Create a Decimal from an unscaled Long value */
+case class MakeDecimal(child: Expression, precision: Int, scale: Int) extends UnaryExpression {
+ override type EvaluatedType = Decimal
+
+ override def dataType: DataType = DecimalType(precision, scale)
+ override def foldable = child.foldable
+ def nullable = child.nullable
+ override def toString = s"MakeDecimal($child,$precision,$scale)"
+
+ override def eval(input: Row): Decimal = {
+ val childResult = child.eval(input)
+ if (childResult == null) {
+ null
+ } else {
+ new Decimal().setOrNull(childResult.asInstanceOf[Long], precision, scale)
+ }
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
index ba240233ca..93c1932515 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions
import java.sql.{Date, Timestamp}
import org.apache.spark.sql.catalyst.types._
+import org.apache.spark.sql.catalyst.types.decimal.Decimal
object Literal {
def apply(v: Any): Literal = v match {
@@ -31,7 +32,8 @@ object Literal {
case s: Short => Literal(s, ShortType)
case s: String => Literal(s, StringType)
case b: Boolean => Literal(b, BooleanType)
- case d: BigDecimal => Literal(d, DecimalType)
+ case d: BigDecimal => Literal(Decimal(d), DecimalType.Unlimited)
+ case d: Decimal => Literal(d, DecimalType.Unlimited)
case t: Timestamp => Literal(t, TimestampType)
case d: Date => Literal(d, DateType)
case a: Array[Byte] => Literal(a, BinaryType)
@@ -62,7 +64,7 @@ case class Literal(value: Any, dataType: DataType) extends LeafExpression {
}
// TODO: Specialize
-case class MutableLiteral(var value: Any, dataType: DataType, nullable: Boolean = true)
+case class MutableLiteral(var value: Any, dataType: DataType, nullable: Boolean = true)
extends LeafExpression {
type EvaluatedType = Any
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 9ce7c78195..a4aa322fc5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.plans.LeftSemi
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.catalyst.types._
+import org.apache.spark.sql.catalyst.types.decimal.Decimal
abstract class Optimizer extends RuleExecutor[LogicalPlan]
@@ -43,6 +44,8 @@ object DefaultOptimizer extends Optimizer {
SimplifyCasts,
SimplifyCaseConversionExpressions,
OptimizeIn) ::
+ Batch("Decimal Optimizations", FixedPoint(100),
+ DecimalAggregates) ::
Batch("Filter Pushdown", FixedPoint(100),
UnionPushdown,
CombineFilters,
@@ -390,9 +393,9 @@ object PushPredicateThroughProject extends Rule[LogicalPlan] {
* evaluated using only the attributes of the left or right side of a join. Other
* [[Filter]] conditions are moved into the `condition` of the [[Join]].
*
- * And also Pushes down the join filter, where the `condition` can be evaluated using only the
- * attributes of the left or right side of sub query when applicable.
- *
+ * And also Pushes down the join filter, where the `condition` can be evaluated using only the
+ * attributes of the left or right side of sub query when applicable.
+ *
* Check https://cwiki.apache.org/confluence/display/Hive/OuterJoinBehavior for more details
*/
object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper {
@@ -404,7 +407,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper {
private def split(condition: Seq[Expression], left: LogicalPlan, right: LogicalPlan) = {
val (leftEvaluateCondition, rest) =
condition.partition(_.references subsetOf left.outputSet)
- val (rightEvaluateCondition, commonCondition) =
+ val (rightEvaluateCondition, commonCondition) =
rest.partition(_.references subsetOf right.outputSet)
(leftEvaluateCondition, rightEvaluateCondition, commonCondition)
@@ -413,7 +416,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
// push the where condition down into join filter
case f @ Filter(filterCondition, Join(left, right, joinType, joinCondition)) =>
- val (leftFilterConditions, rightFilterConditions, commonFilterCondition) =
+ val (leftFilterConditions, rightFilterConditions, commonFilterCondition) =
split(splitConjunctivePredicates(filterCondition), left, right)
joinType match {
@@ -451,7 +454,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper {
// push down the join filter into sub query scanning if applicable
case f @ Join(left, right, joinType, joinCondition) =>
- val (leftJoinConditions, rightJoinConditions, commonJoinCondition) =
+ val (leftJoinConditions, rightJoinConditions, commonJoinCondition) =
split(joinCondition.map(splitConjunctivePredicates).getOrElse(Nil), left, right)
joinType match {
@@ -519,3 +522,26 @@ object SimplifyCaseConversionExpressions extends Rule[LogicalPlan] {
}
}
}
+
+/**
+ * Speeds up aggregates on fixed-precision decimals by executing them on unscaled Long values.
+ *
+ * This uses the same rules for increasing the precision and scale of the output as
+ * [[org.apache.spark.sql.catalyst.analysis.HiveTypeCoercion.DecimalPrecision]].
+ */
+object DecimalAggregates extends Rule[LogicalPlan] {
+ import Decimal.MAX_LONG_DIGITS
+
+ /** Maximum number of decimal digits representable precisely in a Double */
+ val MAX_DOUBLE_DIGITS = 15
+
+ def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
+ case Sum(e @ DecimalType.Expression(prec, scale)) if prec + 10 <= MAX_LONG_DIGITS =>
+ MakeDecimal(Sum(UnscaledValue(e)), prec + 10, scale)
+
+ case Average(e @ DecimalType.Expression(prec, scale)) if prec + 4 <= MAX_DOUBLE_DIGITS =>
+ Cast(
+ Divide(Average(UnscaledValue(e)), Literal(math.pow(10.0, scale), DoubleType)),
+ DecimalType(prec + 4, scale + 4))
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala
index 6069f9b0a6..8dda0b1828 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.types
import java.sql.{Date, Timestamp}
-import scala.math.Numeric.{BigDecimalAsIfIntegral, DoubleAsIfIntegral, FloatAsIfIntegral}
+import scala.math.Numeric.{FloatAsIfIntegral, BigDecimalAsIfIntegral, DoubleAsIfIntegral}
import scala.reflect.ClassTag
import scala.reflect.runtime.universe.{TypeTag, runtimeMirror, typeTag}
import scala.util.parsing.combinator.RegexParsers
@@ -33,6 +33,7 @@ import org.apache.spark.sql.catalyst.ScalaReflectionLock
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression}
import org.apache.spark.sql.catalyst.util.Metadata
import org.apache.spark.util.Utils
+import org.apache.spark.sql.catalyst.types.decimal._
object DataType {
def fromJson(json: String): DataType = parseDataType(parse(json))
@@ -91,11 +92,17 @@ object DataType {
| "LongType" ^^^ LongType
| "BinaryType" ^^^ BinaryType
| "BooleanType" ^^^ BooleanType
- | "DecimalType" ^^^ DecimalType
| "DateType" ^^^ DateType
+ | "DecimalType()" ^^^ DecimalType.Unlimited
+ | fixedDecimalType
| "TimestampType" ^^^ TimestampType
)
+ protected lazy val fixedDecimalType: Parser[DataType] =
+ ("DecimalType(" ~> "[0-9]+".r) ~ ("," ~> "[0-9]+".r <~ ")") ^^ {
+ case precision ~ scale => DecimalType(precision.toInt, scale.toInt)
+ }
+
protected lazy val arrayType: Parser[DataType] =
"ArrayType" ~> "(" ~> dataType ~ "," ~ boolVal <~ ")" ^^ {
case tpe ~ _ ~ containsNull => ArrayType(tpe, containsNull)
@@ -200,10 +207,18 @@ trait PrimitiveType extends DataType {
}
object PrimitiveType {
- private[sql] val all = Seq(DecimalType, DateType, TimestampType, BinaryType) ++
- NativeType.all
-
- private[sql] val nameToType = all.map(t => t.typeName -> t).toMap
+ private val nonDecimals = Seq(DateType, TimestampType, BinaryType) ++ NativeType.all
+ private val nonDecimalNameToType = nonDecimals.map(t => t.typeName -> t).toMap
+
+ /** Given the string representation of a type, return its DataType */
+ private[sql] def nameToType(name: String): DataType = {
+ val FIXED_DECIMAL = """decimal\(\s*(\d+)\s*,\s*(\d+)\s*\)""".r
+ name match {
+ case "decimal" => DecimalType.Unlimited
+ case FIXED_DECIMAL(precision, scale) => DecimalType(precision.toInt, scale.toInt)
+ case other => nonDecimalNameToType(other)
+ }
+ }
}
abstract class NativeType extends DataType {
@@ -332,13 +347,58 @@ abstract class FractionalType extends NumericType {
private[sql] val asIntegral: Integral[JvmType]
}
-case object DecimalType extends FractionalType {
- private[sql] type JvmType = BigDecimal
+/** Precision parameters for a Decimal */
+case class PrecisionInfo(precision: Int, scale: Int)
+
+/** A Decimal that might have fixed precision and scale, or unlimited values for these */
+case class DecimalType(precisionInfo: Option[PrecisionInfo]) extends FractionalType {
+ private[sql] type JvmType = Decimal
@transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[JvmType] }
- private[sql] val numeric = implicitly[Numeric[BigDecimal]]
- private[sql] val fractional = implicitly[Fractional[BigDecimal]]
- private[sql] val ordering = implicitly[Ordering[JvmType]]
- private[sql] val asIntegral = BigDecimalAsIfIntegral
+ private[sql] val numeric = Decimal.DecimalIsFractional
+ private[sql] val fractional = Decimal.DecimalIsFractional
+ private[sql] val ordering = Decimal.DecimalIsFractional
+ private[sql] val asIntegral = Decimal.DecimalAsIfIntegral
+
+ override def typeName: String = precisionInfo match {
+ case Some(PrecisionInfo(precision, scale)) => s"decimal($precision,$scale)"
+ case None => "decimal"
+ }
+
+ override def toString: String = precisionInfo match {
+ case Some(PrecisionInfo(precision, scale)) => s"DecimalType($precision,$scale)"
+ case None => "DecimalType()"
+ }
+}
+
+/** Extra factory methods and pattern matchers for Decimals */
+object DecimalType {
+ val Unlimited: DecimalType = DecimalType(None)
+
+ object Fixed {
+ def unapply(t: DecimalType): Option[(Int, Int)] =
+ t.precisionInfo.map(p => (p.precision, p.scale))
+ }
+
+ object Expression {
+ def unapply(e: Expression): Option[(Int, Int)] = e.dataType match {
+ case t: DecimalType => t.precisionInfo.map(p => (p.precision, p.scale))
+ case _ => None
+ }
+ }
+
+ def apply(): DecimalType = Unlimited
+
+ def apply(precision: Int, scale: Int): DecimalType =
+ DecimalType(Some(PrecisionInfo(precision, scale)))
+
+ def unapply(t: DataType): Boolean = t.isInstanceOf[DecimalType]
+
+ def unapply(e: Expression): Boolean = e.dataType.isInstanceOf[DecimalType]
+
+ def isFixed(dataType: DataType): Boolean = dataType match {
+ case DecimalType.Fixed(_, _) => true
+ case _ => false
+ }
}
case object DoubleType extends FractionalType {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/decimal/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/decimal/Decimal.scala
new file mode 100644
index 0000000000..708362acf3
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/decimal/Decimal.scala
@@ -0,0 +1,335 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.types.decimal
+
+import org.apache.spark.annotation.DeveloperApi
+
+/**
+ * A mutable implementation of BigDecimal that can hold a Long if values are small enough.
+ *
+ * The semantics of the fields are as follows:
+ * - _precision and _scale represent the SQL precision and scale we are looking for
+ * - If decimalVal is set, it represents the whole decimal value
+ * - Otherwise, the decimal value is longVal / (10 ** _scale)
+ */
+final class Decimal extends Ordered[Decimal] with Serializable {
+ import Decimal.{MAX_LONG_DIGITS, POW_10, ROUNDING_MODE, BIG_DEC_ZERO}
+
+ private var decimalVal: BigDecimal = null
+ private var longVal: Long = 0L
+ private var _precision: Int = 1
+ private var _scale: Int = 0
+
+ def precision: Int = _precision
+ def scale: Int = _scale
+
+ /**
+ * Set this Decimal to the given Long. Will have precision 20 and scale 0.
+ */
+ def set(longVal: Long): Decimal = {
+ if (longVal <= -POW_10(MAX_LONG_DIGITS) || longVal >= POW_10(MAX_LONG_DIGITS)) {
+ // We can't represent this compactly as a long without risking overflow
+ this.decimalVal = BigDecimal(longVal)
+ this.longVal = 0L
+ } else {
+ this.decimalVal = null
+ this.longVal = longVal
+ }
+ this._precision = 20
+ this._scale = 0
+ this
+ }
+
+ /**
+ * Set this Decimal to the given Int. Will have precision 10 and scale 0.
+ */
+ def set(intVal: Int): Decimal = {
+ this.decimalVal = null
+ this.longVal = intVal
+ this._precision = 10
+ this._scale = 0
+ this
+ }
+
+ /**
+ * Set this Decimal to the given unscaled Long, with a given precision and scale.
+ */
+ def set(unscaled: Long, precision: Int, scale: Int): Decimal = {
+ if (setOrNull(unscaled, precision, scale) == null) {
+ throw new IllegalArgumentException("Unscaled value too large for precision")
+ }
+ this
+ }
+
+ /**
+ * Set this Decimal to the given unscaled Long, with a given precision and scale,
+ * and return it, or return null if it cannot be set due to overflow.
+ */
+ def setOrNull(unscaled: Long, precision: Int, scale: Int): Decimal = {
+ if (unscaled <= -POW_10(MAX_LONG_DIGITS) || unscaled >= POW_10(MAX_LONG_DIGITS)) {
+ // We can't represent this compactly as a long without risking overflow
+ if (precision < 19) {
+ return null // Requested precision is too low to represent this value
+ }
+ this.decimalVal = BigDecimal(longVal)
+ this.longVal = 0L
+ } else {
+ val p = POW_10(math.min(precision, MAX_LONG_DIGITS))
+ if (unscaled <= -p || unscaled >= p) {
+ return null // Requested precision is too low to represent this value
+ }
+ this.decimalVal = null
+ this.longVal = unscaled
+ }
+ this._precision = precision
+ this._scale = scale
+ this
+ }
+
+ /**
+ * Set this Decimal to the given BigDecimal value, with a given precision and scale.
+ */
+ def set(decimal: BigDecimal, precision: Int, scale: Int): Decimal = {
+ this.decimalVal = decimal.setScale(scale, ROUNDING_MODE)
+ require(decimalVal.precision <= precision, "Overflowed precision")
+ this.longVal = 0L
+ this._precision = precision
+ this._scale = scale
+ this
+ }
+
+ /**
+ * Set this Decimal to the given BigDecimal value, inheriting its precision and scale.
+ */
+ def set(decimal: BigDecimal): Decimal = {
+ this.decimalVal = decimal
+ this.longVal = 0L
+ this._precision = decimal.precision
+ this._scale = decimal.scale
+ this
+ }
+
+ /**
+ * Set this Decimal to the given Decimal value.
+ */
+ def set(decimal: Decimal): Decimal = {
+ this.decimalVal = decimal.decimalVal
+ this.longVal = decimal.longVal
+ this._precision = decimal._precision
+ this._scale = decimal._scale
+ this
+ }
+
+ def toBigDecimal: BigDecimal = {
+ if (decimalVal.ne(null)) {
+ decimalVal
+ } else {
+ BigDecimal(longVal, _scale)
+ }
+ }
+
+ def toUnscaledLong: Long = {
+ if (decimalVal.ne(null)) {
+ decimalVal.underlying().unscaledValue().longValue()
+ } else {
+ longVal
+ }
+ }
+
+ override def toString: String = toBigDecimal.toString()
+
+ @DeveloperApi
+ def toDebugString: String = {
+ if (decimalVal.ne(null)) {
+ s"Decimal(expanded,$decimalVal,$precision,$scale})"
+ } else {
+ s"Decimal(compact,$longVal,$precision,$scale})"
+ }
+ }
+
+ def toDouble: Double = toBigDecimal.doubleValue()
+
+ def toFloat: Float = toBigDecimal.floatValue()
+
+ def toLong: Long = {
+ if (decimalVal.eq(null)) {
+ longVal / POW_10(_scale)
+ } else {
+ decimalVal.longValue()
+ }
+ }
+
+ def toInt: Int = toLong.toInt
+
+ def toShort: Short = toLong.toShort
+
+ def toByte: Byte = toLong.toByte
+
+ /**
+ * Update precision and scale while keeping our value the same, and return true if successful.
+ *
+ * @return true if successful, false if overflow would occur
+ */
+ def changePrecision(precision: Int, scale: Int): Boolean = {
+ // First, update our longVal if we can, or transfer over to using a BigDecimal
+ if (decimalVal.eq(null)) {
+ if (scale < _scale) {
+ // Easier case: we just need to divide our scale down
+ val diff = _scale - scale
+ val droppedDigits = longVal % POW_10(diff)
+ longVal /= POW_10(diff)
+ if (math.abs(droppedDigits) * 2 >= POW_10(diff)) {
+ longVal += (if (longVal < 0) -1L else 1L)
+ }
+ } else if (scale > _scale) {
+ // We might be able to multiply longVal by a power of 10 and not overflow, but if not,
+ // switch to using a BigDecimal
+ val diff = scale - _scale
+ val p = POW_10(math.max(MAX_LONG_DIGITS - diff, 0))
+ if (diff <= MAX_LONG_DIGITS && longVal > -p && longVal < p) {
+ // Multiplying longVal by POW_10(diff) will still keep it below MAX_LONG_DIGITS
+ longVal *= POW_10(diff)
+ } else {
+ // Give up on using Longs; switch to BigDecimal, which we'll modify below
+ decimalVal = BigDecimal(longVal, _scale)
+ }
+ }
+ // In both cases, we will check whether our precision is okay below
+ }
+
+ if (decimalVal.ne(null)) {
+ // We get here if either we started with a BigDecimal, or we switched to one because we would
+ // have overflowed our Long; in either case we must rescale decimalVal to the new scale.
+ val newVal = decimalVal.setScale(scale, ROUNDING_MODE)
+ if (newVal.precision > precision) {
+ return false
+ }
+ decimalVal = newVal
+ } else {
+ // We're still using Longs, but we should check whether we match the new precision
+ val p = POW_10(math.min(_precision, MAX_LONG_DIGITS))
+ if (longVal <= -p || longVal >= p) {
+ // Note that we shouldn't have been able to fix this by switching to BigDecimal
+ return false
+ }
+ }
+
+ _precision = precision
+ _scale = scale
+ true
+ }
+
+ override def clone(): Decimal = new Decimal().set(this)
+
+ override def compare(other: Decimal): Int = {
+ if (decimalVal.eq(null) && other.decimalVal.eq(null) && _scale == other._scale) {
+ if (longVal < other.longVal) -1 else if (longVal == other.longVal) 0 else 1
+ } else {
+ toBigDecimal.compare(other.toBigDecimal)
+ }
+ }
+
+ override def equals(other: Any) = other match {
+ case d: Decimal =>
+ compare(d) == 0
+ case _ =>
+ false
+ }
+
+ override def hashCode(): Int = toBigDecimal.hashCode()
+
+ def isZero: Boolean = if (decimalVal.ne(null)) decimalVal == BIG_DEC_ZERO else longVal == 0
+
+ def + (that: Decimal): Decimal = Decimal(toBigDecimal + that.toBigDecimal)
+
+ def - (that: Decimal): Decimal = Decimal(toBigDecimal - that.toBigDecimal)
+
+ def * (that: Decimal): Decimal = Decimal(toBigDecimal * that.toBigDecimal)
+
+ def / (that: Decimal): Decimal =
+ if (that.isZero) null else Decimal(toBigDecimal / that.toBigDecimal)
+
+ def % (that: Decimal): Decimal =
+ if (that.isZero) null else Decimal(toBigDecimal % that.toBigDecimal)
+
+ def remainder(that: Decimal): Decimal = this % that
+
+ def unary_- : Decimal = {
+ if (decimalVal.ne(null)) {
+ Decimal(-decimalVal)
+ } else {
+ Decimal(-longVal, precision, scale)
+ }
+ }
+}
+
+object Decimal {
+ private val ROUNDING_MODE = BigDecimal.RoundingMode.HALF_UP
+
+ /** Maximum number of decimal digits a Long can represent */
+ val MAX_LONG_DIGITS = 18
+
+ private val POW_10 = Array.tabulate[Long](MAX_LONG_DIGITS + 1)(i => math.pow(10, i).toLong)
+
+ private val BIG_DEC_ZERO = BigDecimal(0)
+
+ def apply(value: Double): Decimal = new Decimal().set(value)
+
+ def apply(value: Long): Decimal = new Decimal().set(value)
+
+ def apply(value: Int): Decimal = new Decimal().set(value)
+
+ def apply(value: BigDecimal): Decimal = new Decimal().set(value)
+
+ def apply(value: BigDecimal, precision: Int, scale: Int): Decimal =
+ new Decimal().set(value, precision, scale)
+
+ def apply(unscaled: Long, precision: Int, scale: Int): Decimal =
+ new Decimal().set(unscaled, precision, scale)
+
+ def apply(value: String): Decimal = new Decimal().set(BigDecimal(value))
+
+ // Evidence parameters for Decimal considered either as Fractional or Integral. We provide two
+ // parameters inheriting from a common trait since both traits define mkNumericOps.
+ // See scala.math's Numeric.scala for examples for Scala's built-in types.
+
+ /** Common methods for Decimal evidence parameters */
+ trait DecimalIsConflicted extends Numeric[Decimal] {
+ override def plus(x: Decimal, y: Decimal): Decimal = x + y
+ override def times(x: Decimal, y: Decimal): Decimal = x * y
+ override def minus(x: Decimal, y: Decimal): Decimal = x - y
+ override def negate(x: Decimal): Decimal = -x
+ override def toDouble(x: Decimal): Double = x.toDouble
+ override def toFloat(x: Decimal): Float = x.toFloat
+ override def toInt(x: Decimal): Int = x.toInt
+ override def toLong(x: Decimal): Long = x.toLong
+ override def fromInt(x: Int): Decimal = new Decimal().set(x)
+ override def compare(x: Decimal, y: Decimal): Int = x.compare(y)
+ }
+
+ /** A [[scala.math.Fractional]] evidence parameter for Decimals. */
+ object DecimalIsFractional extends DecimalIsConflicted with Fractional[Decimal] {
+ override def div(x: Decimal, y: Decimal): Decimal = x / y
+ }
+
+ /** A [[scala.math.Integral]] evidence parameter for Decimals. */
+ object DecimalAsIfIntegral extends DecimalIsConflicted with Integral[Decimal] {
+ override def quot(x: Decimal, y: Decimal): Decimal = x / y
+ override def rem(x: Decimal, y: Decimal): Decimal = x % y
+ }
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
index 430f0664b7..21b2c8e20d 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
@@ -96,7 +96,7 @@ class ScalaReflectionSuite extends FunSuite {
StructField("byteField", ByteType, nullable = true),
StructField("booleanField", BooleanType, nullable = true),
StructField("stringField", StringType, nullable = true),
- StructField("decimalField", DecimalType, nullable = true),
+ StructField("decimalField", DecimalType.Unlimited, nullable = true),
StructField("dateField", DateType, nullable = true),
StructField("timestampField", TimestampType, nullable = true),
StructField("binaryField", BinaryType, nullable = true))),
@@ -199,7 +199,7 @@ class ScalaReflectionSuite extends FunSuite {
assert(DoubleType === typeOfObject(1.7976931348623157E308))
// DecimalType
- assert(DecimalType === typeOfObject(BigDecimal("1.7976931348623157E318")))
+ assert(DecimalType.Unlimited === typeOfObject(BigDecimal("1.7976931348623157E318")))
// DateType
assert(DateType === typeOfObject(Date.valueOf("2014-07-25")))
@@ -211,19 +211,19 @@ class ScalaReflectionSuite extends FunSuite {
assert(NullType === typeOfObject(null))
def typeOfObject1: PartialFunction[Any, DataType] = typeOfObject orElse {
- case value: java.math.BigInteger => DecimalType
- case value: java.math.BigDecimal => DecimalType
+ case value: java.math.BigInteger => DecimalType.Unlimited
+ case value: java.math.BigDecimal => DecimalType.Unlimited
case _ => StringType
}
- assert(DecimalType === typeOfObject1(
+ assert(DecimalType.Unlimited === typeOfObject1(
new BigInteger("92233720368547758070")))
- assert(DecimalType === typeOfObject1(
+ assert(DecimalType.Unlimited === typeOfObject1(
new java.math.BigDecimal("1.7976931348623157E318")))
assert(StringType === typeOfObject1(BigInt("92233720368547758070")))
def typeOfObject2: PartialFunction[Any, DataType] = typeOfObject orElse {
- case value: java.math.BigInteger => DecimalType
+ case value: java.math.BigInteger => DecimalType.Unlimited
}
intercept[MatchError](typeOfObject2(BigInt("92233720368547758070")))
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
index 7b45738c4f..33a3cba3d4 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
@@ -38,7 +38,7 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter {
AttributeReference("a", StringType)(),
AttributeReference("b", StringType)(),
AttributeReference("c", DoubleType)(),
- AttributeReference("d", DecimalType)(),
+ AttributeReference("d", DecimalType.Unlimited)(),
AttributeReference("e", ShortType)())
before {
@@ -119,7 +119,7 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter {
AttributeReference("a", StringType)(),
AttributeReference("b", StringType)(),
AttributeReference("c", DoubleType)(),
- AttributeReference("d", DecimalType)(),
+ AttributeReference("d", DecimalType.Unlimited)(),
AttributeReference("e", ShortType)())
val expr0 = 'a / 2
@@ -137,7 +137,7 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter {
assert(pl(0).dataType == DoubleType)
assert(pl(1).dataType == DoubleType)
assert(pl(2).dataType == DoubleType)
- assert(pl(3).dataType == DecimalType)
+ assert(pl(3).dataType == DecimalType.Unlimited)
assert(pl(4).dataType == DoubleType)
}
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala
new file mode 100644
index 0000000000..d5b7d2789a
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala
@@ -0,0 +1,88 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.analysis
+
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.logical.{Project, LocalRelation}
+import org.apache.spark.sql.catalyst.types._
+import org.scalatest.{BeforeAndAfter, FunSuite}
+
+class DecimalPrecisionSuite extends FunSuite with BeforeAndAfter {
+ val catalog = new SimpleCatalog(false)
+ val analyzer = new Analyzer(catalog, EmptyFunctionRegistry, caseSensitive = false)
+
+ val relation = LocalRelation(
+ AttributeReference("i", IntegerType)(),
+ AttributeReference("d1", DecimalType(2, 1))(),
+ AttributeReference("d2", DecimalType(5, 2))(),
+ AttributeReference("u", DecimalType.Unlimited)(),
+ AttributeReference("f", FloatType)()
+ )
+
+ val i: Expression = UnresolvedAttribute("i")
+ val d1: Expression = UnresolvedAttribute("d1")
+ val d2: Expression = UnresolvedAttribute("d2")
+ val u: Expression = UnresolvedAttribute("u")
+ val f: Expression = UnresolvedAttribute("f")
+
+ before {
+ catalog.registerTable(None, "table", relation)
+ }
+
+ private def checkType(expression: Expression, expectedType: DataType): Unit = {
+ val plan = Project(Seq(Alias(expression, "c")()), relation)
+ assert(analyzer(plan).schema.fields(0).dataType === expectedType)
+ }
+
+ test("basic operations") {
+ checkType(Add(d1, d2), DecimalType(6, 2))
+ checkType(Subtract(d1, d2), DecimalType(6, 2))
+ checkType(Multiply(d1, d2), DecimalType(8, 3))
+ checkType(Divide(d1, d2), DecimalType(10, 7))
+ checkType(Divide(d2, d1), DecimalType(10, 6))
+ checkType(Remainder(d1, d2), DecimalType(3, 2))
+ checkType(Remainder(d2, d1), DecimalType(3, 2))
+ checkType(Sum(d1), DecimalType(12, 1))
+ checkType(Average(d1), DecimalType(6, 5))
+
+ checkType(Add(Add(d1, d2), d1), DecimalType(7, 2))
+ checkType(Add(Add(Add(d1, d2), d1), d2), DecimalType(8, 2))
+ checkType(Add(Add(d1, d2), Add(d1, d2)), DecimalType(7, 2))
+ }
+
+ test("bringing in primitive types") {
+ checkType(Add(d1, i), DecimalType(12, 1))
+ checkType(Add(d1, f), DoubleType)
+ checkType(Add(i, d1), DecimalType(12, 1))
+ checkType(Add(f, d1), DoubleType)
+ checkType(Add(d1, Cast(i, LongType)), DecimalType(22, 1))
+ checkType(Add(d1, Cast(i, ShortType)), DecimalType(7, 1))
+ checkType(Add(d1, Cast(i, ByteType)), DecimalType(5, 1))
+ checkType(Add(d1, Cast(i, DoubleType)), DoubleType)
+ }
+
+ test("unlimited decimals make everything else cast up") {
+ for (expr <- Seq(d1, d2, i, f, u)) {
+ checkType(Add(expr, u), DecimalType.Unlimited)
+ checkType(Subtract(expr, u), DecimalType.Unlimited)
+ checkType(Multiply(expr, u), DecimalType.Unlimited)
+ checkType(Divide(expr, u), DecimalType.Unlimited)
+ checkType(Remainder(expr, u), DecimalType.Unlimited)
+ }
+ }
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
index baeb9b0cf5..dfa2d958c0 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
@@ -68,6 +68,21 @@ class HiveTypeCoercionSuite extends FunSuite {
widenTest(LongType, FloatType, Some(FloatType))
widenTest(LongType, DoubleType, Some(DoubleType))
+ // Casting up to unlimited-precision decimal
+ widenTest(IntegerType, DecimalType.Unlimited, Some(DecimalType.Unlimited))
+ widenTest(DoubleType, DecimalType.Unlimited, Some(DecimalType.Unlimited))
+ widenTest(DecimalType(3, 2), DecimalType.Unlimited, Some(DecimalType.Unlimited))
+ widenTest(DecimalType.Unlimited, IntegerType, Some(DecimalType.Unlimited))
+ widenTest(DecimalType.Unlimited, DoubleType, Some(DecimalType.Unlimited))
+ widenTest(DecimalType.Unlimited, DecimalType(3, 2), Some(DecimalType.Unlimited))
+
+ // No up-casting for fixed-precision decimal (this is handled by arithmetic rules)
+ widenTest(DecimalType(2, 1), DecimalType(3, 2), None)
+ widenTest(DecimalType(2, 1), DoubleType, None)
+ widenTest(DecimalType(2, 1), IntegerType, None)
+ widenTest(DoubleType, DecimalType(2, 1), None)
+ widenTest(IntegerType, DecimalType(2, 1), None)
+
// StringType
widenTest(NullType, StringType, Some(StringType))
widenTest(StringType, StringType, Some(StringType))
@@ -92,7 +107,7 @@ class HiveTypeCoercionSuite extends FunSuite {
def ruleTest(initial: Expression, transformed: Expression) {
val testRelation = LocalRelation(AttributeReference("a", IntegerType)())
assert(booleanCasts(Project(Seq(Alias(initial, "a")()), testRelation)) ==
- Project(Seq(Alias(transformed, "a")()), testRelation))
+ Project(Seq(Alias(transformed, "a")()), testRelation))
}
// Remove superflous boolean -> boolean casts.
ruleTest(Cast(Literal(true), BooleanType), Literal(true))
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
index 5657bc555e..6bfa0dbd65 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
@@ -21,6 +21,7 @@ import java.sql.{Date, Timestamp}
import scala.collection.immutable.HashSet
+import org.apache.spark.sql.catalyst.types.decimal.Decimal
import org.scalatest.FunSuite
import org.scalatest.Matchers._
import org.scalactic.TripleEqualsSupport.Spread
@@ -138,7 +139,7 @@ class ExpressionEvaluationSuite extends FunSuite {
val actual = try evaluate(expression, inputRow) catch {
case e: Exception => fail(s"Exception evaluating $expression", e)
}
- actual.asInstanceOf[Double] shouldBe expected
+ actual.asInstanceOf[Double] shouldBe expected
}
test("IN") {
@@ -165,7 +166,7 @@ class ExpressionEvaluationSuite extends FunSuite {
checkEvaluation(InSet(three, nS, three +: nullS), false)
checkEvaluation(InSet(one, hS, one +: s) && InSet(two, hS, two +: s), true)
}
-
+
test("MaxOf") {
checkEvaluation(MaxOf(1, 2), 2)
checkEvaluation(MaxOf(2, 1), 2)
@@ -265,9 +266,9 @@ class ExpressionEvaluationSuite extends FunSuite {
val ts = Timestamp.valueOf(nts)
checkEvaluation("abdef" cast StringType, "abdef")
- checkEvaluation("abdef" cast DecimalType, null)
+ checkEvaluation("abdef" cast DecimalType.Unlimited, null)
checkEvaluation("abdef" cast TimestampType, null)
- checkEvaluation("12.65" cast DecimalType, BigDecimal(12.65))
+ checkEvaluation("12.65" cast DecimalType.Unlimited, Decimal(12.65))
checkEvaluation(Literal(1) cast LongType, 1)
checkEvaluation(Cast(Literal(1000) cast TimestampType, LongType), 1.toLong)
@@ -289,12 +290,12 @@ class ExpressionEvaluationSuite extends FunSuite {
checkEvaluation(Cast(Cast(Cast(Cast(
Cast("5" cast ByteType, ShortType), IntegerType), FloatType), DoubleType), LongType), 5)
- checkEvaluation(Cast(Cast(Cast(Cast(
- Cast("5" cast ByteType, TimestampType), DecimalType), LongType), StringType), ShortType), 0)
- checkEvaluation(Cast(Cast(Cast(Cast(
- Cast("5" cast TimestampType, ByteType), DecimalType), LongType), StringType), ShortType), null)
- checkEvaluation(Cast(Cast(Cast(Cast(
- Cast("5" cast DecimalType, ByteType), TimestampType), LongType), StringType), ShortType), 0)
+ checkEvaluation(Cast(Cast(Cast(Cast(Cast("5" cast
+ ByteType, TimestampType), DecimalType.Unlimited), LongType), StringType), ShortType), 0)
+ checkEvaluation(Cast(Cast(Cast(Cast(Cast("5" cast
+ TimestampType, ByteType), DecimalType.Unlimited), LongType), StringType), ShortType), null)
+ checkEvaluation(Cast(Cast(Cast(Cast(Cast("5" cast
+ DecimalType.Unlimited, ByteType), TimestampType), LongType), StringType), ShortType), 0)
checkEvaluation(Literal(true) cast IntegerType, 1)
checkEvaluation(Literal(false) cast IntegerType, 0)
checkEvaluation(Cast(Literal(1) cast BooleanType, IntegerType), 1)
@@ -302,7 +303,7 @@ class ExpressionEvaluationSuite extends FunSuite {
checkEvaluation("23" cast DoubleType, 23d)
checkEvaluation("23" cast IntegerType, 23)
checkEvaluation("23" cast FloatType, 23f)
- checkEvaluation("23" cast DecimalType, 23: BigDecimal)
+ checkEvaluation("23" cast DecimalType.Unlimited, Decimal(23))
checkEvaluation("23" cast ByteType, 23.toByte)
checkEvaluation("23" cast ShortType, 23.toShort)
checkEvaluation("2012-12-11" cast DoubleType, null)
@@ -311,7 +312,7 @@ class ExpressionEvaluationSuite extends FunSuite {
checkEvaluation(Literal(23d) + Cast(true, DoubleType), 24d)
checkEvaluation(Literal(23) + Cast(true, IntegerType), 24)
checkEvaluation(Literal(23f) + Cast(true, FloatType), 24f)
- checkEvaluation(Literal(BigDecimal(23)) + Cast(true, DecimalType), 24: BigDecimal)
+ checkEvaluation(Literal(Decimal(23)) + Cast(true, DecimalType.Unlimited), Decimal(24))
checkEvaluation(Literal(23.toByte) + Cast(true, ByteType), 24.toByte)
checkEvaluation(Literal(23.toShort) + Cast(true, ShortType), 24.toShort)
@@ -325,7 +326,8 @@ class ExpressionEvaluationSuite extends FunSuite {
assert(("abcdef" cast IntegerType).nullable === true)
assert(("abcdef" cast ShortType).nullable === true)
assert(("abcdef" cast ByteType).nullable === true)
- assert(("abcdef" cast DecimalType).nullable === true)
+ assert(("abcdef" cast DecimalType.Unlimited).nullable === true)
+ assert(("abcdef" cast DecimalType(4, 2)).nullable === true)
assert(("abcdef" cast DoubleType).nullable === true)
assert(("abcdef" cast FloatType).nullable === true)
@@ -338,6 +340,64 @@ class ExpressionEvaluationSuite extends FunSuite {
checkEvaluation(Literal(d1) < Literal(d2), true)
}
+ test("casting to fixed-precision decimals") {
+ // Overflow and rounding for casting to fixed-precision decimals:
+ // - Values should round with HALF_UP mode by default when you lower scale
+ // - Values that would overflow the target precision should turn into null
+ // - Because of this, casts to fixed-precision decimals should be nullable
+
+ assert(Cast(Literal(123), DecimalType.Unlimited).nullable === false)
+ assert(Cast(Literal(10.03f), DecimalType.Unlimited).nullable === false)
+ assert(Cast(Literal(10.03), DecimalType.Unlimited).nullable === false)
+ assert(Cast(Literal(Decimal(10.03)), DecimalType.Unlimited).nullable === false)
+
+ assert(Cast(Literal(123), DecimalType(2, 1)).nullable === true)
+ assert(Cast(Literal(10.03f), DecimalType(2, 1)).nullable === true)
+ assert(Cast(Literal(10.03), DecimalType(2, 1)).nullable === true)
+ assert(Cast(Literal(Decimal(10.03)), DecimalType(2, 1)).nullable === true)
+
+ checkEvaluation(Cast(Literal(123), DecimalType.Unlimited), Decimal(123))
+ checkEvaluation(Cast(Literal(123), DecimalType(3, 0)), Decimal(123))
+ checkEvaluation(Cast(Literal(123), DecimalType(3, 1)), null)
+ checkEvaluation(Cast(Literal(123), DecimalType(2, 0)), null)
+
+ checkEvaluation(Cast(Literal(10.03), DecimalType.Unlimited), Decimal(10.03))
+ checkEvaluation(Cast(Literal(10.03), DecimalType(4, 2)), Decimal(10.03))
+ checkEvaluation(Cast(Literal(10.03), DecimalType(3, 1)), Decimal(10.0))
+ checkEvaluation(Cast(Literal(10.03), DecimalType(2, 0)), Decimal(10))
+ checkEvaluation(Cast(Literal(10.03), DecimalType(1, 0)), null)
+ checkEvaluation(Cast(Literal(10.03), DecimalType(2, 1)), null)
+ checkEvaluation(Cast(Literal(10.03), DecimalType(3, 2)), null)
+ checkEvaluation(Cast(Literal(Decimal(10.03)), DecimalType(3, 1)), Decimal(10.0))
+ checkEvaluation(Cast(Literal(Decimal(10.03)), DecimalType(3, 2)), null)
+
+ checkEvaluation(Cast(Literal(10.05), DecimalType.Unlimited), Decimal(10.05))
+ checkEvaluation(Cast(Literal(10.05), DecimalType(4, 2)), Decimal(10.05))
+ checkEvaluation(Cast(Literal(10.05), DecimalType(3, 1)), Decimal(10.1))
+ checkEvaluation(Cast(Literal(10.05), DecimalType(2, 0)), Decimal(10))
+ checkEvaluation(Cast(Literal(10.05), DecimalType(1, 0)), null)
+ checkEvaluation(Cast(Literal(10.05), DecimalType(2, 1)), null)
+ checkEvaluation(Cast(Literal(10.05), DecimalType(3, 2)), null)
+ checkEvaluation(Cast(Literal(Decimal(10.05)), DecimalType(3, 1)), Decimal(10.1))
+ checkEvaluation(Cast(Literal(Decimal(10.05)), DecimalType(3, 2)), null)
+
+ checkEvaluation(Cast(Literal(9.95), DecimalType(3, 2)), Decimal(9.95))
+ checkEvaluation(Cast(Literal(9.95), DecimalType(3, 1)), Decimal(10.0))
+ checkEvaluation(Cast(Literal(9.95), DecimalType(2, 0)), Decimal(10))
+ checkEvaluation(Cast(Literal(9.95), DecimalType(2, 1)), null)
+ checkEvaluation(Cast(Literal(9.95), DecimalType(1, 0)), null)
+ checkEvaluation(Cast(Literal(Decimal(9.95)), DecimalType(3, 1)), Decimal(10.0))
+ checkEvaluation(Cast(Literal(Decimal(9.95)), DecimalType(1, 0)), null)
+
+ checkEvaluation(Cast(Literal(-9.95), DecimalType(3, 2)), Decimal(-9.95))
+ checkEvaluation(Cast(Literal(-9.95), DecimalType(3, 1)), Decimal(-10.0))
+ checkEvaluation(Cast(Literal(-9.95), DecimalType(2, 0)), Decimal(-10))
+ checkEvaluation(Cast(Literal(-9.95), DecimalType(2, 1)), null)
+ checkEvaluation(Cast(Literal(-9.95), DecimalType(1, 0)), null)
+ checkEvaluation(Cast(Literal(Decimal(-9.95)), DecimalType(3, 1)), Decimal(-10.0))
+ checkEvaluation(Cast(Literal(Decimal(-9.95)), DecimalType(1, 0)), null)
+ }
+
test("timestamp") {
val ts1 = new Timestamp(12)
val ts2 = new Timestamp(123)
@@ -374,7 +434,7 @@ class ExpressionEvaluationSuite extends FunSuite {
millis.toFloat / 1000)
checkEvaluation(Cast(Cast(millis.toDouble / 1000, TimestampType), DoubleType),
millis.toDouble / 1000)
- checkEvaluation(Cast(Literal(BigDecimal(1)) cast TimestampType, DecimalType), 1)
+ checkEvaluation(Cast(Literal(Decimal(1)) cast TimestampType, DecimalType.Unlimited), Decimal(1))
// A test for higher precision than millis
checkEvaluation(Cast(Cast(0.00000001, TimestampType), DoubleType), 0.00000001)
@@ -673,7 +733,7 @@ class ExpressionEvaluationSuite extends FunSuite {
val expectedResults = inputSequence.map(l => math.sqrt(l.toDouble))
val rowSequence = inputSequence.map(l => new GenericRow(Array[Any](l.toDouble)))
val d = 'a.double.at(0)
-
+
for ((row, expected) <- rowSequence zip expectedResults) {
checkEvaluation(Sqrt(d), expected, row)
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/types/decimal/DecimalSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/types/decimal/DecimalSuite.scala
new file mode 100644
index 0000000000..5aa263484d
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/types/decimal/DecimalSuite.scala
@@ -0,0 +1,158 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.types.decimal
+
+import org.scalatest.{PrivateMethodTester, FunSuite}
+
+import scala.language.postfixOps
+
+class DecimalSuite extends FunSuite with PrivateMethodTester {
+ test("creating decimals") {
+ /** Check that a Decimal has the given string representation, precision and scale */
+ def checkDecimal(d: Decimal, string: String, precision: Int, scale: Int): Unit = {
+ assert(d.toString === string)
+ assert(d.precision === precision)
+ assert(d.scale === scale)
+ }
+
+ checkDecimal(new Decimal(), "0", 1, 0)
+ checkDecimal(Decimal(BigDecimal("10.030")), "10.030", 5, 3)
+ checkDecimal(Decimal(BigDecimal("10.030"), 4, 1), "10.0", 4, 1)
+ checkDecimal(Decimal(BigDecimal("-9.95"), 4, 1), "-10.0", 4, 1)
+ checkDecimal(Decimal("10.030"), "10.030", 5, 3)
+ checkDecimal(Decimal(10.03), "10.03", 4, 2)
+ checkDecimal(Decimal(17L), "17", 20, 0)
+ checkDecimal(Decimal(17), "17", 10, 0)
+ checkDecimal(Decimal(17L, 2, 1), "1.7", 2, 1)
+ checkDecimal(Decimal(170L, 4, 2), "1.70", 4, 2)
+ checkDecimal(Decimal(17L, 24, 1), "1.7", 24, 1)
+ checkDecimal(Decimal(1e17.toLong, 18, 0), 1e17.toLong.toString, 18, 0)
+ checkDecimal(Decimal(Long.MaxValue), Long.MaxValue.toString, 20, 0)
+ checkDecimal(Decimal(Long.MinValue), Long.MinValue.toString, 20, 0)
+ intercept[IllegalArgumentException](Decimal(170L, 2, 1))
+ intercept[IllegalArgumentException](Decimal(170L, 2, 0))
+ intercept[IllegalArgumentException](Decimal(BigDecimal("10.030"), 2, 1))
+ intercept[IllegalArgumentException](Decimal(BigDecimal("-9.95"), 2, 1))
+ intercept[IllegalArgumentException](Decimal(1e17.toLong, 17, 0))
+ }
+
+ test("double and long values") {
+ /** Check that a Decimal converts to the given double and long values */
+ def checkValues(d: Decimal, doubleValue: Double, longValue: Long): Unit = {
+ assert(d.toDouble === doubleValue)
+ assert(d.toLong === longValue)
+ }
+
+ checkValues(new Decimal(), 0.0, 0L)
+ checkValues(Decimal(BigDecimal("10.030")), 10.03, 10L)
+ checkValues(Decimal(BigDecimal("10.030"), 4, 1), 10.0, 10L)
+ checkValues(Decimal(BigDecimal("-9.95"), 4, 1), -10.0, -10L)
+ checkValues(Decimal(10.03), 10.03, 10L)
+ checkValues(Decimal(17L), 17.0, 17L)
+ checkValues(Decimal(17), 17.0, 17L)
+ checkValues(Decimal(17L, 2, 1), 1.7, 1L)
+ checkValues(Decimal(170L, 4, 2), 1.7, 1L)
+ checkValues(Decimal(1e16.toLong), 1e16, 1e16.toLong)
+ checkValues(Decimal(1e17.toLong), 1e17, 1e17.toLong)
+ checkValues(Decimal(1e18.toLong), 1e18, 1e18.toLong)
+ checkValues(Decimal(2e18.toLong), 2e18, 2e18.toLong)
+ checkValues(Decimal(Long.MaxValue), Long.MaxValue.toDouble, Long.MaxValue)
+ checkValues(Decimal(Long.MinValue), Long.MinValue.toDouble, Long.MinValue)
+ checkValues(Decimal(Double.MaxValue), Double.MaxValue, 0L)
+ checkValues(Decimal(Double.MinValue), Double.MinValue, 0L)
+ }
+
+ // Accessor for the BigDecimal value of a Decimal, which will be null if it's using Longs
+ private val decimalVal = PrivateMethod[BigDecimal]('decimalVal)
+
+ /** Check whether a decimal is represented compactly (passing whether we expect it to be) */
+ private def checkCompact(d: Decimal, expected: Boolean): Unit = {
+ val isCompact = d.invokePrivate(decimalVal()).eq(null)
+ assert(isCompact == expected, s"$d ${if (expected) "was not" else "was"} compact")
+ }
+
+ test("small decimals represented as unscaled long") {
+ checkCompact(new Decimal(), true)
+ checkCompact(Decimal(BigDecimal(10.03)), false)
+ checkCompact(Decimal(BigDecimal(1e20)), false)
+ checkCompact(Decimal(17L), true)
+ checkCompact(Decimal(17), true)
+ checkCompact(Decimal(17L, 2, 1), true)
+ checkCompact(Decimal(170L, 4, 2), true)
+ checkCompact(Decimal(17L, 24, 1), true)
+ checkCompact(Decimal(1e16.toLong), true)
+ checkCompact(Decimal(1e17.toLong), true)
+ checkCompact(Decimal(1e18.toLong - 1), true)
+ checkCompact(Decimal(- 1e18.toLong + 1), true)
+ checkCompact(Decimal(1e18.toLong - 1, 30, 10), true)
+ checkCompact(Decimal(- 1e18.toLong + 1, 30, 10), true)
+ checkCompact(Decimal(1e18.toLong), false)
+ checkCompact(Decimal(-1e18.toLong), false)
+ checkCompact(Decimal(1e18.toLong, 30, 10), false)
+ checkCompact(Decimal(-1e18.toLong, 30, 10), false)
+ checkCompact(Decimal(Long.MaxValue), false)
+ checkCompact(Decimal(Long.MinValue), false)
+ }
+
+ test("hash code") {
+ assert(Decimal(123).hashCode() === (123).##)
+ assert(Decimal(-123).hashCode() === (-123).##)
+ assert(Decimal(123.312).hashCode() === (123.312).##)
+ assert(Decimal(Int.MaxValue).hashCode() === Int.MaxValue.##)
+ assert(Decimal(Long.MaxValue).hashCode() === Long.MaxValue.##)
+ assert(Decimal(BigDecimal(123)).hashCode() === (123).##)
+
+ val reallyBig = BigDecimal("123182312312313232112312312123.1231231231")
+ assert(Decimal(reallyBig).hashCode() === reallyBig.hashCode)
+ }
+
+ test("equals") {
+ // The decimals on the left are stored compactly, while the ones on the right aren't
+ checkCompact(Decimal(123), true)
+ checkCompact(Decimal(BigDecimal(123)), false)
+ checkCompact(Decimal("123"), false)
+ assert(Decimal(123) === Decimal(BigDecimal(123)))
+ assert(Decimal(123) === Decimal(BigDecimal("123.00")))
+ assert(Decimal(-123) === Decimal(BigDecimal(-123)))
+ assert(Decimal(-123) === Decimal(BigDecimal("-123.00")))
+ }
+
+ test("isZero") {
+ assert(Decimal(0).isZero)
+ assert(Decimal(0, 4, 2).isZero)
+ assert(Decimal("0").isZero)
+ assert(Decimal("0.000").isZero)
+ assert(!Decimal(1).isZero)
+ assert(!Decimal(1, 4, 2).isZero)
+ assert(!Decimal("1").isZero)
+ assert(!Decimal("0.001").isZero)
+ }
+
+ test("arithmetic") {
+ assert(Decimal(100) + Decimal(-100) === Decimal(0))
+ assert(Decimal(100) + Decimal(-100) === Decimal(0))
+ assert(Decimal(100) * Decimal(-100) === Decimal(-10000))
+ assert(Decimal(1e13) * Decimal(1e13) === Decimal(1e26))
+ assert(Decimal(100) / Decimal(-100) === Decimal(-1))
+ assert(Decimal(100) / Decimal(0) === null)
+ assert(Decimal(100) % Decimal(-100) === Decimal(0))
+ assert(Decimal(100) % Decimal(3) === Decimal(1))
+ assert(Decimal(-100) % Decimal(3) === Decimal(-1))
+ assert(Decimal(100) % Decimal(0) === null)
+ }
+}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/DataType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/DataType.java
index 0c85cdc0aa..c38354039d 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/api/java/DataType.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/DataType.java
@@ -53,11 +53,6 @@ public abstract class DataType {
public static final TimestampType TimestampType = new TimestampType();
/**
- * Gets the DecimalType object.
- */
- public static final DecimalType DecimalType = new DecimalType();
-
- /**
* Gets the DoubleType object.
*/
public static final DoubleType DoubleType = new DoubleType();
diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/DecimalType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/DecimalType.java
index bc54c078d7..60752451ec 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/api/java/DecimalType.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/DecimalType.java
@@ -19,9 +19,61 @@ package org.apache.spark.sql.api.java;
/**
* The data type representing java.math.BigDecimal values.
- *
- * {@code DecimalType} is represented by the singleton object {@link DataType#DecimalType}.
*/
public class DecimalType extends DataType {
- protected DecimalType() {}
+ private boolean hasPrecisionInfo;
+ private int precision;
+ private int scale;
+
+ public DecimalType(int precision, int scale) {
+ this.hasPrecisionInfo = true;
+ this.precision = precision;
+ this.scale = scale;
+ }
+
+ public DecimalType() {
+ this.hasPrecisionInfo = false;
+ this.precision = -1;
+ this.scale = -1;
+ }
+
+ public boolean isUnlimited() {
+ return !hasPrecisionInfo;
+ }
+
+ public boolean isFixed() {
+ return hasPrecisionInfo;
+ }
+
+ /** Return the precision, or -1 if no precision is set */
+ public int getPrecision() {
+ return precision;
+ }
+
+ /** Return the scale, or -1 if no precision is set */
+ public int getScale() {
+ return scale;
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) return true;
+ if (o == null || getClass() != o.getClass()) return false;
+
+ DecimalType that = (DecimalType) o;
+
+ if (hasPrecisionInfo != that.hasPrecisionInfo) return false;
+ if (precision != that.precision) return false;
+ if (scale != that.scale) return false;
+
+ return true;
+ }
+
+ @Override
+ public int hashCode() {
+ int result = (hasPrecisionInfo ? 1 : 0);
+ result = 31 * result + precision;
+ result = 31 * result + scale;
+ return result;
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
index 8b96df1096..018a18c4ac 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql
import java.util.{Map => JMap, List => JList}
+import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.storage.StorageLevel
import scala.collection.JavaConversions._
@@ -113,7 +114,7 @@ class SchemaRDD(
// =========================================================================================
override def compute(split: Partition, context: TaskContext): Iterator[Row] =
- firstParent[Row].compute(split, context).map(_.copy())
+ firstParent[Row].compute(split, context).map(ScalaReflection.convertRowToScala)
override def getPartitions: Array[Partition] = firstParent[Row].partitions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala
index 082ae03eef..876b1c6ede 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala
@@ -230,7 +230,7 @@ class JavaSQLContext(val sqlContext: SQLContext) extends UDFRegistration {
case c: Class[_] if c == classOf[java.lang.Boolean] =>
(org.apache.spark.sql.BooleanType, true)
case c: Class[_] if c == classOf[java.math.BigDecimal] =>
- (org.apache.spark.sql.DecimalType, true)
+ (org.apache.spark.sql.DecimalType(), true)
case c: Class[_] if c == classOf[java.sql.Date] =>
(org.apache.spark.sql.DateType, true)
case c: Class[_] if c == classOf[java.sql.Timestamp] =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/Row.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/java/Row.scala
index df01411f60..401798e317 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/api/java/Row.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/Row.scala
@@ -17,6 +17,8 @@
package org.apache.spark.sql.api.java
+import org.apache.spark.sql.catalyst.types.decimal.Decimal
+
import scala.annotation.varargs
import scala.collection.convert.Wrappers.{JListWrapper, JMapWrapper}
import scala.collection.JavaConversions
@@ -106,6 +108,8 @@ class Row(private[spark] val row: ScalaRow) extends Serializable {
}
override def hashCode(): Int = row.hashCode()
+
+ override def toString: String = row.toString
}
object Row {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
index b3edd5020f..087b0ecbb2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
@@ -70,16 +70,29 @@ case class GeneratedAggregate(
val computeFunctions = aggregatesToCompute.map {
case c @ Count(expr) =>
+ // If we're evaluating UnscaledValue(x), we can do Count on x directly, since its
+ // UnscaledValue will be null if and only if x is null; helps with Average on decimals
+ val toCount = expr match {
+ case UnscaledValue(e) => e
+ case _ => expr
+ }
val currentCount = AttributeReference("currentCount", LongType, nullable = false)()
val initialValue = Literal(0L)
- val updateFunction = If(IsNotNull(expr), Add(currentCount, Literal(1L)), currentCount)
+ val updateFunction = If(IsNotNull(toCount), Add(currentCount, Literal(1L)), currentCount)
val result = currentCount
AggregateEvaluation(currentCount :: Nil, initialValue :: Nil, updateFunction :: Nil, result)
case Sum(expr) =>
- val currentSum = AttributeReference("currentSum", expr.dataType, nullable = false)()
- val initialValue = Cast(Literal(0L), expr.dataType)
+ val resultType = expr.dataType match {
+ case DecimalType.Fixed(precision, scale) =>
+ DecimalType(precision + 10, scale)
+ case _ =>
+ expr.dataType
+ }
+
+ val currentSum = AttributeReference("currentSum", resultType, nullable = false)()
+ val initialValue = Cast(Literal(0L), resultType)
// Coalasce avoids double calculation...
// but really, common sub expression elimination would be better....
@@ -93,10 +106,26 @@ case class GeneratedAggregate(
val currentSum = AttributeReference("currentSum", expr.dataType, nullable = false)()
val initialCount = Literal(0L)
val initialSum = Cast(Literal(0L), expr.dataType)
- val updateCount = If(IsNotNull(expr), Add(currentCount, Literal(1L)), currentCount)
+
+ // If we're evaluating UnscaledValue(x), we can do Count on x directly, since its
+ // UnscaledValue will be null if and only if x is null; helps with Average on decimals
+ val toCount = expr match {
+ case UnscaledValue(e) => e
+ case _ => expr
+ }
+
+ val updateCount = If(IsNotNull(toCount), Add(currentCount, Literal(1L)), currentCount)
val updateSum = Coalesce(Add(expr, currentSum) :: currentSum :: Nil)
- val result = Divide(Cast(currentSum, DoubleType), Cast(currentCount, DoubleType))
+ val resultType = expr.dataType match {
+ case DecimalType.Fixed(precision, scale) =>
+ DecimalType(precision + 4, scale + 4)
+ case DecimalType.Unlimited =>
+ DecimalType.Unlimited
+ case _ =>
+ DoubleType
+ }
+ val result = Divide(Cast(currentSum, resultType), Cast(currentCount, resultType))
AggregateEvaluation(
currentCount :: currentSum :: Nil,
@@ -142,7 +171,7 @@ case class GeneratedAggregate(
val computationSchema = computeFunctions.flatMap(_.schema)
- val resultMap: Map[TreeNodeRef, Expression] =
+ val resultMap: Map[TreeNodeRef, Expression] =
aggregatesToCompute.zip(computeFunctions).map {
case (agg, func) => new TreeNodeRef(agg) -> func.result
}.toMap
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
index b1a7948b66..aafcce0572 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
@@ -23,7 +23,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SQLContext
-import org.apache.spark.sql.catalyst.trees
+import org.apache.spark.sql.catalyst.{ScalaReflection, trees}
import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen._
@@ -82,7 +82,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
/**
* Runs this query returning the result as an array.
*/
- def executeCollect(): Array[Row] = execute().map(_.copy()).collect()
+ def executeCollect(): Array[Row] = execute().map(ScalaReflection.convertRowToScala).collect()
protected def newProjection(
expressions: Seq[Expression], inputSchema: Seq[Attribute]): Projection = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala
index 077e6ebc5f..84d96e612f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala
@@ -29,6 +29,7 @@ import com.twitter.chill.{AllScalaRegistrar, ResourcePool}
import org.apache.spark.{SparkEnv, SparkConf}
import org.apache.spark.serializer.{SerializerInstance, KryoSerializer}
import org.apache.spark.sql.catalyst.expressions.GenericRow
+import org.apache.spark.sql.catalyst.types.decimal.Decimal
import org.apache.spark.util.collection.OpenHashSet
import org.apache.spark.util.MutablePair
import org.apache.spark.util.Utils
@@ -51,6 +52,7 @@ private[sql] class SparkSqlSerializer(conf: SparkConf) extends KryoSerializer(co
kryo.register(classOf[LongHashSet], new LongHashSetSerializer)
kryo.register(classOf[org.apache.spark.util.collection.OpenHashSet[_]],
new OpenHashSetSerializer)
+ kryo.register(classOf[Decimal])
kryo.setReferences(false)
kryo.setClassLoader(Utils.getSparkClassLoader)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
index 977f3c9f32..e6cd1a9d04 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
@@ -143,7 +143,7 @@ case class Limit(limit: Int, child: SparkPlan)
partsScanned += numPartsToTry
}
- buf.toArray
+ buf.toArray.map(ScalaReflection.convertRowToScala)
}
override def execute() = {
@@ -176,10 +176,11 @@ case class TakeOrdered(limit: Int, sortOrder: Seq[SortOrder], child: SparkPlan)
override def output = child.output
override def outputPartitioning = SinglePartition
- val ordering = new RowOrdering(sortOrder, child.output)
+ val ord = new RowOrdering(sortOrder, child.output)
// TODO: Is this copying for no reason?
- override def executeCollect() = child.execute().map(_.copy()).takeOrdered(limit)(ordering)
+ override def executeCollect() =
+ child.execute().map(_.copy()).takeOrdered(limit)(ord).map(ScalaReflection.convertRowToScala)
// TODO: Terminal split should be implemented differently from non-terminal split.
// TODO: Pick num splits based on |limit|.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
index 8fd35880ee..5cf2a785ad 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
@@ -49,7 +49,8 @@ case class BroadcastHashJoin(
@transient
private val broadcastFuture = future {
- val input: Array[Row] = buildPlan.executeCollect()
+ // Note that we use .execute().collect() because we don't want to convert data to Scala types
+ val input: Array[Row] = buildPlan.execute().map(_.copy()).collect()
val hashed = HashedRelation(input.iterator, buildSideKeyGenerator, input.length)
sparkContext.broadcast(hashed)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala
index a1961bba18..997669051e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala
@@ -19,6 +19,8 @@ package org.apache.spark.sql.execution
import java.util.{List => JList, Map => JMap}
+import org.apache.spark.sql.catalyst.types.decimal.Decimal
+
import scala.collection.JavaConversions._
import scala.collection.JavaConverters._
@@ -116,7 +118,7 @@ object EvaluatePython {
def toJava(obj: Any, dataType: DataType): Any = (obj, dataType) match {
case (null, _) => null
- case (row: Row, struct: StructType) =>
+ case (row: Seq[Any], struct: StructType) =>
val fields = struct.fields.map(field => field.dataType)
row.zip(fields).map {
case (obj, dataType) => toJava(obj, dataType)
@@ -133,6 +135,8 @@ object EvaluatePython {
case (k, v) => (k, toJava(v, mt.valueType)) // key should be primitive type
}.asJava
+ case (dec: BigDecimal, dt: DecimalType) => dec.underlying() // Pyrolite can handle BigDecimal
+
// Pyrolite can handle Timestamp
case (other, _) => other
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala
index eabe312f92..5bb6f6c85d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala
@@ -17,6 +17,8 @@
package org.apache.spark.sql.json
+import org.apache.spark.sql.catalyst.types.decimal.Decimal
+
import scala.collection.Map
import scala.collection.convert.Wrappers.{JMapWrapper, JListWrapper}
import scala.math.BigDecimal
@@ -175,9 +177,9 @@ private[sql] object JsonRDD extends Logging {
ScalaReflection.typeOfObject orElse {
// Since we do not have a data type backed by BigInteger,
// when we see a Java BigInteger, we use DecimalType.
- case value: java.math.BigInteger => DecimalType
+ case value: java.math.BigInteger => DecimalType.Unlimited
// DecimalType's JVMType is scala BigDecimal.
- case value: java.math.BigDecimal => DecimalType
+ case value: java.math.BigDecimal => DecimalType.Unlimited
// Unexpected data type.
case _ => StringType
}
@@ -319,13 +321,13 @@ private[sql] object JsonRDD extends Logging {
}
}
- private def toDecimal(value: Any): BigDecimal = {
+ private def toDecimal(value: Any): Decimal = {
value match {
- case value: java.lang.Integer => BigDecimal(value)
- case value: java.lang.Long => BigDecimal(value)
- case value: java.math.BigInteger => BigDecimal(value)
- case value: java.lang.Double => BigDecimal(value)
- case value: java.math.BigDecimal => BigDecimal(value)
+ case value: java.lang.Integer => Decimal(value)
+ case value: java.lang.Long => Decimal(value)
+ case value: java.math.BigInteger => Decimal(BigDecimal(value))
+ case value: java.lang.Double => Decimal(value)
+ case value: java.math.BigDecimal => Decimal(BigDecimal(value))
}
}
@@ -391,7 +393,7 @@ private[sql] object JsonRDD extends Logging {
case IntegerType => value.asInstanceOf[IntegerType.JvmType]
case LongType => toLong(value)
case DoubleType => toDouble(value)
- case DecimalType => toDecimal(value)
+ case DecimalType() => toDecimal(value)
case BooleanType => value.asInstanceOf[BooleanType.JvmType]
case NullType => null
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/package.scala
index f0e57e2a74..05926a24c5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/package.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/package.scala
@@ -183,6 +183,20 @@ package object sql {
*
* The data type representing `scala.math.BigDecimal` values.
*
+ * TODO(matei): explain precision and scale
+ *
+ * @group dataType
+ */
+ @DeveloperApi
+ type DecimalType = catalyst.types.DecimalType
+
+ /**
+ * :: DeveloperApi ::
+ *
+ * The data type representing `scala.math.BigDecimal` values.
+ *
+ * TODO(matei): explain precision and scale
+ *
* @group dataType
*/
@DeveloperApi
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala
index 2fc7e1cf23..08feced61a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala
@@ -17,6 +17,8 @@
package org.apache.spark.sql.parquet
+import org.apache.spark.sql.catalyst.types.decimal.Decimal
+
import scala.collection.mutable.{Buffer, ArrayBuffer, HashMap}
import parquet.io.api.{PrimitiveConverter, GroupConverter, Binary, Converter}
@@ -117,6 +119,12 @@ private[sql] object CatalystConverter {
parent.updateByte(fieldIndex, value.asInstanceOf[ByteType.JvmType])
}
}
+ case d: DecimalType => {
+ new CatalystPrimitiveConverter(parent, fieldIndex) {
+ override def addBinary(value: Binary): Unit =
+ parent.updateDecimal(fieldIndex, value, d)
+ }
+ }
// All other primitive types use the default converter
case ctype: PrimitiveType => { // note: need the type tag here!
new CatalystPrimitiveConverter(parent, fieldIndex)
@@ -191,6 +199,10 @@ private[parquet] abstract class CatalystConverter extends GroupConverter {
protected[parquet] def updateString(fieldIndex: Int, value: Binary): Unit =
updateField(fieldIndex, value.toStringUsingUTF8)
+ protected[parquet] def updateDecimal(fieldIndex: Int, value: Binary, ctype: DecimalType): Unit = {
+ updateField(fieldIndex, readDecimal(new Decimal(), value, ctype))
+ }
+
protected[parquet] def isRootConverter: Boolean = parent == null
protected[parquet] def clearBuffer(): Unit
@@ -201,6 +213,27 @@ private[parquet] abstract class CatalystConverter extends GroupConverter {
* @return
*/
def getCurrentRecord: Row = throw new UnsupportedOperationException
+
+ /**
+ * Read a decimal value from a Parquet Binary into "dest". Only supports decimals that fit in
+ * a long (i.e. precision <= 18)
+ */
+ protected[parquet] def readDecimal(dest: Decimal, value: Binary, ctype: DecimalType): Unit = {
+ val precision = ctype.precisionInfo.get.precision
+ val scale = ctype.precisionInfo.get.scale
+ val bytes = value.getBytes
+ require(bytes.length <= 16, "Decimal field too large to read")
+ var unscaled = 0L
+ var i = 0
+ while (i < bytes.length) {
+ unscaled = (unscaled << 8) | (bytes(i) & 0xFF)
+ i += 1
+ }
+ // Make sure unscaled has the right sign, by sign-extending the first bit
+ val numBits = 8 * bytes.length
+ unscaled = (unscaled << (64 - numBits)) >> (64 - numBits)
+ dest.set(unscaled, precision, scale)
+ }
}
/**
@@ -352,6 +385,16 @@ private[parquet] class CatalystPrimitiveRowConverter(
override protected[parquet] def updateString(fieldIndex: Int, value: Binary): Unit =
current.setString(fieldIndex, value.toStringUsingUTF8)
+
+ override protected[parquet] def updateDecimal(
+ fieldIndex: Int, value: Binary, ctype: DecimalType): Unit = {
+ var decimal = current(fieldIndex).asInstanceOf[Decimal]
+ if (decimal == null) {
+ decimal = new Decimal
+ current(fieldIndex) = decimal
+ }
+ readDecimal(decimal, value, ctype)
+ }
}
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala
index bdf02401b2..2a5f23b24e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.parquet
import java.util.{HashMap => JHashMap}
import org.apache.hadoop.conf.Configuration
+import org.apache.spark.sql.catalyst.types.decimal.Decimal
import parquet.column.ParquetProperties
import parquet.hadoop.ParquetOutputFormat
import parquet.hadoop.api.ReadSupport.ReadContext
@@ -204,6 +205,11 @@ private[parquet] class RowWriteSupport extends WriteSupport[Row] with Logging {
case DoubleType => writer.addDouble(value.asInstanceOf[Double])
case FloatType => writer.addFloat(value.asInstanceOf[Float])
case BooleanType => writer.addBoolean(value.asInstanceOf[Boolean])
+ case d: DecimalType =>
+ if (d.precisionInfo == None || d.precisionInfo.get.precision > 18) {
+ sys.error(s"Unsupported datatype $d, cannot write to consumer")
+ }
+ writeDecimal(value.asInstanceOf[Decimal], d.precisionInfo.get.precision)
case _ => sys.error(s"Do not know how to writer $schema to consumer")
}
}
@@ -283,6 +289,23 @@ private[parquet] class RowWriteSupport extends WriteSupport[Row] with Logging {
}
writer.endGroup()
}
+
+ // Scratch array used to write decimals as fixed-length binary
+ private val scratchBytes = new Array[Byte](8)
+
+ private[parquet] def writeDecimal(decimal: Decimal, precision: Int): Unit = {
+ val numBytes = ParquetTypesConverter.BYTES_FOR_PRECISION(precision)
+ val unscaledLong = decimal.toUnscaledLong
+ var i = 0
+ var shift = 8 * (numBytes - 1)
+ while (i < numBytes) {
+ scratchBytes(i) = (unscaledLong >> shift).toByte
+ i += 1
+ shift -= 8
+ }
+ writer.addBinary(Binary.fromByteArray(scratchBytes, 0, numBytes))
+ }
+
}
// Optimized for non-nested rows
@@ -326,6 +349,11 @@ private[parquet] class MutableRowWriteSupport extends RowWriteSupport {
case DoubleType => writer.addDouble(record.getDouble(index))
case FloatType => writer.addFloat(record.getFloat(index))
case BooleanType => writer.addBoolean(record.getBoolean(index))
+ case d: DecimalType =>
+ if (d.precisionInfo == None || d.precisionInfo.get.precision > 18) {
+ sys.error(s"Unsupported datatype $d, cannot write to consumer")
+ }
+ writeDecimal(record(index).asInstanceOf[Decimal], d.precisionInfo.get.precision)
case _ => sys.error(s"Unsupported datatype $ctype, cannot write to consumer")
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala
index e6389cf77a..e5077de8dd 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala
@@ -29,8 +29,8 @@ import org.apache.hadoop.mapreduce.lib.output.FileOutputCommitter
import parquet.hadoop.{ParquetFileReader, Footer, ParquetFileWriter}
import parquet.hadoop.metadata.{ParquetMetadata, FileMetaData}
import parquet.hadoop.util.ContextUtil
-import parquet.schema.{Type => ParquetType, PrimitiveType => ParquetPrimitiveType, MessageType}
-import parquet.schema.{GroupType => ParquetGroupType, OriginalType => ParquetOriginalType, ConversionPatterns}
+import parquet.schema.{Type => ParquetType, Types => ParquetTypes, PrimitiveType => ParquetPrimitiveType, MessageType}
+import parquet.schema.{GroupType => ParquetGroupType, OriginalType => ParquetOriginalType, ConversionPatterns, DecimalMetadata}
import parquet.schema.PrimitiveType.{PrimitiveTypeName => ParquetPrimitiveTypeName}
import parquet.schema.Type.Repetition
@@ -41,17 +41,25 @@ import org.apache.spark.sql.catalyst.types._
// Implicits
import scala.collection.JavaConversions._
+/** A class representing Parquet info fields we care about, for passing back to Parquet */
+private[parquet] case class ParquetTypeInfo(
+ primitiveType: ParquetPrimitiveTypeName,
+ originalType: Option[ParquetOriginalType] = None,
+ decimalMetadata: Option[DecimalMetadata] = None,
+ length: Option[Int] = None)
+
private[parquet] object ParquetTypesConverter extends Logging {
def isPrimitiveType(ctype: DataType): Boolean =
classOf[PrimitiveType] isAssignableFrom ctype.getClass
def toPrimitiveDataType(
parquetType: ParquetPrimitiveType,
- binayAsString: Boolean): DataType =
+ binaryAsString: Boolean): DataType = {
+ val originalType = parquetType.getOriginalType
+ val decimalInfo = parquetType.getDecimalMetadata
parquetType.getPrimitiveTypeName match {
case ParquetPrimitiveTypeName.BINARY
- if (parquetType.getOriginalType == ParquetOriginalType.UTF8 ||
- binayAsString) => StringType
+ if (originalType == ParquetOriginalType.UTF8 || binaryAsString) => StringType
case ParquetPrimitiveTypeName.BINARY => BinaryType
case ParquetPrimitiveTypeName.BOOLEAN => BooleanType
case ParquetPrimitiveTypeName.DOUBLE => DoubleType
@@ -61,9 +69,14 @@ private[parquet] object ParquetTypesConverter extends Logging {
case ParquetPrimitiveTypeName.INT96 =>
// TODO: add BigInteger type? TODO(andre) use DecimalType instead????
sys.error("Potential loss of precision: cannot convert INT96")
+ case ParquetPrimitiveTypeName.FIXED_LEN_BYTE_ARRAY
+ if (originalType == ParquetOriginalType.DECIMAL && decimalInfo.getPrecision <= 18) =>
+ // TODO: for now, our reader only supports decimals that fit in a Long
+ DecimalType(decimalInfo.getPrecision, decimalInfo.getScale)
case _ => sys.error(
s"Unsupported parquet datatype $parquetType")
}
+ }
/**
* Converts a given Parquet `Type` into the corresponding
@@ -183,24 +196,41 @@ private[parquet] object ParquetTypesConverter extends Logging {
* is not primitive.
*
* @param ctype The type to convert
- * @return The name of the corresponding Parquet primitive type
+ * @return The name of the corresponding Parquet type properties
*/
- def fromPrimitiveDataType(ctype: DataType):
- Option[(ParquetPrimitiveTypeName, Option[ParquetOriginalType])] = ctype match {
- case StringType => Some(ParquetPrimitiveTypeName.BINARY, Some(ParquetOriginalType.UTF8))
- case BinaryType => Some(ParquetPrimitiveTypeName.BINARY, None)
- case BooleanType => Some(ParquetPrimitiveTypeName.BOOLEAN, None)
- case DoubleType => Some(ParquetPrimitiveTypeName.DOUBLE, None)
- case FloatType => Some(ParquetPrimitiveTypeName.FLOAT, None)
- case IntegerType => Some(ParquetPrimitiveTypeName.INT32, None)
+ def fromPrimitiveDataType(ctype: DataType): Option[ParquetTypeInfo] = ctype match {
+ case StringType => Some(ParquetTypeInfo(
+ ParquetPrimitiveTypeName.BINARY, Some(ParquetOriginalType.UTF8)))
+ case BinaryType => Some(ParquetTypeInfo(ParquetPrimitiveTypeName.BINARY))
+ case BooleanType => Some(ParquetTypeInfo(ParquetPrimitiveTypeName.BOOLEAN))
+ case DoubleType => Some(ParquetTypeInfo(ParquetPrimitiveTypeName.DOUBLE))
+ case FloatType => Some(ParquetTypeInfo(ParquetPrimitiveTypeName.FLOAT))
+ case IntegerType => Some(ParquetTypeInfo(ParquetPrimitiveTypeName.INT32))
// There is no type for Byte or Short so we promote them to INT32.
- case ShortType => Some(ParquetPrimitiveTypeName.INT32, None)
- case ByteType => Some(ParquetPrimitiveTypeName.INT32, None)
- case LongType => Some(ParquetPrimitiveTypeName.INT64, None)
+ case ShortType => Some(ParquetTypeInfo(ParquetPrimitiveTypeName.INT32))
+ case ByteType => Some(ParquetTypeInfo(ParquetPrimitiveTypeName.INT32))
+ case LongType => Some(ParquetTypeInfo(ParquetPrimitiveTypeName.INT64))
+ case DecimalType.Fixed(precision, scale) if precision <= 18 =>
+ // TODO: for now, our writer only supports decimals that fit in a Long
+ Some(ParquetTypeInfo(ParquetPrimitiveTypeName.FIXED_LEN_BYTE_ARRAY,
+ Some(ParquetOriginalType.DECIMAL),
+ Some(new DecimalMetadata(precision, scale)),
+ Some(BYTES_FOR_PRECISION(precision))))
case _ => None
}
/**
+ * Compute the FIXED_LEN_BYTE_ARRAY length needed to represent a given DECIMAL precision.
+ */
+ private[parquet] val BYTES_FOR_PRECISION = Array.tabulate[Int](38) { precision =>
+ var length = 1
+ while (math.pow(2.0, 8 * length - 1) < math.pow(10.0, precision)) {
+ length += 1
+ }
+ length
+ }
+
+ /**
* Converts a given Catalyst [[org.apache.spark.sql.catalyst.types.DataType]] into
* the corresponding Parquet `Type`.
*
@@ -247,10 +277,17 @@ private[parquet] object ParquetTypesConverter extends Logging {
} else {
if (nullable) Repetition.OPTIONAL else Repetition.REQUIRED
}
- val primitiveType = fromPrimitiveDataType(ctype)
- primitiveType.map {
- case (primitiveType, originalType) =>
- new ParquetPrimitiveType(repetition, primitiveType, name, originalType.orNull)
+ val typeInfo = fromPrimitiveDataType(ctype)
+ typeInfo.map {
+ case ParquetTypeInfo(primitiveType, originalType, decimalMetadata, length) =>
+ val builder = ParquetTypes.primitive(primitiveType, repetition).as(originalType.orNull)
+ for (len <- length) {
+ builder.length(len)
+ }
+ for (metadata <- decimalMetadata) {
+ builder.precision(metadata.getPrecision).scale(metadata.getScale)
+ }
+ builder.named(name)
}.getOrElse {
ctype match {
case ArrayType(elementType, false) => {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala b/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala
index 142598c904..7564bf3923 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala
@@ -19,6 +19,8 @@ package org.apache.spark.sql.types.util
import org.apache.spark.sql._
import org.apache.spark.sql.api.java.{DataType => JDataType, StructField => JStructField, MetadataBuilder => JMetaDataBuilder}
+import org.apache.spark.sql.api.java.{DecimalType => JDecimalType}
+import org.apache.spark.sql.catalyst.types.decimal.Decimal
import scala.collection.JavaConverters._
@@ -44,7 +46,8 @@ protected[sql] object DataTypeConversions {
case BooleanType => JDataType.BooleanType
case DateType => JDataType.DateType
case TimestampType => JDataType.TimestampType
- case DecimalType => JDataType.DecimalType
+ case DecimalType.Fixed(precision, scale) => new JDecimalType(precision, scale)
+ case DecimalType.Unlimited => new JDecimalType()
case DoubleType => JDataType.DoubleType
case FloatType => JDataType.FloatType
case ByteType => JDataType.ByteType
@@ -88,7 +91,11 @@ protected[sql] object DataTypeConversions {
case timestampType: org.apache.spark.sql.api.java.TimestampType =>
TimestampType
case decimalType: org.apache.spark.sql.api.java.DecimalType =>
- DecimalType
+ if (decimalType.isFixed) {
+ DecimalType(decimalType.getPrecision, decimalType.getScale)
+ } else {
+ DecimalType.Unlimited
+ }
case doubleType: org.apache.spark.sql.api.java.DoubleType =>
DoubleType
case floatType: org.apache.spark.sql.api.java.FloatType =>
@@ -115,7 +122,7 @@ protected[sql] object DataTypeConversions {
/** Converts Java objects to catalyst rows / types */
def convertJavaToCatalyst(a: Any): Any = a match {
- case d: java.math.BigDecimal => BigDecimal(d)
+ case d: java.math.BigDecimal => Decimal(BigDecimal(d))
case other => other
}
diff --git a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaApplySchemaSuite.java b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaApplySchemaSuite.java
index 9435a88009..a04b8060cd 100644
--- a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaApplySchemaSuite.java
+++ b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaApplySchemaSuite.java
@@ -118,7 +118,7 @@ public class JavaApplySchemaSuite implements Serializable {
"\"bigInteger\":92233720368547758069, \"double\":1.7976931348623157E305, " +
"\"boolean\":false, \"null\":null}"));
List<StructField> fields = new ArrayList<StructField>(7);
- fields.add(DataType.createStructField("bigInteger", DataType.DecimalType, true));
+ fields.add(DataType.createStructField("bigInteger", new DecimalType(), true));
fields.add(DataType.createStructField("boolean", DataType.BooleanType, true));
fields.add(DataType.createStructField("double", DataType.DoubleType, true));
fields.add(DataType.createStructField("integer", DataType.IntegerType, true));
diff --git a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaSideDataTypeConversionSuite.java b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaSideDataTypeConversionSuite.java
index d04396a5f8..8396a29c61 100644
--- a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaSideDataTypeConversionSuite.java
+++ b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaSideDataTypeConversionSuite.java
@@ -41,7 +41,8 @@ public class JavaSideDataTypeConversionSuite {
checkDataType(DataType.BooleanType);
checkDataType(DataType.DateType);
checkDataType(DataType.TimestampType);
- checkDataType(DataType.DecimalType);
+ checkDataType(new DecimalType());
+ checkDataType(new DecimalType(10, 4));
checkDataType(DataType.DoubleType);
checkDataType(DataType.FloatType);
checkDataType(DataType.ByteType);
@@ -59,7 +60,7 @@ public class JavaSideDataTypeConversionSuite {
// Simple StructType.
List<StructField> simpleFields = new ArrayList<StructField>();
- simpleFields.add(DataType.createStructField("a", DataType.DecimalType, false));
+ simpleFields.add(DataType.createStructField("a", new DecimalType(), false));
simpleFields.add(DataType.createStructField("b", DataType.BooleanType, true));
simpleFields.add(DataType.createStructField("c", DataType.LongType, true));
simpleFields.add(DataType.createStructField("d", DataType.BinaryType, false));
@@ -128,7 +129,7 @@ public class JavaSideDataTypeConversionSuite {
// StructType
try {
List<StructField> simpleFields = new ArrayList<StructField>();
- simpleFields.add(DataType.createStructField("a", DataType.DecimalType, false));
+ simpleFields.add(DataType.createStructField("a", new DecimalType(), false));
simpleFields.add(DataType.createStructField("b", DataType.BooleanType, true));
simpleFields.add(DataType.createStructField("c", DataType.LongType, true));
simpleFields.add(null);
@@ -138,7 +139,7 @@ public class JavaSideDataTypeConversionSuite {
}
try {
List<StructField> simpleFields = new ArrayList<StructField>();
- simpleFields.add(DataType.createStructField("a", DataType.DecimalType, false));
+ simpleFields.add(DataType.createStructField("a", new DecimalType(), false));
simpleFields.add(DataType.createStructField("a", DataType.BooleanType, true));
simpleFields.add(DataType.createStructField("c", DataType.LongType, true));
DataType.createStructType(simpleFields);
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataTypeSuite.scala
index 6c9db639c0..e9740d913c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataTypeSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataTypeSuite.scala
@@ -69,7 +69,7 @@ class DataTypeSuite extends FunSuite {
checkDataTypeJsonRepr(LongType)
checkDataTypeJsonRepr(FloatType)
checkDataTypeJsonRepr(DoubleType)
- checkDataTypeJsonRepr(DecimalType)
+ checkDataTypeJsonRepr(DecimalType.Unlimited)
checkDataTypeJsonRepr(TimestampType)
checkDataTypeJsonRepr(StringType)
checkDataTypeJsonRepr(BinaryType)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala
index bfa9ea4162..cf3a59e545 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql
import java.sql.{Date, Timestamp}
+import org.apache.spark.sql.catalyst.types.decimal.Decimal
import org.scalatest.FunSuite
import org.apache.spark.sql.catalyst.expressions._
@@ -81,7 +82,9 @@ class ScalaReflectionRelationSuite extends FunSuite {
val rdd = sparkContext.parallelize(data :: Nil)
rdd.registerTempTable("reflectData")
- assert(sql("SELECT * FROM reflectData").collect().head === data.productIterator.toSeq)
+ assert(sql("SELECT * FROM reflectData").collect().head ===
+ Seq("a", 1, 1L, 1.toFloat, 1.toDouble, 1.toShort, 1.toByte, true,
+ BigDecimal(1), new Date(12345), new Timestamp(12345), Seq(1,2,3)))
}
test("query case class RDD with nulls") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/api/java/JavaSQLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/api/java/JavaSQLSuite.scala
index d83f3e23a9..c9012c9e47 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/api/java/JavaSQLSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/api/java/JavaSQLSuite.scala
@@ -17,6 +17,8 @@
package org.apache.spark.sql.api.java
+import org.apache.spark.sql.catalyst.types.decimal.Decimal
+
import scala.beans.BeanProperty
import org.scalatest.FunSuite
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/api/java/ScalaSideDataTypeConversionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/api/java/ScalaSideDataTypeConversionSuite.scala
index e0e0ff9cb3..62fe59dd34 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/api/java/ScalaSideDataTypeConversionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/api/java/ScalaSideDataTypeConversionSuite.scala
@@ -38,7 +38,7 @@ class ScalaSideDataTypeConversionSuite extends FunSuite {
checkDataType(org.apache.spark.sql.BooleanType)
checkDataType(org.apache.spark.sql.DateType)
checkDataType(org.apache.spark.sql.TimestampType)
- checkDataType(org.apache.spark.sql.DecimalType)
+ checkDataType(org.apache.spark.sql.DecimalType.Unlimited)
checkDataType(org.apache.spark.sql.DoubleType)
checkDataType(org.apache.spark.sql.FloatType)
checkDataType(org.apache.spark.sql.ByteType)
@@ -58,7 +58,7 @@ class ScalaSideDataTypeConversionSuite extends FunSuite {
// Simple StructType.
val simpleScalaStructType = SStructType(
- SStructField("a", org.apache.spark.sql.DecimalType, false) ::
+ SStructField("a", org.apache.spark.sql.DecimalType.Unlimited, false) ::
SStructField("b", org.apache.spark.sql.BooleanType, true) ::
SStructField("c", org.apache.spark.sql.LongType, true) ::
SStructField("d", org.apache.spark.sql.BinaryType, false) :: Nil)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala
index ce6184f5d8..1cb6c23c58 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala
@@ -18,6 +18,7 @@
package org.apache.spark.sql.json
import org.apache.spark.sql.catalyst.types._
+import org.apache.spark.sql.catalyst.types.decimal.Decimal
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.json.JsonRDD.{enforceCorrectType, compatibleType}
import org.apache.spark.sql.QueryTest
@@ -44,19 +45,22 @@ class JsonSuite extends QueryTest {
checkTypePromotion(intNumber, enforceCorrectType(intNumber, IntegerType))
checkTypePromotion(intNumber.toLong, enforceCorrectType(intNumber, LongType))
checkTypePromotion(intNumber.toDouble, enforceCorrectType(intNumber, DoubleType))
- checkTypePromotion(BigDecimal(intNumber), enforceCorrectType(intNumber, DecimalType))
+ checkTypePromotion(
+ Decimal(intNumber), enforceCorrectType(intNumber, DecimalType.Unlimited))
val longNumber: Long = 9223372036854775807L
checkTypePromotion(longNumber, enforceCorrectType(longNumber, LongType))
checkTypePromotion(longNumber.toDouble, enforceCorrectType(longNumber, DoubleType))
- checkTypePromotion(BigDecimal(longNumber), enforceCorrectType(longNumber, DecimalType))
+ checkTypePromotion(
+ Decimal(longNumber), enforceCorrectType(longNumber, DecimalType.Unlimited))
val doubleNumber: Double = 1.7976931348623157E308d
checkTypePromotion(doubleNumber.toDouble, enforceCorrectType(doubleNumber, DoubleType))
- checkTypePromotion(BigDecimal(doubleNumber), enforceCorrectType(doubleNumber, DecimalType))
-
+ checkTypePromotion(
+ Decimal(doubleNumber), enforceCorrectType(doubleNumber, DecimalType.Unlimited))
+
checkTypePromotion(new Timestamp(intNumber), enforceCorrectType(intNumber, TimestampType))
- checkTypePromotion(new Timestamp(intNumber.toLong),
+ checkTypePromotion(new Timestamp(intNumber.toLong),
enforceCorrectType(intNumber.toLong, TimestampType))
val strTime = "2014-09-30 12:34:56"
checkTypePromotion(Timestamp.valueOf(strTime), enforceCorrectType(strTime, TimestampType))
@@ -80,7 +84,7 @@ class JsonSuite extends QueryTest {
checkDataType(NullType, IntegerType, IntegerType)
checkDataType(NullType, LongType, LongType)
checkDataType(NullType, DoubleType, DoubleType)
- checkDataType(NullType, DecimalType, DecimalType)
+ checkDataType(NullType, DecimalType.Unlimited, DecimalType.Unlimited)
checkDataType(NullType, StringType, StringType)
checkDataType(NullType, ArrayType(IntegerType), ArrayType(IntegerType))
checkDataType(NullType, StructType(Nil), StructType(Nil))
@@ -91,7 +95,7 @@ class JsonSuite extends QueryTest {
checkDataType(BooleanType, IntegerType, StringType)
checkDataType(BooleanType, LongType, StringType)
checkDataType(BooleanType, DoubleType, StringType)
- checkDataType(BooleanType, DecimalType, StringType)
+ checkDataType(BooleanType, DecimalType.Unlimited, StringType)
checkDataType(BooleanType, StringType, StringType)
checkDataType(BooleanType, ArrayType(IntegerType), StringType)
checkDataType(BooleanType, StructType(Nil), StringType)
@@ -100,7 +104,7 @@ class JsonSuite extends QueryTest {
checkDataType(IntegerType, IntegerType, IntegerType)
checkDataType(IntegerType, LongType, LongType)
checkDataType(IntegerType, DoubleType, DoubleType)
- checkDataType(IntegerType, DecimalType, DecimalType)
+ checkDataType(IntegerType, DecimalType.Unlimited, DecimalType.Unlimited)
checkDataType(IntegerType, StringType, StringType)
checkDataType(IntegerType, ArrayType(IntegerType), StringType)
checkDataType(IntegerType, StructType(Nil), StringType)
@@ -108,23 +112,23 @@ class JsonSuite extends QueryTest {
// LongType
checkDataType(LongType, LongType, LongType)
checkDataType(LongType, DoubleType, DoubleType)
- checkDataType(LongType, DecimalType, DecimalType)
+ checkDataType(LongType, DecimalType.Unlimited, DecimalType.Unlimited)
checkDataType(LongType, StringType, StringType)
checkDataType(LongType, ArrayType(IntegerType), StringType)
checkDataType(LongType, StructType(Nil), StringType)
// DoubleType
checkDataType(DoubleType, DoubleType, DoubleType)
- checkDataType(DoubleType, DecimalType, DecimalType)
+ checkDataType(DoubleType, DecimalType.Unlimited, DecimalType.Unlimited)
checkDataType(DoubleType, StringType, StringType)
checkDataType(DoubleType, ArrayType(IntegerType), StringType)
checkDataType(DoubleType, StructType(Nil), StringType)
// DoubleType
- checkDataType(DecimalType, DecimalType, DecimalType)
- checkDataType(DecimalType, StringType, StringType)
- checkDataType(DecimalType, ArrayType(IntegerType), StringType)
- checkDataType(DecimalType, StructType(Nil), StringType)
+ checkDataType(DecimalType.Unlimited, DecimalType.Unlimited, DecimalType.Unlimited)
+ checkDataType(DecimalType.Unlimited, StringType, StringType)
+ checkDataType(DecimalType.Unlimited, ArrayType(IntegerType), StringType)
+ checkDataType(DecimalType.Unlimited, StructType(Nil), StringType)
// StringType
checkDataType(StringType, StringType, StringType)
@@ -178,7 +182,7 @@ class JsonSuite extends QueryTest {
checkDataType(
StructType(
StructField("f1", IntegerType, true) :: Nil),
- DecimalType,
+ DecimalType.Unlimited,
StringType)
}
@@ -186,7 +190,7 @@ class JsonSuite extends QueryTest {
val jsonSchemaRDD = jsonRDD(primitiveFieldAndType)
val expectedSchema = StructType(
- StructField("bigInteger", DecimalType, true) ::
+ StructField("bigInteger", DecimalType.Unlimited, true) ::
StructField("boolean", BooleanType, true) ::
StructField("double", DoubleType, true) ::
StructField("integer", IntegerType, true) ::
@@ -216,7 +220,7 @@ class JsonSuite extends QueryTest {
val expectedSchema = StructType(
StructField("arrayOfArray1", ArrayType(ArrayType(StringType, false), false), true) ::
StructField("arrayOfArray2", ArrayType(ArrayType(DoubleType, false), false), true) ::
- StructField("arrayOfBigInteger", ArrayType(DecimalType, false), true) ::
+ StructField("arrayOfBigInteger", ArrayType(DecimalType.Unlimited, false), true) ::
StructField("arrayOfBoolean", ArrayType(BooleanType, false), true) ::
StructField("arrayOfDouble", ArrayType(DoubleType, false), true) ::
StructField("arrayOfInteger", ArrayType(IntegerType, false), true) ::
@@ -230,7 +234,7 @@ class JsonSuite extends QueryTest {
StructField("field3", StringType, true) :: Nil), false), true) ::
StructField("struct", StructType(
StructField("field1", BooleanType, true) ::
- StructField("field2", DecimalType, true) :: Nil), true) ::
+ StructField("field2", DecimalType.Unlimited, true) :: Nil), true) ::
StructField("structWithArrayFields", StructType(
StructField("field1", ArrayType(IntegerType, false), true) ::
StructField("field2", ArrayType(StringType, false), true) :: Nil), true) :: Nil)
@@ -331,7 +335,7 @@ class JsonSuite extends QueryTest {
val expectedSchema = StructType(
StructField("num_bool", StringType, true) ::
StructField("num_num_1", LongType, true) ::
- StructField("num_num_2", DecimalType, true) ::
+ StructField("num_num_2", DecimalType.Unlimited, true) ::
StructField("num_num_3", DoubleType, true) ::
StructField("num_str", StringType, true) ::
StructField("str_bool", StringType, true) :: Nil)
@@ -521,7 +525,7 @@ class JsonSuite extends QueryTest {
val jsonSchemaRDD = jsonFile(path)
val expectedSchema = StructType(
- StructField("bigInteger", DecimalType, true) ::
+ StructField("bigInteger", DecimalType.Unlimited, true) ::
StructField("boolean", BooleanType, true) ::
StructField("double", DoubleType, true) ::
StructField("integer", IntegerType, true) ::
@@ -551,7 +555,7 @@ class JsonSuite extends QueryTest {
primitiveFieldAndType.map(record => record.replaceAll("\n", " ")).saveAsTextFile(path)
val schema = StructType(
- StructField("bigInteger", DecimalType, true) ::
+ StructField("bigInteger", DecimalType.Unlimited, true) ::
StructField("boolean", BooleanType, true) ::
StructField("double", DoubleType, true) ::
StructField("integer", IntegerType, true) ::
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala
index 9979ab446d..08d9da27f1 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala
@@ -77,6 +77,8 @@ case class AllDataTypesWithNonPrimitiveType(
case class BinaryData(binaryData: Array[Byte])
+case class NumericData(i: Int, d: Double)
+
class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterAll {
TestData // Load test data tables.
@@ -560,7 +562,7 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA
assert(stringResult.size === 1)
assert(stringResult(0).getString(2) == "100", "stringvalue incorrect")
assert(stringResult(0).getInt(1) === 100)
-
+
val query7 = sql(s"SELECT * FROM testfiltersource WHERE myoptint < 40")
assert(
query7.queryExecution.executedPlan(0)(0).isInstanceOf[ParquetTableScan],
@@ -869,4 +871,35 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA
assert(a.dataType === b.dataType)
}
}
+
+ test("read/write fixed-length decimals") {
+ for ((precision, scale) <- Seq((5, 2), (1, 0), (1, 1), (18, 10), (18, 17))) {
+ val tempDir = getTempFilePath("parquetTest").getCanonicalPath
+ val data = sparkContext.parallelize(0 to 1000)
+ .map(i => NumericData(i, i / 100.0))
+ .select('i, 'd cast DecimalType(precision, scale))
+ data.saveAsParquetFile(tempDir)
+ checkAnswer(parquetFile(tempDir), data.toSchemaRDD.collect().toSeq)
+ }
+
+ // Decimals with precision above 18 are not yet supported
+ intercept[RuntimeException] {
+ val tempDir = getTempFilePath("parquetTest").getCanonicalPath
+ val data = sparkContext.parallelize(0 to 1000)
+ .map(i => NumericData(i, i / 100.0))
+ .select('i, 'd cast DecimalType(19, 10))
+ data.saveAsParquetFile(tempDir)
+ checkAnswer(parquetFile(tempDir), data.toSchemaRDD.collect().toSeq)
+ }
+
+ // Unlimited-length decimals are not yet supported
+ intercept[RuntimeException] {
+ val tempDir = getTempFilePath("parquetTest").getCanonicalPath
+ val data = sparkContext.parallelize(0 to 1000)
+ .map(i => NumericData(i, i / 100.0))
+ .select('i, 'd cast DecimalType.Unlimited)
+ data.saveAsParquetFile(tempDir)
+ checkAnswer(parquetFile(tempDir), data.toSchemaRDD.collect().toSeq)
+ }
+ }
}
diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala
index 2a4f24132c..99c4f46a82 100644
--- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala
+++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala
@@ -47,7 +47,7 @@ private[thriftserver] class SparkSQLOperationManager(hiveContext: HiveContext)
val operation = new SparkExecuteStatementOperation(parentSession, statement, confOverlay)(
hiveContext, sessionToActivePool)
- handleToOperation.put(operation.getHandle, operation)
- operation
+ handleToOperation.put(operation.getHandle, operation)
+ operation
}
}
diff --git a/sql/hive-thriftserver/v0.12.0/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim12.scala b/sql/hive-thriftserver/v0.12.0/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim12.scala
index bbd727c686..8077d0ec46 100644
--- a/sql/hive-thriftserver/v0.12.0/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim12.scala
+++ b/sql/hive-thriftserver/v0.12.0/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim12.scala
@@ -123,7 +123,7 @@ private[hive] class SparkExecuteStatementOperation(
to.addColumnValue(ColumnValue.doubleValue(from.getDouble(ordinal)))
case FloatType =>
to.addColumnValue(ColumnValue.floatValue(from.getFloat(ordinal)))
- case DecimalType =>
+ case DecimalType() =>
val hiveDecimal = from.get(ordinal).asInstanceOf[BigDecimal].bigDecimal
to.addColumnValue(ColumnValue.stringValue(new HiveDecimal(hiveDecimal)))
case LongType =>
@@ -156,7 +156,7 @@ private[hive] class SparkExecuteStatementOperation(
to.addColumnValue(ColumnValue.doubleValue(null))
case FloatType =>
to.addColumnValue(ColumnValue.floatValue(null))
- case DecimalType =>
+ case DecimalType() =>
to.addColumnValue(ColumnValue.stringValue(null: HiveDecimal))
case LongType =>
to.addColumnValue(ColumnValue.longValue(null))
diff --git a/sql/hive-thriftserver/v0.13.1/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim13.scala b/sql/hive-thriftserver/v0.13.1/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim13.scala
index e59681bfbe..2c1983de1d 100644
--- a/sql/hive-thriftserver/v0.13.1/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim13.scala
+++ b/sql/hive-thriftserver/v0.13.1/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim13.scala
@@ -123,7 +123,7 @@ private[hive] class SparkExecuteStatementOperation(
to += from.getDouble(ordinal)
case FloatType =>
to += from.getFloat(ordinal)
- case DecimalType =>
+ case DecimalType() =>
to += from.get(ordinal).asInstanceOf[BigDecimal].bigDecimal
case LongType =>
to += from.getLong(ordinal)
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
index ff8fa44194..2e27817d60 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
@@ -21,6 +21,10 @@ import java.io.{BufferedReader, File, InputStreamReader, PrintStream}
import java.sql.{Date, Timestamp}
import java.util.{ArrayList => JArrayList}
+import org.apache.hadoop.hive.common.`type`.HiveDecimal
+import org.apache.spark.sql.catalyst.types.DecimalType
+import org.apache.spark.sql.catalyst.types.decimal.Decimal
+
import scala.collection.JavaConversions._
import scala.language.implicitConversions
import scala.reflect.runtime.universe.{TypeTag, typeTag}
@@ -370,7 +374,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
protected val primitiveTypes =
Seq(StringType, IntegerType, LongType, DoubleType, FloatType, BooleanType, ByteType,
- ShortType, DecimalType, DateType, TimestampType, BinaryType)
+ ShortType, DateType, TimestampType, BinaryType)
protected[sql] def toHiveString(a: (Any, DataType)): String = a match {
case (struct: Row, StructType(fields)) =>
@@ -388,6 +392,8 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
case (d: Date, DateType) => new DateWritable(d).toString
case (t: Timestamp, TimestampType) => new TimestampWritable(t).toString
case (bin: Array[Byte], BinaryType) => new String(bin, "UTF-8")
+ case (decimal: Decimal, DecimalType()) => // Hive strips trailing zeros so use its toString
+ HiveShim.createDecimal(decimal.toBigDecimal.underlying()).toString
case (other, tpe) if primitiveTypes contains tpe => other.toString
}
@@ -406,6 +412,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
}.toSeq.sorted.mkString("{", ",", "}")
case (null, _) => "null"
case (s: String, StringType) => "\"" + s + "\""
+ case (decimal, DecimalType()) => decimal.toString
case (other, tpe) if primitiveTypes contains tpe => other.toString
}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala
index 0439ab97d8..1e2bf5cc4b 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala
@@ -28,6 +28,7 @@ import org.apache.hadoop.{io => hadoopIo}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.types
import org.apache.spark.sql.catalyst.types._
+import org.apache.spark.sql.catalyst.types.decimal.Decimal
/* Implicit conversions */
import scala.collection.JavaConversions._
@@ -38,7 +39,7 @@ private[hive] trait HiveInspectors {
// writable
case c: Class[_] if c == classOf[hadoopIo.DoubleWritable] => DoubleType
case c: Class[_] if c == classOf[hiveIo.DoubleWritable] => DoubleType
- case c: Class[_] if c == classOf[hiveIo.HiveDecimalWritable] => DecimalType
+ case c: Class[_] if c == classOf[hiveIo.HiveDecimalWritable] => DecimalType.Unlimited
case c: Class[_] if c == classOf[hiveIo.ByteWritable] => ByteType
case c: Class[_] if c == classOf[hiveIo.ShortWritable] => ShortType
case c: Class[_] if c == classOf[hiveIo.DateWritable] => DateType
@@ -54,8 +55,8 @@ private[hive] trait HiveInspectors {
case c: Class[_] if c == classOf[java.lang.String] => StringType
case c: Class[_] if c == classOf[java.sql.Date] => DateType
case c: Class[_] if c == classOf[java.sql.Timestamp] => TimestampType
- case c: Class[_] if c == classOf[HiveDecimal] => DecimalType
- case c: Class[_] if c == classOf[java.math.BigDecimal] => DecimalType
+ case c: Class[_] if c == classOf[HiveDecimal] => DecimalType.Unlimited
+ case c: Class[_] if c == classOf[java.math.BigDecimal] => DecimalType.Unlimited
case c: Class[_] if c == classOf[Array[Byte]] => BinaryType
case c: Class[_] if c == classOf[java.lang.Short] => ShortType
case c: Class[_] if c == classOf[java.lang.Integer] => IntegerType
@@ -90,7 +91,7 @@ private[hive] trait HiveInspectors {
case hvoi: HiveVarcharObjectInspector =>
if (data == null) null else hvoi.getPrimitiveJavaObject(data).getValue
case hdoi: HiveDecimalObjectInspector =>
- if (data == null) null else BigDecimal(hdoi.getPrimitiveJavaObject(data).bigDecimalValue())
+ if (data == null) null else HiveShim.toCatalystDecimal(hdoi, data)
// org.apache.hadoop.hive.serde2.io.TimestampWritable.set will reset current time object
// if next timestamp is null, so Timestamp object is cloned
case ti: TimestampObjectInspector => ti.getPrimitiveJavaObject(data).clone()
@@ -137,8 +138,9 @@ private[hive] trait HiveInspectors {
case l: Short => l: java.lang.Short
case l: Byte => l: java.lang.Byte
case b: BigDecimal => HiveShim.createDecimal(b.underlying())
+ case d: Decimal => HiveShim.createDecimal(d.toBigDecimal.underlying())
case b: Array[Byte] => b
- case d: java.sql.Date => d
+ case d: java.sql.Date => d
case t: java.sql.Timestamp => t
}
case x: StructObjectInspector =>
@@ -200,7 +202,7 @@ private[hive] trait HiveInspectors {
case BinaryType => PrimitiveObjectInspectorFactory.javaByteArrayObjectInspector
case DateType => PrimitiveObjectInspectorFactory.javaDateObjectInspector
case TimestampType => PrimitiveObjectInspectorFactory.javaTimestampObjectInspector
- case DecimalType => PrimitiveObjectInspectorFactory.javaHiveDecimalObjectInspector
+ case DecimalType() => PrimitiveObjectInspectorFactory.javaHiveDecimalObjectInspector
case StructType(fields) =>
ObjectInspectorFactory.getStandardStructObjectInspector(
fields.map(f => f.name), fields.map(f => toInspector(f.dataType)))
@@ -229,8 +231,10 @@ private[hive] trait HiveInspectors {
HiveShim.getPrimitiveWritableConstantObjectInspector(value)
case Literal(value: java.sql.Timestamp, TimestampType) =>
HiveShim.getPrimitiveWritableConstantObjectInspector(value)
- case Literal(value: BigDecimal, DecimalType) =>
+ case Literal(value: BigDecimal, DecimalType()) =>
HiveShim.getPrimitiveWritableConstantObjectInspector(value)
+ case Literal(value: Decimal, DecimalType()) =>
+ HiveShim.getPrimitiveWritableConstantObjectInspector(value.toBigDecimal)
case Literal(_, NullType) =>
HiveShim.getPrimitiveNullWritableConstantObjectInspector
case Literal(value: Seq[_], ArrayType(dt, _)) =>
@@ -277,8 +281,8 @@ private[hive] trait HiveInspectors {
case _: JavaFloatObjectInspector => FloatType
case _: WritableBinaryObjectInspector => BinaryType
case _: JavaBinaryObjectInspector => BinaryType
- case _: WritableHiveDecimalObjectInspector => DecimalType
- case _: JavaHiveDecimalObjectInspector => DecimalType
+ case w: WritableHiveDecimalObjectInspector => HiveShim.decimalTypeInfoToCatalyst(w)
+ case j: JavaHiveDecimalObjectInspector => HiveShim.decimalTypeInfoToCatalyst(j)
case _: WritableDateObjectInspector => DateType
case _: JavaDateObjectInspector => DateType
case _: WritableTimestampObjectInspector => TimestampType
@@ -307,7 +311,7 @@ private[hive] trait HiveInspectors {
case LongType => longTypeInfo
case ShortType => shortTypeInfo
case StringType => stringTypeInfo
- case DecimalType => decimalTypeInfo
+ case d: DecimalType => HiveShim.decimalTypeInfo(d)
case DateType => dateTypeInfo
case TimestampType => timestampTypeInfo
case NullType => voidTypeInfo
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
index 2dd2c882a8..096b4a07aa 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.hive
import java.io.IOException
import java.util.{List => JList}
+import scala.util.matching.Regex
import scala.util.parsing.combinator.RegexParsers
import org.apache.hadoop.util.ReflectionUtils
@@ -321,11 +322,18 @@ object HiveMetastoreTypes extends RegexParsers {
"bigint" ^^^ LongType |
"binary" ^^^ BinaryType |
"boolean" ^^^ BooleanType |
- HiveShim.metastoreDecimal ^^^ DecimalType |
+ fixedDecimalType | // Hive 0.13+ decimal with precision/scale
+ "decimal" ^^^ DecimalType.Unlimited | // Hive 0.12 decimal with no precision/scale
"date" ^^^ DateType |
"timestamp" ^^^ TimestampType |
"varchar\\((\\d+)\\)".r ^^^ StringType
+ protected lazy val fixedDecimalType: Parser[DataType] =
+ ("decimal" ~> "(" ~> "\\d+".r) ~ ("," ~> "\\d+".r <~ ")") ^^ {
+ case precision ~ scale =>
+ DecimalType(precision.toInt, scale.toInt)
+ }
+
protected lazy val arrayType: Parser[DataType] =
"array" ~> "<" ~> dataType <~ ">" ^^ {
case tpe => ArrayType(tpe)
@@ -373,7 +381,7 @@ object HiveMetastoreTypes extends RegexParsers {
case BinaryType => "binary"
case BooleanType => "boolean"
case DateType => "date"
- case DecimalType => "decimal"
+ case d: DecimalType => HiveShim.decimalMetastoreString(d)
case TimestampType => "timestamp"
case NullType => "void"
}
@@ -441,7 +449,7 @@ private[hive] case class MetastoreRelation
val partitionKeys = hiveQlTable.getPartitionKeys.map(_.toAttribute)
/** Non-partitionKey attributes */
- val attributes = hiveQlTable.getCols.map(_.toAttribute)
+ val attributes = hiveQlTable.getCols.map(_.toAttribute)
val output = attributes ++ partitionKeys
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
index a3573e6502..74f68d0f95 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
@@ -31,6 +31,7 @@ import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.types._
+import org.apache.spark.sql.catalyst.types.decimal.Decimal
/* Implicit conversions */
import scala.collection.JavaConversions._
@@ -325,7 +326,11 @@ private[hive] object HiveQl {
}
protected def nodeToDataType(node: Node): DataType = node match {
- case Token("TOK_DECIMAL", Nil) => DecimalType
+ case Token("TOK_DECIMAL", precision :: scale :: Nil) =>
+ DecimalType(precision.getText.toInt, scale.getText.toInt)
+ case Token("TOK_DECIMAL", precision :: Nil) =>
+ DecimalType(precision.getText.toInt, 0)
+ case Token("TOK_DECIMAL", Nil) => DecimalType.Unlimited
case Token("TOK_BIGINT", Nil) => LongType
case Token("TOK_INT", Nil) => IntegerType
case Token("TOK_TINYINT", Nil) => ByteType
@@ -942,8 +947,12 @@ private[hive] object HiveQl {
Cast(nodeToExpr(arg), BinaryType)
case Token("TOK_FUNCTION", Token("TOK_BOOLEAN", Nil) :: arg :: Nil) =>
Cast(nodeToExpr(arg), BooleanType)
+ case Token("TOK_FUNCTION", Token("TOK_DECIMAL", precision :: scale :: nil) :: arg :: Nil) =>
+ Cast(nodeToExpr(arg), DecimalType(precision.getText.toInt, scale.getText.toInt))
+ case Token("TOK_FUNCTION", Token("TOK_DECIMAL", precision :: Nil) :: arg :: Nil) =>
+ Cast(nodeToExpr(arg), DecimalType(precision.getText.toInt, 0))
case Token("TOK_FUNCTION", Token("TOK_DECIMAL", Nil) :: arg :: Nil) =>
- Cast(nodeToExpr(arg), DecimalType)
+ Cast(nodeToExpr(arg), DecimalType.Unlimited)
case Token("TOK_FUNCTION", Token("TOK_TIMESTAMP", Nil) :: arg :: Nil) =>
Cast(nodeToExpr(arg), TimestampType)
case Token("TOK_FUNCTION", Token("TOK_DATE", Nil) :: arg :: Nil) =>
@@ -1063,7 +1072,7 @@ private[hive] object HiveQl {
} else if (ast.getText.endsWith("BD") || ast.getText.endsWith("D")) {
// Literal decimal
val strVal = ast.getText.stripSuffix("D").stripSuffix("B")
- v = Literal(BigDecimal(strVal))
+ v = Literal(Decimal(strVal))
} else {
v = Literal(ast.getText.toDouble, DoubleType)
v = Literal(ast.getText.toLong, LongType)
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala
index 79234f8a66..92bc1c6625 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala
@@ -35,6 +35,7 @@ import org.apache.hadoop.mapred.{FileOutputCommitter, FileOutputFormat, JobConf}
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.expressions.Row
+import org.apache.spark.sql.catalyst.types.decimal.Decimal
import org.apache.spark.sql.execution.{Command, SparkPlan, UnaryNode}
import org.apache.spark.sql.hive._
import org.apache.spark.sql.hive.{ ShimFileSinkDesc => FileSinkDesc}
@@ -76,7 +77,7 @@ case class InsertIntoHiveTable(
(o: Any) => new HiveVarchar(o.asInstanceOf[String], o.asInstanceOf[String].size)
case _: JavaHiveDecimalObjectInspector =>
- (o: Any) => HiveShim.createDecimal(o.asInstanceOf[BigDecimal].underlying())
+ (o: Any) => HiveShim.createDecimal(o.asInstanceOf[Decimal].toBigDecimal.underlying())
case soi: StandardStructObjectInspector =>
val wrappers = soi.getAllStructFieldRefs.map(ref => wrapperFor(ref.getFieldObjectInspector))
diff --git a/sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala b/sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala
index afc252ac27..8e946b7e82 100644
--- a/sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala
+++ b/sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala
@@ -30,21 +30,24 @@ import org.apache.hadoop.hive.ql.plan.{CreateTableDesc, FileSinkDesc, TableDesc}
import org.apache.hadoop.hive.ql.processors._
import org.apache.hadoop.hive.ql.stats.StatsSetupConst
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector.PrimitiveCategory
-import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory
-import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.{HiveDecimalObjectInspector, PrimitiveObjectInspectorFactory}
+import org.apache.hadoop.hive.serde2.objectinspector.{PrimitiveObjectInspector, ObjectInspector}
+import org.apache.hadoop.hive.serde2.typeinfo.{TypeInfo, TypeInfoFactory}
import org.apache.hadoop.hive.serde2.{Deserializer, ColumnProjectionUtils}
import org.apache.hadoop.hive.serde2.{io => hiveIo}
import org.apache.hadoop.{io => hadoopIo}
import org.apache.hadoop.mapred.InputFormat
+import org.apache.spark.sql.catalyst.types.decimal.Decimal
import scala.collection.JavaConversions._
import scala.language.implicitConversions
+import org.apache.spark.sql.catalyst.types.DecimalType
+
/**
* A compatibility layer for interacting with Hive version 0.12.0.
*/
private[hive] object HiveShim {
val version = "0.12.0"
- val metastoreDecimal = "decimal"
def getTableDesc(
serdeClass: Class[_ <: Deserializer],
@@ -149,6 +152,19 @@ private[hive] object HiveShim {
def setLocation(tbl: Table, crtTbl: CreateTableDesc): Unit = {
tbl.setDataLocation(new Path(crtTbl.getLocation()).toUri())
}
+
+ def decimalMetastoreString(decimalType: DecimalType): String = "decimal"
+
+ def decimalTypeInfo(decimalType: DecimalType): TypeInfo =
+ TypeInfoFactory.decimalTypeInfo
+
+ def decimalTypeInfoToCatalyst(inspector: PrimitiveObjectInspector): DecimalType = {
+ DecimalType.Unlimited
+ }
+
+ def toCatalystDecimal(hdoi: HiveDecimalObjectInspector, data: Any): Decimal = {
+ Decimal(hdoi.getPrimitiveJavaObject(data).bigDecimalValue())
+ }
}
class ShimFileSinkDesc(var dir: String, var tableInfo: TableDesc, var compressed: Boolean)
diff --git a/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala b/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala
index 42cd65b251..0bc330cdbe 100644
--- a/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala
+++ b/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala
@@ -29,15 +29,15 @@ import org.apache.hadoop.hive.ql.Context
import org.apache.hadoop.hive.ql.metadata.{Table, Hive, Partition}
import org.apache.hadoop.hive.ql.plan.{CreateTableDesc, FileSinkDesc, TableDesc}
import org.apache.hadoop.hive.ql.processors.CommandProcessorFactory
-import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoFactory
-import org.apache.hadoop.hive.serde2.{ColumnProjectionUtils, Deserializer}
-import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector.PrimitiveCategory
-import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory
-import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector
+import org.apache.hadoop.hive.serde2.typeinfo.{TypeInfo, DecimalTypeInfo, TypeInfoFactory}
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.{HiveDecimalObjectInspector, PrimitiveObjectInspectorFactory}
+import org.apache.hadoop.hive.serde2.objectinspector.{PrimitiveObjectInspector, ObjectInspector}
import org.apache.hadoop.hive.serde2.{Deserializer, ColumnProjectionUtils}
import org.apache.hadoop.hive.serde2.{io => hiveIo}
import org.apache.hadoop.{io => hadoopIo}
import org.apache.spark.Logging
+import org.apache.spark.sql.catalyst.types.DecimalType
+import org.apache.spark.sql.catalyst.types.decimal.Decimal
import scala.collection.JavaConversions._
import scala.language.implicitConversions
@@ -47,11 +47,6 @@ import scala.language.implicitConversions
*/
private[hive] object HiveShim {
val version = "0.13.1"
- /*
- * TODO: hive-0.13 support DECIMAL(precision, scale), DECIMAL in hive-0.12 is actually DECIMAL(38,unbounded)
- * Full support of new decimal feature need to be fixed in seperate PR.
- */
- val metastoreDecimal = "decimal\\((\\d+),(\\d+)\\)".r
def getTableDesc(
serdeClass: Class[_ <: Deserializer],
@@ -197,6 +192,30 @@ private[hive] object HiveShim {
f.setDestTableId(w.destTableId)
f
}
+
+ // Precision and scale to pass for unlimited decimals; these are the same as the precision and
+ // scale Hive 0.13 infers for BigDecimals from sources that don't specify them (e.g. UDFs)
+ private val UNLIMITED_DECIMAL_PRECISION = 38
+ private val UNLIMITED_DECIMAL_SCALE = 18
+
+ def decimalMetastoreString(decimalType: DecimalType): String = decimalType match {
+ case DecimalType.Fixed(precision, scale) => s"decimal($precision,$scale)"
+ case _ => s"decimal($UNLIMITED_DECIMAL_PRECISION,$UNLIMITED_DECIMAL_SCALE)"
+ }
+
+ def decimalTypeInfo(decimalType: DecimalType): TypeInfo = decimalType match {
+ case DecimalType.Fixed(precision, scale) => new DecimalTypeInfo(precision, scale)
+ case _ => new DecimalTypeInfo(UNLIMITED_DECIMAL_PRECISION, UNLIMITED_DECIMAL_SCALE)
+ }
+
+ def decimalTypeInfoToCatalyst(inspector: PrimitiveObjectInspector): DecimalType = {
+ val info = inspector.getTypeInfo.asInstanceOf[DecimalTypeInfo]
+ DecimalType(info.precision(), info.scale())
+ }
+
+ def toCatalystDecimal(hdoi: HiveDecimalObjectInspector, data: Any): Decimal = {
+ Decimal(hdoi.getPrimitiveJavaObject(data).bigDecimalValue(), hdoi.precision(), hdoi.scale())
+ }
}
/*