diff options
4 files changed, 85 insertions, 60 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystRowConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystRowConverter.scala index b5e4263008..e00bd90edb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystRowConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystRowConverter.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.parquet +import java.math.{BigDecimal, BigInteger} import java.nio.ByteOrder import scala.collection.JavaConversions._ @@ -263,17 +264,23 @@ private[parquet] class CatalystRowConverter( val scale = decimalType.scale val bytes = value.getBytes - var unscaled = 0L - var i = 0 + if (precision <= 8) { + // Constructs a `Decimal` with an unscaled `Long` value if possible. + var unscaled = 0L + var i = 0 - while (i < bytes.length) { - unscaled = (unscaled << 8) | (bytes(i) & 0xff) - i += 1 - } + while (i < bytes.length) { + unscaled = (unscaled << 8) | (bytes(i) & 0xff) + i += 1 + } - val bits = 8 * bytes.length - unscaled = (unscaled << (64 - bits)) >> (64 - bits) - Decimal(unscaled, precision, scale) + val bits = 8 * bytes.length + unscaled = (unscaled << (64 - bits)) >> (64 - bits) + Decimal(unscaled, precision, scale) + } else { + // Otherwise, resorts to an unscaled `BigInteger` instead. + Decimal(new BigDecimal(new BigInteger(bytes), scale), precision, scale) + } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystSchemaConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystSchemaConverter.scala index e9ef01e2db..d43ca95b4e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystSchemaConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystSchemaConverter.scala @@ -387,24 +387,18 @@ private[parquet] class CatalystSchemaConverter( // ===================================== // Spark 1.4.x and prior versions only support decimals with a maximum precision of 18 and - // always store decimals in fixed-length byte arrays. - case DecimalType.Fixed(precision, scale) - if precision <= maxPrecisionForBytes(8) && !followParquetFormatSpec => + // always store decimals in fixed-length byte arrays. To keep compatibility with these older + // versions, here we convert decimals with all precisions to `FIXED_LEN_BYTE_ARRAY` annotated + // by `DECIMAL`. + case DecimalType.Fixed(precision, scale) if !followParquetFormatSpec => Types .primitive(FIXED_LEN_BYTE_ARRAY, repetition) .as(DECIMAL) .precision(precision) .scale(scale) - .length(minBytesForPrecision(precision)) + .length(CatalystSchemaConverter.minBytesForPrecision(precision)) .named(field.name) - case dec @ DecimalType() if !followParquetFormatSpec => - throw new AnalysisException( - s"Data type $dec is not supported. " + - s"When ${SQLConf.PARQUET_FOLLOW_PARQUET_FORMAT_SPEC.key} is set to false," + - "decimal precision and scale must be specified, " + - "and precision must be less than or equal to 18.") - // ===================================== // Decimals (follow Parquet format spec) // ===================================== @@ -436,7 +430,7 @@ private[parquet] class CatalystSchemaConverter( .as(DECIMAL) .precision(precision) .scale(scale) - .length(minBytesForPrecision(precision)) + .length(CatalystSchemaConverter.minBytesForPrecision(precision)) .named(field.name) // =================================================== @@ -548,15 +542,6 @@ private[parquet] class CatalystSchemaConverter( Math.pow(2, 8 * numBytes - 1) - 1))) // max value stored in numBytes .asInstanceOf[Int] } - - // Min byte counts needed to store decimals with various precisions - private val minBytesForPrecision: Array[Int] = Array.tabulate(38) { precision => - var numBytes = 1 - while (math.pow(2.0, 8 * numBytes - 1) < math.pow(10.0, precision)) { - numBytes += 1 - } - numBytes - } } @@ -580,4 +565,23 @@ private[parquet] object CatalystSchemaConverter { throw new AnalysisException(message) } } + + private def computeMinBytesForPrecision(precision : Int) : Int = { + var numBytes = 1 + while (math.pow(2.0, 8 * numBytes - 1) < math.pow(10.0, precision)) { + numBytes += 1 + } + numBytes + } + + private val MIN_BYTES_FOR_PRECISION = Array.tabulate[Int](39)(computeMinBytesForPrecision) + + // Returns the minimum number of bytes needed to store a decimal with a given `precision`. + def minBytesForPrecision(precision : Int) : Int = { + if (precision < MIN_BYTES_FOR_PRECISION.length) { + MIN_BYTES_FOR_PRECISION(precision) + } else { + computeMinBytesForPrecision(precision) + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala index fc9f61a636..78ecfad1d5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.parquet +import java.math.BigInteger import java.nio.{ByteBuffer, ByteOrder} import java.util.{HashMap => JHashMap} @@ -114,11 +115,8 @@ private[parquet] class RowWriteSupport extends WriteSupport[InternalRow] with Lo Binary.fromByteArray(value.asInstanceOf[UTF8String].getBytes)) case BinaryType => writer.addBinary( Binary.fromByteArray(value.asInstanceOf[Array[Byte]])) - case d: DecimalType => - if (d.precision > 18) { - sys.error(s"Unsupported datatype $d, cannot write to consumer") - } - writeDecimal(value.asInstanceOf[Decimal], d.precision) + case DecimalType.Fixed(precision, _) => + writeDecimal(value.asInstanceOf[Decimal], precision) case _ => sys.error(s"Do not know how to writer $schema to consumer") } } @@ -199,20 +197,47 @@ private[parquet] class RowWriteSupport extends WriteSupport[InternalRow] with Lo writer.endGroup() } - // Scratch array used to write decimals as fixed-length binary - private[this] val scratchBytes = new Array[Byte](8) + // Scratch array used to write decimals as fixed-length byte array + private[this] var reusableDecimalBytes = new Array[Byte](16) private[parquet] def writeDecimal(decimal: Decimal, precision: Int): Unit = { - val numBytes = ParquetTypesConverter.BYTES_FOR_PRECISION(precision) - val unscaledLong = decimal.toUnscaledLong - var i = 0 - var shift = 8 * (numBytes - 1) - while (i < numBytes) { - scratchBytes(i) = (unscaledLong >> shift).toByte - i += 1 - shift -= 8 + val numBytes = CatalystSchemaConverter.minBytesForPrecision(precision) + + def longToBinary(unscaled: Long): Binary = { + var i = 0 + var shift = 8 * (numBytes - 1) + while (i < numBytes) { + reusableDecimalBytes(i) = (unscaled >> shift).toByte + i += 1 + shift -= 8 + } + Binary.fromByteArray(reusableDecimalBytes, 0, numBytes) } - writer.addBinary(Binary.fromByteArray(scratchBytes, 0, numBytes)) + + def bigIntegerToBinary(unscaled: BigInteger): Binary = { + unscaled.toByteArray match { + case bytes if bytes.length == numBytes => + Binary.fromByteArray(bytes) + + case bytes if bytes.length <= reusableDecimalBytes.length => + val signedByte = (if (bytes.head < 0) -1 else 0).toByte + java.util.Arrays.fill(reusableDecimalBytes, 0, numBytes - bytes.length, signedByte) + System.arraycopy(bytes, 0, reusableDecimalBytes, numBytes - bytes.length, bytes.length) + Binary.fromByteArray(reusableDecimalBytes, 0, numBytes) + + case bytes => + reusableDecimalBytes = new Array[Byte](bytes.length) + bigIntegerToBinary(unscaled) + } + } + + val binary = if (numBytes <= 8) { + longToBinary(decimal.toUnscaledLong) + } else { + bigIntegerToBinary(decimal.toJavaBigDecimal.unscaledValue()) + } + + writer.addBinary(binary) } // array used to write Timestamp as Int96 (fixed-length binary) @@ -268,11 +293,8 @@ private[parquet] class MutableRowWriteSupport extends RowWriteSupport { writer.addBinary(Binary.fromByteArray(record.getUTF8String(index).getBytes)) case BinaryType => writer.addBinary(Binary.fromByteArray(record.getBinary(index))) - case d: DecimalType => - if (d.precision > 18) { - sys.error(s"Unsupported datatype $d, cannot write to consumer") - } - writeDecimal(record.getDecimal(index), d.precision) + case DecimalType.Fixed(precision, _) => + writeDecimal(record.getDecimal(index), precision) case _ => sys.error(s"Unsupported datatype $ctype, cannot write to consumer") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala index b5314a3dd9..b415da5b8c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala @@ -106,21 +106,13 @@ class ParquetIOSuite extends QueryTest with ParquetTest { // Parquet doesn't allow column names with spaces, have to add an alias here .select($"_1" cast decimal as "dec") - for ((precision, scale) <- Seq((5, 2), (1, 0), (1, 1), (18, 10), (18, 17))) { + for ((precision, scale) <- Seq((5, 2), (1, 0), (1, 1), (18, 10), (18, 17), (19, 0), (38, 37))) { withTempPath { dir => val data = makeDecimalRDD(DecimalType(precision, scale)) data.write.parquet(dir.getCanonicalPath) checkAnswer(sqlContext.read.parquet(dir.getCanonicalPath), data.collect().toSeq) } } - - // Decimals with precision above 18 are not yet supported - intercept[Throwable] { - withTempPath { dir => - makeDecimalRDD(DecimalType(19, 10)).write.parquet(dir.getCanonicalPath) - sqlContext.read.parquet(dir.getCanonicalPath).collect() - } - } } test("date type") { |