aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorWenchen Fan <cloud0fan@outlook.com>2015-09-10 10:04:38 -0700
committerDavies Liu <davies.liu@gmail.com>2015-09-10 10:04:38 -0700
commit4f1daa1ef6b36440962f3c8faea3928599e33784 (patch)
tree29901bdaa603b657f4ebd3b54b8f46c19f81453b
parent48817cc111a9705f40b7c842315eee24291c2198 (diff)
downloadspark-4f1daa1ef6b36440962f3c8faea3928599e33784.tar.gz
spark-4f1daa1ef6b36440962f3c8faea3928599e33784.tar.bz2
spark-4f1daa1ef6b36440962f3c8faea3928599e33784.zip
[SPARK-10065] [SQL] avoid the extra copy when generate unsafe array
The reason for this extra copy is that we iterate the array twice: calculate elements data size and copy elements to array buffer. A simple solution is to follow `createCodeForStruct`, we can dynamically grow the buffer when needed and thus don't need to know the data size ahead. This PR also include some typo and style fixes, and did some minor refactor to make sure `input.primitive` is always variable name not code when generate unsafe code. Author: Wenchen Fan <cloud0fan@outlook.com> Closes #8496 from cloud-fan/avoid-copy.
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala84
1 files changed, 24 insertions, 60 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
index 03c5f449bf..55562facf9 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
@@ -206,11 +206,11 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
ctx.addMutableState("UnsafeArrayData", output, s"$output = new UnsafeArrayData();")
val buffer = ctx.freshName("buffer")
ctx.addMutableState("byte[]", buffer, s"$buffer = new byte[64];")
+ val tmpBuffer = ctx.freshName("tmpBuffer")
val outputIsNull = ctx.freshName("isNull")
val numElements = ctx.freshName("numElements")
val fixedSize = ctx.freshName("fixedSize")
val numBytes = ctx.freshName("numBytes")
- val elements = ctx.freshName("elements")
val cursor = ctx.freshName("cursor")
val index = ctx.freshName("index")
val elementName = ctx.freshName("elementName")
@@ -224,57 +224,11 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
val convertedElement = createConvertCode(ctx, element, elementType)
- // go through the input array to calculate how many bytes we need.
- val calculateNumBytes = elementType match {
- case _ if ctx.isPrimitiveType(elementType) =>
- // Should we do word align?
- val elementSize = elementType.defaultSize
- s"""
- $numBytes += $elementSize * $numElements;
- """
- case t: DecimalType if t.precision <= Decimal.MAX_LONG_DIGITS =>
- s"""
- $numBytes += 8 * $numElements;
- """
- case _ =>
- val writer = getWriter(elementType)
- val elementSize = s"$writer.getSize($elements[$index])"
- // TODO(davies): avoid the copy
- val unsafeType = elementType match {
- case _: StructType => "UnsafeRow"
- case _: ArrayType => "UnsafeArrayData"
- case _: MapType => "UnsafeMapData"
- case _ => ctx.javaType(elementType)
- }
- val copy = elementType match {
- // We reuse the buffer during conversion, need copy it before process next element.
- case _: StructType | _: ArrayType | _: MapType => ".copy()"
- case _ => ""
- }
-
- val newElements = if (elementType == BinaryType) {
- s"new byte[$numElements][]"
- } else {
- s"new $unsafeType[$numElements]"
- }
- s"""
- final $unsafeType[] $elements = $newElements;
- for (int $index = 0; $index < $numElements; $index++) {
- ${convertedElement.code}
- if (!${convertedElement.isNull}) {
- $elements[$index] = ${convertedElement.primitive}$copy;
- $numBytes += $elementSize;
- }
- }
- """
- }
-
val writeElement = elementType match {
case _ if ctx.isPrimitiveType(elementType) =>
// Should we do word align?
val elementSize = elementType.defaultSize
s"""
- ${convertedElement.code}
Platform.put${ctx.primitiveTypeName(elementType)}(
$buffer,
Platform.BYTE_ARRAY_OFFSET + $cursor,
@@ -283,7 +237,6 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
"""
case t: DecimalType if t.precision <= Decimal.MAX_LONG_DIGITS =>
s"""
- ${convertedElement.code}
Platform.putLong(
$buffer,
Platform.BYTE_ARRAY_OFFSET + $cursor,
@@ -296,15 +249,23 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
$cursor += $writer.write(
$buffer,
Platform.BYTE_ARRAY_OFFSET + $cursor,
- $elements[$index]);
+ ${convertedElement.primitive});
"""
}
- val checkNull = elementType match {
- case _ if ctx.isPrimitiveType(elementType) => s"${convertedElement.isNull}"
- case t: DecimalType => s"$elements[$index] == null" +
- s" || !$elements[$index].changePrecision(${t.precision}, ${t.scale})"
- case _ => s"$elements[$index] == null"
+ val checkNull = convertedElement.isNull + (elementType match {
+ case t: DecimalType =>
+ s" || !${convertedElement.primitive}.changePrecision(${t.precision}, ${t.scale})"
+ case _ => ""
+ })
+
+ val elementSize = elementType match {
+ // Should we do word align for primitive types?
+ case _ if ctx.isPrimitiveType(elementType) => elementType.defaultSize.toString
+ case t: DecimalType if t.precision <= Decimal.MAX_LONG_DIGITS => "8"
+ case _ =>
+ val writer = getWriter(elementType)
+ s"$writer.getSize(${convertedElement.primitive})"
}
val code = s"""
@@ -318,18 +279,21 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
final int $fixedSize = 4 * $numElements;
int $numBytes = $fixedSize;
- $calculateNumBytes
-
- if ($numBytes > $buffer.length) {
- $buffer = new byte[$numBytes];
- }
-
int $cursor = $fixedSize;
for (int $index = 0; $index < $numElements; $index++) {
+ ${convertedElement.code}
if ($checkNull) {
// If element is null, write the negative value address into offset region.
Platform.putInt($buffer, Platform.BYTE_ARRAY_OFFSET + 4 * $index, -$cursor);
} else {
+ $numBytes += $elementSize;
+ if ($buffer.length < $numBytes) {
+ // This will not happen frequently, because the buffer is re-used.
+ byte[] $tmpBuffer = new byte[$numBytes * 2];
+ Platform.copyMemory($buffer, Platform.BYTE_ARRAY_OFFSET,
+ $tmpBuffer, Platform.BYTE_ARRAY_OFFSET, $buffer.length);
+ $buffer = $tmpBuffer;
+ }
Platform.putInt($buffer, Platform.BYTE_ARRAY_OFFSET + 4 * $index, $cursor);
$writeElement
}