diff options
author | Sameer Agarwal <sameer@databricks.com> | 2016-04-06 11:59:42 -0700 |
---|---|---|
committer | Yin Huai <yhuai@databricks.com> | 2016-04-06 11:59:42 -0700 |
commit | bb1fa5b2182f384cb711fc2be45b0f1a8c466ed6 (patch) | |
tree | 9b8a37bd7b069fc525c67a02861ab8acbad5a262 /sql | |
parent | af73d9737874f7adaec3cd19ac889ab3badb8e2a (diff) | |
download | spark-bb1fa5b2182f384cb711fc2be45b0f1a8c466ed6.tar.gz spark-bb1fa5b2182f384cb711fc2be45b0f1a8c466ed6.tar.bz2 spark-bb1fa5b2182f384cb711fc2be45b0f1a8c466ed6.zip |
[SPARK-14320][SQL] Make ColumnarBatch.Row mutable
## What changes were proposed in this pull request?
In order to leverage a data structure like `AggregateHashMap` (https://github.com/apache/spark/pull/12055) to speed up aggregates with keys, we need to make `ColumnarBatch.Row` mutable.
## How was this patch tested?
Unit test in `ColumnarBatchSuite`. Also, tested via `BenchmarkWholeStageCodegen`.
Author: Sameer Agarwal <sameer@databricks.com>
Closes #12103 from sameeragarwal/mutable-row.
Diffstat (limited to 'sql')
5 files changed, 135 insertions, 8 deletions
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/AggregateHashMap.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/AggregateHashMap.java index abe8db589d..69ce54390f 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/AggregateHashMap.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/AggregateHashMap.java @@ -19,6 +19,8 @@ package org.apache.spark.sql.execution.vectorized; import java.util.Arrays; +import com.google.common.annotations.VisibleForTesting; + import org.apache.spark.memory.MemoryMode; import org.apache.spark.sql.types.StructType; @@ -38,9 +40,9 @@ import static org.apache.spark.sql.types.DataTypes.LongType; * for certain distribution of keys) and requires us to fall back on the latter for correctness. */ public class AggregateHashMap { - public ColumnarBatch batch; - public int[] buckets; + private ColumnarBatch batch; + private int[] buckets; private int numBuckets; private int numRows = 0; private int maxSteps = 3; @@ -69,16 +71,17 @@ public class AggregateHashMap { this(schema, DEFAULT_CAPACITY, DEFAULT_LOAD_FACTOR, DEFAULT_MAX_STEPS); } - public int findOrInsert(long key) { + public ColumnarBatch.Row findOrInsert(long key) { int idx = find(key); if (idx != -1 && buckets[idx] == -1) { batch.column(0).putLong(numRows, key); batch.column(1).putLong(numRows, 0); buckets[idx] = numRows++; } - return idx; + return batch.getRow(buckets[idx]); } + @VisibleForTesting public int find(long key) { long h = hash(key); int step = 0; diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java index 74fa6323cc..d5daaf99df 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java @@ -566,6 +566,18 @@ public abstract class ColumnVector { } } + + public final void putDecimal(int rowId, Decimal value, int precision) { + if (precision <= Decimal.MAX_INT_DIGITS()) { + putInt(rowId, value.toInt()); + } else if (precision <= Decimal.MAX_LONG_DIGITS()) { + putLong(rowId, value.toLong()); + } else { + BigInteger bigInteger = value.toJavaBigDecimal().unscaledValue(); + putByteArray(rowId, bigInteger.toByteArray()); + } + } + /** * Returns the UTF8String for rowId. */ diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java index d1cc4e6d03..8cece73faa 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java @@ -16,6 +16,7 @@ */ package org.apache.spark.sql.execution.vectorized; +import java.math.BigDecimal; import java.util.*; import org.apache.commons.lang.NotImplementedException; @@ -23,6 +24,7 @@ import org.apache.commons.lang.NotImplementedException; import org.apache.spark.memory.MemoryMode; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.expressions.GenericMutableRow; +import org.apache.spark.sql.catalyst.expressions.MutableRow; import org.apache.spark.sql.catalyst.expressions.UnsafeRow; import org.apache.spark.sql.catalyst.util.ArrayData; import org.apache.spark.sql.catalyst.util.MapData; @@ -91,7 +93,7 @@ public final class ColumnarBatch { * Adapter class to interop with existing components that expect internal row. A lot of * performance is lost with this translation. */ - public static final class Row extends InternalRow { + public static final class Row extends MutableRow { protected int rowId; private final ColumnarBatch parent; private final int fixedLenRowSize; @@ -232,6 +234,96 @@ public final class ColumnarBatch { public Object get(int ordinal, DataType dataType) { throw new NotImplementedException(); } + + @Override + public void update(int ordinal, Object value) { + if (value == null) { + setNullAt(ordinal); + } else { + DataType dt = columns[ordinal].dataType(); + if (dt instanceof BooleanType) { + setBoolean(ordinal, (boolean) value); + } else if (dt instanceof IntegerType) { + setInt(ordinal, (int) value); + } else if (dt instanceof ShortType) { + setShort(ordinal, (short) value); + } else if (dt instanceof LongType) { + setLong(ordinal, (long) value); + } else if (dt instanceof FloatType) { + setFloat(ordinal, (float) value); + } else if (dt instanceof DoubleType) { + setDouble(ordinal, (double) value); + } else if (dt instanceof DecimalType) { + DecimalType t = (DecimalType) dt; + setDecimal(ordinal, Decimal.apply((BigDecimal) value, t.precision(), t.scale()), + t.precision()); + } else { + throw new NotImplementedException("Datatype not supported " + dt); + } + } + } + + @Override + public void setNullAt(int ordinal) { + assert (!columns[ordinal].isConstant); + columns[ordinal].putNull(rowId); + } + + @Override + public void setBoolean(int ordinal, boolean value) { + assert (!columns[ordinal].isConstant); + columns[ordinal].putNotNull(rowId); + columns[ordinal].putBoolean(rowId, value); + } + + @Override + public void setByte(int ordinal, byte value) { + assert (!columns[ordinal].isConstant); + columns[ordinal].putNotNull(rowId); + columns[ordinal].putByte(rowId, value); + } + + @Override + public void setShort(int ordinal, short value) { + assert (!columns[ordinal].isConstant); + columns[ordinal].putNotNull(rowId); + columns[ordinal].putShort(rowId, value); + } + + @Override + public void setInt(int ordinal, int value) { + assert (!columns[ordinal].isConstant); + columns[ordinal].putNotNull(rowId); + columns[ordinal].putInt(rowId, value); + } + + @Override + public void setLong(int ordinal, long value) { + assert (!columns[ordinal].isConstant); + columns[ordinal].putNotNull(rowId); + columns[ordinal].putLong(rowId, value); + } + + @Override + public void setFloat(int ordinal, float value) { + assert (!columns[ordinal].isConstant); + columns[ordinal].putNotNull(rowId); + columns[ordinal].putFloat(rowId, value); + } + + @Override + public void setDouble(int ordinal, double value) { + assert (!columns[ordinal].isConstant); + columns[ordinal].putNotNull(rowId); + columns[ordinal].putDouble(rowId, value); + } + + @Override + public void setDecimal(int ordinal, Decimal value, int precision) { + assert (!columns[ordinal].isConstant); + columns[ordinal].putNotNull(rowId); + columns[ordinal].putDecimal(rowId, value, precision); + } } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala index 3566ef3043..5dbf619876 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala @@ -517,9 +517,8 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { .add("value", LongType) val map = new AggregateHashMap(schema) while (i < numKeys) { - val idx = map.findOrInsert(i.toLong) - map.batch.column(1).putLong(map.buckets(idx), - map.batch.column(1).getLong(map.buckets(idx)) + 1) + val row = map.findOrInsert(i.toLong) + row.setLong(1, row.getLong(1) + 1) i += 1 } var s = 0 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala index 4262097e8f..8a551cd78c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala @@ -756,4 +756,25 @@ class ColumnarBatchSuite extends SparkFunSuite { }} } } + + test("mutable ColumnarBatch rows") { + val NUM_ITERS = 10 + val types = Array( + BooleanType, FloatType, DoubleType, + IntegerType, LongType, ShortType, DecimalType.IntDecimal, new DecimalType(30, 10)) + for (i <- 0 to NUM_ITERS) { + val random = new Random(System.nanoTime()) + val schema = RandomDataGenerator.randomSchema(random, numFields = 20, types) + val oldRow = RandomDataGenerator.randomRow(random, schema) + val newRow = RandomDataGenerator.randomRow(random, schema) + + (MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode => + val batch = ColumnVectorUtils.toBatch(schema, memMode, (oldRow :: Nil).iterator.asJava) + val columnarBatchRow = batch.getRow(0) + newRow.toSeq.zipWithIndex.foreach(i => columnarBatchRow.update(i._2, i._1)) + compareStruct(schema, columnarBatchRow, newRow, 0) + batch.close() + } + } + } } |