From f82aa82480d95451510ee7a74c52e83e98c8b794 Mon Sep 17 00:00:00 2001 From: Sameer Agarwal Date: Thu, 21 Apr 2016 16:00:59 -0700 Subject: [SPARK-14774][SQL] Write unscaled values in ColumnVector.putDecimal ## What changes were proposed in this pull request? We recently made `ColumnarBatch.row` mutable and added a new `ColumnVector.putDecimal` method to support putting `Decimal` values in the `ColumnarBatch`. This unfortunately introduced a bug wherein we were not updating the vector with the proper unscaled values. ## How was this patch tested? This codepath is hit only when the vectorized aggregate hashmap is enabled. https://github.com/apache/spark/pull/12440 makes sure that a number of regression tests/benchmarks test this bugfix. Author: Sameer Agarwal Closes #12541 from sameeragarwal/fix-bigdecimal. --- .../sql/execution/vectorized/ColumnVector.java | 4 +- .../execution/vectorized/ColumnVectorUtils.java | 8 ++-- .../execution/vectorized/ColumnarBatchSuite.scala | 55 ++++++++++++---------- 3 files changed, 37 insertions(+), 30 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java index ff1f6680a7..e7dccd1b22 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java @@ -569,9 +569,9 @@ public abstract class ColumnVector implements AutoCloseable { public final void putDecimal(int rowId, Decimal value, int precision) { if (precision <= Decimal.MAX_INT_DIGITS()) { - putInt(rowId, value.toInt()); + putInt(rowId, (int) value.toUnscaledLong()); } else if (precision <= Decimal.MAX_LONG_DIGITS()) { - putLong(rowId, value.toLong()); + putLong(rowId, value.toUnscaledLong()); } else { BigInteger bigInteger = value.toJavaBigDecimal().unscaledValue(); putByteArray(rowId, bigInteger.toByteArray()); diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java index 2dc57dc50d..f50c35fc64 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java @@ -142,9 +142,11 @@ public class ColumnVectorUtils { byte[] b =((String)o).getBytes(StandardCharsets.UTF_8); dst.appendByteArray(b, 0, b.length); } else if (t instanceof DecimalType) { - DecimalType dt = (DecimalType)t; - Decimal d = Decimal.apply((BigDecimal)o, dt.precision(), dt.scale()); - if (dt.precision() <= Decimal.MAX_LONG_DIGITS()) { + DecimalType dt = (DecimalType) t; + Decimal d = Decimal.apply((BigDecimal) o, dt.precision(), dt.scale()); + if (dt.precision() <= Decimal.MAX_INT_DIGITS()) { + dst.appendInt((int) d.toUnscaledLong()); + } else if (dt.precision() <= Decimal.MAX_LONG_DIGITS()) { dst.appendLong(d.toUnscaledLong()); } else { final BigInteger integer = d.toJavaBigDecimal().unscaledValue(); diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala index 31b63f2ce1..a63007fc3b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala @@ -586,30 +586,31 @@ class ColumnarBatchSuite extends SparkFunSuite { } private def compareStruct(fields: Seq[StructField], r1: InternalRow, r2: Row, seed: Long) { - fields.zipWithIndex.foreach { v => { - assert(r1.isNullAt(v._2) == r2.isNullAt(v._2), "Seed = " + seed) - if (!r1.isNullAt(v._2)) { - v._1.dataType match { - case BooleanType => assert(r1.getBoolean(v._2) == r2.getBoolean(v._2), "Seed = " + seed) - case ByteType => assert(r1.getByte(v._2) == r2.getByte(v._2), "Seed = " + seed) - case ShortType => assert(r1.getShort(v._2) == r2.getShort(v._2), "Seed = " + seed) - case IntegerType => assert(r1.getInt(v._2) == r2.getInt(v._2), "Seed = " + seed) - case LongType => assert(r1.getLong(v._2) == r2.getLong(v._2), "Seed = " + seed) - case FloatType => assert(doubleEquals(r1.getFloat(v._2), r2.getFloat(v._2)), + fields.zipWithIndex.foreach { case (field: StructField, ordinal: Int) => + assert(r1.isNullAt(ordinal) == r2.isNullAt(ordinal), "Seed = " + seed) + if (!r1.isNullAt(ordinal)) { + field.dataType match { + case BooleanType => assert(r1.getBoolean(ordinal) == r2.getBoolean(ordinal), "Seed = " + seed) - case DoubleType => assert(doubleEquals(r1.getDouble(v._2), r2.getDouble(v._2)), + case ByteType => assert(r1.getByte(ordinal) == r2.getByte(ordinal), "Seed = " + seed) + case ShortType => assert(r1.getShort(ordinal) == r2.getShort(ordinal), "Seed = " + seed) + case IntegerType => assert(r1.getInt(ordinal) == r2.getInt(ordinal), "Seed = " + seed) + case LongType => assert(r1.getLong(ordinal) == r2.getLong(ordinal), "Seed = " + seed) + case FloatType => assert(doubleEquals(r1.getFloat(ordinal), r2.getFloat(ordinal)), + "Seed = " + seed) + case DoubleType => assert(doubleEquals(r1.getDouble(ordinal), r2.getDouble(ordinal)), "Seed = " + seed) case t: DecimalType => - val d1 = r1.getDecimal(v._2, t.precision, t.scale).toBigDecimal - val d2 = r2.getDecimal(v._2) + val d1 = r1.getDecimal(ordinal, t.precision, t.scale).toBigDecimal + val d2 = r2.getDecimal(ordinal) assert(d1.compare(d2) == 0, "Seed = " + seed) case StringType => - assert(r1.getString(v._2) == r2.getString(v._2), "Seed = " + seed) + assert(r1.getString(ordinal) == r2.getString(ordinal), "Seed = " + seed) case CalendarIntervalType => - assert(r1.getInterval(v._2) === r2.get(v._2).asInstanceOf[CalendarInterval]) + assert(r1.getInterval(ordinal) === r2.get(ordinal).asInstanceOf[CalendarInterval]) case ArrayType(childType, n) => - val a1 = r1.getArray(v._2).array - val a2 = r2.getList(v._2).toArray + val a1 = r1.getArray(ordinal).array + val a2 = r2.getList(ordinal).toArray assert(a1.length == a2.length, "Seed = " + seed) childType match { case DoubleType => @@ -640,12 +641,13 @@ class ColumnarBatchSuite extends SparkFunSuite { case _ => assert(a1 === a2, "Seed = " + seed) } case StructType(childFields) => - compareStruct(childFields, r1.getStruct(v._2, fields.length), r2.getStruct(v._2), seed) + compareStruct(childFields, r1.getStruct(ordinal, fields.length), + r2.getStruct(ordinal), seed) case _ => - throw new NotImplementedError("Not implemented " + v._1.dataType) + throw new NotImplementedError("Not implemented " + field.dataType) } } - }} + } } test("Convert rows") { @@ -678,9 +680,10 @@ class ColumnarBatchSuite extends SparkFunSuite { def testRandomRows(flatSchema: Boolean, numFields: Int) { // TODO: Figure out why StringType doesn't work on jenkins. val types = Array( - BooleanType, ByteType, FloatType, DoubleType, - IntegerType, LongType, ShortType, DecimalType.IntDecimal, new DecimalType(30, 10), - CalendarIntervalType) + BooleanType, ByteType, FloatType, DoubleType, IntegerType, LongType, ShortType, + DecimalType.ShortDecimal, DecimalType.IntDecimal, DecimalType.ByteDecimal, + DecimalType.FloatDecimal, DecimalType.LongDecimal, new DecimalType(5, 2), + new DecimalType(12, 2), new DecimalType(30, 10), CalendarIntervalType) val seed = System.nanoTime() val NUM_ROWS = 200 val NUM_ITERS = 1000 @@ -756,8 +759,10 @@ class ColumnarBatchSuite extends SparkFunSuite { test("mutable ColumnarBatch rows") { val NUM_ITERS = 10 val types = Array( - BooleanType, FloatType, DoubleType, - IntegerType, LongType, ShortType, DecimalType.IntDecimal, new DecimalType(30, 10)) + BooleanType, FloatType, DoubleType, IntegerType, LongType, ShortType, + DecimalType.ShortDecimal, DecimalType.IntDecimal, DecimalType.ByteDecimal, + DecimalType.FloatDecimal, DecimalType.LongDecimal, new DecimalType(5, 2), + new DecimalType(12, 2), new DecimalType(30, 10)) for (i <- 0 to NUM_ITERS) { val random = new Random(System.nanoTime()) val schema = RandomDataGenerator.randomSchema(random, numFields = 20, types) -- cgit v1.2.3