aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystRowConverter.scala25
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystSchemaConverter.scala46
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala64
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala10
4 files changed, 85 insertions, 60 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystRowConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystRowConverter.scala
index b5e4263008..e00bd90edb 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystRowConverter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystRowConverter.scala
@@ -17,6 +17,7 @@
package org.apache.spark.sql.parquet
+import java.math.{BigDecimal, BigInteger}
import java.nio.ByteOrder
import scala.collection.JavaConversions._
@@ -263,17 +264,23 @@ private[parquet] class CatalystRowConverter(
val scale = decimalType.scale
val bytes = value.getBytes
- var unscaled = 0L
- var i = 0
+ if (precision <= 8) {
+ // Constructs a `Decimal` with an unscaled `Long` value if possible.
+ var unscaled = 0L
+ var i = 0
- while (i < bytes.length) {
- unscaled = (unscaled << 8) | (bytes(i) & 0xff)
- i += 1
- }
+ while (i < bytes.length) {
+ unscaled = (unscaled << 8) | (bytes(i) & 0xff)
+ i += 1
+ }
- val bits = 8 * bytes.length
- unscaled = (unscaled << (64 - bits)) >> (64 - bits)
- Decimal(unscaled, precision, scale)
+ val bits = 8 * bytes.length
+ unscaled = (unscaled << (64 - bits)) >> (64 - bits)
+ Decimal(unscaled, precision, scale)
+ } else {
+ // Otherwise, resorts to an unscaled `BigInteger` instead.
+ Decimal(new BigDecimal(new BigInteger(bytes), scale), precision, scale)
+ }
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystSchemaConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystSchemaConverter.scala
index e9ef01e2db..d43ca95b4e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystSchemaConverter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystSchemaConverter.scala
@@ -387,24 +387,18 @@ private[parquet] class CatalystSchemaConverter(
// =====================================
// Spark 1.4.x and prior versions only support decimals with a maximum precision of 18 and
- // always store decimals in fixed-length byte arrays.
- case DecimalType.Fixed(precision, scale)
- if precision <= maxPrecisionForBytes(8) && !followParquetFormatSpec =>
+ // always store decimals in fixed-length byte arrays. To keep compatibility with these older
+ // versions, here we convert decimals with all precisions to `FIXED_LEN_BYTE_ARRAY` annotated
+ // by `DECIMAL`.
+ case DecimalType.Fixed(precision, scale) if !followParquetFormatSpec =>
Types
.primitive(FIXED_LEN_BYTE_ARRAY, repetition)
.as(DECIMAL)
.precision(precision)
.scale(scale)
- .length(minBytesForPrecision(precision))
+ .length(CatalystSchemaConverter.minBytesForPrecision(precision))
.named(field.name)
- case dec @ DecimalType() if !followParquetFormatSpec =>
- throw new AnalysisException(
- s"Data type $dec is not supported. " +
- s"When ${SQLConf.PARQUET_FOLLOW_PARQUET_FORMAT_SPEC.key} is set to false," +
- "decimal precision and scale must be specified, " +
- "and precision must be less than or equal to 18.")
-
// =====================================
// Decimals (follow Parquet format spec)
// =====================================
@@ -436,7 +430,7 @@ private[parquet] class CatalystSchemaConverter(
.as(DECIMAL)
.precision(precision)
.scale(scale)
- .length(minBytesForPrecision(precision))
+ .length(CatalystSchemaConverter.minBytesForPrecision(precision))
.named(field.name)
// ===================================================
@@ -548,15 +542,6 @@ private[parquet] class CatalystSchemaConverter(
Math.pow(2, 8 * numBytes - 1) - 1))) // max value stored in numBytes
.asInstanceOf[Int]
}
-
- // Min byte counts needed to store decimals with various precisions
- private val minBytesForPrecision: Array[Int] = Array.tabulate(38) { precision =>
- var numBytes = 1
- while (math.pow(2.0, 8 * numBytes - 1) < math.pow(10.0, precision)) {
- numBytes += 1
- }
- numBytes
- }
}
@@ -580,4 +565,23 @@ private[parquet] object CatalystSchemaConverter {
throw new AnalysisException(message)
}
}
+
+ private def computeMinBytesForPrecision(precision : Int) : Int = {
+ var numBytes = 1
+ while (math.pow(2.0, 8 * numBytes - 1) < math.pow(10.0, precision)) {
+ numBytes += 1
+ }
+ numBytes
+ }
+
+ private val MIN_BYTES_FOR_PRECISION = Array.tabulate[Int](39)(computeMinBytesForPrecision)
+
+ // Returns the minimum number of bytes needed to store a decimal with a given `precision`.
+ def minBytesForPrecision(precision : Int) : Int = {
+ if (precision < MIN_BYTES_FOR_PRECISION.length) {
+ MIN_BYTES_FOR_PRECISION(precision)
+ } else {
+ computeMinBytesForPrecision(precision)
+ }
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala
index fc9f61a636..78ecfad1d5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala
@@ -17,6 +17,7 @@
package org.apache.spark.sql.parquet
+import java.math.BigInteger
import java.nio.{ByteBuffer, ByteOrder}
import java.util.{HashMap => JHashMap}
@@ -114,11 +115,8 @@ private[parquet] class RowWriteSupport extends WriteSupport[InternalRow] with Lo
Binary.fromByteArray(value.asInstanceOf[UTF8String].getBytes))
case BinaryType => writer.addBinary(
Binary.fromByteArray(value.asInstanceOf[Array[Byte]]))
- case d: DecimalType =>
- if (d.precision > 18) {
- sys.error(s"Unsupported datatype $d, cannot write to consumer")
- }
- writeDecimal(value.asInstanceOf[Decimal], d.precision)
+ case DecimalType.Fixed(precision, _) =>
+ writeDecimal(value.asInstanceOf[Decimal], precision)
case _ => sys.error(s"Do not know how to writer $schema to consumer")
}
}
@@ -199,20 +197,47 @@ private[parquet] class RowWriteSupport extends WriteSupport[InternalRow] with Lo
writer.endGroup()
}
- // Scratch array used to write decimals as fixed-length binary
- private[this] val scratchBytes = new Array[Byte](8)
+ // Scratch array used to write decimals as fixed-length byte array
+ private[this] var reusableDecimalBytes = new Array[Byte](16)
private[parquet] def writeDecimal(decimal: Decimal, precision: Int): Unit = {
- val numBytes = ParquetTypesConverter.BYTES_FOR_PRECISION(precision)
- val unscaledLong = decimal.toUnscaledLong
- var i = 0
- var shift = 8 * (numBytes - 1)
- while (i < numBytes) {
- scratchBytes(i) = (unscaledLong >> shift).toByte
- i += 1
- shift -= 8
+ val numBytes = CatalystSchemaConverter.minBytesForPrecision(precision)
+
+ def longToBinary(unscaled: Long): Binary = {
+ var i = 0
+ var shift = 8 * (numBytes - 1)
+ while (i < numBytes) {
+ reusableDecimalBytes(i) = (unscaled >> shift).toByte
+ i += 1
+ shift -= 8
+ }
+ Binary.fromByteArray(reusableDecimalBytes, 0, numBytes)
}
- writer.addBinary(Binary.fromByteArray(scratchBytes, 0, numBytes))
+
+ def bigIntegerToBinary(unscaled: BigInteger): Binary = {
+ unscaled.toByteArray match {
+ case bytes if bytes.length == numBytes =>
+ Binary.fromByteArray(bytes)
+
+ case bytes if bytes.length <= reusableDecimalBytes.length =>
+ val signedByte = (if (bytes.head < 0) -1 else 0).toByte
+ java.util.Arrays.fill(reusableDecimalBytes, 0, numBytes - bytes.length, signedByte)
+ System.arraycopy(bytes, 0, reusableDecimalBytes, numBytes - bytes.length, bytes.length)
+ Binary.fromByteArray(reusableDecimalBytes, 0, numBytes)
+
+ case bytes =>
+ reusableDecimalBytes = new Array[Byte](bytes.length)
+ bigIntegerToBinary(unscaled)
+ }
+ }
+
+ val binary = if (numBytes <= 8) {
+ longToBinary(decimal.toUnscaledLong)
+ } else {
+ bigIntegerToBinary(decimal.toJavaBigDecimal.unscaledValue())
+ }
+
+ writer.addBinary(binary)
}
// array used to write Timestamp as Int96 (fixed-length binary)
@@ -268,11 +293,8 @@ private[parquet] class MutableRowWriteSupport extends RowWriteSupport {
writer.addBinary(Binary.fromByteArray(record.getUTF8String(index).getBytes))
case BinaryType =>
writer.addBinary(Binary.fromByteArray(record.getBinary(index)))
- case d: DecimalType =>
- if (d.precision > 18) {
- sys.error(s"Unsupported datatype $d, cannot write to consumer")
- }
- writeDecimal(record.getDecimal(index), d.precision)
+ case DecimalType.Fixed(precision, _) =>
+ writeDecimal(record.getDecimal(index), precision)
case _ => sys.error(s"Unsupported datatype $ctype, cannot write to consumer")
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala
index b5314a3dd9..b415da5b8c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala
@@ -106,21 +106,13 @@ class ParquetIOSuite extends QueryTest with ParquetTest {
// Parquet doesn't allow column names with spaces, have to add an alias here
.select($"_1" cast decimal as "dec")
- for ((precision, scale) <- Seq((5, 2), (1, 0), (1, 1), (18, 10), (18, 17))) {
+ for ((precision, scale) <- Seq((5, 2), (1, 0), (1, 1), (18, 10), (18, 17), (19, 0), (38, 37))) {
withTempPath { dir =>
val data = makeDecimalRDD(DecimalType(precision, scale))
data.write.parquet(dir.getCanonicalPath)
checkAnswer(sqlContext.read.parquet(dir.getCanonicalPath), data.collect().toSeq)
}
}
-
- // Decimals with precision above 18 are not yet supported
- intercept[Throwable] {
- withTempPath { dir =>
- makeDecimalRDD(DecimalType(19, 10)).write.parquet(dir.getCanonicalPath)
- sqlContext.read.parquet(dir.getCanonicalPath).collect()
- }
- }
}
test("date type") {