aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2015-06-29 15:59:20 -0700
committerDavies Liu <davies@databricks.com>2015-06-29 15:59:20 -0700
commited359de595d5dd67b666660eddf092eaf89041c8 (patch)
tree1cb3205b2155cfd828522f25d6ff74e468917010 /sql
parent931da5c8ab271ff2ee04419c7e3c6b0012459694 (diff)
downloadspark-ed359de595d5dd67b666660eddf092eaf89041c8.tar.gz
spark-ed359de595d5dd67b666660eddf092eaf89041c8.tar.bz2
spark-ed359de595d5dd67b666660eddf092eaf89041c8.zip
[SPARK-8579] [SQL] support arbitrary object in UnsafeRow
This PR brings arbitrary object support in UnsafeRow (both in grouping key and aggregation buffer). Two object pools will be created to hold those non-primitive objects, and put the index of them into UnsafeRow. In order to compare the grouping key as bytes, the objects in key will be stored in a unique object pool, to make sure same objects will have same index (used as hashCode). For StringType and BinaryType, we still put them as var-length in UnsafeRow when initializing for better performance. But for update, they will be an object inside object pools (there will be some garbages left in the buffer). BTW: Will create a JIRA once issue.apache.org is available. cc JoshRosen rxin Author: Davies Liu <davies@databricks.com> Closes #6959 from davies/unsafe_obj and squashes the following commits: 5ce39da [Davies Liu] fix comment 5e797bf [Davies Liu] Merge branch 'master' of github.com:apache/spark into unsafe_obj 5803d64 [Davies Liu] fix conflict 461d304 [Davies Liu] Merge branch 'master' of github.com:apache/spark into unsafe_obj 2f41c90 [Davies Liu] Merge branch 'master' of github.com:apache/spark into unsafe_obj b04d69c [Davies Liu] address comments 4859b80 [Davies Liu] fix comments f38011c [Davies Liu] add a test for grouping by decimal d2cf7ab [Davies Liu] add more tests for null checking 71983c5 [Davies Liu] add test for timestamp e8a1649 [Davies Liu] reuse buffer for string 39f09ca [Davies Liu] Merge branch 'master' of github.com:apache/spark into unsafe_obj 035501e [Davies Liu] fix style 236d6de [Davies Liu] support arbitrary object in UnsafeRow
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java144
-rw-r--r--sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java218
-rw-r--r--sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/util/ObjectPool.java78
-rw-r--r--sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/util/UniqueObjectPool.java59
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala5
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala94
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala65
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala190
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ObjectPoolSuite.scala57
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala16
10 files changed, 615 insertions, 311 deletions
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java
index 83f2a31297..1e79f4b2e8 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java
@@ -19,9 +19,11 @@ package org.apache.spark.sql.catalyst.expressions;
import java.util.Iterator;
+import scala.Function1;
+
import org.apache.spark.sql.catalyst.InternalRow;
-import org.apache.spark.sql.types.StructField;
-import org.apache.spark.sql.types.StructType;
+import org.apache.spark.sql.catalyst.util.ObjectPool;
+import org.apache.spark.sql.catalyst.util.UniqueObjectPool;
import org.apache.spark.unsafe.PlatformDependent;
import org.apache.spark.unsafe.map.BytesToBytesMap;
import org.apache.spark.unsafe.memory.MemoryLocation;
@@ -38,16 +40,28 @@ public final class UnsafeFixedWidthAggregationMap {
* An empty aggregation buffer, encoded in UnsafeRow format. When inserting a new key into the
* map, we copy this buffer and use it as the value.
*/
- private final byte[] emptyAggregationBuffer;
+ private final byte[] emptyBuffer;
- private final StructType aggregationBufferSchema;
+ /**
+ * An empty row used by `initProjection`
+ */
+ private static final InternalRow emptyRow = new GenericInternalRow();
- private final StructType groupingKeySchema;
+ /**
+ * Whether can the empty aggregation buffer be reuse without calling `initProjection` or not.
+ */
+ private final boolean reuseEmptyBuffer;
/**
- * Encodes grouping keys as UnsafeRows.
+ * The projection used to initialize the emptyBuffer
*/
- private final UnsafeRowConverter groupingKeyToUnsafeRowConverter;
+ private final Function1<InternalRow, InternalRow> initProjection;
+
+ /**
+ * Encodes grouping keys or buffers as UnsafeRows.
+ */
+ private final UnsafeRowConverter keyConverter;
+ private final UnsafeRowConverter bufferConverter;
/**
* A hashmap which maps from opaque bytearray keys to bytearray values.
@@ -55,9 +69,19 @@ public final class UnsafeFixedWidthAggregationMap {
private final BytesToBytesMap map;
/**
+ * An object pool for objects that are used in grouping keys.
+ */
+ private final UniqueObjectPool keyPool;
+
+ /**
+ * An object pool for objects that are used in aggregation buffers.
+ */
+ private final ObjectPool bufferPool;
+
+ /**
* Re-used pointer to the current aggregation buffer
*/
- private final UnsafeRow currentAggregationBuffer = new UnsafeRow();
+ private final UnsafeRow currentBuffer = new UnsafeRow();
/**
* Scratch space that is used when encoding grouping keys into UnsafeRow format.
@@ -70,67 +94,38 @@ public final class UnsafeFixedWidthAggregationMap {
private final boolean enablePerfMetrics;
/**
- * @return true if UnsafeFixedWidthAggregationMap supports grouping keys with the given schema,
- * false otherwise.
- */
- public static boolean supportsGroupKeySchema(StructType schema) {
- for (StructField field: schema.fields()) {
- if (!UnsafeRow.readableFieldTypes.contains(field.dataType())) {
- return false;
- }
- }
- return true;
- }
-
- /**
- * @return true if UnsafeFixedWidthAggregationMap supports aggregation buffers with the given
- * schema, false otherwise.
- */
- public static boolean supportsAggregationBufferSchema(StructType schema) {
- for (StructField field: schema.fields()) {
- if (!UnsafeRow.settableFieldTypes.contains(field.dataType())) {
- return false;
- }
- }
- return true;
- }
-
- /**
* Create a new UnsafeFixedWidthAggregationMap.
*
- * @param emptyAggregationBuffer the default value for new keys (a "zero" of the agg. function)
- * @param aggregationBufferSchema the schema of the aggregation buffer, used for row conversion.
- * @param groupingKeySchema the schema of the grouping key, used for row conversion.
+ * @param initProjection the default value for new keys (a "zero" of the agg. function)
+ * @param keyConverter the converter of the grouping key, used for row conversion.
+ * @param bufferConverter the converter of the aggregation buffer, used for row conversion.
* @param memoryManager the memory manager used to allocate our Unsafe memory structures.
* @param initialCapacity the initial capacity of the map (a sizing hint to avoid re-hashing).
* @param enablePerfMetrics if true, performance metrics will be recorded (has minor perf impact)
*/
public UnsafeFixedWidthAggregationMap(
- InternalRow emptyAggregationBuffer,
- StructType aggregationBufferSchema,
- StructType groupingKeySchema,
+ Function1<InternalRow, InternalRow> initProjection,
+ UnsafeRowConverter keyConverter,
+ UnsafeRowConverter bufferConverter,
TaskMemoryManager memoryManager,
int initialCapacity,
boolean enablePerfMetrics) {
- this.emptyAggregationBuffer =
- convertToUnsafeRow(emptyAggregationBuffer, aggregationBufferSchema);
- this.aggregationBufferSchema = aggregationBufferSchema;
- this.groupingKeyToUnsafeRowConverter = new UnsafeRowConverter(groupingKeySchema);
- this.groupingKeySchema = groupingKeySchema;
- this.map = new BytesToBytesMap(memoryManager, initialCapacity, enablePerfMetrics);
+ this.initProjection = initProjection;
+ this.keyConverter = keyConverter;
+ this.bufferConverter = bufferConverter;
this.enablePerfMetrics = enablePerfMetrics;
- }
- /**
- * Convert a Java object row into an UnsafeRow, allocating it into a new byte array.
- */
- private static byte[] convertToUnsafeRow(InternalRow javaRow, StructType schema) {
- final UnsafeRowConverter converter = new UnsafeRowConverter(schema);
- final byte[] unsafeRow = new byte[converter.getSizeRequirement(javaRow)];
- final int writtenLength =
- converter.writeRow(javaRow, unsafeRow, PlatformDependent.BYTE_ARRAY_OFFSET);
- assert (writtenLength == unsafeRow.length): "Size requirement calculation was wrong!";
- return unsafeRow;
+ this.map = new BytesToBytesMap(memoryManager, initialCapacity, enablePerfMetrics);
+ this.keyPool = new UniqueObjectPool(100);
+ this.bufferPool = new ObjectPool(initialCapacity);
+
+ InternalRow initRow = initProjection.apply(emptyRow);
+ this.emptyBuffer = new byte[bufferConverter.getSizeRequirement(initRow)];
+ int writtenLength = bufferConverter.writeRow(
+ initRow, emptyBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, bufferPool);
+ assert (writtenLength == emptyBuffer.length): "Size requirement calculation was wrong!";
+ // re-use the empty buffer only when there is no object saved in pool.
+ reuseEmptyBuffer = bufferPool.size() == 0;
}
/**
@@ -138,15 +133,16 @@ public final class UnsafeFixedWidthAggregationMap {
* return the same object.
*/
public UnsafeRow getAggregationBuffer(InternalRow groupingKey) {
- final int groupingKeySize = groupingKeyToUnsafeRowConverter.getSizeRequirement(groupingKey);
+ final int groupingKeySize = keyConverter.getSizeRequirement(groupingKey);
// Make sure that the buffer is large enough to hold the key. If it's not, grow it:
if (groupingKeySize > groupingKeyConversionScratchSpace.length) {
groupingKeyConversionScratchSpace = new byte[groupingKeySize];
}
- final int actualGroupingKeySize = groupingKeyToUnsafeRowConverter.writeRow(
+ final int actualGroupingKeySize = keyConverter.writeRow(
groupingKey,
groupingKeyConversionScratchSpace,
- PlatformDependent.BYTE_ARRAY_OFFSET);
+ PlatformDependent.BYTE_ARRAY_OFFSET,
+ keyPool);
assert (groupingKeySize == actualGroupingKeySize) : "Size requirement calculation was wrong!";
// Probe our map using the serialized key
@@ -157,25 +153,31 @@ public final class UnsafeFixedWidthAggregationMap {
if (!loc.isDefined()) {
// This is the first time that we've seen this grouping key, so we'll insert a copy of the
// empty aggregation buffer into the map:
+ if (!reuseEmptyBuffer) {
+ // There is some objects referenced by emptyBuffer, so generate a new one
+ InternalRow initRow = initProjection.apply(emptyRow);
+ bufferConverter.writeRow(initRow, emptyBuffer, PlatformDependent.BYTE_ARRAY_OFFSET,
+ bufferPool);
+ }
loc.putNewKey(
groupingKeyConversionScratchSpace,
PlatformDependent.BYTE_ARRAY_OFFSET,
groupingKeySize,
- emptyAggregationBuffer,
+ emptyBuffer,
PlatformDependent.BYTE_ARRAY_OFFSET,
- emptyAggregationBuffer.length
+ emptyBuffer.length
);
}
// Reset the pointer to point to the value that we just stored or looked up:
final MemoryLocation address = loc.getValueAddress();
- currentAggregationBuffer.pointTo(
+ currentBuffer.pointTo(
address.getBaseObject(),
address.getBaseOffset(),
- aggregationBufferSchema.length(),
- aggregationBufferSchema
+ bufferConverter.numFields(),
+ bufferPool
);
- return currentAggregationBuffer;
+ return currentBuffer;
}
/**
@@ -211,14 +213,14 @@ public final class UnsafeFixedWidthAggregationMap {
entry.key.pointTo(
keyAddress.getBaseObject(),
keyAddress.getBaseOffset(),
- groupingKeySchema.length(),
- groupingKeySchema
+ keyConverter.numFields(),
+ keyPool
);
entry.value.pointTo(
valueAddress.getBaseObject(),
valueAddress.getBaseOffset(),
- aggregationBufferSchema.length(),
- aggregationBufferSchema
+ bufferConverter.numFields(),
+ bufferPool
);
return entry;
}
@@ -246,6 +248,8 @@ public final class UnsafeFixedWidthAggregationMap {
System.out.println("Number of hash collisions: " + map.getNumHashCollisions());
System.out.println("Time spent resizing (ns): " + map.getTimeSpentResizingNs());
System.out.println("Total memory consumption (bytes): " + map.getTotalMemoryConsumption());
+ System.out.println("Number of unique objects in keys: " + keyPool.size());
+ System.out.println("Number of objects in buffers: " + bufferPool.size());
}
}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
index 11d51d90f1..f077064a02 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
@@ -17,20 +17,12 @@
package org.apache.spark.sql.catalyst.expressions;
-import javax.annotation.Nullable;
-import java.util.Arrays;
-import java.util.Collections;
-import java.util.HashSet;
-import java.util.Set;
-
import org.apache.spark.sql.catalyst.InternalRow;
-import org.apache.spark.sql.types.DataType;
-import org.apache.spark.sql.types.StructType;
+import org.apache.spark.sql.catalyst.util.ObjectPool;
import org.apache.spark.unsafe.PlatformDependent;
import org.apache.spark.unsafe.bitset.BitSetMethods;
import org.apache.spark.unsafe.types.UTF8String;
-import static org.apache.spark.sql.types.DataTypes.*;
/**
* An Unsafe implementation of Row which is backed by raw memory instead of Java objects.
@@ -44,7 +36,20 @@ import static org.apache.spark.sql.types.DataTypes.*;
* primitive types, such as long, double, or int, we store the value directly in the word. For
* fields with non-primitive or variable-length values, we store a relative offset (w.r.t. the
* base address of the row) that points to the beginning of the variable-length field, and length
- * (they are combined into a long).
+ * (they are combined into a long). For other objects, they are stored in a pool, the indexes of
+ * them are hold in the the word.
+ *
+ * In order to support fast hashing and equality checks for UnsafeRows that contain objects
+ * when used as grouping key in BytesToBytesMap, we put the objects in an UniqueObjectPool to make
+ * sure all the key have the same index for same object, then we can hash/compare the objects by
+ * hash/compare the index.
+ *
+ * For non-primitive types, the word of a field could be:
+ * UNION {
+ * [1] [offset: 31bits] [length: 31bits] // StringType
+ * [0] [offset: 31bits] [length: 31bits] // BinaryType
+ * - [index: 63bits] // StringType, Binary, index to object in pool
+ * }
*
* Instances of `UnsafeRow` act as pointers to row data stored in this format.
*/
@@ -53,8 +58,12 @@ public final class UnsafeRow extends MutableRow {
private Object baseObject;
private long baseOffset;
+ /** A pool to hold non-primitive objects */
+ private ObjectPool pool;
+
Object getBaseObject() { return baseObject; }
long getBaseOffset() { return baseOffset; }
+ ObjectPool getPool() { return pool; }
/** The number of fields in this row, used for calculating the bitset width (and in assertions) */
private int numFields;
@@ -63,15 +72,6 @@ public final class UnsafeRow extends MutableRow {
/** The width of the null tracking bit set, in bytes */
private int bitSetWidthInBytes;
- /**
- * This optional schema is required if you want to call generic get() and set() methods on
- * this UnsafeRow, but is optional if callers will only use type-specific getTYPE() and setTYPE()
- * methods. This should be removed after the planned InternalRow / Row split; right now, it's only
- * needed by the generic get() method, which is only called internally by code that accesses
- * UTF8String-typed columns.
- */
- @Nullable
- private StructType schema;
private long getFieldOffset(int ordinal) {
return baseOffset + bitSetWidthInBytes + ordinal * 8L;
@@ -81,42 +81,7 @@ public final class UnsafeRow extends MutableRow {
return ((numFields / 64) + (numFields % 64 == 0 ? 0 : 1)) * 8;
}
- /**
- * Field types that can be updated in place in UnsafeRows (e.g. we support set() for these types)
- */
- public static final Set<DataType> settableFieldTypes;
-
- /**
- * Fields types can be read(but not set (e.g. set() will throw UnsupportedOperationException).
- */
- public static final Set<DataType> readableFieldTypes;
-
- // TODO: support DecimalType
- static {
- settableFieldTypes = Collections.unmodifiableSet(
- new HashSet<DataType>(
- Arrays.asList(new DataType[] {
- NullType,
- BooleanType,
- ByteType,
- ShortType,
- IntegerType,
- LongType,
- FloatType,
- DoubleType,
- DateType,
- TimestampType
- })));
-
- // We support get() on a superset of the types for which we support set():
- final Set<DataType> _readableFieldTypes = new HashSet<DataType>(
- Arrays.asList(new DataType[]{
- StringType,
- BinaryType
- }));
- _readableFieldTypes.addAll(settableFieldTypes);
- readableFieldTypes = Collections.unmodifiableSet(_readableFieldTypes);
- }
+ public static final long OFFSET_BITS = 31L;
/**
* Construct a new UnsafeRow. The resulting row won't be usable until `pointTo()` has been called,
@@ -130,22 +95,15 @@ public final class UnsafeRow extends MutableRow {
* @param baseObject the base object
* @param baseOffset the offset within the base object
* @param numFields the number of fields in this row
- * @param schema an optional schema; this is necessary if you want to call generic get() or set()
- * methods on this row, but is optional if the caller will only use type-specific
- * getTYPE() and setTYPE() methods.
+ * @param pool the object pool to hold arbitrary objects
*/
- public void pointTo(
- Object baseObject,
- long baseOffset,
- int numFields,
- @Nullable StructType schema) {
+ public void pointTo(Object baseObject, long baseOffset, int numFields, ObjectPool pool) {
assert numFields >= 0 : "numFields should >= 0";
- assert schema == null || schema.fields().length == numFields;
this.bitSetWidthInBytes = calculateBitSetWidthInBytes(numFields);
this.baseObject = baseObject;
this.baseOffset = baseOffset;
this.numFields = numFields;
- this.schema = schema;
+ this.pool = pool;
}
private void assertIndexIsValid(int index) {
@@ -168,9 +126,68 @@ public final class UnsafeRow extends MutableRow {
BitSetMethods.unset(baseObject, baseOffset, i);
}
+ /**
+ * Updates the column `i` as Object `value`, which cannot be primitive types.
+ */
@Override
- public void update(int ordinal, Object value) {
- throw new UnsupportedOperationException();
+ public void update(int i, Object value) {
+ if (value == null) {
+ if (!isNullAt(i)) {
+ // remove the old value from pool
+ long idx = getLong(i);
+ if (idx <= 0) {
+ // this is the index of old value in pool, remove it
+ pool.replace((int)-idx, null);
+ } else {
+ // there will be some garbage left (UTF8String or byte[])
+ }
+ setNullAt(i);
+ }
+ return;
+ }
+
+ if (isNullAt(i)) {
+ // there is not an old value, put the new value into pool
+ int idx = pool.put(value);
+ setLong(i, (long)-idx);
+ } else {
+ // there is an old value, check the type, then replace it or update it
+ long v = getLong(i);
+ if (v <= 0) {
+ // it's the index in the pool, replace old value with new one
+ int idx = (int)-v;
+ pool.replace(idx, value);
+ } else {
+ // old value is UTF8String or byte[], try to reuse the space
+ boolean isString;
+ byte[] newBytes;
+ if (value instanceof UTF8String) {
+ newBytes = ((UTF8String) value).getBytes();
+ isString = true;
+ } else {
+ newBytes = (byte[]) value;
+ isString = false;
+ }
+ int offset = (int) ((v >> OFFSET_BITS) & Integer.MAX_VALUE);
+ int oldLength = (int) (v & Integer.MAX_VALUE);
+ if (newBytes.length <= oldLength) {
+ // the new value can fit in the old buffer, re-use it
+ PlatformDependent.copyMemory(
+ newBytes,
+ PlatformDependent.BYTE_ARRAY_OFFSET,
+ baseObject,
+ baseOffset + offset,
+ newBytes.length);
+ long flag = isString ? 1L << (OFFSET_BITS * 2) : 0L;
+ setLong(i, flag | (((long) offset) << OFFSET_BITS) | (long) newBytes.length);
+ } else {
+ // Cannot fit in the buffer
+ int idx = pool.put(value);
+ setLong(i, (long) -idx);
+ }
+ }
+ }
+ setNotNullAt(i);
}
@Override
@@ -227,28 +244,38 @@ public final class UnsafeRow extends MutableRow {
return numFields;
}
- @Override
- public StructType schema() {
- return schema;
- }
-
+ /**
+ * Returns the object for column `i`, which should not be primitive type.
+ */
@Override
public Object get(int i) {
assertIndexIsValid(i);
- assert (schema != null) : "Schema must be defined when calling generic get() method";
- final DataType dataType = schema.fields()[i].dataType();
- // UnsafeRow is only designed to be invoked by internal code, which only invokes this generic
- // get() method when trying to access UTF8String-typed columns. If we refactor the codebase to
- // separate the internal and external row interfaces, then internal code can fetch strings via
- // a new getUTF8String() method and we'll be able to remove this method.
if (isNullAt(i)) {
return null;
- } else if (dataType == StringType) {
- return getUTF8String(i);
- } else if (dataType == BinaryType) {
- return getBinary(i);
+ }
+ long v = PlatformDependent.UNSAFE.getLong(baseObject, getFieldOffset(i));
+ if (v <= 0) {
+ // It's an index to object in the pool.
+ int idx = (int)-v;
+ return pool.get(idx);
} else {
- throw new UnsupportedOperationException();
+ // The column could be StingType or BinaryType
+ boolean isString = (v >> (OFFSET_BITS * 2)) > 0;
+ int offset = (int) ((v >> OFFSET_BITS) & Integer.MAX_VALUE);
+ int size = (int) (v & Integer.MAX_VALUE);
+ final byte[] bytes = new byte[size];
+ PlatformDependent.copyMemory(
+ baseObject,
+ baseOffset + offset,
+ bytes,
+ PlatformDependent.BYTE_ARRAY_OFFSET,
+ size
+ );
+ if (isString) {
+ return UTF8String.fromBytes(bytes);
+ } else {
+ return bytes;
+ }
}
}
@@ -308,31 +335,6 @@ public final class UnsafeRow extends MutableRow {
}
}
- public UTF8String getUTF8String(int i) {
- return UTF8String.fromBytes(getBinary(i));
- }
-
- public byte[] getBinary(int i) {
- assertIndexIsValid(i);
- final long offsetAndSize = getLong(i);
- final int offset = (int)(offsetAndSize >> 32);
- final int size = (int)(offsetAndSize & ((1L << 32) - 1));
- final byte[] bytes = new byte[size];
- PlatformDependent.copyMemory(
- baseObject,
- baseOffset + offset,
- bytes,
- PlatformDependent.BYTE_ARRAY_OFFSET,
- size
- );
- return bytes;
- }
-
- @Override
- public String getString(int i) {
- return getUTF8String(i).toString();
- }
-
@Override
public InternalRow copy() {
throw new UnsupportedOperationException();
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/util/ObjectPool.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/util/ObjectPool.java
new file mode 100644
index 0000000000..97f89a7d0b
--- /dev/null
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/util/ObjectPool.java
@@ -0,0 +1,78 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.util;
+
+/**
+ * A object pool stores a collection of objects in array, then they can be referenced by the
+ * pool plus an index.
+ */
+public class ObjectPool {
+
+ /**
+ * An array to hold objects, which will grow as needed.
+ */
+ private Object[] objects;
+
+ /**
+ * How many objects in the pool.
+ */
+ private int numObj;
+
+ public ObjectPool(int capacity) {
+ objects = new Object[capacity];
+ numObj = 0;
+ }
+
+ /**
+ * Returns how many objects in the pool.
+ */
+ public int size() {
+ return numObj;
+ }
+
+ /**
+ * Returns the object at position `idx` in the array.
+ */
+ public Object get(int idx) {
+ assert (idx < numObj);
+ return objects[idx];
+ }
+
+ /**
+ * Puts an object `obj` at the end of array, returns the index of it.
+ * <p/>
+ * The array will grow as needed.
+ */
+ public int put(Object obj) {
+ if (numObj >= objects.length) {
+ Object[] tmp = new Object[objects.length * 2];
+ System.arraycopy(objects, 0, tmp, 0, objects.length);
+ objects = tmp;
+ }
+ objects[numObj++] = obj;
+ return numObj - 1;
+ }
+
+ /**
+ * Replaces the object at `idx` with new one `obj`.
+ */
+ public void replace(int idx, Object obj) {
+ assert (idx < numObj);
+ objects[idx] = obj;
+ }
+}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/util/UniqueObjectPool.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/util/UniqueObjectPool.java
new file mode 100644
index 0000000000..d512392dca
--- /dev/null
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/util/UniqueObjectPool.java
@@ -0,0 +1,59 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.util;
+
+import java.util.HashMap;
+
+/**
+ * An unique object pool stores a collection of unique objects in it.
+ */
+public class UniqueObjectPool extends ObjectPool {
+
+ /**
+ * A hash map from objects to their indexes in the array.
+ */
+ private HashMap<Object, Integer> objIndex;
+
+ public UniqueObjectPool(int capacity) {
+ super(capacity);
+ objIndex = new HashMap<Object, Integer>();
+ }
+
+ /**
+ * Put an object `obj` into the pool. If there is an existing object equals to `obj`, it will
+ * return the index of the existing one.
+ */
+ @Override
+ public int put(Object obj) {
+ if (objIndex.containsKey(obj)) {
+ return objIndex.get(obj);
+ } else {
+ int idx = super.put(obj);
+ objIndex.put(obj, idx);
+ return idx;
+ }
+ }
+
+ /**
+ * The objects can not be replaced.
+ */
+ @Override
+ public void replace(int idx, Object obj) {
+ throw new UnsupportedOperationException();
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala
index 61a29c89d8..57de0f26a9 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala
@@ -28,7 +28,10 @@ import org.apache.spark.unsafe.types.UTF8String
abstract class InternalRow extends Row {
// This is only use for test
- override def getString(i: Int): String = getAs[UTF8String](i).toString
+ override def getString(i: Int): String = {
+ val str = getAs[UTF8String](i)
+ if (str != null) str.toString else null
+ }
// These expensive API should not be used internally.
final override def getDecimal(i: Int): java.math.BigDecimal =
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala
index b61d490429..b11fc245c4 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala
@@ -17,6 +17,7 @@
package org.apache.spark.sql.catalyst.expressions
+import org.apache.spark.sql.catalyst.util.ObjectPool
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.PlatformDependent
import org.apache.spark.unsafe.array.ByteArrayMethods
@@ -33,6 +34,8 @@ class UnsafeRowConverter(fieldTypes: Array[DataType]) {
this(schema.fields.map(_.dataType))
}
+ def numFields: Int = fieldTypes.length
+
/** Re-used pointer to the unsafe row being written */
private[this] val unsafeRow = new UnsafeRow()
@@ -68,8 +71,8 @@ class UnsafeRowConverter(fieldTypes: Array[DataType]) {
* @param baseOffset the base offset of the destination address
* @return the number of bytes written. This should be equal to `getSizeRequirement(row)`.
*/
- def writeRow(row: InternalRow, baseObject: Object, baseOffset: Long): Int = {
- unsafeRow.pointTo(baseObject, baseOffset, writers.length, null)
+ def writeRow(row: InternalRow, baseObject: Object, baseOffset: Long, pool: ObjectPool): Int = {
+ unsafeRow.pointTo(baseObject, baseOffset, writers.length, pool)
if (writers.length > 0) {
// zero-out the bitset
@@ -84,16 +87,16 @@ class UnsafeRowConverter(fieldTypes: Array[DataType]) {
}
var fieldNumber = 0
- var appendCursor: Int = fixedLengthSize
+ var cursor: Int = fixedLengthSize
while (fieldNumber < writers.length) {
if (row.isNullAt(fieldNumber)) {
unsafeRow.setNullAt(fieldNumber)
} else {
- appendCursor += writers(fieldNumber).write(row, unsafeRow, fieldNumber, appendCursor)
+ cursor += writers(fieldNumber).write(row, unsafeRow, fieldNumber, cursor)
}
fieldNumber += 1
}
- appendCursor
+ cursor
}
}
@@ -108,11 +111,11 @@ private abstract class UnsafeColumnWriter {
* @param source the row being converted
* @param target a pointer to the converted unsafe row
* @param column the column to write
- * @param appendCursor the offset from the start of the unsafe row to the end of the row;
+ * @param cursor the offset from the start of the unsafe row to the end of the row;
* used for calculating where variable-length data should be written
* @return the number of variable-length bytes written
*/
- def write(source: InternalRow, target: UnsafeRow, column: Int, appendCursor: Int): Int
+ def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int
/**
* Return the number of bytes that are needed to write this variable-length value.
@@ -134,8 +137,7 @@ private object UnsafeColumnWriter {
case DoubleType => DoubleUnsafeColumnWriter
case StringType => StringUnsafeColumnWriter
case BinaryType => BinaryUnsafeColumnWriter
- case t =>
- throw new UnsupportedOperationException(s"Do not know how to write columns of type $t")
+ case t => ObjectUnsafeColumnWriter
}
}
}
@@ -152,6 +154,7 @@ private object FloatUnsafeColumnWriter extends FloatUnsafeColumnWriter
private object DoubleUnsafeColumnWriter extends DoubleUnsafeColumnWriter
private object StringUnsafeColumnWriter extends StringUnsafeColumnWriter
private object BinaryUnsafeColumnWriter extends BinaryUnsafeColumnWriter
+private object ObjectUnsafeColumnWriter extends ObjectUnsafeColumnWriter
private abstract class PrimitiveUnsafeColumnWriter extends UnsafeColumnWriter {
// Primitives don't write to the variable-length region:
@@ -159,88 +162,56 @@ private abstract class PrimitiveUnsafeColumnWriter extends UnsafeColumnWriter {
}
private class NullUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter {
- override def write(
- source: InternalRow,
- target: UnsafeRow,
- column: Int,
- appendCursor: Int): Int = {
+ override def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int = {
target.setNullAt(column)
0
}
}
private class BooleanUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter {
- override def write(
- source: InternalRow,
- target: UnsafeRow,
- column: Int,
- appendCursor: Int): Int = {
+ override def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int = {
target.setBoolean(column, source.getBoolean(column))
0
}
}
private class ByteUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter {
- override def write(
- source: InternalRow,
- target: UnsafeRow,
- column: Int,
- appendCursor: Int): Int = {
+ override def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int = {
target.setByte(column, source.getByte(column))
0
}
}
private class ShortUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter {
- override def write(
- source: InternalRow,
- target: UnsafeRow,
- column: Int,
- appendCursor: Int): Int = {
+ override def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int = {
target.setShort(column, source.getShort(column))
0
}
}
private class IntUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter {
- override def write(
- source: InternalRow,
- target: UnsafeRow,
- column: Int,
- appendCursor: Int): Int = {
+ override def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int = {
target.setInt(column, source.getInt(column))
0
}
}
private class LongUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter {
- override def write(
- source: InternalRow,
- target: UnsafeRow,
- column: Int,
- appendCursor: Int): Int = {
+ override def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int = {
target.setLong(column, source.getLong(column))
0
}
}
private class FloatUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter {
- override def write(
- source: InternalRow,
- target: UnsafeRow,
- column: Int,
- appendCursor: Int): Int = {
+ override def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int = {
target.setFloat(column, source.getFloat(column))
0
}
}
private class DoubleUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter {
- override def write(
- source: InternalRow,
- target: UnsafeRow,
- column: Int,
- appendCursor: Int): Int = {
+ override def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int = {
target.setDouble(column, source.getDouble(column))
0
}
@@ -255,12 +226,10 @@ private abstract class BytesUnsafeColumnWriter extends UnsafeColumnWriter {
ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes)
}
- override def write(
- source: InternalRow,
- target: UnsafeRow,
- column: Int,
- appendCursor: Int): Int = {
- val offset = target.getBaseOffset + appendCursor
+ protected[this] def isString: Boolean
+
+ override def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int = {
+ val offset = target.getBaseOffset + cursor
val bytes = getBytes(source, column)
val numBytes = bytes.length
if ((numBytes & 0x07) > 0) {
@@ -274,19 +243,32 @@ private abstract class BytesUnsafeColumnWriter extends UnsafeColumnWriter {
offset,
numBytes
)
- target.setLong(column, (appendCursor.toLong << 32L) | numBytes.toLong)
+ val flag = if (isString) 1L << (UnsafeRow.OFFSET_BITS * 2) else 0
+ target.setLong(column, flag | (cursor.toLong << UnsafeRow.OFFSET_BITS) | numBytes.toLong)
ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes)
}
}
private class StringUnsafeColumnWriter private() extends BytesUnsafeColumnWriter {
+ protected[this] def isString: Boolean = true
def getBytes(source: InternalRow, column: Int): Array[Byte] = {
source.getAs[UTF8String](column).getBytes
}
}
private class BinaryUnsafeColumnWriter private() extends BytesUnsafeColumnWriter {
+ protected[this] def isString: Boolean = false
def getBytes(source: InternalRow, column: Int): Array[Byte] = {
source.getAs[Array[Byte]](column)
}
}
+
+private class ObjectUnsafeColumnWriter private() extends UnsafeColumnWriter {
+ def getSize(sourceRow: InternalRow, column: Int): Int = 0
+ override def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int = {
+ val obj = source.get(column)
+ val idx = target.getPool.put(obj)
+ target.setLong(column, - idx)
+ 0
+ }
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala
index 3095ccb777..6fafc2f866 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala
@@ -23,8 +23,9 @@ import scala.util.Random
import org.scalatest.{BeforeAndAfterEach, Matchers}
import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.catalyst.expressions.codegen.GenerateProjection
import org.apache.spark.sql.types._
-import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, TaskMemoryManager, MemoryAllocator}
+import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, MemoryAllocator, TaskMemoryManager}
import org.apache.spark.unsafe.types.UTF8String
@@ -33,10 +34,10 @@ class UnsafeFixedWidthAggregationMapSuite
with Matchers
with BeforeAndAfterEach {
- import UnsafeFixedWidthAggregationMap._
-
private val groupKeySchema = StructType(StructField("product", StringType) :: Nil)
private val aggBufferSchema = StructType(StructField("salePrice", IntegerType) :: Nil)
+ private def emptyProjection: Projection =
+ GenerateProjection.generate(Seq(Literal(0)), Seq(AttributeReference("price", IntegerType)()))
private def emptyAggregationBuffer: InternalRow = InternalRow(0)
private var memoryManager: TaskMemoryManager = null
@@ -52,21 +53,11 @@ class UnsafeFixedWidthAggregationMapSuite
}
}
- test("supported schemas") {
- assert(!supportsAggregationBufferSchema(StructType(StructField("x", StringType) :: Nil)))
- assert(supportsGroupKeySchema(StructType(StructField("x", StringType) :: Nil)))
-
- assert(
- !supportsAggregationBufferSchema(StructType(StructField("x", ArrayType(IntegerType)) :: Nil)))
- assert(
- !supportsGroupKeySchema(StructType(StructField("x", ArrayType(IntegerType)) :: Nil)))
- }
-
test("empty map") {
val map = new UnsafeFixedWidthAggregationMap(
- emptyAggregationBuffer,
- aggBufferSchema,
- groupKeySchema,
+ emptyProjection,
+ new UnsafeRowConverter(groupKeySchema),
+ new UnsafeRowConverter(aggBufferSchema),
memoryManager,
1024, // initial capacity
false // disable perf metrics
@@ -77,9 +68,9 @@ class UnsafeFixedWidthAggregationMapSuite
test("updating values for a single key") {
val map = new UnsafeFixedWidthAggregationMap(
- emptyAggregationBuffer,
- aggBufferSchema,
- groupKeySchema,
+ emptyProjection,
+ new UnsafeRowConverter(groupKeySchema),
+ new UnsafeRowConverter(aggBufferSchema),
memoryManager,
1024, // initial capacity
false // disable perf metrics
@@ -103,9 +94,9 @@ class UnsafeFixedWidthAggregationMapSuite
test("inserting large random keys") {
val map = new UnsafeFixedWidthAggregationMap(
- emptyAggregationBuffer,
- aggBufferSchema,
- groupKeySchema,
+ emptyProjection,
+ new UnsafeRowConverter(groupKeySchema),
+ new UnsafeRowConverter(aggBufferSchema),
memoryManager,
128, // initial capacity
false // disable perf metrics
@@ -120,6 +111,36 @@ class UnsafeFixedWidthAggregationMapSuite
}.toSet
seenKeys.size should be (groupKeys.size)
seenKeys should be (groupKeys)
+
+ map.free()
+ }
+
+ test("with decimal in the key and values") {
+ val groupKeySchema = StructType(StructField("price", DecimalType(10, 0)) :: Nil)
+ val aggBufferSchema = StructType(StructField("amount", DecimalType.Unlimited) :: Nil)
+ val emptyProjection = GenerateProjection.generate(Seq(Literal(Decimal(0))),
+ Seq(AttributeReference("price", DecimalType.Unlimited)()))
+ val map = new UnsafeFixedWidthAggregationMap(
+ emptyProjection,
+ new UnsafeRowConverter(groupKeySchema),
+ new UnsafeRowConverter(aggBufferSchema),
+ memoryManager,
+ 1, // initial capacity
+ false // disable perf metrics
+ )
+
+ (0 until 100).foreach { i =>
+ val groupKey = InternalRow(Decimal(i % 10))
+ val row = map.getAggregationBuffer(groupKey)
+ row.update(0, Decimal(i))
+ }
+ val seenKeys: Set[Int] = map.iterator().asScala.map { entry =>
+ entry.key.getAs[Decimal](0).toInt
+ }.toSet
+ seenKeys.size should be (10)
+ seenKeys should be ((0 until 10).toSet)
+
+ map.free()
}
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala
index c0675f4f4d..94c2f3242b 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala
@@ -23,10 +23,11 @@ import java.util.Arrays
import org.scalatest.Matchers
import org.apache.spark.SparkFunSuite
-import org.apache.spark.sql.catalyst.util.DateTimeUtils
+import org.apache.spark.sql.catalyst.util.{ObjectPool, DateTimeUtils}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.PlatformDependent
import org.apache.spark.unsafe.array.ByteArrayMethods
+import org.apache.spark.unsafe.types.UTF8String
class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
@@ -40,16 +41,21 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
row.setInt(2, 2)
val sizeRequired: Int = converter.getSizeRequirement(row)
- sizeRequired should be (8 + (3 * 8))
+ assert(sizeRequired === 8 + (3 * 8))
val buffer: Array[Long] = new Array[Long](sizeRequired / 8)
- val numBytesWritten = converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET)
- numBytesWritten should be (sizeRequired)
+ val numBytesWritten = converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET, null)
+ assert(numBytesWritten === sizeRequired)
val unsafeRow = new UnsafeRow()
unsafeRow.pointTo(buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, null)
- unsafeRow.getLong(0) should be (0)
- unsafeRow.getLong(1) should be (1)
- unsafeRow.getInt(2) should be (2)
+ assert(unsafeRow.getLong(0) === 0)
+ assert(unsafeRow.getLong(1) === 1)
+ assert(unsafeRow.getInt(2) === 2)
+
+ unsafeRow.setLong(1, 3)
+ assert(unsafeRow.getLong(1) === 3)
+ unsafeRow.setInt(2, 4)
+ assert(unsafeRow.getInt(2) === 4)
}
test("basic conversion with primitive, string and binary types") {
@@ -58,22 +64,67 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
val row = new SpecificMutableRow(fieldTypes)
row.setLong(0, 0)
- row.setString(1, "Hello")
+ row.update(1, UTF8String.fromString("Hello"))
row.update(2, "World".getBytes)
val sizeRequired: Int = converter.getSizeRequirement(row)
- sizeRequired should be (8 + (8 * 3) +
+ assert(sizeRequired === 8 + (8 * 3) +
ByteArrayMethods.roundNumberOfBytesToNearestWord("Hello".getBytes.length) +
ByteArrayMethods.roundNumberOfBytesToNearestWord("World".getBytes.length))
val buffer: Array[Long] = new Array[Long](sizeRequired / 8)
- val numBytesWritten = converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET)
- numBytesWritten should be (sizeRequired)
+ val numBytesWritten = converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET, null)
+ assert(numBytesWritten === sizeRequired)
val unsafeRow = new UnsafeRow()
- unsafeRow.pointTo(buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, null)
- unsafeRow.getLong(0) should be (0)
- unsafeRow.getString(1) should be ("Hello")
- unsafeRow.getBinary(2) should be ("World".getBytes)
+ val pool = new ObjectPool(10)
+ unsafeRow.pointTo(buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, pool)
+ assert(unsafeRow.getLong(0) === 0)
+ assert(unsafeRow.getString(1) === "Hello")
+ assert(unsafeRow.get(2) === "World".getBytes)
+
+ unsafeRow.update(1, UTF8String.fromString("World"))
+ assert(unsafeRow.getString(1) === "World")
+ assert(pool.size === 0)
+ unsafeRow.update(1, UTF8String.fromString("Hello World"))
+ assert(unsafeRow.getString(1) === "Hello World")
+ assert(pool.size === 1)
+
+ unsafeRow.update(2, "World".getBytes)
+ assert(unsafeRow.get(2) === "World".getBytes)
+ assert(pool.size === 1)
+ unsafeRow.update(2, "Hello World".getBytes)
+ assert(unsafeRow.get(2) === "Hello World".getBytes)
+ assert(pool.size === 2)
+ }
+
+ test("basic conversion with primitive, decimal and array") {
+ val fieldTypes: Array[DataType] = Array(LongType, DecimalType(10, 0), ArrayType(StringType))
+ val converter = new UnsafeRowConverter(fieldTypes)
+
+ val row = new SpecificMutableRow(fieldTypes)
+ row.setLong(0, 0)
+ row.update(1, Decimal(1))
+ row.update(2, Array(2))
+
+ val pool = new ObjectPool(10)
+ val sizeRequired: Int = converter.getSizeRequirement(row)
+ assert(sizeRequired === 8 + (8 * 3))
+ val buffer: Array[Long] = new Array[Long](sizeRequired / 8)
+ val numBytesWritten = converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET, pool)
+ assert(numBytesWritten === sizeRequired)
+ assert(pool.size === 2)
+
+ val unsafeRow = new UnsafeRow()
+ unsafeRow.pointTo(buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, pool)
+ assert(unsafeRow.getLong(0) === 0)
+ assert(unsafeRow.get(1) === Decimal(1))
+ assert(unsafeRow.get(2) === Array(2))
+
+ unsafeRow.update(1, Decimal(2))
+ assert(unsafeRow.get(1) === Decimal(2))
+ unsafeRow.update(2, Array(3, 4))
+ assert(unsafeRow.get(2) === Array(3, 4))
+ assert(pool.size === 2)
}
test("basic conversion with primitive, string, date and timestamp types") {
@@ -87,21 +138,27 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
row.update(3, DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf("2015-05-08 08:10:25")))
val sizeRequired: Int = converter.getSizeRequirement(row)
- sizeRequired should be (8 + (8 * 4) +
+ assert(sizeRequired === 8 + (8 * 4) +
ByteArrayMethods.roundNumberOfBytesToNearestWord("Hello".getBytes.length))
val buffer: Array[Long] = new Array[Long](sizeRequired / 8)
- val numBytesWritten = converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET)
- numBytesWritten should be (sizeRequired)
+ val numBytesWritten = converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET, null)
+ assert(numBytesWritten === sizeRequired)
val unsafeRow = new UnsafeRow()
unsafeRow.pointTo(buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, null)
- unsafeRow.getLong(0) should be (0)
- unsafeRow.getString(1) should be ("Hello")
+ assert(unsafeRow.getLong(0) === 0)
+ assert(unsafeRow.getString(1) === "Hello")
// Date is represented as Int in unsafeRow
- DateTimeUtils.toJavaDate(unsafeRow.getInt(2)) should be (Date.valueOf("1970-01-01"))
+ assert(DateTimeUtils.toJavaDate(unsafeRow.getInt(2)) === Date.valueOf("1970-01-01"))
// Timestamp is represented as Long in unsafeRow
DateTimeUtils.toJavaTimestamp(unsafeRow.getLong(3)) should be
(Timestamp.valueOf("2015-05-08 08:10:25"))
+
+ unsafeRow.setInt(2, DateTimeUtils.fromJavaDate(Date.valueOf("2015-06-22")))
+ assert(DateTimeUtils.toJavaDate(unsafeRow.getInt(2)) === Date.valueOf("2015-06-22"))
+ unsafeRow.setLong(3, DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf("2015-06-22 08:10:25")))
+ DateTimeUtils.toJavaTimestamp(unsafeRow.getLong(3)) should be
+ (Timestamp.valueOf("2015-06-22 08:10:25"))
}
test("null handling") {
@@ -113,7 +170,12 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
IntegerType,
LongType,
FloatType,
- DoubleType)
+ DoubleType,
+ StringType,
+ BinaryType,
+ DecimalType.Unlimited,
+ ArrayType(IntegerType)
+ )
val converter = new UnsafeRowConverter(fieldTypes)
val rowWithAllNullColumns: InternalRow = {
@@ -127,8 +189,8 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
val sizeRequired: Int = converter.getSizeRequirement(rowWithAllNullColumns)
val createdFromNullBuffer: Array[Long] = new Array[Long](sizeRequired / 8)
val numBytesWritten = converter.writeRow(
- rowWithAllNullColumns, createdFromNullBuffer, PlatformDependent.LONG_ARRAY_OFFSET)
- numBytesWritten should be (sizeRequired)
+ rowWithAllNullColumns, createdFromNullBuffer, PlatformDependent.LONG_ARRAY_OFFSET, null)
+ assert(numBytesWritten === sizeRequired)
val createdFromNull = new UnsafeRow()
createdFromNull.pointTo(
@@ -136,13 +198,17 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
for (i <- 0 to fieldTypes.length - 1) {
assert(createdFromNull.isNullAt(i))
}
- createdFromNull.getBoolean(1) should be (false)
- createdFromNull.getByte(2) should be (0)
- createdFromNull.getShort(3) should be (0)
- createdFromNull.getInt(4) should be (0)
- createdFromNull.getLong(5) should be (0)
+ assert(createdFromNull.getBoolean(1) === false)
+ assert(createdFromNull.getByte(2) === 0)
+ assert(createdFromNull.getShort(3) === 0)
+ assert(createdFromNull.getInt(4) === 0)
+ assert(createdFromNull.getLong(5) === 0)
assert(java.lang.Float.isNaN(createdFromNull.getFloat(6)))
- assert(java.lang.Double.isNaN(createdFromNull.getFloat(7)))
+ assert(java.lang.Double.isNaN(createdFromNull.getDouble(7)))
+ assert(createdFromNull.getString(8) === null)
+ assert(createdFromNull.get(9) === null)
+ assert(createdFromNull.get(10) === null)
+ assert(createdFromNull.get(11) === null)
// If we have an UnsafeRow with columns that are initially non-null and we null out those
// columns, then the serialized row representation should be identical to what we would get by
@@ -157,28 +223,68 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
r.setLong(5, 500)
r.setFloat(6, 600)
r.setDouble(7, 700)
+ r.update(8, UTF8String.fromString("hello"))
+ r.update(9, "world".getBytes)
+ r.update(10, Decimal(10))
+ r.update(11, Array(11))
r
}
- val setToNullAfterCreationBuffer: Array[Long] = new Array[Long](sizeRequired / 8)
+ val pool = new ObjectPool(1)
+ val setToNullAfterCreationBuffer: Array[Long] = new Array[Long](sizeRequired / 8 + 2)
converter.writeRow(
- rowWithNoNullColumns, setToNullAfterCreationBuffer, PlatformDependent.LONG_ARRAY_OFFSET)
+ rowWithNoNullColumns, setToNullAfterCreationBuffer, PlatformDependent.LONG_ARRAY_OFFSET, pool)
val setToNullAfterCreation = new UnsafeRow()
setToNullAfterCreation.pointTo(
- setToNullAfterCreationBuffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, null)
+ setToNullAfterCreationBuffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, pool)
- setToNullAfterCreation.isNullAt(0) should be (rowWithNoNullColumns.isNullAt(0))
- setToNullAfterCreation.getBoolean(1) should be (rowWithNoNullColumns.getBoolean(1))
- setToNullAfterCreation.getByte(2) should be (rowWithNoNullColumns.getByte(2))
- setToNullAfterCreation.getShort(3) should be (rowWithNoNullColumns.getShort(3))
- setToNullAfterCreation.getInt(4) should be (rowWithNoNullColumns.getInt(4))
- setToNullAfterCreation.getLong(5) should be (rowWithNoNullColumns.getLong(5))
- setToNullAfterCreation.getFloat(6) should be (rowWithNoNullColumns.getFloat(6))
- setToNullAfterCreation.getDouble(7) should be (rowWithNoNullColumns.getDouble(7))
+ assert(setToNullAfterCreation.isNullAt(0) === rowWithNoNullColumns.isNullAt(0))
+ assert(setToNullAfterCreation.getBoolean(1) === rowWithNoNullColumns.getBoolean(1))
+ assert(setToNullAfterCreation.getByte(2) === rowWithNoNullColumns.getByte(2))
+ assert(setToNullAfterCreation.getShort(3) === rowWithNoNullColumns.getShort(3))
+ assert(setToNullAfterCreation.getInt(4) === rowWithNoNullColumns.getInt(4))
+ assert(setToNullAfterCreation.getLong(5) === rowWithNoNullColumns.getLong(5))
+ assert(setToNullAfterCreation.getFloat(6) === rowWithNoNullColumns.getFloat(6))
+ assert(setToNullAfterCreation.getDouble(7) === rowWithNoNullColumns.getDouble(7))
+ assert(setToNullAfterCreation.getString(8) === rowWithNoNullColumns.getString(8))
+ assert(setToNullAfterCreation.get(9) === rowWithNoNullColumns.get(9))
+ assert(setToNullAfterCreation.get(10) === rowWithNoNullColumns.get(10))
+ assert(setToNullAfterCreation.get(11) === rowWithNoNullColumns.get(11))
for (i <- 0 to fieldTypes.length - 1) {
+ if (i >= 8) {
+ setToNullAfterCreation.update(i, null)
+ }
setToNullAfterCreation.setNullAt(i)
}
- assert(Arrays.equals(createdFromNullBuffer, setToNullAfterCreationBuffer))
+ // There are some garbage left in the var-length area
+ assert(Arrays.equals(createdFromNullBuffer,
+ java.util.Arrays.copyOf(setToNullAfterCreationBuffer, sizeRequired / 8)))
+
+ setToNullAfterCreation.setNullAt(0)
+ setToNullAfterCreation.setBoolean(1, false)
+ setToNullAfterCreation.setByte(2, 20)
+ setToNullAfterCreation.setShort(3, 30)
+ setToNullAfterCreation.setInt(4, 400)
+ setToNullAfterCreation.setLong(5, 500)
+ setToNullAfterCreation.setFloat(6, 600)
+ setToNullAfterCreation.setDouble(7, 700)
+ setToNullAfterCreation.update(8, UTF8String.fromString("hello"))
+ setToNullAfterCreation.update(9, "world".getBytes)
+ setToNullAfterCreation.update(10, Decimal(10))
+ setToNullAfterCreation.update(11, Array(11))
+
+ assert(setToNullAfterCreation.isNullAt(0) === rowWithNoNullColumns.isNullAt(0))
+ assert(setToNullAfterCreation.getBoolean(1) === rowWithNoNullColumns.getBoolean(1))
+ assert(setToNullAfterCreation.getByte(2) === rowWithNoNullColumns.getByte(2))
+ assert(setToNullAfterCreation.getShort(3) === rowWithNoNullColumns.getShort(3))
+ assert(setToNullAfterCreation.getInt(4) === rowWithNoNullColumns.getInt(4))
+ assert(setToNullAfterCreation.getLong(5) === rowWithNoNullColumns.getLong(5))
+ assert(setToNullAfterCreation.getFloat(6) === rowWithNoNullColumns.getFloat(6))
+ assert(setToNullAfterCreation.getDouble(7) === rowWithNoNullColumns.getDouble(7))
+ assert(setToNullAfterCreation.getString(8) === rowWithNoNullColumns.getString(8))
+ assert(setToNullAfterCreation.get(9) === rowWithNoNullColumns.get(9))
+ assert(setToNullAfterCreation.get(10) === rowWithNoNullColumns.get(10))
+ assert(setToNullAfterCreation.get(11) === rowWithNoNullColumns.get(11))
}
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ObjectPoolSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ObjectPoolSuite.scala
new file mode 100644
index 0000000000..94764df4b9
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ObjectPoolSuite.scala
@@ -0,0 +1,57 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.util
+
+import org.scalatest.Matchers
+
+import org.apache.spark.SparkFunSuite
+
+class ObjectPoolSuite extends SparkFunSuite with Matchers {
+
+ test("pool") {
+ val pool = new ObjectPool(1)
+ assert(pool.put(1) === 0)
+ assert(pool.put("hello") === 1)
+ assert(pool.put(false) === 2)
+
+ assert(pool.get(0) === 1)
+ assert(pool.get(1) === "hello")
+ assert(pool.get(2) === false)
+ assert(pool.size() === 3)
+
+ pool.replace(1, "world")
+ assert(pool.get(1) === "world")
+ assert(pool.size() === 3)
+ }
+
+ test("unique pool") {
+ val pool = new UniqueObjectPool(1)
+ assert(pool.put(1) === 0)
+ assert(pool.put("hello") === 1)
+ assert(pool.put(1) === 0)
+ assert(pool.put("hello") === 1)
+
+ assert(pool.get(0) === 1)
+ assert(pool.get(1) === "hello")
+ assert(pool.size() === 2)
+
+ intercept[UnsupportedOperationException] {
+ pool.replace(1, "world")
+ }
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
index ba2c8f53d7..44930f82b5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
@@ -238,11 +238,6 @@ case class GeneratedAggregate(
StructType(fields)
}
- val schemaSupportsUnsafe: Boolean = {
- UnsafeFixedWidthAggregationMap.supportsAggregationBufferSchema(aggregationBufferSchema) &&
- UnsafeFixedWidthAggregationMap.supportsGroupKeySchema(groupKeySchema)
- }
-
child.execute().mapPartitions { iter =>
// Builds a new custom class for holding the results of aggregation for a group.
val initialValues = computeFunctions.flatMap(_.initialValues)
@@ -283,12 +278,12 @@ case class GeneratedAggregate(
val resultProjection = resultProjectionBuilder()
Iterator(resultProjection(buffer))
- } else if (unsafeEnabled && schemaSupportsUnsafe) {
+ } else if (unsafeEnabled) {
log.info("Using Unsafe-based aggregator")
val aggregationMap = new UnsafeFixedWidthAggregationMap(
- newAggregationBuffer(EmptyRow),
- aggregationBufferSchema,
- groupKeySchema,
+ newAggregationBuffer,
+ new UnsafeRowConverter(groupKeySchema),
+ new UnsafeRowConverter(aggregationBufferSchema),
TaskContext.get.taskMemoryManager(),
1024 * 16, // initial capacity
false // disable tracking of performance metrics
@@ -323,9 +318,6 @@ case class GeneratedAggregate(
}
}
} else {
- if (unsafeEnabled) {
- log.info("Not using Unsafe-based aggregator because it is not supported for this schema")
- }
val buffers = new java.util.HashMap[InternalRow, MutableRow]()
var currentRow: InternalRow = null