aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorSameer Agarwal <sameer@databricks.com>2016-04-06 11:59:42 -0700
committerYin Huai <yhuai@databricks.com>2016-04-06 11:59:42 -0700
commitbb1fa5b2182f384cb711fc2be45b0f1a8c466ed6 (patch)
tree9b8a37bd7b069fc525c67a02861ab8acbad5a262 /sql
parentaf73d9737874f7adaec3cd19ac889ab3badb8e2a (diff)
downloadspark-bb1fa5b2182f384cb711fc2be45b0f1a8c466ed6.tar.gz
spark-bb1fa5b2182f384cb711fc2be45b0f1a8c466ed6.tar.bz2
spark-bb1fa5b2182f384cb711fc2be45b0f1a8c466ed6.zip
[SPARK-14320][SQL] Make ColumnarBatch.Row mutable
## What changes were proposed in this pull request? In order to leverage a data structure like `AggregateHashMap` (https://github.com/apache/spark/pull/12055) to speed up aggregates with keys, we need to make `ColumnarBatch.Row` mutable. ## How was this patch tested? Unit test in `ColumnarBatchSuite`. Also, tested via `BenchmarkWholeStageCodegen`. Author: Sameer Agarwal <sameer@databricks.com> Closes #12103 from sameeragarwal/mutable-row.
Diffstat (limited to 'sql')
-rw-r--r--sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/AggregateHashMap.java11
-rw-r--r--sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java12
-rw-r--r--sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java94
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala5
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala21
5 files changed, 135 insertions, 8 deletions
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/AggregateHashMap.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/AggregateHashMap.java
index abe8db589d..69ce54390f 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/AggregateHashMap.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/AggregateHashMap.java
@@ -19,6 +19,8 @@ package org.apache.spark.sql.execution.vectorized;
import java.util.Arrays;
+import com.google.common.annotations.VisibleForTesting;
+
import org.apache.spark.memory.MemoryMode;
import org.apache.spark.sql.types.StructType;
@@ -38,9 +40,9 @@ import static org.apache.spark.sql.types.DataTypes.LongType;
* for certain distribution of keys) and requires us to fall back on the latter for correctness.
*/
public class AggregateHashMap {
- public ColumnarBatch batch;
- public int[] buckets;
+ private ColumnarBatch batch;
+ private int[] buckets;
private int numBuckets;
private int numRows = 0;
private int maxSteps = 3;
@@ -69,16 +71,17 @@ public class AggregateHashMap {
this(schema, DEFAULT_CAPACITY, DEFAULT_LOAD_FACTOR, DEFAULT_MAX_STEPS);
}
- public int findOrInsert(long key) {
+ public ColumnarBatch.Row findOrInsert(long key) {
int idx = find(key);
if (idx != -1 && buckets[idx] == -1) {
batch.column(0).putLong(numRows, key);
batch.column(1).putLong(numRows, 0);
buckets[idx] = numRows++;
}
- return idx;
+ return batch.getRow(buckets[idx]);
}
+ @VisibleForTesting
public int find(long key) {
long h = hash(key);
int step = 0;
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java
index 74fa6323cc..d5daaf99df 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java
@@ -566,6 +566,18 @@ public abstract class ColumnVector {
}
}
+
+ public final void putDecimal(int rowId, Decimal value, int precision) {
+ if (precision <= Decimal.MAX_INT_DIGITS()) {
+ putInt(rowId, value.toInt());
+ } else if (precision <= Decimal.MAX_LONG_DIGITS()) {
+ putLong(rowId, value.toLong());
+ } else {
+ BigInteger bigInteger = value.toJavaBigDecimal().unscaledValue();
+ putByteArray(rowId, bigInteger.toByteArray());
+ }
+ }
+
/**
* Returns the UTF8String for rowId.
*/
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java
index d1cc4e6d03..8cece73faa 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java
@@ -16,6 +16,7 @@
*/
package org.apache.spark.sql.execution.vectorized;
+import java.math.BigDecimal;
import java.util.*;
import org.apache.commons.lang.NotImplementedException;
@@ -23,6 +24,7 @@ import org.apache.commons.lang.NotImplementedException;
import org.apache.spark.memory.MemoryMode;
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.catalyst.expressions.GenericMutableRow;
+import org.apache.spark.sql.catalyst.expressions.MutableRow;
import org.apache.spark.sql.catalyst.expressions.UnsafeRow;
import org.apache.spark.sql.catalyst.util.ArrayData;
import org.apache.spark.sql.catalyst.util.MapData;
@@ -91,7 +93,7 @@ public final class ColumnarBatch {
* Adapter class to interop with existing components that expect internal row. A lot of
* performance is lost with this translation.
*/
- public static final class Row extends InternalRow {
+ public static final class Row extends MutableRow {
protected int rowId;
private final ColumnarBatch parent;
private final int fixedLenRowSize;
@@ -232,6 +234,96 @@ public final class ColumnarBatch {
public Object get(int ordinal, DataType dataType) {
throw new NotImplementedException();
}
+
+ @Override
+ public void update(int ordinal, Object value) {
+ if (value == null) {
+ setNullAt(ordinal);
+ } else {
+ DataType dt = columns[ordinal].dataType();
+ if (dt instanceof BooleanType) {
+ setBoolean(ordinal, (boolean) value);
+ } else if (dt instanceof IntegerType) {
+ setInt(ordinal, (int) value);
+ } else if (dt instanceof ShortType) {
+ setShort(ordinal, (short) value);
+ } else if (dt instanceof LongType) {
+ setLong(ordinal, (long) value);
+ } else if (dt instanceof FloatType) {
+ setFloat(ordinal, (float) value);
+ } else if (dt instanceof DoubleType) {
+ setDouble(ordinal, (double) value);
+ } else if (dt instanceof DecimalType) {
+ DecimalType t = (DecimalType) dt;
+ setDecimal(ordinal, Decimal.apply((BigDecimal) value, t.precision(), t.scale()),
+ t.precision());
+ } else {
+ throw new NotImplementedException("Datatype not supported " + dt);
+ }
+ }
+ }
+
+ @Override
+ public void setNullAt(int ordinal) {
+ assert (!columns[ordinal].isConstant);
+ columns[ordinal].putNull(rowId);
+ }
+
+ @Override
+ public void setBoolean(int ordinal, boolean value) {
+ assert (!columns[ordinal].isConstant);
+ columns[ordinal].putNotNull(rowId);
+ columns[ordinal].putBoolean(rowId, value);
+ }
+
+ @Override
+ public void setByte(int ordinal, byte value) {
+ assert (!columns[ordinal].isConstant);
+ columns[ordinal].putNotNull(rowId);
+ columns[ordinal].putByte(rowId, value);
+ }
+
+ @Override
+ public void setShort(int ordinal, short value) {
+ assert (!columns[ordinal].isConstant);
+ columns[ordinal].putNotNull(rowId);
+ columns[ordinal].putShort(rowId, value);
+ }
+
+ @Override
+ public void setInt(int ordinal, int value) {
+ assert (!columns[ordinal].isConstant);
+ columns[ordinal].putNotNull(rowId);
+ columns[ordinal].putInt(rowId, value);
+ }
+
+ @Override
+ public void setLong(int ordinal, long value) {
+ assert (!columns[ordinal].isConstant);
+ columns[ordinal].putNotNull(rowId);
+ columns[ordinal].putLong(rowId, value);
+ }
+
+ @Override
+ public void setFloat(int ordinal, float value) {
+ assert (!columns[ordinal].isConstant);
+ columns[ordinal].putNotNull(rowId);
+ columns[ordinal].putFloat(rowId, value);
+ }
+
+ @Override
+ public void setDouble(int ordinal, double value) {
+ assert (!columns[ordinal].isConstant);
+ columns[ordinal].putNotNull(rowId);
+ columns[ordinal].putDouble(rowId, value);
+ }
+
+ @Override
+ public void setDecimal(int ordinal, Decimal value, int precision) {
+ assert (!columns[ordinal].isConstant);
+ columns[ordinal].putNotNull(rowId);
+ columns[ordinal].putDecimal(rowId, value, precision);
+ }
}
/**
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala
index 3566ef3043..5dbf619876 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala
@@ -517,9 +517,8 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
.add("value", LongType)
val map = new AggregateHashMap(schema)
while (i < numKeys) {
- val idx = map.findOrInsert(i.toLong)
- map.batch.column(1).putLong(map.buckets(idx),
- map.batch.column(1).getLong(map.buckets(idx)) + 1)
+ val row = map.findOrInsert(i.toLong)
+ row.setLong(1, row.getLong(1) + 1)
i += 1
}
var s = 0
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala
index 4262097e8f..8a551cd78c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala
@@ -756,4 +756,25 @@ class ColumnarBatchSuite extends SparkFunSuite {
}}
}
}
+
+ test("mutable ColumnarBatch rows") {
+ val NUM_ITERS = 10
+ val types = Array(
+ BooleanType, FloatType, DoubleType,
+ IntegerType, LongType, ShortType, DecimalType.IntDecimal, new DecimalType(30, 10))
+ for (i <- 0 to NUM_ITERS) {
+ val random = new Random(System.nanoTime())
+ val schema = RandomDataGenerator.randomSchema(random, numFields = 20, types)
+ val oldRow = RandomDataGenerator.randomRow(random, schema)
+ val newRow = RandomDataGenerator.randomRow(random, schema)
+
+ (MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode =>
+ val batch = ColumnVectorUtils.toBatch(schema, memMode, (oldRow :: Nil).iterator.asJava)
+ val columnarBatchRow = batch.getRow(0)
+ newRow.toSeq.zipWithIndex.foreach(i => columnarBatchRow.update(i._2, i._1))
+ compareStruct(schema, columnarBatchRow, newRow, 0)
+ batch.close()
+ }
+ }
+ }
}