aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java43
-rw-r--r--sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeMapData.java88
-rw-r--r--sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeReaders.java54
-rw-r--r--sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java8
-rw-r--r--sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java24
-rw-r--r--sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java15
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala60
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala42
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala30
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala2
10 files changed, 174 insertions, 192 deletions
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java
index 4c63abb071..761f044794 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java
@@ -30,19 +30,18 @@ import org.apache.spark.unsafe.types.UTF8String;
/**
* An Unsafe implementation of Array which is backed by raw memory instead of Java objects.
*
- * Each tuple has two parts: [offsets] [values]
+ * Each tuple has three parts: [numElements] [offsets] [values]
*
- * In the `offsets` region, we store 4 bytes per element, represents the start address of this
- * element in `values` region. We can get the length of this element by subtracting next offset.
+ * The `numElements` is 4 bytes storing the number of elements of this array.
+ *
+ * In the `offsets` region, we store 4 bytes per element, represents the relative offset (w.r.t. the
+ * base address of the array) of this element in `values` region. We can get the length of this
+ * element by subtracting next offset.
* Note that offset can by negative which means this element is null.
*
* In the `values` region, we store the content of elements. As we can get length info, so elements
* can be variable-length.
*
- * Note that when we write out this array, we should write out the `numElements` at first 4 bytes,
- * then follows content. When we read in an array, we should read first 4 bytes as `numElements`
- * and take the rest as content.
- *
* Instances of `UnsafeArrayData` act as pointers to row data stored in this format.
*/
// todo: there is a lof of duplicated code between UnsafeRow and UnsafeArrayData.
@@ -54,11 +53,16 @@ public class UnsafeArrayData extends ArrayData {
// The number of elements in this array
private int numElements;
- // The size of this array's backing data, in bytes
+ // The size of this array's backing data, in bytes.
+ // The 4-bytes header of `numElements` is also included.
private int sizeInBytes;
+ public Object getBaseObject() { return baseObject; }
+ public long getBaseOffset() { return baseOffset; }
+ public int getSizeInBytes() { return sizeInBytes; }
+
private int getElementOffset(int ordinal) {
- return Platform.getInt(baseObject, baseOffset + ordinal * 4L);
+ return Platform.getInt(baseObject, baseOffset + 4 + ordinal * 4L);
}
private int getElementSize(int offset, int ordinal) {
@@ -85,10 +89,6 @@ public class UnsafeArrayData extends ArrayData {
*/
public UnsafeArrayData() { }
- public Object getBaseObject() { return baseObject; }
- public long getBaseOffset() { return baseOffset; }
- public int getSizeInBytes() { return sizeInBytes; }
-
@Override
public int numElements() { return numElements; }
@@ -97,10 +97,13 @@ public class UnsafeArrayData extends ArrayData {
*
* @param baseObject the base object
* @param baseOffset the offset within the base object
- * @param sizeInBytes the size of this row's backing data, in bytes
+ * @param sizeInBytes the size of this array's backing data, in bytes
*/
- public void pointTo(Object baseObject, long baseOffset, int numElements, int sizeInBytes) {
+ public void pointTo(Object baseObject, long baseOffset, int sizeInBytes) {
+ // Read the number of elements from the first 4 bytes.
+ final int numElements = Platform.getInt(baseObject, baseOffset);
assert numElements >= 0 : "numElements (" + numElements + ") should >= 0";
+
this.numElements = numElements;
this.baseObject = baseObject;
this.baseOffset = baseOffset;
@@ -277,7 +280,9 @@ public class UnsafeArrayData extends ArrayData {
final int offset = getElementOffset(ordinal);
if (offset < 0) return null;
final int size = getElementSize(offset, ordinal);
- return UnsafeReaders.readArray(baseObject, baseOffset + offset, size);
+ final UnsafeArrayData array = new UnsafeArrayData();
+ array.pointTo(baseObject, baseOffset + offset, size);
+ return array;
}
@Override
@@ -286,7 +291,9 @@ public class UnsafeArrayData extends ArrayData {
final int offset = getElementOffset(ordinal);
if (offset < 0) return null;
final int size = getElementSize(offset, ordinal);
- return UnsafeReaders.readMap(baseObject, baseOffset + offset, size);
+ final UnsafeMapData map = new UnsafeMapData();
+ map.pointTo(baseObject, baseOffset + offset, size);
+ return map;
}
@Override
@@ -328,7 +335,7 @@ public class UnsafeArrayData extends ArrayData {
final byte[] arrayDataCopy = new byte[sizeInBytes];
Platform.copyMemory(
baseObject, baseOffset, arrayDataCopy, Platform.BYTE_ARRAY_OFFSET, sizeInBytes);
- arrayCopy.pointTo(arrayDataCopy, Platform.BYTE_ARRAY_OFFSET, numElements, sizeInBytes);
+ arrayCopy.pointTo(arrayDataCopy, Platform.BYTE_ARRAY_OFFSET, sizeInBytes);
return arrayCopy;
}
}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeMapData.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeMapData.java
index e9dab9edb6..5bebe2a96e 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeMapData.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeMapData.java
@@ -17,41 +17,73 @@
package org.apache.spark.sql.catalyst.expressions;
+import java.nio.ByteBuffer;
+
import org.apache.spark.sql.types.MapData;
+import org.apache.spark.unsafe.Platform;
/**
* An Unsafe implementation of Map which is backed by raw memory instead of Java objects.
*
- * Currently we just use 2 UnsafeArrayData to represent UnsafeMapData.
- *
- * Note that when we write out this map, we should write out the `numElements` at first 4 bytes,
- * and numBytes of key array at second 4 bytes, then follows key array content and value array
- * content without `numElements` header.
- * When we read in a map, we should read first 4 bytes as `numElements` and second 4 bytes as
- * numBytes of key array, and construct unsafe key array and value array with these 2 information.
+ * Currently we just use 2 UnsafeArrayData to represent UnsafeMapData, with extra 4 bytes at head
+ * to indicate the number of bytes of the unsafe key array.
+ * [unsafe key array numBytes] [unsafe key array] [unsafe value array]
*/
+// TODO: Use a more efficient format which doesn't depend on unsafe array.
public class UnsafeMapData extends MapData {
- private final UnsafeArrayData keys;
- private final UnsafeArrayData values;
- // The number of elements in this array
- private int numElements;
- // The size of this array's backing data, in bytes
+ private Object baseObject;
+ private long baseOffset;
+
+ // The size of this map's backing data, in bytes.
+ // The 4-bytes header of key array `numBytes` is also included, so it's actually equal to
+ // 4 + key array numBytes + value array numBytes.
private int sizeInBytes;
+ public Object getBaseObject() { return baseObject; }
+ public long getBaseOffset() { return baseOffset; }
public int getSizeInBytes() { return sizeInBytes; }
- public UnsafeMapData(UnsafeArrayData keys, UnsafeArrayData values) {
+ private final UnsafeArrayData keys;
+ private final UnsafeArrayData values;
+
+ /**
+ * Construct a new UnsafeMapData. The resulting UnsafeMapData won't be usable until
+ * `pointTo()` has been called, since the value returned by this constructor is equivalent
+ * to a null pointer.
+ */
+ public UnsafeMapData() {
+ keys = new UnsafeArrayData();
+ values = new UnsafeArrayData();
+ }
+
+ /**
+ * Update this UnsafeMapData to point to different backing data.
+ *
+ * @param baseObject the base object
+ * @param baseOffset the offset within the base object
+ * @param sizeInBytes the size of this map's backing data, in bytes
+ */
+ public void pointTo(Object baseObject, long baseOffset, int sizeInBytes) {
+ // Read the numBytes of key array from the first 4 bytes.
+ final int keyArraySize = Platform.getInt(baseObject, baseOffset);
+ final int valueArraySize = sizeInBytes - keyArraySize - 4;
+ assert keyArraySize >= 0 : "keyArraySize (" + keyArraySize + ") should >= 0";
+ assert valueArraySize >= 0 : "valueArraySize (" + valueArraySize + ") should >= 0";
+
+ keys.pointTo(baseObject, baseOffset + 4, keyArraySize);
+ values.pointTo(baseObject, baseOffset + 4 + keyArraySize, valueArraySize);
+
assert keys.numElements() == values.numElements();
- this.sizeInBytes = keys.getSizeInBytes() + values.getSizeInBytes();
- this.numElements = keys.numElements();
- this.keys = keys;
- this.values = values;
+
+ this.baseObject = baseObject;
+ this.baseOffset = baseOffset;
+ this.sizeInBytes = sizeInBytes;
}
@Override
public int numElements() {
- return numElements;
+ return keys.numElements();
}
@Override
@@ -64,8 +96,26 @@ public class UnsafeMapData extends MapData {
return values;
}
+ public void writeToMemory(Object target, long targetOffset) {
+ Platform.copyMemory(baseObject, baseOffset, target, targetOffset, sizeInBytes);
+ }
+
+ public void writeTo(ByteBuffer buffer) {
+ assert(buffer.hasArray());
+ byte[] target = buffer.array();
+ int offset = buffer.arrayOffset();
+ int pos = buffer.position();
+ writeToMemory(target, Platform.BYTE_ARRAY_OFFSET + offset + pos);
+ buffer.position(pos + sizeInBytes);
+ }
+
@Override
public UnsafeMapData copy() {
- return new UnsafeMapData(keys.copy(), values.copy());
+ UnsafeMapData mapCopy = new UnsafeMapData();
+ final byte[] mapDataCopy = new byte[sizeInBytes];
+ Platform.copyMemory(
+ baseObject, baseOffset, mapDataCopy, Platform.BYTE_ARRAY_OFFSET, sizeInBytes);
+ mapCopy.pointTo(mapDataCopy, Platform.BYTE_ARRAY_OFFSET, sizeInBytes);
+ return mapCopy;
}
}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeReaders.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeReaders.java
deleted file mode 100644
index 6c5fcbca63..0000000000
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeReaders.java
+++ /dev/null
@@ -1,54 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.catalyst.expressions;
-
-import org.apache.spark.unsafe.Platform;
-
-public class UnsafeReaders {
-
- /**
- * Reads in unsafe array according to the format described in `UnsafeArrayData`.
- */
- public static UnsafeArrayData readArray(Object baseObject, long baseOffset, int numBytes) {
- // Read the number of elements from first 4 bytes.
- final int numElements = Platform.getInt(baseObject, baseOffset);
- final UnsafeArrayData array = new UnsafeArrayData();
- // Skip the first 4 bytes.
- array.pointTo(baseObject, baseOffset + 4, numElements, numBytes - 4);
- return array;
- }
-
- /**
- * Reads in unsafe map according to the format described in `UnsafeMapData`.
- */
- public static UnsafeMapData readMap(Object baseObject, long baseOffset, int numBytes) {
- // Read the number of elements from first 4 bytes.
- final int numElements = Platform.getInt(baseObject, baseOffset);
- // Read the numBytes of key array in second 4 bytes.
- final int keyArraySize = Platform.getInt(baseObject, baseOffset + 4);
- final int valueArraySize = numBytes - 8 - keyArraySize;
-
- final UnsafeArrayData keyArray = new UnsafeArrayData();
- keyArray.pointTo(baseObject, baseOffset + 8, numElements, keyArraySize);
-
- final UnsafeArrayData valueArray = new UnsafeArrayData();
- valueArray.pointTo(baseObject, baseOffset + 8 + keyArraySize, numElements, valueArraySize);
-
- return new UnsafeMapData(keyArray, valueArray);
- }
-}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
index 36859fbab9..366615f6fe 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
@@ -461,7 +461,9 @@ public final class UnsafeRow extends MutableRow implements Externalizable, KryoS
final long offsetAndSize = getLong(ordinal);
final int offset = (int) (offsetAndSize >> 32);
final int size = (int) (offsetAndSize & ((1L << 32) - 1));
- return UnsafeReaders.readArray(baseObject, baseOffset + offset, size);
+ final UnsafeArrayData array = new UnsafeArrayData();
+ array.pointTo(baseObject, baseOffset + offset, size);
+ return array;
}
}
@@ -473,7 +475,9 @@ public final class UnsafeRow extends MutableRow implements Externalizable, KryoS
final long offsetAndSize = getLong(ordinal);
final int offset = (int) (offsetAndSize >> 32);
final int size = (int) (offsetAndSize & ((1L << 32) - 1));
- return UnsafeReaders.readMap(baseObject, baseOffset + offset, size);
+ final UnsafeMapData map = new UnsafeMapData();
+ map.pointTo(baseObject, baseOffset + offset, size);
+ return map;
}
}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java
index 138178ce99..7f2a1cb07a 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java
@@ -30,17 +30,19 @@ import org.apache.spark.unsafe.types.UTF8String;
public class UnsafeArrayWriter {
private BufferHolder holder;
+
// The offset of the global buffer where we start to write this array.
private int startingOffset;
public void initialize(BufferHolder holder, int numElements, int fixedElementSize) {
- // We need 4 bytes each element to store offset.
- final int fixedSize = 4 * numElements;
+ // We need 4 bytes to store numElements and 4 bytes each element to store offset.
+ final int fixedSize = 4 + 4 * numElements;
this.holder = holder;
this.startingOffset = holder.cursor;
holder.grow(fixedSize);
+ Platform.putInt(holder.buffer, holder.cursor, numElements);
holder.cursor += fixedSize;
// Grows the global buffer ahead for fixed size data.
@@ -48,7 +50,7 @@ public class UnsafeArrayWriter {
}
private long getElementOffset(int ordinal) {
- return startingOffset + 4 * ordinal;
+ return startingOffset + 4 + 4 * ordinal;
}
public void setNullAt(int ordinal) {
@@ -132,20 +134,4 @@ public class UnsafeArrayWriter {
// move the cursor forward.
holder.cursor += 16;
}
-
-
-
- // If this array is already an UnsafeArray, we don't need to go through all elements, we can
- // directly write it.
- public static void directWrite(BufferHolder holder, UnsafeArrayData input) {
- final int numBytes = input.getSizeInBytes();
-
- // grow the global buffer before writing data.
- holder.grow(numBytes);
-
- // Writes the array content to the variable length portion.
- input.writeToMemory(holder.buffer, holder.cursor);
-
- holder.cursor += numBytes;
- }
}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java
index 8b7debd440..e1f5a05d1d 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java
@@ -181,19 +181,4 @@ public class UnsafeRowWriter {
// move the cursor forward.
holder.cursor += 16;
}
-
-
-
- // If this struct is already an UnsafeRow, we don't need to go through all fields, we can
- // directly write it.
- public static void directWrite(BufferHolder holder, UnsafeRow input) {
- // No need to zero-out the bytes as UnsafeRow is word aligned for sure.
- final int numBytes = input.getSizeInBytes();
- // grow the global buffer before writing data.
- holder.grow(numBytes);
- // Write the bytes to the variable length portion.
- input.writeToMemory(holder.buffer, holder.cursor);
- // move the cursor forward.
- holder.cursor += numBytes;
- }
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
index 1b957a508d..dbe92d6a83 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
@@ -62,7 +62,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
s"""
if ($input instanceof UnsafeRow) {
- $rowWriterClass.directWrite($bufferHolder, (UnsafeRow) $input);
+ ${writeUnsafeData(ctx, s"((UnsafeRow) $input)", bufferHolder)}
} else {
${writeExpressionsToBuffer(ctx, input, fieldEvals, fieldTypes, bufferHolder)}
}
@@ -164,8 +164,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
ctx: CodeGenContext,
input: String,
elementType: DataType,
- bufferHolder: String,
- needHeader: Boolean = true): String = {
+ bufferHolder: String): String = {
val arrayWriter = ctx.freshName("arrayWriter")
ctx.addMutableState(arrayWriterClass, arrayWriter,
s"this.$arrayWriter = new $arrayWriterClass();")
@@ -227,21 +226,11 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
case _ => s"$arrayWriter.write($index, $element);"
}
- val writeHeader = if (needHeader) {
- // If header is required, we need to write the number of elements into first 4 bytes.
- s"""
- $bufferHolder.grow(4);
- Platform.putInt($bufferHolder.buffer, $bufferHolder.cursor, $numElements);
- $bufferHolder.cursor += 4;
- """
- } else ""
-
s"""
- final int $numElements = $input.numElements();
- $writeHeader
if ($input instanceof UnsafeArrayData) {
- $arrayWriterClass.directWrite($bufferHolder, (UnsafeArrayData) $input);
+ ${writeUnsafeData(ctx, s"((UnsafeArrayData) $input)", bufferHolder)}
} else {
+ final int $numElements = $input.numElements();
$arrayWriter.initialize($bufferHolder, $numElements, $fixedElementSize);
for (int $index = 0; $index < $numElements; $index++) {
@@ -270,23 +259,40 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
// Writes out unsafe map according to the format described in `UnsafeMapData`.
s"""
- final ArrayData $keys = $input.keyArray();
- final ArrayData $values = $input.valueArray();
+ if ($input instanceof UnsafeMapData) {
+ ${writeUnsafeData(ctx, s"((UnsafeMapData) $input)", bufferHolder)}
+ } else {
+ final ArrayData $keys = $input.keyArray();
+ final ArrayData $values = $input.valueArray();
- $bufferHolder.grow(8);
+ // preserve 4 bytes to write the key array numBytes later.
+ $bufferHolder.grow(4);
+ $bufferHolder.cursor += 4;
- // Write the numElements into first 4 bytes.
- Platform.putInt($bufferHolder.buffer, $bufferHolder.cursor, $keys.numElements());
+ // Remember the current cursor so that we can write numBytes of key array later.
+ final int $tmpCursor = $bufferHolder.cursor;
- $bufferHolder.cursor += 8;
- // Remember the current cursor so that we can write numBytes of key array later.
- final int $tmpCursor = $bufferHolder.cursor;
+ ${writeArrayToBuffer(ctx, keys, keyType, bufferHolder)}
+ // Write the numBytes of key array into the first 4 bytes.
+ Platform.putInt($bufferHolder.buffer, $tmpCursor - 4, $bufferHolder.cursor - $tmpCursor);
- ${writeArrayToBuffer(ctx, keys, keyType, bufferHolder, needHeader = false)}
- // Write the numBytes of key array into second 4 bytes.
- Platform.putInt($bufferHolder.buffer, $tmpCursor - 4, $bufferHolder.cursor - $tmpCursor);
+ ${writeArrayToBuffer(ctx, values, valueType, bufferHolder)}
+ }
+ """
+ }
- ${writeArrayToBuffer(ctx, values, valueType, bufferHolder, needHeader = false)}
+ /**
+ * If the input is already in unsafe format, we don't need to go through all elements/fields,
+ * we can directly write it.
+ */
+ private def writeUnsafeData(ctx: CodeGenContext, input: String, bufferHolder: String) = {
+ val sizeInBytes = ctx.freshName("sizeInBytes")
+ s"""
+ final int $sizeInBytes = $input.getSizeInBytes();
+ // grow the global buffer before writing data.
+ $bufferHolder.grow($sizeInBytes);
+ $input.writeToMemory($bufferHolder.buffer, $bufferHolder.cursor);
+ $bufferHolder.cursor += $sizeInBytes;
"""
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala
index c991cd86d2..c6aad34e97 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala
@@ -296,13 +296,9 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
new ArrayBasedMapData(createArray(keys: _*), createArray(values: _*))
}
- private def arraySizeInRow(numBytes: Int): Int = roundedSize(4 + numBytes)
-
- private def mapSizeInRow(numBytes: Int): Int = roundedSize(8 + numBytes)
-
private def testArrayInt(array: UnsafeArrayData, values: Seq[Int]): Unit = {
assert(array.numElements == values.length)
- assert(array.getSizeInBytes == (4 + 4) * values.length)
+ assert(array.getSizeInBytes == 4 + (4 + 4) * values.length)
values.zipWithIndex.foreach {
case (value, index) => assert(array.getInt(index) == value)
}
@@ -315,7 +311,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
testArrayInt(map.keyArray, keys)
testArrayInt(map.valueArray, values)
- assert(map.getSizeInBytes == map.keyArray.getSizeInBytes + map.valueArray.getSizeInBytes)
+ assert(map.getSizeInBytes == 4 + map.keyArray.getSizeInBytes + map.valueArray.getSizeInBytes)
}
test("basic conversion with array type") {
@@ -341,10 +337,10 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
val nestedArray = unsafeArray2.getArray(0)
testArrayInt(nestedArray, Seq(3, 4))
- assert(unsafeArray2.getSizeInBytes == 4 + (4 + nestedArray.getSizeInBytes))
+ assert(unsafeArray2.getSizeInBytes == 4 + 4 + nestedArray.getSizeInBytes)
- val array1Size = arraySizeInRow(unsafeArray1.getSizeInBytes)
- val array2Size = arraySizeInRow(unsafeArray2.getSizeInBytes)
+ val array1Size = roundedSize(unsafeArray1.getSizeInBytes)
+ val array2Size = roundedSize(unsafeArray2.getSizeInBytes)
assert(unsafeRow.getSizeInBytes == 8 + 8 * 2 + array1Size + array2Size)
}
@@ -384,13 +380,13 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
val nestedMap = valueArray.getMap(0)
testMapInt(nestedMap, Seq(5, 6), Seq(7, 8))
- assert(valueArray.getSizeInBytes == 4 + (8 + nestedMap.getSizeInBytes))
+ assert(valueArray.getSizeInBytes == 4 + 4 + nestedMap.getSizeInBytes)
}
- assert(unsafeMap2.getSizeInBytes == keyArray.getSizeInBytes + valueArray.getSizeInBytes)
+ assert(unsafeMap2.getSizeInBytes == 4 + keyArray.getSizeInBytes + valueArray.getSizeInBytes)
- val map1Size = mapSizeInRow(unsafeMap1.getSizeInBytes)
- val map2Size = mapSizeInRow(unsafeMap2.getSizeInBytes)
+ val map1Size = roundedSize(unsafeMap1.getSizeInBytes)
+ val map2Size = roundedSize(unsafeMap2.getSizeInBytes)
assert(unsafeRow.getSizeInBytes == 8 + 8 * 2 + map1Size + map2Size)
}
@@ -414,7 +410,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
val innerArray = field1.getArray(0)
testArrayInt(innerArray, Seq(1))
- assert(field1.getSizeInBytes == 8 + 8 + arraySizeInRow(innerArray.getSizeInBytes))
+ assert(field1.getSizeInBytes == 8 + 8 + roundedSize(innerArray.getSizeInBytes))
val field2 = unsafeRow.getArray(1)
assert(field2.numElements == 1)
@@ -427,10 +423,10 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
assert(innerStruct.getLong(0) == 2L)
}
- assert(field2.getSizeInBytes == 4 + innerStruct.getSizeInBytes)
+ assert(field2.getSizeInBytes == 4 + 4 + innerStruct.getSizeInBytes)
assert(unsafeRow.getSizeInBytes ==
- 8 + 8 * 2 + field1.getSizeInBytes + arraySizeInRow(field2.getSizeInBytes))
+ 8 + 8 * 2 + field1.getSizeInBytes + roundedSize(field2.getSizeInBytes))
}
test("basic conversion with struct and map") {
@@ -453,7 +449,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
val innerMap = field1.getMap(0)
testMapInt(innerMap, Seq(1), Seq(2))
- assert(field1.getSizeInBytes == 8 + 8 + mapSizeInRow(innerMap.getSizeInBytes))
+ assert(field1.getSizeInBytes == 8 + 8 + roundedSize(innerMap.getSizeInBytes))
val field2 = unsafeRow.getMap(1)
@@ -470,13 +466,13 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
assert(innerStruct.getSizeInBytes == 8 + 8)
assert(innerStruct.getLong(0) == 4L)
- assert(valueArray.getSizeInBytes == 4 + innerStruct.getSizeInBytes)
+ assert(valueArray.getSizeInBytes == 4 + 4 + innerStruct.getSizeInBytes)
}
- assert(field2.getSizeInBytes == keyArray.getSizeInBytes + valueArray.getSizeInBytes)
+ assert(field2.getSizeInBytes == 4 + keyArray.getSizeInBytes + valueArray.getSizeInBytes)
assert(unsafeRow.getSizeInBytes ==
- 8 + 8 * 2 + field1.getSizeInBytes + mapSizeInRow(field2.getSizeInBytes))
+ 8 + 8 * 2 + field1.getSizeInBytes + roundedSize(field2.getSizeInBytes))
}
test("basic conversion with array and map") {
@@ -499,7 +495,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
val innerMap = field1.getMap(0)
testMapInt(innerMap, Seq(1), Seq(2))
- assert(field1.getSizeInBytes == 4 + (8 + innerMap.getSizeInBytes))
+ assert(field1.getSizeInBytes == 4 + 4 + innerMap.getSizeInBytes)
val field2 = unsafeRow.getMap(1)
assert(field2.numElements == 1)
@@ -518,9 +514,9 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
assert(valueArray.getSizeInBytes == 4 + (4 + innerArray.getSizeInBytes))
}
- assert(field2.getSizeInBytes == keyArray.getSizeInBytes + valueArray.getSizeInBytes)
+ assert(field2.getSizeInBytes == 4 + keyArray.getSizeInBytes + valueArray.getSizeInBytes)
assert(unsafeRow.getSizeInBytes ==
- 8 + 8 * 2 + arraySizeInRow(field1.getSizeInBytes) + mapSizeInRow(field2.getSizeInBytes))
+ 8 + 8 * 2 + roundedSize(field1.getSizeInBytes) + roundedSize(field2.getSizeInBytes))
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala
index 2bc2c96b61..a41f04dd3b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala
@@ -482,12 +482,14 @@ private[sql] case class STRUCT(dataType: StructType) extends ColumnType[UnsafeRo
override def extract(buffer: ByteBuffer): UnsafeRow = {
val sizeInBytes = buffer.getInt()
assert(buffer.hasArray)
- val base = buffer.array()
- val offset = buffer.arrayOffset()
val cursor = buffer.position()
buffer.position(cursor + sizeInBytes)
val unsafeRow = new UnsafeRow
- unsafeRow.pointTo(base, Platform.BYTE_ARRAY_OFFSET + offset + cursor, numOfFields, sizeInBytes)
+ unsafeRow.pointTo(
+ buffer.array(),
+ Platform.BYTE_ARRAY_OFFSET + buffer.arrayOffset() + cursor,
+ numOfFields,
+ sizeInBytes)
unsafeRow
}
@@ -508,12 +510,11 @@ private[sql] case class ARRAY(dataType: ArrayType) extends ColumnType[UnsafeArra
override def actualSize(row: InternalRow, ordinal: Int): Int = {
val unsafeArray = getField(row, ordinal)
- 4 + 4 + unsafeArray.getSizeInBytes
+ 4 + unsafeArray.getSizeInBytes
}
override def append(value: UnsafeArrayData, buffer: ByteBuffer): Unit = {
- buffer.putInt(4 + value.getSizeInBytes)
- buffer.putInt(value.numElements())
+ buffer.putInt(value.getSizeInBytes)
value.writeTo(buffer)
}
@@ -522,10 +523,12 @@ private[sql] case class ARRAY(dataType: ArrayType) extends ColumnType[UnsafeArra
assert(buffer.hasArray)
val cursor = buffer.position()
buffer.position(cursor + numBytes)
- UnsafeReaders.readArray(
+ val array = new UnsafeArrayData
+ array.pointTo(
buffer.array(),
Platform.BYTE_ARRAY_OFFSET + buffer.arrayOffset() + cursor,
numBytes)
+ array
}
override def clone(v: UnsafeArrayData): UnsafeArrayData = v.copy()
@@ -545,15 +548,12 @@ private[sql] case class MAP(dataType: MapType) extends ColumnType[UnsafeMapData]
override def actualSize(row: InternalRow, ordinal: Int): Int = {
val unsafeMap = getField(row, ordinal)
- 12 + unsafeMap.keyArray().getSizeInBytes + unsafeMap.valueArray().getSizeInBytes
+ 4 + unsafeMap.getSizeInBytes
}
override def append(value: UnsafeMapData, buffer: ByteBuffer): Unit = {
- buffer.putInt(8 + value.keyArray().getSizeInBytes + value.valueArray().getSizeInBytes)
- buffer.putInt(value.numElements())
- buffer.putInt(value.keyArray().getSizeInBytes)
- value.keyArray().writeTo(buffer)
- value.valueArray().writeTo(buffer)
+ buffer.putInt(value.getSizeInBytes)
+ value.writeTo(buffer)
}
override def extract(buffer: ByteBuffer): UnsafeMapData = {
@@ -561,10 +561,12 @@ private[sql] case class MAP(dataType: MapType) extends ColumnType[UnsafeMapData]
assert(buffer.hasArray)
val cursor = buffer.position()
buffer.position(cursor + numBytes)
- UnsafeReaders.readMap(
+ val map = new UnsafeMapData
+ map.pointTo(
buffer.array(),
Platform.BYTE_ARRAY_OFFSET + buffer.arrayOffset() + cursor,
numBytes)
+ map
}
override def clone(v: UnsafeMapData): UnsafeMapData = v.copy()
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala
index 0e6e1bcf72..63bc39bfa0 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala
@@ -73,7 +73,7 @@ class ColumnTypeSuite extends SparkFunSuite with Logging {
checkActualSize(COMPACT_DECIMAL(15, 10), Decimal(0, 15, 10), 8)
checkActualSize(LARGE_DECIMAL(20, 10), Decimal(0, 20, 10), 5)
checkActualSize(ARRAY_TYPE, Array[Any](1), 16)
- checkActualSize(MAP_TYPE, Map(1 -> "a"), 25)
+ checkActualSize(MAP_TYPE, Map(1 -> "a"), 29)
checkActualSize(STRUCT_TYPE, Row("hello"), 28)
}