aboutsummaryrefslogtreecommitdiff
path: root/sql/core
diff options
context:
space:
mode:
authorMatei Zaharia <matei@databricks.com>2014-11-01 19:29:14 -0700
committerMichael Armbrust <michael@databricks.com>2014-11-01 19:29:14 -0700
commit23f966f47523f85ba440b4080eee665271f53b5e (patch)
treed796351567f8b187511b9049199cbf99c5826fb3 /sql/core
parent56f2c61cde3f5d906c2a58e9af1a661222f2c679 (diff)
downloadspark-23f966f47523f85ba440b4080eee665271f53b5e.tar.gz
spark-23f966f47523f85ba440b4080eee665271f53b5e.tar.bz2
spark-23f966f47523f85ba440b4080eee665271f53b5e.zip
[SPARK-3930] [SPARK-3933] Support fixed-precision decimal in SQL, and some optimizations
- Adds optional precision and scale to Spark SQL's decimal type, which behave similarly to those in Hive 13 (https://cwiki.apache.org/confluence/download/attachments/27362075/Hive_Decimal_Precision_Scale_Support.pdf) - Replaces our internal representation of decimals with a Decimal class that can store small values in a mutable Long, saving memory in this situation and letting some operations happen directly on Longs This is still marked WIP because there are a few TODOs, but I'll remove that tag when done. Author: Matei Zaharia <matei@databricks.com> Closes #2983 from mateiz/decimal-1 and squashes the following commits: 35e6b02 [Matei Zaharia] Fix issues after merge 227f24a [Matei Zaharia] Review comments 31f915e [Matei Zaharia] Implement Davies's suggestions in Python eb84820 [Matei Zaharia] Support reading/writing decimals as fixed-length binary in Parquet 4dc6bae [Matei Zaharia] Fix decimal support in PySpark d1d9d68 [Matei Zaharia] Fix compile error and test issues after rebase b28933d [Matei Zaharia] Support decimal precision/scale in Hive metastore 2118c0d [Matei Zaharia] Some test and bug fixes 81db9cb [Matei Zaharia] Added mutable Decimal that will be more efficient for small precisions 7af0c3b [Matei Zaharia] Add optional precision and scale to DecimalType, but use Unlimited for now ec0a947 [Matei Zaharia] Make the result of AVG on Decimals be Decimal, not Double
Diffstat (limited to 'sql/core')
-rw-r--r--sql/core/src/main/java/org/apache/spark/sql/api/java/DataType.java5
-rw-r--r--sql/core/src/main/java/org/apache/spark/sql/api/java/DecimalType.java58
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala3
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/api/java/Row.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala41
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala7
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala3
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala6
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala20
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/package.scala14
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala43
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala28
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala79
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala13
-rw-r--r--sql/core/src/test/java/org/apache/spark/sql/api/java/JavaApplySchemaSuite.java2
-rw-r--r--sql/core/src/test/java/org/apache/spark/sql/api/java/JavaSideDataTypeConversionSuite.java9
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataTypeSuite.scala2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala5
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/api/java/JavaSQLSuite.scala2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/api/java/ScalaSideDataTypeConversionSuite.scala4
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala46
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala35
25 files changed, 350 insertions, 87 deletions
diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/DataType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/DataType.java
index 0c85cdc0aa..c38354039d 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/api/java/DataType.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/DataType.java
@@ -53,11 +53,6 @@ public abstract class DataType {
public static final TimestampType TimestampType = new TimestampType();
/**
- * Gets the DecimalType object.
- */
- public static final DecimalType DecimalType = new DecimalType();
-
- /**
* Gets the DoubleType object.
*/
public static final DoubleType DoubleType = new DoubleType();
diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/DecimalType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/DecimalType.java
index bc54c078d7..60752451ec 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/api/java/DecimalType.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/DecimalType.java
@@ -19,9 +19,61 @@ package org.apache.spark.sql.api.java;
/**
* The data type representing java.math.BigDecimal values.
- *
- * {@code DecimalType} is represented by the singleton object {@link DataType#DecimalType}.
*/
public class DecimalType extends DataType {
- protected DecimalType() {}
+ private boolean hasPrecisionInfo;
+ private int precision;
+ private int scale;
+
+ public DecimalType(int precision, int scale) {
+ this.hasPrecisionInfo = true;
+ this.precision = precision;
+ this.scale = scale;
+ }
+
+ public DecimalType() {
+ this.hasPrecisionInfo = false;
+ this.precision = -1;
+ this.scale = -1;
+ }
+
+ public boolean isUnlimited() {
+ return !hasPrecisionInfo;
+ }
+
+ public boolean isFixed() {
+ return hasPrecisionInfo;
+ }
+
+ /** Return the precision, or -1 if no precision is set */
+ public int getPrecision() {
+ return precision;
+ }
+
+ /** Return the scale, or -1 if no precision is set */
+ public int getScale() {
+ return scale;
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) return true;
+ if (o == null || getClass() != o.getClass()) return false;
+
+ DecimalType that = (DecimalType) o;
+
+ if (hasPrecisionInfo != that.hasPrecisionInfo) return false;
+ if (precision != that.precision) return false;
+ if (scale != that.scale) return false;
+
+ return true;
+ }
+
+ @Override
+ public int hashCode() {
+ int result = (hasPrecisionInfo ? 1 : 0);
+ result = 31 * result + precision;
+ result = 31 * result + scale;
+ return result;
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
index 8b96df1096..018a18c4ac 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql
import java.util.{Map => JMap, List => JList}
+import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.storage.StorageLevel
import scala.collection.JavaConversions._
@@ -113,7 +114,7 @@ class SchemaRDD(
// =========================================================================================
override def compute(split: Partition, context: TaskContext): Iterator[Row] =
- firstParent[Row].compute(split, context).map(_.copy())
+ firstParent[Row].compute(split, context).map(ScalaReflection.convertRowToScala)
override def getPartitions: Array[Partition] = firstParent[Row].partitions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala
index 082ae03eef..876b1c6ede 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala
@@ -230,7 +230,7 @@ class JavaSQLContext(val sqlContext: SQLContext) extends UDFRegistration {
case c: Class[_] if c == classOf[java.lang.Boolean] =>
(org.apache.spark.sql.BooleanType, true)
case c: Class[_] if c == classOf[java.math.BigDecimal] =>
- (org.apache.spark.sql.DecimalType, true)
+ (org.apache.spark.sql.DecimalType(), true)
case c: Class[_] if c == classOf[java.sql.Date] =>
(org.apache.spark.sql.DateType, true)
case c: Class[_] if c == classOf[java.sql.Timestamp] =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/Row.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/java/Row.scala
index df01411f60..401798e317 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/api/java/Row.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/Row.scala
@@ -17,6 +17,8 @@
package org.apache.spark.sql.api.java
+import org.apache.spark.sql.catalyst.types.decimal.Decimal
+
import scala.annotation.varargs
import scala.collection.convert.Wrappers.{JListWrapper, JMapWrapper}
import scala.collection.JavaConversions
@@ -106,6 +108,8 @@ class Row(private[spark] val row: ScalaRow) extends Serializable {
}
override def hashCode(): Int = row.hashCode()
+
+ override def toString: String = row.toString
}
object Row {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
index b3edd5020f..087b0ecbb2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
@@ -70,16 +70,29 @@ case class GeneratedAggregate(
val computeFunctions = aggregatesToCompute.map {
case c @ Count(expr) =>
+ // If we're evaluating UnscaledValue(x), we can do Count on x directly, since its
+ // UnscaledValue will be null if and only if x is null; helps with Average on decimals
+ val toCount = expr match {
+ case UnscaledValue(e) => e
+ case _ => expr
+ }
val currentCount = AttributeReference("currentCount", LongType, nullable = false)()
val initialValue = Literal(0L)
- val updateFunction = If(IsNotNull(expr), Add(currentCount, Literal(1L)), currentCount)
+ val updateFunction = If(IsNotNull(toCount), Add(currentCount, Literal(1L)), currentCount)
val result = currentCount
AggregateEvaluation(currentCount :: Nil, initialValue :: Nil, updateFunction :: Nil, result)
case Sum(expr) =>
- val currentSum = AttributeReference("currentSum", expr.dataType, nullable = false)()
- val initialValue = Cast(Literal(0L), expr.dataType)
+ val resultType = expr.dataType match {
+ case DecimalType.Fixed(precision, scale) =>
+ DecimalType(precision + 10, scale)
+ case _ =>
+ expr.dataType
+ }
+
+ val currentSum = AttributeReference("currentSum", resultType, nullable = false)()
+ val initialValue = Cast(Literal(0L), resultType)
// Coalasce avoids double calculation...
// but really, common sub expression elimination would be better....
@@ -93,10 +106,26 @@ case class GeneratedAggregate(
val currentSum = AttributeReference("currentSum", expr.dataType, nullable = false)()
val initialCount = Literal(0L)
val initialSum = Cast(Literal(0L), expr.dataType)
- val updateCount = If(IsNotNull(expr), Add(currentCount, Literal(1L)), currentCount)
+
+ // If we're evaluating UnscaledValue(x), we can do Count on x directly, since its
+ // UnscaledValue will be null if and only if x is null; helps with Average on decimals
+ val toCount = expr match {
+ case UnscaledValue(e) => e
+ case _ => expr
+ }
+
+ val updateCount = If(IsNotNull(toCount), Add(currentCount, Literal(1L)), currentCount)
val updateSum = Coalesce(Add(expr, currentSum) :: currentSum :: Nil)
- val result = Divide(Cast(currentSum, DoubleType), Cast(currentCount, DoubleType))
+ val resultType = expr.dataType match {
+ case DecimalType.Fixed(precision, scale) =>
+ DecimalType(precision + 4, scale + 4)
+ case DecimalType.Unlimited =>
+ DecimalType.Unlimited
+ case _ =>
+ DoubleType
+ }
+ val result = Divide(Cast(currentSum, resultType), Cast(currentCount, resultType))
AggregateEvaluation(
currentCount :: currentSum :: Nil,
@@ -142,7 +171,7 @@ case class GeneratedAggregate(
val computationSchema = computeFunctions.flatMap(_.schema)
- val resultMap: Map[TreeNodeRef, Expression] =
+ val resultMap: Map[TreeNodeRef, Expression] =
aggregatesToCompute.zip(computeFunctions).map {
case (agg, func) => new TreeNodeRef(agg) -> func.result
}.toMap
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
index b1a7948b66..aafcce0572 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
@@ -23,7 +23,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SQLContext
-import org.apache.spark.sql.catalyst.trees
+import org.apache.spark.sql.catalyst.{ScalaReflection, trees}
import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen._
@@ -82,7 +82,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
/**
* Runs this query returning the result as an array.
*/
- def executeCollect(): Array[Row] = execute().map(_.copy()).collect()
+ def executeCollect(): Array[Row] = execute().map(ScalaReflection.convertRowToScala).collect()
protected def newProjection(
expressions: Seq[Expression], inputSchema: Seq[Attribute]): Projection = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala
index 077e6ebc5f..84d96e612f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala
@@ -29,6 +29,7 @@ import com.twitter.chill.{AllScalaRegistrar, ResourcePool}
import org.apache.spark.{SparkEnv, SparkConf}
import org.apache.spark.serializer.{SerializerInstance, KryoSerializer}
import org.apache.spark.sql.catalyst.expressions.GenericRow
+import org.apache.spark.sql.catalyst.types.decimal.Decimal
import org.apache.spark.util.collection.OpenHashSet
import org.apache.spark.util.MutablePair
import org.apache.spark.util.Utils
@@ -51,6 +52,7 @@ private[sql] class SparkSqlSerializer(conf: SparkConf) extends KryoSerializer(co
kryo.register(classOf[LongHashSet], new LongHashSetSerializer)
kryo.register(classOf[org.apache.spark.util.collection.OpenHashSet[_]],
new OpenHashSetSerializer)
+ kryo.register(classOf[Decimal])
kryo.setReferences(false)
kryo.setClassLoader(Utils.getSparkClassLoader)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
index 977f3c9f32..e6cd1a9d04 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
@@ -143,7 +143,7 @@ case class Limit(limit: Int, child: SparkPlan)
partsScanned += numPartsToTry
}
- buf.toArray
+ buf.toArray.map(ScalaReflection.convertRowToScala)
}
override def execute() = {
@@ -176,10 +176,11 @@ case class TakeOrdered(limit: Int, sortOrder: Seq[SortOrder], child: SparkPlan)
override def output = child.output
override def outputPartitioning = SinglePartition
- val ordering = new RowOrdering(sortOrder, child.output)
+ val ord = new RowOrdering(sortOrder, child.output)
// TODO: Is this copying for no reason?
- override def executeCollect() = child.execute().map(_.copy()).takeOrdered(limit)(ordering)
+ override def executeCollect() =
+ child.execute().map(_.copy()).takeOrdered(limit)(ord).map(ScalaReflection.convertRowToScala)
// TODO: Terminal split should be implemented differently from non-terminal split.
// TODO: Pick num splits based on |limit|.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
index 8fd35880ee..5cf2a785ad 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
@@ -49,7 +49,8 @@ case class BroadcastHashJoin(
@transient
private val broadcastFuture = future {
- val input: Array[Row] = buildPlan.executeCollect()
+ // Note that we use .execute().collect() because we don't want to convert data to Scala types
+ val input: Array[Row] = buildPlan.execute().map(_.copy()).collect()
val hashed = HashedRelation(input.iterator, buildSideKeyGenerator, input.length)
sparkContext.broadcast(hashed)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala
index a1961bba18..997669051e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala
@@ -19,6 +19,8 @@ package org.apache.spark.sql.execution
import java.util.{List => JList, Map => JMap}
+import org.apache.spark.sql.catalyst.types.decimal.Decimal
+
import scala.collection.JavaConversions._
import scala.collection.JavaConverters._
@@ -116,7 +118,7 @@ object EvaluatePython {
def toJava(obj: Any, dataType: DataType): Any = (obj, dataType) match {
case (null, _) => null
- case (row: Row, struct: StructType) =>
+ case (row: Seq[Any], struct: StructType) =>
val fields = struct.fields.map(field => field.dataType)
row.zip(fields).map {
case (obj, dataType) => toJava(obj, dataType)
@@ -133,6 +135,8 @@ object EvaluatePython {
case (k, v) => (k, toJava(v, mt.valueType)) // key should be primitive type
}.asJava
+ case (dec: BigDecimal, dt: DecimalType) => dec.underlying() // Pyrolite can handle BigDecimal
+
// Pyrolite can handle Timestamp
case (other, _) => other
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala
index eabe312f92..5bb6f6c85d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala
@@ -17,6 +17,8 @@
package org.apache.spark.sql.json
+import org.apache.spark.sql.catalyst.types.decimal.Decimal
+
import scala.collection.Map
import scala.collection.convert.Wrappers.{JMapWrapper, JListWrapper}
import scala.math.BigDecimal
@@ -175,9 +177,9 @@ private[sql] object JsonRDD extends Logging {
ScalaReflection.typeOfObject orElse {
// Since we do not have a data type backed by BigInteger,
// when we see a Java BigInteger, we use DecimalType.
- case value: java.math.BigInteger => DecimalType
+ case value: java.math.BigInteger => DecimalType.Unlimited
// DecimalType's JVMType is scala BigDecimal.
- case value: java.math.BigDecimal => DecimalType
+ case value: java.math.BigDecimal => DecimalType.Unlimited
// Unexpected data type.
case _ => StringType
}
@@ -319,13 +321,13 @@ private[sql] object JsonRDD extends Logging {
}
}
- private def toDecimal(value: Any): BigDecimal = {
+ private def toDecimal(value: Any): Decimal = {
value match {
- case value: java.lang.Integer => BigDecimal(value)
- case value: java.lang.Long => BigDecimal(value)
- case value: java.math.BigInteger => BigDecimal(value)
- case value: java.lang.Double => BigDecimal(value)
- case value: java.math.BigDecimal => BigDecimal(value)
+ case value: java.lang.Integer => Decimal(value)
+ case value: java.lang.Long => Decimal(value)
+ case value: java.math.BigInteger => Decimal(BigDecimal(value))
+ case value: java.lang.Double => Decimal(value)
+ case value: java.math.BigDecimal => Decimal(BigDecimal(value))
}
}
@@ -391,7 +393,7 @@ private[sql] object JsonRDD extends Logging {
case IntegerType => value.asInstanceOf[IntegerType.JvmType]
case LongType => toLong(value)
case DoubleType => toDouble(value)
- case DecimalType => toDecimal(value)
+ case DecimalType() => toDecimal(value)
case BooleanType => value.asInstanceOf[BooleanType.JvmType]
case NullType => null
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/package.scala
index f0e57e2a74..05926a24c5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/package.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/package.scala
@@ -183,6 +183,20 @@ package object sql {
*
* The data type representing `scala.math.BigDecimal` values.
*
+ * TODO(matei): explain precision and scale
+ *
+ * @group dataType
+ */
+ @DeveloperApi
+ type DecimalType = catalyst.types.DecimalType
+
+ /**
+ * :: DeveloperApi ::
+ *
+ * The data type representing `scala.math.BigDecimal` values.
+ *
+ * TODO(matei): explain precision and scale
+ *
* @group dataType
*/
@DeveloperApi
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala
index 2fc7e1cf23..08feced61a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala
@@ -17,6 +17,8 @@
package org.apache.spark.sql.parquet
+import org.apache.spark.sql.catalyst.types.decimal.Decimal
+
import scala.collection.mutable.{Buffer, ArrayBuffer, HashMap}
import parquet.io.api.{PrimitiveConverter, GroupConverter, Binary, Converter}
@@ -117,6 +119,12 @@ private[sql] object CatalystConverter {
parent.updateByte(fieldIndex, value.asInstanceOf[ByteType.JvmType])
}
}
+ case d: DecimalType => {
+ new CatalystPrimitiveConverter(parent, fieldIndex) {
+ override def addBinary(value: Binary): Unit =
+ parent.updateDecimal(fieldIndex, value, d)
+ }
+ }
// All other primitive types use the default converter
case ctype: PrimitiveType => { // note: need the type tag here!
new CatalystPrimitiveConverter(parent, fieldIndex)
@@ -191,6 +199,10 @@ private[parquet] abstract class CatalystConverter extends GroupConverter {
protected[parquet] def updateString(fieldIndex: Int, value: Binary): Unit =
updateField(fieldIndex, value.toStringUsingUTF8)
+ protected[parquet] def updateDecimal(fieldIndex: Int, value: Binary, ctype: DecimalType): Unit = {
+ updateField(fieldIndex, readDecimal(new Decimal(), value, ctype))
+ }
+
protected[parquet] def isRootConverter: Boolean = parent == null
protected[parquet] def clearBuffer(): Unit
@@ -201,6 +213,27 @@ private[parquet] abstract class CatalystConverter extends GroupConverter {
* @return
*/
def getCurrentRecord: Row = throw new UnsupportedOperationException
+
+ /**
+ * Read a decimal value from a Parquet Binary into "dest". Only supports decimals that fit in
+ * a long (i.e. precision <= 18)
+ */
+ protected[parquet] def readDecimal(dest: Decimal, value: Binary, ctype: DecimalType): Unit = {
+ val precision = ctype.precisionInfo.get.precision
+ val scale = ctype.precisionInfo.get.scale
+ val bytes = value.getBytes
+ require(bytes.length <= 16, "Decimal field too large to read")
+ var unscaled = 0L
+ var i = 0
+ while (i < bytes.length) {
+ unscaled = (unscaled << 8) | (bytes(i) & 0xFF)
+ i += 1
+ }
+ // Make sure unscaled has the right sign, by sign-extending the first bit
+ val numBits = 8 * bytes.length
+ unscaled = (unscaled << (64 - numBits)) >> (64 - numBits)
+ dest.set(unscaled, precision, scale)
+ }
}
/**
@@ -352,6 +385,16 @@ private[parquet] class CatalystPrimitiveRowConverter(
override protected[parquet] def updateString(fieldIndex: Int, value: Binary): Unit =
current.setString(fieldIndex, value.toStringUsingUTF8)
+
+ override protected[parquet] def updateDecimal(
+ fieldIndex: Int, value: Binary, ctype: DecimalType): Unit = {
+ var decimal = current(fieldIndex).asInstanceOf[Decimal]
+ if (decimal == null) {
+ decimal = new Decimal
+ current(fieldIndex) = decimal
+ }
+ readDecimal(decimal, value, ctype)
+ }
}
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala
index bdf02401b2..2a5f23b24e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.parquet
import java.util.{HashMap => JHashMap}
import org.apache.hadoop.conf.Configuration
+import org.apache.spark.sql.catalyst.types.decimal.Decimal
import parquet.column.ParquetProperties
import parquet.hadoop.ParquetOutputFormat
import parquet.hadoop.api.ReadSupport.ReadContext
@@ -204,6 +205,11 @@ private[parquet] class RowWriteSupport extends WriteSupport[Row] with Logging {
case DoubleType => writer.addDouble(value.asInstanceOf[Double])
case FloatType => writer.addFloat(value.asInstanceOf[Float])
case BooleanType => writer.addBoolean(value.asInstanceOf[Boolean])
+ case d: DecimalType =>
+ if (d.precisionInfo == None || d.precisionInfo.get.precision > 18) {
+ sys.error(s"Unsupported datatype $d, cannot write to consumer")
+ }
+ writeDecimal(value.asInstanceOf[Decimal], d.precisionInfo.get.precision)
case _ => sys.error(s"Do not know how to writer $schema to consumer")
}
}
@@ -283,6 +289,23 @@ private[parquet] class RowWriteSupport extends WriteSupport[Row] with Logging {
}
writer.endGroup()
}
+
+ // Scratch array used to write decimals as fixed-length binary
+ private val scratchBytes = new Array[Byte](8)
+
+ private[parquet] def writeDecimal(decimal: Decimal, precision: Int): Unit = {
+ val numBytes = ParquetTypesConverter.BYTES_FOR_PRECISION(precision)
+ val unscaledLong = decimal.toUnscaledLong
+ var i = 0
+ var shift = 8 * (numBytes - 1)
+ while (i < numBytes) {
+ scratchBytes(i) = (unscaledLong >> shift).toByte
+ i += 1
+ shift -= 8
+ }
+ writer.addBinary(Binary.fromByteArray(scratchBytes, 0, numBytes))
+ }
+
}
// Optimized for non-nested rows
@@ -326,6 +349,11 @@ private[parquet] class MutableRowWriteSupport extends RowWriteSupport {
case DoubleType => writer.addDouble(record.getDouble(index))
case FloatType => writer.addFloat(record.getFloat(index))
case BooleanType => writer.addBoolean(record.getBoolean(index))
+ case d: DecimalType =>
+ if (d.precisionInfo == None || d.precisionInfo.get.precision > 18) {
+ sys.error(s"Unsupported datatype $d, cannot write to consumer")
+ }
+ writeDecimal(record(index).asInstanceOf[Decimal], d.precisionInfo.get.precision)
case _ => sys.error(s"Unsupported datatype $ctype, cannot write to consumer")
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala
index e6389cf77a..e5077de8dd 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala
@@ -29,8 +29,8 @@ import org.apache.hadoop.mapreduce.lib.output.FileOutputCommitter
import parquet.hadoop.{ParquetFileReader, Footer, ParquetFileWriter}
import parquet.hadoop.metadata.{ParquetMetadata, FileMetaData}
import parquet.hadoop.util.ContextUtil
-import parquet.schema.{Type => ParquetType, PrimitiveType => ParquetPrimitiveType, MessageType}
-import parquet.schema.{GroupType => ParquetGroupType, OriginalType => ParquetOriginalType, ConversionPatterns}
+import parquet.schema.{Type => ParquetType, Types => ParquetTypes, PrimitiveType => ParquetPrimitiveType, MessageType}
+import parquet.schema.{GroupType => ParquetGroupType, OriginalType => ParquetOriginalType, ConversionPatterns, DecimalMetadata}
import parquet.schema.PrimitiveType.{PrimitiveTypeName => ParquetPrimitiveTypeName}
import parquet.schema.Type.Repetition
@@ -41,17 +41,25 @@ import org.apache.spark.sql.catalyst.types._
// Implicits
import scala.collection.JavaConversions._
+/** A class representing Parquet info fields we care about, for passing back to Parquet */
+private[parquet] case class ParquetTypeInfo(
+ primitiveType: ParquetPrimitiveTypeName,
+ originalType: Option[ParquetOriginalType] = None,
+ decimalMetadata: Option[DecimalMetadata] = None,
+ length: Option[Int] = None)
+
private[parquet] object ParquetTypesConverter extends Logging {
def isPrimitiveType(ctype: DataType): Boolean =
classOf[PrimitiveType] isAssignableFrom ctype.getClass
def toPrimitiveDataType(
parquetType: ParquetPrimitiveType,
- binayAsString: Boolean): DataType =
+ binaryAsString: Boolean): DataType = {
+ val originalType = parquetType.getOriginalType
+ val decimalInfo = parquetType.getDecimalMetadata
parquetType.getPrimitiveTypeName match {
case ParquetPrimitiveTypeName.BINARY
- if (parquetType.getOriginalType == ParquetOriginalType.UTF8 ||
- binayAsString) => StringType
+ if (originalType == ParquetOriginalType.UTF8 || binaryAsString) => StringType
case ParquetPrimitiveTypeName.BINARY => BinaryType
case ParquetPrimitiveTypeName.BOOLEAN => BooleanType
case ParquetPrimitiveTypeName.DOUBLE => DoubleType
@@ -61,9 +69,14 @@ private[parquet] object ParquetTypesConverter extends Logging {
case ParquetPrimitiveTypeName.INT96 =>
// TODO: add BigInteger type? TODO(andre) use DecimalType instead????
sys.error("Potential loss of precision: cannot convert INT96")
+ case ParquetPrimitiveTypeName.FIXED_LEN_BYTE_ARRAY
+ if (originalType == ParquetOriginalType.DECIMAL && decimalInfo.getPrecision <= 18) =>
+ // TODO: for now, our reader only supports decimals that fit in a Long
+ DecimalType(decimalInfo.getPrecision, decimalInfo.getScale)
case _ => sys.error(
s"Unsupported parquet datatype $parquetType")
}
+ }
/**
* Converts a given Parquet `Type` into the corresponding
@@ -183,24 +196,41 @@ private[parquet] object ParquetTypesConverter extends Logging {
* is not primitive.
*
* @param ctype The type to convert
- * @return The name of the corresponding Parquet primitive type
+ * @return The name of the corresponding Parquet type properties
*/
- def fromPrimitiveDataType(ctype: DataType):
- Option[(ParquetPrimitiveTypeName, Option[ParquetOriginalType])] = ctype match {
- case StringType => Some(ParquetPrimitiveTypeName.BINARY, Some(ParquetOriginalType.UTF8))
- case BinaryType => Some(ParquetPrimitiveTypeName.BINARY, None)
- case BooleanType => Some(ParquetPrimitiveTypeName.BOOLEAN, None)
- case DoubleType => Some(ParquetPrimitiveTypeName.DOUBLE, None)
- case FloatType => Some(ParquetPrimitiveTypeName.FLOAT, None)
- case IntegerType => Some(ParquetPrimitiveTypeName.INT32, None)
+ def fromPrimitiveDataType(ctype: DataType): Option[ParquetTypeInfo] = ctype match {
+ case StringType => Some(ParquetTypeInfo(
+ ParquetPrimitiveTypeName.BINARY, Some(ParquetOriginalType.UTF8)))
+ case BinaryType => Some(ParquetTypeInfo(ParquetPrimitiveTypeName.BINARY))
+ case BooleanType => Some(ParquetTypeInfo(ParquetPrimitiveTypeName.BOOLEAN))
+ case DoubleType => Some(ParquetTypeInfo(ParquetPrimitiveTypeName.DOUBLE))
+ case FloatType => Some(ParquetTypeInfo(ParquetPrimitiveTypeName.FLOAT))
+ case IntegerType => Some(ParquetTypeInfo(ParquetPrimitiveTypeName.INT32))
// There is no type for Byte or Short so we promote them to INT32.
- case ShortType => Some(ParquetPrimitiveTypeName.INT32, None)
- case ByteType => Some(ParquetPrimitiveTypeName.INT32, None)
- case LongType => Some(ParquetPrimitiveTypeName.INT64, None)
+ case ShortType => Some(ParquetTypeInfo(ParquetPrimitiveTypeName.INT32))
+ case ByteType => Some(ParquetTypeInfo(ParquetPrimitiveTypeName.INT32))
+ case LongType => Some(ParquetTypeInfo(ParquetPrimitiveTypeName.INT64))
+ case DecimalType.Fixed(precision, scale) if precision <= 18 =>
+ // TODO: for now, our writer only supports decimals that fit in a Long
+ Some(ParquetTypeInfo(ParquetPrimitiveTypeName.FIXED_LEN_BYTE_ARRAY,
+ Some(ParquetOriginalType.DECIMAL),
+ Some(new DecimalMetadata(precision, scale)),
+ Some(BYTES_FOR_PRECISION(precision))))
case _ => None
}
/**
+ * Compute the FIXED_LEN_BYTE_ARRAY length needed to represent a given DECIMAL precision.
+ */
+ private[parquet] val BYTES_FOR_PRECISION = Array.tabulate[Int](38) { precision =>
+ var length = 1
+ while (math.pow(2.0, 8 * length - 1) < math.pow(10.0, precision)) {
+ length += 1
+ }
+ length
+ }
+
+ /**
* Converts a given Catalyst [[org.apache.spark.sql.catalyst.types.DataType]] into
* the corresponding Parquet `Type`.
*
@@ -247,10 +277,17 @@ private[parquet] object ParquetTypesConverter extends Logging {
} else {
if (nullable) Repetition.OPTIONAL else Repetition.REQUIRED
}
- val primitiveType = fromPrimitiveDataType(ctype)
- primitiveType.map {
- case (primitiveType, originalType) =>
- new ParquetPrimitiveType(repetition, primitiveType, name, originalType.orNull)
+ val typeInfo = fromPrimitiveDataType(ctype)
+ typeInfo.map {
+ case ParquetTypeInfo(primitiveType, originalType, decimalMetadata, length) =>
+ val builder = ParquetTypes.primitive(primitiveType, repetition).as(originalType.orNull)
+ for (len <- length) {
+ builder.length(len)
+ }
+ for (metadata <- decimalMetadata) {
+ builder.precision(metadata.getPrecision).scale(metadata.getScale)
+ }
+ builder.named(name)
}.getOrElse {
ctype match {
case ArrayType(elementType, false) => {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala b/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala
index 142598c904..7564bf3923 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala
@@ -19,6 +19,8 @@ package org.apache.spark.sql.types.util
import org.apache.spark.sql._
import org.apache.spark.sql.api.java.{DataType => JDataType, StructField => JStructField, MetadataBuilder => JMetaDataBuilder}
+import org.apache.spark.sql.api.java.{DecimalType => JDecimalType}
+import org.apache.spark.sql.catalyst.types.decimal.Decimal
import scala.collection.JavaConverters._
@@ -44,7 +46,8 @@ protected[sql] object DataTypeConversions {
case BooleanType => JDataType.BooleanType
case DateType => JDataType.DateType
case TimestampType => JDataType.TimestampType
- case DecimalType => JDataType.DecimalType
+ case DecimalType.Fixed(precision, scale) => new JDecimalType(precision, scale)
+ case DecimalType.Unlimited => new JDecimalType()
case DoubleType => JDataType.DoubleType
case FloatType => JDataType.FloatType
case ByteType => JDataType.ByteType
@@ -88,7 +91,11 @@ protected[sql] object DataTypeConversions {
case timestampType: org.apache.spark.sql.api.java.TimestampType =>
TimestampType
case decimalType: org.apache.spark.sql.api.java.DecimalType =>
- DecimalType
+ if (decimalType.isFixed) {
+ DecimalType(decimalType.getPrecision, decimalType.getScale)
+ } else {
+ DecimalType.Unlimited
+ }
case doubleType: org.apache.spark.sql.api.java.DoubleType =>
DoubleType
case floatType: org.apache.spark.sql.api.java.FloatType =>
@@ -115,7 +122,7 @@ protected[sql] object DataTypeConversions {
/** Converts Java objects to catalyst rows / types */
def convertJavaToCatalyst(a: Any): Any = a match {
- case d: java.math.BigDecimal => BigDecimal(d)
+ case d: java.math.BigDecimal => Decimal(BigDecimal(d))
case other => other
}
diff --git a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaApplySchemaSuite.java b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaApplySchemaSuite.java
index 9435a88009..a04b8060cd 100644
--- a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaApplySchemaSuite.java
+++ b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaApplySchemaSuite.java
@@ -118,7 +118,7 @@ public class JavaApplySchemaSuite implements Serializable {
"\"bigInteger\":92233720368547758069, \"double\":1.7976931348623157E305, " +
"\"boolean\":false, \"null\":null}"));
List<StructField> fields = new ArrayList<StructField>(7);
- fields.add(DataType.createStructField("bigInteger", DataType.DecimalType, true));
+ fields.add(DataType.createStructField("bigInteger", new DecimalType(), true));
fields.add(DataType.createStructField("boolean", DataType.BooleanType, true));
fields.add(DataType.createStructField("double", DataType.DoubleType, true));
fields.add(DataType.createStructField("integer", DataType.IntegerType, true));
diff --git a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaSideDataTypeConversionSuite.java b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaSideDataTypeConversionSuite.java
index d04396a5f8..8396a29c61 100644
--- a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaSideDataTypeConversionSuite.java
+++ b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaSideDataTypeConversionSuite.java
@@ -41,7 +41,8 @@ public class JavaSideDataTypeConversionSuite {
checkDataType(DataType.BooleanType);
checkDataType(DataType.DateType);
checkDataType(DataType.TimestampType);
- checkDataType(DataType.DecimalType);
+ checkDataType(new DecimalType());
+ checkDataType(new DecimalType(10, 4));
checkDataType(DataType.DoubleType);
checkDataType(DataType.FloatType);
checkDataType(DataType.ByteType);
@@ -59,7 +60,7 @@ public class JavaSideDataTypeConversionSuite {
// Simple StructType.
List<StructField> simpleFields = new ArrayList<StructField>();
- simpleFields.add(DataType.createStructField("a", DataType.DecimalType, false));
+ simpleFields.add(DataType.createStructField("a", new DecimalType(), false));
simpleFields.add(DataType.createStructField("b", DataType.BooleanType, true));
simpleFields.add(DataType.createStructField("c", DataType.LongType, true));
simpleFields.add(DataType.createStructField("d", DataType.BinaryType, false));
@@ -128,7 +129,7 @@ public class JavaSideDataTypeConversionSuite {
// StructType
try {
List<StructField> simpleFields = new ArrayList<StructField>();
- simpleFields.add(DataType.createStructField("a", DataType.DecimalType, false));
+ simpleFields.add(DataType.createStructField("a", new DecimalType(), false));
simpleFields.add(DataType.createStructField("b", DataType.BooleanType, true));
simpleFields.add(DataType.createStructField("c", DataType.LongType, true));
simpleFields.add(null);
@@ -138,7 +139,7 @@ public class JavaSideDataTypeConversionSuite {
}
try {
List<StructField> simpleFields = new ArrayList<StructField>();
- simpleFields.add(DataType.createStructField("a", DataType.DecimalType, false));
+ simpleFields.add(DataType.createStructField("a", new DecimalType(), false));
simpleFields.add(DataType.createStructField("a", DataType.BooleanType, true));
simpleFields.add(DataType.createStructField("c", DataType.LongType, true));
DataType.createStructType(simpleFields);
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataTypeSuite.scala
index 6c9db639c0..e9740d913c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataTypeSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataTypeSuite.scala
@@ -69,7 +69,7 @@ class DataTypeSuite extends FunSuite {
checkDataTypeJsonRepr(LongType)
checkDataTypeJsonRepr(FloatType)
checkDataTypeJsonRepr(DoubleType)
- checkDataTypeJsonRepr(DecimalType)
+ checkDataTypeJsonRepr(DecimalType.Unlimited)
checkDataTypeJsonRepr(TimestampType)
checkDataTypeJsonRepr(StringType)
checkDataTypeJsonRepr(BinaryType)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala
index bfa9ea4162..cf3a59e545 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql
import java.sql.{Date, Timestamp}
+import org.apache.spark.sql.catalyst.types.decimal.Decimal
import org.scalatest.FunSuite
import org.apache.spark.sql.catalyst.expressions._
@@ -81,7 +82,9 @@ class ScalaReflectionRelationSuite extends FunSuite {
val rdd = sparkContext.parallelize(data :: Nil)
rdd.registerTempTable("reflectData")
- assert(sql("SELECT * FROM reflectData").collect().head === data.productIterator.toSeq)
+ assert(sql("SELECT * FROM reflectData").collect().head ===
+ Seq("a", 1, 1L, 1.toFloat, 1.toDouble, 1.toShort, 1.toByte, true,
+ BigDecimal(1), new Date(12345), new Timestamp(12345), Seq(1,2,3)))
}
test("query case class RDD with nulls") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/api/java/JavaSQLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/api/java/JavaSQLSuite.scala
index d83f3e23a9..c9012c9e47 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/api/java/JavaSQLSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/api/java/JavaSQLSuite.scala
@@ -17,6 +17,8 @@
package org.apache.spark.sql.api.java
+import org.apache.spark.sql.catalyst.types.decimal.Decimal
+
import scala.beans.BeanProperty
import org.scalatest.FunSuite
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/api/java/ScalaSideDataTypeConversionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/api/java/ScalaSideDataTypeConversionSuite.scala
index e0e0ff9cb3..62fe59dd34 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/api/java/ScalaSideDataTypeConversionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/api/java/ScalaSideDataTypeConversionSuite.scala
@@ -38,7 +38,7 @@ class ScalaSideDataTypeConversionSuite extends FunSuite {
checkDataType(org.apache.spark.sql.BooleanType)
checkDataType(org.apache.spark.sql.DateType)
checkDataType(org.apache.spark.sql.TimestampType)
- checkDataType(org.apache.spark.sql.DecimalType)
+ checkDataType(org.apache.spark.sql.DecimalType.Unlimited)
checkDataType(org.apache.spark.sql.DoubleType)
checkDataType(org.apache.spark.sql.FloatType)
checkDataType(org.apache.spark.sql.ByteType)
@@ -58,7 +58,7 @@ class ScalaSideDataTypeConversionSuite extends FunSuite {
// Simple StructType.
val simpleScalaStructType = SStructType(
- SStructField("a", org.apache.spark.sql.DecimalType, false) ::
+ SStructField("a", org.apache.spark.sql.DecimalType.Unlimited, false) ::
SStructField("b", org.apache.spark.sql.BooleanType, true) ::
SStructField("c", org.apache.spark.sql.LongType, true) ::
SStructField("d", org.apache.spark.sql.BinaryType, false) :: Nil)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala
index ce6184f5d8..1cb6c23c58 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala
@@ -18,6 +18,7 @@
package org.apache.spark.sql.json
import org.apache.spark.sql.catalyst.types._
+import org.apache.spark.sql.catalyst.types.decimal.Decimal
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.json.JsonRDD.{enforceCorrectType, compatibleType}
import org.apache.spark.sql.QueryTest
@@ -44,19 +45,22 @@ class JsonSuite extends QueryTest {
checkTypePromotion(intNumber, enforceCorrectType(intNumber, IntegerType))
checkTypePromotion(intNumber.toLong, enforceCorrectType(intNumber, LongType))
checkTypePromotion(intNumber.toDouble, enforceCorrectType(intNumber, DoubleType))
- checkTypePromotion(BigDecimal(intNumber), enforceCorrectType(intNumber, DecimalType))
+ checkTypePromotion(
+ Decimal(intNumber), enforceCorrectType(intNumber, DecimalType.Unlimited))
val longNumber: Long = 9223372036854775807L
checkTypePromotion(longNumber, enforceCorrectType(longNumber, LongType))
checkTypePromotion(longNumber.toDouble, enforceCorrectType(longNumber, DoubleType))
- checkTypePromotion(BigDecimal(longNumber), enforceCorrectType(longNumber, DecimalType))
+ checkTypePromotion(
+ Decimal(longNumber), enforceCorrectType(longNumber, DecimalType.Unlimited))
val doubleNumber: Double = 1.7976931348623157E308d
checkTypePromotion(doubleNumber.toDouble, enforceCorrectType(doubleNumber, DoubleType))
- checkTypePromotion(BigDecimal(doubleNumber), enforceCorrectType(doubleNumber, DecimalType))
-
+ checkTypePromotion(
+ Decimal(doubleNumber), enforceCorrectType(doubleNumber, DecimalType.Unlimited))
+
checkTypePromotion(new Timestamp(intNumber), enforceCorrectType(intNumber, TimestampType))
- checkTypePromotion(new Timestamp(intNumber.toLong),
+ checkTypePromotion(new Timestamp(intNumber.toLong),
enforceCorrectType(intNumber.toLong, TimestampType))
val strTime = "2014-09-30 12:34:56"
checkTypePromotion(Timestamp.valueOf(strTime), enforceCorrectType(strTime, TimestampType))
@@ -80,7 +84,7 @@ class JsonSuite extends QueryTest {
checkDataType(NullType, IntegerType, IntegerType)
checkDataType(NullType, LongType, LongType)
checkDataType(NullType, DoubleType, DoubleType)
- checkDataType(NullType, DecimalType, DecimalType)
+ checkDataType(NullType, DecimalType.Unlimited, DecimalType.Unlimited)
checkDataType(NullType, StringType, StringType)
checkDataType(NullType, ArrayType(IntegerType), ArrayType(IntegerType))
checkDataType(NullType, StructType(Nil), StructType(Nil))
@@ -91,7 +95,7 @@ class JsonSuite extends QueryTest {
checkDataType(BooleanType, IntegerType, StringType)
checkDataType(BooleanType, LongType, StringType)
checkDataType(BooleanType, DoubleType, StringType)
- checkDataType(BooleanType, DecimalType, StringType)
+ checkDataType(BooleanType, DecimalType.Unlimited, StringType)
checkDataType(BooleanType, StringType, StringType)
checkDataType(BooleanType, ArrayType(IntegerType), StringType)
checkDataType(BooleanType, StructType(Nil), StringType)
@@ -100,7 +104,7 @@ class JsonSuite extends QueryTest {
checkDataType(IntegerType, IntegerType, IntegerType)
checkDataType(IntegerType, LongType, LongType)
checkDataType(IntegerType, DoubleType, DoubleType)
- checkDataType(IntegerType, DecimalType, DecimalType)
+ checkDataType(IntegerType, DecimalType.Unlimited, DecimalType.Unlimited)
checkDataType(IntegerType, StringType, StringType)
checkDataType(IntegerType, ArrayType(IntegerType), StringType)
checkDataType(IntegerType, StructType(Nil), StringType)
@@ -108,23 +112,23 @@ class JsonSuite extends QueryTest {
// LongType
checkDataType(LongType, LongType, LongType)
checkDataType(LongType, DoubleType, DoubleType)
- checkDataType(LongType, DecimalType, DecimalType)
+ checkDataType(LongType, DecimalType.Unlimited, DecimalType.Unlimited)
checkDataType(LongType, StringType, StringType)
checkDataType(LongType, ArrayType(IntegerType), StringType)
checkDataType(LongType, StructType(Nil), StringType)
// DoubleType
checkDataType(DoubleType, DoubleType, DoubleType)
- checkDataType(DoubleType, DecimalType, DecimalType)
+ checkDataType(DoubleType, DecimalType.Unlimited, DecimalType.Unlimited)
checkDataType(DoubleType, StringType, StringType)
checkDataType(DoubleType, ArrayType(IntegerType), StringType)
checkDataType(DoubleType, StructType(Nil), StringType)
// DoubleType
- checkDataType(DecimalType, DecimalType, DecimalType)
- checkDataType(DecimalType, StringType, StringType)
- checkDataType(DecimalType, ArrayType(IntegerType), StringType)
- checkDataType(DecimalType, StructType(Nil), StringType)
+ checkDataType(DecimalType.Unlimited, DecimalType.Unlimited, DecimalType.Unlimited)
+ checkDataType(DecimalType.Unlimited, StringType, StringType)
+ checkDataType(DecimalType.Unlimited, ArrayType(IntegerType), StringType)
+ checkDataType(DecimalType.Unlimited, StructType(Nil), StringType)
// StringType
checkDataType(StringType, StringType, StringType)
@@ -178,7 +182,7 @@ class JsonSuite extends QueryTest {
checkDataType(
StructType(
StructField("f1", IntegerType, true) :: Nil),
- DecimalType,
+ DecimalType.Unlimited,
StringType)
}
@@ -186,7 +190,7 @@ class JsonSuite extends QueryTest {
val jsonSchemaRDD = jsonRDD(primitiveFieldAndType)
val expectedSchema = StructType(
- StructField("bigInteger", DecimalType, true) ::
+ StructField("bigInteger", DecimalType.Unlimited, true) ::
StructField("boolean", BooleanType, true) ::
StructField("double", DoubleType, true) ::
StructField("integer", IntegerType, true) ::
@@ -216,7 +220,7 @@ class JsonSuite extends QueryTest {
val expectedSchema = StructType(
StructField("arrayOfArray1", ArrayType(ArrayType(StringType, false), false), true) ::
StructField("arrayOfArray2", ArrayType(ArrayType(DoubleType, false), false), true) ::
- StructField("arrayOfBigInteger", ArrayType(DecimalType, false), true) ::
+ StructField("arrayOfBigInteger", ArrayType(DecimalType.Unlimited, false), true) ::
StructField("arrayOfBoolean", ArrayType(BooleanType, false), true) ::
StructField("arrayOfDouble", ArrayType(DoubleType, false), true) ::
StructField("arrayOfInteger", ArrayType(IntegerType, false), true) ::
@@ -230,7 +234,7 @@ class JsonSuite extends QueryTest {
StructField("field3", StringType, true) :: Nil), false), true) ::
StructField("struct", StructType(
StructField("field1", BooleanType, true) ::
- StructField("field2", DecimalType, true) :: Nil), true) ::
+ StructField("field2", DecimalType.Unlimited, true) :: Nil), true) ::
StructField("structWithArrayFields", StructType(
StructField("field1", ArrayType(IntegerType, false), true) ::
StructField("field2", ArrayType(StringType, false), true) :: Nil), true) :: Nil)
@@ -331,7 +335,7 @@ class JsonSuite extends QueryTest {
val expectedSchema = StructType(
StructField("num_bool", StringType, true) ::
StructField("num_num_1", LongType, true) ::
- StructField("num_num_2", DecimalType, true) ::
+ StructField("num_num_2", DecimalType.Unlimited, true) ::
StructField("num_num_3", DoubleType, true) ::
StructField("num_str", StringType, true) ::
StructField("str_bool", StringType, true) :: Nil)
@@ -521,7 +525,7 @@ class JsonSuite extends QueryTest {
val jsonSchemaRDD = jsonFile(path)
val expectedSchema = StructType(
- StructField("bigInteger", DecimalType, true) ::
+ StructField("bigInteger", DecimalType.Unlimited, true) ::
StructField("boolean", BooleanType, true) ::
StructField("double", DoubleType, true) ::
StructField("integer", IntegerType, true) ::
@@ -551,7 +555,7 @@ class JsonSuite extends QueryTest {
primitiveFieldAndType.map(record => record.replaceAll("\n", " ")).saveAsTextFile(path)
val schema = StructType(
- StructField("bigInteger", DecimalType, true) ::
+ StructField("bigInteger", DecimalType.Unlimited, true) ::
StructField("boolean", BooleanType, true) ::
StructField("double", DoubleType, true) ::
StructField("integer", IntegerType, true) ::
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala
index 9979ab446d..08d9da27f1 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala
@@ -77,6 +77,8 @@ case class AllDataTypesWithNonPrimitiveType(
case class BinaryData(binaryData: Array[Byte])
+case class NumericData(i: Int, d: Double)
+
class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterAll {
TestData // Load test data tables.
@@ -560,7 +562,7 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA
assert(stringResult.size === 1)
assert(stringResult(0).getString(2) == "100", "stringvalue incorrect")
assert(stringResult(0).getInt(1) === 100)
-
+
val query7 = sql(s"SELECT * FROM testfiltersource WHERE myoptint < 40")
assert(
query7.queryExecution.executedPlan(0)(0).isInstanceOf[ParquetTableScan],
@@ -869,4 +871,35 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA
assert(a.dataType === b.dataType)
}
}
+
+ test("read/write fixed-length decimals") {
+ for ((precision, scale) <- Seq((5, 2), (1, 0), (1, 1), (18, 10), (18, 17))) {
+ val tempDir = getTempFilePath("parquetTest").getCanonicalPath
+ val data = sparkContext.parallelize(0 to 1000)
+ .map(i => NumericData(i, i / 100.0))
+ .select('i, 'd cast DecimalType(precision, scale))
+ data.saveAsParquetFile(tempDir)
+ checkAnswer(parquetFile(tempDir), data.toSchemaRDD.collect().toSeq)
+ }
+
+ // Decimals with precision above 18 are not yet supported
+ intercept[RuntimeException] {
+ val tempDir = getTempFilePath("parquetTest").getCanonicalPath
+ val data = sparkContext.parallelize(0 to 1000)
+ .map(i => NumericData(i, i / 100.0))
+ .select('i, 'd cast DecimalType(19, 10))
+ data.saveAsParquetFile(tempDir)
+ checkAnswer(parquetFile(tempDir), data.toSchemaRDD.collect().toSeq)
+ }
+
+ // Unlimited-length decimals are not yet supported
+ intercept[RuntimeException] {
+ val tempDir = getTempFilePath("parquetTest").getCanonicalPath
+ val data = sparkContext.parallelize(0 to 1000)
+ .map(i => NumericData(i, i / 100.0))
+ .select('i, 'd cast DecimalType.Unlimited)
+ data.saveAsParquetFile(tempDir)
+ checkAnswer(parquetFile(tempDir), data.toSchemaRDD.collect().toSeq)
+ }
+ }
}