aboutsummaryrefslogtreecommitdiff
path: root/sql/catalyst
diff options
context:
space:
mode:
authorKazuaki Ishizaki <ishizaki@jp.ibm.com>2016-09-27 14:18:32 +0800
committerWenchen Fan <wenchen@databricks.com>2016-09-27 14:18:32 +0800
commit85b0a157543201895557d66306b38b3ca52f2151 (patch)
treecfeae0712fd72172e7d0fa1c16f50df32e7777f6 /sql/catalyst
parent6ee28423ad1b2e6089b82af64a31d77d3552bb38 (diff)
downloadspark-85b0a157543201895557d66306b38b3ca52f2151.tar.gz
spark-85b0a157543201895557d66306b38b3ca52f2151.tar.bz2
spark-85b0a157543201895557d66306b38b3ca52f2151.zip
[SPARK-15962][SQL] Introduce implementation with a dense format for UnsafeArrayData
## What changes were proposed in this pull request? This PR introduces more compact representation for ```UnsafeArrayData```. ```UnsafeArrayData``` needs to accept ```null``` value in each entry of an array. In the current version, it has three parts ``` [numElements] [offsets] [values] ``` `Offsets` has the number of `numElements`, and represents `null` if its value is negative. It may increase memory footprint, and introduces an indirection for accessing each of `values`. This PR uses bitvectors to represent nullability for each element like `UnsafeRow`, and eliminates an indirection for accessing each element. The new ```UnsafeArrayData``` has four parts. ``` [numElements][null bits][values or offset&length][variable length portion] ``` In the `null bits` region, we store 1 bit per element, represents whether an element is null. Its total size is ceil(numElements / 8) bytes, and it is aligned to 8-byte boundaries. In the `values or offset&length` region, we store the content of elements. For fields that hold fixed-length primitive types, such as long, double, or int, we store the value directly in the field. For fields with non-primitive or variable-length values, we store a relative offset (w.r.t. the base address of the array) that points to the beginning of the variable-length field and length (they are combined into a long). Each is word-aligned. For `variable length portion`, each is aligned to 8-byte boundaries. The new format can reduce memory footprint and improve performance of accessing each element. An example of memory foot comparison: 1024x1024 elements integer array Size of ```baseObject``` for ```UnsafeArrayData```: 8 + 1024x1024 + 1024x1024 = 2M bytes Size of ```baseObject``` for ```UnsafeArrayData```: 8 + 1024x1024/8 + 1024x1024 = 1.25M bytes In summary, we got 1.0-2.6x performance improvements over the code before applying this PR. Here are performance results of [benchmark programs](https://github.com/kiszk/spark/blob/04d2e4b6dbdc4eff43ce18b3c9b776e0129257c7/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/UnsafeArrayDataBenchmark.scala): **Read UnsafeArrayData**: 1.7x and 1.6x performance improvements over the code before applying this PR ```` OpenJDK 64-Bit Server VM 1.8.0_91-b14 on Linux 4.4.11-200.fc22.x86_64 Intel Xeon E3-12xx v2 (Ivy Bridge) Without SPARK-15962 Read UnsafeArrayData: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ Int 430 / 436 390.0 2.6 1.0X Double 456 / 485 367.8 2.7 0.9X With SPARK-15962 Read UnsafeArrayData: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ Int 252 / 260 666.1 1.5 1.0X Double 281 / 292 597.7 1.7 0.9X ```` **Write UnsafeArrayData**: 1.0x and 1.1x performance improvements over the code before applying this PR ```` OpenJDK 64-Bit Server VM 1.8.0_91-b14 on Linux 4.0.4-301.fc22.x86_64 Intel Xeon E3-12xx v2 (Ivy Bridge) Without SPARK-15962 Write UnsafeArrayData: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ Int 203 / 273 103.4 9.7 1.0X Double 239 / 356 87.9 11.4 0.8X With SPARK-15962 Write UnsafeArrayData: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ Int 196 / 249 107.0 9.3 1.0X Double 227 / 367 92.3 10.8 0.9X ```` **Get primitive array from UnsafeArrayData**: 2.6x and 1.6x performance improvements over the code before applying this PR ```` OpenJDK 64-Bit Server VM 1.8.0_91-b14 on Linux 4.0.4-301.fc22.x86_64 Intel Xeon E3-12xx v2 (Ivy Bridge) Without SPARK-15962 Get primitive array from UnsafeArrayData: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ Int 207 / 217 304.2 3.3 1.0X Double 257 / 363 245.2 4.1 0.8X With SPARK-15962 Get primitive array from UnsafeArrayData: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ Int 151 / 198 415.8 2.4 1.0X Double 214 / 394 293.6 3.4 0.7X ```` **Create UnsafeArrayData from primitive array**: 1.7x and 2.1x performance improvements over the code before applying this PR ```` OpenJDK 64-Bit Server VM 1.8.0_91-b14 on Linux 4.0.4-301.fc22.x86_64 Intel Xeon E3-12xx v2 (Ivy Bridge) Without SPARK-15962 Create UnsafeArrayData from primitive array: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ Int 340 / 385 185.1 5.4 1.0X Double 479 / 705 131.3 7.6 0.7X With SPARK-15962 Create UnsafeArrayData from primitive array: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ Int 206 / 211 306.0 3.3 1.0X Double 232 / 406 271.6 3.7 0.9X ```` 1.7x and 1.4x performance improvements in [```UDTSerializationBenchmark```](https://github.com/apache/spark/blob/master/mllib/src/test/scala/org/apache/spark/mllib/linalg/UDTSerializationBenchmark.scala) over the code before applying this PR ```` OpenJDK 64-Bit Server VM 1.8.0_91-b14 on Linux 4.4.11-200.fc22.x86_64 Intel Xeon E3-12xx v2 (Ivy Bridge) Without SPARK-15962 VectorUDT de/serialization: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ serialize 442 / 533 0.0 441927.1 1.0X deserialize 217 / 274 0.0 217087.6 2.0X With SPARK-15962 VectorUDT de/serialization: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------ serialize 265 / 318 0.0 265138.5 1.0X deserialize 155 / 197 0.0 154611.4 1.7X ```` ## How was this patch tested? Added unit tests into ```UnsafeArraySuite``` Author: Kazuaki Ishizaki <ishizaki@jp.ibm.com> Closes #13680 from kiszk/SPARK-15962.
Diffstat (limited to 'sql/catalyst')
-rw-r--r--sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java269
-rw-r--r--sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeMapData.java13
-rw-r--r--sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java193
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala31
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala23
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/UnsafeArraySuite.scala195
6 files changed, 504 insertions, 220 deletions
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java
index 6302660548..86523c1474 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java
@@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.util.ArrayData;
import org.apache.spark.sql.types.*;
import org.apache.spark.unsafe.Platform;
import org.apache.spark.unsafe.array.ByteArrayMethods;
+import org.apache.spark.unsafe.bitset.BitSetMethods;
import org.apache.spark.unsafe.hash.Murmur3_x86_32;
import org.apache.spark.unsafe.types.CalendarInterval;
import org.apache.spark.unsafe.types.UTF8String;
@@ -32,23 +33,31 @@ import org.apache.spark.unsafe.types.UTF8String;
/**
* An Unsafe implementation of Array which is backed by raw memory instead of Java objects.
*
- * Each tuple has three parts: [numElements] [offsets] [values]
+ * Each array has four parts:
+ * [numElements][null bits][values or offset&length][variable length portion]
*
- * The `numElements` is 4 bytes storing the number of elements of this array.
+ * The `numElements` is 8 bytes storing the number of elements of this array.
*
- * In the `offsets` region, we store 4 bytes per element, represents the relative offset (w.r.t. the
- * base address of the array) of this element in `values` region. We can get the length of this
- * element by subtracting next offset.
- * Note that offset can by negative which means this element is null.
+ * In the `null bits` region, we store 1 bit per element, represents whether an element is null
+ * Its total size is ceil(numElements / 8) bytes, and it is aligned to 8-byte boundaries.
*
- * In the `values` region, we store the content of elements. As we can get length info, so elements
- * can be variable-length.
+ * In the `values or offset&length` region, we store the content of elements. For fields that hold
+ * fixed-length primitive types, such as long, double, or int, we store the value directly
+ * in the field. The whole fixed-length portion (even for byte) is aligned to 8-byte boundaries.
+ * For fields with non-primitive or variable-length values, we store a relative offset
+ * (w.r.t. the base address of the array) that points to the beginning of the variable-length field
+ * and length (they are combined into a long). For variable length portion, each is aligned
+ * to 8-byte boundaries.
*
* Instances of `UnsafeArrayData` act as pointers to row data stored in this format.
*/
-// todo: there is a lof of duplicated code between UnsafeRow and UnsafeArrayData.
+
public final class UnsafeArrayData extends ArrayData {
+ public static int calculateHeaderPortionInBytes(int numFields) {
+ return 8 + ((numFields + 63)/ 64) * 8;
+ }
+
private Object baseObject;
private long baseOffset;
@@ -56,24 +65,19 @@ public final class UnsafeArrayData extends ArrayData {
private int numElements;
// The size of this array's backing data, in bytes.
- // The 4-bytes header of `numElements` is also included.
+ // The 8-bytes header of `numElements` is also included.
private int sizeInBytes;
- public Object getBaseObject() { return baseObject; }
- public long getBaseOffset() { return baseOffset; }
- public int getSizeInBytes() { return sizeInBytes; }
+ /** The position to start storing array elements, */
+ private long elementOffset;
- private int getElementOffset(int ordinal) {
- return Platform.getInt(baseObject, baseOffset + 4 + ordinal * 4L);
+ private long getElementOffset(int ordinal, int elementSize) {
+ return elementOffset + ordinal * elementSize;
}
- private int getElementSize(int offset, int ordinal) {
- if (ordinal == numElements - 1) {
- return sizeInBytes - offset;
- } else {
- return Math.abs(getElementOffset(ordinal + 1)) - offset;
- }
- }
+ public Object getBaseObject() { return baseObject; }
+ public long getBaseOffset() { return baseOffset; }
+ public int getSizeInBytes() { return sizeInBytes; }
private void assertIndexIsValid(int ordinal) {
assert ordinal >= 0 : "ordinal (" + ordinal + ") should >= 0";
@@ -102,20 +106,22 @@ public final class UnsafeArrayData extends ArrayData {
* @param sizeInBytes the size of this array's backing data, in bytes
*/
public void pointTo(Object baseObject, long baseOffset, int sizeInBytes) {
- // Read the number of elements from the first 4 bytes.
- final int numElements = Platform.getInt(baseObject, baseOffset);
+ // Read the number of elements from the first 8 bytes.
+ final long numElements = Platform.getLong(baseObject, baseOffset);
assert numElements >= 0 : "numElements (" + numElements + ") should >= 0";
+ assert numElements <= Integer.MAX_VALUE : "numElements (" + numElements + ") should <= Integer.MAX_VALUE";
- this.numElements = numElements;
+ this.numElements = (int)numElements;
this.baseObject = baseObject;
this.baseOffset = baseOffset;
this.sizeInBytes = sizeInBytes;
+ this.elementOffset = baseOffset + calculateHeaderPortionInBytes(this.numElements);
}
@Override
public boolean isNullAt(int ordinal) {
assertIndexIsValid(ordinal);
- return getElementOffset(ordinal) < 0;
+ return BitSetMethods.isSet(baseObject, baseOffset + 8, ordinal);
}
@Override
@@ -165,68 +171,50 @@ public final class UnsafeArrayData extends ArrayData {
@Override
public boolean getBoolean(int ordinal) {
assertIndexIsValid(ordinal);
- final int offset = getElementOffset(ordinal);
- if (offset < 0) return false;
- return Platform.getBoolean(baseObject, baseOffset + offset);
+ return Platform.getBoolean(baseObject, getElementOffset(ordinal, 1));
}
@Override
public byte getByte(int ordinal) {
assertIndexIsValid(ordinal);
- final int offset = getElementOffset(ordinal);
- if (offset < 0) return 0;
- return Platform.getByte(baseObject, baseOffset + offset);
+ return Platform.getByte(baseObject, getElementOffset(ordinal, 1));
}
@Override
public short getShort(int ordinal) {
assertIndexIsValid(ordinal);
- final int offset = getElementOffset(ordinal);
- if (offset < 0) return 0;
- return Platform.getShort(baseObject, baseOffset + offset);
+ return Platform.getShort(baseObject, getElementOffset(ordinal, 2));
}
@Override
public int getInt(int ordinal) {
assertIndexIsValid(ordinal);
- final int offset = getElementOffset(ordinal);
- if (offset < 0) return 0;
- return Platform.getInt(baseObject, baseOffset + offset);
+ return Platform.getInt(baseObject, getElementOffset(ordinal, 4));
}
@Override
public long getLong(int ordinal) {
assertIndexIsValid(ordinal);
- final int offset = getElementOffset(ordinal);
- if (offset < 0) return 0;
- return Platform.getLong(baseObject, baseOffset + offset);
+ return Platform.getLong(baseObject, getElementOffset(ordinal, 8));
}
@Override
public float getFloat(int ordinal) {
assertIndexIsValid(ordinal);
- final int offset = getElementOffset(ordinal);
- if (offset < 0) return 0;
- return Platform.getFloat(baseObject, baseOffset + offset);
+ return Platform.getFloat(baseObject, getElementOffset(ordinal, 4));
}
@Override
public double getDouble(int ordinal) {
assertIndexIsValid(ordinal);
- final int offset = getElementOffset(ordinal);
- if (offset < 0) return 0;
- return Platform.getDouble(baseObject, baseOffset + offset);
+ return Platform.getDouble(baseObject, getElementOffset(ordinal, 8));
}
@Override
public Decimal getDecimal(int ordinal, int precision, int scale) {
- assertIndexIsValid(ordinal);
- final int offset = getElementOffset(ordinal);
- if (offset < 0) return null;
-
+ if (isNullAt(ordinal)) return null;
if (precision <= Decimal.MAX_LONG_DIGITS()) {
- final long value = Platform.getLong(baseObject, baseOffset + offset);
- return Decimal.apply(value, precision, scale);
+ return Decimal.apply(getLong(ordinal), precision, scale);
} else {
final byte[] bytes = getBinary(ordinal);
final BigInteger bigInteger = new BigInteger(bytes);
@@ -237,19 +225,19 @@ public final class UnsafeArrayData extends ArrayData {
@Override
public UTF8String getUTF8String(int ordinal) {
- assertIndexIsValid(ordinal);
- final int offset = getElementOffset(ordinal);
- if (offset < 0) return null;
- final int size = getElementSize(offset, ordinal);
+ if (isNullAt(ordinal)) return null;
+ final long offsetAndSize = getLong(ordinal);
+ final int offset = (int) (offsetAndSize >> 32);
+ final int size = (int) offsetAndSize;
return UTF8String.fromAddress(baseObject, baseOffset + offset, size);
}
@Override
public byte[] getBinary(int ordinal) {
- assertIndexIsValid(ordinal);
- final int offset = getElementOffset(ordinal);
- if (offset < 0) return null;
- final int size = getElementSize(offset, ordinal);
+ if (isNullAt(ordinal)) return null;
+ final long offsetAndSize = getLong(ordinal);
+ final int offset = (int) (offsetAndSize >> 32);
+ final int size = (int) offsetAndSize;
final byte[] bytes = new byte[size];
Platform.copyMemory(baseObject, baseOffset + offset, bytes, Platform.BYTE_ARRAY_OFFSET, size);
return bytes;
@@ -257,9 +245,9 @@ public final class UnsafeArrayData extends ArrayData {
@Override
public CalendarInterval getInterval(int ordinal) {
- assertIndexIsValid(ordinal);
- final int offset = getElementOffset(ordinal);
- if (offset < 0) return null;
+ if (isNullAt(ordinal)) return null;
+ final long offsetAndSize = getLong(ordinal);
+ final int offset = (int) (offsetAndSize >> 32);
final int months = (int) Platform.getLong(baseObject, baseOffset + offset);
final long microseconds = Platform.getLong(baseObject, baseOffset + offset + 8);
return new CalendarInterval(months, microseconds);
@@ -267,10 +255,10 @@ public final class UnsafeArrayData extends ArrayData {
@Override
public UnsafeRow getStruct(int ordinal, int numFields) {
- assertIndexIsValid(ordinal);
- final int offset = getElementOffset(ordinal);
- if (offset < 0) return null;
- final int size = getElementSize(offset, ordinal);
+ if (isNullAt(ordinal)) return null;
+ final long offsetAndSize = getLong(ordinal);
+ final int offset = (int) (offsetAndSize >> 32);
+ final int size = (int) offsetAndSize;
final UnsafeRow row = new UnsafeRow(numFields);
row.pointTo(baseObject, baseOffset + offset, size);
return row;
@@ -278,10 +266,10 @@ public final class UnsafeArrayData extends ArrayData {
@Override
public UnsafeArrayData getArray(int ordinal) {
- assertIndexIsValid(ordinal);
- final int offset = getElementOffset(ordinal);
- if (offset < 0) return null;
- final int size = getElementSize(offset, ordinal);
+ if (isNullAt(ordinal)) return null;
+ final long offsetAndSize = getLong(ordinal);
+ final int offset = (int) (offsetAndSize >> 32);
+ final int size = (int) offsetAndSize;
final UnsafeArrayData array = new UnsafeArrayData();
array.pointTo(baseObject, baseOffset + offset, size);
return array;
@@ -289,10 +277,10 @@ public final class UnsafeArrayData extends ArrayData {
@Override
public UnsafeMapData getMap(int ordinal) {
- assertIndexIsValid(ordinal);
- final int offset = getElementOffset(ordinal);
- if (offset < 0) return null;
- final int size = getElementSize(offset, ordinal);
+ if (isNullAt(ordinal)) return null;
+ final long offsetAndSize = getLong(ordinal);
+ final int offset = (int) (offsetAndSize >> 32);
+ final int size = (int) offsetAndSize;
final UnsafeMapData map = new UnsafeMapData();
map.pointTo(baseObject, baseOffset + offset, size);
return map;
@@ -341,63 +329,108 @@ public final class UnsafeArrayData extends ArrayData {
return arrayCopy;
}
- public static UnsafeArrayData fromPrimitiveArray(int[] arr) {
- if (arr.length > (Integer.MAX_VALUE - 4) / 8) {
- throw new UnsupportedOperationException("Cannot convert this array to unsafe format as " +
- "it's too big.");
- }
+ @Override
+ public boolean[] toBooleanArray() {
+ boolean[] values = new boolean[numElements];
+ Platform.copyMemory(
+ baseObject, elementOffset, values, Platform.BOOLEAN_ARRAY_OFFSET, numElements);
+ return values;
+ }
- final int offsetRegionSize = 4 * arr.length;
- final int valueRegionSize = 4 * arr.length;
- final int totalSize = 4 + offsetRegionSize + valueRegionSize;
- final byte[] data = new byte[totalSize];
+ @Override
+ public byte[] toByteArray() {
+ byte[] values = new byte[numElements];
+ Platform.copyMemory(
+ baseObject, elementOffset, values, Platform.BYTE_ARRAY_OFFSET, numElements);
+ return values;
+ }
- Platform.putInt(data, Platform.BYTE_ARRAY_OFFSET, arr.length);
+ @Override
+ public short[] toShortArray() {
+ short[] values = new short[numElements];
+ Platform.copyMemory(
+ baseObject, elementOffset, values, Platform.SHORT_ARRAY_OFFSET, numElements * 2);
+ return values;
+ }
- int offsetPosition = Platform.BYTE_ARRAY_OFFSET + 4;
- int valueOffset = 4 + offsetRegionSize;
- for (int i = 0; i < arr.length; i++) {
- Platform.putInt(data, offsetPosition, valueOffset);
- offsetPosition += 4;
- valueOffset += 4;
- }
+ @Override
+ public int[] toIntArray() {
+ int[] values = new int[numElements];
+ Platform.copyMemory(
+ baseObject, elementOffset, values, Platform.INT_ARRAY_OFFSET, numElements * 4);
+ return values;
+ }
- Platform.copyMemory(arr, Platform.INT_ARRAY_OFFSET, data,
- Platform.BYTE_ARRAY_OFFSET + 4 + offsetRegionSize, valueRegionSize);
+ @Override
+ public long[] toLongArray() {
+ long[] values = new long[numElements];
+ Platform.copyMemory(
+ baseObject, elementOffset, values, Platform.LONG_ARRAY_OFFSET, numElements * 8);
+ return values;
+ }
- UnsafeArrayData result = new UnsafeArrayData();
- result.pointTo(data, Platform.BYTE_ARRAY_OFFSET, totalSize);
- return result;
+ @Override
+ public float[] toFloatArray() {
+ float[] values = new float[numElements];
+ Platform.copyMemory(
+ baseObject, elementOffset, values, Platform.FLOAT_ARRAY_OFFSET, numElements * 4);
+ return values;
}
- public static UnsafeArrayData fromPrimitiveArray(double[] arr) {
- if (arr.length > (Integer.MAX_VALUE - 4) / 12) {
+ @Override
+ public double[] toDoubleArray() {
+ double[] values = new double[numElements];
+ Platform.copyMemory(
+ baseObject, elementOffset, values, Platform.DOUBLE_ARRAY_OFFSET, numElements * 8);
+ return values;
+ }
+
+ private static UnsafeArrayData fromPrimitiveArray(
+ Object arr, int offset, int length, int elementSize) {
+ final long headerInBytes = calculateHeaderPortionInBytes(length);
+ final long valueRegionInBytes = elementSize * length;
+ final long totalSizeInLongs = (headerInBytes + valueRegionInBytes + 7) / 8;
+ if (totalSizeInLongs > Integer.MAX_VALUE / 8) {
throw new UnsupportedOperationException("Cannot convert this array to unsafe format as " +
"it's too big.");
}
- final int offsetRegionSize = 4 * arr.length;
- final int valueRegionSize = 8 * arr.length;
- final int totalSize = 4 + offsetRegionSize + valueRegionSize;
- final byte[] data = new byte[totalSize];
+ final long[] data = new long[(int)totalSizeInLongs];
- Platform.putInt(data, Platform.BYTE_ARRAY_OFFSET, arr.length);
-
- int offsetPosition = Platform.BYTE_ARRAY_OFFSET + 4;
- int valueOffset = 4 + offsetRegionSize;
- for (int i = 0; i < arr.length; i++) {
- Platform.putInt(data, offsetPosition, valueOffset);
- offsetPosition += 4;
- valueOffset += 8;
- }
-
- Platform.copyMemory(arr, Platform.DOUBLE_ARRAY_OFFSET, data,
- Platform.BYTE_ARRAY_OFFSET + 4 + offsetRegionSize, valueRegionSize);
+ Platform.putLong(data, Platform.LONG_ARRAY_OFFSET, length);
+ Platform.copyMemory(arr, offset, data,
+ Platform.LONG_ARRAY_OFFSET + headerInBytes, valueRegionInBytes);
UnsafeArrayData result = new UnsafeArrayData();
- result.pointTo(data, Platform.BYTE_ARRAY_OFFSET, totalSize);
+ result.pointTo(data, Platform.LONG_ARRAY_OFFSET, (int)totalSizeInLongs * 8);
return result;
}
- // TODO: add more specialized methods.
+ public static UnsafeArrayData fromPrimitiveArray(boolean[] arr) {
+ return fromPrimitiveArray(arr, Platform.BOOLEAN_ARRAY_OFFSET, arr.length, 1);
+ }
+
+ public static UnsafeArrayData fromPrimitiveArray(byte[] arr) {
+ return fromPrimitiveArray(arr, Platform.BYTE_ARRAY_OFFSET, arr.length, 1);
+ }
+
+ public static UnsafeArrayData fromPrimitiveArray(short[] arr) {
+ return fromPrimitiveArray(arr, Platform.SHORT_ARRAY_OFFSET, arr.length, 2);
+ }
+
+ public static UnsafeArrayData fromPrimitiveArray(int[] arr) {
+ return fromPrimitiveArray(arr, Platform.INT_ARRAY_OFFSET, arr.length, 4);
+ }
+
+ public static UnsafeArrayData fromPrimitiveArray(long[] arr) {
+ return fromPrimitiveArray(arr, Platform.LONG_ARRAY_OFFSET, arr.length, 8);
+ }
+
+ public static UnsafeArrayData fromPrimitiveArray(float[] arr) {
+ return fromPrimitiveArray(arr, Platform.FLOAT_ARRAY_OFFSET, arr.length, 4);
+ }
+
+ public static UnsafeArrayData fromPrimitiveArray(double[] arr) {
+ return fromPrimitiveArray(arr, Platform.DOUBLE_ARRAY_OFFSET, arr.length, 8);
+ }
}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeMapData.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeMapData.java
index 0700148bec..35029f5a50 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeMapData.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeMapData.java
@@ -25,7 +25,7 @@ import org.apache.spark.unsafe.Platform;
/**
* An Unsafe implementation of Map which is backed by raw memory instead of Java objects.
*
- * Currently we just use 2 UnsafeArrayData to represent UnsafeMapData, with extra 4 bytes at head
+ * Currently we just use 2 UnsafeArrayData to represent UnsafeMapData, with extra 8 bytes at head
* to indicate the number of bytes of the unsafe key array.
* [unsafe key array numBytes] [unsafe key array] [unsafe value array]
*/
@@ -65,14 +65,15 @@ public final class UnsafeMapData extends MapData {
* @param sizeInBytes the size of this map's backing data, in bytes
*/
public void pointTo(Object baseObject, long baseOffset, int sizeInBytes) {
- // Read the numBytes of key array from the first 4 bytes.
- final int keyArraySize = Platform.getInt(baseObject, baseOffset);
- final int valueArraySize = sizeInBytes - keyArraySize - 4;
+ // Read the numBytes of key array from the first 8 bytes.
+ final long keyArraySize = Platform.getLong(baseObject, baseOffset);
assert keyArraySize >= 0 : "keyArraySize (" + keyArraySize + ") should >= 0";
+ assert keyArraySize <= Integer.MAX_VALUE : "keyArraySize (" + keyArraySize + ") should <= Integer.MAX_VALUE";
+ final int valueArraySize = sizeInBytes - (int)keyArraySize - 8;
assert valueArraySize >= 0 : "valueArraySize (" + valueArraySize + ") should >= 0";
- keys.pointTo(baseObject, baseOffset + 4, keyArraySize);
- values.pointTo(baseObject, baseOffset + 4 + keyArraySize, valueArraySize);
+ keys.pointTo(baseObject, baseOffset + 8, (int)keyArraySize);
+ values.pointTo(baseObject, baseOffset + 8 + keyArraySize, valueArraySize);
assert keys.numElements() == values.numElements();
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java
index 7dd932d198..afea467689 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java
@@ -19,9 +19,13 @@ package org.apache.spark.sql.catalyst.expressions.codegen;
import org.apache.spark.sql.types.Decimal;
import org.apache.spark.unsafe.Platform;
+import org.apache.spark.unsafe.array.ByteArrayMethods;
+import org.apache.spark.unsafe.bitset.BitSetMethods;
import org.apache.spark.unsafe.types.CalendarInterval;
import org.apache.spark.unsafe.types.UTF8String;
+import static org.apache.spark.sql.catalyst.expressions.UnsafeArrayData.calculateHeaderPortionInBytes;
+
/**
* A helper class to write data into global row buffer using `UnsafeArrayData` format,
* used by {@link org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection}.
@@ -33,134 +37,213 @@ public class UnsafeArrayWriter {
// The offset of the global buffer where we start to write this array.
private int startingOffset;
- public void initialize(BufferHolder holder, int numElements, int fixedElementSize) {
- // We need 4 bytes to store numElements and 4 bytes each element to store offset.
- final int fixedSize = 4 + 4 * numElements;
+ // The number of elements in this array
+ private int numElements;
+
+ private int headerInBytes;
+
+ private void assertIndexIsValid(int index) {
+ assert index >= 0 : "index (" + index + ") should >= 0";
+ assert index < numElements : "index (" + index + ") should < " + numElements;
+ }
+
+ public void initialize(BufferHolder holder, int numElements, int elementSize) {
+ // We need 8 bytes to store numElements in header
+ this.numElements = numElements;
+ this.headerInBytes = calculateHeaderPortionInBytes(numElements);
this.holder = holder;
this.startingOffset = holder.cursor;
- holder.grow(fixedSize);
- Platform.putInt(holder.buffer, holder.cursor, numElements);
- holder.cursor += fixedSize;
+ // Grows the global buffer ahead for header and fixed size data.
+ int fixedPartInBytes =
+ ByteArrayMethods.roundNumberOfBytesToNearestWord(elementSize * numElements);
+ holder.grow(headerInBytes + fixedPartInBytes);
+
+ // Write numElements and clear out null bits to header
+ Platform.putLong(holder.buffer, startingOffset, numElements);
+ for (int i = 8; i < headerInBytes; i += 8) {
+ Platform.putLong(holder.buffer, startingOffset + i, 0L);
+ }
+
+ // fill 0 into reminder part of 8-bytes alignment in unsafe array
+ for (int i = elementSize * numElements; i < fixedPartInBytes; i++) {
+ Platform.putByte(holder.buffer, startingOffset + headerInBytes + i, (byte) 0);
+ }
+ holder.cursor += (headerInBytes + fixedPartInBytes);
+ }
+
+ private void zeroOutPaddingBytes(int numBytes) {
+ if ((numBytes & 0x07) > 0) {
+ Platform.putLong(holder.buffer, holder.cursor + ((numBytes >> 3) << 3), 0L);
+ }
+ }
+
+ private long getElementOffset(int ordinal, int elementSize) {
+ return startingOffset + headerInBytes + ordinal * elementSize;
+ }
+
+ public void setOffsetAndSize(int ordinal, long currentCursor, int size) {
+ assertIndexIsValid(ordinal);
+ final long relativeOffset = currentCursor - startingOffset;
+ final long offsetAndSize = (relativeOffset << 32) | (long)size;
- // Grows the global buffer ahead for fixed size data.
- holder.grow(fixedElementSize * numElements);
+ write(ordinal, offsetAndSize);
}
- private long getElementOffset(int ordinal) {
- return startingOffset + 4 + 4 * ordinal;
+ private void setNullBit(int ordinal) {
+ assertIndexIsValid(ordinal);
+ BitSetMethods.set(holder.buffer, startingOffset + 8, ordinal);
}
- public void setNullAt(int ordinal) {
- final int relativeOffset = holder.cursor - startingOffset;
- // Writes negative offset value to represent null element.
- Platform.putInt(holder.buffer, getElementOffset(ordinal), -relativeOffset);
+ public void setNullBoolean(int ordinal) {
+ setNullBit(ordinal);
+ // put zero into the corresponding field when set null
+ Platform.putBoolean(holder.buffer, getElementOffset(ordinal, 1), false);
}
- public void setOffset(int ordinal) {
- final int relativeOffset = holder.cursor - startingOffset;
- Platform.putInt(holder.buffer, getElementOffset(ordinal), relativeOffset);
+ public void setNullByte(int ordinal) {
+ setNullBit(ordinal);
+ // put zero into the corresponding field when set null
+ Platform.putByte(holder.buffer, getElementOffset(ordinal, 1), (byte)0);
}
+ public void setNullShort(int ordinal) {
+ setNullBit(ordinal);
+ // put zero into the corresponding field when set null
+ Platform.putShort(holder.buffer, getElementOffset(ordinal, 2), (short)0);
+ }
+
+ public void setNullInt(int ordinal) {
+ setNullBit(ordinal);
+ // put zero into the corresponding field when set null
+ Platform.putInt(holder.buffer, getElementOffset(ordinal, 4), (int)0);
+ }
+
+ public void setNullLong(int ordinal) {
+ setNullBit(ordinal);
+ // put zero into the corresponding field when set null
+ Platform.putLong(holder.buffer, getElementOffset(ordinal, 8), (long)0);
+ }
+
+ public void setNullFloat(int ordinal) {
+ setNullBit(ordinal);
+ // put zero into the corresponding field when set null
+ Platform.putFloat(holder.buffer, getElementOffset(ordinal, 4), (float)0);
+ }
+
+ public void setNullDouble(int ordinal) {
+ setNullBit(ordinal);
+ // put zero into the corresponding field when set null
+ Platform.putDouble(holder.buffer, getElementOffset(ordinal, 8), (double)0);
+ }
+
+ public void setNull(int ordinal) { setNullLong(ordinal); }
+
public void write(int ordinal, boolean value) {
- Platform.putBoolean(holder.buffer, holder.cursor, value);
- setOffset(ordinal);
- holder.cursor += 1;
+ assertIndexIsValid(ordinal);
+ Platform.putBoolean(holder.buffer, getElementOffset(ordinal, 1), value);
}
public void write(int ordinal, byte value) {
- Platform.putByte(holder.buffer, holder.cursor, value);
- setOffset(ordinal);
- holder.cursor += 1;
+ assertIndexIsValid(ordinal);
+ Platform.putByte(holder.buffer, getElementOffset(ordinal, 1), value);
}
public void write(int ordinal, short value) {
- Platform.putShort(holder.buffer, holder.cursor, value);
- setOffset(ordinal);
- holder.cursor += 2;
+ assertIndexIsValid(ordinal);
+ Platform.putShort(holder.buffer, getElementOffset(ordinal, 2), value);
}
public void write(int ordinal, int value) {
- Platform.putInt(holder.buffer, holder.cursor, value);
- setOffset(ordinal);
- holder.cursor += 4;
+ assertIndexIsValid(ordinal);
+ Platform.putInt(holder.buffer, getElementOffset(ordinal, 4), value);
}
public void write(int ordinal, long value) {
- Platform.putLong(holder.buffer, holder.cursor, value);
- setOffset(ordinal);
- holder.cursor += 8;
+ assertIndexIsValid(ordinal);
+ Platform.putLong(holder.buffer, getElementOffset(ordinal, 8), value);
}
public void write(int ordinal, float value) {
if (Float.isNaN(value)) {
value = Float.NaN;
}
- Platform.putFloat(holder.buffer, holder.cursor, value);
- setOffset(ordinal);
- holder.cursor += 4;
+ assertIndexIsValid(ordinal);
+ Platform.putFloat(holder.buffer, getElementOffset(ordinal, 4), value);
}
public void write(int ordinal, double value) {
if (Double.isNaN(value)) {
value = Double.NaN;
}
- Platform.putDouble(holder.buffer, holder.cursor, value);
- setOffset(ordinal);
- holder.cursor += 8;
+ assertIndexIsValid(ordinal);
+ Platform.putDouble(holder.buffer, getElementOffset(ordinal, 8), value);
}
public void write(int ordinal, Decimal input, int precision, int scale) {
// make sure Decimal object has the same scale as DecimalType
+ assertIndexIsValid(ordinal);
if (input.changePrecision(precision, scale)) {
if (precision <= Decimal.MAX_LONG_DIGITS()) {
- Platform.putLong(holder.buffer, holder.cursor, input.toUnscaledLong());
- setOffset(ordinal);
- holder.cursor += 8;
+ write(ordinal, input.toUnscaledLong());
} else {
final byte[] bytes = input.toJavaBigDecimal().unscaledValue().toByteArray();
- assert bytes.length <= 16;
- holder.grow(bytes.length);
+ final int numBytes = bytes.length;
+ assert numBytes <= 16;
+ int roundedSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes);
+ holder.grow(roundedSize);
+
+ zeroOutPaddingBytes(numBytes);
// Write the bytes to the variable length portion.
Platform.copyMemory(
- bytes, Platform.BYTE_ARRAY_OFFSET, holder.buffer, holder.cursor, bytes.length);
- setOffset(ordinal);
- holder.cursor += bytes.length;
+ bytes, Platform.BYTE_ARRAY_OFFSET, holder.buffer, holder.cursor, numBytes);
+ setOffsetAndSize(ordinal, holder.cursor, numBytes);
+
+ // move the cursor forward with 8-bytes boundary
+ holder.cursor += roundedSize;
}
} else {
- setNullAt(ordinal);
+ setNull(ordinal);
}
}
public void write(int ordinal, UTF8String input) {
final int numBytes = input.numBytes();
+ final int roundedSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes);
// grow the global buffer before writing data.
- holder.grow(numBytes);
+ holder.grow(roundedSize);
+
+ zeroOutPaddingBytes(numBytes);
// Write the bytes to the variable length portion.
input.writeToMemory(holder.buffer, holder.cursor);
- setOffset(ordinal);
+ setOffsetAndSize(ordinal, holder.cursor, numBytes);
// move the cursor forward.
- holder.cursor += numBytes;
+ holder.cursor += roundedSize;
}
public void write(int ordinal, byte[] input) {
+ final int numBytes = input.length;
+ final int roundedSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(input.length);
+
// grow the global buffer before writing data.
- holder.grow(input.length);
+ holder.grow(roundedSize);
+
+ zeroOutPaddingBytes(numBytes);
// Write the bytes to the variable length portion.
Platform.copyMemory(
- input, Platform.BYTE_ARRAY_OFFSET, holder.buffer, holder.cursor, input.length);
+ input, Platform.BYTE_ARRAY_OFFSET, holder.buffer, holder.cursor, numBytes);
- setOffset(ordinal);
+ setOffsetAndSize(ordinal, holder.cursor, numBytes);
// move the cursor forward.
- holder.cursor += input.length;
+ holder.cursor += roundedSize;
}
public void write(int ordinal, CalendarInterval input) {
@@ -171,7 +254,7 @@ public class UnsafeArrayWriter {
Platform.putLong(holder.buffer, holder.cursor, input.months);
Platform.putLong(holder.buffer, holder.cursor + 8, input.microseconds);
- setOffset(ordinal);
+ setOffsetAndSize(ordinal, holder.cursor, 16);
// move the cursor forward.
holder.cursor += 16;
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
index 5efba4b3a6..75bb6936b4 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
@@ -124,7 +124,6 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
final int $tmpCursor = $bufferHolder.cursor;
${writeArrayToBuffer(ctx, input.value, et, bufferHolder)}
$rowWriter.setOffsetAndSize($index, $tmpCursor, $bufferHolder.cursor - $tmpCursor);
- $rowWriter.alignToWords($bufferHolder.cursor - $tmpCursor);
"""
case m @ MapType(kt, vt, _) =>
@@ -134,7 +133,6 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
final int $tmpCursor = $bufferHolder.cursor;
${writeMapToBuffer(ctx, input.value, kt, vt, bufferHolder)}
$rowWriter.setOffsetAndSize($index, $tmpCursor, $bufferHolder.cursor - $tmpCursor);
- $rowWriter.alignToWords($bufferHolder.cursor - $tmpCursor);
"""
case t: DecimalType =>
@@ -189,29 +187,33 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
val jt = ctx.javaType(et)
- val fixedElementSize = et match {
+ val elementOrOffsetSize = et match {
case t: DecimalType if t.precision <= Decimal.MAX_LONG_DIGITS => 8
case _ if ctx.isPrimitiveType(jt) => et.defaultSize
- case _ => 0
+ case _ => 8 // we need 8 bytes to store offset and length
}
+ val tmpCursor = ctx.freshName("tmpCursor")
val writeElement = et match {
case t: StructType =>
s"""
- $arrayWriter.setOffset($index);
+ final int $tmpCursor = $bufferHolder.cursor;
${writeStructToBuffer(ctx, element, t.map(_.dataType), bufferHolder)}
+ $arrayWriter.setOffsetAndSize($index, $tmpCursor, $bufferHolder.cursor - $tmpCursor);
"""
case a @ ArrayType(et, _) =>
s"""
- $arrayWriter.setOffset($index);
+ final int $tmpCursor = $bufferHolder.cursor;
${writeArrayToBuffer(ctx, element, et, bufferHolder)}
+ $arrayWriter.setOffsetAndSize($index, $tmpCursor, $bufferHolder.cursor - $tmpCursor);
"""
case m @ MapType(kt, vt, _) =>
s"""
- $arrayWriter.setOffset($index);
+ final int $tmpCursor = $bufferHolder.cursor;
${writeMapToBuffer(ctx, element, kt, vt, bufferHolder)}
+ $arrayWriter.setOffsetAndSize($index, $tmpCursor, $bufferHolder.cursor - $tmpCursor);
"""
case t: DecimalType =>
@@ -222,16 +224,17 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
case _ => s"$arrayWriter.write($index, $element);"
}
+ val primitiveTypeName = if (ctx.isPrimitiveType(jt)) ctx.primitiveTypeName(et) else ""
s"""
if ($input instanceof UnsafeArrayData) {
${writeUnsafeData(ctx, s"((UnsafeArrayData) $input)", bufferHolder)}
} else {
final int $numElements = $input.numElements();
- $arrayWriter.initialize($bufferHolder, $numElements, $fixedElementSize);
+ $arrayWriter.initialize($bufferHolder, $numElements, $elementOrOffsetSize);
for (int $index = 0; $index < $numElements; $index++) {
if ($input.isNullAt($index)) {
- $arrayWriter.setNullAt($index);
+ $arrayWriter.setNull$primitiveTypeName($index);
} else {
final $jt $element = ${ctx.getValue(input, et, index)};
$writeElement
@@ -261,16 +264,16 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
final ArrayData $keys = $input.keyArray();
final ArrayData $values = $input.valueArray();
- // preserve 4 bytes to write the key array numBytes later.
- $bufferHolder.grow(4);
- $bufferHolder.cursor += 4;
+ // preserve 8 bytes to write the key array numBytes later.
+ $bufferHolder.grow(8);
+ $bufferHolder.cursor += 8;
// Remember the current cursor so that we can write numBytes of key array later.
final int $tmpCursor = $bufferHolder.cursor;
${writeArrayToBuffer(ctx, keys, keyType, bufferHolder)}
- // Write the numBytes of key array into the first 4 bytes.
- Platform.putInt($bufferHolder.buffer, $tmpCursor - 4, $bufferHolder.cursor - $tmpCursor);
+ // Write the numBytes of key array into the first 8 bytes.
+ Platform.putLong($bufferHolder.buffer, $tmpCursor - 8, $bufferHolder.cursor - $tmpCursor);
${writeArrayToBuffer(ctx, values, valueType, bufferHolder)}
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala
index 1265908182..90790dda75 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala
@@ -300,7 +300,8 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
private def testArrayInt(array: UnsafeArrayData, values: Seq[Int]): Unit = {
assert(array.numElements == values.length)
- assert(array.getSizeInBytes == 4 + (4 + 4) * values.length)
+ assert(array.getSizeInBytes ==
+ 8 + scala.math.ceil(values.length / 64.toDouble) * 8 + roundedSize(4 * values.length))
values.zipWithIndex.foreach {
case (value, index) => assert(array.getInt(index) == value)
}
@@ -313,7 +314,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
testArrayInt(map.keyArray, keys)
testArrayInt(map.valueArray, values)
- assert(map.getSizeInBytes == 4 + map.keyArray.getSizeInBytes + map.valueArray.getSizeInBytes)
+ assert(map.getSizeInBytes == 8 + map.keyArray.getSizeInBytes + map.valueArray.getSizeInBytes)
}
test("basic conversion with array type") {
@@ -339,7 +340,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
val nestedArray = unsafeArray2.getArray(0)
testArrayInt(nestedArray, Seq(3, 4))
- assert(unsafeArray2.getSizeInBytes == 4 + 4 + nestedArray.getSizeInBytes)
+ assert(unsafeArray2.getSizeInBytes == 8 + 8 + 8 + nestedArray.getSizeInBytes)
val array1Size = roundedSize(unsafeArray1.getSizeInBytes)
val array2Size = roundedSize(unsafeArray2.getSizeInBytes)
@@ -382,10 +383,10 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
val nestedMap = valueArray.getMap(0)
testMapInt(nestedMap, Seq(5, 6), Seq(7, 8))
- assert(valueArray.getSizeInBytes == 4 + 4 + nestedMap.getSizeInBytes)
+ assert(valueArray.getSizeInBytes == 8 + 8 + 8 + roundedSize(nestedMap.getSizeInBytes))
}
- assert(unsafeMap2.getSizeInBytes == 4 + keyArray.getSizeInBytes + valueArray.getSizeInBytes)
+ assert(unsafeMap2.getSizeInBytes == 8 + keyArray.getSizeInBytes + valueArray.getSizeInBytes)
val map1Size = roundedSize(unsafeMap1.getSizeInBytes)
val map2Size = roundedSize(unsafeMap2.getSizeInBytes)
@@ -425,7 +426,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
assert(innerStruct.getLong(0) == 2L)
}
- assert(field2.getSizeInBytes == 4 + 4 + innerStruct.getSizeInBytes)
+ assert(field2.getSizeInBytes == 8 + 8 + 8 + innerStruct.getSizeInBytes)
assert(unsafeRow.getSizeInBytes ==
8 + 8 * 2 + field1.getSizeInBytes + roundedSize(field2.getSizeInBytes))
@@ -468,10 +469,10 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
assert(innerStruct.getSizeInBytes == 8 + 8)
assert(innerStruct.getLong(0) == 4L)
- assert(valueArray.getSizeInBytes == 4 + 4 + innerStruct.getSizeInBytes)
+ assert(valueArray.getSizeInBytes == 8 + 8 + 8 + innerStruct.getSizeInBytes)
}
- assert(field2.getSizeInBytes == 4 + keyArray.getSizeInBytes + valueArray.getSizeInBytes)
+ assert(field2.getSizeInBytes == 8 + keyArray.getSizeInBytes + valueArray.getSizeInBytes)
assert(unsafeRow.getSizeInBytes ==
8 + 8 * 2 + field1.getSizeInBytes + roundedSize(field2.getSizeInBytes))
@@ -497,7 +498,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
val innerMap = field1.getMap(0)
testMapInt(innerMap, Seq(1), Seq(2))
- assert(field1.getSizeInBytes == 4 + 4 + innerMap.getSizeInBytes)
+ assert(field1.getSizeInBytes == 8 + 8 + 8 + roundedSize(innerMap.getSizeInBytes))
val field2 = unsafeRow.getMap(1)
assert(field2.numElements == 1)
@@ -513,10 +514,10 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
val innerArray = valueArray.getArray(0)
testArrayInt(innerArray, Seq(4))
- assert(valueArray.getSizeInBytes == 4 + (4 + innerArray.getSizeInBytes))
+ assert(valueArray.getSizeInBytes == 8 + 8 + 8 + innerArray.getSizeInBytes)
}
- assert(field2.getSizeInBytes == 4 + keyArray.getSizeInBytes + valueArray.getSizeInBytes)
+ assert(field2.getSizeInBytes == 8 + keyArray.getSizeInBytes + valueArray.getSizeInBytes)
assert(unsafeRow.getSizeInBytes ==
8 + 8 * 2 + roundedSize(field1.getSizeInBytes) + roundedSize(field2.getSizeInBytes))
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/UnsafeArraySuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/UnsafeArraySuite.scala
index 1685276ff1..f0e247bf46 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/UnsafeArraySuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/UnsafeArraySuite.scala
@@ -18,27 +18,190 @@
package org.apache.spark.sql.catalyst.util
import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder}
import org.apache.spark.sql.catalyst.expressions.UnsafeArrayData
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
class UnsafeArraySuite extends SparkFunSuite {
- test("from primitive int array") {
- val array = Array(1, 10, 100)
- val unsafe = UnsafeArrayData.fromPrimitiveArray(array)
- assert(unsafe.numElements == 3)
- assert(unsafe.getSizeInBytes == 4 + 4 * 3 + 4 * 3)
- assert(unsafe.getInt(0) == 1)
- assert(unsafe.getInt(1) == 10)
- assert(unsafe.getInt(2) == 100)
+ val booleanArray = Array(false, true)
+ val shortArray = Array(1.toShort, 10.toShort, 100.toShort)
+ val intArray = Array(1, 10, 100)
+ val longArray = Array(1.toLong, 10.toLong, 100.toLong)
+ val floatArray = Array(1.1.toFloat, 2.2.toFloat, 3.3.toFloat)
+ val doubleArray = Array(1.1, 2.2, 3.3)
+ val stringArray = Array("1", "10", "100")
+ val dateArray = Array(
+ DateTimeUtils.stringToDate(UTF8String.fromString("1970-1-1")).get,
+ DateTimeUtils.stringToDate(UTF8String.fromString("2016-7-26")).get)
+ val timestampArray = Array(
+ DateTimeUtils.stringToTimestamp(UTF8String.fromString("1970-1-1 00:00:00")).get,
+ DateTimeUtils.stringToTimestamp(UTF8String.fromString("2016-7-26 00:00:00")).get)
+ val decimalArray4_1 = Array(
+ BigDecimal("123.4").setScale(1, BigDecimal.RoundingMode.FLOOR),
+ BigDecimal("567.8").setScale(1, BigDecimal.RoundingMode.FLOOR))
+ val decimalArray20_20 = Array(
+ BigDecimal("1.2345678901234567890123456").setScale(21, BigDecimal.RoundingMode.FLOOR),
+ BigDecimal("2.3456789012345678901234567").setScale(21, BigDecimal.RoundingMode.FLOOR))
+
+ val calenderintervalArray = Array(new CalendarInterval(3, 321), new CalendarInterval(1, 123))
+
+ val intMultiDimArray = Array(Array(1), Array(2, 20), Array(3, 30, 300))
+ val doubleMultiDimArray = Array(
+ Array(1.1, 11.1), Array(2.2, 22.2, 222.2), Array(3.3, 33.3, 333.3, 3333.3))
+
+ test("read array") {
+ val unsafeBoolean = ExpressionEncoder[Array[Boolean]].resolveAndBind().
+ toRow(booleanArray).getArray(0)
+ assert(unsafeBoolean.isInstanceOf[UnsafeArrayData])
+ assert(unsafeBoolean.numElements == booleanArray.length)
+ booleanArray.zipWithIndex.map { case (e, i) =>
+ assert(unsafeBoolean.getBoolean(i) == e)
+ }
+
+ val unsafeShort = ExpressionEncoder[Array[Short]].resolveAndBind().
+ toRow(shortArray).getArray(0)
+ assert(unsafeShort.isInstanceOf[UnsafeArrayData])
+ assert(unsafeShort.numElements == shortArray.length)
+ shortArray.zipWithIndex.map { case (e, i) =>
+ assert(unsafeShort.getShort(i) == e)
+ }
+
+ val unsafeInt = ExpressionEncoder[Array[Int]].resolveAndBind().
+ toRow(intArray).getArray(0)
+ assert(unsafeInt.isInstanceOf[UnsafeArrayData])
+ assert(unsafeInt.numElements == intArray.length)
+ intArray.zipWithIndex.map { case (e, i) =>
+ assert(unsafeInt.getInt(i) == e)
+ }
+
+ val unsafeLong = ExpressionEncoder[Array[Long]].resolveAndBind().
+ toRow(longArray).getArray(0)
+ assert(unsafeLong.isInstanceOf[UnsafeArrayData])
+ assert(unsafeLong.numElements == longArray.length)
+ longArray.zipWithIndex.map { case (e, i) =>
+ assert(unsafeLong.getLong(i) == e)
+ }
+
+ val unsafeFloat = ExpressionEncoder[Array[Float]].resolveAndBind().
+ toRow(floatArray).getArray(0)
+ assert(unsafeFloat.isInstanceOf[UnsafeArrayData])
+ assert(unsafeFloat.numElements == floatArray.length)
+ floatArray.zipWithIndex.map { case (e, i) =>
+ assert(unsafeFloat.getFloat(i) == e)
+ }
+
+ val unsafeDouble = ExpressionEncoder[Array[Double]].resolveAndBind().
+ toRow(doubleArray).getArray(0)
+ assert(unsafeDouble.isInstanceOf[UnsafeArrayData])
+ assert(unsafeDouble.numElements == doubleArray.length)
+ doubleArray.zipWithIndex.map { case (e, i) =>
+ assert(unsafeDouble.getDouble(i) == e)
+ }
+
+ val unsafeString = ExpressionEncoder[Array[String]].resolveAndBind().
+ toRow(stringArray).getArray(0)
+ assert(unsafeString.isInstanceOf[UnsafeArrayData])
+ assert(unsafeString.numElements == stringArray.length)
+ stringArray.zipWithIndex.map { case (e, i) =>
+ assert(unsafeString.getUTF8String(i).toString().equals(e))
+ }
+
+ val unsafeDate = ExpressionEncoder[Array[Int]].resolveAndBind().
+ toRow(dateArray).getArray(0)
+ assert(unsafeDate.isInstanceOf[UnsafeArrayData])
+ assert(unsafeDate.numElements == dateArray.length)
+ dateArray.zipWithIndex.map { case (e, i) =>
+ assert(unsafeDate.get(i, DateType) == e)
+ }
+
+ val unsafeTimestamp = ExpressionEncoder[Array[Long]].resolveAndBind().
+ toRow(timestampArray).getArray(0)
+ assert(unsafeTimestamp.isInstanceOf[UnsafeArrayData])
+ assert(unsafeTimestamp.numElements == timestampArray.length)
+ timestampArray.zipWithIndex.map { case (e, i) =>
+ assert(unsafeTimestamp.get(i, TimestampType) == e)
+ }
+
+ Seq(decimalArray4_1, decimalArray20_20).map { decimalArray =>
+ val decimal = decimalArray(0)
+ val schema = new StructType().add(
+ "array", ArrayType(DecimalType(decimal.precision, decimal.scale)))
+ val encoder = RowEncoder(schema).resolveAndBind()
+ val externalRow = Row(decimalArray)
+ val ir = encoder.toRow(externalRow)
+
+ val unsafeDecimal = ir.getArray(0)
+ assert(unsafeDecimal.isInstanceOf[UnsafeArrayData])
+ assert(unsafeDecimal.numElements == decimalArray.length)
+ decimalArray.zipWithIndex.map { case (e, i) =>
+ assert(unsafeDecimal.getDecimal(i, e.precision, e.scale).toBigDecimal == e)
+ }
+ }
+
+ val schema = new StructType().add("array", ArrayType(CalendarIntervalType))
+ val encoder = RowEncoder(schema).resolveAndBind()
+ val externalRow = Row(calenderintervalArray)
+ val ir = encoder.toRow(externalRow)
+ val unsafeCalendar = ir.getArray(0)
+ assert(unsafeCalendar.isInstanceOf[UnsafeArrayData])
+ assert(unsafeCalendar.numElements == calenderintervalArray.length)
+ calenderintervalArray.zipWithIndex.map { case (e, i) =>
+ assert(unsafeCalendar.getInterval(i) == e)
+ }
+
+ val unsafeMultiDimInt = ExpressionEncoder[Array[Array[Int]]].resolveAndBind().
+ toRow(intMultiDimArray).getArray(0)
+ assert(unsafeMultiDimInt.isInstanceOf[UnsafeArrayData])
+ assert(unsafeMultiDimInt.numElements == intMultiDimArray.length)
+ intMultiDimArray.zipWithIndex.map { case (a, j) =>
+ val u = unsafeMultiDimInt.getArray(j)
+ assert(u.isInstanceOf[UnsafeArrayData])
+ assert(u.numElements == a.length)
+ a.zipWithIndex.map { case (e, i) =>
+ assert(u.getInt(i) == e)
+ }
+ }
+
+ val unsafeMultiDimDouble = ExpressionEncoder[Array[Array[Double]]].resolveAndBind().
+ toRow(doubleMultiDimArray).getArray(0)
+ assert(unsafeDouble.isInstanceOf[UnsafeArrayData])
+ assert(unsafeMultiDimDouble.numElements == doubleMultiDimArray.length)
+ doubleMultiDimArray.zipWithIndex.map { case (a, j) =>
+ val u = unsafeMultiDimDouble.getArray(j)
+ assert(u.isInstanceOf[UnsafeArrayData])
+ assert(u.numElements == a.length)
+ a.zipWithIndex.map { case (e, i) =>
+ assert(u.getDouble(i) == e)
+ }
+ }
}
- test("from primitive double array") {
- val array = Array(1.1, 2.2, 3.3)
- val unsafe = UnsafeArrayData.fromPrimitiveArray(array)
- assert(unsafe.numElements == 3)
- assert(unsafe.getSizeInBytes == 4 + 4 * 3 + 8 * 3)
- assert(unsafe.getDouble(0) == 1.1)
- assert(unsafe.getDouble(1) == 2.2)
- assert(unsafe.getDouble(2) == 3.3)
+ test("from primitive array") {
+ val unsafeInt = UnsafeArrayData.fromPrimitiveArray(intArray)
+ assert(unsafeInt.numElements == 3)
+ assert(unsafeInt.getSizeInBytes ==
+ ((8 + scala.math.ceil(3/64.toDouble) * 8 + 4 * 3 + 7).toInt / 8) * 8)
+ intArray.zipWithIndex.map { case (e, i) =>
+ assert(unsafeInt.getInt(i) == e)
+ }
+
+ val unsafeDouble = UnsafeArrayData.fromPrimitiveArray(doubleArray)
+ assert(unsafeDouble.numElements == 3)
+ assert(unsafeDouble.getSizeInBytes ==
+ ((8 + scala.math.ceil(3/64.toDouble) * 8 + 8 * 3 + 7).toInt / 8) * 8)
+ doubleArray.zipWithIndex.map { case (e, i) =>
+ assert(unsafeDouble.getDouble(i) == e)
+ }
+ }
+
+ test("to primitive array") {
+ val intEncoder = ExpressionEncoder[Array[Int]].resolveAndBind()
+ assert(intEncoder.toRow(intArray).getArray(0).toIntArray.sameElements(intArray))
+
+ val doubleEncoder = ExpressionEncoder[Array[Double]].resolveAndBind()
+ assert(doubleEncoder.toRow(doubleArray).getArray(0).toDoubleArray.sameElements(doubleArray))
}
}