aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java7
-rw-r--r--sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeMapData.java15
-rw-r--r--sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeReaders.java6
-rw-r--r--sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java4
-rw-r--r--sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java4
-rw-r--r--sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeWriters.java4
-rw-r--r--sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java54
-rw-r--r--sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java151
-rw-r--r--sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java199
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala1
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala284
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala305
12 files changed, 950 insertions, 84 deletions
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java
index 501dff0903..da9538b3f1 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java
@@ -20,7 +20,6 @@ package org.apache.spark.sql.catalyst.expressions;
import java.math.BigDecimal;
import java.math.BigInteger;
-import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.types.*;
import org.apache.spark.unsafe.Platform;
import org.apache.spark.unsafe.array.ByteArrayMethods;
@@ -256,7 +255,7 @@ public class UnsafeArrayData extends ArrayData {
}
@Override
- public InternalRow getStruct(int ordinal, int numFields) {
+ public UnsafeRow getStruct(int ordinal, int numFields) {
assertIndexIsValid(ordinal);
final int offset = getElementOffset(ordinal);
if (offset < 0) return null;
@@ -267,7 +266,7 @@ public class UnsafeArrayData extends ArrayData {
}
@Override
- public ArrayData getArray(int ordinal) {
+ public UnsafeArrayData getArray(int ordinal) {
assertIndexIsValid(ordinal);
final int offset = getElementOffset(ordinal);
if (offset < 0) return null;
@@ -276,7 +275,7 @@ public class UnsafeArrayData extends ArrayData {
}
@Override
- public MapData getMap(int ordinal) {
+ public UnsafeMapData getMap(int ordinal) {
assertIndexIsValid(ordinal);
final int offset = getElementOffset(ordinal);
if (offset < 0) return null;
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeMapData.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeMapData.java
index 46216054ab..e9dab9edb6 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeMapData.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeMapData.java
@@ -17,18 +17,23 @@
package org.apache.spark.sql.catalyst.expressions;
-import org.apache.spark.sql.types.ArrayData;
import org.apache.spark.sql.types.MapData;
/**
* An Unsafe implementation of Map which is backed by raw memory instead of Java objects.
*
* Currently we just use 2 UnsafeArrayData to represent UnsafeMapData.
+ *
+ * Note that when we write out this map, we should write out the `numElements` at first 4 bytes,
+ * and numBytes of key array at second 4 bytes, then follows key array content and value array
+ * content without `numElements` header.
+ * When we read in a map, we should read first 4 bytes as `numElements` and second 4 bytes as
+ * numBytes of key array, and construct unsafe key array and value array with these 2 information.
*/
public class UnsafeMapData extends MapData {
- public final UnsafeArrayData keys;
- public final UnsafeArrayData values;
+ private final UnsafeArrayData keys;
+ private final UnsafeArrayData values;
// The number of elements in this array
private int numElements;
// The size of this array's backing data, in bytes
@@ -50,12 +55,12 @@ public class UnsafeMapData extends MapData {
}
@Override
- public ArrayData keyArray() {
+ public UnsafeArrayData keyArray() {
return keys;
}
@Override
- public ArrayData valueArray() {
+ public UnsafeArrayData valueArray() {
return values;
}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeReaders.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeReaders.java
index 7b03185a30..6c5fcbca63 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeReaders.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeReaders.java
@@ -21,6 +21,9 @@ import org.apache.spark.unsafe.Platform;
public class UnsafeReaders {
+ /**
+ * Reads in unsafe array according to the format described in `UnsafeArrayData`.
+ */
public static UnsafeArrayData readArray(Object baseObject, long baseOffset, int numBytes) {
// Read the number of elements from first 4 bytes.
final int numElements = Platform.getInt(baseObject, baseOffset);
@@ -30,6 +33,9 @@ public class UnsafeReaders {
return array;
}
+ /**
+ * Reads in unsafe map according to the format described in `UnsafeMapData`.
+ */
public static UnsafeMapData readMap(Object baseObject, long baseOffset, int numBytes) {
// Read the number of elements from first 4 bytes.
final int numElements = Platform.getInt(baseObject, baseOffset);
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
index 6c020045c3..e8ac2999c2 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
@@ -446,7 +446,7 @@ public final class UnsafeRow extends MutableRow {
}
@Override
- public ArrayData getArray(int ordinal) {
+ public UnsafeArrayData getArray(int ordinal) {
if (isNullAt(ordinal)) {
return null;
} else {
@@ -458,7 +458,7 @@ public final class UnsafeRow extends MutableRow {
}
@Override
- public MapData getMap(int ordinal) {
+ public UnsafeMapData getMap(int ordinal) {
if (isNullAt(ordinal)) {
return null;
} else {
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java
index 2f43db68a7..0f1e0202aa 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java
@@ -233,8 +233,8 @@ public class UnsafeRowWriters {
public static int write(UnsafeRow target, int ordinal, int cursor, UnsafeMapData input) {
final long offset = target.getBaseOffset() + cursor;
- final UnsafeArrayData keyArray = input.keys;
- final UnsafeArrayData valueArray = input.values;
+ final UnsafeArrayData keyArray = input.keyArray();
+ final UnsafeArrayData valueArray = input.valueArray();
final int keysNumBytes = keyArray.getSizeInBytes();
final int valuesNumBytes = valueArray.getSizeInBytes();
final int numBytes = 4 + 4 + keysNumBytes + valuesNumBytes;
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeWriters.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeWriters.java
index cd83695fca..ce2d9c4ffb 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeWriters.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeWriters.java
@@ -168,8 +168,8 @@ public class UnsafeWriters {
}
public static int write(Object targetObject, long targetOffset, UnsafeMapData input) {
- final UnsafeArrayData keyArray = input.keys;
- final UnsafeArrayData valueArray = input.values;
+ final UnsafeArrayData keyArray = input.keyArray();
+ final UnsafeArrayData valueArray = input.valueArray();
final int keysNumBytes = keyArray.getSizeInBytes();
final int valuesNumBytes = valueArray.getSizeInBytes();
final int numBytes = 4 + 4 + keysNumBytes + valuesNumBytes;
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java
new file mode 100644
index 0000000000..9c94686780
--- /dev/null
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java
@@ -0,0 +1,54 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions.codegen;
+
+import org.apache.spark.unsafe.Platform;
+
+/**
+ * A helper class to manage the row buffer used in `GenerateUnsafeProjection`.
+ *
+ * Note that it is only used in `GenerateUnsafeProjection`, so it's safe to mark member variables
+ * public for ease of use.
+ */
+public class BufferHolder {
+ public byte[] buffer = new byte[64];
+ public int cursor = Platform.BYTE_ARRAY_OFFSET;
+
+ public void grow(int neededSize) {
+ final int length = totalSize() + neededSize;
+ if (buffer.length < length) {
+ // This will not happen frequently, because the buffer is re-used.
+ final byte[] tmp = new byte[length * 2];
+ Platform.copyMemory(
+ buffer,
+ Platform.BYTE_ARRAY_OFFSET,
+ tmp,
+ Platform.BYTE_ARRAY_OFFSET,
+ totalSize());
+ buffer = tmp;
+ }
+ }
+
+ public void reset() {
+ cursor = Platform.BYTE_ARRAY_OFFSET;
+ }
+
+ public int totalSize() {
+ return cursor - Platform.BYTE_ARRAY_OFFSET;
+ }
+}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java
new file mode 100644
index 0000000000..138178ce99
--- /dev/null
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeArrayWriter.java
@@ -0,0 +1,151 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions.codegen;
+
+import org.apache.spark.sql.catalyst.expressions.UnsafeArrayData;
+import org.apache.spark.sql.types.Decimal;
+import org.apache.spark.unsafe.Platform;
+import org.apache.spark.unsafe.types.CalendarInterval;
+import org.apache.spark.unsafe.types.UTF8String;
+
+/**
+ * A helper class to write data into global row buffer using `UnsafeArrayData` format,
+ * used by {@link org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection}.
+ */
+public class UnsafeArrayWriter {
+
+ private BufferHolder holder;
+ // The offset of the global buffer where we start to write this array.
+ private int startingOffset;
+
+ public void initialize(BufferHolder holder, int numElements, int fixedElementSize) {
+ // We need 4 bytes each element to store offset.
+ final int fixedSize = 4 * numElements;
+
+ this.holder = holder;
+ this.startingOffset = holder.cursor;
+
+ holder.grow(fixedSize);
+ holder.cursor += fixedSize;
+
+ // Grows the global buffer ahead for fixed size data.
+ holder.grow(fixedElementSize * numElements);
+ }
+
+ private long getElementOffset(int ordinal) {
+ return startingOffset + 4 * ordinal;
+ }
+
+ public void setNullAt(int ordinal) {
+ final int relativeOffset = holder.cursor - startingOffset;
+ // Writes negative offset value to represent null element.
+ Platform.putInt(holder.buffer, getElementOffset(ordinal), -relativeOffset);
+ }
+
+ public void setOffset(int ordinal) {
+ final int relativeOffset = holder.cursor - startingOffset;
+ Platform.putInt(holder.buffer, getElementOffset(ordinal), relativeOffset);
+ }
+
+ public void writeCompactDecimal(int ordinal, Decimal input, int precision, int scale) {
+ // make sure Decimal object has the same scale as DecimalType
+ if (input.changePrecision(precision, scale)) {
+ Platform.putLong(holder.buffer, holder.cursor, input.toUnscaledLong());
+ setOffset(ordinal);
+ holder.cursor += 8;
+ } else {
+ setNullAt(ordinal);
+ }
+ }
+
+ public void write(int ordinal, Decimal input, int precision, int scale) {
+ // make sure Decimal object has the same scale as DecimalType
+ if (input.changePrecision(precision, scale)) {
+ final byte[] bytes = input.toJavaBigDecimal().unscaledValue().toByteArray();
+ assert bytes.length <= 16;
+ holder.grow(bytes.length);
+
+ // Write the bytes to the variable length portion.
+ Platform.copyMemory(
+ bytes, Platform.BYTE_ARRAY_OFFSET, holder.buffer, holder.cursor, bytes.length);
+ setOffset(ordinal);
+ holder.cursor += bytes.length;
+ } else {
+ setNullAt(ordinal);
+ }
+ }
+
+ public void write(int ordinal, UTF8String input) {
+ final int numBytes = input.numBytes();
+
+ // grow the global buffer before writing data.
+ holder.grow(numBytes);
+
+ // Write the bytes to the variable length portion.
+ input.writeToMemory(holder.buffer, holder.cursor);
+
+ setOffset(ordinal);
+
+ // move the cursor forward.
+ holder.cursor += numBytes;
+ }
+
+ public void write(int ordinal, byte[] input) {
+ // grow the global buffer before writing data.
+ holder.grow(input.length);
+
+ // Write the bytes to the variable length portion.
+ Platform.copyMemory(
+ input, Platform.BYTE_ARRAY_OFFSET, holder.buffer, holder.cursor, input.length);
+
+ setOffset(ordinal);
+
+ // move the cursor forward.
+ holder.cursor += input.length;
+ }
+
+ public void write(int ordinal, CalendarInterval input) {
+ // grow the global buffer before writing data.
+ holder.grow(16);
+
+ // Write the months and microseconds fields of Interval to the variable length portion.
+ Platform.putLong(holder.buffer, holder.cursor, input.months);
+ Platform.putLong(holder.buffer, holder.cursor + 8, input.microseconds);
+
+ setOffset(ordinal);
+
+ // move the cursor forward.
+ holder.cursor += 16;
+ }
+
+
+
+ // If this array is already an UnsafeArray, we don't need to go through all elements, we can
+ // directly write it.
+ public static void directWrite(BufferHolder holder, UnsafeArrayData input) {
+ final int numBytes = input.getSizeInBytes();
+
+ // grow the global buffer before writing data.
+ holder.grow(numBytes);
+
+ // Writes the array content to the variable length portion.
+ input.writeToMemory(holder.buffer, holder.cursor);
+
+ holder.cursor += numBytes;
+ }
+}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java
new file mode 100644
index 0000000000..8b7debd440
--- /dev/null
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriter.java
@@ -0,0 +1,199 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions.codegen;
+
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow;
+import org.apache.spark.sql.types.Decimal;
+import org.apache.spark.unsafe.Platform;
+import org.apache.spark.unsafe.array.ByteArrayMethods;
+import org.apache.spark.unsafe.bitset.BitSetMethods;
+import org.apache.spark.unsafe.types.CalendarInterval;
+import org.apache.spark.unsafe.types.UTF8String;
+
+/**
+ * A helper class to write data into global row buffer using `UnsafeRow` format,
+ * used by {@link org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection}.
+ */
+public class UnsafeRowWriter {
+
+ private BufferHolder holder;
+ // The offset of the global buffer where we start to write this row.
+ private int startingOffset;
+ private int nullBitsSize;
+
+ public void initialize(BufferHolder holder, int numFields) {
+ this.holder = holder;
+ this.startingOffset = holder.cursor;
+ this.nullBitsSize = UnsafeRow.calculateBitSetWidthInBytes(numFields);
+
+ // grow the global buffer to make sure it has enough space to write fixed-length data.
+ final int fixedSize = nullBitsSize + 8 * numFields;
+ holder.grow(fixedSize);
+ holder.cursor += fixedSize;
+
+ // zero-out the null bits region
+ for (int i = 0; i < nullBitsSize; i += 8) {
+ Platform.putLong(holder.buffer, startingOffset + i, 0L);
+ }
+ }
+
+ private void zeroOutPaddingBytes(int numBytes) {
+ if ((numBytes & 0x07) > 0) {
+ Platform.putLong(holder.buffer, holder.cursor + ((numBytes >> 3) << 3), 0L);
+ }
+ }
+
+ public void setNullAt(int ordinal) {
+ BitSetMethods.set(holder.buffer, startingOffset, ordinal);
+ Platform.putLong(holder.buffer, getFieldOffset(ordinal), 0L);
+ }
+
+ public long getFieldOffset(int ordinal) {
+ return startingOffset + nullBitsSize + 8 * ordinal;
+ }
+
+ public void setOffsetAndSize(int ordinal, long size) {
+ setOffsetAndSize(ordinal, holder.cursor, size);
+ }
+
+ public void setOffsetAndSize(int ordinal, long currentCursor, long size) {
+ final long relativeOffset = currentCursor - startingOffset;
+ final long fieldOffset = getFieldOffset(ordinal);
+ final long offsetAndSize = (relativeOffset << 32) | size;
+
+ Platform.putLong(holder.buffer, fieldOffset, offsetAndSize);
+ }
+
+ // Do word alignment for this row and grow the row buffer if needed.
+ // todo: remove this after we make unsafe array data word align.
+ public void alignToWords(int numBytes) {
+ final int remainder = numBytes & 0x07;
+
+ if (remainder > 0) {
+ final int paddingBytes = 8 - remainder;
+ holder.grow(paddingBytes);
+
+ for (int i = 0; i < paddingBytes; i++) {
+ Platform.putByte(holder.buffer, holder.cursor, (byte) 0);
+ holder.cursor++;
+ }
+ }
+ }
+
+ public void writeCompactDecimal(int ordinal, Decimal input, int precision, int scale) {
+ // make sure Decimal object has the same scale as DecimalType
+ if (input.changePrecision(precision, scale)) {
+ Platform.putLong(holder.buffer, getFieldOffset(ordinal), input.toUnscaledLong());
+ } else {
+ setNullAt(ordinal);
+ }
+ }
+
+ public void write(int ordinal, Decimal input, int precision, int scale) {
+ // grow the global buffer before writing data.
+ holder.grow(16);
+
+ // zero-out the bytes
+ Platform.putLong(holder.buffer, holder.cursor, 0L);
+ Platform.putLong(holder.buffer, holder.cursor + 8, 0L);
+
+ // Make sure Decimal object has the same scale as DecimalType.
+ // Note that we may pass in null Decimal object to set null for it.
+ if (input == null || !input.changePrecision(precision, scale)) {
+ BitSetMethods.set(holder.buffer, startingOffset, ordinal);
+ // keep the offset for future update
+ setOffsetAndSize(ordinal, 0L);
+ } else {
+ final byte[] bytes = input.toJavaBigDecimal().unscaledValue().toByteArray();
+ assert bytes.length <= 16;
+
+ // Write the bytes to the variable length portion.
+ Platform.copyMemory(
+ bytes, Platform.BYTE_ARRAY_OFFSET, holder.buffer, holder.cursor, bytes.length);
+ setOffsetAndSize(ordinal, bytes.length);
+ }
+
+ // move the cursor forward.
+ holder.cursor += 16;
+ }
+
+ public void write(int ordinal, UTF8String input) {
+ final int numBytes = input.numBytes();
+ final int roundedSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes);
+
+ // grow the global buffer before writing data.
+ holder.grow(roundedSize);
+
+ zeroOutPaddingBytes(numBytes);
+
+ // Write the bytes to the variable length portion.
+ input.writeToMemory(holder.buffer, holder.cursor);
+
+ setOffsetAndSize(ordinal, numBytes);
+
+ // move the cursor forward.
+ holder.cursor += roundedSize;
+ }
+
+ public void write(int ordinal, byte[] input) {
+ final int numBytes = input.length;
+ final int roundedSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes);
+
+ // grow the global buffer before writing data.
+ holder.grow(roundedSize);
+
+ zeroOutPaddingBytes(numBytes);
+
+ // Write the bytes to the variable length portion.
+ Platform.copyMemory(input, Platform.BYTE_ARRAY_OFFSET, holder.buffer, holder.cursor, numBytes);
+
+ setOffsetAndSize(ordinal, numBytes);
+
+ // move the cursor forward.
+ holder.cursor += roundedSize;
+ }
+
+ public void write(int ordinal, CalendarInterval input) {
+ // grow the global buffer before writing data.
+ holder.grow(16);
+
+ // Write the months and microseconds fields of Interval to the variable length portion.
+ Platform.putLong(holder.buffer, holder.cursor, input.months);
+ Platform.putLong(holder.buffer, holder.cursor + 8, input.microseconds);
+
+ setOffsetAndSize(ordinal, 16);
+
+ // move the cursor forward.
+ holder.cursor += 16;
+ }
+
+
+
+ // If this struct is already an UnsafeRow, we don't need to go through all fields, we can
+ // directly write it.
+ public static void directWrite(BufferHolder holder, UnsafeRow input) {
+ // No need to zero-out the bytes as UnsafeRow is word aligned for sure.
+ final int numBytes = input.getSizeInBytes();
+ // grow the global buffer before writing data.
+ holder.grow(numBytes);
+ // Write the bytes to the variable length portion.
+ input.writeToMemory(holder.buffer, holder.cursor);
+ // move the cursor forward.
+ holder.cursor += numBytes;
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index da3103b4eb..9a28781133 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -272,6 +272,7 @@ class CodeGenContext {
* 64kb code size limit in JVM
*
* @param row the variable name of row that is used by expressions
+ * @param expressions the codes to evaluate expressions.
*/
def splitExpressions(row: String, expressions: Seq[String]): String = {
val blocks = new ArrayBuffer[String]()
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
index 55562facf9..99bf50a845 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
@@ -393,10 +393,292 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
case _ => input
}
+ private val rowWriterClass = classOf[UnsafeRowWriter].getName
+ private val arrayWriterClass = classOf[UnsafeArrayWriter].getName
+
+ // TODO: if the nullability of field is correct, we can use it to save null check.
+ private def writeStructToBuffer(
+ ctx: CodeGenContext,
+ input: String,
+ fieldTypes: Seq[DataType],
+ bufferHolder: String): String = {
+ val fieldEvals = fieldTypes.zipWithIndex.map { case (dt, i) =>
+ val fieldName = ctx.freshName("fieldName")
+ val code = s"final ${ctx.javaType(dt)} $fieldName = ${ctx.getValue(input, dt, i.toString)};"
+ val isNull = s"$input.isNullAt($i)"
+ GeneratedExpressionCode(code, isNull, fieldName)
+ }
+
+ s"""
+ if ($input instanceof UnsafeRow) {
+ $rowWriterClass.directWrite($bufferHolder, (UnsafeRow) $input);
+ } else {
+ ${writeExpressionsToBuffer(ctx, input, fieldEvals, fieldTypes, bufferHolder)}
+ }
+ """
+ }
+
+ private def writeExpressionsToBuffer(
+ ctx: CodeGenContext,
+ row: String,
+ inputs: Seq[GeneratedExpressionCode],
+ inputTypes: Seq[DataType],
+ bufferHolder: String): String = {
+ val rowWriter = ctx.freshName("rowWriter")
+ ctx.addMutableState(rowWriterClass, rowWriter, s"this.$rowWriter = new $rowWriterClass();")
+
+ val writeFields = inputs.zip(inputTypes).zipWithIndex.map {
+ case ((input, dt), index) =>
+ val tmpCursor = ctx.freshName("tmpCursor")
+
+ val setNull = dt match {
+ case t: DecimalType if t.precision > Decimal.MAX_LONG_DIGITS =>
+ // Can't call setNullAt() for DecimalType with precision larger than 18.
+ s"$rowWriter.write($index, null, ${t.precision}, ${t.scale});"
+ case _ => s"$rowWriter.setNullAt($index);"
+ }
+
+ val writeField = dt match {
+ case t: StructType =>
+ s"""
+ // Remember the current cursor so that we can calculate how many bytes are
+ // written later.
+ final int $tmpCursor = $bufferHolder.cursor;
+ ${writeStructToBuffer(ctx, input.primitive, t.map(_.dataType), bufferHolder)}
+ $rowWriter.setOffsetAndSize($index, $tmpCursor, $bufferHolder.cursor - $tmpCursor);
+ """
+
+ case a @ ArrayType(et, _) =>
+ s"""
+ // Remember the current cursor so that we can calculate how many bytes are
+ // written later.
+ final int $tmpCursor = $bufferHolder.cursor;
+ ${writeArrayToBuffer(ctx, input.primitive, et, bufferHolder)}
+ $rowWriter.setOffsetAndSize($index, $tmpCursor, $bufferHolder.cursor - $tmpCursor);
+ $rowWriter.alignToWords($bufferHolder.cursor - $tmpCursor);
+ """
+
+ case m @ MapType(kt, vt, _) =>
+ s"""
+ // Remember the current cursor so that we can calculate how many bytes are
+ // written later.
+ final int $tmpCursor = $bufferHolder.cursor;
+ ${writeMapToBuffer(ctx, input.primitive, kt, vt, bufferHolder)}
+ $rowWriter.setOffsetAndSize($index, $tmpCursor, $bufferHolder.cursor - $tmpCursor);
+ $rowWriter.alignToWords($bufferHolder.cursor - $tmpCursor);
+ """
+
+ case _ if ctx.isPrimitiveType(dt) =>
+ val fieldOffset = ctx.freshName("fieldOffset")
+ s"""
+ final long $fieldOffset = $rowWriter.getFieldOffset($index);
+ Platform.putLong($bufferHolder.buffer, $fieldOffset, 0L);
+ ${writePrimitiveType(ctx, input.primitive, dt, s"$bufferHolder.buffer", fieldOffset)}
+ """
+
+ case t: DecimalType if t.precision <= Decimal.MAX_LONG_DIGITS =>
+ s"$rowWriter.writeCompactDecimal($index, ${input.primitive}, " +
+ s"${t.precision}, ${t.scale});"
+
+ case t: DecimalType =>
+ s"$rowWriter.write($index, ${input.primitive}, ${t.precision}, ${t.scale});"
+
+ case NullType => ""
+
+ case _ => s"$rowWriter.write($index, ${input.primitive});"
+ }
+
+ s"""
+ ${input.code}
+ if (${input.isNull}) {
+ $setNull
+ } else {
+ $writeField
+ }
+ """
+ }
+
+ s"""
+ $rowWriter.initialize($bufferHolder, ${inputs.length});
+ ${ctx.splitExpressions(row, writeFields)}
+ """
+ }
+
+ // TODO: if the nullability of array element is correct, we can use it to save null check.
+ private def writeArrayToBuffer(
+ ctx: CodeGenContext,
+ input: String,
+ elementType: DataType,
+ bufferHolder: String,
+ needHeader: Boolean = true): String = {
+ val arrayWriter = ctx.freshName("arrayWriter")
+ ctx.addMutableState(arrayWriterClass, arrayWriter,
+ s"this.$arrayWriter = new $arrayWriterClass();")
+ val numElements = ctx.freshName("numElements")
+ val index = ctx.freshName("index")
+ val element = ctx.freshName("element")
+
+ val jt = ctx.javaType(elementType)
+
+ val fixedElementSize = elementType match {
+ case t: DecimalType if t.precision <= Decimal.MAX_LONG_DIGITS => 8
+ case _ if ctx.isPrimitiveType(jt) => elementType.defaultSize
+ case _ => 0
+ }
+
+ val writeElement = elementType match {
+ case t: StructType =>
+ s"""
+ $arrayWriter.setOffset($index);
+ ${writeStructToBuffer(ctx, element, t.map(_.dataType), bufferHolder)}
+ """
+
+ case a @ ArrayType(et, _) =>
+ s"""
+ $arrayWriter.setOffset($index);
+ ${writeArrayToBuffer(ctx, element, et, bufferHolder)}
+ """
+
+ case m @ MapType(kt, vt, _) =>
+ s"""
+ $arrayWriter.setOffset($index);
+ ${writeMapToBuffer(ctx, element, kt, vt, bufferHolder)}
+ """
+
+ case _ if ctx.isPrimitiveType(elementType) =>
+ // Should we do word align?
+ val dataSize = elementType.defaultSize
+
+ s"""
+ $arrayWriter.setOffset($index);
+ ${writePrimitiveType(ctx, element, elementType,
+ s"$bufferHolder.buffer", s"$bufferHolder.cursor")}
+ $bufferHolder.cursor += $dataSize;
+ """
+
+ case t: DecimalType if t.precision <= Decimal.MAX_LONG_DIGITS =>
+ s"$arrayWriter.writeCompactDecimal($index, $element, ${t.precision}, ${t.scale});"
+
+ case t: DecimalType =>
+ s"$arrayWriter.write($index, $element, ${t.precision}, ${t.scale});"
+
+ case NullType => ""
+
+ case _ => s"$arrayWriter.write($index, $element);"
+ }
+
+ val writeHeader = if (needHeader) {
+ // If header is required, we need to write the number of elements into first 4 bytes.
+ s"""
+ $bufferHolder.grow(4);
+ Platform.putInt($bufferHolder.buffer, $bufferHolder.cursor, $numElements);
+ $bufferHolder.cursor += 4;
+ """
+ } else ""
+
+ s"""
+ final int $numElements = $input.numElements();
+ $writeHeader
+ if ($input instanceof UnsafeArrayData) {
+ $arrayWriterClass.directWrite($bufferHolder, (UnsafeArrayData) $input);
+ } else {
+ $arrayWriter.initialize($bufferHolder, $numElements, $fixedElementSize);
+
+ for (int $index = 0; $index < $numElements; $index++) {
+ if ($input.isNullAt($index)) {
+ $arrayWriter.setNullAt($index);
+ } else {
+ final $jt $element = ${ctx.getValue(input, elementType, index)};
+ $writeElement
+ }
+ }
+ }
+ """
+ }
+
+ // TODO: if the nullability of value element is correct, we can use it to save null check.
+ private def writeMapToBuffer(
+ ctx: CodeGenContext,
+ input: String,
+ keyType: DataType,
+ valueType: DataType,
+ bufferHolder: String): String = {
+ val keys = ctx.freshName("keys")
+ val values = ctx.freshName("values")
+ val tmpCursor = ctx.freshName("tmpCursor")
+
+
+ // Writes out unsafe map according to the format described in `UnsafeMapData`.
+ s"""
+ final ArrayData $keys = $input.keyArray();
+ final ArrayData $values = $input.valueArray();
+
+ $bufferHolder.grow(8);
+
+ // Write the numElements into first 4 bytes.
+ Platform.putInt($bufferHolder.buffer, $bufferHolder.cursor, $keys.numElements());
+
+ $bufferHolder.cursor += 8;
+ // Remember the current cursor so that we can write numBytes of key array later.
+ final int $tmpCursor = $bufferHolder.cursor;
+
+ ${writeArrayToBuffer(ctx, keys, keyType, bufferHolder, needHeader = false)}
+ // Write the numBytes of key array into second 4 bytes.
+ Platform.putInt($bufferHolder.buffer, $tmpCursor - 4, $bufferHolder.cursor - $tmpCursor);
+
+ ${writeArrayToBuffer(ctx, values, valueType, bufferHolder, needHeader = false)}
+ """
+ }
+
+ private def writePrimitiveType(
+ ctx: CodeGenContext,
+ input: String,
+ dt: DataType,
+ buffer: String,
+ offset: String) = {
+ assert(ctx.isPrimitiveType(dt))
+
+ val putMethod = s"put${ctx.primitiveTypeName(dt)}"
+
+ dt match {
+ case FloatType | DoubleType =>
+ val normalized = ctx.freshName("normalized")
+ val boxedType = ctx.boxedType(dt)
+ val handleNaN =
+ s"""
+ final ${ctx.javaType(dt)} $normalized;
+ if ($boxedType.isNaN($input)) {
+ $normalized = $boxedType.NaN;
+ } else {
+ $normalized = $input;
+ }
+ """
+
+ s"""
+ $handleNaN
+ Platform.$putMethod($buffer, $offset, $normalized);
+ """
+ case _ => s"Platform.$putMethod($buffer, $offset, $input);"
+ }
+ }
+
def createCode(ctx: CodeGenContext, expressions: Seq[Expression]): GeneratedExpressionCode = {
val exprEvals = expressions.map(e => e.gen(ctx))
val exprTypes = expressions.map(_.dataType)
- createCodeForStruct(ctx, "i", exprEvals, exprTypes)
+
+ val result = ctx.freshName("result")
+ ctx.addMutableState("UnsafeRow", result, s"this.$result = new UnsafeRow();")
+ val bufferHolder = ctx.freshName("bufferHolder")
+ val holderClass = classOf[BufferHolder].getName
+ ctx.addMutableState(holderClass, bufferHolder, s"this.$bufferHolder = new $holderClass();")
+
+ val code =
+ s"""
+ $bufferHolder.reset();
+ ${writeExpressionsToBuffer(ctx, "i", exprEvals, exprTypes, bufferHolder)}
+ $result.pointTo($bufferHolder.buffer, ${expressions.length}, $bufferHolder.totalSize());
+ """
+ GeneratedExpressionCode(code, "false", result)
}
protected def canonicalize(in: Seq[Expression]): Seq[Expression] =
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala
index 8c72203193..c991cd86d2 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala
@@ -18,7 +18,6 @@
package org.apache.spark.sql.catalyst.expressions
import java.sql.{Date, Timestamp}
-import java.util.Arrays
import org.scalatest.Matchers
@@ -43,7 +42,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
row.setInt(2, 2)
val unsafeRow: UnsafeRow = converter.apply(row)
- assert(converter.apply(row).getSizeInBytes === 8 + (3 * 8))
+ assert(unsafeRow.getSizeInBytes === 8 + (3 * 8))
assert(unsafeRow.getLong(0) === 0)
assert(unsafeRow.getLong(1) === 1)
assert(unsafeRow.getInt(2) === 2)
@@ -62,6 +61,13 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
assert(unsafeRowCopy.getLong(0) === 0)
assert(unsafeRowCopy.getLong(1) === 1)
assert(unsafeRowCopy.getInt(2) === 2)
+
+ // Make sure the converter can be reused, i.e. we correctly reset all states.
+ val unsafeRow2: UnsafeRow = converter.apply(row)
+ assert(unsafeRow2.getSizeInBytes === 8 + (3 * 8))
+ assert(unsafeRow2.getLong(0) === 0)
+ assert(unsafeRow2.getLong(1) === 1)
+ assert(unsafeRow2.getInt(2) === 2)
}
test("basic conversion with primitive, string and binary types") {
@@ -176,7 +182,6 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
r
}
- // todo: we reuse the UnsafeRow in projection, so these tests are meaningless.
val setToNullAfterCreation = converter.apply(rowWithNoNullColumns)
assert(setToNullAfterCreation.isNullAt(0) === rowWithNoNullColumns.isNullAt(0))
assert(setToNullAfterCreation.getBoolean(1) === rowWithNoNullColumns.getBoolean(1))
@@ -192,7 +197,6 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
rowWithNoNullColumns.getDecimal(10, 10, 0))
assert(setToNullAfterCreation.getDecimal(11, 38, 18) ===
rowWithNoNullColumns.getDecimal(11, 38, 18))
- // assert(setToNullAfterCreation.get(11) === rowWithNoNullColumns.get(11))
for (i <- fieldTypes.indices) {
// Cann't call setNullAt() on DecimalType
@@ -202,8 +206,6 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
setToNullAfterCreation.setNullAt(i)
}
}
- // There are some garbage left in the var-length area
- assert(Arrays.equals(createdFromNull.getBytes, setToNullAfterCreation.getBytes()))
setToNullAfterCreation.setNullAt(0)
setToNullAfterCreation.setBoolean(1, false)
@@ -251,107 +253,274 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
assert(converter.apply(row1).getBytes === converter.apply(row2).getBytes)
}
+ test("basic conversion with struct type") {
+ val fieldTypes: Array[DataType] = Array(
+ new StructType().add("i", IntegerType),
+ new StructType().add("nest", new StructType().add("l", LongType))
+ )
+
+ val converter = UnsafeProjection.create(fieldTypes)
+
+ val row = new GenericMutableRow(fieldTypes.length)
+ row.update(0, InternalRow(1))
+ row.update(1, InternalRow(InternalRow(2L)))
+
+ val unsafeRow: UnsafeRow = converter.apply(row)
+ assert(unsafeRow.numFields == 2)
+
+ val row1 = unsafeRow.getStruct(0, 1)
+ assert(row1.getSizeInBytes == 8 + 1 * 8)
+ assert(row1.numFields == 1)
+ assert(row1.getInt(0) == 1)
+
+ val row2 = unsafeRow.getStruct(1, 1)
+ assert(row2.numFields() == 1)
+
+ val innerRow = row2.getStruct(0, 1)
+
+ {
+ assert(innerRow.getSizeInBytes == 8 + 1 * 8)
+ assert(innerRow.numFields == 1)
+ assert(innerRow.getLong(0) == 2L)
+ }
+
+ assert(row2.getSizeInBytes == 8 + 1 * 8 + innerRow.getSizeInBytes)
+
+ assert(unsafeRow.getSizeInBytes == 8 + 2 * 8 + row1.getSizeInBytes + row2.getSizeInBytes)
+ }
+
+ private def createArray(values: Any*): ArrayData = new GenericArrayData(values.toArray)
+
+ private def createMap(keys: Any*)(values: Any*): MapData = {
+ assert(keys.length == values.length)
+ new ArrayBasedMapData(createArray(keys: _*), createArray(values: _*))
+ }
+
+ private def arraySizeInRow(numBytes: Int): Int = roundedSize(4 + numBytes)
+
+ private def mapSizeInRow(numBytes: Int): Int = roundedSize(8 + numBytes)
+
+ private def testArrayInt(array: UnsafeArrayData, values: Seq[Int]): Unit = {
+ assert(array.numElements == values.length)
+ assert(array.getSizeInBytes == (4 + 4) * values.length)
+ values.zipWithIndex.foreach {
+ case (value, index) => assert(array.getInt(index) == value)
+ }
+ }
+
+ private def testMapInt(map: UnsafeMapData, keys: Seq[Int], values: Seq[Int]): Unit = {
+ assert(keys.length == values.length)
+ assert(map.numElements == keys.length)
+
+ testArrayInt(map.keyArray, keys)
+ testArrayInt(map.valueArray, values)
+
+ assert(map.getSizeInBytes == map.keyArray.getSizeInBytes + map.valueArray.getSizeInBytes)
+ }
+
test("basic conversion with array type") {
val fieldTypes: Array[DataType] = Array(
- ArrayType(LongType),
- ArrayType(ArrayType(LongType))
+ ArrayType(IntegerType),
+ ArrayType(ArrayType(IntegerType))
)
val converter = UnsafeProjection.create(fieldTypes)
- val array1 = new GenericArrayData(Array[Any](1L, 2L))
- val array2 = new GenericArrayData(Array[Any](new GenericArrayData(Array[Any](3L, 4L))))
val row = new GenericMutableRow(fieldTypes.length)
- row.update(0, array1)
- row.update(1, array2)
+ row.update(0, createArray(1, 2))
+ row.update(1, createArray(createArray(3, 4)))
val unsafeRow: UnsafeRow = converter.apply(row)
assert(unsafeRow.numFields() == 2)
- val unsafeArray1 = unsafeRow.getArray(0).asInstanceOf[UnsafeArrayData]
- assert(unsafeArray1.getSizeInBytes == 4 * 2 + 8 * 2)
- assert(unsafeArray1.numElements() == 2)
- assert(unsafeArray1.getLong(0) == 1L)
- assert(unsafeArray1.getLong(1) == 2L)
+ val unsafeArray1 = unsafeRow.getArray(0)
+ testArrayInt(unsafeArray1, Seq(1, 2))
- val unsafeArray2 = unsafeRow.getArray(1).asInstanceOf[UnsafeArrayData]
- assert(unsafeArray2.numElements() == 1)
+ val unsafeArray2 = unsafeRow.getArray(1)
+ assert(unsafeArray2.numElements == 1)
- val nestedArray = unsafeArray2.getArray(0).asInstanceOf[UnsafeArrayData]
- assert(nestedArray.getSizeInBytes == 4 * 2 + 8 * 2)
- assert(nestedArray.numElements() == 2)
- assert(nestedArray.getLong(0) == 3L)
- assert(nestedArray.getLong(1) == 4L)
+ val nestedArray = unsafeArray2.getArray(0)
+ testArrayInt(nestedArray, Seq(3, 4))
- assert(unsafeArray2.getSizeInBytes == 4 + 4 + nestedArray.getSizeInBytes)
+ assert(unsafeArray2.getSizeInBytes == 4 + (4 + nestedArray.getSizeInBytes))
- val array1Size = roundedSize(4 + unsafeArray1.getSizeInBytes)
- val array2Size = roundedSize(4 + unsafeArray2.getSizeInBytes)
+ val array1Size = arraySizeInRow(unsafeArray1.getSizeInBytes)
+ val array2Size = arraySizeInRow(unsafeArray2.getSizeInBytes)
assert(unsafeRow.getSizeInBytes == 8 + 8 * 2 + array1Size + array2Size)
}
test("basic conversion with map type") {
- def createArray(values: Any*): ArrayData = new GenericArrayData(values.toArray)
+ val fieldTypes: Array[DataType] = Array(
+ MapType(IntegerType, IntegerType),
+ MapType(IntegerType, MapType(IntegerType, IntegerType))
+ )
+ val converter = UnsafeProjection.create(fieldTypes)
- def testIntLongMap(map: UnsafeMapData, keys: Array[Int], values: Array[Long]): Unit = {
- val numElements = keys.length
- assert(map.numElements() == numElements)
+ val map1 = createMap(1, 2)(3, 4)
- val keyArray = map.keys
- assert(keyArray.getSizeInBytes == 4 * numElements + 4 * numElements)
- assert(keyArray.numElements() == numElements)
- keys.zipWithIndex.foreach { case (key, i) =>
- assert(keyArray.getInt(i) == key)
- }
+ val innerMap = createMap(5, 6)(7, 8)
+ val map2 = createMap(9)(innerMap)
- val valueArray = map.values
- assert(valueArray.getSizeInBytes == 4 * numElements + 8 * numElements)
- assert(valueArray.numElements() == numElements)
- values.zipWithIndex.foreach { case (value, i) =>
- assert(valueArray.getLong(i) == value)
- }
+ val row = new GenericMutableRow(fieldTypes.length)
+ row.update(0, map1)
+ row.update(1, map2)
- assert(map.getSizeInBytes == keyArray.getSizeInBytes + valueArray.getSizeInBytes)
+ val unsafeRow: UnsafeRow = converter.apply(row)
+ assert(unsafeRow.numFields == 2)
+
+ val unsafeMap1 = unsafeRow.getMap(0)
+ testMapInt(unsafeMap1, Seq(1, 2), Seq(3, 4))
+
+ val unsafeMap2 = unsafeRow.getMap(1)
+ assert(unsafeMap2.numElements == 1)
+
+ val keyArray = unsafeMap2.keyArray
+ testArrayInt(keyArray, Seq(9))
+
+ val valueArray = unsafeMap2.valueArray
+
+ {
+ assert(valueArray.numElements == 1)
+
+ val nestedMap = valueArray.getMap(0)
+ testMapInt(nestedMap, Seq(5, 6), Seq(7, 8))
+
+ assert(valueArray.getSizeInBytes == 4 + (8 + nestedMap.getSizeInBytes))
}
+ assert(unsafeMap2.getSizeInBytes == keyArray.getSizeInBytes + valueArray.getSizeInBytes)
+
+ val map1Size = mapSizeInRow(unsafeMap1.getSizeInBytes)
+ val map2Size = mapSizeInRow(unsafeMap2.getSizeInBytes)
+ assert(unsafeRow.getSizeInBytes == 8 + 8 * 2 + map1Size + map2Size)
+ }
+
+ test("basic conversion with struct and array") {
val fieldTypes: Array[DataType] = Array(
- MapType(IntegerType, LongType),
- MapType(IntegerType, MapType(IntegerType, LongType))
+ new StructType().add("arr", ArrayType(IntegerType)),
+ ArrayType(new StructType().add("l", LongType))
)
val converter = UnsafeProjection.create(fieldTypes)
- val map1 = new ArrayBasedMapData(createArray(1, 2), createArray(3L, 4L))
+ val row = new GenericMutableRow(fieldTypes.length)
+ row.update(0, InternalRow(createArray(1)))
+ row.update(1, createArray(InternalRow(2L)))
+
+ val unsafeRow: UnsafeRow = converter.apply(row)
+ assert(unsafeRow.numFields() == 2)
+
+ val field1 = unsafeRow.getStruct(0, 1)
+ assert(field1.numFields == 1)
+
+ val innerArray = field1.getArray(0)
+ testArrayInt(innerArray, Seq(1))
- val innerMap = new ArrayBasedMapData(createArray(5, 6), createArray(7L, 8L))
- val map2 = new ArrayBasedMapData(createArray(9), createArray(innerMap))
+ assert(field1.getSizeInBytes == 8 + 8 + arraySizeInRow(innerArray.getSizeInBytes))
+
+ val field2 = unsafeRow.getArray(1)
+ assert(field2.numElements == 1)
+
+ val innerStruct = field2.getStruct(0, 1)
+
+ {
+ assert(innerStruct.numFields == 1)
+ assert(innerStruct.getSizeInBytes == 8 + 8)
+ assert(innerStruct.getLong(0) == 2L)
+ }
+
+ assert(field2.getSizeInBytes == 4 + innerStruct.getSizeInBytes)
+
+ assert(unsafeRow.getSizeInBytes ==
+ 8 + 8 * 2 + field1.getSizeInBytes + arraySizeInRow(field2.getSizeInBytes))
+ }
+
+ test("basic conversion with struct and map") {
+ val fieldTypes: Array[DataType] = Array(
+ new StructType().add("map", MapType(IntegerType, IntegerType)),
+ MapType(IntegerType, new StructType().add("l", LongType))
+ )
+ val converter = UnsafeProjection.create(fieldTypes)
val row = new GenericMutableRow(fieldTypes.length)
- row.update(0, map1)
- row.update(1, map2)
+ row.update(0, InternalRow(createMap(1)(2)))
+ row.update(1, createMap(3)(InternalRow(4L)))
val unsafeRow: UnsafeRow = converter.apply(row)
assert(unsafeRow.numFields() == 2)
- val unsafeMap1 = unsafeRow.getMap(0).asInstanceOf[UnsafeMapData]
- testIntLongMap(unsafeMap1, Array(1, 2), Array(3L, 4L))
+ val field1 = unsafeRow.getStruct(0, 1)
+ assert(field1.numFields == 1)
- val unsafeMap2 = unsafeRow.getMap(1).asInstanceOf[UnsafeMapData]
- assert(unsafeMap2.numElements() == 1)
+ val innerMap = field1.getMap(0)
+ testMapInt(innerMap, Seq(1), Seq(2))
- val keyArray = unsafeMap2.keys
- assert(keyArray.getSizeInBytes == 4 + 4)
- assert(keyArray.numElements() == 1)
- assert(keyArray.getInt(0) == 9)
+ assert(field1.getSizeInBytes == 8 + 8 + mapSizeInRow(innerMap.getSizeInBytes))
- val valueArray = unsafeMap2.values
- assert(valueArray.numElements() == 1)
- val nestedMap = valueArray.getMap(0).asInstanceOf[UnsafeMapData]
- testIntLongMap(nestedMap, Array(5, 6), Array(7L, 8L))
- assert(valueArray.getSizeInBytes == 4 + 8 + nestedMap.getSizeInBytes)
+ val field2 = unsafeRow.getMap(1)
- assert(unsafeMap2.getSizeInBytes == keyArray.getSizeInBytes + valueArray.getSizeInBytes)
+ val keyArray = field2.keyArray
+ testArrayInt(keyArray, Seq(3))
- val map1Size = roundedSize(8 + unsafeMap1.getSizeInBytes)
- val map2Size = roundedSize(8 + unsafeMap2.getSizeInBytes)
- assert(unsafeRow.getSizeInBytes == 8 + 8 * 2 + map1Size + map2Size)
+ val valueArray = field2.valueArray
+
+ {
+ assert(valueArray.numElements == 1)
+
+ val innerStruct = valueArray.getStruct(0, 1)
+ assert(innerStruct.numFields == 1)
+ assert(innerStruct.getSizeInBytes == 8 + 8)
+ assert(innerStruct.getLong(0) == 4L)
+
+ assert(valueArray.getSizeInBytes == 4 + innerStruct.getSizeInBytes)
+ }
+
+ assert(field2.getSizeInBytes == keyArray.getSizeInBytes + valueArray.getSizeInBytes)
+
+ assert(unsafeRow.getSizeInBytes ==
+ 8 + 8 * 2 + field1.getSizeInBytes + mapSizeInRow(field2.getSizeInBytes))
+ }
+
+ test("basic conversion with array and map") {
+ val fieldTypes: Array[DataType] = Array(
+ ArrayType(MapType(IntegerType, IntegerType)),
+ MapType(IntegerType, ArrayType(IntegerType))
+ )
+ val converter = UnsafeProjection.create(fieldTypes)
+
+ val row = new GenericMutableRow(fieldTypes.length)
+ row.update(0, createArray(createMap(1)(2)))
+ row.update(1, createMap(3)(createArray(4)))
+
+ val unsafeRow: UnsafeRow = converter.apply(row)
+ assert(unsafeRow.numFields() == 2)
+
+ val field1 = unsafeRow.getArray(0)
+ assert(field1.numElements == 1)
+
+ val innerMap = field1.getMap(0)
+ testMapInt(innerMap, Seq(1), Seq(2))
+
+ assert(field1.getSizeInBytes == 4 + (8 + innerMap.getSizeInBytes))
+
+ val field2 = unsafeRow.getMap(1)
+ assert(field2.numElements == 1)
+
+ val keyArray = field2.keyArray
+ testArrayInt(keyArray, Seq(3))
+
+ val valueArray = field2.valueArray
+
+ {
+ assert(valueArray.numElements == 1)
+
+ val innerArray = valueArray.getArray(0)
+ testArrayInt(innerArray, Seq(4))
+
+ assert(valueArray.getSizeInBytes == 4 + (4 + innerArray.getSizeInBytes))
+ }
+
+ assert(field2.getSizeInBytes == keyArray.getSizeInBytes + valueArray.getSizeInBytes)
+
+ assert(unsafeRow.getSizeInBytes ==
+ 8 + 8 * 2 + arraySizeInRow(field1.getSizeInBytes) + mapSizeInRow(field2.getSizeInBytes))
}
}