diff options
-rw-r--r-- | sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala | 84 |
1 files changed, 24 insertions, 60 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 03c5f449bf..55562facf9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -206,11 +206,11 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro ctx.addMutableState("UnsafeArrayData", output, s"$output = new UnsafeArrayData();") val buffer = ctx.freshName("buffer") ctx.addMutableState("byte[]", buffer, s"$buffer = new byte[64];") + val tmpBuffer = ctx.freshName("tmpBuffer") val outputIsNull = ctx.freshName("isNull") val numElements = ctx.freshName("numElements") val fixedSize = ctx.freshName("fixedSize") val numBytes = ctx.freshName("numBytes") - val elements = ctx.freshName("elements") val cursor = ctx.freshName("cursor") val index = ctx.freshName("index") val elementName = ctx.freshName("elementName") @@ -224,57 +224,11 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro val convertedElement = createConvertCode(ctx, element, elementType) - // go through the input array to calculate how many bytes we need. - val calculateNumBytes = elementType match { - case _ if ctx.isPrimitiveType(elementType) => - // Should we do word align? - val elementSize = elementType.defaultSize - s""" - $numBytes += $elementSize * $numElements; - """ - case t: DecimalType if t.precision <= Decimal.MAX_LONG_DIGITS => - s""" - $numBytes += 8 * $numElements; - """ - case _ => - val writer = getWriter(elementType) - val elementSize = s"$writer.getSize($elements[$index])" - // TODO(davies): avoid the copy - val unsafeType = elementType match { - case _: StructType => "UnsafeRow" - case _: ArrayType => "UnsafeArrayData" - case _: MapType => "UnsafeMapData" - case _ => ctx.javaType(elementType) - } - val copy = elementType match { - // We reuse the buffer during conversion, need copy it before process next element. - case _: StructType | _: ArrayType | _: MapType => ".copy()" - case _ => "" - } - - val newElements = if (elementType == BinaryType) { - s"new byte[$numElements][]" - } else { - s"new $unsafeType[$numElements]" - } - s""" - final $unsafeType[] $elements = $newElements; - for (int $index = 0; $index < $numElements; $index++) { - ${convertedElement.code} - if (!${convertedElement.isNull}) { - $elements[$index] = ${convertedElement.primitive}$copy; - $numBytes += $elementSize; - } - } - """ - } - val writeElement = elementType match { case _ if ctx.isPrimitiveType(elementType) => // Should we do word align? val elementSize = elementType.defaultSize s""" - ${convertedElement.code} Platform.put${ctx.primitiveTypeName(elementType)}( $buffer, Platform.BYTE_ARRAY_OFFSET + $cursor, @@ -283,7 +237,6 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro """ case t: DecimalType if t.precision <= Decimal.MAX_LONG_DIGITS => s""" - ${convertedElement.code} Platform.putLong( $buffer, Platform.BYTE_ARRAY_OFFSET + $cursor, @@ -296,15 +249,23 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro $cursor += $writer.write( $buffer, Platform.BYTE_ARRAY_OFFSET + $cursor, - $elements[$index]); + ${convertedElement.primitive}); """ } - val checkNull = elementType match { - case _ if ctx.isPrimitiveType(elementType) => s"${convertedElement.isNull}" - case t: DecimalType => s"$elements[$index] == null" + - s" || !$elements[$index].changePrecision(${t.precision}, ${t.scale})" - case _ => s"$elements[$index] == null" + val checkNull = convertedElement.isNull + (elementType match { + case t: DecimalType => + s" || !${convertedElement.primitive}.changePrecision(${t.precision}, ${t.scale})" + case _ => "" + }) + + val elementSize = elementType match { + // Should we do word align for primitive types? + case _ if ctx.isPrimitiveType(elementType) => elementType.defaultSize.toString + case t: DecimalType if t.precision <= Decimal.MAX_LONG_DIGITS => "8" + case _ => + val writer = getWriter(elementType) + s"$writer.getSize(${convertedElement.primitive})" } val code = s""" @@ -318,18 +279,21 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro final int $fixedSize = 4 * $numElements; int $numBytes = $fixedSize; - $calculateNumBytes - - if ($numBytes > $buffer.length) { - $buffer = new byte[$numBytes]; - } - int $cursor = $fixedSize; for (int $index = 0; $index < $numElements; $index++) { + ${convertedElement.code} if ($checkNull) { // If element is null, write the negative value address into offset region. Platform.putInt($buffer, Platform.BYTE_ARRAY_OFFSET + 4 * $index, -$cursor); } else { + $numBytes += $elementSize; + if ($buffer.length < $numBytes) { + // This will not happen frequently, because the buffer is re-used. + byte[] $tmpBuffer = new byte[$numBytes * 2]; + Platform.copyMemory($buffer, Platform.BYTE_ARRAY_OFFSET, + $tmpBuffer, Platform.BYTE_ARRAY_OFFSET, $buffer.length); + $buffer = $tmpBuffer; + } Platform.putInt($buffer, Platform.BYTE_ARRAY_OFFSET + 4 * $index, $cursor); $writeElement } |