From 4f1daa1ef6b36440962f3c8faea3928599e33784 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 10 Sep 2015 10:04:38 -0700 Subject: [SPARK-10065] [SQL] avoid the extra copy when generate unsafe array The reason for this extra copy is that we iterate the array twice: calculate elements data size and copy elements to array buffer. A simple solution is to follow `createCodeForStruct`, we can dynamically grow the buffer when needed and thus don't need to know the data size ahead. This PR also include some typo and style fixes, and did some minor refactor to make sure `input.primitive` is always variable name not code when generate unsafe code. Author: Wenchen Fan Closes #8496 from cloud-fan/avoid-copy. --- .../codegen/GenerateUnsafeProjection.scala | 84 +++++++--------------- 1 file changed, 24 insertions(+), 60 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 03c5f449bf..55562facf9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -206,11 +206,11 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro ctx.addMutableState("UnsafeArrayData", output, s"$output = new UnsafeArrayData();") val buffer = ctx.freshName("buffer") ctx.addMutableState("byte[]", buffer, s"$buffer = new byte[64];") + val tmpBuffer = ctx.freshName("tmpBuffer") val outputIsNull = ctx.freshName("isNull") val numElements = ctx.freshName("numElements") val fixedSize = ctx.freshName("fixedSize") val numBytes = ctx.freshName("numBytes") - val elements = ctx.freshName("elements") val cursor = ctx.freshName("cursor") val index = ctx.freshName("index") val elementName = ctx.freshName("elementName") @@ -224,57 +224,11 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro val convertedElement = createConvertCode(ctx, element, elementType) - // go through the input array to calculate how many bytes we need. - val calculateNumBytes = elementType match { - case _ if ctx.isPrimitiveType(elementType) => - // Should we do word align? - val elementSize = elementType.defaultSize - s""" - $numBytes += $elementSize * $numElements; - """ - case t: DecimalType if t.precision <= Decimal.MAX_LONG_DIGITS => - s""" - $numBytes += 8 * $numElements; - """ - case _ => - val writer = getWriter(elementType) - val elementSize = s"$writer.getSize($elements[$index])" - // TODO(davies): avoid the copy - val unsafeType = elementType match { - case _: StructType => "UnsafeRow" - case _: ArrayType => "UnsafeArrayData" - case _: MapType => "UnsafeMapData" - case _ => ctx.javaType(elementType) - } - val copy = elementType match { - // We reuse the buffer during conversion, need copy it before process next element. - case _: StructType | _: ArrayType | _: MapType => ".copy()" - case _ => "" - } - - val newElements = if (elementType == BinaryType) { - s"new byte[$numElements][]" - } else { - s"new $unsafeType[$numElements]" - } - s""" - final $unsafeType[] $elements = $newElements; - for (int $index = 0; $index < $numElements; $index++) { - ${convertedElement.code} - if (!${convertedElement.isNull}) { - $elements[$index] = ${convertedElement.primitive}$copy; - $numBytes += $elementSize; - } - } - """ - } - val writeElement = elementType match { case _ if ctx.isPrimitiveType(elementType) => // Should we do word align? val elementSize = elementType.defaultSize s""" - ${convertedElement.code} Platform.put${ctx.primitiveTypeName(elementType)}( $buffer, Platform.BYTE_ARRAY_OFFSET + $cursor, @@ -283,7 +237,6 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro """ case t: DecimalType if t.precision <= Decimal.MAX_LONG_DIGITS => s""" - ${convertedElement.code} Platform.putLong( $buffer, Platform.BYTE_ARRAY_OFFSET + $cursor, @@ -296,15 +249,23 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro $cursor += $writer.write( $buffer, Platform.BYTE_ARRAY_OFFSET + $cursor, - $elements[$index]); + ${convertedElement.primitive}); """ } - val checkNull = elementType match { - case _ if ctx.isPrimitiveType(elementType) => s"${convertedElement.isNull}" - case t: DecimalType => s"$elements[$index] == null" + - s" || !$elements[$index].changePrecision(${t.precision}, ${t.scale})" - case _ => s"$elements[$index] == null" + val checkNull = convertedElement.isNull + (elementType match { + case t: DecimalType => + s" || !${convertedElement.primitive}.changePrecision(${t.precision}, ${t.scale})" + case _ => "" + }) + + val elementSize = elementType match { + // Should we do word align for primitive types? + case _ if ctx.isPrimitiveType(elementType) => elementType.defaultSize.toString + case t: DecimalType if t.precision <= Decimal.MAX_LONG_DIGITS => "8" + case _ => + val writer = getWriter(elementType) + s"$writer.getSize(${convertedElement.primitive})" } val code = s""" @@ -318,18 +279,21 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro final int $fixedSize = 4 * $numElements; int $numBytes = $fixedSize; - $calculateNumBytes - - if ($numBytes > $buffer.length) { - $buffer = new byte[$numBytes]; - } - int $cursor = $fixedSize; for (int $index = 0; $index < $numElements; $index++) { + ${convertedElement.code} if ($checkNull) { // If element is null, write the negative value address into offset region. Platform.putInt($buffer, Platform.BYTE_ARRAY_OFFSET + 4 * $index, -$cursor); } else { + $numBytes += $elementSize; + if ($buffer.length < $numBytes) { + // This will not happen frequently, because the buffer is re-used. + byte[] $tmpBuffer = new byte[$numBytes * 2]; + Platform.copyMemory($buffer, Platform.BYTE_ARRAY_OFFSET, + $tmpBuffer, Platform.BYTE_ARRAY_OFFSET, $buffer.length); + $buffer = $tmpBuffer; + } Platform.putInt($buffer, Platform.BYTE_ARRAY_OFFSET + 4 * $index, $cursor); $writeElement } -- cgit v1.2.3