From 4e0027feaee7c028741da88d8fbc26a45fc4a268 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 5 Oct 2015 23:24:12 -0700 Subject: [SPARK-10585] [SQL] [FOLLOW-UP] remove no-longer-necessary code for unsafe generation These code was left there to produce clear diff for https://github.com/apache/spark/pull/8747 Author: Wenchen Fan Closes #8991 from cloud-fan/clean. --- .../sql/catalyst/expressions/UnsafeRowWriters.java | 264 ---------------- .../sql/catalyst/expressions/UnsafeWriters.java | 193 ----------- .../codegen/GenerateUnsafeProjection.scala | 351 --------------------- 3 files changed, 808 deletions(-) delete mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java delete mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeWriters.java diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java deleted file mode 100644 index 0f1e0202aa..0000000000 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java +++ /dev/null @@ -1,264 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.expressions; - -import java.math.BigInteger; - -import org.apache.spark.sql.catalyst.InternalRow; -import org.apache.spark.sql.types.Decimal; -import org.apache.spark.unsafe.Platform; -import org.apache.spark.unsafe.array.ByteArrayMethods; -import org.apache.spark.unsafe.types.ByteArray; -import org.apache.spark.unsafe.types.CalendarInterval; -import org.apache.spark.unsafe.types.UTF8String; - -/** - * A set of helper methods to write data into {@link UnsafeRow}s, - * used by {@link org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection}. - */ -public class UnsafeRowWriters { - - /** Writer for Decimal with precision under 18. */ - public static class CompactDecimalWriter { - - public static int getSize(Decimal input) { - return 0; - } - - public static int write(UnsafeRow target, int ordinal, int cursor, Decimal input) { - target.setLong(ordinal, input.toUnscaledLong()); - return 0; - } - } - - /** Writer for Decimal with precision larger than 18. */ - public static class DecimalWriter { - private static final int SIZE = 16; - public static int getSize(Decimal input) { - // bounded size - return SIZE; - } - - public static int write(UnsafeRow target, int ordinal, int cursor, Decimal input) { - final Object base = target.getBaseObject(); - final long offset = target.getBaseOffset() + cursor; - // zero-out the bytes - Platform.putLong(base, offset, 0L); - Platform.putLong(base, offset + 8, 0L); - - if (input == null) { - target.setNullAt(ordinal); - // keep the offset and length for update - int fieldOffset = UnsafeRow.calculateBitSetWidthInBytes(target.numFields()) + ordinal * 8; - Platform.putLong(base, target.getBaseOffset() + fieldOffset, - ((long) cursor) << 32); - return SIZE; - } - - final BigInteger integer = input.toJavaBigDecimal().unscaledValue(); - byte[] bytes = integer.toByteArray(); - - // Write the bytes to the variable length portion. - Platform.copyMemory( - bytes, Platform.BYTE_ARRAY_OFFSET, base, target.getBaseOffset() + cursor, bytes.length); - // Set the fixed length portion. - target.setLong(ordinal, (((long) cursor) << 32) | (long) bytes.length); - - return SIZE; - } - } - - /** Writer for UTF8String. */ - public static class UTF8StringWriter { - - public static int getSize(UTF8String input) { - return ByteArrayMethods.roundNumberOfBytesToNearestWord(input.numBytes()); - } - - public static int write(UnsafeRow target, int ordinal, int cursor, UTF8String input) { - final long offset = target.getBaseOffset() + cursor; - final int numBytes = input.numBytes(); - - // zero-out the padding bytes - if ((numBytes & 0x07) > 0) { - Platform.putLong(target.getBaseObject(), offset + ((numBytes >> 3) << 3), 0L); - } - - // Write the bytes to the variable length portion. - input.writeToMemory(target.getBaseObject(), offset); - - // Set the fixed length portion. - target.setLong(ordinal, (((long) cursor) << 32) | ((long) numBytes)); - return ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes); - } - } - - /** Writer for binary (byte array) type. */ - public static class BinaryWriter { - - public static int getSize(byte[] input) { - return ByteArrayMethods.roundNumberOfBytesToNearestWord(input.length); - } - - public static int write(UnsafeRow target, int ordinal, int cursor, byte[] input) { - final long offset = target.getBaseOffset() + cursor; - final int numBytes = input.length; - - // zero-out the padding bytes - if ((numBytes & 0x07) > 0) { - Platform.putLong(target.getBaseObject(), offset + ((numBytes >> 3) << 3), 0L); - } - - // Write the bytes to the variable length portion. - ByteArray.writeToMemory(input, target.getBaseObject(), offset); - - // Set the fixed length portion. - target.setLong(ordinal, (((long) cursor) << 32) | ((long) numBytes)); - return ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes); - } - } - - /** - * Writer for struct type where the struct field is backed by an {@link UnsafeRow}. - * - * We throw UnsupportedOperationException for inputs that are not backed by {@link UnsafeRow}. - * Non-UnsafeRow struct fields are handled directly in - * {@link org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection} - * by generating the Java code needed to convert them into UnsafeRow. - */ - public static class StructWriter { - public static int getSize(InternalRow input) { - int numBytes = 0; - if (input instanceof UnsafeRow) { - numBytes = ((UnsafeRow) input).getSizeInBytes(); - } else { - // This is handled directly in GenerateUnsafeProjection. - throw new UnsupportedOperationException(); - } - return ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes); - } - - public static int write(UnsafeRow target, int ordinal, int cursor, InternalRow input) { - int numBytes = 0; - final long offset = target.getBaseOffset() + cursor; - if (input instanceof UnsafeRow) { - final UnsafeRow row = (UnsafeRow) input; - numBytes = row.getSizeInBytes(); - - // zero-out the padding bytes - if ((numBytes & 0x07) > 0) { - Platform.putLong(target.getBaseObject(), offset + ((numBytes >> 3) << 3), 0L); - } - - // Write the bytes to the variable length portion. - row.writeToMemory(target.getBaseObject(), offset); - - // Set the fixed length portion. - target.setLong(ordinal, (((long) cursor) << 32) | ((long) numBytes)); - } else { - // This is handled directly in GenerateUnsafeProjection. - throw new UnsupportedOperationException(); - } - return ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes); - } - } - - /** Writer for interval type. */ - public static class IntervalWriter { - - public static int write(UnsafeRow target, int ordinal, int cursor, CalendarInterval input) { - final long offset = target.getBaseOffset() + cursor; - - // Write the months and microseconds fields of Interval to the variable length portion. - Platform.putLong(target.getBaseObject(), offset, input.months); - Platform.putLong(target.getBaseObject(), offset + 8, input.microseconds); - - // Set the fixed length portion. - target.setLong(ordinal, ((long) cursor) << 32); - return 16; - } - } - - public static class ArrayWriter { - - public static int getSize(UnsafeArrayData input) { - // we need extra 4 bytes the store the number of elements in this array. - return ByteArrayMethods.roundNumberOfBytesToNearestWord(input.getSizeInBytes() + 4); - } - - public static int write(UnsafeRow target, int ordinal, int cursor, UnsafeArrayData input) { - final int numBytes = input.getSizeInBytes() + 4; - final long offset = target.getBaseOffset() + cursor; - - // write the number of elements into first 4 bytes. - Platform.putInt(target.getBaseObject(), offset, input.numElements()); - - // zero-out the padding bytes - if ((numBytes & 0x07) > 0) { - Platform.putLong(target.getBaseObject(), offset + ((numBytes >> 3) << 3), 0L); - } - - // Write the bytes to the variable length portion. - input.writeToMemory(target.getBaseObject(), offset + 4); - - // Set the fixed length portion. - target.setLong(ordinal, (((long) cursor) << 32) | ((long) numBytes)); - - return ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes); - } - } - - public static class MapWriter { - - public static int getSize(UnsafeMapData input) { - // we need extra 8 bytes to store number of elements and numBytes of key array. - final int sizeInBytes = 4 + 4 + input.getSizeInBytes(); - return ByteArrayMethods.roundNumberOfBytesToNearestWord(sizeInBytes); - } - - public static int write(UnsafeRow target, int ordinal, int cursor, UnsafeMapData input) { - final long offset = target.getBaseOffset() + cursor; - final UnsafeArrayData keyArray = input.keyArray(); - final UnsafeArrayData valueArray = input.valueArray(); - final int keysNumBytes = keyArray.getSizeInBytes(); - final int valuesNumBytes = valueArray.getSizeInBytes(); - final int numBytes = 4 + 4 + keysNumBytes + valuesNumBytes; - - // write the number of elements into first 4 bytes. - Platform.putInt(target.getBaseObject(), offset, input.numElements()); - // write the numBytes of key array into second 4 bytes. - Platform.putInt(target.getBaseObject(), offset + 4, keysNumBytes); - - // zero-out the padding bytes - if ((numBytes & 0x07) > 0) { - Platform.putLong(target.getBaseObject(), offset + ((numBytes >> 3) << 3), 0L); - } - - // Write the bytes of key array to the variable length portion. - keyArray.writeToMemory(target.getBaseObject(), offset + 8); - - // Write the bytes of value array to the variable length portion. - valueArray.writeToMemory(target.getBaseObject(), offset + 8 + keysNumBytes); - - // Set the fixed length portion. - target.setLong(ordinal, (((long) cursor) << 32) | ((long) numBytes)); - - return ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes); - } - } -} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeWriters.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeWriters.java deleted file mode 100644 index ce2d9c4ffb..0000000000 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeWriters.java +++ /dev/null @@ -1,193 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.expressions; - -import org.apache.spark.sql.types.Decimal; -import org.apache.spark.unsafe.Platform; -import org.apache.spark.unsafe.types.CalendarInterval; -import org.apache.spark.unsafe.types.UTF8String; - -/** - * A set of helper methods to write data into the variable length portion. - */ -public class UnsafeWriters { - public static void writeToMemory( - Object inputObject, - long inputOffset, - Object targetObject, - long targetOffset, - int numBytes) { - - // zero-out the padding bytes -// if ((numBytes & 0x07) > 0) { -// Platform.putLong(targetObject, targetOffset + ((numBytes >> 3) << 3), 0L); -// } - - // Write the UnsafeData to the target memory. - Platform.copyMemory(inputObject, inputOffset, targetObject, targetOffset, numBytes); - } - - public static int getRoundedSize(int size) { - //return ByteArrayMethods.roundNumberOfBytesToNearestWord(size); - // todo: do word alignment - return size; - } - - /** Writer for Decimal with precision larger than 18. */ - public static class DecimalWriter { - - public static int getSize(Decimal input) { - return 16; - } - - public static int write(Object targetObject, long targetOffset, Decimal input) { - final byte[] bytes = input.toJavaBigDecimal().unscaledValue().toByteArray(); - final int numBytes = bytes.length; - assert(numBytes <= 16); - - // zero-out the bytes - Platform.putLong(targetObject, targetOffset, 0L); - Platform.putLong(targetObject, targetOffset + 8, 0L); - - // Write the bytes to the variable length portion. - Platform.copyMemory(bytes, Platform.BYTE_ARRAY_OFFSET, targetObject, targetOffset, numBytes); - return 16; - } - } - - /** Writer for UTF8String. */ - public static class UTF8StringWriter { - - public static int getSize(UTF8String input) { - return getRoundedSize(input.numBytes()); - } - - public static int write(Object targetObject, long targetOffset, UTF8String input) { - final int numBytes = input.numBytes(); - - // Write the bytes to the variable length portion. - writeToMemory(input.getBaseObject(), input.getBaseOffset(), - targetObject, targetOffset, numBytes); - - return getRoundedSize(numBytes); - } - } - - /** Writer for binary (byte array) type. */ - public static class BinaryWriter { - - public static int getSize(byte[] input) { - return getRoundedSize(input.length); - } - - public static int write(Object targetObject, long targetOffset, byte[] input) { - final int numBytes = input.length; - - // Write the bytes to the variable length portion. - writeToMemory(input, Platform.BYTE_ARRAY_OFFSET, targetObject, targetOffset, numBytes); - - return getRoundedSize(numBytes); - } - } - - /** Writer for UnsafeRow. */ - public static class StructWriter { - - public static int getSize(UnsafeRow input) { - return getRoundedSize(input.getSizeInBytes()); - } - - public static int write(Object targetObject, long targetOffset, UnsafeRow input) { - final int numBytes = input.getSizeInBytes(); - - // Write the bytes to the variable length portion. - writeToMemory(input.getBaseObject(), input.getBaseOffset(), - targetObject, targetOffset, numBytes); - - return getRoundedSize(numBytes); - } - } - - /** Writer for interval type. */ - public static class IntervalWriter { - - public static int getSize(UnsafeRow input) { - return 16; - } - - public static int write(Object targetObject, long targetOffset, CalendarInterval input) { - // Write the months and microseconds fields of Interval to the variable length portion. - Platform.putLong(targetObject, targetOffset, input.months); - Platform.putLong(targetObject, targetOffset + 8, input.microseconds); - return 16; - } - } - - /** Writer for UnsafeArrayData. */ - public static class ArrayWriter { - - public static int getSize(UnsafeArrayData input) { - // we need extra 4 bytes the store the number of elements in this array. - return getRoundedSize(input.getSizeInBytes() + 4); - } - - public static int write(Object targetObject, long targetOffset, UnsafeArrayData input) { - final int numBytes = input.getSizeInBytes(); - - // write the number of elements into first 4 bytes. - Platform.putInt(targetObject, targetOffset, input.numElements()); - - // Write the bytes to the variable length portion. - writeToMemory( - input.getBaseObject(), input.getBaseOffset(), targetObject, targetOffset + 4, numBytes); - - return getRoundedSize(numBytes + 4); - } - } - - public static class MapWriter { - - public static int getSize(UnsafeMapData input) { - // we need extra 8 bytes to store number of elements and numBytes of key array. - return getRoundedSize(4 + 4 + input.getSizeInBytes()); - } - - public static int write(Object targetObject, long targetOffset, UnsafeMapData input) { - final UnsafeArrayData keyArray = input.keyArray(); - final UnsafeArrayData valueArray = input.valueArray(); - final int keysNumBytes = keyArray.getSizeInBytes(); - final int valuesNumBytes = valueArray.getSizeInBytes(); - final int numBytes = 4 + 4 + keysNumBytes + valuesNumBytes; - - // write the number of elements into first 4 bytes. - Platform.putInt(targetObject, targetOffset, input.numElements()); - // write the numBytes of key array into second 4 bytes. - Platform.putInt(targetObject, targetOffset + 4, keysNumBytes); - - // Write the bytes of key array to the variable length portion. - writeToMemory(keyArray.getBaseObject(), keyArray.getBaseOffset(), - targetObject, targetOffset + 8, keysNumBytes); - - // Write the bytes of value array to the variable length portion. - writeToMemory(valueArray.getBaseObject(), valueArray.getBaseOffset(), - targetObject, targetOffset + 8 + keysNumBytes, valuesNumBytes); - - return getRoundedSize(numBytes); - } - } -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 99bf50a845..8e58cb9ad1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -31,15 +31,6 @@ import org.apache.spark.sql.types._ */ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafeProjection] { - private val StringWriter = classOf[UnsafeRowWriters.UTF8StringWriter].getName - private val BinaryWriter = classOf[UnsafeRowWriters.BinaryWriter].getName - private val IntervalWriter = classOf[UnsafeRowWriters.IntervalWriter].getName - private val StructWriter = classOf[UnsafeRowWriters.StructWriter].getName - private val CompactDecimalWriter = classOf[UnsafeRowWriters.CompactDecimalWriter].getName - private val DecimalWriter = classOf[UnsafeRowWriters.DecimalWriter].getName - private val ArrayWriter = classOf[UnsafeRowWriters.ArrayWriter].getName - private val MapWriter = classOf[UnsafeRowWriters.MapWriter].getName - /** Returns true iff we support this data type. */ def canSupport(dataType: DataType): Boolean = dataType match { case NullType => true @@ -51,348 +42,6 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro case _ => false } - def genAdditionalSize(dt: DataType, ev: GeneratedExpressionCode): String = dt match { - case t: DecimalType if t.precision > Decimal.MAX_LONG_DIGITS => - s"$DecimalWriter.getSize(${ev.primitive})" - case StringType => - s"${ev.isNull} ? 0 : $StringWriter.getSize(${ev.primitive})" - case BinaryType => - s"${ev.isNull} ? 0 : $BinaryWriter.getSize(${ev.primitive})" - case CalendarIntervalType => - s"${ev.isNull} ? 0 : 16" - case _: StructType => - s"${ev.isNull} ? 0 : $StructWriter.getSize(${ev.primitive})" - case _: ArrayType => - s"${ev.isNull} ? 0 : $ArrayWriter.getSize(${ev.primitive})" - case _: MapType => - s"${ev.isNull} ? 0 : $MapWriter.getSize(${ev.primitive})" - case _ => "" - } - - def genFieldWriter( - ctx: CodeGenContext, - fieldType: DataType, - ev: GeneratedExpressionCode, - target: String, - index: Int, - cursor: String): String = fieldType match { - case _ if ctx.isPrimitiveType(fieldType) => - s"${ctx.setColumn(target, fieldType, index, ev.primitive)}" - case t: DecimalType if t.precision <= Decimal.MAX_LONG_DIGITS => - s""" - // make sure Decimal object has the same scale as DecimalType - if (${ev.primitive}.changePrecision(${t.precision}, ${t.scale})) { - $CompactDecimalWriter.write($target, $index, $cursor, ${ev.primitive}); - } else { - $target.setNullAt($index); - } - """ - case t: DecimalType if t.precision > Decimal.MAX_LONG_DIGITS => - s""" - // make sure Decimal object has the same scale as DecimalType - if (${ev.primitive}.changePrecision(${t.precision}, ${t.scale})) { - $cursor += $DecimalWriter.write($target, $index, $cursor, ${ev.primitive}); - } else { - $cursor += $DecimalWriter.write($target, $index, $cursor, null); - } - """ - case StringType => - s"$cursor += $StringWriter.write($target, $index, $cursor, ${ev.primitive})" - case BinaryType => - s"$cursor += $BinaryWriter.write($target, $index, $cursor, ${ev.primitive})" - case CalendarIntervalType => - s"$cursor += $IntervalWriter.write($target, $index, $cursor, ${ev.primitive})" - case _: StructType => - s"$cursor += $StructWriter.write($target, $index, $cursor, ${ev.primitive})" - case _: ArrayType => - s"$cursor += $ArrayWriter.write($target, $index, $cursor, ${ev.primitive})" - case _: MapType => - s"$cursor += $MapWriter.write($target, $index, $cursor, ${ev.primitive})" - case NullType => "" - case _ => - throw new UnsupportedOperationException(s"Not supported DataType: $fieldType") - } - - /** - * Generates the Java code to convert a struct (backed by InternalRow) to UnsafeRow. - * - * @param ctx code generation context - * @param inputs could be the codes for expressions or input struct fields. - * @param inputTypes types of the inputs - */ - private def createCodeForStruct( - ctx: CodeGenContext, - row: String, - inputs: Seq[GeneratedExpressionCode], - inputTypes: Seq[DataType]): GeneratedExpressionCode = { - - val fixedSize = 8 * inputTypes.length + UnsafeRow.calculateBitSetWidthInBytes(inputTypes.length) - - val output = ctx.freshName("convertedStruct") - ctx.addMutableState("UnsafeRow", output, s"this.$output = new UnsafeRow();") - val buffer = ctx.freshName("buffer") - ctx.addMutableState("byte[]", buffer, s"this.$buffer = new byte[$fixedSize];") - val cursor = ctx.freshName("cursor") - ctx.addMutableState("int", cursor, s"this.$cursor = 0;") - val tmpBuffer = ctx.freshName("tmpBuffer") - - val convertedFields = inputTypes.zip(inputs).zipWithIndex.map { case ((dt, input), i) => - val ev = createConvertCode(ctx, input, dt) - val growBuffer = if (!UnsafeRow.isFixedLength(dt)) { - val numBytes = ctx.freshName("numBytes") - s""" - int $numBytes = $cursor + (${genAdditionalSize(dt, ev)}); - if ($buffer.length < $numBytes) { - // This will not happen frequently, because the buffer is re-used. - byte[] $tmpBuffer = new byte[$numBytes * 2]; - Platform.copyMemory($buffer, Platform.BYTE_ARRAY_OFFSET, - $tmpBuffer, Platform.BYTE_ARRAY_OFFSET, $buffer.length); - $buffer = $tmpBuffer; - } - $output.pointTo($buffer, Platform.BYTE_ARRAY_OFFSET, ${inputTypes.length}, $numBytes); - """ - } else { - "" - } - val update = dt match { - case dt: DecimalType if dt.precision > Decimal.MAX_LONG_DIGITS => - // Can't call setNullAt() for DecimalType - s""" - if (${ev.isNull}) { - $cursor += $DecimalWriter.write($output, $i, $cursor, null); - } else { - ${genFieldWriter(ctx, dt, ev, output, i, cursor)}; - } - """ - case _ => - s""" - if (${ev.isNull}) { - $output.setNullAt($i); - } else { - ${genFieldWriter(ctx, dt, ev, output, i, cursor)}; - } - """ - } - s""" - ${ev.code} - $growBuffer - $update - """ - } - - val code = s""" - $cursor = $fixedSize; - $output.pointTo($buffer, Platform.BYTE_ARRAY_OFFSET, ${inputTypes.length}, $cursor); - ${ctx.splitExpressions(row, convertedFields)} - """ - GeneratedExpressionCode(code, "false", output) - } - - private def getWriter(dt: DataType) = dt match { - case StringType => classOf[UnsafeWriters.UTF8StringWriter].getName - case BinaryType => classOf[UnsafeWriters.BinaryWriter].getName - case CalendarIntervalType => classOf[UnsafeWriters.IntervalWriter].getName - case _: StructType => classOf[UnsafeWriters.StructWriter].getName - case _: ArrayType => classOf[UnsafeWriters.ArrayWriter].getName - case _: MapType => classOf[UnsafeWriters.MapWriter].getName - case _: DecimalType => classOf[UnsafeWriters.DecimalWriter].getName - } - - private def createCodeForArray( - ctx: CodeGenContext, - input: GeneratedExpressionCode, - elementType: DataType): GeneratedExpressionCode = { - val output = ctx.freshName("convertedArray") - ctx.addMutableState("UnsafeArrayData", output, s"$output = new UnsafeArrayData();") - val buffer = ctx.freshName("buffer") - ctx.addMutableState("byte[]", buffer, s"$buffer = new byte[64];") - val tmpBuffer = ctx.freshName("tmpBuffer") - val outputIsNull = ctx.freshName("isNull") - val numElements = ctx.freshName("numElements") - val fixedSize = ctx.freshName("fixedSize") - val numBytes = ctx.freshName("numBytes") - val cursor = ctx.freshName("cursor") - val index = ctx.freshName("index") - val elementName = ctx.freshName("elementName") - - val element = { - val code = s"${ctx.javaType(elementType)} $elementName = " + - s"${ctx.getValue(input.primitive, elementType, index)};" - val isNull = s"${input.primitive}.isNullAt($index)" - GeneratedExpressionCode(code, isNull, elementName) - } - - val convertedElement = createConvertCode(ctx, element, elementType) - - val writeElement = elementType match { - case _ if ctx.isPrimitiveType(elementType) => - // Should we do word align? - val elementSize = elementType.defaultSize - s""" - Platform.put${ctx.primitiveTypeName(elementType)}( - $buffer, - Platform.BYTE_ARRAY_OFFSET + $cursor, - ${convertedElement.primitive}); - $cursor += $elementSize; - """ - case t: DecimalType if t.precision <= Decimal.MAX_LONG_DIGITS => - s""" - Platform.putLong( - $buffer, - Platform.BYTE_ARRAY_OFFSET + $cursor, - ${convertedElement.primitive}.toUnscaledLong()); - $cursor += 8; - """ - case _ => - val writer = getWriter(elementType) - s""" - $cursor += $writer.write( - $buffer, - Platform.BYTE_ARRAY_OFFSET + $cursor, - ${convertedElement.primitive}); - """ - } - - val checkNull = convertedElement.isNull + (elementType match { - case t: DecimalType => - s" || !${convertedElement.primitive}.changePrecision(${t.precision}, ${t.scale})" - case _ => "" - }) - - val elementSize = elementType match { - // Should we do word align for primitive types? - case _ if ctx.isPrimitiveType(elementType) => elementType.defaultSize.toString - case t: DecimalType if t.precision <= Decimal.MAX_LONG_DIGITS => "8" - case _ => - val writer = getWriter(elementType) - s"$writer.getSize(${convertedElement.primitive})" - } - - val code = s""" - ${input.code} - final boolean $outputIsNull = ${input.isNull}; - if (!$outputIsNull) { - if (${input.primitive} instanceof UnsafeArrayData) { - $output = (UnsafeArrayData) ${input.primitive}; - } else { - final int $numElements = ${input.primitive}.numElements(); - final int $fixedSize = 4 * $numElements; - int $numBytes = $fixedSize; - - int $cursor = $fixedSize; - for (int $index = 0; $index < $numElements; $index++) { - ${convertedElement.code} - if ($checkNull) { - // If element is null, write the negative value address into offset region. - Platform.putInt($buffer, Platform.BYTE_ARRAY_OFFSET + 4 * $index, -$cursor); - } else { - $numBytes += $elementSize; - if ($buffer.length < $numBytes) { - // This will not happen frequently, because the buffer is re-used. - byte[] $tmpBuffer = new byte[$numBytes * 2]; - Platform.copyMemory($buffer, Platform.BYTE_ARRAY_OFFSET, - $tmpBuffer, Platform.BYTE_ARRAY_OFFSET, $buffer.length); - $buffer = $tmpBuffer; - } - Platform.putInt($buffer, Platform.BYTE_ARRAY_OFFSET + 4 * $index, $cursor); - $writeElement - } - } - - $output.pointTo( - $buffer, - Platform.BYTE_ARRAY_OFFSET, - $numElements, - $numBytes); - } - } - """ - GeneratedExpressionCode(code, outputIsNull, output) - } - - private def createCodeForMap( - ctx: CodeGenContext, - input: GeneratedExpressionCode, - keyType: DataType, - valueType: DataType): GeneratedExpressionCode = { - val output = ctx.freshName("convertedMap") - val outputIsNull = ctx.freshName("isNull") - val keyArrayName = ctx.freshName("keyArrayName") - val valueArrayName = ctx.freshName("valueArrayName") - - val keyArray = { - val code = s"ArrayData $keyArrayName = ${input.primitive}.keyArray();" - val isNull = "false" - GeneratedExpressionCode(code, isNull, keyArrayName) - } - - val valueArray = { - val code = s"ArrayData $valueArrayName = ${input.primitive}.valueArray();" - val isNull = "false" - GeneratedExpressionCode(code, isNull, valueArrayName) - } - - val convertedKeys = createCodeForArray(ctx, keyArray, keyType) - val convertedValues = createCodeForArray(ctx, valueArray, valueType) - - val code = s""" - ${input.code} - final boolean $outputIsNull = ${input.isNull}; - UnsafeMapData $output = null; - if (!$outputIsNull) { - if (${input.primitive} instanceof UnsafeMapData) { - $output = (UnsafeMapData) ${input.primitive}; - } else { - ${convertedKeys.code} - ${convertedValues.code} - $output = new UnsafeMapData(${convertedKeys.primitive}, ${convertedValues.primitive}); - } - } - """ - GeneratedExpressionCode(code, outputIsNull, output) - } - - /** - * Generates the java code to convert a data to its unsafe version. - */ - private def createConvertCode( - ctx: CodeGenContext, - input: GeneratedExpressionCode, - dataType: DataType): GeneratedExpressionCode = dataType match { - case t: StructType => - val output = ctx.freshName("convertedStruct") - val outputIsNull = ctx.freshName("isNull") - val fieldTypes = t.fields.map(_.dataType) - val fieldEvals = fieldTypes.zipWithIndex.map { case (dt, i) => - val fieldName = ctx.freshName("fieldName") - val code = s"${ctx.javaType(dt)} $fieldName = " + - s"${ctx.getValue(input.primitive, dt, i.toString)};" - val isNull = s"${input.primitive}.isNullAt($i)" - GeneratedExpressionCode(code, isNull, fieldName) - } - val converter = createCodeForStruct(ctx, input.primitive, fieldEvals, fieldTypes) - val code = s""" - ${input.code} - UnsafeRow $output = null; - final boolean $outputIsNull = ${input.isNull}; - if (!$outputIsNull) { - if (${input.primitive} instanceof UnsafeRow) { - $output = (UnsafeRow) ${input.primitive}; - } else { - ${converter.code} - $output = ${converter.primitive}; - } - } - """ - GeneratedExpressionCode(code, outputIsNull, output) - - case ArrayType(elementType, _) => createCodeForArray(ctx, input, elementType) - - case MapType(kt, vt, _) => createCodeForMap(ctx, input, kt, vt) - - case _ => input - } - private val rowWriterClass = classOf[UnsafeRowWriter].getName private val arrayWriterClass = classOf[UnsafeArrayWriter].getName -- cgit v1.2.3