aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorWenchen Fan <wenchen@databricks.com>2015-10-19 11:02:26 -0700
committerDavies Liu <davies.liu@gmail.com>2015-10-19 11:02:26 -0700
commit7893cd95db5f2caba59ff5c859d7e4964ad7938d (patch)
tree6bfb747406a4e138b0f6042c51b420940a378836 /sql
parent4c33a34ba3167ae67fdb4978ea2166ce65638fb9 (diff)
downloadspark-7893cd95db5f2caba59ff5c859d7e4964ad7938d.tar.gz
spark-7893cd95db5f2caba59ff5c859d7e4964ad7938d.tar.bz2
spark-7893cd95db5f2caba59ff5c859d7e4964ad7938d.zip
[SPARK-11119] [SQL] cleanup for unsafe array and map
The purpose of this PR is to keep the unsafe format detail only inside the unsafe class itself, so when we use them(like use unsafe array in unsafe map, use unsafe array and map in columnar cache), we don't need to understand the format before use them. change list: * unsafe array's 4-bytes numElements header is now required(was optional), and become a part of unsafe array format. * w.r.t the previous changing, the `sizeInBytes` of unsafe array now counts the 4-bytes header. * unsafe map's format was `[numElements] [key array numBytes] [key array content(without numElements header)] [value array content(without numElements header)]` before, which is a little hacky as it makes unsafe array's header optional. I think saving 4 bytes is not a big deal, so the format is now: `[key array numBytes] [unsafe key array] [unsafe value array]`. * w.r.t the previous changing, the `sizeInBytes` of unsafe map now counts both map's header and array's header. Author: Wenchen Fan <wenchen@databricks.com> Closes #9131 from cloud-fan/unsafe.
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java43
-rw-r--r--sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeMapData.java88
-rw-r--r--sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeReaders.java54
-rw-r--r--sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java8
-rw-r--r--sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java24
-rw-r--r--sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java15
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala60
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala42
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala30
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala2
10 files changed, 174 insertions, 192 deletions
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java
index 4c63abb071..761f044794 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java
@@ -30,19 +30,18 @@ import org.apache.spark.unsafe.types.UTF8String;
/**
* An Unsafe implementation of Array which is backed by raw memory instead of Java objects.
*
- * Each tuple has two parts: [offsets] [values]
+ * Each tuple has three parts: [numElements] [offsets] [values]
*
- * In the `offsets` region, we store 4 bytes per element, represents the start address of this
- * element in `values` region. We can get the length of this element by subtracting next offset.
+ * The `numElements` is 4 bytes storing the number of elements of this array.
+ *
+ * In the `offsets` region, we store 4 bytes per element, represents the relative offset (w.r.t. the
+ * base address of the array) of this element in `values` region. We can get the length of this
+ * element by subtracting next offset.
* Note that offset can by negative which means this element is null.
*
* In the `values` region, we store the content of elements. As we can get length info, so elements
* can be variable-length.
*
- * Note that when we write out this array, we should write out the `numElements` at first 4 bytes,
- * then follows content. When we read in an array, we should read first 4 bytes as `numElements`
- * and take the rest as content.
- *
* Instances of `UnsafeArrayData` act as pointers to row data stored in this format.
*/
// todo: there is a lof of duplicated code between UnsafeRow and UnsafeArrayData.
@@ -54,11 +53,16 @@ public class UnsafeArrayData extends ArrayData {
// The number of elements in this array
private int numElements;
- // The size of this array's backing data, in bytes
+ // The size of this array's backing data, in bytes.
+ // The 4-bytes header of `numElements` is also included.
private int sizeInBytes;
+ public Object getBaseObject() { return baseObject; }
+ public long getBaseOffset() { return baseOffset; }
+ public int getSizeInBytes() { return sizeInBytes; }
+
private int getElementOffset(int ordinal) {
- return Platform.getInt(baseObject, baseOffset + ordinal * 4L);
+ return Platform.getInt(baseObject, baseOffset + 4 + ordinal * 4L);
}
private int getElementSize(int offset, int ordinal) {
@@ -85,10 +89,6 @@ public class UnsafeArrayData extends ArrayData {
*/
public UnsafeArrayData() { }
- public Object getBaseObject() { return baseObject; }
- public long getBaseOffset() { return baseOffset; }
- public int getSizeInBytes() { return sizeInBytes; }
-
@Override
public int numElements() { return numElements; }
@@ -97,10 +97,13 @@ public class UnsafeArrayData extends ArrayData {
*
* @param baseObject the base object
* @param baseOffset the offset within the base object
- * @param sizeInBytes the size of this row's backing data, in bytes
+ * @param sizeInBytes the size of this array's backing data, in bytes
*/
- public void pointTo(Object baseObject, long baseOffset, int numElements, int sizeInBytes) {
+ public void pointTo(Object baseObject, long baseOffset, int sizeInBytes) {
+ // Read the number of elements from the first 4 bytes.
+ final int numElements = Platform.getInt(baseObject, baseOffset);
assert numElements >= 0 : "numElements (" + numElements + ") should >= 0";
+
this.numElements = numElements;
this.baseObject = baseObject;
this.baseOffset = baseOffset;
@@ -277,7 +280,9 @@ public class UnsafeArrayData extends ArrayData {
final int offset = getElementOffset(ordinal);
if (offset < 0) return null;
final int size = getElementSize(offset, ordinal);
- return UnsafeReaders.readArray(baseObject, baseOffset + offset, size);
+ final UnsafeArrayData array = new UnsafeArrayData();
+ array.pointTo(baseObject, baseOffset + offset, size);
+ return array;
}
@Override
@@ -286,7 +291,9 @@ public class UnsafeArrayData extends ArrayData {
final int offset = getElementOffset(ordinal);
if (offset < 0) return null;
final int size = getElementSize(offset, ordinal);
- return UnsafeReaders.readMap(baseObject, baseOffset + offset, size);
+ final UnsafeMapData map = new UnsafeMapData();
+ map.pointTo(baseObject, baseOffset + offset, size);
+ return map;
}
@Override
@@ -328,7 +335,7 @@ public class UnsafeArrayData extends ArrayData {
final byte[] arrayDataCopy = new byte[sizeInBytes];
Platform.copyMemory(
baseObject, baseOffset, arrayDataCopy, Platform.BYTE_ARRAY_OFFSET, sizeInBytes);
- arrayCopy.pointTo(arrayDataCopy, Platform.BYTE_ARRAY_OFFSET, numElements, sizeInBytes);
+ arrayCopy.pointTo(arrayDataCopy, Platform.BYTE_ARRAY_OFFSET, sizeInBytes);
return arrayCopy;
}
}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeMapData.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeMapData.java
index e9dab9edb6..5bebe2a96e 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeMapData.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeMapData.java
@@ -17,41 +17,73 @@
package org.apache.spark.sql.catalyst.expressions;
+import java.nio.ByteBuffer;
+
import org.apache.spark.sql.types.MapData;
+import org.apache.spark.unsafe.Platform;
/**
* An Unsafe implementation of Map which is backed by raw memory instead of Java objects.
*
- * Currently we just use 2 UnsafeArrayData to represent UnsafeMapData.
- *
- * Note that when we write out this map, we should write out the `numElements` at first 4 bytes,
- * and numBytes of key array at second 4 bytes, then follows key array content and value array
- * content without `numElements` header.
- * When we read in a map, we should read first 4 bytes as `numElements` and second 4 bytes as
- * numBytes of key array, and construct unsafe key array and value array with these 2 information.
+ * Currently we just use 2 UnsafeArrayData to represent UnsafeMapData, with extra 4 bytes at head
+ * to indicate the number of bytes of the unsafe key array.
+ * [unsafe key array numBytes] [unsafe key array] [unsafe value array]
*/
+// TODO: Use a more efficient format which doesn't depend on unsafe array.
public class UnsafeMapData extends MapData {
- private final UnsafeArrayData keys;
- private final UnsafeArrayData values;
- // The number of elements in this array
- private int numElements;
- // The size of this array's backing data, in bytes
+ private Object baseObject;
+ private long baseOffset;
+
+ // The size of this map's backing data, in bytes.
+ // The 4-bytes header of key array `numBytes` is also included, so it's actually equal to
+ // 4 + key array numBytes + value array numBytes.
private int sizeInBytes;
+ public Object getBaseObject() { return baseObject; }
+ public long getBaseOffset() { return baseOffset; }
public int getSizeInBytes() { return sizeInBytes; }
- public UnsafeMapData(UnsafeArrayData keys, UnsafeArrayData values) {
+ private final UnsafeArrayData keys;
+ private final UnsafeArrayData values;
+
+ /**
+ * Construct a new UnsafeMapData. The resulting UnsafeMapData won't be usable until
+ * `pointTo()` has been called, since the value returned by this constructor is equivalent
+ * to a null pointer.
+ */
+ public UnsafeMapData() {
+ keys = new UnsafeArrayData();
+ values = new UnsafeArrayData();
+ }
+
+ /**
+ * Update this UnsafeMapData to point to different backing data.
+ *
+ * @param baseObject the base object
+ * @param baseOffset the offset within the base object
+ * @param sizeInBytes the size of this map's backing data, in bytes
+ */
+ public void pointTo(Object baseObject, long baseOffset, int sizeInBytes) {
+ // Read the numBytes of key array from the first 4 bytes.
+ final int keyArraySize = Platform.getInt(baseObject, baseOffset);
+ final int valueArraySize = sizeInBytes - keyArraySize - 4;
+ assert keyArraySize >= 0 : "keyArraySize (" + keyArraySize + ") should >= 0";
+ assert valueArraySize >= 0 : "valueArraySize (" + valueArraySize + ") should >= 0";
+
+ keys.pointTo(baseObject, baseOffset + 4, keyArraySize);
+ values.pointTo(baseObject, baseOffset + 4 + keyArraySize, valueArraySize);
+
assert keys.numElements() == values.numElements();
- this.sizeInBytes = keys.getSizeInBytes() + values.getSizeInBytes();
- this.numElements = keys.numElements();
- this.keys = keys;
- this.values = values;
+
+ this.baseObject = baseObject;
+ this.baseOffset = baseOffset;
+ this.sizeInBytes = sizeInBytes;
}
@Override
public int numElements() {
- return numElements;
+ return keys.numElements();
}
@Override
@@ -64,8 +96,26 @@ public class UnsafeMapData extends MapData {
return values;
}
+ public void writeToMemory(Object target, long targetOffset) {
+ Platform.copyMemory(baseObject, baseOffset, target, targetOffset, sizeInBytes);
+ }
+
+ public void writeTo(ByteBuffer buffer) {
+ assert(buffer.hasArray());
+ byte[] target = buffer.array();
+ int offset = buffer.arrayOffset();
+ int pos = buffer.position();
+ writeToMemory(target, Platform.BYTE_ARRAY_OFFSET + offset + pos);
+ buffer.position(pos + sizeInBytes);
+ }
+
@Override
public UnsafeMapData copy() {
- return new UnsafeMapData(keys.copy(), values.copy());
+ UnsafeMapData mapCopy = new UnsafeMapData();
+ final byte[] mapDataCopy = new byte[sizeInBytes];
+ Platform.copyMemory(
+ baseObject, baseOffset, mapDataCopy, Platform.BYTE_ARRAY_OFFSET, sizeInBytes);
+ mapCopy.pointTo(mapDataCopy, Platform.BYTE_ARRAY_OFFSET, sizeInBytes);
+ return mapCopy;
}
}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeReaders.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeReaders.java
deleted file mode 100644
index 6c5fcbca63..0000000000
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeReaders.java
+++ /dev/null
@@ -1,54 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.catalyst.expressions;
-
-import org.apache.spark.unsafe.Platform;
-
-public class UnsafeReaders {
-
- /**
- * Reads in unsafe array according to the format described in `UnsafeArrayData`.
- */
- public static UnsafeArrayData readArray(Object baseObject, long baseOffset, int numBytes) {
- // Read the number of elements from first 4 bytes.
- final int numElements = Platform.getInt(baseObject, baseOffset);
- final UnsafeArrayData array = new UnsafeArrayData();
- // Skip the first 4 bytes.
- array.pointTo(baseObject, baseOffset + 4, numElements, numBytes - 4);
- return array;
- }
-
- /**
- * Reads in unsafe map according to the format described in `UnsafeMapData`.
- */
- public static UnsafeMapData readMap(Object baseObject, long baseOffset, int numBytes) {
- // Read the number of elements from first 4 bytes.
- final int numElements = Platform.getInt(baseObject, baseOffset);
- // Read the numBytes of key array in second 4 bytes.
- final int keyArraySize = Platform.getInt(baseObject, baseOffset + 4);
- final int valueArraySize = numBytes - 8 - keyArraySize;
-
- final UnsafeArrayData keyArray = new UnsafeArrayData();
- keyArray.pointTo(baseObject, baseOffset + 8, numElements, keyArraySize);
-
- final UnsafeArrayData valueArray = new UnsafeArrayData();
- valueArray.pointTo(baseObject, baseOffset + 8 + keyArraySize, numElements, valueArraySize);
-
- return new UnsafeMapData(keyArray, valueArray);
- }
-}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
index 36859fbab9..366615f6fe 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
@@ -461,7 +461,9 @@ public final class UnsafeRow extends MutableRow implements Externalizable, KryoS
final long offsetAndSize = getLong(ordinal);
final int offset = (int) (offsetAndSize >> 32);
final int size = (int) (offsetAndSize & ((1L << 32) - 1));
- return UnsafeReaders.readArray(baseObject, baseOffset + offset, size);
+ final UnsafeArrayData array = new UnsafeArrayData();
+ array.pointTo(baseObject, baseOffset + offset, size);
+ return array;
}
}
@@ -473,7 +475,9 @@ public final class UnsafeRow extends MutableRow implements Externalizable, KryoS
final long offsetAndSize = getLong(ordinal);
final int offset = (int) (offsetAndSize >> 32);
final int size = (int) (offsetAndSize & ((1L << 32) - 1));
- return UnsafeReaders.readMap(baseObject, baseOffset + offset, size);
+ final UnsafeMapData map = new UnsafeMapData();
+ map.pointTo(baseObject, baseOffset + offset, size);
+ return map;
}
}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java
index 138178ce99..7f2a1cb07a 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java
@@ -30,17 +30,19 @@ import org.apache.spark.unsafe.types.UTF8String;
public class UnsafeArrayWriter {
private BufferHolder holder;
+
// The offset of the global buffer where we start to write this array.
private int startingOffset;
public void initialize(BufferHolder holder, int numElements, int fixedElementSize) {
- // We need 4 bytes each element to store offset.
- final int fixedSize = 4 * numElements;
+ // We need 4 bytes to store numElements and 4 bytes each element to store offset.
+ final int fixedSize = 4 + 4 * numElements;
this.holder = holder;
this.startingOffset = holder.cursor;
holder.grow(fixedSize);
+ Platform.putInt(holder.buffer, holder.cursor, numElements);
holder.cursor += fixedSize;
// Grows the global buffer ahead for fixed size data.
@@ -48,7 +50,7 @@ public class UnsafeArrayWriter {
}
private long getElementOffset(int ordinal) {
- return startingOffset + 4 * ordinal;
+ return startingOffset + 4 + 4 * ordinal;
}
public void setNullAt(int ordinal) {
@@ -132,20 +134,4 @@ public class UnsafeArrayWriter {
// move the cursor forward.
holder.cursor += 16;
}
-
-
-
- // If this array is already an UnsafeArray, we don't need to go through all elements, we can
- // directly write it.
- public static void directWrite(BufferHolder holder, UnsafeArrayData input) {
- final int numBytes = input.getSizeInBytes();
-
- // grow the global buffer before writing data.
- holder.grow(numBytes);
-
- // Writes the array content to the variable length portion.
- input.writeToMemory(holder.buffer, holder.cursor);
-
- holder.cursor += numBytes;
- }
}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java
index 8b7debd440..e1f5a05d1d 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java
@@ -181,19 +181,4 @@ public class UnsafeRowWriter {
// move the cursor forward.
holder.cursor += 16;
}
-
-
-
- // If this struct is already an UnsafeRow, we don't need to go through all fields, we can
- // directly write it.
- public static void directWrite(BufferHolder holder, UnsafeRow input) {
- // No need to zero-out the bytes as UnsafeRow is word aligned for sure.
- final int numBytes = input.getSizeInBytes();
- // grow the global buffer before writing data.
- holder.grow(numBytes);
- // Write the bytes to the variable length portion.
- input.writeToMemory(holder.buffer, holder.cursor);
- // move the cursor forward.
- holder.cursor += numBytes;
- }
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
index 1b957a508d..dbe92d6a83 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
@@ -62,7 +62,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
s"""
if ($input instanceof UnsafeRow) {
- $rowWriterClass.directWrite($bufferHolder, (UnsafeRow) $input);
+ ${writeUnsafeData(ctx, s"((UnsafeRow) $input)", bufferHolder)}
} else {
${writeExpressionsToBuffer(ctx, input, fieldEvals, fieldTypes, bufferHolder)}
}
@@ -164,8 +164,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
ctx: CodeGenContext,
input: String,
elementType: DataType,
- bufferHolder: String,
- needHeader: Boolean = true): String = {
+ bufferHolder: String): String = {
val arrayWriter = ctx.freshName("arrayWriter")
ctx.addMutableState(arrayWriterClass, arrayWriter,
s"this.$arrayWriter = new $arrayWriterClass();")
@@ -227,21 +226,11 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
case _ => s"$arrayWriter.write($index, $element);"
}
- val writeHeader = if (needHeader) {
- // If header is required, we need to write the number of elements into first 4 bytes.
- s"""
- $bufferHolder.grow(4);
- Platform.putInt($bufferHolder.buffer, $bufferHolder.cursor, $numElements);
- $bufferHolder.cursor += 4;
- """
- } else ""
-
s"""
- final int $numElements = $input.numElements();
- $writeHeader
if ($input instanceof UnsafeArrayData) {
- $arrayWriterClass.directWrite($bufferHolder, (UnsafeArrayData) $input);
+ ${writeUnsafeData(ctx, s"((UnsafeArrayData) $input)", bufferHolder)}
} else {
+ final int $numElements = $input.numElements();
$arrayWriter.initialize($bufferHolder, $numElements, $fixedElementSize);
for (int $index = 0; $index < $numElements; $index++) {
@@ -270,23 +259,40 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
// Writes out unsafe map according to the format described in `UnsafeMapData`.
s"""
- final ArrayData $keys = $input.keyArray();
- final ArrayData $values = $input.valueArray();
+ if ($input instanceof UnsafeMapData) {
+ ${writeUnsafeData(ctx, s"((UnsafeMapData) $input)", bufferHolder)}
+ } else {
+ final ArrayData $keys = $input.keyArray();
+ final ArrayData $values = $input.valueArray();
- $bufferHolder.grow(8);
+ // preserve 4 bytes to write the key array numBytes later.
+ $bufferHolder.grow(4);
+ $bufferHolder.cursor += 4;
- // Write the numElements into first 4 bytes.
- Platform.putInt($bufferHolder.buffer, $bufferHolder.cursor, $keys.numElements());
+ // Remember the current cursor so that we can write numBytes of key array later.
+ final int $tmpCursor = $bufferHolder.cursor;
- $bufferHolder.cursor += 8;
- // Remember the current cursor so that we can write numBytes of key array later.
- final int $tmpCursor = $bufferHolder.cursor;
+ ${writeArrayToBuffer(ctx, keys, keyType, bufferHolder)}
+ // Write the numBytes of key array into the first 4 bytes.
+ Platform.putInt($bufferHolder.buffer, $tmpCursor - 4, $bufferHolder.cursor - $tmpCursor);
- ${writeArrayToBuffer(ctx, keys, keyType, bufferHolder, needHeader = false)}
- // Write the numBytes of key array into second 4 bytes.
- Platform.putInt($bufferHolder.buffer, $tmpCursor - 4, $bufferHolder.cursor - $tmpCursor);
+ ${writeArrayToBuffer(ctx, values, valueType, bufferHolder)}
+ }
+ """
+ }
- ${writeArrayToBuffer(ctx, values, valueType, bufferHolder, needHeader = false)}
+ /**
+ * If the input is already in unsafe format, we don't need to go through all elements/fields,
+ * we can directly write it.
+ */
+ private def writeUnsafeData(ctx: CodeGenContext, input: String, bufferHolder: String) = {
+ val sizeInBytes = ctx.freshName("sizeInBytes")
+ s"""
+ final int $sizeInBytes = $input.getSizeInBytes();
+ // grow the global buffer before writing data.
+ $bufferHolder.grow($sizeInBytes);
+ $input.writeToMemory($bufferHolder.buffer, $bufferHolder.cursor);
+ $bufferHolder.cursor += $sizeInBytes;
"""
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala
index c991cd86d2..c6aad34e97 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala
@@ -296,13 +296,9 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
new ArrayBasedMapData(createArray(keys: _*), createArray(values: _*))
}
- private def arraySizeInRow(numBytes: Int): Int = roundedSize(4 + numBytes)
-
- private def mapSizeInRow(numBytes: Int): Int = roundedSize(8 + numBytes)
-
private def testArrayInt(array: UnsafeArrayData, values: Seq[Int]): Unit = {
assert(array.numElements == values.length)
- assert(array.getSizeInBytes == (4 + 4) * values.length)
+ assert(array.getSizeInBytes == 4 + (4 + 4) * values.length)
values.zipWithIndex.foreach {
case (value, index) => assert(array.getInt(index) == value)
}
@@ -315,7 +311,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
testArrayInt(map.keyArray, keys)
testArrayInt(map.valueArray, values)
- assert(map.getSizeInBytes == map.keyArray.getSizeInBytes + map.valueArray.getSizeInBytes)
+ assert(map.getSizeInBytes == 4 + map.keyArray.getSizeInBytes + map.valueArray.getSizeInBytes)
}
test("basic conversion with array type") {
@@ -341,10 +337,10 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
val nestedArray = unsafeArray2.getArray(0)
testArrayInt(nestedArray, Seq(3, 4))
- assert(unsafeArray2.getSizeInBytes == 4 + (4 + nestedArray.getSizeInBytes))
+ assert(unsafeArray2.getSizeInBytes == 4 + 4 + nestedArray.getSizeInBytes)
- val array1Size = arraySizeInRow(unsafeArray1.getSizeInBytes)
- val array2Size = arraySizeInRow(unsafeArray2.getSizeInBytes)
+ val array1Size = roundedSize(unsafeArray1.getSizeInBytes)
+ val array2Size = roundedSize(unsafeArray2.getSizeInBytes)
assert(unsafeRow.getSizeInBytes == 8 + 8 * 2 + array1Size + array2Size)
}
@@ -384,13 +380,13 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
val nestedMap = valueArray.getMap(0)
testMapInt(nestedMap, Seq(5, 6), Seq(7, 8))
- assert(valueArray.getSizeInBytes == 4 + (8 + nestedMap.getSizeInBytes))
+ assert(valueArray.getSizeInBytes == 4 + 4 + nestedMap.getSizeInBytes)
}
- assert(unsafeMap2.getSizeInBytes == keyArray.getSizeInBytes + valueArray.getSizeInBytes)
+ assert(unsafeMap2.getSizeInBytes == 4 + keyArray.getSizeInBytes + valueArray.getSizeInBytes)
- val map1Size = mapSizeInRow(unsafeMap1.getSizeInBytes)
- val map2Size = mapSizeInRow(unsafeMap2.getSizeInBytes)
+ val map1Size = roundedSize(unsafeMap1.getSizeInBytes)
+ val map2Size = roundedSize(unsafeMap2.getSizeInBytes)
assert(unsafeRow.getSizeInBytes == 8 + 8 * 2 + map1Size + map2Size)
}
@@ -414,7 +410,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
val innerArray = field1.getArray(0)
testArrayInt(innerArray, Seq(1))
- assert(field1.getSizeInBytes == 8 + 8 + arraySizeInRow(innerArray.getSizeInBytes))
+ assert(field1.getSizeInBytes == 8 + 8 + roundedSize(innerArray.getSizeInBytes))
val field2 = unsafeRow.getArray(1)
assert(field2.numElements == 1)
@@ -427,10 +423,10 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
assert(innerStruct.getLong(0) == 2L)
}
- assert(field2.getSizeInBytes == 4 + innerStruct.getSizeInBytes)
+ assert(field2.getSizeInBytes == 4 + 4 + innerStruct.getSizeInBytes)
assert(unsafeRow.getSizeInBytes ==
- 8 + 8 * 2 + field1.getSizeInBytes + arraySizeInRow(field2.getSizeInBytes))
+ 8 + 8 * 2 + field1.getSizeInBytes + roundedSize(field2.getSizeInBytes))
}
test("basic conversion with struct and map") {
@@ -453,7 +449,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
val innerMap = field1.getMap(0)
testMapInt(innerMap, Seq(1), Seq(2))
- assert(field1.getSizeInBytes == 8 + 8 + mapSizeInRow(innerMap.getSizeInBytes))
+ assert(field1.getSizeInBytes == 8 + 8 + roundedSize(innerMap.getSizeInBytes))
val field2 = unsafeRow.getMap(1)
@@ -470,13 +466,13 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
assert(innerStruct.getSizeInBytes == 8 + 8)
assert(innerStruct.getLong(0) == 4L)
- assert(valueArray.getSizeInBytes == 4 + innerStruct.getSizeInBytes)
+ assert(valueArray.getSizeInBytes == 4 + 4 + innerStruct.getSizeInBytes)
}
- assert(field2.getSizeInBytes == keyArray.getSizeInBytes + valueArray.getSizeInBytes)
+ assert(field2.getSizeInBytes == 4 + keyArray.getSizeInBytes + valueArray.getSizeInBytes)
assert(unsafeRow.getSizeInBytes ==
- 8 + 8 * 2 + field1.getSizeInBytes + mapSizeInRow(field2.getSizeInBytes))
+ 8 + 8 * 2 + field1.getSizeInBytes + roundedSize(field2.getSizeInBytes))
}
test("basic conversion with array and map") {
@@ -499,7 +495,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
val innerMap = field1.getMap(0)
testMapInt(innerMap, Seq(1), Seq(2))
- assert(field1.getSizeInBytes == 4 + (8 + innerMap.getSizeInBytes))
+ assert(field1.getSizeInBytes == 4 + 4 + innerMap.getSizeInBytes)
val field2 = unsafeRow.getMap(1)
assert(field2.numElements == 1)
@@ -518,9 +514,9 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
assert(valueArray.getSizeInBytes == 4 + (4 + innerArray.getSizeInBytes))
}
- assert(field2.getSizeInBytes == keyArray.getSizeInBytes + valueArray.getSizeInBytes)
+ assert(field2.getSizeInBytes == 4 + keyArray.getSizeInBytes + valueArray.getSizeInBytes)
assert(unsafeRow.getSizeInBytes ==
- 8 + 8 * 2 + arraySizeInRow(field1.getSizeInBytes) + mapSizeInRow(field2.getSizeInBytes))
+ 8 + 8 * 2 + roundedSize(field1.getSizeInBytes) + roundedSize(field2.getSizeInBytes))
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala
index 2bc2c96b61..a41f04dd3b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala
@@ -482,12 +482,14 @@ private[sql] case class STRUCT(dataType: StructType) extends ColumnType[UnsafeRo
override def extract(buffer: ByteBuffer): UnsafeRow = {
val sizeInBytes = buffer.getInt()
assert(buffer.hasArray)
- val base = buffer.array()
- val offset = buffer.arrayOffset()
val cursor = buffer.position()
buffer.position(cursor + sizeInBytes)
val unsafeRow = new UnsafeRow
- unsafeRow.pointTo(base, Platform.BYTE_ARRAY_OFFSET + offset + cursor, numOfFields, sizeInBytes)
+ unsafeRow.pointTo(
+ buffer.array(),
+ Platform.BYTE_ARRAY_OFFSET + buffer.arrayOffset() + cursor,
+ numOfFields,
+ sizeInBytes)
unsafeRow
}
@@ -508,12 +510,11 @@ private[sql] case class ARRAY(dataType: ArrayType) extends ColumnType[UnsafeArra
override def actualSize(row: InternalRow, ordinal: Int): Int = {
val unsafeArray = getField(row, ordinal)
- 4 + 4 + unsafeArray.getSizeInBytes
+ 4 + unsafeArray.getSizeInBytes
}
override def append(value: UnsafeArrayData, buffer: ByteBuffer): Unit = {
- buffer.putInt(4 + value.getSizeInBytes)
- buffer.putInt(value.numElements())
+ buffer.putInt(value.getSizeInBytes)
value.writeTo(buffer)
}
@@ -522,10 +523,12 @@ private[sql] case class ARRAY(dataType: ArrayType) extends ColumnType[UnsafeArra
assert(buffer.hasArray)
val cursor = buffer.position()
buffer.position(cursor + numBytes)
- UnsafeReaders.readArray(
+ val array = new UnsafeArrayData
+ array.pointTo(
buffer.array(),
Platform.BYTE_ARRAY_OFFSET + buffer.arrayOffset() + cursor,
numBytes)
+ array
}
override def clone(v: UnsafeArrayData): UnsafeArrayData = v.copy()
@@ -545,15 +548,12 @@ private[sql] case class MAP(dataType: MapType) extends ColumnType[UnsafeMapData]
override def actualSize(row: InternalRow, ordinal: Int): Int = {
val unsafeMap = getField(row, ordinal)
- 12 + unsafeMap.keyArray().getSizeInBytes + unsafeMap.valueArray().getSizeInBytes
+ 4 + unsafeMap.getSizeInBytes
}
override def append(value: UnsafeMapData, buffer: ByteBuffer): Unit = {
- buffer.putInt(8 + value.keyArray().getSizeInBytes + value.valueArray().getSizeInBytes)
- buffer.putInt(value.numElements())
- buffer.putInt(value.keyArray().getSizeInBytes)
- value.keyArray().writeTo(buffer)
- value.valueArray().writeTo(buffer)
+ buffer.putInt(value.getSizeInBytes)
+ value.writeTo(buffer)
}
override def extract(buffer: ByteBuffer): UnsafeMapData = {
@@ -561,10 +561,12 @@ private[sql] case class MAP(dataType: MapType) extends ColumnType[UnsafeMapData]
assert(buffer.hasArray)
val cursor = buffer.position()
buffer.position(cursor + numBytes)
- UnsafeReaders.readMap(
+ val map = new UnsafeMapData
+ map.pointTo(
buffer.array(),
Platform.BYTE_ARRAY_OFFSET + buffer.arrayOffset() + cursor,
numBytes)
+ map
}
override def clone(v: UnsafeMapData): UnsafeMapData = v.copy()
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala
index 0e6e1bcf72..63bc39bfa0 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala
@@ -73,7 +73,7 @@ class ColumnTypeSuite extends SparkFunSuite with Logging {
checkActualSize(COMPACT_DECIMAL(15, 10), Decimal(0, 15, 10), 8)
checkActualSize(LARGE_DECIMAL(20, 10), Decimal(0, 20, 10), 5)
checkActualSize(ARRAY_TYPE, Array[Any](1), 16)
- checkActualSize(MAP_TYPE, Map(1 -> "a"), 25)
+ checkActualSize(MAP_TYPE, Map(1 -> "a"), 29)
checkActualSize(STRUCT_TYPE, Row("hello"), 28)
}