From 608353c8e8e50461fafff91a2c885dca8af3aaa8 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Sun, 2 Aug 2015 23:41:16 -0700 Subject: [SPARK-9404][SPARK-9542][SQL] unsafe array data and map data This PR adds a UnsafeArrayData, current we encode it in this way: first 4 bytes is the # elements then each 4 byte is the start offset of the element, unless it is negative, in which case the element is null. followed by the elements themselves an example: [10, 11, 12, 13, null, 14] will be encoded as: 5, 28, 32, 36, 40, -44, 44, 10, 11, 12, 13, 14 Note that, when we read a UnsafeArrayData from bytes, we can read the first 4 bytes as numElements and take the rest(first 4 bytes skipped) as value region. unsafe map data just use 2 unsafe array data, first 4 bytes is # of elements, second 4 bytes is numBytes of key array, the follows key array data and value array data. Author: Wenchen Fan Closes #7752 from cloud-fan/unsafe-array and squashes the following commits: 3269bd7 [Wenchen Fan] fix a bug 6445289 [Wenchen Fan] add unit tests 49adf26 [Wenchen Fan] add unsafe map 20d1039 [Wenchen Fan] add comments and unsafe converter 821b8db [Wenchen Fan] add unsafe array --- .../sql/catalyst/expressions/UnsafeArrayData.java | 333 +++++++++++++++++++++ .../sql/catalyst/expressions/UnsafeMapData.java | 66 ++++ .../sql/catalyst/expressions/UnsafeReaders.java | 48 +++ .../spark/sql/catalyst/expressions/UnsafeRow.java | 34 ++- .../sql/catalyst/expressions/UnsafeRowWriters.java | 71 +++++ .../sql/catalyst/expressions/UnsafeWriters.java | 208 +++++++++++++ .../sql/catalyst/expressions/FromUnsafe.scala | 67 +++++ .../sql/catalyst/expressions/Projection.scala | 10 +- .../expressions/codegen/CodeGenerator.scala | 4 +- .../codegen/GenerateUnsafeProjection.scala | 327 +++++++++++++++++++- .../apache/spark/sql/types/ArrayBasedMapData.scala | 15 +- .../org/apache/spark/sql/types/ArrayData.scala | 14 +- .../apache/spark/sql/types/GenericArrayData.scala | 10 +- .../scala/org/apache/spark/sql/types/MapData.scala | 2 + .../expressions/UnsafeRowConverterSuite.scala | 114 ++++++- 15 files changed, 1292 insertions(+), 31 deletions(-) create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeMapData.java create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeReaders.java create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeWriters.java create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/FromUnsafe.scala (limited to 'sql') diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java new file mode 100644 index 0000000000..0374846d71 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java @@ -0,0 +1,333 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions; + +import java.math.BigDecimal; +import java.math.BigInteger; + +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.types.*; +import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.array.ByteArrayMethods; +import org.apache.spark.unsafe.hash.Murmur3_x86_32; +import org.apache.spark.unsafe.types.CalendarInterval; +import org.apache.spark.unsafe.types.UTF8String; + +/** + * An Unsafe implementation of Array which is backed by raw memory instead of Java objects. + * + * Each tuple has two parts: [offsets] [values] + * + * In the `offsets` region, we store 4 bytes per element, represents the start address of this + * element in `values` region. We can get the length of this element by subtracting next offset. + * Note that offset can by negative which means this element is null. + * + * In the `values` region, we store the content of elements. As we can get length info, so elements + * can be variable-length. + * + * Note that when we write out this array, we should write out the `numElements` at first 4 bytes, + * then follows content. When we read in an array, we should read first 4 bytes as `numElements` + * and take the rest as content. + * + * Instances of `UnsafeArrayData` act as pointers to row data stored in this format. + */ +// todo: there is a lof of duplicated code between UnsafeRow and UnsafeArrayData. +public class UnsafeArrayData extends ArrayData { + + private Object baseObject; + private long baseOffset; + + // The number of elements in this array + private int numElements; + + // The size of this array's backing data, in bytes + private int sizeInBytes; + + private int getElementOffset(int ordinal) { + return PlatformDependent.UNSAFE.getInt(baseObject, baseOffset + ordinal * 4L); + } + + private int getElementSize(int offset, int ordinal) { + if (ordinal == numElements - 1) { + return sizeInBytes - offset; + } else { + return Math.abs(getElementOffset(ordinal + 1)) - offset; + } + } + + private void assertIndexIsValid(int ordinal) { + assert ordinal >= 0 : "ordinal (" + ordinal + ") should >= 0"; + assert ordinal < numElements : "ordinal (" + ordinal + ") should < " + numElements; + } + + /** + * Construct a new UnsafeArrayData. The resulting UnsafeArrayData won't be usable until + * `pointTo()` has been called, since the value returned by this constructor is equivalent + * to a null pointer. + */ + public UnsafeArrayData() { } + + public Object getBaseObject() { return baseObject; } + public long getBaseOffset() { return baseOffset; } + public int getSizeInBytes() { return sizeInBytes; } + + @Override + public int numElements() { return numElements; } + + /** + * Update this UnsafeArrayData to point to different backing data. + * + * @param baseObject the base object + * @param baseOffset the offset within the base object + * @param sizeInBytes the size of this row's backing data, in bytes + */ + public void pointTo(Object baseObject, long baseOffset, int numElements, int sizeInBytes) { + assert numElements >= 0 : "numElements (" + numElements + ") should >= 0"; + this.numElements = numElements; + this.baseObject = baseObject; + this.baseOffset = baseOffset; + this.sizeInBytes = sizeInBytes; + } + + @Override + public boolean isNullAt(int ordinal) { + assertIndexIsValid(ordinal); + return getElementOffset(ordinal) < 0; + } + + @Override + public Object get(int ordinal, DataType dataType) { + if (isNullAt(ordinal) || dataType instanceof NullType) { + return null; + } else if (dataType instanceof BooleanType) { + return getBoolean(ordinal); + } else if (dataType instanceof ByteType) { + return getByte(ordinal); + } else if (dataType instanceof ShortType) { + return getShort(ordinal); + } else if (dataType instanceof IntegerType) { + return getInt(ordinal); + } else if (dataType instanceof LongType) { + return getLong(ordinal); + } else if (dataType instanceof FloatType) { + return getFloat(ordinal); + } else if (dataType instanceof DoubleType) { + return getDouble(ordinal); + } else if (dataType instanceof DecimalType) { + DecimalType dt = (DecimalType) dataType; + return getDecimal(ordinal, dt.precision(), dt.scale()); + } else if (dataType instanceof DateType) { + return getInt(ordinal); + } else if (dataType instanceof TimestampType) { + return getLong(ordinal); + } else if (dataType instanceof BinaryType) { + return getBinary(ordinal); + } else if (dataType instanceof StringType) { + return getUTF8String(ordinal); + } else if (dataType instanceof CalendarIntervalType) { + return getInterval(ordinal); + } else if (dataType instanceof StructType) { + return getStruct(ordinal, ((StructType) dataType).size()); + } else if (dataType instanceof ArrayType) { + return getArray(ordinal); + } else if (dataType instanceof MapType) { + return getMap(ordinal); + } else { + throw new UnsupportedOperationException("Unsupported data type " + dataType.simpleString()); + } + } + + @Override + public boolean getBoolean(int ordinal) { + assertIndexIsValid(ordinal); + final int offset = getElementOffset(ordinal); + if (offset < 0) return false; + return PlatformDependent.UNSAFE.getBoolean(baseObject, baseOffset + offset); + } + + @Override + public byte getByte(int ordinal) { + assertIndexIsValid(ordinal); + final int offset = getElementOffset(ordinal); + if (offset < 0) return 0; + return PlatformDependent.UNSAFE.getByte(baseObject, baseOffset + offset); + } + + @Override + public short getShort(int ordinal) { + assertIndexIsValid(ordinal); + final int offset = getElementOffset(ordinal); + if (offset < 0) return 0; + return PlatformDependent.UNSAFE.getShort(baseObject, baseOffset + offset); + } + + @Override + public int getInt(int ordinal) { + assertIndexIsValid(ordinal); + final int offset = getElementOffset(ordinal); + if (offset < 0) return 0; + return PlatformDependent.UNSAFE.getInt(baseObject, baseOffset + offset); + } + + @Override + public long getLong(int ordinal) { + assertIndexIsValid(ordinal); + final int offset = getElementOffset(ordinal); + if (offset < 0) return 0; + return PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + offset); + } + + @Override + public float getFloat(int ordinal) { + assertIndexIsValid(ordinal); + final int offset = getElementOffset(ordinal); + if (offset < 0) return 0; + return PlatformDependent.UNSAFE.getFloat(baseObject, baseOffset + offset); + } + + @Override + public double getDouble(int ordinal) { + assertIndexIsValid(ordinal); + final int offset = getElementOffset(ordinal); + if (offset < 0) return 0; + return PlatformDependent.UNSAFE.getDouble(baseObject, baseOffset + offset); + } + + @Override + public Decimal getDecimal(int ordinal, int precision, int scale) { + assertIndexIsValid(ordinal); + final int offset = getElementOffset(ordinal); + if (offset < 0) return null; + + if (precision <= Decimal.MAX_LONG_DIGITS()) { + final long value = PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + offset); + return Decimal.apply(value, precision, scale); + } else { + final byte[] bytes = getBinary(ordinal); + final BigInteger bigInteger = new BigInteger(bytes); + final BigDecimal javaDecimal = new BigDecimal(bigInteger, scale); + return Decimal.apply(new scala.math.BigDecimal(javaDecimal), precision, scale); + } + } + + @Override + public UTF8String getUTF8String(int ordinal) { + assertIndexIsValid(ordinal); + final int offset = getElementOffset(ordinal); + if (offset < 0) return null; + final int size = getElementSize(offset, ordinal); + return UTF8String.fromAddress(baseObject, baseOffset + offset, size); + } + + @Override + public byte[] getBinary(int ordinal) { + assertIndexIsValid(ordinal); + final int offset = getElementOffset(ordinal); + if (offset < 0) return null; + final int size = getElementSize(offset, ordinal); + final byte[] bytes = new byte[size]; + PlatformDependent.copyMemory( + baseObject, + baseOffset + offset, + bytes, + PlatformDependent.BYTE_ARRAY_OFFSET, + size); + return bytes; + } + + @Override + public CalendarInterval getInterval(int ordinal) { + assertIndexIsValid(ordinal); + final int offset = getElementOffset(ordinal); + if (offset < 0) return null; + final int months = (int) PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + offset); + final long microseconds = + PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + offset + 8); + return new CalendarInterval(months, microseconds); + } + + @Override + public InternalRow getStruct(int ordinal, int numFields) { + assertIndexIsValid(ordinal); + final int offset = getElementOffset(ordinal); + if (offset < 0) return null; + final int size = getElementSize(offset, ordinal); + final UnsafeRow row = new UnsafeRow(); + row.pointTo(baseObject, baseOffset + offset, numFields, size); + return row; + } + + @Override + public ArrayData getArray(int ordinal) { + assertIndexIsValid(ordinal); + final int offset = getElementOffset(ordinal); + if (offset < 0) return null; + final int size = getElementSize(offset, ordinal); + return UnsafeReaders.readArray(baseObject, baseOffset + offset, size); + } + + @Override + public MapData getMap(int ordinal) { + assertIndexIsValid(ordinal); + final int offset = getElementOffset(ordinal); + if (offset < 0) return null; + final int size = getElementSize(offset, ordinal); + return UnsafeReaders.readMap(baseObject, baseOffset + offset, size); + } + + @Override + public int hashCode() { + return Murmur3_x86_32.hashUnsafeWords(baseObject, baseOffset, sizeInBytes, 42); + } + + @Override + public boolean equals(Object other) { + if (other instanceof UnsafeArrayData) { + UnsafeArrayData o = (UnsafeArrayData) other; + return (sizeInBytes == o.sizeInBytes) && + ByteArrayMethods.arrayEquals(baseObject, baseOffset, o.baseObject, o.baseOffset, + sizeInBytes); + } + return false; + } + + public void writeToMemory(Object target, long targetOffset) { + PlatformDependent.copyMemory( + baseObject, + baseOffset, + target, + targetOffset, + sizeInBytes + ); + } + + @Override + public UnsafeArrayData copy() { + UnsafeArrayData arrayCopy = new UnsafeArrayData(); + final byte[] arrayDataCopy = new byte[sizeInBytes]; + PlatformDependent.copyMemory( + baseObject, + baseOffset, + arrayDataCopy, + PlatformDependent.BYTE_ARRAY_OFFSET, + sizeInBytes + ); + arrayCopy.pointTo(arrayDataCopy, PlatformDependent.BYTE_ARRAY_OFFSET, numElements, sizeInBytes); + return arrayCopy; + } +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeMapData.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeMapData.java new file mode 100644 index 0000000000..46216054ab --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeMapData.java @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions; + +import org.apache.spark.sql.types.ArrayData; +import org.apache.spark.sql.types.MapData; + +/** + * An Unsafe implementation of Map which is backed by raw memory instead of Java objects. + * + * Currently we just use 2 UnsafeArrayData to represent UnsafeMapData. + */ +public class UnsafeMapData extends MapData { + + public final UnsafeArrayData keys; + public final UnsafeArrayData values; + // The number of elements in this array + private int numElements; + // The size of this array's backing data, in bytes + private int sizeInBytes; + + public int getSizeInBytes() { return sizeInBytes; } + + public UnsafeMapData(UnsafeArrayData keys, UnsafeArrayData values) { + assert keys.numElements() == values.numElements(); + this.sizeInBytes = keys.getSizeInBytes() + values.getSizeInBytes(); + this.numElements = keys.numElements(); + this.keys = keys; + this.values = values; + } + + @Override + public int numElements() { + return numElements; + } + + @Override + public ArrayData keyArray() { + return keys; + } + + @Override + public ArrayData valueArray() { + return values; + } + + @Override + public UnsafeMapData copy() { + return new UnsafeMapData(keys.copy(), values.copy()); + } +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeReaders.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeReaders.java new file mode 100644 index 0000000000..b521b70338 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeReaders.java @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions; + +import org.apache.spark.unsafe.PlatformDependent; + +public class UnsafeReaders { + + public static UnsafeArrayData readArray(Object baseObject, long baseOffset, int numBytes) { + // Read the number of elements from first 4 bytes. + final int numElements = PlatformDependent.UNSAFE.getInt(baseObject, baseOffset); + final UnsafeArrayData array = new UnsafeArrayData(); + // Skip the first 4 bytes. + array.pointTo(baseObject, baseOffset + 4, numElements, numBytes - 4); + return array; + } + + public static UnsafeMapData readMap(Object baseObject, long baseOffset, int numBytes) { + // Read the number of elements from first 4 bytes. + final int numElements = PlatformDependent.UNSAFE.getInt(baseObject, baseOffset); + // Read the numBytes of key array in second 4 bytes. + final int keyArraySize = PlatformDependent.UNSAFE.getInt(baseObject, baseOffset + 4); + final int valueArraySize = numBytes - 8 - keyArraySize; + + final UnsafeArrayData keyArray = new UnsafeArrayData(); + keyArray.pointTo(baseObject, baseOffset + 8, numElements, keyArraySize); + + final UnsafeArrayData valueArray = new UnsafeArrayData(); + valueArray.pointTo(baseObject, baseOffset + 8 + keyArraySize, numElements, valueArraySize); + + return new UnsafeMapData(keyArray, valueArray); + } +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index b4fc0b7b70..c5d42d73a4 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -291,6 +291,10 @@ public final class UnsafeRow extends MutableRow { return getInterval(ordinal); } else if (dataType instanceof StructType) { return getStruct(ordinal, ((StructType) dataType).size()); + } else if (dataType instanceof ArrayType) { + return getArray(ordinal); + } else if (dataType instanceof MapType) { + return getMap(ordinal); } else { throw new UnsupportedOperationException("Unsupported data type " + dataType.simpleString()); } @@ -346,7 +350,6 @@ public final class UnsafeRow extends MutableRow { @Override public Decimal getDecimal(int ordinal, int precision, int scale) { - assertIndexIsValid(ordinal); if (isNullAt(ordinal)) { return null; } @@ -362,7 +365,6 @@ public final class UnsafeRow extends MutableRow { @Override public UTF8String getUTF8String(int ordinal) { - assertIndexIsValid(ordinal); if (isNullAt(ordinal)) return null; final long offsetAndSize = getLong(ordinal); final int offset = (int) (offsetAndSize >> 32); @@ -372,7 +374,6 @@ public final class UnsafeRow extends MutableRow { @Override public byte[] getBinary(int ordinal) { - assertIndexIsValid(ordinal); if (isNullAt(ordinal)) { return null; } else { @@ -410,7 +411,6 @@ public final class UnsafeRow extends MutableRow { if (isNullAt(ordinal)) { return null; } else { - assertIndexIsValid(ordinal); final long offsetAndSize = getLong(ordinal); final int offset = (int) (offsetAndSize >> 32); final int size = (int) (offsetAndSize & ((1L << 32) - 1)); @@ -420,11 +420,33 @@ public final class UnsafeRow extends MutableRow { } } + @Override + public ArrayData getArray(int ordinal) { + if (isNullAt(ordinal)) { + return null; + } else { + final long offsetAndSize = getLong(ordinal); + final int offset = (int) (offsetAndSize >> 32); + final int size = (int) (offsetAndSize & ((1L << 32) - 1)); + return UnsafeReaders.readArray(baseObject, baseOffset + offset, size); + } + } + + @Override + public MapData getMap(int ordinal) { + if (isNullAt(ordinal)) { + return null; + } else { + final long offsetAndSize = getLong(ordinal); + final int offset = (int) (offsetAndSize >> 32); + final int size = (int) (offsetAndSize & ((1L << 32) - 1)); + return UnsafeReaders.readMap(baseObject, baseOffset + offset, size); + } + } + /** * Copies this row, returning a self-contained UnsafeRow that stores its data in an internal * byte array rather than referencing data stored in a data page. - *

