aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSameer Agarwal <sameer@databricks.com>2016-04-21 16:00:59 -0700
committerReynold Xin <rxin@databricks.com>2016-04-21 16:00:59 -0700
commitf82aa82480d95451510ee7a74c52e83e98c8b794 (patch)
tree4fbb47feb8e96b517c07fceafcef20d268624285
parent1a95397bb6a4e7e7a06ac450bf556fa3aa47b8cd (diff)
downloadspark-f82aa82480d95451510ee7a74c52e83e98c8b794.tar.gz
spark-f82aa82480d95451510ee7a74c52e83e98c8b794.tar.bz2
spark-f82aa82480d95451510ee7a74c52e83e98c8b794.zip
[SPARK-14774][SQL] Write unscaled values in ColumnVector.putDecimal
## What changes were proposed in this pull request? We recently made `ColumnarBatch.row` mutable and added a new `ColumnVector.putDecimal` method to support putting `Decimal` values in the `ColumnarBatch`. This unfortunately introduced a bug wherein we were not updating the vector with the proper unscaled values. ## How was this patch tested? This codepath is hit only when the vectorized aggregate hashmap is enabled. https://github.com/apache/spark/pull/12440 makes sure that a number of regression tests/benchmarks test this bugfix. Author: Sameer Agarwal <sameer@databricks.com> Closes #12541 from sameeragarwal/fix-bigdecimal.
-rw-r--r--sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java4
-rw-r--r--sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java8
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala55
3 files changed, 37 insertions, 30 deletions
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java
index ff1f6680a7..e7dccd1b22 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java
@@ -569,9 +569,9 @@ public abstract class ColumnVector implements AutoCloseable {
public final void putDecimal(int rowId, Decimal value, int precision) {
if (precision <= Decimal.MAX_INT_DIGITS()) {
- putInt(rowId, value.toInt());
+ putInt(rowId, (int) value.toUnscaledLong());
} else if (precision <= Decimal.MAX_LONG_DIGITS()) {
- putLong(rowId, value.toLong());
+ putLong(rowId, value.toUnscaledLong());
} else {
BigInteger bigInteger = value.toJavaBigDecimal().unscaledValue();
putByteArray(rowId, bigInteger.toByteArray());
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java
index 2dc57dc50d..f50c35fc64 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java
@@ -142,9 +142,11 @@ public class ColumnVectorUtils {
byte[] b =((String)o).getBytes(StandardCharsets.UTF_8);
dst.appendByteArray(b, 0, b.length);
} else if (t instanceof DecimalType) {
- DecimalType dt = (DecimalType)t;
- Decimal d = Decimal.apply((BigDecimal)o, dt.precision(), dt.scale());
- if (dt.precision() <= Decimal.MAX_LONG_DIGITS()) {
+ DecimalType dt = (DecimalType) t;
+ Decimal d = Decimal.apply((BigDecimal) o, dt.precision(), dt.scale());
+ if (dt.precision() <= Decimal.MAX_INT_DIGITS()) {
+ dst.appendInt((int) d.toUnscaledLong());
+ } else if (dt.precision() <= Decimal.MAX_LONG_DIGITS()) {
dst.appendLong(d.toUnscaledLong());
} else {
final BigInteger integer = d.toJavaBigDecimal().unscaledValue();
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala
index 31b63f2ce1..a63007fc3b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala
@@ -586,30 +586,31 @@ class ColumnarBatchSuite extends SparkFunSuite {
}
private def compareStruct(fields: Seq[StructField], r1: InternalRow, r2: Row, seed: Long) {
- fields.zipWithIndex.foreach { v => {
- assert(r1.isNullAt(v._2) == r2.isNullAt(v._2), "Seed = " + seed)
- if (!r1.isNullAt(v._2)) {
- v._1.dataType match {
- case BooleanType => assert(r1.getBoolean(v._2) == r2.getBoolean(v._2), "Seed = " + seed)
- case ByteType => assert(r1.getByte(v._2) == r2.getByte(v._2), "Seed = " + seed)
- case ShortType => assert(r1.getShort(v._2) == r2.getShort(v._2), "Seed = " + seed)
- case IntegerType => assert(r1.getInt(v._2) == r2.getInt(v._2), "Seed = " + seed)
- case LongType => assert(r1.getLong(v._2) == r2.getLong(v._2), "Seed = " + seed)
- case FloatType => assert(doubleEquals(r1.getFloat(v._2), r2.getFloat(v._2)),
+ fields.zipWithIndex.foreach { case (field: StructField, ordinal: Int) =>
+ assert(r1.isNullAt(ordinal) == r2.isNullAt(ordinal), "Seed = " + seed)
+ if (!r1.isNullAt(ordinal)) {
+ field.dataType match {
+ case BooleanType => assert(r1.getBoolean(ordinal) == r2.getBoolean(ordinal),
"Seed = " + seed)
- case DoubleType => assert(doubleEquals(r1.getDouble(v._2), r2.getDouble(v._2)),
+ case ByteType => assert(r1.getByte(ordinal) == r2.getByte(ordinal), "Seed = " + seed)
+ case ShortType => assert(r1.getShort(ordinal) == r2.getShort(ordinal), "Seed = " + seed)
+ case IntegerType => assert(r1.getInt(ordinal) == r2.getInt(ordinal), "Seed = " + seed)
+ case LongType => assert(r1.getLong(ordinal) == r2.getLong(ordinal), "Seed = " + seed)
+ case FloatType => assert(doubleEquals(r1.getFloat(ordinal), r2.getFloat(ordinal)),
+ "Seed = " + seed)
+ case DoubleType => assert(doubleEquals(r1.getDouble(ordinal), r2.getDouble(ordinal)),
"Seed = " + seed)
case t: DecimalType =>
- val d1 = r1.getDecimal(v._2, t.precision, t.scale).toBigDecimal
- val d2 = r2.getDecimal(v._2)
+ val d1 = r1.getDecimal(ordinal, t.precision, t.scale).toBigDecimal
+ val d2 = r2.getDecimal(ordinal)
assert(d1.compare(d2) == 0, "Seed = " + seed)
case StringType =>
- assert(r1.getString(v._2) == r2.getString(v._2), "Seed = " + seed)
+ assert(r1.getString(ordinal) == r2.getString(ordinal), "Seed = " + seed)
case CalendarIntervalType =>
- assert(r1.getInterval(v._2) === r2.get(v._2).asInstanceOf[CalendarInterval])
+ assert(r1.getInterval(ordinal) === r2.get(ordinal).asInstanceOf[CalendarInterval])
case ArrayType(childType, n) =>
- val a1 = r1.getArray(v._2).array
- val a2 = r2.getList(v._2).toArray
+ val a1 = r1.getArray(ordinal).array
+ val a2 = r2.getList(ordinal).toArray
assert(a1.length == a2.length, "Seed = " + seed)
childType match {
case DoubleType =>
@@ -640,12 +641,13 @@ class ColumnarBatchSuite extends SparkFunSuite {
case _ => assert(a1 === a2, "Seed = " + seed)
}
case StructType(childFields) =>
- compareStruct(childFields, r1.getStruct(v._2, fields.length), r2.getStruct(v._2), seed)
+ compareStruct(childFields, r1.getStruct(ordinal, fields.length),
+ r2.getStruct(ordinal), seed)
case _ =>
- throw new NotImplementedError("Not implemented " + v._1.dataType)
+ throw new NotImplementedError("Not implemented " + field.dataType)
}
}
- }}
+ }
}
test("Convert rows") {
@@ -678,9 +680,10 @@ class ColumnarBatchSuite extends SparkFunSuite {
def testRandomRows(flatSchema: Boolean, numFields: Int) {
// TODO: Figure out why StringType doesn't work on jenkins.
val types = Array(
- BooleanType, ByteType, FloatType, DoubleType,
- IntegerType, LongType, ShortType, DecimalType.IntDecimal, new DecimalType(30, 10),
- CalendarIntervalType)
+ BooleanType, ByteType, FloatType, DoubleType, IntegerType, LongType, ShortType,
+ DecimalType.ShortDecimal, DecimalType.IntDecimal, DecimalType.ByteDecimal,
+ DecimalType.FloatDecimal, DecimalType.LongDecimal, new DecimalType(5, 2),
+ new DecimalType(12, 2), new DecimalType(30, 10), CalendarIntervalType)
val seed = System.nanoTime()
val NUM_ROWS = 200
val NUM_ITERS = 1000
@@ -756,8 +759,10 @@ class ColumnarBatchSuite extends SparkFunSuite {
test("mutable ColumnarBatch rows") {
val NUM_ITERS = 10
val types = Array(
- BooleanType, FloatType, DoubleType,
- IntegerType, LongType, ShortType, DecimalType.IntDecimal, new DecimalType(30, 10))
+ BooleanType, FloatType, DoubleType, IntegerType, LongType, ShortType,
+ DecimalType.ShortDecimal, DecimalType.IntDecimal, DecimalType.ByteDecimal,
+ DecimalType.FloatDecimal, DecimalType.LongDecimal, new DecimalType(5, 2),
+ new DecimalType(12, 2), new DecimalType(30, 10))
for (i <- 0 to NUM_ITERS) {
val random = new Random(System.nanoTime())
val schema = RandomDataGenerator.randomSchema(random, numFields = 20, types)