aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorWenchen Fan <cloud0fan@163.com>2015-10-05 23:24:12 -0700
committerDavies Liu <davies.liu@gmail.com>2015-10-05 23:24:12 -0700
commit4e0027feaee7c028741da88d8fbc26a45fc4a268 (patch)
treef3c2c7f98fb67fae40e5335e0811a34994f74fec
parentbe7c5ff1ad02ce1c03113c98656a4e0c0c3cee83 (diff)
downloadspark-4e0027feaee7c028741da88d8fbc26a45fc4a268.tar.gz
spark-4e0027feaee7c028741da88d8fbc26a45fc4a268.tar.bz2
spark-4e0027feaee7c028741da88d8fbc26a45fc4a268.zip
[SPARK-10585] [SQL] [FOLLOW-UP] remove no-longer-necessary code for unsafe generation
These code was left there to produce clear diff for https://github.com/apache/spark/pull/8747 Author: Wenchen Fan <cloud0fan@163.com> Closes #8991 from cloud-fan/clean.
-rw-r--r--sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java264
-rw-r--r--sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeWriters.java193
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala351
3 files changed, 0 insertions, 808 deletions
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java
deleted file mode 100644
index 0f1e0202aa..0000000000
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRowWriters.java
+++ /dev/null
@@ -1,264 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.catalyst.expressions;
-
-import java.math.BigInteger;
-
-import org.apache.spark.sql.catalyst.InternalRow;
-import org.apache.spark.sql.types.Decimal;
-import org.apache.spark.unsafe.Platform;
-import org.apache.spark.unsafe.array.ByteArrayMethods;
-import org.apache.spark.unsafe.types.ByteArray;
-import org.apache.spark.unsafe.types.CalendarInterval;
-import org.apache.spark.unsafe.types.UTF8String;
-
-/**
- * A set of helper methods to write data into {@link UnsafeRow}s,
- * used by {@link org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection}.
- */
-public class UnsafeRowWriters {
-
- /** Writer for Decimal with precision under 18. */
- public static class CompactDecimalWriter {
-
- public static int getSize(Decimal input) {
- return 0;
- }
-
- public static int write(UnsafeRow target, int ordinal, int cursor, Decimal input) {
- target.setLong(ordinal, input.toUnscaledLong());
- return 0;
- }
- }
-
- /** Writer for Decimal with precision larger than 18. */
- public static class DecimalWriter {
- private static final int SIZE = 16;
- public static int getSize(Decimal input) {
- // bounded size
- return SIZE;
- }
-
- public static int write(UnsafeRow target, int ordinal, int cursor, Decimal input) {
- final Object base = target.getBaseObject();
- final long offset = target.getBaseOffset() + cursor;
- // zero-out the bytes
- Platform.putLong(base, offset, 0L);
- Platform.putLong(base, offset + 8, 0L);
-
- if (input == null) {
- target.setNullAt(ordinal);
- // keep the offset and length for update
- int fieldOffset = UnsafeRow.calculateBitSetWidthInBytes(target.numFields()) + ordinal * 8;
- Platform.putLong(base, target.getBaseOffset() + fieldOffset,
- ((long) cursor) << 32);
- return SIZE;
- }
-
- final BigInteger integer = input.toJavaBigDecimal().unscaledValue();
- byte[] bytes = integer.toByteArray();
-
- // Write the bytes to the variable length portion.
- Platform.copyMemory(
- bytes, Platform.BYTE_ARRAY_OFFSET, base, target.getBaseOffset() + cursor, bytes.length);
- // Set the fixed length portion.
- target.setLong(ordinal, (((long) cursor) << 32) | (long) bytes.length);
-
- return SIZE;
- }
- }
-
- /** Writer for UTF8String. */
- public static class UTF8StringWriter {
-
- public static int getSize(UTF8String input) {
- return ByteArrayMethods.roundNumberOfBytesToNearestWord(input.numBytes());
- }
-
- public static int write(UnsafeRow target, int ordinal, int cursor, UTF8String input) {
- final long offset = target.getBaseOffset() + cursor;
- final int numBytes = input.numBytes();
-
- // zero-out the padding bytes
- if ((numBytes & 0x07) > 0) {
- Platform.putLong(target.getBaseObject(), offset + ((numBytes >> 3) << 3), 0L);
- }
-
- // Write the bytes to the variable length portion.
- input.writeToMemory(target.getBaseObject(), offset);
-
- // Set the fixed length portion.
- target.setLong(ordinal, (((long) cursor) << 32) | ((long) numBytes));
- return ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes);
- }
- }
-
- /** Writer for binary (byte array) type. */
- public static class BinaryWriter {
-
- public static int getSize(byte[] input) {
- return ByteArrayMethods.roundNumberOfBytesToNearestWord(input.length);
- }
-
- public static int write(UnsafeRow target, int ordinal, int cursor, byte[] input) {
- final long offset = target.getBaseOffset() + cursor;
- final int numBytes = input.length;
-
- // zero-out the padding bytes
- if ((numBytes & 0x07) > 0) {
- Platform.putLong(target.getBaseObject(), offset + ((numBytes >> 3) << 3), 0L);
- }
-
- // Write the bytes to the variable length portion.
- ByteArray.writeToMemory(input, target.getBaseObject(), offset);
-
- // Set the fixed length portion.
- target.setLong(ordinal, (((long) cursor) << 32) | ((long) numBytes));
- return ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes);
- }
- }
-
- /**
- * Writer for struct type where the struct field is backed by an {@link UnsafeRow}.
- *
- * We throw UnsupportedOperationException for inputs that are not backed by {@link UnsafeRow}.
- * Non-UnsafeRow struct fields are handled directly in
- * {@link org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection}
- * by generating the Java code needed to convert them into UnsafeRow.
- */
- public static class StructWriter {
- public static int getSize(InternalRow input) {
- int numBytes = 0;
- if (input instanceof UnsafeRow) {
- numBytes = ((UnsafeRow) input).getSizeInBytes();
- } else {
- // This is handled directly in GenerateUnsafeProjection.
- throw new UnsupportedOperationException();
- }
- return ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes);
- }
-
- public static int write(UnsafeRow target, int ordinal, int cursor, InternalRow input) {
- int numBytes = 0;
- final long offset = target.getBaseOffset() + cursor;
- if (input instanceof UnsafeRow) {
- final UnsafeRow row = (UnsafeRow) input;
- numBytes = row.getSizeInBytes();
-
- // zero-out the padding bytes
- if ((numBytes & 0x07) > 0) {
- Platform.putLong(target.getBaseObject(), offset + ((numBytes >> 3) << 3), 0L);
- }
-
- // Write the bytes to the variable length portion.
- row.writeToMemory(target.getBaseObject(), offset);
-
- // Set the fixed length portion.
- target.setLong(ordinal, (((long) cursor) << 32) | ((long) numBytes));
- } else {
- // This is handled directly in GenerateUnsafeProjection.
- throw new UnsupportedOperationException();
- }
- return ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes);
- }
- }
-
- /** Writer for interval type. */
- public static class IntervalWriter {
-
- public static int write(UnsafeRow target, int ordinal, int cursor, CalendarInterval input) {
- final long offset = target.getBaseOffset() + cursor;
-
- // Write the months and microseconds fields of Interval to the variable length portion.
- Platform.putLong(target.getBaseObject(), offset, input.months);
- Platform.putLong(target.getBaseObject(), offset + 8, input.microseconds);
-
- // Set the fixed length portion.
- target.setLong(ordinal, ((long) cursor) << 32);
- return 16;
- }
- }
-
- public static class ArrayWriter {
-
- public static int getSize(UnsafeArrayData input) {
- // we need extra 4 bytes the store the number of elements in this array.
- return ByteArrayMethods.roundNumberOfBytesToNearestWord(input.getSizeInBytes() + 4);
- }
-
- public static int write(UnsafeRow target, int ordinal, int cursor, UnsafeArrayData input) {
- final int numBytes = input.getSizeInBytes() + 4;
- final long offset = target.getBaseOffset() + cursor;
-
- // write the number of elements into first 4 bytes.
- Platform.putInt(target.getBaseObject(), offset, input.numElements());
-
- // zero-out the padding bytes
- if ((numBytes & 0x07) > 0) {
- Platform.putLong(target.getBaseObject(), offset + ((numBytes >> 3) << 3), 0L);
- }
-
- // Write the bytes to the variable length portion.
- input.writeToMemory(target.getBaseObject(), offset + 4);
-
- // Set the fixed length portion.
- target.setLong(ordinal, (((long) cursor) << 32) | ((long) numBytes));
-
- return ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes);
- }
- }
-
- public static class MapWriter {
-
- public static int getSize(UnsafeMapData input) {
- // we need extra 8 bytes to store number of elements and numBytes of key array.
- final int sizeInBytes = 4 + 4 + input.getSizeInBytes();
- return ByteArrayMethods.roundNumberOfBytesToNearestWord(sizeInBytes);
- }
-
- public static int write(UnsafeRow target, int ordinal, int cursor, UnsafeMapData input) {
- final long offset = target.getBaseOffset() + cursor;
- final UnsafeArrayData keyArray = input.keyArray();
- final UnsafeArrayData valueArray = input.valueArray();
- final int keysNumBytes = keyArray.getSizeInBytes();
- final int valuesNumBytes = valueArray.getSizeInBytes();
- final int numBytes = 4 + 4 + keysNumBytes + valuesNumBytes;
-
- // write the number of elements into first 4 bytes.
- Platform.putInt(target.getBaseObject(), offset, input.numElements());
- // write the numBytes of key array into second 4 bytes.
- Platform.putInt(target.getBaseObject(), offset + 4, keysNumBytes);
-
- // zero-out the padding bytes
- if ((numBytes & 0x07) > 0) {
- Platform.putLong(target.getBaseObject(), offset + ((numBytes >> 3) << 3), 0L);
- }
-
- // Write the bytes of key array to the variable length portion.
- keyArray.writeToMemory(target.getBaseObject(), offset + 8);
-
- // Write the bytes of value array to the variable length portion.
- valueArray.writeToMemory(target.getBaseObject(), offset + 8 + keysNumBytes);
-
- // Set the fixed length portion.
- target.setLong(ordinal, (((long) cursor) << 32) | ((long) numBytes));
-
- return ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes);
- }
- }
-}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeWriters.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeWriters.java
deleted file mode 100644
index ce2d9c4ffb..0000000000
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeWriters.java
+++ /dev/null
@@ -1,193 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.catalyst.expressions;
-
-import org.apache.spark.sql.types.Decimal;
-import org.apache.spark.unsafe.Platform;
-import org.apache.spark.unsafe.types.CalendarInterval;
-import org.apache.spark.unsafe.types.UTF8String;
-
-/**
- * A set of helper methods to write data into the variable length portion.
- */
-public class UnsafeWriters {
- public static void writeToMemory(
- Object inputObject,
- long inputOffset,
- Object targetObject,
- long targetOffset,
- int numBytes) {
-
- // zero-out the padding bytes
-// if ((numBytes & 0x07) > 0) {
-// Platform.putLong(targetObject, targetOffset + ((numBytes >> 3) << 3), 0L);
-// }
-
- // Write the UnsafeData to the target memory.
- Platform.copyMemory(inputObject, inputOffset, targetObject, targetOffset, numBytes);
- }
-
- public static int getRoundedSize(int size) {
- //return ByteArrayMethods.roundNumberOfBytesToNearestWord(size);
- // todo: do word alignment
- return size;
- }
-
- /** Writer for Decimal with precision larger than 18. */
- public static class DecimalWriter {
-
- public static int getSize(Decimal input) {
- return 16;
- }
-
- public static int write(Object targetObject, long targetOffset, Decimal input) {
- final byte[] bytes = input.toJavaBigDecimal().unscaledValue().toByteArray();
- final int numBytes = bytes.length;
- assert(numBytes <= 16);
-
- // zero-out the bytes
- Platform.putLong(targetObject, targetOffset, 0L);
- Platform.putLong(targetObject, targetOffset + 8, 0L);
-
- // Write the bytes to the variable length portion.
- Platform.copyMemory(bytes, Platform.BYTE_ARRAY_OFFSET, targetObject, targetOffset, numBytes);
- return 16;
- }
- }
-
- /** Writer for UTF8String. */
- public static class UTF8StringWriter {
-
- public static int getSize(UTF8String input) {
- return getRoundedSize(input.numBytes());
- }
-
- public static int write(Object targetObject, long targetOffset, UTF8String input) {
- final int numBytes = input.numBytes();
-
- // Write the bytes to the variable length portion.
- writeToMemory(input.getBaseObject(), input.getBaseOffset(),
- targetObject, targetOffset, numBytes);
-
- return getRoundedSize(numBytes);
- }
- }
-
- /** Writer for binary (byte array) type. */
- public static class BinaryWriter {
-
- public static int getSize(byte[] input) {
- return getRoundedSize(input.length);
- }
-
- public static int write(Object targetObject, long targetOffset, byte[] input) {
- final int numBytes = input.length;
-
- // Write the bytes to the variable length portion.
- writeToMemory(input, Platform.BYTE_ARRAY_OFFSET, targetObject, targetOffset, numBytes);
-
- return getRoundedSize(numBytes);
- }
- }
-
- /** Writer for UnsafeRow. */
- public static class StructWriter {
-
- public static int getSize(UnsafeRow input) {
- return getRoundedSize(input.getSizeInBytes());
- }
-
- public static int write(Object targetObject, long targetOffset, UnsafeRow input) {
- final int numBytes = input.getSizeInBytes();
-
- // Write the bytes to the variable length portion.
- writeToMemory(input.getBaseObject(), input.getBaseOffset(),
- targetObject, targetOffset, numBytes);
-
- return getRoundedSize(numBytes);
- }
- }
-
- /** Writer for interval type. */
- public static class IntervalWriter {
-
- public static int getSize(UnsafeRow input) {
- return 16;
- }
-
- public static int write(Object targetObject, long targetOffset, CalendarInterval input) {
- // Write the months and microseconds fields of Interval to the variable length portion.
- Platform.putLong(targetObject, targetOffset, input.months);
- Platform.putLong(targetObject, targetOffset + 8, input.microseconds);
- return 16;
- }
- }
-
- /** Writer for UnsafeArrayData. */
- public static class ArrayWriter {
-
- public static int getSize(UnsafeArrayData input) {
- // we need extra 4 bytes the store the number of elements in this array.
- return getRoundedSize(input.getSizeInBytes() + 4);
- }
-
- public static int write(Object targetObject, long targetOffset, UnsafeArrayData input) {
- final int numBytes = input.getSizeInBytes();
-
- // write the number of elements into first 4 bytes.
- Platform.putInt(targetObject, targetOffset, input.numElements());
-
- // Write the bytes to the variable length portion.
- writeToMemory(
- input.getBaseObject(), input.getBaseOffset(), targetObject, targetOffset + 4, numBytes);
-
- return getRoundedSize(numBytes + 4);
- }
- }
-
- public static class MapWriter {
-
- public static int getSize(UnsafeMapData input) {
- // we need extra 8 bytes to store number of elements and numBytes of key array.
- return getRoundedSize(4 + 4 + input.getSizeInBytes());
- }
-
- public static int write(Object targetObject, long targetOffset, UnsafeMapData input) {
- final UnsafeArrayData keyArray = input.keyArray();
- final UnsafeArrayData valueArray = input.valueArray();
- final int keysNumBytes = keyArray.getSizeInBytes();
- final int valuesNumBytes = valueArray.getSizeInBytes();
- final int numBytes = 4 + 4 + keysNumBytes + valuesNumBytes;
-
- // write the number of elements into first 4 bytes.
- Platform.putInt(targetObject, targetOffset, input.numElements());
- // write the numBytes of key array into second 4 bytes.
- Platform.putInt(targetObject, targetOffset + 4, keysNumBytes);
-
- // Write the bytes of key array to the variable length portion.
- writeToMemory(keyArray.getBaseObject(), keyArray.getBaseOffset(),
- targetObject, targetOffset + 8, keysNumBytes);
-
- // Write the bytes of value array to the variable length portion.
- writeToMemory(valueArray.getBaseObject(), valueArray.getBaseOffset(),
- targetObject, targetOffset + 8 + keysNumBytes, valuesNumBytes);
-
- return getRoundedSize(numBytes);
- }
- }
-}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
index 99bf50a845..8e58cb9ad1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
@@ -31,15 +31,6 @@ import org.apache.spark.sql.types._
*/
object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafeProjection] {
- private val StringWriter = classOf[UnsafeRowWriters.UTF8StringWriter].getName
- private val BinaryWriter = classOf[UnsafeRowWriters.BinaryWriter].getName
- private val IntervalWriter = classOf[UnsafeRowWriters.IntervalWriter].getName
- private val StructWriter = classOf[UnsafeRowWriters.StructWriter].getName
- private val CompactDecimalWriter = classOf[UnsafeRowWriters.CompactDecimalWriter].getName
- private val DecimalWriter = classOf[UnsafeRowWriters.DecimalWriter].getName
- private val ArrayWriter = classOf[UnsafeRowWriters.ArrayWriter].getName
- private val MapWriter = classOf[UnsafeRowWriters.MapWriter].getName
-
/** Returns true iff we support this data type. */
def canSupport(dataType: DataType): Boolean = dataType match {
case NullType => true
@@ -51,348 +42,6 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
case _ => false
}
- def genAdditionalSize(dt: DataType, ev: GeneratedExpressionCode): String = dt match {
- case t: DecimalType if t.precision > Decimal.MAX_LONG_DIGITS =>
- s"$DecimalWriter.getSize(${ev.primitive})"
- case StringType =>
- s"${ev.isNull} ? 0 : $StringWriter.getSize(${ev.primitive})"
- case BinaryType =>
- s"${ev.isNull} ? 0 : $BinaryWriter.getSize(${ev.primitive})"
- case CalendarIntervalType =>
- s"${ev.isNull} ? 0 : 16"
- case _: StructType =>
- s"${ev.isNull} ? 0 : $StructWriter.getSize(${ev.primitive})"
- case _: ArrayType =>
- s"${ev.isNull} ? 0 : $ArrayWriter.getSize(${ev.primitive})"
- case _: MapType =>
- s"${ev.isNull} ? 0 : $MapWriter.getSize(${ev.primitive})"
- case _ => ""
- }
-
- def genFieldWriter(
- ctx: CodeGenContext,
- fieldType: DataType,
- ev: GeneratedExpressionCode,
- target: String,
- index: Int,
- cursor: String): String = fieldType match {
- case _ if ctx.isPrimitiveType(fieldType) =>
- s"${ctx.setColumn(target, fieldType, index, ev.primitive)}"
- case t: DecimalType if t.precision <= Decimal.MAX_LONG_DIGITS =>
- s"""
- // make sure Decimal object has the same scale as DecimalType
- if (${ev.primitive}.changePrecision(${t.precision}, ${t.scale})) {
- $CompactDecimalWriter.write($target, $index, $cursor, ${ev.primitive});
- } else {
- $target.setNullAt($index);
- }
- """
- case t: DecimalType if t.precision > Decimal.MAX_LONG_DIGITS =>
- s"""
- // make sure Decimal object has the same scale as DecimalType
- if (${ev.primitive}.changePrecision(${t.precision}, ${t.scale})) {
- $cursor += $DecimalWriter.write($target, $index, $cursor, ${ev.primitive});
- } else {
- $cursor += $DecimalWriter.write($target, $index, $cursor, null);
- }
- """
- case StringType =>
- s"$cursor += $StringWriter.write($target, $index, $cursor, ${ev.primitive})"
- case BinaryType =>
- s"$cursor += $BinaryWriter.write($target, $index, $cursor, ${ev.primitive})"
- case CalendarIntervalType =>
- s"$cursor += $IntervalWriter.write($target, $index, $cursor, ${ev.primitive})"
- case _: StructType =>
- s"$cursor += $StructWriter.write($target, $index, $cursor, ${ev.primitive})"
- case _: ArrayType =>
- s"$cursor += $ArrayWriter.write($target, $index, $cursor, ${ev.primitive})"
- case _: MapType =>
- s"$cursor += $MapWriter.write($target, $index, $cursor, ${ev.primitive})"
- case NullType => ""
- case _ =>
- throw new UnsupportedOperationException(s"Not supported DataType: $fieldType")
- }
-
- /**
- * Generates the Java code to convert a struct (backed by InternalRow) to UnsafeRow.
- *
- * @param ctx code generation context
- * @param inputs could be the codes for expressions or input struct fields.
- * @param inputTypes types of the inputs
- */
- private def createCodeForStruct(
- ctx: CodeGenContext,
- row: String,
- inputs: Seq[GeneratedExpressionCode],
- inputTypes: Seq[DataType]): GeneratedExpressionCode = {
-
- val fixedSize = 8 * inputTypes.length + UnsafeRow.calculateBitSetWidthInBytes(inputTypes.length)
-
- val output = ctx.freshName("convertedStruct")
- ctx.addMutableState("UnsafeRow", output, s"this.$output = new UnsafeRow();")
- val buffer = ctx.freshName("buffer")
- ctx.addMutableState("byte[]", buffer, s"this.$buffer = new byte[$fixedSize];")
- val cursor = ctx.freshName("cursor")
- ctx.addMutableState("int", cursor, s"this.$cursor = 0;")
- val tmpBuffer = ctx.freshName("tmpBuffer")
-
- val convertedFields = inputTypes.zip(inputs).zipWithIndex.map { case ((dt, input), i) =>
- val ev = createConvertCode(ctx, input, dt)
- val growBuffer = if (!UnsafeRow.isFixedLength(dt)) {
- val numBytes = ctx.freshName("numBytes")
- s"""
- int $numBytes = $cursor + (${genAdditionalSize(dt, ev)});
- if ($buffer.length < $numBytes) {
- // This will not happen frequently, because the buffer is re-used.
- byte[] $tmpBuffer = new byte[$numBytes * 2];
- Platform.copyMemory($buffer, Platform.BYTE_ARRAY_OFFSET,
- $tmpBuffer, Platform.BYTE_ARRAY_OFFSET, $buffer.length);
- $buffer = $tmpBuffer;
- }
- $output.pointTo($buffer, Platform.BYTE_ARRAY_OFFSET, ${inputTypes.length}, $numBytes);
- """
- } else {
- ""
- }
- val update = dt match {
- case dt: DecimalType if dt.precision > Decimal.MAX_LONG_DIGITS =>
- // Can't call setNullAt() for DecimalType
- s"""
- if (${ev.isNull}) {
- $cursor += $DecimalWriter.write($output, $i, $cursor, null);
- } else {
- ${genFieldWriter(ctx, dt, ev, output, i, cursor)};
- }
- """
- case _ =>
- s"""
- if (${ev.isNull}) {
- $output.setNullAt($i);
- } else {
- ${genFieldWriter(ctx, dt, ev, output, i, cursor)};
- }
- """
- }
- s"""
- ${ev.code}
- $growBuffer
- $update
- """
- }
-
- val code = s"""
- $cursor = $fixedSize;
- $output.pointTo($buffer, Platform.BYTE_ARRAY_OFFSET, ${inputTypes.length}, $cursor);
- ${ctx.splitExpressions(row, convertedFields)}
- """
- GeneratedExpressionCode(code, "false", output)
- }
-
- private def getWriter(dt: DataType) = dt match {
- case StringType => classOf[UnsafeWriters.UTF8StringWriter].getName
- case BinaryType => classOf[UnsafeWriters.BinaryWriter].getName
- case CalendarIntervalType => classOf[UnsafeWriters.IntervalWriter].getName
- case _: StructType => classOf[UnsafeWriters.StructWriter].getName
- case _: ArrayType => classOf[UnsafeWriters.ArrayWriter].getName
- case _: MapType => classOf[UnsafeWriters.MapWriter].getName
- case _: DecimalType => classOf[UnsafeWriters.DecimalWriter].getName
- }
-
- private def createCodeForArray(
- ctx: CodeGenContext,
- input: GeneratedExpressionCode,
- elementType: DataType): GeneratedExpressionCode = {
- val output = ctx.freshName("convertedArray")
- ctx.addMutableState("UnsafeArrayData", output, s"$output = new UnsafeArrayData();")
- val buffer = ctx.freshName("buffer")
- ctx.addMutableState("byte[]", buffer, s"$buffer = new byte[64];")
- val tmpBuffer = ctx.freshName("tmpBuffer")
- val outputIsNull = ctx.freshName("isNull")
- val numElements = ctx.freshName("numElements")
- val fixedSize = ctx.freshName("fixedSize")
- val numBytes = ctx.freshName("numBytes")
- val cursor = ctx.freshName("cursor")
- val index = ctx.freshName("index")
- val elementName = ctx.freshName("elementName")
-
- val element = {
- val code = s"${ctx.javaType(elementType)} $elementName = " +
- s"${ctx.getValue(input.primitive, elementType, index)};"
- val isNull = s"${input.primitive}.isNullAt($index)"
- GeneratedExpressionCode(code, isNull, elementName)
- }
-
- val convertedElement = createConvertCode(ctx, element, elementType)
-
- val writeElement = elementType match {
- case _ if ctx.isPrimitiveType(elementType) =>
- // Should we do word align?
- val elementSize = elementType.defaultSize
- s"""
- Platform.put${ctx.primitiveTypeName(elementType)}(
- $buffer,
- Platform.BYTE_ARRAY_OFFSET + $cursor,
- ${convertedElement.primitive});
- $cursor += $elementSize;
- """
- case t: DecimalType if t.precision <= Decimal.MAX_LONG_DIGITS =>
- s"""
- Platform.putLong(
- $buffer,
- Platform.BYTE_ARRAY_OFFSET + $cursor,
- ${convertedElement.primitive}.toUnscaledLong());
- $cursor += 8;
- """
- case _ =>
- val writer = getWriter(elementType)
- s"""
- $cursor += $writer.write(
- $buffer,
- Platform.BYTE_ARRAY_OFFSET + $cursor,
- ${convertedElement.primitive});
- """
- }
-
- val checkNull = convertedElement.isNull + (elementType match {
- case t: DecimalType =>
- s" || !${convertedElement.primitive}.changePrecision(${t.precision}, ${t.scale})"
- case _ => ""
- })
-
- val elementSize = elementType match {
- // Should we do word align for primitive types?
- case _ if ctx.isPrimitiveType(elementType) => elementType.defaultSize.toString
- case t: DecimalType if t.precision <= Decimal.MAX_LONG_DIGITS => "8"
- case _ =>
- val writer = getWriter(elementType)
- s"$writer.getSize(${convertedElement.primitive})"
- }
-
- val code = s"""
- ${input.code}
- final boolean $outputIsNull = ${input.isNull};
- if (!$outputIsNull) {
- if (${input.primitive} instanceof UnsafeArrayData) {
- $output = (UnsafeArrayData) ${input.primitive};
- } else {
- final int $numElements = ${input.primitive}.numElements();
- final int $fixedSize = 4 * $numElements;
- int $numBytes = $fixedSize;
-
- int $cursor = $fixedSize;
- for (int $index = 0; $index < $numElements; $index++) {
- ${convertedElement.code}
- if ($checkNull) {
- // If element is null, write the negative value address into offset region.
- Platform.putInt($buffer, Platform.BYTE_ARRAY_OFFSET + 4 * $index, -$cursor);
- } else {
- $numBytes += $elementSize;
- if ($buffer.length < $numBytes) {
- // This will not happen frequently, because the buffer is re-used.
- byte[] $tmpBuffer = new byte[$numBytes * 2];
- Platform.copyMemory($buffer, Platform.BYTE_ARRAY_OFFSET,
- $tmpBuffer, Platform.BYTE_ARRAY_OFFSET, $buffer.length);
- $buffer = $tmpBuffer;
- }
- Platform.putInt($buffer, Platform.BYTE_ARRAY_OFFSET + 4 * $index, $cursor);
- $writeElement
- }
- }
-
- $output.pointTo(
- $buffer,
- Platform.BYTE_ARRAY_OFFSET,
- $numElements,
- $numBytes);
- }
- }
- """
- GeneratedExpressionCode(code, outputIsNull, output)
- }
-
- private def createCodeForMap(
- ctx: CodeGenContext,
- input: GeneratedExpressionCode,
- keyType: DataType,
- valueType: DataType): GeneratedExpressionCode = {
- val output = ctx.freshName("convertedMap")
- val outputIsNull = ctx.freshName("isNull")
- val keyArrayName = ctx.freshName("keyArrayName")
- val valueArrayName = ctx.freshName("valueArrayName")
-
- val keyArray = {
- val code = s"ArrayData $keyArrayName = ${input.primitive}.keyArray();"
- val isNull = "false"
- GeneratedExpressionCode(code, isNull, keyArrayName)
- }
-
- val valueArray = {
- val code = s"ArrayData $valueArrayName = ${input.primitive}.valueArray();"
- val isNull = "false"
- GeneratedExpressionCode(code, isNull, valueArrayName)
- }
-
- val convertedKeys = createCodeForArray(ctx, keyArray, keyType)
- val convertedValues = createCodeForArray(ctx, valueArray, valueType)
-
- val code = s"""
- ${input.code}
- final boolean $outputIsNull = ${input.isNull};
- UnsafeMapData $output = null;
- if (!$outputIsNull) {
- if (${input.primitive} instanceof UnsafeMapData) {
- $output = (UnsafeMapData) ${input.primitive};
- } else {
- ${convertedKeys.code}
- ${convertedValues.code}
- $output = new UnsafeMapData(${convertedKeys.primitive}, ${convertedValues.primitive});
- }
- }
- """
- GeneratedExpressionCode(code, outputIsNull, output)
- }
-
- /**
- * Generates the java code to convert a data to its unsafe version.
- */
- private def createConvertCode(
- ctx: CodeGenContext,
- input: GeneratedExpressionCode,
- dataType: DataType): GeneratedExpressionCode = dataType match {
- case t: StructType =>
- val output = ctx.freshName("convertedStruct")
- val outputIsNull = ctx.freshName("isNull")
- val fieldTypes = t.fields.map(_.dataType)
- val fieldEvals = fieldTypes.zipWithIndex.map { case (dt, i) =>
- val fieldName = ctx.freshName("fieldName")
- val code = s"${ctx.javaType(dt)} $fieldName = " +
- s"${ctx.getValue(input.primitive, dt, i.toString)};"
- val isNull = s"${input.primitive}.isNullAt($i)"
- GeneratedExpressionCode(code, isNull, fieldName)
- }
- val converter = createCodeForStruct(ctx, input.primitive, fieldEvals, fieldTypes)
- val code = s"""
- ${input.code}
- UnsafeRow $output = null;
- final boolean $outputIsNull = ${input.isNull};
- if (!$outputIsNull) {
- if (${input.primitive} instanceof UnsafeRow) {
- $output = (UnsafeRow) ${input.primitive};
- } else {
- ${converter.code}
- $output = ${converter.primitive};
- }
- }
- """
- GeneratedExpressionCode(code, outputIsNull, output)
-
- case ArrayType(elementType, _) => createCodeForArray(ctx, input, elementType)
-
- case MapType(kt, vt, _) => createCodeForMap(ctx, input, kt, vt)
-
- case _ => input
- }
-
private val rowWriterClass = classOf[UnsafeRowWriter].getName
private val arrayWriterClass = classOf[UnsafeArrayWriter].getName