- * This method is only supported on UnsafeRows that do not use ObjectPools. */ @Override public UnsafeRow copy() { diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java index f43a285cd6..3192873154 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.types.Decimal; +import org.apache.spark.sql.types.MapData; import org.apache.spark.unsafe.PlatformDependent; import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.types.ByteArray; @@ -185,4 +186,74 @@ public class UnsafeRowWriters { return 16; } } + + public static class ArrayWriter { + + public static int getSize(UnsafeArrayData input) { + // we need extra 4 bytes the store the number of elements in this array. + return ByteArrayMethods.roundNumberOfBytesToNearestWord(input.getSizeInBytes() + 4); + } + + public static int write(UnsafeRow target, int ordinal, int cursor, UnsafeArrayData input) { + final int numBytes = input.getSizeInBytes() + 4; + final long offset = target.getBaseOffset() + cursor; + + // write the number of elements into first 4 bytes. + PlatformDependent.UNSAFE.putInt(target.getBaseObject(), offset, input.numElements()); + + // zero-out the padding bytes + if ((numBytes & 0x07) > 0) { + PlatformDependent.UNSAFE.putLong( + target.getBaseObject(), offset + ((numBytes >> 3) << 3), 0L); + } + + // Write the bytes to the variable length portion. + input.writeToMemory(target.getBaseObject(), offset + 4); + + // Set the fixed length portion. + target.setLong(ordinal, (((long) cursor) << 32) | ((long) numBytes)); + + return ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes); + } + } + + public static class MapWriter { + + public static int getSize(UnsafeMapData input) { + // we need extra 8 bytes to store number of elements and numBytes of key array. + final int sizeInBytes = 4 + 4 + input.getSizeInBytes(); + return ByteArrayMethods.roundNumberOfBytesToNearestWord(sizeInBytes); + } + + public static int write(UnsafeRow target, int ordinal, int cursor, UnsafeMapData input) { + final long offset = target.getBaseOffset() + cursor; + final UnsafeArrayData keyArray = input.keys; + final UnsafeArrayData valueArray = input.values; + final int keysNumBytes = keyArray.getSizeInBytes(); + final int valuesNumBytes = valueArray.getSizeInBytes(); + final int numBytes = 4 + 4 + keysNumBytes + valuesNumBytes; + + // write the number of elements into first 4 bytes. + PlatformDependent.UNSAFE.putInt(target.getBaseObject(), offset, input.numElements()); + // write the numBytes of key array into second 4 bytes. + PlatformDependent.UNSAFE.putInt(target.getBaseObject(), offset + 4, keysNumBytes); + + // zero-out the padding bytes + if ((numBytes & 0x07) > 0) { + PlatformDependent.UNSAFE.putLong( + target.getBaseObject(), offset + ((numBytes >> 3) << 3), 0L); + } + + // Write the bytes of key array to the variable length portion. + keyArray.writeToMemory(target.getBaseObject(), offset + 8); + + // Write the bytes of value array to the variable length portion. + valueArray.writeToMemory(target.getBaseObject(), offset + 8 + keysNumBytes); + + // Set the fixed length portion. + target.setLong(ordinal, (((long) cursor) << 32) | ((long) numBytes)); + + return ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes); + } + } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeWriters.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeWriters.java new file mode 100644 index 0000000000..0e8e405d05 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeWriters.java @@ -0,0 +1,208 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions; + +import org.apache.spark.sql.types.Decimal; +import org.apache.spark.unsafe.PlatformDependent; +import org.apache.spark.unsafe.array.ByteArrayMethods; +import org.apache.spark.unsafe.types.CalendarInterval; +import org.apache.spark.unsafe.types.UTF8String; + +/** + * A set of helper methods to write data into the variable length portion. + */ +public class UnsafeWriters { + public static void writeToMemory( + Object inputObject, + long inputOffset, + Object targetObject, + long targetOffset, + int numBytes) { + + // zero-out the padding bytes +// if ((numBytes & 0x07) > 0) { +// PlatformDependent.UNSAFE.putLong(targetObject, targetOffset + ((numBytes >> 3) << 3), 0L); +// } + + // Write the UnsafeData to the target memory. + PlatformDependent.copyMemory( + inputObject, + inputOffset, + targetObject, + targetOffset, + numBytes + ); + } + + public static int getRoundedSize(int size) { + //return ByteArrayMethods.roundNumberOfBytesToNearestWord(size); + // todo: do word alignment + return size; + } + + /** Writer for Decimal with precision larger than 18. */ + public static class DecimalWriter { + + public static int getSize(Decimal input) { + return 16; + } + + public static int write(Object targetObject, long targetOffset, Decimal input) { + final byte[] bytes = input.toJavaBigDecimal().unscaledValue().toByteArray(); + final int numBytes = bytes.length; + assert(numBytes <= 16); + + // zero-out the bytes + PlatformDependent.UNSAFE.putLong(targetObject, targetOffset, 0L); + PlatformDependent.UNSAFE.putLong(targetObject, targetOffset + 8, 0L); + + // Write the bytes to the variable length portion. + PlatformDependent.copyMemory(bytes, + PlatformDependent.BYTE_ARRAY_OFFSET, + targetObject, + targetOffset, + numBytes); + + return 16; + } + } + + /** Writer for UTF8String. */ + public static class UTF8StringWriter { + + public static int getSize(UTF8String input) { + return getRoundedSize(input.numBytes()); + } + + public static int write(Object targetObject, long targetOffset, UTF8String input) { + final int numBytes = input.numBytes(); + + // Write the bytes to the variable length portion. + writeToMemory(input.getBaseObject(), input.getBaseOffset(), + targetObject, targetOffset, numBytes); + + return getRoundedSize(numBytes); + } + } + + /** Writer for binary (byte array) type. */ + public static class BinaryWriter { + + public static int getSize(byte[] input) { + return getRoundedSize(input.length); + } + + public static int write(Object targetObject, long targetOffset, byte[] input) { + final int numBytes = input.length; + + // Write the bytes to the variable length portion. + writeToMemory(input, PlatformDependent.BYTE_ARRAY_OFFSET, + targetObject, targetOffset, numBytes); + + return getRoundedSize(numBytes); + } + } + + /** Writer for UnsafeRow. */ + public static class StructWriter { + + public static int getSize(UnsafeRow input) { + return getRoundedSize(input.getSizeInBytes()); + } + + public static int write(Object targetObject, long targetOffset, UnsafeRow input) { + final int numBytes = input.getSizeInBytes(); + + // Write the bytes to the variable length portion. + writeToMemory(input.getBaseObject(), input.getBaseOffset(), + targetObject, targetOffset, numBytes); + + return getRoundedSize(numBytes); + } + } + + /** Writer for interval type. */ + public static class IntervalWriter { + + public static int getSize(UnsafeRow input) { + return 16; + } + + public static int write(Object targetObject, long targetOffset, CalendarInterval input) { + + // Write the months and microseconds fields of Interval to the variable length portion. + PlatformDependent.UNSAFE.putLong(targetObject, targetOffset, input.months); + PlatformDependent.UNSAFE.putLong(targetObject, targetOffset + 8, input.microseconds); + + return 16; + } + } + + /** Writer for UnsafeArrayData. */ + public static class ArrayWriter { + + public static int getSize(UnsafeArrayData input) { + // we need extra 4 bytes the store the number of elements in this array. + return getRoundedSize(input.getSizeInBytes() + 4); + } + + public static int write(Object targetObject, long targetOffset, UnsafeArrayData input) { + final int numBytes = input.getSizeInBytes(); + + // write the number of elements into first 4 bytes. + PlatformDependent.UNSAFE.putInt(targetObject, targetOffset, input.numElements()); + + // Write the bytes to the variable length portion. + writeToMemory(input.getBaseObject(), input.getBaseOffset(), + targetObject, targetOffset + 4, numBytes); + + return getRoundedSize(numBytes + 4); + } + } + + public static class MapWriter { + + public static int getSize(UnsafeMapData input) { + // we need extra 8 bytes to store number of elements and numBytes of key array. + return getRoundedSize(4 + 4 + input.getSizeInBytes()); + } + + public static int write(Object targetObject, long targetOffset, UnsafeMapData input) { + final UnsafeArrayData keyArray = input.keys; + final UnsafeArrayData valueArray = input.values; + final int keysNumBytes = keyArray.getSizeInBytes(); + final int valuesNumBytes = valueArray.getSizeInBytes(); + final int numBytes = 4 + 4 + keysNumBytes + valuesNumBytes; + + // write the number of elements into first 4 bytes. + PlatformDependent.UNSAFE.putInt(targetObject, targetOffset, input.numElements()); + // write the numBytes of key array into second 4 bytes. + PlatformDependent.UNSAFE.putInt(targetObject, targetOffset + 4, keysNumBytes); + + // Write the bytes of key array to the variable length portion. + writeToMemory(keyArray.getBaseObject(), keyArray.getBaseOffset(), + targetObject, targetOffset + 8, keysNumBytes); + + // Write the bytes of value array to the variable length portion. + writeToMemory(valueArray.getBaseObject(), valueArray.getBaseOffset(), + targetObject, targetOffset + 8 + keysNumBytes, valuesNumBytes); + + return getRoundedSize(numBytes); + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/FromUnsafe.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/FromUnsafe.scala new file mode 100644 index 0000000000..3caf0fb341 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/FromUnsafe.scala @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.types._ + +case class FromUnsafe(child: Expression) extends UnaryExpression + with ExpectsInputTypes with CodegenFallback { + + override def inputTypes: Seq[AbstractDataType] = + Seq(TypeCollection(ArrayType, StructType, MapType)) + + override def dataType: DataType = child.dataType + + private def convert(value: Any, dt: DataType): Any = dt match { + case StructType(fields) => + val row = value.asInstanceOf[UnsafeRow] + val result = new Array[Any](fields.length) + fields.map(_.dataType).zipWithIndex.foreach { case (dt, i) => + if (!row.isNullAt(i)) { + result(i) = convert(row.get(i, dt), dt) + } + } + new GenericInternalRow(result) + + case ArrayType(elementType, _) => + val array = value.asInstanceOf[UnsafeArrayData] + val length = array.numElements() + val result = new Array[Any](length) + var i = 0 + while (i < length) { + if (!array.isNullAt(i)) { + result(i) = convert(array.get(i, elementType), elementType) + } + i += 1 + } + new GenericArrayData(result) + + case MapType(kt, vt, _) => + val map = value.asInstanceOf[UnsafeMapData] + val safeKeyArray = convert(map.keys, ArrayType(kt)).asInstanceOf[GenericArrayData] + val safeValueArray = convert(map.values, ArrayType(vt)).asInstanceOf[GenericArrayData] + new ArrayBasedMapData(safeKeyArray, safeValueArray) + + case _ => value + } + + override def nullSafeEval(input: Any): Any = { + convert(input, dataType) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index 83129dc12d..7964974102 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -151,7 +151,15 @@ object FromUnsafeProjection { * Returns an UnsafeProjection for given Array of DataTypes. */ def apply(fields: Seq[DataType]): Projection = { - create(fields.zipWithIndex.map(x => new BoundReference(x._2, x._1, true))) + create(fields.zipWithIndex.map(x => { + val b = new BoundReference(x._2, x._1, true) + // todo: this is quite slow, maybe remove this whole projection after remove generic getter of + // InternalRow? + b.dataType match { + case _: StructType | _: ArrayType | _: MapType => FromUnsafe(b) + case _ => b + } + })) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 03ec4b4b4e..7b41c9a3f3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -336,7 +336,9 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin classOf[Decimal].getName, classOf[CalendarInterval].getName, classOf[ArrayData].getName, - classOf[MapData].getName + classOf[UnsafeArrayData].getName, + classOf[MapData].getName, + classOf[UnsafeMapData].getName )) evaluator.setExtendedClass(classOf[GeneratedClass]) try { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 934ec3f75c..fc3ecf5451 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions.codegen import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.PlatformDependent /** * Generates a [[Projection]] that returns an [[UnsafeRow]]. @@ -37,14 +38,19 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro private val StructWriter = classOf[UnsafeRowWriters.StructWriter].getName private val CompactDecimalWriter = classOf[UnsafeRowWriters.CompactDecimalWriter].getName private val DecimalWriter = classOf[UnsafeRowWriters.DecimalWriter].getName + private val ArrayWriter = classOf[UnsafeRowWriters.ArrayWriter].getName + private val MapWriter = classOf[UnsafeRowWriters.MapWriter].getName + + private val PlatformDependent = classOf[PlatformDependent].getName /** Returns true iff we support this data type. */ def canSupport(dataType: DataType): Boolean = dataType match { - case t: AtomicType if !t.isInstanceOf[DecimalType] => true + case t: AtomicType => true case _: CalendarIntervalType => true case t: StructType => t.toSeq.forall(field => canSupport(field.dataType)) case NullType => true - case t: DecimalType => true + case t: ArrayType if canSupport(t.elementType) => true + case MapType(kt, vt, _) if canSupport(kt) && canSupport(vt) => true case _ => false } @@ -59,6 +65,10 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro s" + (${ev.isNull} ? 0 : 16)" case _: StructType => s" + (${ev.isNull} ? 0 : $StructWriter.getSize(${ev.primitive}))" + case _: ArrayType => + s" + (${ev.isNull} ? 0 : $ArrayWriter.getSize(${ev.primitive}))" + case _: MapType => + s" + (${ev.isNull} ? 0 : $MapWriter.getSize(${ev.primitive}))" case _ => "" } @@ -95,8 +105,12 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro s"$cursor += $BinaryWriter.write($primitive, $index, $cursor, ${ev.primitive})" case CalendarIntervalType => s"$cursor += $IntervalWriter.write($primitive, $index, $cursor, ${ev.primitive})" - case t: StructType => + case _: StructType => s"$cursor += $StructWriter.write($primitive, $index, $cursor, ${ev.primitive})" + case _: ArrayType => + s"$cursor += $ArrayWriter.write($primitive, $index, $cursor, ${ev.primitive})" + case _: MapType => + s"$cursor += $MapWriter.write($primitive, $index, $cursor, ${ev.primitive})" case NullType => "" case _ => throw new UnsupportedOperationException(s"Not supported DataType: $fieldType") @@ -148,7 +162,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro $ret.pointTo( $buffer, - org.apache.spark.unsafe.PlatformDependent.BYTE_ARRAY_OFFSET, + $PlatformDependent.BYTE_ARRAY_OFFSET, ${expressions.size}, $numBytes); int $cursor = $fixedSize; @@ -237,7 +251,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro | | $primitive.pointTo( | $buffer, - | org.apache.spark.unsafe.PlatformDependent.BYTE_ARRAY_OFFSET, + | $PlatformDependent.BYTE_ARRAY_OFFSET, | ${exprs.size}, | $numBytes); | int $cursor = $fixedSize; @@ -250,6 +264,303 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro GeneratedExpressionCode(code, isNull, primitive) } + /** + * Generates the Java code to convert a struct (backed by InternalRow) to UnsafeRow. + * + * @param ctx code generation context + * @param inputs could be the codes for expressions or input struct fields. + * @param inputTypes types of the inputs + */ + private def createCodeForStruct2( + ctx: CodeGenContext, + inputs: Seq[GeneratedExpressionCode], + inputTypes: Seq[DataType]): GeneratedExpressionCode = { + + val output = ctx.freshName("convertedStruct") + ctx.addMutableState("UnsafeRow", output, s"$output = new UnsafeRow();") + val buffer = ctx.freshName("buffer") + ctx.addMutableState("byte[]", buffer, s"$buffer = new byte[64];") + val numBytes = ctx.freshName("numBytes") + val cursor = ctx.freshName("cursor") + + val convertedFields = inputTypes.zip(inputs).map { case (dt, input) => + createConvertCode(ctx, input, dt) + } + + val fixedSize = 8 * inputTypes.length + UnsafeRow.calculateBitSetWidthInBytes(inputTypes.length) + val additionalSize = inputTypes.zip(convertedFields).map { case (dt, ev) => + genAdditionalSize(dt, ev) + }.mkString("") + + val fieldWriters = inputTypes.zip(convertedFields).zipWithIndex.map { case ((dt, ev), i) => + val update = genFieldWriter(ctx, dt, ev, output, i, cursor) + s""" + if (${ev.isNull}) { + $output.setNullAt($i); + } else { + $update; + } + """ + }.mkString("\n") + + val code = s""" + ${convertedFields.map(_.code).mkString("\n")} + + final int $numBytes = $fixedSize $additionalSize; + if ($numBytes > $buffer.length) { + $buffer = new byte[$numBytes]; + } + + $output.pointTo( + $buffer, + $PlatformDependent.BYTE_ARRAY_OFFSET, + ${inputTypes.length}, + $numBytes); + + int $cursor = $fixedSize; + + $fieldWriters + """ + GeneratedExpressionCode(code, "false", output) + } + + private def getWriter(dt: DataType) = dt match { + case StringType => classOf[UnsafeWriters.UTF8StringWriter].getName + case BinaryType => classOf[UnsafeWriters.BinaryWriter].getName + case CalendarIntervalType => classOf[UnsafeWriters.IntervalWriter].getName + case _: StructType => classOf[UnsafeWriters.StructWriter].getName + case _: ArrayType => classOf[UnsafeWriters.ArrayWriter].getName + case _: MapType => classOf[UnsafeWriters.MapWriter].getName + case _: DecimalType => classOf[UnsafeWriters.DecimalWriter].getName + } + + private def createCodeForArray( + ctx: CodeGenContext, + input: GeneratedExpressionCode, + elementType: DataType): GeneratedExpressionCode = { + val output = ctx.freshName("convertedArray") + ctx.addMutableState("UnsafeArrayData", output, s"$output = new UnsafeArrayData();") + val buffer = ctx.freshName("buffer") + ctx.addMutableState("byte[]", buffer, s"$buffer = new byte[64];") + val outputIsNull = ctx.freshName("isNull") + val tmp = ctx.freshName("tmp") + val numElements = ctx.freshName("numElements") + val fixedSize = ctx.freshName("fixedSize") + val numBytes = ctx.freshName("numBytes") + val elements = ctx.freshName("elements") + val cursor = ctx.freshName("cursor") + val index = ctx.freshName("index") + + val element = GeneratedExpressionCode( + code = "", + isNull = s"$tmp.isNullAt($index)", + primitive = s"${ctx.getValue(tmp, elementType, index)}" + ) + val convertedElement: GeneratedExpressionCode = createConvertCode(ctx, element, elementType) + + // go through the input array to calculate how many bytes we need. + val calculateNumBytes = elementType match { + case _ if (ctx.isPrimitiveType(elementType)) => + // Should we do word align? + val elementSize = elementType.defaultSize + s""" + $numBytes += $elementSize * $numElements; + """ + case t: DecimalType if t.precision <= Decimal.MAX_LONG_DIGITS => + s""" + $numBytes += 8 * $numElements; + """ + case _ => + val writer = getWriter(elementType) + val elementSize = s"$writer.getSize($elements[$index])" + val unsafeType = elementType match { + case _: StructType => "UnsafeRow" + case _: ArrayType => "UnsafeArrayData" + case _: MapType => "UnsafeMapData" + case _ => ctx.javaType(elementType) + } + val copy = elementType match { + // We reuse the buffer during conversion, need copy it before process next element. + case _: StructType | _: ArrayType | _: MapType => ".copy()" + case _ => "" + } + + s""" + final $unsafeType[] $elements = new $unsafeType[$numElements]; + for (int $index = 0; $index < $numElements; $index++) { + ${convertedElement.code} + if (!${convertedElement.isNull}) { + $elements[$index] = ${convertedElement.primitive}$copy; + $numBytes += $elementSize; + } + } + """ + } + + val writeElement = elementType match { + case _ if (ctx.isPrimitiveType(elementType)) => + // Should we do word align? + val elementSize = elementType.defaultSize + s""" + $PlatformDependent.UNSAFE.put${ctx.primitiveTypeName(elementType)}( + $buffer, + $PlatformDependent.BYTE_ARRAY_OFFSET + $cursor, + ${convertedElement.primitive}); + $cursor += $elementSize; + """ + case t: DecimalType if t.precision <= Decimal.MAX_LONG_DIGITS => + s""" + $PlatformDependent.UNSAFE.putLong( + $buffer, + $PlatformDependent.BYTE_ARRAY_OFFSET + $cursor, + ${convertedElement.primitive}.toUnscaledLong()); + $cursor += 8; + """ + case _ => + val writer = getWriter(elementType) + s""" + $cursor += $writer.write( + $buffer, + $PlatformDependent.BYTE_ARRAY_OFFSET + $cursor, + $elements[$index]); + """ + } + + val checkNull = elementType match { + case _ if ctx.isPrimitiveType(elementType) => s"${convertedElement.isNull}" + case t: DecimalType => s"$elements[$index] == null" + + s" || !$elements[$index].changePrecision(${t.precision}, ${t.scale})" + case _ => s"$elements[$index] == null" + } + + val code = s""" + ${input.code} + final boolean $outputIsNull = ${input.isNull}; + if (!$outputIsNull) { + final ArrayData $tmp = ${input.primitive}; + if ($tmp instanceof UnsafeArrayData) { + $output = (UnsafeArrayData) $tmp; + } else { + final int $numElements = $tmp.numElements(); + final int $fixedSize = 4 * $numElements; + int $numBytes = $fixedSize; + + $calculateNumBytes + + if ($numBytes > $buffer.length) { + $buffer = new byte[$numBytes]; + } + + int $cursor = $fixedSize; + for (int $index = 0; $index < $numElements; $index++) { + if ($checkNull) { + // If element is null, write the negative value address into offset region. + $PlatformDependent.UNSAFE.putInt( + $buffer, + $PlatformDependent.BYTE_ARRAY_OFFSET + 4 * $index, + -$cursor); + } else { + $PlatformDependent.UNSAFE.putInt( + $buffer, + $PlatformDependent.BYTE_ARRAY_OFFSET + 4 * $index, + $cursor); + + $writeElement + } + } + + $output.pointTo( + $buffer, + $PlatformDependent.BYTE_ARRAY_OFFSET, + $numElements, + $numBytes); + } + } + """ + GeneratedExpressionCode(code, outputIsNull, output) + } + + private def createCodeForMap( + ctx: CodeGenContext, + input: GeneratedExpressionCode, + keyType: DataType, + valueType: DataType): GeneratedExpressionCode = { + val output = ctx.freshName("convertedMap") + val outputIsNull = ctx.freshName("isNull") + val tmp = ctx.freshName("tmp") + + val keyArray = GeneratedExpressionCode( + code = "", + isNull = "false", + primitive = s"$tmp.keyArray()" + ) + val valueArray = GeneratedExpressionCode( + code = "", + isNull = "false", + primitive = s"$tmp.valueArray()" + ) + val convertedKeys: GeneratedExpressionCode = createCodeForArray(ctx, keyArray, keyType) + val convertedValues: GeneratedExpressionCode = createCodeForArray(ctx, valueArray, valueType) + + val code = s""" + ${input.code} + final boolean $outputIsNull = ${input.isNull}; + UnsafeMapData $output = null; + if (!$outputIsNull) { + final MapData $tmp = ${input.primitive}; + if ($tmp instanceof UnsafeMapData) { + $output = (UnsafeMapData) $tmp; + } else { + ${convertedKeys.code} + ${convertedValues.code} + $output = new UnsafeMapData(${convertedKeys.primitive}, ${convertedValues.primitive}); + } + } + """ + GeneratedExpressionCode(code, outputIsNull, output) + } + + /** + * Generates the java code to convert a data to its unsafe version. + */ + private def createConvertCode( + ctx: CodeGenContext, + input: GeneratedExpressionCode, + dataType: DataType): GeneratedExpressionCode = dataType match { + case t: StructType => + val output = ctx.freshName("convertedStruct") + val outputIsNull = ctx.freshName("isNull") + val tmp = ctx.freshName("tmp") + val fieldTypes = t.fields.map(_.dataType) + val fieldEvals = fieldTypes.zipWithIndex.map { case (dt, i) => + val getFieldCode = ctx.getValue(tmp, dt, i.toString) + val fieldIsNull = s"$tmp.isNullAt($i)" + GeneratedExpressionCode("", fieldIsNull, getFieldCode) + } + val converter = createCodeForStruct2(ctx, fieldEvals, fieldTypes) + val code = s""" + ${input.code} + UnsafeRow $output = null; + final boolean $outputIsNull = ${input.isNull}; + if (!$outputIsNull) { + final InternalRow $tmp = ${input.primitive}; + if ($tmp instanceof UnsafeRow) { + $output = (UnsafeRow) $tmp; + } else { + ${converter.code} + $output = ${converter.primitive}; + } + } + """ + GeneratedExpressionCode(code, outputIsNull, output) + + case ArrayType(elementType, _) => createCodeForArray(ctx, input, elementType) + + case MapType(kt, vt, _) => createCodeForMap(ctx, input, kt, vt) + + case _ => input + } + protected def canonicalize(in: Seq[Expression]): Seq[Expression] = in.map(ExpressionCanonicalizer.execute) @@ -259,10 +570,8 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro protected def create(expressions: Seq[Expression]): UnsafeProjection = { val ctx = newCodeGenContext() - val isNull = ctx.freshName("retIsNull") - val primitive = ctx.freshName("retValue") - val eval = GeneratedExpressionCode("", isNull, primitive) - eval.code = createCode(ctx, eval, expressions) + val exprEvals = expressions.map(e => e.gen(ctx)) + val eval = createCodeForStruct2(ctx, exprEvals, expressions.map(_.dataType)) val code = s""" public Object generate($exprType[] exprs) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayBasedMapData.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayBasedMapData.scala index db4876355d..f6fa021ade 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayBasedMapData.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayBasedMapData.scala @@ -22,6 +22,9 @@ class ArrayBasedMapData(val keyArray: ArrayData, val valueArray: ArrayData) exte override def numElements(): Int = keyArray.numElements() + override def copy(): MapData = new ArrayBasedMapData(keyArray.copy(), valueArray.copy()) + + // We need to check equality of map type in tests. override def equals(o: Any): Boolean = { if (!o.isInstanceOf[ArrayBasedMapData]) { return false @@ -32,15 +35,15 @@ class ArrayBasedMapData(val keyArray: ArrayData, val valueArray: ArrayData) exte return false } - this.keyArray == other.keyArray && this.valueArray == other.valueArray + ArrayBasedMapData.toScalaMap(this) == ArrayBasedMapData.toScalaMap(other) } override def hashCode: Int = { - keyArray.hashCode() * 37 + valueArray.hashCode() + ArrayBasedMapData.toScalaMap(this).hashCode() } override def toString(): String = { - s"keys: $keyArray\nvalues: $valueArray" + s"keys: $keyArray, values: $valueArray" } } @@ -48,4 +51,10 @@ object ArrayBasedMapData { def apply(keys: Array[Any], values: Array[Any]): ArrayBasedMapData = { new ArrayBasedMapData(new GenericArrayData(keys), new GenericArrayData(values)) } + + def toScalaMap(map: ArrayBasedMapData): Map[Any, Any] = { + val keys = map.keyArray.asInstanceOf[GenericArrayData].array + val values = map.valueArray.asInstanceOf[GenericArrayData].array + keys.zip(values).toMap + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayData.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayData.scala index c99fc23325..642c56f12d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayData.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayData.scala @@ -17,11 +17,15 @@ package org.apache.spark.sql.types +import scala.reflect.ClassTag + import org.apache.spark.sql.catalyst.expressions.SpecializedGetters abstract class ArrayData extends SpecializedGetters with Serializable { def numElements(): Int + def copy(): ArrayData + def toBooleanArray(): Array[Boolean] = { val size = numElements() val values = new Array[Boolean](size) @@ -99,19 +103,19 @@ abstract class ArrayData extends SpecializedGetters with Serializable { values } - def toArray[T](elementType: DataType): Array[T] = { + def toArray[T: ClassTag](elementType: DataType): Array[T] = { val size = numElements() - val values = new Array[Any](size) + val values = new Array[T](size) var i = 0 while (i < size) { if (isNullAt(i)) { - values(i) = null + values(i) = null.asInstanceOf[T] } else { - values(i) = get(i, elementType) + values(i) = get(i, elementType).asInstanceOf[T] } i += 1 } - values.asInstanceOf[Array[T]] + values } // todo: specialize this. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/GenericArrayData.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/GenericArrayData.scala index b3e75f8bad..b314acdfe3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/GenericArrayData.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/GenericArrayData.scala @@ -17,13 +17,19 @@ package org.apache.spark.sql.types +import scala.reflect.ClassTag + import org.apache.spark.sql.catalyst.expressions.GenericSpecializedGetters -class GenericArrayData(array: Array[Any]) extends ArrayData with GenericSpecializedGetters { +class GenericArrayData(private[sql] val array: Array[Any]) + extends ArrayData with GenericSpecializedGetters { override def genericGet(ordinal: Int): Any = array(ordinal) - override def toArray[T](elementType: DataType): Array[T] = array.asInstanceOf[Array[T]] + override def copy(): ArrayData = new GenericArrayData(array.clone()) + + // todo: Array is invariant in scala, maybe use toSeq instead? + override def toArray[T: ClassTag](elementType: DataType): Array[T] = array.map(_.asInstanceOf[T]) override def numElements(): Int = array.length diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapData.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapData.scala index 5514c3cd85..f50969f0f0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapData.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapData.scala @@ -25,6 +25,8 @@ abstract class MapData extends Serializable { def valueArray(): ArrayData + def copy(): MapData + def foreach(keyType: DataType, valueType: DataType, f: (Any, Any) => Unit): Unit = { val length = numElements() val keys = keyArray() diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala index 44f845620a..59491c5ba1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala @@ -31,6 +31,8 @@ import org.apache.spark.unsafe.types.UTF8String class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { + private def roundedSize(size: Int) = ByteArrayMethods.roundNumberOfBytesToNearestWord(size) + test("basic conversion with only primitive types") { val fieldTypes: Array[DataType] = Array(LongType, LongType, IntegerType) val converter = UnsafeProjection.create(fieldTypes) @@ -73,8 +75,8 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { val unsafeRow: UnsafeRow = converter.apply(row) assert(unsafeRow.getSizeInBytes === 8 + (8 * 3) + - ByteArrayMethods.roundNumberOfBytesToNearestWord("Hello".getBytes.length) + - ByteArrayMethods.roundNumberOfBytesToNearestWord("World".getBytes.length)) + roundedSize("Hello".getBytes.length) + + roundedSize("World".getBytes.length)) assert(unsafeRow.getLong(0) === 0) assert(unsafeRow.getString(1) === "Hello") @@ -92,8 +94,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { row.update(3, DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf("2015-05-08 08:10:25"))) val unsafeRow: UnsafeRow = converter.apply(row) - assert(unsafeRow.getSizeInBytes === 8 + (8 * 4) + - ByteArrayMethods.roundNumberOfBytesToNearestWord("Hello".getBytes.length)) + assert(unsafeRow.getSizeInBytes === 8 + (8 * 4) + roundedSize("Hello".getBytes.length)) assert(unsafeRow.getLong(0) === 0) assert(unsafeRow.getString(1) === "Hello") @@ -172,6 +173,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { r } + // todo: we reuse the UnsafeRow in projection, so these tests are meaningless. val setToNullAfterCreation = converter.apply(rowWithNoNullColumns) assert(setToNullAfterCreation.isNullAt(0) === rowWithNoNullColumns.isNullAt(0)) assert(setToNullAfterCreation.getBoolean(1) === rowWithNoNullColumns.getBoolean(1)) @@ -235,4 +237,108 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { val converter = UnsafeProjection.create(fieldTypes) assert(converter.apply(row1).getBytes === converter.apply(row2).getBytes) } + + test("basic conversion with array type") { + val fieldTypes: Array[DataType] = Array( + ArrayType(LongType), + ArrayType(ArrayType(LongType)) + ) + val converter = UnsafeProjection.create(fieldTypes) + + val array1 = new GenericArrayData(Array[Any](1L, 2L)) + val array2 = new GenericArrayData(Array[Any](new GenericArrayData(Array[Any](3L, 4L)))) + val row = new GenericMutableRow(fieldTypes.length) + row.update(0, array1) + row.update(1, array2) + + val unsafeRow: UnsafeRow = converter.apply(row) + assert(unsafeRow.numFields() == 2) + + val unsafeArray1 = unsafeRow.getArray(0).asInstanceOf[UnsafeArrayData] + assert(unsafeArray1.getSizeInBytes == 4 * 2 + 8 * 2) + assert(unsafeArray1.numElements() == 2) + assert(unsafeArray1.getLong(0) == 1L) + assert(unsafeArray1.getLong(1) == 2L) + + val unsafeArray2 = unsafeRow.getArray(1).asInstanceOf[UnsafeArrayData] + assert(unsafeArray2.numElements() == 1) + + val nestedArray = unsafeArray2.getArray(0).asInstanceOf[UnsafeArrayData] + assert(nestedArray.getSizeInBytes == 4 * 2 + 8 * 2) + assert(nestedArray.numElements() == 2) + assert(nestedArray.getLong(0) == 3L) + assert(nestedArray.getLong(1) == 4L) + + assert(unsafeArray2.getSizeInBytes == 4 + 4 + nestedArray.getSizeInBytes) + + val array1Size = roundedSize(4 + unsafeArray1.getSizeInBytes) + val array2Size = roundedSize(4 + unsafeArray2.getSizeInBytes) + assert(unsafeRow.getSizeInBytes == 8 + 8 * 2 + array1Size + array2Size) + } + + test("basic conversion with map type") { + def createArray(values: Any*): ArrayData = new GenericArrayData(values.toArray) + + def testIntLongMap(map: UnsafeMapData, keys: Array[Int], values: Array[Long]): Unit = { + val numElements = keys.length + assert(map.numElements() == numElements) + + val keyArray = map.keys + assert(keyArray.getSizeInBytes == 4 * numElements + 4 * numElements) + assert(keyArray.numElements() == numElements) + keys.zipWithIndex.foreach { case (key, i) => + assert(keyArray.getInt(i) == key) + } + + val valueArray = map.values + assert(valueArray.getSizeInBytes == 4 * numElements + 8 * numElements) + assert(valueArray.numElements() == numElements) + values.zipWithIndex.foreach { case (value, i) => + assert(valueArray.getLong(i) == value) + } + + assert(map.getSizeInBytes == keyArray.getSizeInBytes + valueArray.getSizeInBytes) + } + + val fieldTypes: Array[DataType] = Array( + MapType(IntegerType, LongType), + MapType(IntegerType, MapType(IntegerType, LongType)) + ) + val converter = UnsafeProjection.create(fieldTypes) + + val map1 = new ArrayBasedMapData(createArray(1, 2), createArray(3L, 4L)) + + val innerMap = new ArrayBasedMapData(createArray(5, 6), createArray(7L, 8L)) + val map2 = new ArrayBasedMapData(createArray(9), createArray(innerMap)) + + val row = new GenericMutableRow(fieldTypes.length) + row.update(0, map1) + row.update(1, map2) + + val unsafeRow: UnsafeRow = converter.apply(row) + assert(unsafeRow.numFields() == 2) + + val unsafeMap1 = unsafeRow.getMap(0).asInstanceOf[UnsafeMapData] + testIntLongMap(unsafeMap1, Array(1, 2), Array(3L, 4L)) + + val unsafeMap2 = unsafeRow.getMap(1).asInstanceOf[UnsafeMapData] + assert(unsafeMap2.numElements() == 1) + + val keyArray = unsafeMap2.keys + assert(keyArray.getSizeInBytes == 4 + 4) + assert(keyArray.numElements() == 1) + assert(keyArray.getInt(0) == 9) + + val valueArray = unsafeMap2.values + assert(valueArray.numElements() == 1) + val nestedMap = valueArray.getMap(0).asInstanceOf[UnsafeMapData] + testIntLongMap(nestedMap, Array(5, 6), Array(7L, 8L)) + assert(valueArray.getSizeInBytes == 4 + 8 + nestedMap.getSizeInBytes) + + assert(unsafeMap2.getSizeInBytes == keyArray.getSizeInBytes + valueArray.getSizeInBytes) + + val map1Size = roundedSize(8 + unsafeMap1.getSizeInBytes) + val map2Size = roundedSize(8 + unsafeMap2.getSizeInBytes) + assert(unsafeRow.getSizeInBytes == 8 + 8 * 2 + map1Size + map2Size) + } } -- cgit v1.2.3